diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 403e77cb5ccd8d3faae999b36070d8be64322079..4126e942259434dc5035a48a7fd054a7b0433f98 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -302,19 +302,19 @@ def convert_var_shape_simple(x): return x.shape -def eval_if_exist_else_none(name, local_symbol_table): +def eval_if_exist_else_none(name, global_symbol_table): """ Args: name([str]): Expression passed into `eval`. - local_symbol_table(dict): Specified from `locals()`. DO NOT use `globals()`, - it has a higher priority and will hide away variables - from `locals()`. + local_symbol_table(dict): Specified from `globals()`. DO NOT use `locals()`, + because all STATIC_CONVERT_VAR_SHAPE_SUFFIX vars is + declared with keyword `global`. Returns: - Return the variable if found in local_symbol_table else None. + Return the variable if found in global_symbol_table else None. """ try: - return eval(name, local_symbol_table) + return eval(name, global_symbol_table) except: return None diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index ffa1d65e6280af9e5c4d3eac8b29351c8177db69..eb53d7ec9bec894771afce2191dfe195fc53580d 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -59,7 +59,7 @@ def create_convert_shape_node(var_shape_node, def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None): - eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())".format( + eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', globals())".format( api_shape_name) args = [attr_shape_name, eval_exist_func] @@ -293,6 +293,10 @@ class TensorShapeTransformer(gast.NodeTransformer): return False def _update_name_to_var_shape(self, node): + def replace_dot(name): + # replace all '.' into '_' + return name.replace('.', '_') + assert isinstance(node, gast.Assign) target_node = node.targets[0] value_node = node.value @@ -307,7 +311,8 @@ class TensorShapeTransformer(gast.NodeTransformer): if value_node.id in self.name_to_var_shape: # TODO(zhhsplendid): is context a problem for the result node of gast.parse? static_shape_var_name = unique_name.generate( - target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) + replace_dot(target_id) + + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value @@ -328,7 +333,8 @@ class TensorShapeTransformer(gast.NodeTransformer): if isinstance(value_node, gast.Attribute): if self._is_var_shape(value_node): # eg: x.shape static_shape_var_name = unique_name.generate( - target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) + replace_dot(target_id) + + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value @@ -341,6 +347,12 @@ class TensorShapeTransformer(gast.NodeTransformer): ast_to_source_code(static_shape_value_node).strip(), idx) sub_node = gast.parse(sub_node_str).body[0].value + # Note(Aurelius84): Becuase static_shape_var_name is used in + # eval_if_exist_else_none() as plain string, so it will not + # be pasred as argument in convert_loop/ifelse. We delcare it + # as global var because it has unique name. + update_static_shape_var_node.append( + gast.Global(names=[static_shape_var_name])) update_static_shape_var_node.append( gast.Assign( @@ -354,7 +366,8 @@ class TensorShapeTransformer(gast.NodeTransformer): if isinstance(value_node, gast.Name): if value_node.id in self.name_to_var_shape: static_shape_var_name = unique_name.generate( - target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) + replace_dot(target_id) + + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse( static_shape_var_name).body[0].value static_shape_value_name = self.name_to_var_shape[ @@ -370,17 +383,20 @@ class TensorShapeTransformer(gast.NodeTransformer): self.name_to_var_shape[target_id] = static_shape_var_name elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0] static_shape_var_name = unique_name.generate( - target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) + replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX) static_shape_var_node = gast.parse(static_shape_var_name).body[ 0].value static_shape_value_node = copy.deepcopy(value_node) # x.shape becomes convert_var_shape_simple(x) static_shape_value_node = ShapeAttributeTransformer().visit( static_shape_value_node) + # Declare static_shape_var_name as global var update_static_shape_var_node = [ + gast.Global(names=[static_shape_var_name]) + ] + update_static_shape_var_node.append( gast.Assign( targets=[static_shape_var_node], - value=static_shape_value_node) - ] + value=static_shape_value_node)) self.name_to_var_shape[target_id] = static_shape_var_name return update_static_shape_var_node diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py index 7a9bad1236f78161f31fd092785d6663244f7062..54dcc152fd6b281648991141973fc3a2b9a63f69 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py @@ -191,29 +191,44 @@ class TestChooseShapeAttrOrApi(unittest.TestCase): class TestEvaIfExistElseNone(unittest.TestCase): - def test_locals(self): + def test_globals(self): + global x_shape x_shape = [1, 2, 3] - self.assertEqual(eval_if_exist_else_none('x_shape', locals()), x_shape) + self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None) + self.assertEqual(eval_if_exist_else_none('x_shape', globals()), x_shape) - def test_globals(self): + del x_shape + + def test_enclosing_scope(self): + global x_shape x_shape = [1, 2, 3] def foo(): - x_shape = [2, 3, 4] + y_shape = [2, 3, 4] + self.assertEqual( + eval_if_exist_else_none('x_shape', globals()), [1, 2, 3]) self.assertEqual( - eval_if_exist_else_none('x_shape', locals()), [2, 3, 4]) + eval_if_exist_else_none('y_shape', locals()), [2, 3, 4]) foo() + del x_shape - def test_invisible_of_func(self): + def test_global_in_func(self): x_shape = [1, 2, 3] def foo(): - x_shape = [2, 3, 4] - return x_shape + global y_shape + y_shape = [2, 3, 4] - self.assertEqual( - eval_if_exist_else_none('x_shape', locals()), [1, 2, 3]) + self.assertEqual( + eval_if_exist_else_none('y_shape', globals()), [2, 3, 4]) + self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None) + self.assertEqual( + eval_if_exist_else_none('x_shape', globals()), None) + + del y_shape + + foo() def test_none(self): def foo(): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index be571aaf2b75dd2a569b44b23e928906b2eaf196..70749c2e24447e67f267dcfe396dec18d2dcebab 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -541,5 +541,27 @@ class TestChangeShapeAfterAssign(TestTensorShapeBasic): self.expected_slice_op_num = 2 +def dyfunc_with_static_convert_var_shape(x): + # Note: this will create `batch_size__static_convert_var_shape_suffix_0` firstly. + batch_size = x.shape[0] + if len(x.shape) < 1: + res = x + else: + # Test for correctly to find `batch_size__static_convert_var_shape_suffix_0` in + # deeply nested scope. + res = fluid.layers.fill_constant( + value=8, shape=[batch_size], dtype="int32") + + return res + + +class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase): + def test(self): + x_spec = paddle.static.InputSpec(shape=[None, 10]) + func = paddle.jit.to_static(dyfunc_with_if_2, input_spec=[x_spec]) + # Call this function to trigger program translation. + func.concrete_program + + if __name__ == '__main__': unittest.main()