From 12e6dfcf1c9f0a2397e3daf66e6219cd696d6fbb Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 26 Oct 2022 17:08:49 +0800 Subject: [PATCH] [Cherry-Pick][Dy2Stat]Fix module loading OSError in multiprocess (#47302) [Dy2Stat]Fix module loading OSError in multiprocess --- .../fluid/dygraph/dygraph_to_static/utils.py | 557 +++++++++++------- 1 file changed, 333 insertions(+), 224 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 8d403300671..1d97a81e616 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -53,7 +53,7 @@ ORIGI_INFO = "Original information of source code for ast node." class BaseNodeVisitor(gast.NodeVisitor): """ - Implement customized NodeVisitor inherited from gast.NodeVisitor. + Implement customized NodeVisitor inherited from gast.NodeVisitor. Ancestor nodes are traced to easily support more operations of currently visited node. """ @@ -100,10 +100,18 @@ RE_PYMODULE = r'[a-zA-Z0-9_]+\.' # FullArgSpec is valid from Python3. Defined a Namedtuple to # to make it available in Python2. -FullArgSpec = collections.namedtuple('FullArgSpec', [ - 'args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', 'kwonlydefaults', - 'annotations' -]) +FullArgSpec = collections.namedtuple( + 'FullArgSpec', + [ + 'args', + 'varargs', + 'varkw', + 'defaults', + 'kwonlyargs', + 'kwonlydefaults', + 'annotations', + ], +) def data_layer_not_check(name, shape, dtype='float32', lod_level=0): @@ -113,7 +121,7 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): data can be various-length. This API is used in translating dygraph into static graph. - Note: + Note: The default :code:`stop_gradient` attribute of the Tensor created by this API is true, which means the gradient won't be passed backward through the data Tensor. Set :code:`var.stop_gradient = False` If @@ -124,7 +132,7 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): for more details. shape (list|tuple): List|Tuple of integers declaring the shape. You can set "None" at a dimension to indicate the dimension can be of any - size. For example, it is useful to set changeable batch size as "None" + size. For example, it is useful to set changeable batch size as "None" dtype (np.dtype|VarType|str, optional): The type of the data. Supported dtype: bool, float16, float32, float64, int8, int16, int32, int64, uint8. Default: float32 @@ -141,20 +149,26 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): if shape[i] is None: shape[i] = -1 - return helper.create_global_variable(name=name, - shape=shape, - dtype=dtype, - type=core.VarDesc.VarType.LOD_TENSOR, - stop_gradient=True, - lod_level=lod_level, - is_data=True, - need_check_feed=False) + return helper.create_global_variable( + name=name, + shape=shape, + dtype=dtype, + type=core.VarDesc.VarType.LOD_TENSOR, + stop_gradient=True, + lod_level=lod_level, + is_data=True, + need_check_feed=False, + ) def create_undefined_variable(): - from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM - var = data_layer_not_check(unique_name.generate("undefined_var"), [1], - "float64") + from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ( + RETURN_NO_VALUE_MAGIC_NUM, + ) + + var = data_layer_not_check( + unique_name.generate("undefined_var"), [1], "float64" + ) var.stop_gradient = False # the variable is created in block(0), we append assign in block(0) either. helper = LayerHelper('create_undefined_variable', **locals()) @@ -166,17 +180,16 @@ def create_undefined_variable(): class UndefinedVar: - def __init__(self, name): self.name = name def check(self): raise UnboundLocalError( - "local variable '{}' should be created before using it.") + "local variable '{}' should be created before using it." + ) class Dygraph2StaticException(Exception): - def __init__(self, message): super().__init__(message) @@ -193,13 +206,15 @@ def getfullargspec(target): return inspect.getfullargspec(target) else: argspec = inspect.getargspec(target) - return FullArgSpec(args=argspec.args, - varargs=argspec.varargs, - varkw=argspec.keywords, - defaults=argspec.defaults, - kwonlyargs=[], - kwonlydefaults=None, - annotations={}) + return FullArgSpec( + args=argspec.args, + varargs=argspec.varargs, + varkw=argspec.keywords, + defaults=argspec.defaults, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ) def parse_arg_and_kwargs(function): @@ -216,7 +231,7 @@ def parse_arg_and_kwargs(function): default_values = fullargspec.defaults if default_values: assert len(default_values) <= len(arg_names) - default_kwarg_names = arg_names[-len(default_values):] + default_kwarg_names = arg_names[-len(default_values) :] default_kwargs = dict(zip(default_kwarg_names, default_values)) return arg_names, default_kwargs @@ -290,8 +305,9 @@ def is_api_in_module(node, module_prefix): from paddle.fluid.dygraph import to_variable from paddle import to_tensor - return eval("_is_api_in_module_helper({}, '{}')".format( - func_str, module_prefix)) + return eval( + "_is_api_in_module_helper({}, '{}')".format(func_str, module_prefix) + ) except Exception: return False @@ -322,8 +338,10 @@ def is_numpy_api(node): func_str = astor.to_source(gast.gast_to_ast(node.func)) try: import numpy as np - module_result = eval("_is_api_in_module_helper({}, '{}')".format( - func_str, "numpy")) + + module_result = eval( + "_is_api_in_module_helper({}, '{}')".format(func_str, "numpy") + ) # BUG: np.random.uniform doesn't have module and cannot be analyzed # TODO: find a better way if not module_result: @@ -332,18 +350,19 @@ def is_numpy_api(node): return False -def is_control_flow_to_transform(node, - static_analysis_visitor=None, - var_name_to_type=None): +def is_control_flow_to_transform( + node, static_analysis_visitor=None, var_name_to_type=None +): """ Determines whether the node is a PaddlePaddle control flow statement which needs to be transformed into a static graph control flow statement. """ - assert isinstance(node, gast.AST), \ - "The type of input node must be gast.AST, but received %s." % type(node) - visitor = IsControlFlowVisitor(node, - static_analysis_visitor, - node_var_type_map=var_name_to_type) + assert isinstance( + node, gast.AST + ), "The type of input node must be gast.AST, but received %s." % type(node) + visitor = IsControlFlowVisitor( + node, static_analysis_visitor, node_var_type_map=var_name_to_type + ) need_to_transform = visitor.transform() return need_to_transform @@ -352,6 +371,7 @@ def _delete_keywords_from(node): assert isinstance(node, gast.Call) func_src = astor.to_source(gast.gast_to_ast(node.func)) import paddle.fluid as fluid + full_args = eval("inspect.getargspec({})".format(func_src)) full_args_name = full_args[0] @@ -365,7 +385,8 @@ def to_static_api(dygraph_class): else: raise NotImplementedError( "Paddle dygraph API {} cannot be converted " - "to static graph at present.".format(dygraph_class)) + "to static graph at present.".format(dygraph_class) + ) def _add_keywords_to(node, dygraph_api_name): @@ -376,8 +397,10 @@ def _add_keywords_to(node, dygraph_api_name): ast_keyword.arg = "size" node.keywords.append( - gast.keyword(arg="num_flatten_dims", - value=gast.Constant(value=-1, kind=None))) + gast.keyword( + arg="num_flatten_dims", value=gast.Constant(value=-1, kind=None) + ) + ) if dygraph_api_name == "BilinearTensorProduct": for ast_keyword in node.keywords: @@ -396,15 +419,17 @@ def to_static_ast(node, class_node): assert isinstance(class_node, gast.Call) static_api = to_static_api(class_node.func.attr) - node.func = gast.Attribute(attr=static_api, - ctx=gast.Load(), - value=gast.Attribute(attr='layers', - ctx=gast.Load(), - value=gast.Name( - ctx=gast.Load(), - id='fluid', - annotation=None, - type_comment=None))) + node.func = gast.Attribute( + attr=static_api, + ctx=gast.Load(), + value=gast.Attribute( + attr='layers', + ctx=gast.Load(), + value=gast.Name( + ctx=gast.Load(), id='fluid', annotation=None, type_comment=None + ), + ), + ) update_args_of_func(node, class_node, 'forward') @@ -427,10 +452,13 @@ def update_args_of_func(node, dygraph_node, method_name): class_src = astor.to_source(gast.gast_to_ast(dygraph_node.func)) import paddle.fluid as fluid + if method_name == "__init__" or eval( - "issubclass({}, fluid.dygraph.Layer)".format(class_src)): - full_args = eval("inspect.getargspec({}.{})".format( - class_src, method_name)) + "issubclass({}, fluid.dygraph.Layer)".format(class_src) + ): + full_args = eval( + "inspect.getargspec({}.{})".format(class_src, method_name) + ) full_args_name = [ arg_name for arg_name in full_args[0] if arg_name != "self" ] @@ -445,21 +473,24 @@ def update_args_of_func(node, dygraph_node, method_name): def create_api_shape_node(tensor_shape_node): - assert isinstance(tensor_shape_node, - (gast.Name, gast.Attribute, gast.Subscript)) + assert isinstance( + tensor_shape_node, (gast.Name, gast.Attribute, gast.Subscript) + ) if isinstance(tensor_shape_node, gast.Name): api_shape_node = gast.Call( func=gast.parse('paddle.shape').body[0].value, args=[tensor_shape_node], - keywords=[]) + keywords=[], + ) return api_shape_node if isinstance(tensor_shape_node, gast.Attribute): api_shape_node = gast.Call( func=gast.parse('paddle.shape').body[0].value, args=[tensor_shape_node.value], - keywords=[]) + keywords=[], + ) return api_shape_node if isinstance(tensor_shape_node, gast.Subscript): @@ -469,14 +500,15 @@ def create_api_shape_node(tensor_shape_node): def get_constant_variable_node(name, value, shape=[1], dtype='int64'): - return gast.parse('%s = paddle.full(%s, "%s", %s)' % - (name, str(shape), str(value), dtype)) + return gast.parse( + '%s = paddle.full(%s, "%s", %s)' % (name, str(shape), str(value), dtype) + ) def get_attribute_full_name(node): assert isinstance( - node, - gast.Attribute), "Input non-Attribute node to get attribute full name" + node, gast.Attribute + ), "Input non-Attribute node to get attribute full name" return astor.to_source(gast.gast_to_ast(node)).strip() @@ -494,15 +526,15 @@ def generate_name_node(name_ids, ctx=gast.Load(), gen_tuple_if_single=False): name_ids = [name_ids] if not isinstance(name_ids, (list, tuple, set)): raise TypeError( - 'name_ids must be list or tuple or set, but received %s' % - type(type(name_ids))) + 'name_ids must be list or tuple or set, but received %s' + % type(type(name_ids)) + ) def create_node_for_name(name): if '.' not in name: - return gast.Name(id=name, - ctx=ctx, - annotation=None, - type_comment=None) + return gast.Name( + id=name, ctx=ctx, annotation=None, type_comment=None + ) return gast.parse(name).body[0].value gast_names = [create_node_for_name(name_id) for name_id in name_ids] @@ -524,12 +556,14 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids): nodes.append(gast.Return(value=generate_name_node(return_name_ids))) else: nodes.append(gast.Return(value=None)) - func_def_node = gast.FunctionDef(name=name, - args=input_args, - body=nodes, - decorator_list=[], - returns=None, - type_comment=None) + func_def_node = gast.FunctionDef( + name=name, + args=input_args, + body=nodes, + decorator_list=[], + returns=None, + type_comment=None, + ) return func_def_node @@ -554,8 +588,9 @@ def get_temp_dir(): """ Return @to_static temp directory. """ - dir_name = "paddle/to_static_tmp" + dir_name = "paddle/to_static_tmp/{pid}".format(pid=os.getpid()) temp_dir = os.path.join(os.path.expanduser('~/.cache'), dir_name) + is_windows = sys.platform.startswith('win') if is_windows: temp_dir = os.path.normpath(temp_dir) @@ -589,12 +624,14 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): source = ast_to_source_code(ast_root) source = _inject_import_statements() + source temp_dir = get_temp_dir() - f = tempfile.NamedTemporaryFile(mode='w', - prefix=func_prefix(dyfunc), - suffix='.py', - delete=False, - dir=temp_dir, - encoding='utf-8') + f = tempfile.NamedTemporaryFile( + mode='w', + prefix=func_prefix(dyfunc), + suffix='.py', + delete=False, + dir=temp_dir, + encoding='utf-8', + ) with f: module_name = os.path.basename(f.name[:-3]) f.write(source) @@ -616,8 +653,9 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): callable_func = getattr(module, func_name) else: raise ValueError( - 'Function: %s doesn\'t exist in the Module transformed from AST.' % - func_name) + 'Function: %s doesn\'t exist in the Module transformed from AST.' + % func_name + ) # After transform dygraph function into callable_func saved in tmp file, # it lost the global variables from imported statements or defined in source file. # Recovers the necessary variables by `__globals__`. @@ -628,10 +666,14 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): def _inject_import_statements(): import_statements = [ - "import paddle", "from paddle import Tensor", - "import paddle.fluid as fluid", "import paddle.jit.dy2static as _jst", - "from typing import *", "import numpy as np", "import warnings", - "warnings.filterwarnings('ignore', category=DeprecationWarning)" + "import paddle", + "from paddle import Tensor", + "import paddle.fluid as fluid", + "import paddle.jit.dy2static as _jst", + "from typing import *", + "import numpy as np", + "import warnings", + "warnings.filterwarnings('ignore', category=DeprecationWarning)", ] return '\n'.join(import_statements) + '\n' @@ -654,8 +696,10 @@ def func_to_source_code(function, dedent=True): """ if not (inspect.isfunction(function) or inspect.ismethod(function)): raise TypeError( - "The type of 'function' should be a function or method, but received {}." - .format(type(function).__name__)) + "The type of 'function' should be a function or method, but received {}.".format( + type(function).__name__ + ) + ) source_code_list, _ = inspect.getsourcelines(function) # Replace comments with blank lines so that error messages are not misplaced source_code_list = [ @@ -675,8 +719,9 @@ def ast_to_source_code(ast_node): """ if not isinstance(ast_node, (gast.AST, ast.AST)): raise TypeError( - "Type of ast_root should be gast.AST or ast.AST, but received %s." % - type(ast_node)) + "Type of ast_root should be gast.AST or ast.AST, but received %s." + % type(ast_node) + ) if isinstance(ast_node, gast.AST): ast_node = gast.gast_to_ast(ast_node) @@ -692,8 +737,17 @@ def is_candidate_node(node): """ Nodes with specified type will be dependent on tensor. """ - is_compare_node = isinstance(node, (gast.Compare, gast.BoolOp, gast.UnaryOp, - gast.For, gast.If, gast.While)) + is_compare_node = isinstance( + node, + ( + gast.Compare, + gast.BoolOp, + gast.UnaryOp, + gast.For, + gast.If, + gast.While, + ), + ) # TODO(Aurelius84): `.numpy()` may be an customized function, # and should consider a more elegant way to solve this problem. has_numpy_attr = ".numpy()" in ast_to_source_code(node) @@ -709,9 +763,9 @@ def compare_with_none(node): # node.comparators is a list. if isinstance(child, list): child = child[0] - if (isinstance(child, gast.Constant) - and child.value is None) or (isinstance(child, gast.Name) - and child.id == 'None'): + if (isinstance(child, gast.Constant) and child.value is None) or ( + isinstance(child, gast.Name) and child.id == 'None' + ): return True return False @@ -746,20 +800,22 @@ class IsControlFlowVisitor(gast.NodeVisitor): because reshape_op may be called before this statement. """ - def __init__(self, - ast_node, - static_analysis_visitor=None, - node_var_type_map=None): + def __init__( + self, ast_node, static_analysis_visitor=None, node_var_type_map=None + ): assert isinstance( ast_node, gast.AST ), "Type of input node should be gast.AST, but received %s." % type( - ast_node) + ast_node + ) self.ast_root = ast_node if static_analysis_visitor is None: from .static_analysis import StaticAnalysisVisitor + static_analysis_visitor = StaticAnalysisVisitor(ast_node) self.static_analysis_visitor = static_analysis_visitor - self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( + self.node_to_wrapper_map = ( + self.static_analysis_visitor.get_node_to_wrapper_map() ) self.node_var_type_map = node_var_type_map @@ -788,7 +844,10 @@ class IsControlFlowVisitor(gast.NodeVisitor): if isinstance(node.iter, gast.Call): # for in range(var[0]|var.numpy()[0]) or for in enumerate(var|var.numpy()) if isinstance(node.iter.func, gast.Name): - if node.iter.func.id == "range" or node.iter.func.id == "enumerate": + if ( + node.iter.func.id == "range" + or node.iter.func.id == "enumerate" + ): for arg in node.iter.args: self.visit(arg) else: @@ -887,7 +946,9 @@ class IsControlFlowVisitor(gast.NodeVisitor): return node def _is_node_with_tensor(self, node, name_id): - from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType + from paddle.fluid.dygraph.dygraph_to_static.static_analysis import ( + NodeVarType, + ) # Look up the node_var_type_map by name_id. if self.node_var_type_map: @@ -917,7 +978,7 @@ def unwrap(func): return hasattr(f, '__wrapped__') unwrapped_f = func - while (_is_wrapped(unwrapped_f)): + while _is_wrapped(unwrapped_f): unwrapped_f = unwrapped_f.__wrapped__ return unwrapped_f @@ -941,10 +1002,12 @@ def input_specs_compatible(src_input_specs, desired_input_specs): if spec not in desired_input_specs: return False else: - for (src_spec, desired_spec) in zip(src_input_specs, - desired_input_specs): + for (src_spec, desired_spec) in zip( + src_input_specs, desired_input_specs + ): if isinstance(src_spec, paddle.static.InputSpec) or isinstance( - desired_spec, paddle.static.InputSpec): + desired_spec, paddle.static.InputSpec + ): if not _compatible_tensor_spec(src_spec, desired_spec): return False else: @@ -1029,15 +1092,14 @@ def slice_is_num(slice_node): class NameScope: - def __init__(self): - """ - A NameScope is a object which manager all the variable names. - only FunctionDef and Controlflow node will have a namescope property. + """ + A NameScope is a object which manager all the variable names. + only FunctionDef and Controlflow node will have a namescope property. - type can be "function" and "controlflow" + type can be "function" and "controlflow" - we don't analyze the read only variable because they don't affect the analysis. + we don't analyze the read only variable because they don't affect the analysis. """ self.globals = set() self.nonlocals = set() @@ -1053,8 +1115,8 @@ class NameScope: self.father = father def existed_vars(self): - """ vars existing in current scope. - they must not contain qualified names. + """vars existing in current scope. + they must not contain qualified names. """ local_vars = self.w_vars - self.globals - self.nonlocals - self.args return set(filter(lambda x: '.' not in x, local_vars)) @@ -1067,9 +1129,9 @@ class NameScope: return self.w_vars def variadic_length_vars(self): - """ + """ At present, we do not support global append, such as - + import numpy as np a = [] def func(): @@ -1083,33 +1145,38 @@ class NameScope: f"Find variable `{var}` defined in global scope" f" and call `{var}.append() or {var}.pop()`" f", which will be ignored and never be transfered into" - f" tensor array.") + f" tensor array." + ) else: non_global_push_pop_names.append(var) return set(non_global_push_pop_names) def control_flow_vars(self): valid_names = self.w_vars - tmp = self.father.global_vars & valid_names, + tmp = (self.father.global_vars & valid_names,) return {"global": tmp, "nonlocal": self.w_vars - tmp} def _is_simple_name(self, name): - if '.' in name or '[' in name: return False + if '.' in name or '[' in name: + return False return True def is_global_var(self, name): - """ + """ Return whether the name is a var created in global scope. - Search from bottom to top. If it is not created or modified, + Search from bottom to top. If it is not created or modified, it means global vars; otherwise, it means local vars. Only valid after FunctionNameLivenessAnalysis visitor. """ assert self._is_simple_name( - name), "is_global_var accept a simple name, but get `{name}`." + name + ), "is_global_var accept a simple name, but get `{name}`." ancestor = self while ancestor is not None: - if name in ancestor.globals: return True - if name in (ancestor.nonlocals | ancestor.w_vars): return False + if name in ancestor.globals: + return True + if name in (ancestor.nonlocals | ancestor.w_vars): + return False ancestor = ancestor.father return True @@ -1125,46 +1192,46 @@ class NameScope: class FunctionNameLivenessAnalysis(gast.NodeVisitor): - """ analyze the liveness of a function. - - every variables stored in this scope will be collected, - in addition with global/nonlocal information and - push_pop information. - - 1. global variable is stored in node.var_globals. - 2. nonlocal variable is stored in node.var_nonlocals. - 3. arguments is stored in node.var_args. - 4. if a variable's push and pop attribute is called, - it will be collected in push_pop_vars. They are - used for transformation to tensor_array. - NOTE: push_pop_vars **may not** in w_vars. - a.push(0) don't modify the variable a, but the content - of a. - - For example: - - def func(*args, **kargs): - a = 12 - global i,j - nonlocal x,y - print(a) - i = k - b = [] - c = [1,2,3] - for m in range(10): - q = 12 - b.push(1) - c.pop() - - After this visitor we have: - # node is the FunctionDef node with name: "func" - node.pd_scope = NameScope( - globals = ['i', 'j'], - nonlocals = ['x', 'y'], - args = ['args', 'kargs'], - wr_vars = ['a', 'i', 'q', 'm', 'c', 'b'] - push_pop_vars = ['b', 'c'] - ) + """analyze the liveness of a function. + + every variables stored in this scope will be collected, + in addition with global/nonlocal information and + push_pop information. + + 1. global variable is stored in node.var_globals. + 2. nonlocal variable is stored in node.var_nonlocals. + 3. arguments is stored in node.var_args. + 4. if a variable's push and pop attribute is called, + it will be collected in push_pop_vars. They are + used for transformation to tensor_array. + NOTE: push_pop_vars **may not** in w_vars. + a.push(0) don't modify the variable a, but the content + of a. + + For example: + + def func(*args, **kargs): + a = 12 + global i,j + nonlocal x,y + print(a) + i = k + b = [] + c = [1,2,3] + for m in range(10): + q = 12 + b.push(1) + c.pop() + + After this visitor we have: + # node is the FunctionDef node with name: "func" + node.pd_scope = NameScope( + globals = ['i', 'j'], + nonlocals = ['x', 'y'], + args = ['args', 'kargs'], + wr_vars = ['a', 'i', 'q', 'm', 'c', 'b'] + push_pop_vars = ['b', 'c'] + ) """ def __init__(self, root_node): @@ -1184,25 +1251,26 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): return self._get_name_scope(self.scope_node_stack[-1]) def _father_name_scope(self): - if len(self.scope_node_stack) == 1: return None + if len(self.scope_node_stack) == 1: + return None return self._get_name_scope(self.scope_node_stack[-2]) def _nearest_function_scope(self): - if len(self.scope_node_stack) == 1: return None + if len(self.scope_node_stack) == 1: + return None for node in self.scope_node_stack[-2::-1]: if isinstance(node, gast.FunctionDef): return self._get_name_scope(node) def visit_ListComp(self, node): - """ [ i for i in range(10) ] - In this case, `i` will not created in FunctionScope. - We don't collect `i` by not calling generic_visit. + """[ i for i in range(10) ] + In this case, `i` will not created in FunctionScope. + We don't collect `i` by not calling generic_visit. """ pass def visit_DictComp(self, node): - """ the same as ListComp. - """ + """the same as ListComp.""" pass def visit_Name(self, node): @@ -1212,62 +1280,86 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): self._current_name_scope().w_vars.add(node.id) def visit_FunctionDef(self, node): - def pre_func(): self._current_name_scope().args |= set( - self._get_argument_names(node)) + self._get_argument_names(node) + ) def post_func(): - """ NOTE: why we need merge w_vars and push_pop_vars here ? - because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope. + """NOTE: why we need merge w_vars and push_pop_vars here ? + because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope. """ - from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import WHILE_CONDITION_PREFIX, WHILE_BODY_PREFIX, FOR_CONDITION_PREFIX, FOR_BODY_PREFIX - from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import TRUE_FUNC_PREFIX, FALSE_FUNC_PREFIX + from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import ( + WHILE_CONDITION_PREFIX, + WHILE_BODY_PREFIX, + FOR_CONDITION_PREFIX, + FOR_BODY_PREFIX, + ) + from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ( + TRUE_FUNC_PREFIX, + FALSE_FUNC_PREFIX, + ) + control_flow_function_def = [ - WHILE_BODY_PREFIX, WHILE_BODY_PREFIX, FOR_CONDITION_PREFIX, - FOR_BODY_PREFIX, TRUE_FUNC_PREFIX, FALSE_FUNC_PREFIX + WHILE_BODY_PREFIX, + WHILE_BODY_PREFIX, + FOR_CONDITION_PREFIX, + FOR_BODY_PREFIX, + TRUE_FUNC_PREFIX, + FALSE_FUNC_PREFIX, ] def is_control_flow_def_node(): for prefix in control_flow_function_def: - if node.name.startswith(prefix): return True + if node.name.startswith(prefix): + return True return False if self._father_name_scope() and is_control_flow_def_node(): - self._father_name_scope().w_vars |= self._current_name_scope( - ).w_vars - self._father_name_scope( - ).push_pop_vars |= self._current_name_scope().push_pop_vars + self._father_name_scope().w_vars |= ( + self._current_name_scope().w_vars + ) + self._father_name_scope().push_pop_vars |= ( + self._current_name_scope().push_pop_vars + ) self._visit_scope_node(node, pre_func, post_func) def _visit_scope_node(self, node, pre_func, post_func): - """ scope node main visit logic. - pre_func and post_func is callbacks + """scope node main visit logic. + pre_func and post_func is callbacks """ self._reset_name_scope(node) self.scope_node_stack.append(node) self._current_name_scope().set_father(self._nearest_function_scope()) - if pre_func: pre_func() + if pre_func: + pre_func() self.generic_visit(node) - if post_func: post_func() + if post_func: + post_func() self.scope_node_stack.pop() def _visit_controlflow_node(self, node): - def post_func(): self._father_name_scope().merge_from(self._current_name_scope()) self._nearest_function_scope().merge_from( - self._current_name_scope()) - self._current_name_scope().created = self._nearest_function_scope( - ).existed_vars() - node.before_created + self._current_name_scope() + ) + self._current_name_scope().created = ( + self._nearest_function_scope().existed_vars() + - node.before_created + ) # gather created vars into father and used in CreateUndefinedVarTransform - self._nearest_function_scope().created |= self._current_name_scope( - ).created + self._nearest_function_scope().created |= ( + self._current_name_scope().created + ) def pre_func(): - setattr(node, "before_created", - self._nearest_function_scope().existed_vars()) + setattr( + node, + "before_created", + self._nearest_function_scope().existed_vars(), + ) self._visit_scope_node(node, pre_func, post_func) @@ -1305,12 +1397,13 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): self._current_name_scope().push_pop_vars.add(name) def _get_argument_names(self, node): - """ get all arguments name in the functiondef node. - this node is local to the function and shouldn't - be created. + """get all arguments name in the functiondef node. + this node is local to the function and shouldn't + be created. """ assert isinstance( - node, gast.FunctionDef), "Input node is not function define node" + node, gast.FunctionDef + ), "Input node is not function define node" names = [a for a in node.args.args] names.append(node.args.vararg) names.append(node.args.kwarg) @@ -1331,7 +1424,9 @@ def create_get_args_node(names): func_def = """ def {func_name}(): return - """.format(func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX)) + """.format( + func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX) + ) return gast.parse(textwrap.dedent(func_def)).body[0] assert isinstance(names, (list, tuple)) @@ -1350,7 +1445,8 @@ def create_get_args_node(names): func_def = template.format( func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX), nonlocal_vars=nonlocal_vars, - vars=",".join(names)) + vars=",".join(names), + ) return gast.parse(textwrap.dedent(func_def)).body[0] @@ -1367,8 +1463,9 @@ def create_set_args_node(names): func_def = """ def {func_name}({args}): pass - """.format(func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), - args=ARGS_NAME) + """.format( + func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), args=ARGS_NAME + ) return gast.parse(textwrap.dedent(func_def)).body[0] assert isinstance(names, (list, tuple)) @@ -1388,7 +1485,8 @@ def create_set_args_node(names): func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), args=ARGS_NAME, nonlocal_vars=nonlocal_vars, - vars=",".join(names)) + vars=",".join(names), + ) return gast.parse(textwrap.dedent(func_def)).body[0] @@ -1398,8 +1496,8 @@ def create_nonlocal_stmt_nodes(names): mapped = list(filter(lambda n: '.' not in n, names)) mapped = list(filter(lambda n: '[' not in n, mapped)) names = sorted( - mapped, - key=mapped.index) # to keep the order, we can't use set() to unique + mapped, key=mapped.index + ) # to keep the order, we can't use set() to unique if not names: return [] func_code = "nonlocal {}".format(','.join(names)) @@ -1407,10 +1505,10 @@ def create_nonlocal_stmt_nodes(names): class GetterSetterHelper: - """ we have two classes of names in setter and getter function: - w_vars(loop_vars) + push_pop_vars - To simplify the setter logic in convert_while and convert_cond, - we extract the helper class here. + """we have two classes of names in setter and getter function: + w_vars(loop_vars) + push_pop_vars + To simplify the setter logic in convert_while and convert_cond, + we extract the helper class here. """ def __init__(self, getter_func, setter_func, *name_lists): @@ -1426,22 +1524,33 @@ class GetterSetterHelper: return self._union def get(self, names): - if names is None: names = [] + if names is None: + names = [] vars = self.getter() - if vars is None: return tuple() + if vars is None: + return tuple() for n in names: - assert n in self.name2id, "the name `{}` not in name union set`{}`.".format( - n, self.name2id.keys()) + assert ( + n in self.name2id + ), "the name `{}` not in name union set`{}`.".format( + n, self.name2id.keys() + ) return tuple(map(lambda n: vars[self.name2id[n]], names)) def set(self, names, values): - if names is None: names = [] - if values is None: values = [] + if names is None: + names = [] + if values is None: + values = [] vars = self.getter() - if vars is None: return + if vars is None: + return for n in names: - assert n in self.name2id, "the name `{}` not in name union set`{}`.".format( - n, self.name2id.keys()) + assert ( + n in self.name2id + ), "the name `{}` not in name union set`{}`.".format( + n, self.name2id.keys() + ) vars = list(vars) indices = list(map(lambda n: self.name2id[n], names)) for i, v in zip(indices, values): -- GitLab