未验证 提交 def27bc8 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat]Fix bug with static_convert_var_shape in locals scope (#31556)

* Fix bug with static_convert_var_shape

* replace dot with dash
上级 49c3d2a9
......@@ -302,19 +302,19 @@ def convert_var_shape_simple(x):
return x.shape
def eval_if_exist_else_none(name, local_symbol_table):
def eval_if_exist_else_none(name, global_symbol_table):
"""
Args:
name([str]): Expression passed into `eval`.
local_symbol_table(dict): Specified from `locals()`. DO NOT use `globals()`,
it has a higher priority and will hide away variables
from `locals()`.
local_symbol_table(dict): Specified from `globals()`. DO NOT use `locals()`,
because all STATIC_CONVERT_VAR_SHAPE_SUFFIX vars is
declared with keyword `global`.
Returns:
Return the variable if found in local_symbol_table else None.
Return the variable if found in global_symbol_table else None.
"""
try:
return eval(name, local_symbol_table)
return eval(name, global_symbol_table)
except:
return None
......
......@@ -59,7 +59,7 @@ def create_convert_shape_node(var_shape_node,
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())".format(
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', globals())".format(
api_shape_name)
args = [attr_shape_name, eval_exist_func]
......@@ -293,6 +293,10 @@ class TensorShapeTransformer(gast.NodeTransformer):
return False
def _update_name_to_var_shape(self, node):
def replace_dot(name):
# replace all '.' into '_'
return name.replace('.', '_')
assert isinstance(node, gast.Assign)
target_node = node.targets[0]
value_node = node.value
......@@ -307,7 +311,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
if value_node.id in self.name_to_var_shape:
# TODO(zhhsplendid): is context a problem for the result node of gast.parse?
static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
replace_dot(target_id) +
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
......@@ -328,7 +333,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
if isinstance(value_node, gast.Attribute):
if self._is_var_shape(value_node): # eg: x.shape
static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
replace_dot(target_id) +
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
......@@ -341,6 +347,12 @@ class TensorShapeTransformer(gast.NodeTransformer):
ast_to_source_code(static_shape_value_node).strip(),
idx)
sub_node = gast.parse(sub_node_str).body[0].value
# Note(Aurelius84): Becuase static_shape_var_name is used in
# eval_if_exist_else_none() as plain string, so it will not
# be pasred as argument in convert_loop/ifelse. We delcare it
# as global var because it has unique name.
update_static_shape_var_node.append(
gast.Global(names=[static_shape_var_name]))
update_static_shape_var_node.append(
gast.Assign(
......@@ -354,7 +366,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_var_shape:
static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
replace_dot(target_id) +
STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(
static_shape_var_name).body[0].value
static_shape_value_name = self.name_to_var_shape[
......@@ -370,17 +383,20 @@ class TensorShapeTransformer(gast.NodeTransformer):
self.name_to_var_shape[target_id] = static_shape_var_name
elif self._is_var_shape(value_node): # eg: x.shape or x.shape[0]
static_shape_var_name = unique_name.generate(
target_id + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
replace_dot(target_id) + STATIC_CONVERT_VAR_SHAPE_SUFFIX)
static_shape_var_node = gast.parse(static_shape_var_name).body[
0].value
static_shape_value_node = copy.deepcopy(value_node)
# x.shape becomes convert_var_shape_simple(x)
static_shape_value_node = ShapeAttributeTransformer().visit(
static_shape_value_node)
# Declare static_shape_var_name as global var
update_static_shape_var_node = [
gast.Global(names=[static_shape_var_name])
]
update_static_shape_var_node.append(
gast.Assign(
targets=[static_shape_var_node],
value=static_shape_value_node)
]
value=static_shape_value_node))
self.name_to_var_shape[target_id] = static_shape_var_name
return update_static_shape_var_node
......@@ -191,29 +191,44 @@ class TestChooseShapeAttrOrApi(unittest.TestCase):
class TestEvaIfExistElseNone(unittest.TestCase):
def test_locals(self):
def test_globals(self):
global x_shape
x_shape = [1, 2, 3]
self.assertEqual(eval_if_exist_else_none('x_shape', locals()), x_shape)
self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None)
self.assertEqual(eval_if_exist_else_none('x_shape', globals()), x_shape)
def test_globals(self):
del x_shape
def test_enclosing_scope(self):
global x_shape
x_shape = [1, 2, 3]
def foo():
x_shape = [2, 3, 4]
y_shape = [2, 3, 4]
self.assertEqual(
eval_if_exist_else_none('x_shape', globals()), [1, 2, 3])
self.assertEqual(
eval_if_exist_else_none('x_shape', locals()), [2, 3, 4])
eval_if_exist_else_none('y_shape', locals()), [2, 3, 4])
foo()
del x_shape
def test_invisible_of_func(self):
def test_global_in_func(self):
x_shape = [1, 2, 3]
def foo():
x_shape = [2, 3, 4]
return x_shape
global y_shape
y_shape = [2, 3, 4]
self.assertEqual(
eval_if_exist_else_none('x_shape', locals()), [1, 2, 3])
self.assertEqual(
eval_if_exist_else_none('y_shape', globals()), [2, 3, 4])
self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None)
self.assertEqual(
eval_if_exist_else_none('x_shape', globals()), None)
del y_shape
foo()
def test_none(self):
def foo():
......
......@@ -541,5 +541,27 @@ class TestChangeShapeAfterAssign(TestTensorShapeBasic):
self.expected_slice_op_num = 2
def dyfunc_with_static_convert_var_shape(x):
# Note: this will create `batch_size__static_convert_var_shape_suffix_0` firstly.
batch_size = x.shape[0]
if len(x.shape) < 1:
res = x
else:
# Test for correctly to find `batch_size__static_convert_var_shape_suffix_0` in
# deeply nested scope.
res = fluid.layers.fill_constant(
value=8, shape=[batch_size], dtype="int32")
return res
class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase):
def test(self):
x_spec = paddle.static.InputSpec(shape=[None, 10])
func = paddle.jit.to_static(dyfunc_with_if_2, input_spec=[x_spec])
# Call this function to trigger program translation.
func.concrete_program
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册