未验证 提交 be0ec904 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ] Remove assign split (#44769)

* just a test

* remove split assign test

* remove other useless code related to split assign
上级 683f8190
...@@ -186,88 +186,6 @@ class ForLoopTuplePreTransformer(BaseTransformer): ...@@ -186,88 +186,6 @@ class ForLoopTuplePreTransformer(BaseTransformer):
return [assign_node] return [assign_node]
class SplitAssignTransformer(BaseTransformer):
"""
This class transforms sequence assignments and multi-target assignments to normal assignments.
"""
def __init__(self, ast_node):
assert isinstance(ast_node, gast.AST)
self.ast_root = ast_node
def transform(self):
self.visit(self.ast_root)
def visit_Assign(self, node):
target_nodes = node.targets
if len(target_nodes) == 1:
node = self._parse_sequence_assign(node)
else:
node = self._parse_multi_target_assign(node)
return node
def _parse_sequence_assign(self, node):
"""
a, b = c, d
->
a = c
b = d
"""
assert isinstance(node, gast.Assign)
target_nodes = node.targets
value_node = node.value
if not isinstance(target_nodes[0], (gast.List, gast.Tuple)):
return node
if not isinstance(value_node, (gast.List, gast.Tuple)):
return node
targets = node.targets[0].elts
values = node.value.elts
if len(targets) != len(values):
return node
new_nodes = []
for target, value in zip(targets, values):
assign_node = gast.Assign(targets=[target], value=value)
new_nodes.append(assign_node)
return new_nodes
def _parse_multi_target_assign(self, node):
"""
Example 1:
a = b = c
->
b = c
a = b
Example 2:
a, b = c, d = x
->
c,d = x
a = c
b = d
"""
assert isinstance(node, gast.Assign)
target_nodes = node.targets
value_node = node.value
new_nodes = []
for target in reversed(target_nodes):
assign_node = gast.Assign(targets=[target], value=value_node)
# NOTE: Because assign_node can be sequence assign statement like `a,b = c,d`,
# it's necessary to visit this new assign_node
parsed_node = self.visit_Assign(assign_node)
if not isinstance(parsed_node, list):
parsed_node = [parsed_node]
new_nodes.extend(parsed_node)
value_node = target
return new_nodes
class ForNodeVisitor(object): class ForNodeVisitor(object):
""" """
This class parses python for statement, get transformed 3 statement components of for node This class parses python for statement, get transformed 3 statement components of for node
......
...@@ -22,7 +22,6 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code ...@@ -22,7 +22,6 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import SplitAssignTransformer
class ListTransformer(BaseTransformer): class ListTransformer(BaseTransformer):
...@@ -47,7 +46,6 @@ class ListTransformer(BaseTransformer): ...@@ -47,7 +46,6 @@ class ListTransformer(BaseTransformer):
self.scope_var_type_dict = var_env.get_scope_var_type() self.scope_var_type_dict = var_env.get_scope_var_type()
def transform(self): def transform(self):
SplitAssignTransformer(self.root).transform()
self.visit(self.root) self.visit(self.root)
self.replace_list_with_tensor_array(self.root) self.replace_list_with_tensor_array(self.root)
......
...@@ -57,15 +57,6 @@ class StaticCode(): ...@@ -57,15 +57,6 @@ class StaticCode():
y = n y = n
class TestSplitAssignTransformer(unittest.TestCase):
def test_code(self):
answer = get_source_code(StaticCode.dyfunc_assign)
program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_assign)
self.assertEqual(answer, code)
class TestIsPaddle(unittest.TestCase): class TestIsPaddle(unittest.TestCase):
def fake_module(self): def fake_module(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册