diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 779e50c3dc5b535316b9e176426daf5f82c7b187..403e77cb5ccd8d3faae999b36070d8be64322079 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -302,9 +302,19 @@ def convert_var_shape_simple(x): return x.shape -def eval_if_exist_else_none(name): +def eval_if_exist_else_none(name, local_symbol_table): + """ + Args: + name([str]): Expression passed into `eval`. + local_symbol_table(dict): Specified from `locals()`. DO NOT use `globals()`, + it has a higher priority and will hide away variables + from `locals()`. + + Returns: + Return the variable if found in local_symbol_table else None. + """ try: - return eval(name) + return eval(name, local_symbol_table) except: return None diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index ddd5d84ef421253d09747518eb9618eed3f1ac0e..7cbe86b60c81e3b023e44fc4ed297b1f2e7d3078 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -58,7 +58,8 @@ def create_convert_shape_node(var_shape_node, def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None): - eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}')".format( + # Note(Aurelius84): Add `locals()` to help `eval` to locate the variable correctly. + eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())".format( api_shape_name) args = [attr_shape_name, eval_exist_func] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py index 631cd426b32b8bc47877d1fcabb30b4a4948e199..7a9bad1236f78161f31fd092785d6663244f7062 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py @@ -15,6 +15,7 @@ import numpy as np import paddle import unittest +from paddle.jit.dy2static.convert_operators import eval_if_exist_else_none class CallNotExist(paddle.nn.Layer): @@ -189,5 +190,61 @@ class TestChooseShapeAttrOrApi(unittest.TestCase): paddle.shape(x)) +class TestEvaIfExistElseNone(unittest.TestCase): + def test_locals(self): + x_shape = [1, 2, 3] + self.assertEqual(eval_if_exist_else_none('x_shape', locals()), x_shape) + + def test_globals(self): + x_shape = [1, 2, 3] + + def foo(): + x_shape = [2, 3, 4] + self.assertEqual( + eval_if_exist_else_none('x_shape', locals()), [2, 3, 4]) + + foo() + + def test_invisible_of_func(self): + x_shape = [1, 2, 3] + + def foo(): + x_shape = [2, 3, 4] + return x_shape + + self.assertEqual( + eval_if_exist_else_none('x_shape', locals()), [1, 2, 3]) + + def test_none(self): + def foo(): + x_shape = [2, 3, 4] + return x_shape + + self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None) + + +class ShapeLayer(paddle.nn.Layer): + def __init__(self): + super(ShapeLayer, self).__init__() + + @paddle.jit.to_static(input_spec=[paddle.static.InputSpec(shape=[None, 1])]) + def forward(self, x): + x = paddle.reshape(x, [-1, x.shape[1]]) + bs = x.shape[0] # -1 + + # for trigger choos_shape_attr_or_api + out = paddle.zeros([bs, 1], dtype='float32') + return out + + +class TestChooseShapeAttrOrApiWithLayer(unittest.TestCase): + def test_tensor_shape(self): + x = paddle.zeros(shape=[4, 1], dtype='float32') + net = ShapeLayer() + out = net(x) + + self.assertTrue(np.array_equal(out.numpy(), x.numpy())) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index d28864aade5ce64f980f8fdb38b616ce978b65ec..b84a13be9b3213d18ce8087db5bb0454aff3c1b1 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -484,7 +484,7 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape): self.dygraph_func = dyfunc_with_if_1 def _set_expected_op_num(self): - self.expected_op_num = 19 + self.expected_op_num = 28 self.expected_shape_op_num = 4 self.expected_slice_op_num = 2