From b2c1247c6a8ec093581e35ade3314d7d779c8dae Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Fri, 8 Jul 2022 19:06:12 +0800 Subject: [PATCH] [Dy2St]Polish visit function in transformer (#44083) * Polish visit function in transformer * Call generic_visit first in visit_While/For * Remove comments * Polish utils.py, move some transformer to base_transformer --- .../dygraph_to_static/base_transformer.py | 649 +++++++++++++++++- .../break_continue_transformer.py | 2 +- .../dygraph_to_static/list_transformer.py | 2 +- .../dygraph_to_static/loop_transformer.py | 21 +- .../dygraph/dygraph_to_static/origin_info.py | 5 +- .../dygraph_to_static/return_transformer.py | 4 +- .../fluid/dygraph/dygraph_to_static/utils.py | 638 +---------------- 7 files changed, 664 insertions(+), 657 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/base_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/base_transformer.py index 127a8e92324..a3c2c0c69ef 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/base_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/base_transformer.py @@ -13,8 +13,17 @@ # limitations under the License. from paddle.utils import gast - -from paddle.fluid.dygraph.dygraph_to_static.origin_info import ORIGI_INFO +from paddle.fluid import unique_name +from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code +from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node +from paddle.fluid.dygraph.dygraph_to_static.utils import ORIGI_INFO +from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_INDEX_PREFIX +from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_TUPLE_PREFIX +from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_TUPLE_INDEX_PREFIX +from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_VAR_LEN_PREFIX +from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_VAR_NAME_PREFIX +from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_ZIP_TO_LIST_PREFIX class BaseTransformer(gast.NodeTransformer): @@ -36,3 +45,639 @@ class BaseTransformer(gast.NodeTransformer): setattr(n, ORIGI_INFO, origin_info) return result + + +class RenameTransformer(BaseTransformer): + + def __init__(self, node): + assert isinstance( + node, gast.AST), "RenameTransformer only accepts gast.AST as input" + self.root = node + self.old_name = "" + self.new_name = "" + + def rename(self, old_name, new_name): + self.old_name = old_name + self.new_name = new_name + self.visit(self.root) + + def visit_Name(self, node): + self.generic_visit(node) + if node.id == self.old_name: + node.id = self.new_name + return node + + def visit_Attribute(self, node): + self.generic_visit(node) + attr_full_name = get_attribute_full_name(node) + if attr_full_name == self.old_name: + new_name_node = gast.parse(self.new_name).body[0].value + return new_name_node + return node + + +class NameNodeReplaceTransformer(BaseTransformer): + """ + This class replaces specified gast.Name node by replace_node. + """ + + def __init__(self, root_node, target_name, replace_node): + assert isinstance(target_name, str) + + # NOTE(liym27): + # Use gast.Name to replace gast.Name, otherwise, errors may occur. + # + # For examples: + # If using a gast.Subscript to replace gast.Name, and the original gast.Name + # is in the arguments of FunctionDef, an exception will be raised. + # + # ``` + # def func(x[i])) # x[i] can not be a argument + # # ... + # ``` + + assert isinstance(replace_node, gast.Name) + self.target_name = target_name + self.replace_node = replace_node + + self.visit(root_node) + + def visit_Name(self, node): + if node.id == self.target_name: + return self.replace_node + return node + + def visit_Nonlocal(self, node): + names = node.names + + def replace(s): + if s == self.target_name: return self.replace_node.id + return s + + node.names = list(map(replace, names)) + return node + + +class ForLoopTuplePreTransformer(BaseTransformer): + """ + ForNodeVisitor parses 3 type statements (Here var is VarBase(Tensor) or python variable): + 1). for x in range(var[*]|var.numpy()[*]) + 2). for x in var|var.numpy() + 3). for i, x in enumerate(var|var.numpy()) + + We chose these 3 types because they are easier (x can be variable name iterating in var). + However, users can write tuples in Python for loop, such as + 1). for var1, var2 in var|var.numpy() + 2). for t in enumerate(var|var.numpy()) + 2). for i, (var1, var2, va3) in enumerate(var|var.numpy()) + + To handle these case, this method will do the rewrite tuple pre-process: + 1). Non-enumerate case: for var1, var2 in var|var.numpy() will be re-written as: + for FOR_ITER_TUPLE_PREFIX_x in var | var.numpy(): + var1 = FOR_ITER_TUPLE_PREFIX_x[0] + var2 = FOR_ITER_TUPLE_PREFIX_x[1] + 2). Enumerate out tuple case: for t in enumerate(var|var.numpy) will be rewritten as: + for FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x in enumerate(var|var.numpy): + t = (FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x) + 3). Enumerate inner tuple case: for i, (var1, (var2, va3)) in enumerate(var|var.numpy()) will + be re-written as: + for i, FOR_ITER_TUPLE_PREFIX_x in var | var.numpy(): + var1 = FOR_ITER_TUPLE_PREFIX_x[0] + var2 = FOR_ITER_TUPLE_PREFIX_x[1][0] + var3 = FOR_ITER_TUPLE_PREFIX_x[1][1] + """ + + def __init__(self, wrapper_root): + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + + def transform(self): + self.visit(self.root) + + def visit_For(self, node): + if self.is_for_enumerate_iter(node): + if isinstance(node.target, (gast.Name, gast.Attribute)): + # Out tuple case + out_tuple_name = ast_to_source_code(node.target).strip() + tuple_iter_name = unique_name.generate( + FOR_ITER_TUPLE_INDEX_PREFIX) + tuple_var_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX) + node.target = gast.Tuple(elts=[ + gast.Name(id=tuple_iter_name, + ctx=gast.Store(), + annotation=None, + type_comment=None), + gast.Name(id=tuple_var_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + ctx=gast.Store()) + node.body.insert( + 0, + gast.Assign(targets=[ + gast.Name(id=out_tuple_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + value=gast.Tuple(elts=[ + gast.Name(id=tuple_iter_name, + ctx=gast.Load(), + annotation=None, + type_comment=None), + gast.Name(id=tuple_var_name, + ctx=gast.Load(), + annotation=None, + type_comment=None) + ], + ctx=gast.Load()))) + elif isinstance(node.target, (gast.List, gast.Tuple)) and len( + node.target.elts) >= 2 and isinstance( + node.target.elts[1], (gast.List, gast.Tuple)): + # Inner tuple case + inner_tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX) + origin_inner_tuple_node = node.target.elts[1] + node.target.elts[1] = gast.Name(id=inner_tuple_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + node.body[0:0] = self.tuple_to_stmts(origin_inner_tuple_node, + inner_tuple_name) + elif self.is_for_iter(node) and isinstance(node.target, + (gast.List, gast.Tuple)): + # Non-enumrate case: + tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX) + origin_tuple_node = node.target + node.target = gast.Name(id=tuple_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + node.body[0:0] = self.tuple_to_stmts(origin_tuple_node, tuple_name) + return node + + def tuple_to_stmts(self, node, tuple_name, idx=[]): + if not isinstance(node, (gast.Tuple, gast.List)): + value_node_str = tuple_name + for i in idx: + value_node_str = value_node_str + "[{}]".format(i) + + node_str = ast_to_source_code(node).strip() + assign_node_str = "{} = {}".format(node_str, value_node_str) + assign_node = gast.parse(assign_node_str).body[0] + return [assign_node] + + # isinstance(node, (gast.Tuple, gast.List)) + ret = [] + for i, element in enumerate(node.elts): + ret += self.tuple_to_stmts(node.elts[i], tuple_name, idx + [i]) + return ret + + def is_for_iter(self, for_node): + assert isinstance(for_node, + gast.For), "Input node is not gast.For node." + if isinstance(for_node.iter, (gast.Name, gast.Attribute)): + return True + elif isinstance(for_node.iter, gast.Call) and isinstance( + for_node.iter.func, + gast.Attribute) and for_node.iter.func.attr == 'numpy': + return True + elif isinstance(for_node.iter, gast.Subscript): + return True + else: + return False + + def is_for_enumerate_iter(self, for_node): + assert isinstance(for_node, + gast.For), "Input node is not gast.For node." + return isinstance(for_node.iter, gast.Call) and isinstance( + for_node.iter.func, + gast.Name) and for_node.iter.func.id == "enumerate" + + +class SplitAssignTransformer(BaseTransformer): + """ + This class transforms sequence assignments and multi-target assignments to normal assignments. + """ + + def __init__(self, ast_node): + assert isinstance(ast_node, gast.AST) + self.ast_root = ast_node + + def transform(self): + self.visit(self.ast_root) + + def visit_Assign(self, node): + target_nodes = node.targets + if len(target_nodes) == 1: + node = self._parse_sequence_assign(node) + else: + node = self._parse_multi_target_assign(node) + return node + + def _parse_sequence_assign(self, node): + """ + a, b = c, d + -> + a = c + b = d + """ + assert isinstance(node, gast.Assign) + + target_nodes = node.targets + value_node = node.value + if not isinstance(target_nodes[0], (gast.List, gast.Tuple)): + return node + if not isinstance(value_node, (gast.List, gast.Tuple)): + return node + + targets = node.targets[0].elts + values = node.value.elts + if len(targets) != len(values): + return node + + new_nodes = [] + for target, value in zip(targets, values): + assign_node = gast.Assign(targets=[target], value=value) + new_nodes.append(assign_node) + + return new_nodes + + def _parse_multi_target_assign(self, node): + """ + Example 1: + a = b = c + -> + b = c + a = b + + Example 2: + a, b = c, d = x + -> + c,d = x + a = c + b = d + """ + assert isinstance(node, gast.Assign) + + target_nodes = node.targets + value_node = node.value + new_nodes = [] + for target in reversed(target_nodes): + assign_node = gast.Assign(targets=[target], value=value_node) + # NOTE: Because assign_node can be sequence assign statement like `a,b = c,d`, + # it's necessary to visit this new assign_node + parsed_node = self.visit_Assign(assign_node) + if not isinstance(parsed_node, list): + parsed_node = [parsed_node] + + new_nodes.extend(parsed_node) + value_node = target + + return new_nodes + + +class ForNodeVisitor(object): + """ + This class parses python for statement, get transformed 3 statement components of for node + three key statements: + 1). init_stmts: list[node], prepare nodes of for loop, may not only one + 2). cond_stmt: node, condition node to judge whether continue loop + 3). body_stmts: list[node], updated loop body, sometimes we should change + the original statement in body, not just append new statement + + In this process, the semantics of for does not change. + + Now only can parse 3 type statements (Here var is VarBase(Tensor) or python variable): + 1). for x in range(var[*]|var.numpy()[*]) + 2). for x in var|var.numpy() + 3). for i, x enumerate(var|var.numpy()) + """ + + def __init__(self, for_node): + assert isinstance( + for_node, gast.For + ), "Input node for the initialization of ForNodeVisitor is not gast.For node." + # 1. original for node + self.node = for_node + + # 2. gast.For node main parts + self.target = for_node.target + # NOTE: type may be Node or list[Node] + self.iter_args = for_node.iter if self.is_for_iter( + ) else for_node.iter.args + self.body = for_node.body + + # 3. key shared node or names + # - x: + # - for x in range(***) + # - for x in var|var.numpy() + # - for i, x enumerate(var|var.numpy()) + self.iter_var_name = self._get_iter_var_name() + + # - created index var to slice Variable: __for_loop_var_index_0 + # - for x in var|var.numpy() + # - for i, x enumerate(var|var.numpy()) + self.iter_idx_name = unique_name.generate(FOR_ITER_INDEX_PREFIX) + + # - created shape var to build loop condition: __for_loop_var_len_0 + # - for x in var|var.numpy() + # - for i, x enumerate(var|var.numpy()) + # - for x in var + self.iter_var_len_name = unique_name.generate(FOR_ITER_VAR_LEN_PREFIX) + # - created zip to list var : __for_loop_iter_zip_0 + self.iter_zip_to_list_name = unique_name.generate( + FOR_ITER_ZIP_TO_LIST_PREFIX) + + # - var.numpy()/var + # - for x in var|var.numpy() + # - for i, x enumerate(var|var.numpy()) + self.iter_node = self._get_iter_node() + + # - enumeate i: + # - for i, x enumerate(var|var.numpy()) + self.enum_idx_name = self._get_enum_idx_name() + + # - range/enumerate args length + self.args_length = None + + def parse(self): + self._args_check() + if self.is_for_range_iter(): + return self._parse_for_range_stmts() + elif self.is_for_iter(): + return self._parse_for_stmts() + elif self.is_for_enumerate_iter(): + return self._parse_for_enumerate_stmts() + else: + return None + + def is_for_range_iter(self): + return isinstance(self.node.iter, gast.Call) and isinstance( + self.node.iter.func, + gast.Name) and self.node.iter.func.id == "range" + + def is_for_iter(self): + if isinstance(self.node.iter, + (gast.Name, gast.Attribute, gast.List, gast.Tuple)): + return True + elif isinstance(self.node.iter, gast.Call) and isinstance( + self.node.iter.func, + gast.Attribute) and self.node.iter.func.attr == 'numpy': + return True + elif isinstance(self.node.iter, gast.Subscript): + return True + else: + return False + + def is_for_enumerate_iter(self): + return isinstance(self.node.iter, gast.Call) and isinstance( + self.node.iter.func, + gast.Name) and self.node.iter.func.id == "enumerate" + + def _args_check(self): + if self.is_for_range_iter(): + self.args_length = len(self.iter_args) + assert self.args_length >= 1 and self.args_length <= 3, "range() function takes 1 to 3 arguments" + elif self.is_for_enumerate_iter(): + self.args_length = len(self.iter_args) + assert self.args_length >= 1 and self.args_length <= 2, "enumerate() function takes 1 to 2 arguments" + else: + self.args_length = None + + def _parse_for_range_stmts(self): + init_stmts = [] + init_stmts.append(self._build_index_init_node()) + + compare_node = self._build_compare_node() + step_node = self._build_step_node() + cond_stmt = self._build_cond_stmt(step_node, compare_node) + + body_stmts = self.body + body_stmts.append(self._build_index_increase_node(step_node)) + + return init_stmts, cond_stmt, body_stmts + + def _parse_for_stmts(self): + init_stmts = [] + init_stmts.extend(self._build_iter_node()) + init_stmts.append(self._build_index_init_node()) + init_stmts.append(self._build_var_len_assign_node()) + + compare_node = self._build_compare_node() + step_node = self._build_step_node() + cond_stmt = self._build_cond_stmt(step_node, compare_node) + + body_stmts = self.body + + # NOTE(liym27): Here add a gast.Assign, and the target of it is gast.Name. + # In NameNodeReplaceTransformer, using gast.Name to replace gast.Name is safe. + target_node, assign_node = self._build_assign_var_slice_node() + body_stmts[0:0] = [assign_node] + for body_node in body_stmts: + NameNodeReplaceTransformer(body_node, self.iter_var_name, + target_node) + body_stmts.append(self._build_index_increase_node(step_node)) + + return init_stmts, cond_stmt, body_stmts + + def _parse_for_enumerate_stmts(self): + init_stmts = [] + init_stmts.extend(self._build_iter_node()) + init_stmts.append(self._build_index_init_node()) + init_stmts.append(self._build_var_len_assign_node()) + init_stmts.append(self._build_enum_init_node()) + + compare_node = self._build_compare_node() + step_node = self._build_step_node() + cond_stmt = self._build_cond_stmt(step_node, compare_node) + + body_stmts = self.body + + target_node, assign_node = self._build_assign_var_slice_node() + body_stmts[0:0] = [assign_node] + for body_node in body_stmts: + NameNodeReplaceTransformer(body_node, self.iter_var_name, + target_node) + + body_stmts.append(self._build_index_increase_node(step_node)) + body_stmts.append(self._build_enum_increase_node()) + + return init_stmts, cond_stmt, body_stmts + + def _build_index_init_node(self): + if self.is_for_range_iter(): + if self.args_length == 1: + index_init_value_str = '0' + else: + index_init_value_str = ast_to_source_code( + self.iter_args[0]).strip() + + index_init_var_name = self.iter_var_name + else: + index_init_value_str = '0' + index_init_var_name = self.iter_idx_name + + index_init_node_source_str = "{target} = {value}".format( + target=index_init_var_name, value=index_init_value_str) + + index_init_node = gast.parse(index_init_node_source_str).body[0] + + return index_init_node + + def _build_var_len_assign_node(self): + # get the length of iterable variable + if isinstance(self.iter_node, gast.Call) and isinstance( + self.iter_node.func, + gast.Attribute) and self.iter_node.func.attr == 'numpy': + iter_var_name = ast_to_source_code( + self.iter_node.func.value).strip() + else: + iter_var_name = ast_to_source_code(self.iter_node).strip() + + convert_len_node_source_str = '{} = _jst.Len({})'.format( + self.iter_var_len_name, iter_var_name) + + convert_len_node = gast.parse(convert_len_node_source_str).body[0] + + return convert_len_node + + def _build_iter_node(self): + """ + Process special cases for iter_node inclue: + - Case 1 (for zip): + + - for i, val in enumerate(zip(x, y)) # original code: + + - __for_loop_iter_zip_0 = list(zip(x, y)) + - for i, val in enumerate(__for_loop_iter_zip_0) + """ + new_nodes = [] + if isinstance(self.iter_node, gast.Call) and isinstance( + self.iter_node.func, gast.Name): + if self.iter_node.func.id == 'zip': + iter_var_name = ast_to_source_code(self.iter_node).strip() + zip_to_list_str = "{target} = list({value})".format( + target=self.iter_zip_to_list_name, value=iter_var_name) + zip_to_list_node = gast.parse(zip_to_list_str).body[0] + new_nodes.append(zip_to_list_node) + + self.iter_node = gast.Name(id=self.iter_zip_to_list_name, + ctx=gast.Load(), + annotation=None, + type_comment=None) + + return new_nodes + + def _build_enum_init_node(self): + if self.is_for_enumerate_iter() and self.args_length != 1: + init_value_str = ast_to_source_code(self.iter_args[1]).strip() + else: + init_value_str = '0' + + enum_init_node_source_str = "{} = {}".format(self.enum_idx_name, + init_value_str) + enum_init_node = gast.parse(enum_init_node_source_str).body[0] + return enum_init_node + + def _build_compare_node(self): + if self.is_for_range_iter(): + compare_node = self.iter_args[ + 0] if self.args_length == 1 else self.iter_args[1] + else: + compare_node = gast.Name(id=self.iter_var_len_name, + ctx=gast.Load(), + annotation=None, + type_comment=None) + return compare_node + + def _build_step_node(self): + if self.is_for_range_iter(): + step_node = self.iter_args[ + 2] if self.args_length == 3 else gast.Constant(value=1, + kind=None) + else: + step_node = gast.Constant(value=1, kind=None) + return step_node + + def _build_cond_stmt(self, step_node, compare_node): + if not isinstance(step_node, (gast.Constant, gast.UnaryOp)): + raise NotImplementedError( + "Dynamic-to-Static only supports the step value is a constant or negative constant in 'for-range' statements, " + "such as '2', '-3'. But received: '{}'. Please fix code to be compatible with Dynamic-to-Static." + .format(ast_to_source_code(step_node).strip())) + + if isinstance(step_node, gast.UnaryOp) or step_node.value < 0: + # eg: + # range(max, min, -2) + # -> + # i > min + return gast.Compare(left=gast.Name( + id=self.iter_var_name + if self.is_for_range_iter() else self.iter_idx_name, + ctx=gast.Load(), + annotation=None, + type_comment=None), + ops=[gast.Gt()], + comparators=[compare_node]) + else: + # eg: + # range(min, max, 2) + # -> + # i < max + return gast.Compare(left=gast.Name( + id=self.iter_var_name + if self.is_for_range_iter() else self.iter_idx_name, + ctx=gast.Load(), + annotation=None, + type_comment=None), + ops=[gast.Lt()], + comparators=[compare_node]) + + def _build_index_increase_node(self, step_node): + return gast.AugAssign(target=gast.Name( + id=self.iter_var_name + if self.is_for_range_iter() else self.iter_idx_name, + ctx=gast.Store(), + annotation=None, + type_comment=None), + op=gast.Add(), + value=step_node) + + def _build_assign_var_slice_node(self): + var_slice_str = "{}[{}]".format( + ast_to_source_code(self.iter_node).strip(), self.iter_idx_name) + var_slice_node = gast.parse(var_slice_str).body[0].value + new_iter_var_name = unique_name.generate(FOR_ITER_VAR_NAME_PREFIX) + target_node, assign_node = create_assign_node(new_iter_var_name, + var_slice_node) + return target_node, assign_node + + def _build_enum_increase_node(self): + return gast.AugAssign(target=gast.Name(id=self.enum_idx_name, + ctx=gast.Store(), + annotation=None, + type_comment=None), + op=gast.Add(), + value=gast.Constant(value=1, kind=None)) + + def _get_iter_var_name(self): + if self.is_for_range_iter(): + return self.target.id + elif self.is_for_iter(): + return self.target.id + elif self.is_for_enumerate_iter(): + return self.target.elts[1].id + return None + + def _get_iter_node(self): + if self.is_for_iter(): + return self.iter_args + elif self.is_for_enumerate_iter(): + return self.iter_args[0] + return None + + def _get_enum_idx_name(self): + if self.is_for_enumerate_iter(): + return self.target.elts[0].id + return None diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py index 020721e85a2..b63fe6eea5a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py @@ -18,10 +18,10 @@ from paddle.utils import gast from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list -from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import BaseNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_node from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer +from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForNodeVisitor __all__ = ['BreakContinueTransformer'] diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py index e29ec6c6e1d..29e3ed52968 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py @@ -21,8 +21,8 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform -from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer +from paddle.fluid.dygraph.dygraph_to_static.base_transformer import SplitAssignTransformer class ListTransformer(BaseTransformer): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index f04161f2c34..29ac905074e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -25,14 +25,14 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysi from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name -from paddle.fluid.dygraph.dygraph_to_static.utils import ForLoopTuplePreTransformer -from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor -from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_nodes, create_get_args_node, create_set_args_node from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer +from paddle.fluid.dygraph.dygraph_to_static.base_transformer import RenameTransformer +from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForLoopTuplePreTransformer +from paddle.fluid.dygraph.dygraph_to_static.base_transformer import ForNodeVisitor __all__ = ['LoopTransformer', 'NameVisitor'] @@ -489,14 +489,15 @@ class LoopTransformer(BaseTransformer): self.name_visitor = NameVisitor(self.root) self.visit(self.root) - def visit(self, node): + def visit_While(self, node): self.generic_visit(node) - # All parent nodes that may contain gast.While/gast.For - if hasattr(node, 'body'): - self.replace_stmt_list(node.body) - if hasattr(node, 'orelse'): - self.replace_stmt_list(node.orelse) - return node + new_stmts = self.get_while_stmt_nodes(node) + return new_stmts + + def visit_For(self, node): + self.generic_visit(node) + new_stmts = self.get_for_stmt_nodes(node) + return new_stmts def replace_stmt_list(self, body_list): if not isinstance(body_list, list): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py b/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py index de126777683..93f089cf8dd 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py @@ -20,16 +20,13 @@ import inspect from paddle.utils import gast from paddle.fluid import core from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap +from paddle.fluid.dygraph.dygraph_to_static.utils import ORIGI_INFO from paddle.fluid.framework import Program try: from collections.abc import Sequence except: from collections import Sequence -# NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node. -ORIGI_INFO = "Original information of source code for ast node." -ORIGI_INFO_MAP = "Original information map of source code." - class Location(object): """ diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py index 2b95f346ae2..3eadd455e10 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py @@ -188,9 +188,7 @@ class ReturnTransformer(BaseTransformer): Self-defined visit for appending ancestor """ self.ancestor_nodes.append(node) - method = 'visit_' + node.__class__.__name__ - visitor = getattr(self, method, self.generic_visit) - ret = visitor(node) + ret = super(ReturnTransformer, self).visit(node) self.ancestor_nodes.pop() return ret diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 1c507ab23c3..9f390252f3a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -42,6 +42,8 @@ DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.fluid.dygraph.dygraph_to_static' GET_ARGS_FUNC_PREFIX = 'get_args' SET_ARGS_FUNC_PREFIX = 'set_args' ARGS_NAME = '__args' +# NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node. +ORIGI_INFO = "Original information of source code for ast node." class BaseNodeVisitor(gast.NodeVisitor): @@ -541,35 +543,6 @@ def create_assign_node(name, node): return targets, assign_node -class RenameTransformer(gast.NodeTransformer): - - def __init__(self, node): - assert isinstance( - node, gast.AST), "RenameTransformer only accepts gast.AST as input" - self.root = node - self.old_name = "" - self.new_name = "" - - def rename(self, old_name, new_name): - self.old_name = old_name - self.new_name = new_name - self.visit(self.root) - - def visit_Name(self, node): - self.generic_visit(node) - if node.id == self.old_name: - node.id = self.new_name - return node - - def visit_Attribute(self, node): - self.generic_visit(node) - attr_full_name = get_attribute_full_name(node) - if attr_full_name == self.old_name: - new_name_node = gast.parse(self.new_name).body[0].value - return new_name_node - return node - - def ast_to_func(ast_root, dyfunc, delete_on_exit=True): """ Transform modified AST of decorated function into python callable object. @@ -897,613 +870,6 @@ class IsControlFlowVisitor(gast.NodeVisitor): return self._compare_node_tenor_set -class NameNodeReplaceTransformer(gast.NodeTransformer): - """ - This class replaces specified gast.Name node by replace_node. - """ - - def __init__(self, root_node, target_name, replace_node): - assert isinstance(target_name, str) - - # NOTE(liym27): - # Use gast.Name to replace gast.Name, otherwise, errors may occur. - # - # For examples: - # If using a gast.Subscript to replace gast.Name, and the original gast.Name - # is in the arguments of FunctionDef, an exception will be raised. - # - # ``` - # def func(x[i])) # x[i] can not be a argument - # # ... - # ``` - - assert isinstance(replace_node, gast.Name) - self.target_name = target_name - self.replace_node = replace_node - - self.visit(root_node) - - def visit_Name(self, node): - if node.id == self.target_name: - return self.replace_node - return node - - def visit_Nonlocal(self, node): - names = node.names - - def replace(s): - if s == self.target_name: return self.replace_node.id - return s - - node.names = list(map(replace, names)) - return node - - -class ForLoopTuplePreTransformer(gast.NodeTransformer): - """ - ForNodeVisitor parses 3 type statements (Here var is VarBase(Tensor) or python variable): - 1). for x in range(var[*]|var.numpy()[*]) - 2). for x in var|var.numpy() - 3). for i, x in enumerate(var|var.numpy()) - - We chose these 3 types because they are easier (x can be variable name iterating in var). - However, users can write tuples in Python for loop, such as - 1). for var1, var2 in var|var.numpy() - 2). for t in enumerate(var|var.numpy()) - 2). for i, (var1, var2, va3) in enumerate(var|var.numpy()) - - To handle these case, this method will do the rewrite tuple pre-process: - 1). Non-enumerate case: for var1, var2 in var|var.numpy() will be re-written as: - for FOR_ITER_TUPLE_PREFIX_x in var | var.numpy(): - var1 = FOR_ITER_TUPLE_PREFIX_x[0] - var2 = FOR_ITER_TUPLE_PREFIX_x[1] - 2). Enumerate out tuple case: for t in enumerate(var|var.numpy) will be rewritten as: - for FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x in enumerate(var|var.numpy): - t = (FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x) - 3). Enumerate inner tuple case: for i, (var1, (var2, va3)) in enumerate(var|var.numpy()) will - be re-written as: - for i, FOR_ITER_TUPLE_PREFIX_x in var | var.numpy(): - var1 = FOR_ITER_TUPLE_PREFIX_x[0] - var2 = FOR_ITER_TUPLE_PREFIX_x[1][0] - var3 = FOR_ITER_TUPLE_PREFIX_x[1][1] - """ - - def __init__(self, wrapper_root): - self.wrapper_root = wrapper_root - self.root = wrapper_root.node - - def transform(self): - self.visit(self.root) - - def visit_For(self, node): - if self.is_for_enumerate_iter(node): - if isinstance(node.target, (gast.Name, gast.Attribute)): - # Out tuple case - out_tuple_name = ast_to_source_code(node.target).strip() - tuple_iter_name = unique_name.generate( - FOR_ITER_TUPLE_INDEX_PREFIX) - tuple_var_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX) - node.target = gast.Tuple(elts=[ - gast.Name(id=tuple_iter_name, - ctx=gast.Store(), - annotation=None, - type_comment=None), - gast.Name(id=tuple_var_name, - ctx=gast.Store(), - annotation=None, - type_comment=None) - ], - ctx=gast.Store()) - node.body.insert( - 0, - gast.Assign(targets=[ - gast.Name(id=out_tuple_name, - ctx=gast.Store(), - annotation=None, - type_comment=None) - ], - value=gast.Tuple(elts=[ - gast.Name(id=tuple_iter_name, - ctx=gast.Load(), - annotation=None, - type_comment=None), - gast.Name(id=tuple_var_name, - ctx=gast.Load(), - annotation=None, - type_comment=None) - ], - ctx=gast.Load()))) - elif isinstance(node.target, (gast.List, gast.Tuple)) and len( - node.target.elts) >= 2 and isinstance( - node.target.elts[1], (gast.List, gast.Tuple)): - # Inner tuple case - inner_tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX) - origin_inner_tuple_node = node.target.elts[1] - node.target.elts[1] = gast.Name(id=inner_tuple_name, - ctx=gast.Store(), - annotation=None, - type_comment=None) - node.body[0:0] = self.tuple_to_stmts(origin_inner_tuple_node, - inner_tuple_name) - elif self.is_for_iter(node) and isinstance(node.target, - (gast.List, gast.Tuple)): - # Non-enumrate case: - tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX) - origin_tuple_node = node.target - node.target = gast.Name(id=tuple_name, - ctx=gast.Store(), - annotation=None, - type_comment=None) - node.body[0:0] = self.tuple_to_stmts(origin_tuple_node, tuple_name) - return node - - def tuple_to_stmts(self, node, tuple_name, idx=[]): - if not isinstance(node, (gast.Tuple, gast.List)): - value_node_str = tuple_name - for i in idx: - value_node_str = value_node_str + "[{}]".format(i) - - node_str = ast_to_source_code(node).strip() - assign_node_str = "{} = {}".format(node_str, value_node_str) - assign_node = gast.parse(assign_node_str).body[0] - return [assign_node] - - # isinstance(node, (gast.Tuple, gast.List)) - ret = [] - for i, element in enumerate(node.elts): - ret += self.tuple_to_stmts(node.elts[i], tuple_name, idx + [i]) - return ret - - def is_for_iter(self, for_node): - assert isinstance(for_node, - gast.For), "Input node is not gast.For node." - if isinstance(for_node.iter, (gast.Name, gast.Attribute)): - return True - elif isinstance(for_node.iter, gast.Call) and isinstance( - for_node.iter.func, - gast.Attribute) and for_node.iter.func.attr == 'numpy': - return True - elif isinstance(for_node.iter, gast.Subscript): - return True - else: - return False - - def is_for_enumerate_iter(self, for_node): - assert isinstance(for_node, - gast.For), "Input node is not gast.For node." - return isinstance(for_node.iter, gast.Call) and isinstance( - for_node.iter.func, - gast.Name) and for_node.iter.func.id == "enumerate" - - -class ForNodeVisitor(object): - """ - This class parses python for statement, get transformed 3 statement components of for node - three key statements: - 1). init_stmts: list[node], prepare nodes of for loop, may not only one - 2). cond_stmt: node, condition node to judge whether continue loop - 3). body_stmts: list[node], updated loop body, sometimes we should change - the original statement in body, not just append new statement - - In this process, the semantics of for does not change. - - Now only can parse 3 type statements (Here var is VarBase(Tensor) or python variable): - 1). for x in range(var[*]|var.numpy()[*]) - 2). for x in var|var.numpy() - 3). for i, x enumerate(var|var.numpy()) - """ - - def __init__(self, for_node): - assert isinstance( - for_node, gast.For - ), "Input node for the initialization of ForNodeVisitor is not gast.For node." - # 1. original for node - self.node = for_node - - # 2. gast.For node main parts - self.target = for_node.target - # NOTE: type may be Node or list[Node] - self.iter_args = for_node.iter if self.is_for_iter( - ) else for_node.iter.args - self.body = for_node.body - - # 3. key shared node or names - # - x: - # - for x in range(***) - # - for x in var|var.numpy() - # - for i, x enumerate(var|var.numpy()) - self.iter_var_name = self._get_iter_var_name() - - # - created index var to slice Variable: __for_loop_var_index_0 - # - for x in var|var.numpy() - # - for i, x enumerate(var|var.numpy()) - self.iter_idx_name = unique_name.generate(FOR_ITER_INDEX_PREFIX) - - # - created shape var to build loop condition: __for_loop_var_len_0 - # - for x in var|var.numpy() - # - for i, x enumerate(var|var.numpy()) - # - for x in var - self.iter_var_len_name = unique_name.generate(FOR_ITER_VAR_LEN_PREFIX) - # - created zip to list var : __for_loop_iter_zip_0 - self.iter_zip_to_list_name = unique_name.generate( - FOR_ITER_ZIP_TO_LIST_PREFIX) - - # - var.numpy()/var - # - for x in var|var.numpy() - # - for i, x enumerate(var|var.numpy()) - self.iter_node = self._get_iter_node() - - # - enumeate i: - # - for i, x enumerate(var|var.numpy()) - self.enum_idx_name = self._get_enum_idx_name() - - # - range/enumerate args length - self.args_length = None - - def parse(self): - self._args_check() - if self.is_for_range_iter(): - return self._parse_for_range_stmts() - elif self.is_for_iter(): - return self._parse_for_stmts() - elif self.is_for_enumerate_iter(): - return self._parse_for_enumerate_stmts() - else: - return None - - def is_for_range_iter(self): - return isinstance(self.node.iter, gast.Call) and isinstance( - self.node.iter.func, - gast.Name) and self.node.iter.func.id == "range" - - def is_for_iter(self): - if isinstance(self.node.iter, - (gast.Name, gast.Attribute, gast.List, gast.Tuple)): - return True - elif isinstance(self.node.iter, gast.Call) and isinstance( - self.node.iter.func, - gast.Attribute) and self.node.iter.func.attr == 'numpy': - return True - elif isinstance(self.node.iter, gast.Subscript): - return True - else: - return False - - def is_for_enumerate_iter(self): - return isinstance(self.node.iter, gast.Call) and isinstance( - self.node.iter.func, - gast.Name) and self.node.iter.func.id == "enumerate" - - def _args_check(self): - if self.is_for_range_iter(): - self.args_length = len(self.iter_args) - assert self.args_length >= 1 and self.args_length <= 3, "range() function takes 1 to 3 arguments" - elif self.is_for_enumerate_iter(): - self.args_length = len(self.iter_args) - assert self.args_length >= 1 and self.args_length <= 2, "enumerate() function takes 1 to 2 arguments" - else: - self.args_length = None - - def _parse_for_range_stmts(self): - init_stmts = [] - init_stmts.append(self._build_index_init_node()) - - compare_node = self._build_compare_node() - step_node = self._build_step_node() - cond_stmt = self._build_cond_stmt(step_node, compare_node) - - body_stmts = self.body - body_stmts.append(self._build_index_increase_node(step_node)) - - return init_stmts, cond_stmt, body_stmts - - def _parse_for_stmts(self): - init_stmts = [] - init_stmts.extend(self._build_iter_node()) - init_stmts.append(self._build_index_init_node()) - init_stmts.append(self._build_var_len_assign_node()) - - compare_node = self._build_compare_node() - step_node = self._build_step_node() - cond_stmt = self._build_cond_stmt(step_node, compare_node) - - body_stmts = self.body - - # NOTE(liym27): Here add a gast.Assign, and the target of it is gast.Name. - # In NameNodeReplaceTransformer, using gast.Name to replace gast.Name is safe. - target_node, assign_node = self._build_assign_var_slice_node() - body_stmts[0:0] = [assign_node] - for body_node in body_stmts: - NameNodeReplaceTransformer(body_node, self.iter_var_name, - target_node) - body_stmts.append(self._build_index_increase_node(step_node)) - - return init_stmts, cond_stmt, body_stmts - - def _parse_for_enumerate_stmts(self): - init_stmts = [] - init_stmts.extend(self._build_iter_node()) - init_stmts.append(self._build_index_init_node()) - init_stmts.append(self._build_var_len_assign_node()) - init_stmts.append(self._build_enum_init_node()) - - compare_node = self._build_compare_node() - step_node = self._build_step_node() - cond_stmt = self._build_cond_stmt(step_node, compare_node) - - body_stmts = self.body - - target_node, assign_node = self._build_assign_var_slice_node() - body_stmts[0:0] = [assign_node] - for body_node in body_stmts: - NameNodeReplaceTransformer(body_node, self.iter_var_name, - target_node) - - body_stmts.append(self._build_index_increase_node(step_node)) - body_stmts.append(self._build_enum_increase_node()) - - return init_stmts, cond_stmt, body_stmts - - def _build_index_init_node(self): - if self.is_for_range_iter(): - if self.args_length == 1: - index_init_value_str = '0' - else: - index_init_value_str = ast_to_source_code( - self.iter_args[0]).strip() - - index_init_var_name = self.iter_var_name - else: - index_init_value_str = '0' - index_init_var_name = self.iter_idx_name - - index_init_node_source_str = "{target} = {value}".format( - target=index_init_var_name, value=index_init_value_str) - - index_init_node = gast.parse(index_init_node_source_str).body[0] - - return index_init_node - - def _build_var_len_assign_node(self): - # get the length of iterable variable - if isinstance(self.iter_node, gast.Call) and isinstance( - self.iter_node.func, - gast.Attribute) and self.iter_node.func.attr == 'numpy': - iter_var_name = ast_to_source_code( - self.iter_node.func.value).strip() - else: - iter_var_name = ast_to_source_code(self.iter_node).strip() - - convert_len_node_source_str = '{} = _jst.Len({})'.format( - self.iter_var_len_name, iter_var_name) - - convert_len_node = gast.parse(convert_len_node_source_str).body[0] - - return convert_len_node - - def _build_iter_node(self): - """ - Process special cases for iter_node inclue: - - Case 1 (for zip): - - - for i, val in enumerate(zip(x, y)) # original code: - - - __for_loop_iter_zip_0 = list(zip(x, y)) - - for i, val in enumerate(__for_loop_iter_zip_0) - """ - new_nodes = [] - if isinstance(self.iter_node, gast.Call) and isinstance( - self.iter_node.func, gast.Name): - if self.iter_node.func.id == 'zip': - iter_var_name = ast_to_source_code(self.iter_node).strip() - zip_to_list_str = "{target} = list({value})".format( - target=self.iter_zip_to_list_name, value=iter_var_name) - zip_to_list_node = gast.parse(zip_to_list_str).body[0] - new_nodes.append(zip_to_list_node) - - self.iter_node = gast.Name(id=self.iter_zip_to_list_name, - ctx=gast.Load(), - annotation=None, - type_comment=None) - - return new_nodes - - def _build_enum_init_node(self): - if self.is_for_enumerate_iter() and self.args_length != 1: - init_value_str = ast_to_source_code(self.iter_args[1]).strip() - else: - init_value_str = '0' - - enum_init_node_source_str = "{} = {}".format(self.enum_idx_name, - init_value_str) - enum_init_node = gast.parse(enum_init_node_source_str).body[0] - return enum_init_node - - def _build_compare_node(self): - if self.is_for_range_iter(): - compare_node = self.iter_args[ - 0] if self.args_length == 1 else self.iter_args[1] - else: - compare_node = gast.Name(id=self.iter_var_len_name, - ctx=gast.Load(), - annotation=None, - type_comment=None) - return compare_node - - def _build_step_node(self): - if self.is_for_range_iter(): - step_node = self.iter_args[ - 2] if self.args_length == 3 else gast.Constant(value=1, - kind=None) - else: - step_node = gast.Constant(value=1, kind=None) - return step_node - - def _build_cond_stmt(self, step_node, compare_node): - if not isinstance(step_node, (gast.Constant, gast.UnaryOp)): - raise NotImplementedError( - "Dynamic-to-Static only supports the step value is a constant or negative constant in 'for-range' statements, " - "such as '2', '-3'. But received: '{}'. Please fix code to be compatible with Dynamic-to-Static." - .format(ast_to_source_code(step_node).strip())) - - if isinstance(step_node, gast.UnaryOp) or step_node.value < 0: - # eg: - # range(max, min, -2) - # -> - # i > min - return gast.Compare(left=gast.Name( - id=self.iter_var_name - if self.is_for_range_iter() else self.iter_idx_name, - ctx=gast.Load(), - annotation=None, - type_comment=None), - ops=[gast.Gt()], - comparators=[compare_node]) - else: - # eg: - # range(min, max, 2) - # -> - # i < max - return gast.Compare(left=gast.Name( - id=self.iter_var_name - if self.is_for_range_iter() else self.iter_idx_name, - ctx=gast.Load(), - annotation=None, - type_comment=None), - ops=[gast.Lt()], - comparators=[compare_node]) - - def _build_index_increase_node(self, step_node): - return gast.AugAssign(target=gast.Name( - id=self.iter_var_name - if self.is_for_range_iter() else self.iter_idx_name, - ctx=gast.Store(), - annotation=None, - type_comment=None), - op=gast.Add(), - value=step_node) - - def _build_assign_var_slice_node(self): - var_slice_str = "{}[{}]".format( - ast_to_source_code(self.iter_node).strip(), self.iter_idx_name) - var_slice_node = gast.parse(var_slice_str).body[0].value - new_iter_var_name = unique_name.generate(FOR_ITER_VAR_NAME_PREFIX) - target_node, assign_node = create_assign_node(new_iter_var_name, - var_slice_node) - return target_node, assign_node - - def _build_enum_increase_node(self): - return gast.AugAssign(target=gast.Name(id=self.enum_idx_name, - ctx=gast.Store(), - annotation=None, - type_comment=None), - op=gast.Add(), - value=gast.Constant(value=1, kind=None)) - - def _get_iter_var_name(self): - if self.is_for_range_iter(): - return self.target.id - elif self.is_for_iter(): - return self.target.id - elif self.is_for_enumerate_iter(): - return self.target.elts[1].id - return None - - def _get_iter_node(self): - if self.is_for_iter(): - return self.iter_args - elif self.is_for_enumerate_iter(): - return self.iter_args[0] - return None - - def _get_enum_idx_name(self): - if self.is_for_enumerate_iter(): - return self.target.elts[0].id - return None - - -class SplitAssignTransformer(gast.NodeTransformer): - """ - This class transforms sequence assignments and multi-target assignments to normal assignments. - """ - - def __init__(self, ast_node): - assert isinstance(ast_node, gast.AST) - self.ast_root = ast_node - - def transform(self): - self.visit(self.ast_root) - - def visit_Assign(self, node): - target_nodes = node.targets - if len(target_nodes) == 1: - node = self._parse_sequence_assign(node) - else: - node = self._parse_multi_target_assign(node) - return node - - def _parse_sequence_assign(self, node): - """ - a, b = c, d - -> - a = c - b = d - """ - assert isinstance(node, gast.Assign) - - target_nodes = node.targets - value_node = node.value - if not isinstance(target_nodes[0], (gast.List, gast.Tuple)): - return node - if not isinstance(value_node, (gast.List, gast.Tuple)): - return node - - targets = node.targets[0].elts - values = node.value.elts - if len(targets) != len(values): - return node - - new_nodes = [] - for target, value in zip(targets, values): - assign_node = gast.Assign(targets=[target], value=value) - new_nodes.append(assign_node) - - return new_nodes - - def _parse_multi_target_assign(self, node): - """ - Example 1: - a = b = c - -> - b = c - a = b - - Example 2: - a, b = c, d = x - -> - c,d = x - a = c - b = d - """ - assert isinstance(node, gast.Assign) - - target_nodes = node.targets - value_node = node.value - new_nodes = [] - for target in reversed(target_nodes): - assign_node = gast.Assign(targets=[target], value=value_node) - # NOTE: Because assign_node can be sequence assign statement like `a,b = c,d`, - # it's necessary to visit this new assign_node - parsed_node = self.visit_Assign(assign_node) - if not isinstance(parsed_node, list): - parsed_node = [parsed_node] - - new_nodes.extend(parsed_node) - value_node = target - - return new_nodes - - # NOTE: inspect.unwrap() exits in PY3 but not in PY2. def unwrap(func): """ -- GitLab