未验证 提交 6bf02a12 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Fix Read-Only Attribute as while_loop Output (#31415)

Fix Read-Only Attribute as while_loop Output:

Usually, our convert_while_loop will be like:
```
    [a, b, c] = paddle.jit.dy2static.convert_while_loop(
            condition_name, body_name, [a, b, c])
```
where a, b, c are in loop_var_names.

However, if loop_var_names contains property such as foo.x, we cannot
assign the attribute as output of convert_while_loop because Python
property is a kind of read-only attribute. To handle the case, we replace
the attributes which are output of convert_while_loop with generated
variables, then if we know the attribute is not read-only at runtime, we
assign the attribute. The created statements are like:
```
    [a, b, __attribute_variable_1] = paddle.jit.dy2static.convert_while_loop(
            condition_name, body_name, [a, b, foo.x])
    if not isinstance(getattr(type(foo), x, None), property): foo.x = __attribute_variable_1
```
上级 5b4f8aac
...@@ -39,8 +39,35 @@ FOR_CONDITION_PREFIX = 'for_loop_condition' ...@@ -39,8 +39,35 @@ FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body' FOR_BODY_PREFIX = 'for_loop_body'
GENERATE_VARIABLE_PREFIX = 'generate_variable' GENERATE_VARIABLE_PREFIX = 'generate_variable'
ATTRIBUTE_VARIABLE_PREFIX = '__attribute_variable'
def create_while_node(condition_name, body_name, loop_var_names):
def create_while_nodes(condition_name, body_name, loop_var_names):
"""
Returns a list of gast.Node which represents the calling of Paddle
controlflow while_loop.
Usually, the list just contain 1 statement such as:
[a, b, c] = paddle.jit.dy2static.convert_while_loop(
condition_name, body_name, [a, b, c])
where a, b, c are in loop_var_names.
However, if loop_var_names contains attribute such as foo.x, we cannot
assign the attribute as output of convert_while_loop because Python
property is a kind of read-only attribute. To handle the case, we replace
the attributes which are output of convert_while_loop with generated
variables, then if we know the attribute is not read-only at runtime, we
assign the attribute. The created statements are like:
[a, b, __attribute_variable_1] = paddle.jit.dy2static.convert_while_loop(
condition_name, body_name, [a, b, foo.x])
if not isinstance(getattr(type(foo), x, None), property): foo.x = __attribute_variable_1
The number of above statements is not only 1, that's why the return type is
a list of gast.Node.
"""
# NOTE(liym27): # NOTE(liym27):
# It's better to parse the source code into an AST node than to customize an AST node # It's better to parse the source code into an AST node than to customize an AST node
# including child nodes, because it is easy to mistake the ast node type when customizing the node. # including child nodes, because it is easy to mistake the ast node type when customizing the node.
...@@ -48,14 +75,37 @@ def create_while_node(condition_name, body_name, loop_var_names): ...@@ -48,14 +75,37 @@ def create_while_node(condition_name, body_name, loop_var_names):
# For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name, # For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name,
# but the type of `foo.x` gast.Attribute. # but the type of `foo.x` gast.Attribute.
unique_name_to_origin = {}
# We have to make loop_var_names and assign_loop_var_names with same order
# set doesn't have order so we convert it to list
loop_var_names = list(loop_var_names)
assign_loop_var_names = []
for name in (loop_var_names):
if "." in name:
# name is an attribute variable such as foo.x
tmp_attr_name = unique_name.generate(ATTRIBUTE_VARIABLE_PREFIX)
unique_name_to_origin[tmp_attr_name] = name
assign_loop_var_names.append(tmp_attr_name)
else:
assign_loop_var_names.append(name)
while_func_name = "paddle.jit.dy2static.convert_while_loop" while_func_name = "paddle.jit.dy2static.convert_while_loop"
while_node_str = "[{}] = {}({}, {}, [{}])".format( while_node_str = "[{}] = {}({}, {}, [{}])".format(
",".join(loop_var_names), while_func_name, condition_name, body_name, ",".join(assign_loop_var_names), while_func_name, condition_name,
",".join(loop_var_names)) body_name, ",".join(loop_var_names))
while_node = gast.parse(while_node_str).body[0] while_node = gast.parse(while_node_str).body[0]
return while_node ret = [while_node]
for tmp_attr_name in unique_name_to_origin:
origin_attr_var = unique_name_to_origin[tmp_attr_name]
dot_pos = origin_attr_var.rindex(".")
obj_name = origin_attr_var[0:dot_pos]
attr_name = origin_attr_var[dot_pos + 1:]
assign_if_not_prop_str = "if not isinstance(getattr(type({}), '{}', None), property): {} = {}".format(
obj_name, attr_name, origin_attr_var, tmp_attr_name)
assign_if_not_prop_node = gast.parse(assign_if_not_prop_str).body[0]
ret.append(assign_if_not_prop_node)
return ret
class NameVisitor(gast.NodeVisitor): class NameVisitor(gast.NodeVisitor):
...@@ -573,9 +623,9 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -573,9 +623,9 @@ class LoopTransformer(gast.NodeTransformer):
new_stmts.append(body_func_node) new_stmts.append(body_func_node)
# 7. create & append while loop node # 7. create & append while loop node
while_loop_node = create_while_node(condition_func_node.name, while_loop_nodes = create_while_nodes(
body_func_node.name, loop_var_names) condition_func_node.name, body_func_node.name, loop_var_names)
new_stmts.append(while_loop_node) new_stmts.extend(while_loop_nodes)
return new_stmts return new_stmts
...@@ -655,7 +705,7 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -655,7 +705,7 @@ class LoopTransformer(gast.NodeTransformer):
name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(body_func_node) new_stmts.append(body_func_node)
while_loop_node = create_while_node(condition_func_node.name, while_loop_nodes = create_while_nodes(
body_func_node.name, loop_var_names) condition_func_node.name, body_func_node.name, loop_var_names)
new_stmts.append(while_loop_node) new_stmts.extend(while_loop_nodes)
return new_stmts return new_stmts
...@@ -340,8 +340,8 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -340,8 +340,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
static_shape_value_node = copy.deepcopy(value_node) static_shape_value_node = copy.deepcopy(value_node)
# x.shape becomes convert_var_shape_simple(x) # x.shape becomes convert_var_shape_simple(x)
ShapeAttributeTransformer().visit( static_shape_value_node = ShapeAttributeTransformer(
static_shape_value_node) ).visit(static_shape_value_node)
index_value_node = gast.Constant(value=idx, kind=None) index_value_node = gast.Constant(value=idx, kind=None)
slice_index_node = gast.Index(value=index_value_node) slice_index_node = gast.Index(value=index_value_node)
sub_node = gast.Subscript( sub_node = gast.Subscript(
...@@ -382,7 +382,8 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -382,7 +382,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
0].value 0].value
static_shape_value_node = copy.deepcopy(value_node) static_shape_value_node = copy.deepcopy(value_node)
# x.shape becomes convert_var_shape_simple(x) # x.shape becomes convert_var_shape_simple(x)
ShapeAttributeTransformer().visit(static_shape_value_node) static_shape_value_node = ShapeAttributeTransformer().visit(
static_shape_value_node)
update_static_shape_var_node = [ update_static_shape_var_node = [
gast.Assign( gast.Assign(
targets=[static_shape_var_node], targets=[static_shape_var_node],
......
...@@ -1098,6 +1098,10 @@ def assign_skip_lod_tensor_array(input, output): ...@@ -1098,6 +1098,10 @@ def assign_skip_lod_tensor_array(input, output):
""" """
Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block. Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block.
""" """
if not isinstance(input, Variable) and not isinstance(input, core.VarBase):
output = input
return
if input.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY: if input.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
main_program = input.block.program main_program = input.block.program
parent_block = main_program.block(main_program.current_block() parent_block = main_program.block(main_program.current_block()
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import gast import gast
import inspect import inspect
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import unittest import unittest
...@@ -157,6 +158,16 @@ def while_loop_class_var(x): ...@@ -157,6 +158,16 @@ def while_loop_class_var(x):
return foo.c return foo.c
def loop_var_contains_property(x):
a = paddle.zeros(shape=[1], dtype='float32')
i = paddle.to_tensor(x)
s = i.shape
while i < 10 and s[0] >= 1:
a += i.shape[0]
i += 1
return a
def for_loop_class_var(max_len): def for_loop_class_var(max_len):
class Foo(object): class Foo(object):
def __init__(self): def __init__(self):
...@@ -240,9 +251,7 @@ class TestNameVisitor(unittest.TestCase): ...@@ -240,9 +251,7 @@ class TestNameVisitor(unittest.TestCase):
name_visitor = NameVisitor(gast_root) name_visitor = NameVisitor(gast_root)
self.loop_var_names = [ self.loop_var_names = [
set(["j", "two"]), set(["j", "two"]), set(["i", "three", "b"]), set(["i", "j"])
set(["i", "three", "b"]),
set(["i", "j"]),
] ]
self.create_var_names = [set(), set(["b"]), set()] self.create_var_names = [set(), set(["b"]), set()]
...@@ -326,6 +335,11 @@ class TestWhileLoopClassVar(TestTransformWhileLoop): ...@@ -326,6 +335,11 @@ class TestWhileLoopClassVar(TestTransformWhileLoop):
self.dyfunc = while_loop_class_var self.dyfunc = while_loop_class_var
class TestLoopVarContainsProperty(TestTransformWhileLoop):
def _init_dyfunc(self):
self.dyfunc = loop_var_contains_property
class TestTransformForLoop(unittest.TestCase): class TestTransformForLoop(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
......
...@@ -144,11 +144,6 @@ def dyfunc_with_for_2(x): ...@@ -144,11 +144,6 @@ def dyfunc_with_for_2(x):
def dyfunc_with_for_3(x): def dyfunc_with_for_3(x):
# TODO(liym27):
# It will fail to run because `for i in range(len(x.shape))` will be transformed into Paddle while_loop.
# Here the python list x.shape will be added to loop_vars. However, loop_vars doesn't support python list.
# And the condition of `for i in range(len(x.shape))` only uses the length of x.shape, so it doesn't have to be transformed into Paddle while_loop.
# After the AST tranformation of for loop is improved, add TestTensorShapeInFor3.
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `len(x.shape)` is not transformed. # `len(x.shape)` is not transformed.
...@@ -282,6 +277,11 @@ class TestTensorShapeBasic2(TestTensorShapeBasic): ...@@ -282,6 +277,11 @@ class TestTensorShapeBasic2(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_2 self.dygraph_func = dyfunc_tensor_shape_2
def _set_expected_op_num(self):
self.expected_op_num = 3
self.expected_shape_op_num = 1
self.expected_slice_op_num = 0
class TestTensorShapeBasic3(TestTensorShapeBasic): class TestTensorShapeBasic3(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
...@@ -319,6 +319,11 @@ class TestTupleShape1(TestTensorShapeBasic): ...@@ -319,6 +319,11 @@ class TestTupleShape1(TestTensorShapeBasic):
self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")] self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")]
self.dygraph_func = dyfunc_tuple_shape_1 self.dygraph_func = dyfunc_tuple_shape_1
def _set_expected_op_num(self):
self.expected_op_num = 6
self.expected_shape_op_num = 2
self.expected_slice_op_num = 2
class TestTupleShape2(TestTensorShapeBasic): class TestTupleShape2(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
...@@ -326,6 +331,11 @@ class TestTupleShape2(TestTensorShapeBasic): ...@@ -326,6 +331,11 @@ class TestTupleShape2(TestTensorShapeBasic):
self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")] self.input_spec = [paddle.static.InputSpec(shape=[5, 7], dtype="int32")]
self.dygraph_func = dyfunc_tuple_shape_2 self.dygraph_func = dyfunc_tuple_shape_2
def _set_expected_op_num(self):
self.expected_op_num = 5
self.expected_shape_op_num = 1
self.expected_slice_op_num = 2
class TestPaddleShapeApi(TestTensorShapeBasic): class TestPaddleShapeApi(TestTensorShapeBasic):
def init_test_func(self): def init_test_func(self):
...@@ -381,6 +391,16 @@ class TestTensorShapeInFor2(TestTensorShapeInFor1): ...@@ -381,6 +391,16 @@ class TestTensorShapeInFor2(TestTensorShapeInFor1):
self.expected_slice_op_num = 1 self.expected_slice_op_num = 1
class TestTensorShapeInFor3(TestTensorShapeInFor1):
def init_test_func(self):
self.dygraph_func = dyfunc_with_for_3
def _set_expected_op_num(self):
self.expected_op_num = 25
self.expected_shape_op_num = 6
self.expected_slice_op_num = 3
# 4. Tests with control flow while loop # 4. Tests with control flow while loop
class TestTensorShapeInWhile1(TestTensorShapeInFor1): class TestTensorShapeInWhile1(TestTensorShapeInFor1):
def init_test_func(self): def init_test_func(self):
...@@ -402,8 +422,8 @@ class TestTensorShapeInWhile3(TestTensorShapeBasic): ...@@ -402,8 +422,8 @@ class TestTensorShapeInWhile3(TestTensorShapeBasic):
self.dygraph_func = dyfunc_with_while_3 self.dygraph_func = dyfunc_with_while_3
def _set_expected_op_num(self): def _set_expected_op_num(self):
self.expected_op_num = 2 self.expected_op_num = 3
self.expected_shape_op_num = 0 self.expected_shape_op_num = 1
self.expected_slice_op_num = 0 self.expected_slice_op_num = 0
...@@ -474,9 +494,9 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape): ...@@ -474,9 +494,9 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape):
self.dygraph_func = dyfunc_tuple_shape_1 self.dygraph_func = dyfunc_tuple_shape_1
def _set_expected_op_num(self): def _set_expected_op_num(self):
self.expected_op_num = 2 self.expected_op_num = 7
self.expected_shape_op_num = 0 self.expected_shape_op_num = 2
self.expected_slice_op_num = 0 self.expected_slice_op_num = 2
class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape): class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
...@@ -516,9 +536,9 @@ class TestChangeShapeAfterAssign(TestTensorShapeBasic): ...@@ -516,9 +536,9 @@ class TestChangeShapeAfterAssign(TestTensorShapeBasic):
self.dygraph_func = dyfunc_change_shape_after_assign self.dygraph_func = dyfunc_change_shape_after_assign
def _set_expected_op_num(self): def _set_expected_op_num(self):
self.expected_op_num = 3 self.expected_op_num = 7
self.expected_shape_op_num = 0 self.expected_shape_op_num = 2
self.expected_slice_op_num = 0 self.expected_slice_op_num = 2
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册