未验证 提交 def27bc8 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat]Fix bug with static_convert_var_shape in locals scope (#31556)

* Fix bug with static_convert_var_shape

* replace dot with dash
上级 49c3d2a9
...@@ -302,19 +302,19 @@ def convert_var_shape_simple(x): ...@@ -302,19 +302,19 @@ def convert_var_shape_simple(x):
return x.shape return x.shape
def eval_if_exist_else_none(name, local_symbol_table): def eval_if_exist_else_none(name, global_symbol_table):
""" """
Args: Args:
name([str]): Expression passed into `eval`. name([str]): Expression passed into `eval`.
local_symbol_table(dict): Specified from `locals()`. DO NOT use `globals()`, local_symbol_table(dict): Specified from `globals()`. DO NOT use `locals()`,
it has a higher priority and will hide away variables because all STATIC_CONVERT_VAR_SHAPE_SUFFIX vars is
from `locals()`. declared with keyword `global`.
Returns: Returns:
Return the variable if found in local_symbol_table else None. Return the variable if found in global_symbol_table else None.
""" """
try: try:
return eval(name, local_symbol_table) return eval(name, global_symbol_table)
except: except:
return None return None
......
...@@ -59,7 +59,7 @@ def create_convert_shape_node(var_shape_node, ...@@ -59,7 +59,7 @@ def create_convert_shape_node(var_shape_node,
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None): def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())".format( eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', globals())".format(
api_shape_name) api_shape_name)
args = [attr_shape_name, eval_exist_func] args = [attr_shape_name, eval_exist_func]
...@@ -293,6 +293,10 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -293,6 +293,10 @@ class TensorShapeTransformer(gast.NodeTransformer):
return False return False
def _update_name_to_var_shape(self, node): def _update_name_to_var_shape(self, node):
def replace_dot(name):
# replace all '.' into '_'
return name.replace('.', '_')
assert isinstance(node, gast.Assign) assert isinstance(node, gast.Assign)
target_node = node.targets[0] target_node = node.targets[0]
value_node = node.value value_node = node.value
...@@ -307,7 +311,8 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -307,7 +311,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
if value_node.id in self.name_to_var_shape: if value_node.id in self.name_to_var_shape:
# TODO(zhhsplendid): is context a problem for the result node of gast.parse? # TODO(zhhsplendid): is context a problem for the result node of gast.parse?
static_shape_var_name = unique_name.generate( static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) replace_dot(target_id) +
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse( static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value static_shape_var_name).body[0].value
...@@ -328,7 +333,8 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -328,7 +333,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
if isinstance(value_node, gast.Attribute): if isinstance(value_node, gast.Attribute):
if self._is_var_shape(value_node): # eg: x.shape if self._is_var_shape(value_node): # eg: x.shape
static_shape_var_name = unique_name.generate( static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) replace_dot(target_id) +
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse( static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value static_shape_var_name).body[0].value
...@@ -341,6 +347,12 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -341,6 +347,12 @@ class TensorShapeTransformer(gast.NodeTransformer):
ast_to_source_code(static_shape_value_node).strip(), ast_to_source_code(static_shape_value_node).strip(),
idx) idx)
sub_node = gast.parse(sub_node_str).body[0].value sub_node = gast.parse(sub_node_str).body[0].value
# Note(Aurelius84): Becuase static_shape_var_name is used in
# eval_if_exist_else_none() as plain string, so it will not
# be pasred as argument in convert_loop/ifelse. We delcare it
# as global var because it has unique name.
update_static_shape_var_node.append(
gast.Global(names=[static_shape_var_name]))
update_static_shape_var_node.append( update_static_shape_var_node.append(
gast.Assign( gast.Assign(
...@@ -354,7 +366,8 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -354,7 +366,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
if isinstance(value_node, gast.Name): if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_var_shape: if value_node.id in self.name_to_var_shape:
static_shape_var_name = unique_name.generate( static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) replace_dot(target_id) +
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse( static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value static_shape_var_name).body[0].value
static_shape_value_name = self.name_to_var_shape[ static_shape_value_name = self.name_to_var_shape[
...@@ -370,17 +383,20 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -370,17 +383,20 @@ class TensorShapeTransformer(gast.NodeTransformer):
self.name_to_var_shape[target_id] = static_shape_var_name self.name_to_var_shape[target_id] = static_shape_var_name
elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0] elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0]
static_shape_var_name = unique_name.generate( static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX) replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(static_shape_var_name).body[ static_shape_var_node = gast.parse(static_shape_var_name).body[
0].value 0].value
static_shape_value_node = copy.deepcopy(value_node) static_shape_value_node = copy.deepcopy(value_node)
# x.shape becomes convert_var_shape_simple(x) # x.shape becomes convert_var_shape_simple(x)
static_shape_value_node = ShapeAttributeTransformer().visit( static_shape_value_node = ShapeAttributeTransformer().visit(
static_shape_value_node) static_shape_value_node)
# Declare static_shape_var_name as global var
update_static_shape_var_node = [ update_static_shape_var_node = [
gast.Global(names=[static_shape_var_name])
]
update_static_shape_var_node.append(
gast.Assign( gast.Assign(
targets=[static_shape_var_node], targets=[static_shape_var_node],
value=static_shape_value_node) value=static_shape_value_node))
]
self.name_to_var_shape[target_id] = static_shape_var_name self.name_to_var_shape[target_id] = static_shape_var_name
return update_static_shape_var_node return update_static_shape_var_node
...@@ -191,29 +191,44 @@ class TestChooseShapeAttrOrApi(unittest.TestCase): ...@@ -191,29 +191,44 @@ class TestChooseShapeAttrOrApi(unittest.TestCase):
class TestEvaIfExistElseNone(unittest.TestCase): class TestEvaIfExistElseNone(unittest.TestCase):
def test_locals(self): def test_globals(self):
global x_shape
x_shape = [1, 2, 3] x_shape = [1, 2, 3]
self.assertEqual(eval_if_exist_else_none('x_shape', locals()), x_shape) self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None)
self.assertEqual(eval_if_exist_else_none('x_shape', globals()), x_shape)
def test_globals(self): del x_shape
def test_enclosing_scope(self):
global x_shape
x_shape = [1, 2, 3] x_shape = [1, 2, 3]
def foo(): def foo():
x_shape = [2, 3, 4] y_shape = [2, 3, 4]
self.assertEqual(
eval_if_exist_else_none('x_shape', globals()), [1, 2, 3])
self.assertEqual( self.assertEqual(
eval_if_exist_else_none('x_shape', locals()), [2, 3, 4]) eval_if_exist_else_none('y_shape', locals()), [2, 3, 4])
foo() foo()
del x_shape
def test_invisible_of_func(self): def test_global_in_func(self):
x_shape = [1, 2, 3] x_shape = [1, 2, 3]
def foo(): def foo():
x_shape = [2, 3, 4] global y_shape
return x_shape y_shape = [2, 3, 4]
self.assertEqual( self.assertEqual(
eval_if_exist_else_none('x_shape', locals()), [1, 2, 3]) eval_if_exist_else_none('y_shape', globals()), [2, 3, 4])
self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None)
self.assertEqual(
eval_if_exist_else_none('x_shape', globals()), None)
del y_shape
foo()
def test_none(self): def test_none(self):
def foo(): def foo():
......
...@@ -541,5 +541,27 @@ class TestChangeShapeAfterAssign(TestTensorShapeBasic): ...@@ -541,5 +541,27 @@ class TestChangeShapeAfterAssign(TestTensorShapeBasic):
self.expected_slice_op_num = 2 self.expected_slice_op_num = 2
def dyfunc_with_static_convert_var_shape(x):
# Note: this will create `batch_size__static_convert_var_shape_suffix_0` firstly.
batch_size = x.shape[0]
if len(x.shape) < 1:
res = x
else:
# Test for correctly to find `batch_size__static_convert_var_shape_suffix_0` in
# deeply nested scope.
res = fluid.layers.fill_constant(
value=8, shape=[batch_size], dtype="int32")
return res
class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase):
def test(self):
x_spec = paddle.static.InputSpec(shape=[None, 10])
func = paddle.jit.to_static(dyfunc_with_if_2, input_spec=[x_spec])
# Call this function to trigger program translation.
func.concrete_program
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册