From 824572c144006083e7548f0c29aae617fee9e595 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Thu, 21 May 2020 21:06:26 +0800 Subject: [PATCH] [Dy2Stat]Support to transform sequence assignments and multi-target assignments to normal assignments (#24643) --- .../dygraph_to_static/list_transformer.py | 4 +- .../fluid/dygraph/dygraph_to_static/utils.py | 82 +++++++++++++++++++ .../unittests/dygraph_to_static/test_utils.py | 33 ++++++++ 3 files changed, 118 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py index 92cc10365f7..de9acabe247 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py @@ -19,6 +19,7 @@ import gast from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code, is_control_flow_to_transform +from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer from paddle.fluid.framework import core, Variable from paddle.fluid.layers import array_length, array_read, array_write, create_array from paddle.fluid.layers import assign, fill_constant, slice @@ -108,6 +109,7 @@ class ListTransformer(gast.NodeTransformer): self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): + SplitAssignTransformer(self.root).transform() self.visit(self.root) self.replace_list_with_tensor_array(self.root) @@ -264,7 +266,7 @@ class ListTransformer(gast.NodeTransformer): def _update_list_name_to_updated(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] - # TODO: Consider node has more than one target. eg: x, y = a, [] + # NOTE: Code like `x, y = a, []` has been transformed to `x=a; y=[]` try: target_id = target_node.id except AttributeError: diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index bad09a160d2..460435c38d3 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -927,3 +927,85 @@ class ForNodeVisitor(object): if self.is_for_enumerate_iter(): return self.target.elts[0].id return None + + +class SplitAssignTransformer(gast.NodeTransformer): + """ + This class transforms sequence assignments and multi-target assignments to normal assignments. + """ + + def __init__(self, ast_node): + assert isinstance(ast_node, gast.AST) + self.ast_root = ast_node + + def transform(self): + self.visit(self.ast_root) + + def visit_Assign(self, node): + target_nodes = node.targets + if len(target_nodes) == 1: + node = self._parse_sequence_assign(node) + else: + node = self._parse_multi_target_assign(node) + return node + + def _parse_sequence_assign(self, node): + """ + a, b = c, d + -> + a = c + b = d + """ + assert isinstance(node, gast.Assign) + + target_nodes = node.targets + value_node = node.value + if not isinstance(target_nodes[0], (gast.List, gast.Tuple)): + return node + if not isinstance(value_node, (gast.List, gast.Tuple)): + return node + + targets = node.targets[0].elts + values = node.value.elts + if len(targets) != len(values): + return node + + new_nodes = [] + for target, value in zip(targets, values): + assign_node = gast.Assign(targets=[target], value=value) + new_nodes.append(assign_node) + + return new_nodes + + def _parse_multi_target_assign(self, node): + """ + Example 1: + a = b = c + -> + b = c + a = b + + Example 2: + a, b = c, d = x + -> + c,d = x + a = c + b = d + """ + assert isinstance(node, gast.Assign) + + target_nodes = node.targets + value_node = node.value + new_nodes = [] + for target in reversed(target_nodes): + assign_node = gast.Assign(targets=[target], value=value_node) + # NOTE: Because assign_node can be sequence assign statement like `a,b = c,d`, + # it's necessary to visit this new assign_node + parsed_node = self.visit_Assign(assign_node) + if not isinstance(parsed_node, list): + parsed_node = [parsed_node] + + new_nodes.extend(parsed_node) + value_node = target + + return new_nodes diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py index 0c8ebb163ba..08fa655c6db 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py @@ -16,7 +16,11 @@ from __future__ import print_function import unittest +from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list +from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer + +from test_program_translator import get_source_code class TestIndexInList(unittest.TestCase): @@ -29,5 +33,34 @@ class TestIndexInList(unittest.TestCase): self.assertEqual(index_in_list(list_to_test, 6), -1) +def dyfunc_assign(input): + a = b = 1 + c, d = e, f = a, b + z = [3, 4] + [x, y] = m, n = z + + +class StaticCode(): + def dyfunc_assign(input): + b = 1 + a = b + e = a + f = b + c = e + d = f + z = [3, 4] + m, n = z + x = m + y = n + + +class TestSplitAssignTransformer(unittest.TestCase): + def test_code(self): + answer = get_source_code(StaticCode.dyfunc_assign) + program_translator = ProgramTranslator() + code = program_translator.get_code(dyfunc_assign) + self.assertEqual(answer, code) + + if __name__ == '__main__': unittest.main() -- GitLab