未验证 提交 824572c1 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat]Support to transform sequence assignments and multi-target...

[Dy2Stat]Support to transform sequence assignments and multi-target assignments to normal assignments (#24643)
上级 586b5875
......@@ -19,6 +19,7 @@ import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code, is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.framework import core, Variable
from paddle.fluid.layers import array_length, array_read, array_write, create_array
from paddle.fluid.layers import assign, fill_constant, slice
......@@ -108,6 +109,7 @@ class ListTransformer(gast.NodeTransformer):
self.scope_var_type_dict = var_env.get_scope_var_type()
def transform(self):
SplitAssignTransformer(self.root).transform()
self.visit(self.root)
self.replace_list_with_tensor_array(self.root)
......@@ -264,7 +266,7 @@ class ListTransformer(gast.NodeTransformer):
def _update_list_name_to_updated(self, node):
assert isinstance(node, gast.Assign)
target_node = node.targets[0]
# TODO: Consider node has more than one target. eg: x, y = a, []
# NOTE: Code like `x, y = a, []` has been transformed to `x=a; y=[]`
try:
target_id = target_node.id
except AttributeError:
......
......@@ -927,3 +927,85 @@ class ForNodeVisitor(object):
if self.is_for_enumerate_iter():
return self.target.elts[0].id
return None
class SplitAssignTransformer(gast.NodeTransformer):
"""
This class transforms sequence assignments and multi-target assignments to normal assignments.
"""
def __init__(self, ast_node):
assert isinstance(ast_node, gast.AST)
self.ast_root = ast_node
def transform(self):
self.visit(self.ast_root)
def visit_Assign(self, node):
target_nodes = node.targets
if len(target_nodes) == 1:
node = self._parse_sequence_assign(node)
else:
node = self._parse_multi_target_assign(node)
return node
def _parse_sequence_assign(self, node):
"""
a, b = c, d
->
a = c
b = d
"""
assert isinstance(node, gast.Assign)
target_nodes = node.targets
value_node = node.value
if not isinstance(target_nodes[0], (gast.List, gast.Tuple)):
return node
if not isinstance(value_node, (gast.List, gast.Tuple)):
return node
targets = node.targets[0].elts
values = node.value.elts
if len(targets) != len(values):
return node
new_nodes = []
for target, value in zip(targets, values):
assign_node = gast.Assign(targets=[target], value=value)
new_nodes.append(assign_node)
return new_nodes
def _parse_multi_target_assign(self, node):
"""
Example 1:
a = b = c
->
b = c
a = b
Example 2:
a, b = c, d = x
->
c,d = x
a = c
b = d
"""
assert isinstance(node, gast.Assign)
target_nodes = node.targets
value_node = node.value
new_nodes = []
for target in reversed(target_nodes):
assign_node = gast.Assign(targets=[target], value=value_node)
# NOTE: Because assign_node can be sequence assign statement like `a,b = c,d`,
# it's necessary to visit this new assign_node
parsed_node = self.visit_Assign(assign_node)
if not isinstance(parsed_node, list):
parsed_node = [parsed_node]
new_nodes.extend(parsed_node)
value_node = target
return new_nodes
......@@ -16,7 +16,11 @@ from __future__ import print_function
import unittest
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from test_program_translator import get_source_code
class TestIndexInList(unittest.TestCase):
......@@ -29,5 +33,34 @@ class TestIndexInList(unittest.TestCase):
self.assertEqual(index_in_list(list_to_test, 6), -1)
def dyfunc_assign(input):
a = b = 1
c, d = e, f = a, b
z = [3, 4]
[x, y] = m, n = z
class StaticCode():
def dyfunc_assign(input):
b = 1
a = b
e = a
f = b
c = e
d = f
z = [3, 4]
m, n = z
x = m
y = n
class TestSplitAssignTransformer(unittest.TestCase):
def test_code(self):
answer = get_source_code(StaticCode.dyfunc_assign)
program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_assign)
self.assertEqual(answer, code)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册