未验证 提交 f4a69d58 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat] Fix eval_if_exist_else_none bug (#31261) (#31277)

* fix eval_if_exist_else_none bug

* fix typo

* fix typo

* fix test_op_num unittest
上级 52f7e773
...@@ -302,9 +302,19 @@ def convert_var_shape_simple(x): ...@@ -302,9 +302,19 @@ def convert_var_shape_simple(x):
return x.shape return x.shape
def eval_if_exist_else_none(name): def eval_if_exist_else_none(name, local_symbol_table):
"""
Args:
name([str]): Expression passed into `eval`.
local_symbol_table(dict): Specified from `locals()`. DO NOT use `globals()`,
it has a higher priority and will hide away variables
from `locals()`.
Returns:
Return the variable if found in local_symbol_table else None.
"""
try: try:
return eval(name) return eval(name, local_symbol_table)
except: except:
return None return None
......
...@@ -58,7 +58,8 @@ def create_convert_shape_node(var_shape_node, ...@@ -58,7 +58,8 @@ def create_convert_shape_node(var_shape_node,
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None): def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}')".format( # Note(Aurelius84): Add `locals()` to help `eval` to locate the variable correctly.
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())".format(
api_shape_name) api_shape_name)
args = [attr_shape_name, eval_exist_func] args = [attr_shape_name, eval_exist_func]
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import numpy as np import numpy as np
import paddle import paddle
import unittest import unittest
from paddle.jit.dy2static.convert_operators import eval_if_exist_else_none
class CallNotExist(paddle.nn.Layer): class CallNotExist(paddle.nn.Layer):
...@@ -189,5 +190,61 @@ class TestChooseShapeAttrOrApi(unittest.TestCase): ...@@ -189,5 +190,61 @@ class TestChooseShapeAttrOrApi(unittest.TestCase):
paddle.shape(x)) paddle.shape(x))
class TestEvaIfExistElseNone(unittest.TestCase):
def test_locals(self):
x_shape = [1, 2, 3]
self.assertEqual(eval_if_exist_else_none('x_shape', locals()), x_shape)
def test_globals(self):
x_shape = [1, 2, 3]
def foo():
x_shape = [2, 3, 4]
self.assertEqual(
eval_if_exist_else_none('x_shape', locals()), [2, 3, 4])
foo()
def test_invisible_of_func(self):
x_shape = [1, 2, 3]
def foo():
x_shape = [2, 3, 4]
return x_shape
self.assertEqual(
eval_if_exist_else_none('x_shape', locals()), [1, 2, 3])
def test_none(self):
def foo():
x_shape = [2, 3, 4]
return x_shape
self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None)
class ShapeLayer(paddle.nn.Layer):
def __init__(self):
super(ShapeLayer, self).__init__()
@paddle.jit.to_static(input_spec=[paddle.static.InputSpec(shape=[None, 1])])
def forward(self, x):
x = paddle.reshape(x, [-1, x.shape[1]])
bs = x.shape[0] # -1
# for trigger choos_shape_attr_or_api
out = paddle.zeros([bs, 1], dtype='float32')
return out
class TestChooseShapeAttrOrApiWithLayer(unittest.TestCase):
def test_tensor_shape(self):
x = paddle.zeros(shape=[4, 1], dtype='float32')
net = ShapeLayer()
out = net(x)
self.assertTrue(np.array_equal(out.numpy(), x.numpy()))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -484,7 +484,7 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape): ...@@ -484,7 +484,7 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_with_if_1 self.dygraph_func = dyfunc_with_if_1
def _set_expected_op_num(self): def _set_expected_op_num(self):
self.expected_op_num = 19 self.expected_op_num = 28
self.expected_shape_op_num = 4 self.expected_shape_op_num = 4
self.expected_slice_op_num = 2 self.expected_slice_op_num = 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册