From cf43a321a825dd1a75f80591707bb25b3fd29091 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Mon, 22 Feb 2021 20:21:11 +0800 Subject: [PATCH] [Dy2stat] Refactoring tensor_shape_transformer.py to Fix Change after Assign Bug (#31082) **Problem** In our old shape transformer logic, if user write: ``` s = tensor.shape ... y = paddle.some_api(s) ``` Dy2stat will change it to ``` ... y = paddle.some_api(convert_var_shape(tensor)) ``` However it will cause fatal bug if user changes the shape of `x` after assign. For example: ``` s = tensor.shape ... tensor = paddle.some_change_shape_api(tensor) ... y = paddle.some_api(s) ``` Then the Dy2stat will get wrong result because the code is translated into: ``` tensor = paddle.some_change_shape_api(tensor) ... y = paddle.some_api(convert_var_shape(tensor)) # tensor shape has been changed, not origin `s` value ``` **Solution Logic** It can not be solved in the old logic, so I refactoring tensor_shape_transformer logic. Now we will use `s` to store shape attribute and generate a var `s__STATIC_CONVERT_VAR_SHAPE_SUFFIX` to store static shape API `shape(tensor)` ``` s = tensor.shape ... y = paddle.some_api(s) ``` Dy2stat will change it to ``` s = tensor.shape s__STATIC_CONVERT_VAR_SHAPE_SUFFIX = shape(tensor) ... y = paddle.some_api(choose_shape_attr_or_api(s, s__STATIC_CONVERT_VAR_SHAPE_SUFFIX )) ``` In this case, the code is consistent with origin dygraph meaning and it fixed the change after assign bug. **Code Key Note** To help reviewers, the key change of this PR is changing `self.name_to_var_shape` from "mapping name to shape node" to "mapping name to its STATIC_CONVERT_VAR_SHAPE_SUFFIX name", then if a variable name has the SUFFIX, we can choose to use attribute shape or shape api. Other changes go with the key change. **Consideration** The issue of this PR is that we store extra static `shape` API result, will it harms the speed of Dy2stat? In some cases it will, but we argue that the benefit would be greater than the cost. 1. The extra calling to static `shape` API will happen when coder assign among shape variables. Take the following dygraph code as an instance: ``` s1 = tensor.shape s2 = s1 s3 = s2 ... ``` Then we called extra static `shape` APIs again and again, however users seldom write code like this. 2. If the shape variable is used a lot, for example: ``` s = tensor.shape y1 = paddle.some_api1(s) y2 = paddle.some_api2(s) y3 = paddle.some_api3(s) ``` Our old logic will create 3 shape APIs but now just 1. This is more common user code pattern. In fact, if reviewers take a look at the current unit test in this PR, you could see the op numbers decrease after this PR. So we argue that this PR can also improve speed in this code pattern. --- .../dygraph_to_static/convert_operators.py | 54 ++++- .../tensor_shape_transformer.py | 215 +++++++++++++----- .../test_convert_operators.py | 53 +++++ .../dygraph_to_static/test_tensor_shape.py | 60 ++++- .../paddle/jit/dy2static/convert_operators.py | 6 +- 5 files changed, 311 insertions(+), 77 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 13574832bd..779e50c3dc 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -267,12 +267,12 @@ def convert_var_shape(x, idx=None, in_control_flow=False): A function representation of the shape of variable. """ - def has_negetive(list_shape, idx=None): + def has_negative(list_shape, idx=None): if idx is not None: return list_shape[idx] < 0 - num_negetive = sum([1 if i < 0 else 0 for i in list_shape]) - return num_negetive > 0 + num_negative = sum([1 if i < 0 else 0 for i in list_shape]) + return num_negative > 0 # When `x` is Variable, call nn.shape(x) in following cases: # (1) The shape of `x` is used in control flow condition. @@ -280,18 +280,62 @@ def convert_var_shape(x, idx=None, in_control_flow=False): # if x.shape[0] == 1: # y = XX # ``` - # (2) The dim to be used is negetive + # (2) The dim to be used is negative # ``` # # Assume x.shape=[3, -1] in static mode # y = paddle.reshape(x, shape=[1, x.shape[1]]) # ``` - if isinstance(x, Variable) and (in_control_flow or has_negetive(x.shape, + if isinstance(x, Variable) and (in_control_flow or has_negative(x.shape, idx)): return nn.shape(x) if idx is None else nn.shape(x)[idx] else: return x.shape if idx is None else x.shape[idx] +def convert_var_shape_simple(x): + """ + A function representation of the shape of variable. + """ + if isinstance(x, Variable): + return nn.shape(x) + else: + return x.shape + + +def eval_if_exist_else_none(name): + try: + return eval(name) + except: + return None + + +def choose_shape_attr_or_api(attr_shape, api_shape, idx=None): + """ + Input can be attribute `x.shape` or api `shape(x)`, this function + chooses which one to return to use in dy2stat. + + Note: sometimes users write `x.shape[3]`, so attr_shape can be an integer. + """ + if api_shape is None: + return attr_shape if idx is None else attr_shape[idx] + if not isinstance(attr_shape, (list, tuple)): + # some variables like x.shape[0] is no longer a list or tuple + if isinstance(attr_shape, int) and attr_shape < 0: + return api_shape if idx is None else api_shape[idx] + return attr_shape if idx is None else attr_shape[idx] + + def has_negative(list_shape, idx=None): + if idx is not None: + return list_shape[idx] < 0 + + num_negative = sum([1 if i < 0 else 0 for i in list_shape]) + return num_negative > 0 + + if has_negative(attr_shape, idx): + return api_shape if idx is None else api_shape[idx] + return attr_shape if idx is None else attr_shape[idx] + + def convert_shape_compare(left, *args): """ A function handles comparison difference between Paddle and Python. diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 98906d0158..ddd5d84ef4 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -17,12 +17,15 @@ from __future__ import print_function import copy import gast +from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor +STATIC_CONVERT_VAR_SHAPE_SUFFIX = '__static_convert_var_shape_suffix' + def create_convert_shape_node(var_shape_node, slice_node=None, @@ -54,6 +57,39 @@ def create_convert_shape_node(var_shape_node, return result_node +def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None): + eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}')".format( + api_shape_name) + args = [attr_shape_name, eval_exist_func] + + if slice_node is not None and isinstance(slice_node, gast.Index): + args.append(ast_to_source_code(slice_node).strip()) + choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format( + ",".join(args)) + choose_shape_node = gast.parse(choose_shape_func).body[0].value + if slice_node is not None and not isinstance(slice_node, gast.Index): + return gast.Subscript( + value=choose_shape_node, slice=slice_node, ctx=gast.Load()) + return choose_shape_node + + +class ShapeAttributeTransformer(gast.NodeTransformer): + """ + Input a node like `x.shape` or `x[4].shape[0]` (self._is_var_shape(node) is True), + return a new node changes input to static shape API like `convert_var_shape(x)`, + `convert_var_shape(x[4])[0]`. + """ + + def visit_Attribute(self, node): + if node.attr == 'shape': + args = ast_to_source_code(node.value).strip() + convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape_simple({})".format( + args) + api_shape_node = gast.parse(convert_var_shape_func).body[0].value + return api_shape_node + return node + + class TensorShapeTransformer(gast.NodeTransformer): """ This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast. @@ -65,6 +101,8 @@ class TensorShapeTransformer(gast.NodeTransformer): ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node + # stores origin var string name (like "x" in `x = t.shape`) to + # static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`) self.name_to_var_shape = {} self.static_analysis_visitor = StaticAnalysisVisitor(self.root) @@ -79,8 +117,11 @@ class TensorShapeTransformer(gast.NodeTransformer): self.visit(self.root) def visit_Assign(self, node): - if self._update_name_to_var_shape(node): - return node + update_static_shape_var_node = self._update_name_to_var_shape(node) + if update_static_shape_var_node is not None: + ret = [node] + ret.extend(update_static_shape_var_node) + return ret self.generic_visit(node) return node @@ -88,37 +129,44 @@ class TensorShapeTransformer(gast.NodeTransformer): value_node = node.value slice_node = node.slice if isinstance(value_node, gast.Name): - if self._is_var_shape(value_node) and self._used_by_paddle_api( + if value_node.id in self.name_to_var_shape and self._used_by_paddle_api( value_node): - var_shape_node = self.name_to_var_shape[value_node.id] - return create_convert_shape_node(var_shape_node, slice_node) - - if isinstance(value_node, gast.Attribute): - if self._used_by_paddle_api(value_node) and self._is_var_shape( - value_node): - return create_convert_shape_node(value_node, slice_node) - + return create_choose_shape_node( + value_node.id, self.name_to_var_shape[value_node.id], + slice_node) + elif isinstance(value_node, gast.Attribute): + if self._used_by_paddle_api(value_node): + value_name = ast_to_source_code(value_node).strip() + if value_name in self.name_to_var_shape: + return create_choose_shape_node( + value_name, self.name_to_var_shape[value_name], + slice_node) + if self._is_var_shape(value_node): + return create_convert_shape_node(value_node, slice_node) return node def visit_Attribute(self, node): if self._used_by_paddle_api(node): + name = ast_to_source_code(node).strip() + if name in self.name_to_var_shape: + return create_choose_shape_node(name, + self.name_to_var_shape[name]) if self._is_var_shape(node): return create_convert_shape_node(node) return node def visit_Name(self, node): - if self._is_var_shape(node): + if node.id in self.name_to_var_shape: if self._used_by_paddle_api(node): - var_shape_node = self.name_to_var_shape[node.id] - return create_convert_shape_node(var_shape_node) + return create_choose_shape_node(node.id, + self.name_to_var_shape[node.id]) return node def visit_Call(self, node): - assert isinstance(node, gast.Call) if is_paddle_api(node): # Visit gast.Attribute and gast.Name to replace var.shape if necessary. self.generic_visit(node) - + # Don't have to visit other APIs return node def visit_If(self, node): @@ -154,22 +202,23 @@ class TensorShapeTransformer(gast.NodeTransformer): return False args = node.iter.args for idx, arg in enumerate(args): - if isinstance(arg, gast.Name) and self._is_var_shape(arg): - args[idx] = create_convert_shape_node(self.name_to_var_shape[ - arg.id]) - + if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape: + args[idx] = create_choose_shape_node( + arg.id, self.name_to_var_shape[arg.id]) return True def _transform_var_shape_if_necessary(self, cond): need_transformed = False for child_node in gast.walk(cond): var_shape_node = None - if isinstance(child_node, (gast.Attribute, gast.Subscript)): - if self._is_var_shape(child_node): + if isinstance(child_node, + (gast.Name, gast.Attribute, gast.Subscript)): + child_name = ast_to_source_code(child_node).strip() + if child_name in self.name_to_var_shape: + var_shape_node = create_choose_shape_node( + child_name, self.name_to_var_shape[child_name]) + elif self._is_var_shape(child_node): var_shape_node = child_node - elif isinstance(child_node, (gast.Name)): - if self._is_var_shape(child_node): - var_shape_node = self.name_to_var_shape[child_node.id] if var_shape_node: need_transformed = True @@ -177,17 +226,23 @@ class TensorShapeTransformer(gast.NodeTransformer): parent_node = wrapper_node.parent.node for field, value in gast.iter_fields(parent_node): if child_node is value: - setattr(parent_node, field, - create_convert_shape_node(var_shape_node, None, - True)) + if var_shape_node is child_node: + setattr(parent_node, field, + create_convert_shape_node(var_shape_node, + None, True)) + else: + setattr(parent_node, field, var_shape_node) break # Some child_node may be in a list such as gast.Compare if isinstance(value, list): has_converted_shape = False for i, v in enumerate(value): if child_node is v: - value[i] = create_convert_shape_node( - var_shape_node, None, True) + if var_shape_node is child_node: + value[i] = create_convert_shape_node( + var_shape_node, None, True) + else: + value[i] = var_shape_node has_converted_shape = True break if has_converted_shape: @@ -224,19 +279,12 @@ class TensorShapeTransformer(gast.NodeTransformer): """ Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise. """ - if not isinstance(node, (gast.Name, gast.Attribute, gast.Subscript)): + if not isinstance(node, (gast.Attribute, gast.Subscript)): return False - if isinstance(node, gast.Name) and node.id in self.name_to_var_shape: - return True - if isinstance(node, gast.Attribute): if node.attr != 'shape': return False - - if not isinstance(node.value, gast.Name): - return False - return True if isinstance(node, gast.Subscript): @@ -250,49 +298,94 @@ class TensorShapeTransformer(gast.NodeTransformer): target_node = node.targets[0] value_node = node.value + update_static_shape_var_node = None if isinstance(target_node, gast.Tuple): - has_updated = False + update_static_shape_var_node = [] for idx, element in enumerate(target_node.elts): target_id = ast_to_source_code(element).strip() if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape: + # TODO(zhhsplendid): is context a problem for the result node of gast.parse? + static_shape_var_name = unique_name.generate( + target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) + static_shape_var_node = gast.parse( + static_shape_var_name).body[0].value + + static_shape_value_name = self.name_to_var_shape[ + value_node.id] + static_shape_value_node = gast.parse( + static_shape_value_name).body[0].value index_value_node = gast.Constant(value=idx, kind=None) slice_index_node = gast.Index(value=index_value_node) - var_shape_node = self.name_to_var_shape[value_node.id] sub_node = gast.Subscript( - value=var_shape_node, + value=static_shape_value_node, slice=slice_index_node, ctx=gast.Load()) - self.name_to_var_shape[target_id] = sub_node - has_updated = True + + update_static_shape_var_node.append( + gast.Assign( + targets=[static_shape_var_node], + value=sub_node)) + + self.name_to_var_shape[ + target_id] = static_shape_var_name if isinstance(value_node, gast.Attribute): if self._is_var_shape(value_node): # eg: x.shape + static_shape_var_name = unique_name.generate( + target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) + static_shape_var_node = gast.parse( + static_shape_var_name).body[0].value + + static_shape_value_node = copy.deepcopy(value_node) + # x.shape becomes convert_var_shape_simple(x) + ShapeAttributeTransformer().visit( + static_shape_value_node) index_value_node = gast.Constant(value=idx, kind=None) slice_index_node = gast.Index(value=index_value_node) sub_node = gast.Subscript( - value=value_node, + value=static_shape_value_node, slice=slice_index_node, ctx=gast.Load()) - self.name_to_var_shape[target_id] = sub_node - has_updated = True - return has_updated + update_static_shape_var_node.append( + gast.Assign( + targets=[static_shape_var_node], + value=sub_node)) + self.name_to_var_shape[ + target_id] = static_shape_var_name + return update_static_shape_var_node else: target_id = ast_to_source_code(target_node).strip() - if isinstance(value_node, gast.Name): - if self._is_var_shape(value_node): - self.name_to_var_shape[target_id] = self.name_to_var_shape[ + if value_node.id in self.name_to_var_shape: + static_shape_var_name = unique_name.generate( + target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) + static_shape_var_node = gast.parse( + static_shape_var_name).body[0].value + static_shape_value_name = self.name_to_var_shape[ value_node.id] - return True - if isinstance(value_node, gast.Attribute): - if self._is_var_shape(value_node): # eg: x.shape - self.name_to_var_shape[target_id] = value_node - return True - if isinstance(value_node, gast.Subscript): - if isinstance(value_node.value, gast.Attribute): - if self._is_var_shape(value_node.value): # eg: x.shape[0] - self.name_to_var_shape[target_id] = value_node - return True - return False + static_shape_value_node = gast.parse( + static_shape_value_name).body[0].value + + update_static_shape_var_node = [ + gast.Assign( + targets=[static_shape_var_node], + value=static_shape_value_node) + ] + self.name_to_var_shape[target_id] = static_shape_var_name + elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0] + static_shape_var_name = unique_name.generate( + target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) + static_shape_var_node = gast.parse(static_shape_var_name).body[ + 0].value + static_shape_value_node = copy.deepcopy(value_node) + # x.shape becomes convert_var_shape_simple(x) + ShapeAttributeTransformer().visit(static_shape_value_node) + update_static_shape_var_node = [ + gast.Assign( + targets=[static_shape_var_node], + value=static_shape_value_node) + ] + self.name_to_var_shape[target_id] = static_shape_var_name + return update_static_shape_var_node diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py index 28c5d22021..631cd426b3 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py @@ -136,5 +136,58 @@ class TestConvertShapeCompare(unittest.TestCase): paddle.disable_static() +class TestChooseShapeAttrOrApi(unittest.TestCase): + def test_api_shape_is_none(self): + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api([1, 2], None), + [1, 2]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api([1], None), [1]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api([2, 3, 7], None, 0), + 2) + + def test_attr_shape_is_int(self): + x = paddle.zeros([1, 3, 5, 7]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api(x.shape[0], + paddle.shape(x)[0]), + 1) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api(x.shape[1], + paddle.shape(x)[1]), + 3) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api(-1, + paddle.shape(x)[0]), + paddle.shape(x)[0]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api(-1, + paddle.shape(x), 0), + paddle.shape(x)[0]) + + def test_positive_attr_shape(self): + x = paddle.zeros([1, 3, 5, 7]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api(x.shape, + paddle.shape(x)), + x.shape) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api(x.shape, + paddle.shape(x), 3), + x.shape[3]) + + def test_negative_attr_shape(self): + x = paddle.zeros([7]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api([-1], + paddle.shape(x), 0), + paddle.shape(x)[0]) + self.assertEqual( + paddle.jit.dy2static.choose_shape_attr_or_api([-1], + paddle.shape(x)), + paddle.shape(x)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index 7a4c63894f..d28864aade 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -207,6 +207,14 @@ def dyfunc_with_while_4(x): return x +def dyfunc_change_shape_after_assign(x): + x = paddle.to_tensor(x) + a, b = x.shape + x = paddle.reshape(x, shape=(-1, 1)) + res = paddle.reshape(x, shape=(b, a)) + return res + + # 1. Basic tests without control flow class TestTensorShapeBasic(unittest.TestCase): def setUp(self): @@ -289,11 +297,21 @@ class TestTensorShapeBasic5(TestTensorShapeBasic): def init_test_func(self): self.dygraph_func = dyfunc_tensor_shape_5 + def _set_expected_op_num(self): + self.expected_op_num = 4 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 + class TestTensorShapeBasic6(TestTensorShapeBasic): def init_test_func(self): self.dygraph_func = dyfunc_tensor_shape_6 + def _set_expected_op_num(self): + self.expected_op_num = 4 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 + class TestTupleShape1(TestTensorShapeBasic): def init_test_func(self): @@ -327,9 +345,9 @@ class TestTensorShapeInIf1(TestTensorShapeBasic): self.dygraph_func = dyfunc_with_if_1 def _set_expected_op_num(self): - self.expected_op_num = 26 - self.expected_shape_op_num = 2 - self.expected_slice_op_num = 2 + self.expected_op_num = 4 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 class TestTensorShapeInIf2(TestTensorShapeBasic): @@ -357,6 +375,11 @@ class TestTensorShapeInFor2(TestTensorShapeInFor1): def init_test_func(self): self.dygraph_func = dyfunc_with_for_2 + def _set_expected_op_num(self): + self.expected_op_num = 9 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 + # 4. Tests with control flow while loop class TestTensorShapeInWhile1(TestTensorShapeInFor1): @@ -368,15 +391,20 @@ class TestTensorShapeInWhile2(TestTensorShapeInFor1): def init_test_func(self): self.dygraph_func = dyfunc_with_while_2 + def _set_expected_op_num(self): + self.expected_op_num = 6 + self.expected_shape_op_num = 1 + self.expected_slice_op_num = 1 + class TestTensorShapeInWhile3(TestTensorShapeBasic): def init_test_func(self): self.dygraph_func = dyfunc_with_while_3 def _set_expected_op_num(self): - self.expected_op_num = 25 - self.expected_shape_op_num = 6 - self.expected_slice_op_num = 3 + self.expected_op_num = 2 + self.expected_shape_op_num = 0 + self.expected_slice_op_num = 0 class TestTensorShapeInWhile4(TestTensorShapeBasic): @@ -446,9 +474,9 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape): self.dygraph_func = dyfunc_tuple_shape_1 def _set_expected_op_num(self): - self.expected_op_num = 5 - self.expected_shape_op_num = 1 - self.expected_slice_op_num = 1 + self.expected_op_num = 2 + self.expected_shape_op_num = 0 + self.expected_slice_op_num = 0 class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape): @@ -456,7 +484,7 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape): self.dygraph_func = dyfunc_with_if_1 def _set_expected_op_num(self): - self.expected_op_num = 28 + self.expected_op_num = 19 self.expected_shape_op_num = 4 self.expected_slice_op_num = 2 @@ -481,5 +509,17 @@ class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape): self.expected_slice_op_num = 3 +class TestChangeShapeAfterAssign(TestTensorShapeBasic): + def init_test_func(self): + self.input = numpy.ones((2, 3)).astype("int32") + self.input_spec = [paddle.static.InputSpec(shape=[2, 3], dtype="int32")] + self.dygraph_func = dyfunc_change_shape_after_assign + + def _set_expected_op_num(self): + self.expected_op_num = 3 + self.expected_shape_op_num = 0 + self.expected_slice_op_num = 0 + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index fcf6a10974..9321cf4a0b 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -25,11 +25,15 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_compare #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape #DEFINE_ALIAS +from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape_simple #DEFINE_ALIAS +from ...fluid.dygraph.dygraph_to_static.convert_operators import eval_if_exist_else_none #DEFINE_ALIAS +from ...fluid.dygraph.dygraph_to_static.convert_operators import choose_shape_attr_or_api #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop #DEFINE_ALIAS __all__ = [ 'cast_bool_if_necessary', 'convert_assert', 'convert_ifelse', 'convert_len', 'convert_logical_and', 'convert_logical_not', 'convert_logical_or', 'convert_pop', 'convert_print', 'convert_shape_compare', - 'convert_var_dtype', 'convert_var_shape', 'convert_while_loop' + 'convert_var_dtype', 'convert_var_shape', 'convert_var_shape_simple', + 'eval_if_exist_else_none', 'choose_shape_attr_or_api', 'convert_while_loop' ] -- GitLab