diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/base_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/base_transformer.py index 9df7e8d9b4f418bf48c37c200399fc3d937fbb64..ff5b522095244609c3f2a511571d22078d299cc5 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/base_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/base_transformer.py @@ -186,88 +186,6 @@ class ForLoopTuplePreTransformer(BaseTransformer): return [assign_node] -class SplitAssignTransformer(BaseTransformer): - """ - This class transforms sequence assignments and multi-target assignments to normal assignments. - """ - - def __init__(self, ast_node): - assert isinstance(ast_node, gast.AST) - self.ast_root = ast_node - - def transform(self): - self.visit(self.ast_root) - - def visit_Assign(self, node): - target_nodes = node.targets - if len(target_nodes) == 1: - node = self._parse_sequence_assign(node) - else: - node = self._parse_multi_target_assign(node) - return node - - def _parse_sequence_assign(self, node): - """ - a, b = c, d - -> - a = c - b = d - """ - assert isinstance(node, gast.Assign) - - target_nodes = node.targets - value_node = node.value - if not isinstance(target_nodes[0], (gast.List, gast.Tuple)): - return node - if not isinstance(value_node, (gast.List, gast.Tuple)): - return node - - targets = node.targets[0].elts - values = node.value.elts - if len(targets) != len(values): - return node - - new_nodes = [] - for target, value in zip(targets, values): - assign_node = gast.Assign(targets=[target], value=value) - new_nodes.append(assign_node) - - return new_nodes - - def _parse_multi_target_assign(self, node): - """ - Example 1: - a = b = c - -> - b = c - a = b - - Example 2: - a, b = c, d = x - -> - c,d = x - a = c - b = d - """ - assert isinstance(node, gast.Assign) - - target_nodes = node.targets - value_node = node.value - new_nodes = [] - for target in reversed(target_nodes): - assign_node = gast.Assign(targets=[target], value=value_node) - # NOTE: Because assign_node can be sequence assign statement like `a,b = c,d`, - # it's necessary to visit this new assign_node - parsed_node = self.visit_Assign(assign_node) - if not isinstance(parsed_node, list): - parsed_node = [parsed_node] - - new_nodes.extend(parsed_node) - value_node = target - - return new_nodes - - class ForNodeVisitor(object): """ This class parses python for statement, get transformed 3 statement components of for node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py index 29e3ed5296806da88d25a8c69becc14fd02da259..33ec9b9be73e5afb81e040df480dae78f28540f6 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py @@ -22,7 +22,6 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer -from paddle.fluid.dygraph.dygraph_to_static.base_transformer import SplitAssignTransformer class ListTransformer(BaseTransformer): @@ -47,7 +46,6 @@ class ListTransformer(BaseTransformer): self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): - SplitAssignTransformer(self.root).transform() self.visit(self.root) self.replace_list_with_tensor_array(self.root) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py index 6f4fe613db714629172a03849cf110a9d9c43b9a..b2943c779ca19ffa90eded0f995a48ef3c9934ac 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py @@ -57,15 +57,6 @@ class StaticCode(): y = n -class TestSplitAssignTransformer(unittest.TestCase): - - def test_code(self): - answer = get_source_code(StaticCode.dyfunc_assign) - program_translator = ProgramTranslator() - code = program_translator.get_code(dyfunc_assign) - self.assertEqual(answer, code) - - class TestIsPaddle(unittest.TestCase): def fake_module(self):