未验证 提交 036121b7 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Make loop_transformer supports class variable (#23478)

This CR makes two changes:

1. In old loop_transformer, if a class variable, such as "self.a, foo.bar" is a loop var, the Dy2stat will fail because `def func(self.foo)` is not legal syntax. We support class variable by renaming.

2. After https://github.com/PaddlePaddle/Paddle/pull/22892 is merged, we can support `while x < 10` in dygraph. I enable those tests in corresponding Dy2stat
上级 abe3e690
......@@ -26,6 +26,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node
......@@ -36,6 +37,7 @@ WHILE_BODY_PREFIX = 'while_body'
FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body'
GENERATE_VARIABLE_PREFIX = 'generate_variable'
def create_while_node(condition_name, body_name, loop_var_names):
......@@ -440,7 +442,8 @@ class LoopTransformer(gast.NodeTransformer):
#
# We need to create static variable for those variables
for name in create_var_names:
new_stmts.append(create_static_variable_gast_node(name))
if "." not in name:
new_stmts.append(create_static_variable_gast_node(name))
new_stmts.append(init_stmt)
......@@ -468,6 +471,11 @@ class LoopTransformer(gast.NodeTransformer):
decorator_list=[],
returns=None,
type_comment=None)
for name in loop_var_names:
if "." in name:
rename_transformer = RenameTransformer(condition_func_node)
rename_transformer.rename(
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(condition_func_node)
new_body = node.body
......@@ -495,6 +503,11 @@ class LoopTransformer(gast.NodeTransformer):
decorator_list=[],
returns=None,
type_comment=None)
for name in loop_var_names:
if "." in name:
rename_transformer = RenameTransformer(body_func_node)
rename_transformer.rename(
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(body_func_node)
while_loop_node = create_while_node(condition_func_node.name,
......@@ -521,7 +534,8 @@ class LoopTransformer(gast.NodeTransformer):
#
# We need to create static variable for those variables
for name in create_var_names:
new_stmts.append(create_static_variable_gast_node(name))
if "." not in name:
new_stmts.append(create_static_variable_gast_node(name))
# while x < 10 in dygraph should be convert into static tensor < 10
for name in loop_var_names:
......@@ -550,6 +564,11 @@ class LoopTransformer(gast.NodeTransformer):
decorator_list=[],
returns=None,
type_comment=None)
for name in loop_var_names:
if "." in name:
rename_transformer = RenameTransformer(condition_func_node)
rename_transformer.rename(
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(condition_func_node)
new_body = node.body
......@@ -576,6 +595,11 @@ class LoopTransformer(gast.NodeTransformer):
decorator_list=[],
returns=None,
type_comment=None)
for name in loop_var_names:
if "." in name:
rename_transformer = RenameTransformer(body_func_node)
rename_transformer.rename(
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(body_func_node)
while_loop_node = create_while_node(condition_func_node.name,
......
......@@ -352,6 +352,43 @@ def index_in_list(array_list, item):
return -1
def create_assign_node(name, node):
"""
Creates a `gast.Assign` node by given name_id as target and node as value.
"""
targets = generate_name_node(name, ctx=gast.Store())
assign_node = gast.Assign(targets=[targets], value=node)
return targets, assign_node
class RenameTransformer(gast.NodeTransformer):
def __init__(self, node):
assert isinstance(
node, gast.AST), "RenameTransformer only accepts gast.AST as input"
self.root = node
self.old_name = ""
self.new_name = ""
def rename(self, old_name, new_name):
self.old_name = old_name
self.new_name = new_name
self.visit(self.root)
def visit_Name(self, node):
self.generic_visit(node)
if node.id == self.old_name:
node.id = self.new_name
return node
def visit_Attribute(self, node):
self.generic_visit(node)
attr_full_name = get_attribute_full_name(node)
if attr_full_name == self.old_name:
new_name_node = gast.parse(self.new_name).body[0].value
return new_name_node
return node
def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
"""
Transform modified AST of decorated function into python callable object.
......@@ -399,12 +436,3 @@ def ast_to_source_code(ast_node):
ast_node = gast.gast_to_ast(ast_node)
source_code = astor.to_source(ast_node)
return source_code
def create_assign_node(name, node):
"""
Creates a `gast.Assign` node by given name_id as target and node as value.
"""
targets = generate_name_node(name, ctx=gast.Store())
assign_node = gast.Assign(targets=[targets], value=node)
return targets, assign_node
......@@ -233,3 +233,20 @@ def if_with_and_or_4(x, y=None):
mean_res.numpy()[0] > 0):
x = x - 1
return x
def if_with_class_var(x, y=None):
class Foo(object):
def __init__(self):
self.a = 1
self.b = 2
foo = Foo()
batch_size = fluid.layers.shape(x)
mean_res = fluid.layers.mean(x)
if batch_size[0] > foo.a:
x = x + foo.b
else:
x = x - foo.b
return x
......@@ -124,6 +124,26 @@ def test_for_in_else(x):
return x
def while_loop_class_var(x):
class Foo(object):
def __init__(self):
self.a = 3
self.b = 4
self.c = 5
foo = Foo()
i = fluid.dygraph.to_variable(x)
while i < 10:
foo.b = fluid.layers.zeros(shape=[1], dtype='float32')
foo.c = foo.b + foo.a
i += 1
if foo.c < 0:
continue
if foo.c > 6:
break
return foo.c
class TestContinueInFor(unittest.TestCase):
def setUp(self):
self.input = np.zeros((1)).astype('int32')
......@@ -186,24 +206,15 @@ class TestContinueInWhile(TestContinueInFor):
def init_dygraph_func(self):
self.dygraph_func = test_continue_in_while
def test_transformed_static_result(self):
# TODO: while i < 10 in dygraph will be supported after PR22892
# so currently we just assert static result.
# remove this overrided function after PR22892 is merged
static_res = self.run_static_mode()
self.assertEqual(15, static_res[0])
class TestBreakInWhile(TestContinueInWhile):
def init_dygraph_func(self):
self.dygraph_func = test_break_in_while
def test_transformed_static_result(self):
# TODO: while i < 10 in dygraph will be supported after PR22892
# so currently we just assert static result.
# remove this overrided function after PR22892 is merged
static_res = self.run_static_mode()
self.assertEqual(15, static_res[0])
class TestWhileLoopClassVar(TestContinueInWhile):
def init_dygraph_func(self):
self.dygraph_func = while_loop_class_var
if __name__ == '__main__':
......
......@@ -116,7 +116,7 @@ class TestDygraphIfElse6(TestDygraphIfElse):
self.dyfunc = dyfunc_ifExp_with_while
def dyfunc_ifExp_with_while2(x):
def dyfunc_ifExp(x):
y = [x]
def add_fn(x):
......@@ -128,16 +128,16 @@ def dyfunc_ifExp_with_while2(x):
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
# It will be converted into `layers.cond` as followed.
# map_func(lambda x: fluid.layers.cond(i==0, lambda: x, lambda: add_fn(x), y)
# `i (Tensor) == 0` is supported in dygraph.
y = map_func(lambda x: x if i == 0 else add_fn(x), y)
# map_func(lambda x: fluid.layers.cond(i==1, lambda: x, lambda: add_fn(x), y)
# `if (Tensor) == 1` is supported in dygraph.
y = map_func(lambda x: x if i == 1 else add_fn(x), y)
return y[0]
class TestDygraphIfElse7(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = dyfunc_ifExp_with_while2
self.dyfunc = dyfunc_ifExp
class TestDygraphIfElseWithAndOr(TestDygraphIfElse):
......@@ -170,6 +170,12 @@ class TestDygraphIfElseWithAndOr4(TestDygraphIfElse):
self.dyfunc = if_with_and_or_4
class TestDygraphIfElseWithClassVar(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = if_with_class_var
class TestDygraphIfElseNet(unittest.TestCase):
"""
TestCase for the transformation from control flow `if/else`
......
......@@ -78,6 +78,36 @@ def while_loop_bool_op(x):
return i
def while_loop_class_var(x):
class Foo(object):
def __init__(self):
self.a = 3
self.b = 4
self.c = 5
foo = Foo()
i = fluid.dygraph.to_variable(x)
while i < 10:
foo.b = fluid.layers.zeros(shape=[1], dtype='float32')
foo.c = foo.b + foo.a
i += 1
return foo.c
def for_loop_class_var(max_len):
class Foo(object):
def __init__(self):
self.a = 3
self.b = 4
self.c = 5
foo = Foo()
for i in range(max_len):
foo.b = fluid.layers.zeros(shape=[1], dtype='float32')
foo.c = foo.b + foo.a
return foo.c
def var_create_in_for_loop(max_len):
for i in range(max_len):
ret = fluid.layers.zeros(shape=[3, 4, 5], dtype='float64')
......@@ -136,15 +166,8 @@ class TestTransformWhileLoop(unittest.TestCase):
def test_ast_to_func(self):
static_numpy = self._run_static()
self.assertTrue(
np.allclose(
np.full(
shape=(1), fill_value=45, dtype=np.int32), static_numpy))
# Enable next lines after Paddle dygraph supports while x < 10
#
# self._run_dygraph()
# self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
dygraph_numpy = self._run_dygraph()
self.assertTrue(np.allclose(dygraph_numpy, static_numpy))
class TestTransformWhileLoopWithConflicVar(TestTransformWhileLoop):
......@@ -162,6 +185,11 @@ class TestWhileLoopBoolOp(TestTransformWhileLoop):
self.dyfunc = while_loop_bool_op
class TestWhileLoopClassVar(TestTransformWhileLoop):
def _init_dyfunc(self):
self.dyfunc = while_loop_class_var
class TestTransformForLoop(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
......@@ -192,6 +220,11 @@ class TestTransformForLoop(unittest.TestCase):
self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
class TestClassVarInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_class_var
class TestVarCreateInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = var_create_in_for_loop
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册