From 3a72408f42d5a1be0e53bc5489b02dd7904fbd7f Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Tue, 23 Feb 2021 19:34:48 +0800 Subject: [PATCH] [Cherry-pick][Dy2stat] Cherry-pick of PR31082 and PR31051 (#31101) Cherry-pick of #31051 and #31082 --- .../dygraph_to_static/convert_operators.py | 54 ++++- .../tensor_shape_transformer.py | 226 +++++++++++++----- .../test_convert_operators.py | 53 ++++ .../dygraph_to_static/test_tensor_shape.py | 75 +++++- .../paddle/jit/dy2static/convert_operators.py | 6 +- 5 files changed, 335 insertions(+), 79 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 13574832bd3..779e50c3dc5 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -267,12 +267,12 @@ def convert_var_shape(x, idx=None, in_control_flow=False): A function representation of the shape of variable. """ - def has_negetive(list_shape, idx=None): + def has_negative(list_shape, idx=None): if idx is not None: return list_shape[idx] < 0 - num_negetive = sum([1 if i < 0 else 0 for i in list_shape]) - return num_negetive > 0 + num_negative = sum([1 if i < 0 else 0 for i in list_shape]) + return num_negative > 0 # When `x` is Variable, call nn.shape(x) in following cases: # (1) The shape of `x` is used in control flow condition. @@ -280,18 +280,62 @@ def convert_var_shape(x, idx=None, in_control_flow=False): # if x.shape[0] == 1: # y = XX # ``` - # (2) The dim to be used is negetive + # (2) The dim to be used is negative # ``` # # Assume x.shape=[3, -1] in static mode # y = paddle.reshape(x, shape=[1, x.shape[1]]) # ``` - if isinstance(x, Variable) and (in_control_flow or has_negetive(x.shape, + if isinstance(x, Variable) and (in_control_flow or has_negative(x.shape, idx)): return nn.shape(x) if idx is None else nn.shape(x)[idx] else: return x.shape if idx is None else x.shape[idx] +def convert_var_shape_simple(x): + """ + A function representation of the shape of variable. + """ + if isinstance(x, Variable): + return nn.shape(x) + else: + return x.shape + + +def eval_if_exist_else_none(name): + try: + return eval(name) + except: + return None + + +def choose_shape_attr_or_api(attr_shape, api_shape, idx=None): + """ + Input can be attribute `x.shape` or api `shape(x)`, this function + chooses which one to return to use in dy2stat. + + Note: sometimes users write `x.shape[3]`, so attr_shape can be an integer. + """ + if api_shape is None: + return attr_shape if idx is None else attr_shape[idx] + if not isinstance(attr_shape, (list, tuple)): + # some variables like x.shape[0] is no longer a list or tuple + if isinstance(attr_shape, int) and attr_shape < 0: + return api_shape if idx is None else api_shape[idx] + return attr_shape if idx is None else attr_shape[idx] + + def has_negative(list_shape, idx=None): + if idx is not None: + return list_shape[idx] < 0 + + num_negative = sum([1 if i < 0 else 0 for i in list_shape]) + return num_negative > 0 + + if has_negative(attr_shape, idx): + return api_shape if idx is None else api_shape[idx] + return attr_shape if idx is None else attr_shape[idx] + + def convert_shape_compare(left, *args): """ A function handles comparison difference between Paddle and Python. diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 6aa55042647..ddd5d84ef42 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -17,12 +17,15 @@ from __future__ import print_function import copy import gast +from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor +STATIC_CONVERT_VAR_SHAPE_SUFFIX = '__static_convert_var_shape_suffix' + def create_convert_shape_node(var_shape_node, slice_node=None, @@ -31,13 +34,20 @@ def create_convert_shape_node(var_shape_node, if isinstance(var_shape_node, gast.Attribute): args = [ast_to_source_code(var_shape_node.value).strip()] - if slice_node: + # (1) A slice can be a simple number such as 1, -2, i.e. gast.Index + # (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index + # In (1) case, we pass the number as 'idx' argument in convert_var_shape + # In (2) case, we have to make it like `convert_var_shape(x)[slice]` + if slice_node is not None and isinstance(slice_node, gast.Index): args.append(ast_to_source_code(slice_node).strip()) convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format( ",".join(args), in_control_flow) - api_shape_node = gast.parse(convert_var_shape_func).body[0].value + + if slice_node is not None and not isinstance(slice_node, gast.Index): + return gast.Subscript( + value=api_shape_node, slice=slice_node, ctx=gast.Load()) return api_shape_node if isinstance(var_shape_node, gast.Subscript): @@ -47,6 +57,39 @@ def create_convert_shape_node(var_shape_node, return result_node +def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None): + eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}')".format( + api_shape_name) + args = [attr_shape_name, eval_exist_func] + + if slice_node is not None and isinstance(slice_node, gast.Index): + args.append(ast_to_source_code(slice_node).strip()) + choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format( + ",".join(args)) + choose_shape_node = gast.parse(choose_shape_func).body[0].value + if slice_node is not None and not isinstance(slice_node, gast.Index): + return gast.Subscript( + value=choose_shape_node, slice=slice_node, ctx=gast.Load()) + return choose_shape_node + + +class ShapeAttributeTransformer(gast.NodeTransformer): + """ + Input a node like `x.shape` or `x[4].shape[0]` (self._is_var_shape(node) is True), + return a new node changes input to static shape API like `convert_var_shape(x)`, + `convert_var_shape(x[4])[0]`. + """ + + def visit_Attribute(self, node): + if node.attr == 'shape': + args = ast_to_source_code(node.value).strip() + convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape_simple({})".format( + args) + api_shape_node = gast.parse(convert_var_shape_func).body[0].value + return api_shape_node + return node + + class TensorShapeTransformer(gast.NodeTransformer): """ This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast. @@ -58,6 +101,8 @@ class TensorShapeTransformer(gast.NodeTransformer): ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node + # stores origin var string name (like "x" in `x = t.shape`) to + # static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`) self.name_to_var_shape = {} self.static_analysis_visitor = StaticAnalysisVisitor(self.root) @@ -72,8 +117,11 @@ class TensorShapeTransformer(gast.NodeTransformer): self.visit(self.root) def visit_Assign(self, node): - if self._update_name_to_var_shape(node): - return node + update_static_shape_var_node = self._update_name_to_var_shape(node) + if update_static_shape_var_node is not None: + ret = [node] + ret.extend(update_static_shape_var_node) + return ret self.generic_visit(node) return node @@ -81,37 +129,44 @@ class TensorShapeTransformer(gast.NodeTransformer): value_node = node.value slice_node = node.slice if isinstance(value_node, gast.Name): - if self._is_var_shape(value_node) and self._used_by_paddle_api( - value_node): - var_shape_node = self.name_to_var_shape[value_node.id] - return create_convert_shape_node(var_shape_node, slice_node) - - if isinstance(value_node, gast.Attribute): - if self._used_by_paddle_api(value_node) and self._is_var_shape( + if value_node.id in self.name_to_var_shape and self._used_by_paddle_api( value_node): - return create_convert_shape_node(value_node, slice_node) - + return create_choose_shape_node( + value_node.id, self.name_to_var_shape[value_node.id], + slice_node) + elif isinstance(value_node, gast.Attribute): + if self._used_by_paddle_api(value_node): + value_name = ast_to_source_code(value_node).strip() + if value_name in self.name_to_var_shape: + return create_choose_shape_node( + value_name, self.name_to_var_shape[value_name], + slice_node) + if self._is_var_shape(value_node): + return create_convert_shape_node(value_node, slice_node) return node def visit_Attribute(self, node): if self._used_by_paddle_api(node): + name = ast_to_source_code(node).strip() + if name in self.name_to_var_shape: + return create_choose_shape_node(name, + self.name_to_var_shape[name]) if self._is_var_shape(node): return create_convert_shape_node(node) return node def visit_Name(self, node): - if self._is_var_shape(node): + if node.id in self.name_to_var_shape: if self._used_by_paddle_api(node): - var_shape_node = self.name_to_var_shape[node.id] - return create_convert_shape_node(var_shape_node) + return create_choose_shape_node(node.id, + self.name_to_var_shape[node.id]) return node def visit_Call(self, node): - assert isinstance(node, gast.Call) if is_paddle_api(node): # Visit gast.Attribute and gast.Name to replace var.shape if necessary. self.generic_visit(node) - + # Don't have to visit other APIs return node def visit_If(self, node): @@ -147,22 +202,23 @@ class TensorShapeTransformer(gast.NodeTransformer): return False args = node.iter.args for idx, arg in enumerate(args): - if isinstance(arg, gast.Name) and self._is_var_shape(arg): - args[idx] = create_convert_shape_node(self.name_to_var_shape[ - arg.id]) - + if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape: + args[idx] = create_choose_shape_node( + arg.id, self.name_to_var_shape[arg.id]) return True def _transform_var_shape_if_necessary(self, cond): need_transformed = False for child_node in gast.walk(cond): var_shape_node = None - if isinstance(child_node, (gast.Attribute, gast.Subscript)): - if self._is_var_shape(child_node): + if isinstance(child_node, + (gast.Name, gast.Attribute, gast.Subscript)): + child_name = ast_to_source_code(child_node).strip() + if child_name in self.name_to_var_shape: + var_shape_node = create_choose_shape_node( + child_name, self.name_to_var_shape[child_name]) + elif self._is_var_shape(child_node): var_shape_node = child_node - elif isinstance(child_node, (gast.Name)): - if self._is_var_shape(child_node): - var_shape_node = self.name_to_var_shape[child_node.id] if var_shape_node: need_transformed = True @@ -170,17 +226,23 @@ class TensorShapeTransformer(gast.NodeTransformer): parent_node = wrapper_node.parent.node for field, value in gast.iter_fields(parent_node): if child_node is value: - setattr(parent_node, field, - create_convert_shape_node(var_shape_node, None, - True)) + if var_shape_node is child_node: + setattr(parent_node, field, + create_convert_shape_node(var_shape_node, + None, True)) + else: + setattr(parent_node, field, var_shape_node) break # Some child_node may be in a list such as gast.Compare if isinstance(value, list): has_converted_shape = False for i, v in enumerate(value): if child_node is v: - value[i] = create_convert_shape_node( - var_shape_node, None, True) + if var_shape_node is child_node: + value[i] = create_convert_shape_node( + var_shape_node, None, True) + else: + value[i] = var_shape_node has_converted_shape = True break if has_converted_shape: @@ -217,19 +279,12 @@ class TensorShapeTransformer(gast.NodeTransformer): """ Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise. """ - if not isinstance(node, (gast.Name, gast.Attribute, gast.Subscript)): + if not isinstance(node, (gast.Attribute, gast.Subscript)): return False - if isinstance(node, gast.Name) and node.id in self.name_to_var_shape: - return True - if isinstance(node, gast.Attribute): if node.attr != 'shape': return False - - if not isinstance(node.value, gast.Name): - return False - return True if isinstance(node, gast.Subscript): @@ -243,49 +298,94 @@ class TensorShapeTransformer(gast.NodeTransformer): target_node = node.targets[0] value_node = node.value + update_static_shape_var_node = None if isinstance(target_node, gast.Tuple): - has_updated = False + update_static_shape_var_node = [] for idx, element in enumerate(target_node.elts): target_id = ast_to_source_code(element).strip() if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape: + # TODO(zhhsplendid): is context a problem for the result node of gast.parse? + static_shape_var_name = unique_name.generate( + target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) + static_shape_var_node = gast.parse( + static_shape_var_name).body[0].value + + static_shape_value_name = self.name_to_var_shape[ + value_node.id] + static_shape_value_node = gast.parse( + static_shape_value_name).body[0].value index_value_node = gast.Constant(value=idx, kind=None) slice_index_node = gast.Index(value=index_value_node) - var_shape_node = self.name_to_var_shape[value_node.id] sub_node = gast.Subscript( - value=var_shape_node, + value=static_shape_value_node, slice=slice_index_node, ctx=gast.Load()) - self.name_to_var_shape[target_id] = sub_node - has_updated = True + + update_static_shape_var_node.append( + gast.Assign( + targets=[static_shape_var_node], + value=sub_node)) + + self.name_to_var_shape[ + target_id] = static_shape_var_name if isinstance(value_node, gast.Attribute): if self._is_var_shape(value_node): # eg: x.shape + static_shape_var_name = unique_name.generate( + target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) + static_shape_var_node = gast.parse( + static_shape_var_name).body[0].value + + static_shape_value_node = copy.deepcopy(value_node) + # x.shape becomes convert_var_shape_simple(x) + ShapeAttributeTransformer().visit( + static_shape_value_node) index_value_node = gast.Constant(value=idx, kind=None) slice_index_node = gast.Index(value=index_value_node) sub_node = gast.Subscript( - value=value_node, + value=static_shape_value_node, slice=slice_index_node, ctx=gast.Load()) - self.name_to_var_shape[target_id] = sub_node - has_updated = True - return has_updated + update_static_shape_var_node.append( + gast.Assign( + targets=[static_shape_var_node], + value=sub_node)) + self.name_to_var_shape[ + target_id] = static_shape_var_name + return update_static_shape_var_node else: target_id = ast_to_source_code(target_node).strip() - if isinstance(value_node, gast.Name): - if self._is_var_shape(value_node): - self.name_to_var_shape[target_id] = self.name_to_var_shape[ + if value_node.id in self.name_to_var_shape: + static_shape_var_name = unique_name.generate( + target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) + static_shape_var_node = gast.parse( + static_shape_var_name).body[0].value + static_shape_value_name = self.name_to_var_shape[ value_node.id] - return True - if isinstance(value_node, gast.Attribute): - if self._is_var_shape(value_node): # eg: x.shape - self.name_to_var_shape[target_id] = value_node - return True - if isinstance(value_node, gast.Subscript): - if isinstance(value_node.value, gast.Attribute): - if self._is_var_shape(value_node.value): # eg: x.shape[0] - self.name_to_var_shape[target_id] = value_node - return True - return False + static_shape_value_node = gast.parse( + static_shape_value_name).body[0].value + + update_static_shape_var_node = [ + gast.Assign( + targets=[static_shape_var_node], + value=static_shape_value_node) + ] + self.name_to_var_shape[target_id] = static_shape_var_name + elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0] + static_shape_var_name = unique_name.generate( + target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) + static_shape_var_node = gast.parse(static_shape_var_name).body[ + 0].value + static_shape_value_node = copy.deepcopy(value_node) + # x.shape becomes convert_var_shape_simple(x) + ShapeAttributeTransformer().visit(static_shape_value_node) + update_static_shape_var_node = [ + gast.Assign( + targets=[static_shape_var_node], + value=static_shape_value_node) + ] + self.name_to_var_shape[target_id] = static_shape_var_name + return update_static_shape_var_node diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py index 28c5d220213..631cd426b32 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py @@ -136,5 +136,58 @@ class TestConvertShapeCompare(unittest.TestCase): paddle.disable_static() +class TestChooseShapeAttrOrApi(unittest.TestCase): + def test_api_shape_is_none(self): + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api([1, 2], None), + [1, 2]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api([1], None), [1]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api([2, 3, 7], None, 0), + 2) + + def test_attr_shape_is_int(self): + x = paddle.zeros([1, 3, 5, 7]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api(x.shape[0], + paddle.shape(x)[0]), + 1) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api(x.shape[1], + paddle.shape(x)[1]), + 3) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api(-1, + paddle.shape(x)[0]), + paddle.shape(x)[0]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api(-1, + paddle.shape(x), 0), + paddle.shape(x)[0]) + + def test_positive_attr_shape(self): + x = paddle.zeros([1, 3, 5, 7]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api(x.shape, + paddle.shape(x)), + x.shape) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api(x.shape, + paddle.shape(x), 3), + x.shape[3]) + + def test_negative_attr_shape(self): + x = paddle.zeros([7]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api([-1], + paddle.shape(x), 0), + paddle.shape(x)[0]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api([-1], + paddle.shape(x)), + paddle.shape(x)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index 17809ea16fd..d28864aade5 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -60,6 +60,16 @@ def dyfunc_tensor_shape_5(x): return res +def dyfunc_tensor_shape_6(x): + # `res = fluid.layers.reshape(x, shape=(-1, s))` to + # `res = fluid.layers.reshape(x, shape=(-1, + # paddle.jit.dy2static.convert_var_shape(x)[0:]))` + x = fluid.dygraph.to_variable(x) + s = x.shape[0:] + res = fluid.layers.reshape(x, shape=s) + return res + + def dyfunc_tuple_shape_1(x): x = paddle.to_tensor(x) a, b = x.shape @@ -197,6 +207,14 @@ def dyfunc_with_while_4(x): return x +def dyfunc_change_shape_after_assign(x): + x = paddle.to_tensor(x) + a, b = x.shape + x = paddle.reshape(x, shape=(-1, 1)) + res = paddle.reshape(x, shape=(b, a)) + return res + + # 1. Basic tests without control flow class TestTensorShapeBasic(unittest.TestCase): def setUp(self): @@ -279,6 +297,21 @@ class TestTensorShapeBasic5(TestTensorShapeBasic): def init_test_func(self): self.dygraph_func = dyfunc_tensor_shape_5 + def _set_expected_op_num(self): + self.expected_op_num = 4 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 + + +class TestTensorShapeBasic6(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_tensor_shape_6 + + def _set_expected_op_num(self): + self.expected_op_num = 4 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 + class TestTupleShape1(TestTensorShapeBasic): def init_test_func(self): @@ -312,9 +345,9 @@ class TestTensorShapeInIf1(TestTensorShapeBasic): self.dygraph_func = dyfunc_with_if_1 def _set_expected_op_num(self): - self.expected_op_num = 26 - self.expected_shape_op_num = 2 - self.expected_slice_op_num = 2 + self.expected_op_num = 4 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 class TestTensorShapeInIf2(TestTensorShapeBasic): @@ -342,6 +375,11 @@ class TestTensorShapeInFor2(TestTensorShapeInFor1): def init_test_func(self): self.dygraph_func = dyfunc_with_for_2 + def _set_expected_op_num(self): + self.expected_op_num = 9 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 + # 4. Tests with control flow while loop class TestTensorShapeInWhile1(TestTensorShapeInFor1): @@ -353,15 +391,20 @@ class TestTensorShapeInWhile2(TestTensorShapeInFor1): def init_test_func(self): self.dygraph_func = dyfunc_with_while_2 + def _set_expected_op_num(self): + self.expected_op_num = 6 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 + class TestTensorShapeInWhile3(TestTensorShapeBasic): def init_test_func(self): self.dygraph_func = dyfunc_with_while_3 def _set_expected_op_num(self): - self.expected_op_num = 25 - self.expected_shape_op_num = 6 - self.expected_slice_op_num = 3 + self.expected_op_num = 2 + self.expected_shape_op_num = 0 + self.expected_slice_op_num = 0 class TestTensorShapeInWhile4(TestTensorShapeBasic): @@ -431,9 +474,9 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape): self.dygraph_func = dyfunc_tuple_shape_1 def _set_expected_op_num(self): - self.expected_op_num = 5 - self.expected_shape_op_num = 1 - self.expected_slice_op_num = 1 + self.expected_op_num = 2 + self.expected_shape_op_num = 0 + self.expected_slice_op_num = 0 class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape): @@ -441,7 +484,7 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape): self.dygraph_func = dyfunc_with_if_1 def _set_expected_op_num(self): - self.expected_op_num = 28 + self.expected_op_num = 19 self.expected_shape_op_num = 4 self.expected_slice_op_num = 2 @@ -466,5 +509,17 @@ class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape): self.expected_slice_op_num = 3 +class TestChangeShapeAfterAssign(TestTensorShapeBasic): + def init_test_func(self): + self.input = numpy.ones((2, 3)).astype("int32") + self.input_spec = [paddle.static.InputSpec(shape=[2, 3], dtype="int32")] + self.dygraph_func = dyfunc_change_shape_after_assign + + def _set_expected_op_num(self): + self.expected_op_num = 3 + self.expected_shape_op_num = 0 + self.expected_slice_op_num = 0 + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index fcf6a10974f..9321cf4a0b8 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -25,11 +25,15 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_compare #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape #DEFINE_ALIAS +from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape_simple #DEFINE_ALIAS +from ...fluid.dygraph.dygraph_to_static.convert_operators import eval_if_exist_else_none #DEFINE_ALIAS +from ...fluid.dygraph.dygraph_to_static.convert_operators import choose_shape_attr_or_api #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop #DEFINE_ALIAS __all__ = [ 'cast_bool_if_necessary', 'convert_assert', 'convert_ifelse', 'convert_len', 'convert_logical_and', 'convert_logical_not', 'convert_logical_or', 'convert_pop', 'convert_print', 'convert_shape_compare', - 'convert_var_dtype', 'convert_var_shape', 'convert_while_loop' + 'convert_var_dtype', 'convert_var_shape', 'convert_var_shape_simple', + 'eval_if_exist_else_none', 'choose_shape_attr_or_api', 'convert_while_loop' ] -- GitLab