未验证 提交 12e6dfcf 编写于 作者: A Aurelius84 提交者: GitHub

[Cherry-Pick][Dy2Stat]Fix module loading OSError in multiprocess (#47302)

[Dy2Stat]Fix module loading OSError in multiprocess
上级 7c6550a6
......@@ -53,7 +53,7 @@ ORIGI_INFO = "Original information of source code for ast node."
class BaseNodeVisitor(gast.NodeVisitor):
"""
Implement customized NodeVisitor inherited from gast.NodeVisitor.
Implement customized NodeVisitor inherited from gast.NodeVisitor.
Ancestor nodes are traced to easily support more operations of currently
visited node.
"""
......@@ -100,10 +100,18 @@ RE_PYMODULE = r'[a-zA-Z0-9_]+\.'
# FullArgSpec is valid from Python3. Defined a Namedtuple to
# to make it available in Python2.
FullArgSpec = collections.namedtuple('FullArgSpec', [
'args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', 'kwonlydefaults',
'annotations'
])
FullArgSpec = collections.namedtuple(
'FullArgSpec',
[
'args',
'varargs',
'varkw',
'defaults',
'kwonlyargs',
'kwonlydefaults',
'annotations',
],
)
def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
......@@ -113,7 +121,7 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
data can be various-length. This API is used in translating dygraph into
static graph.
Note:
Note:
The default :code:`stop_gradient` attribute of the Tensor created by
this API is true, which means the gradient won't be passed backward
through the data Tensor. Set :code:`var.stop_gradient = False` If
......@@ -124,7 +132,7 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
for more details.
shape (list|tuple): List|Tuple of integers declaring the shape. You can
set "None" at a dimension to indicate the dimension can be of any
size. For example, it is useful to set changeable batch size as "None"
size. For example, it is useful to set changeable batch size as "None"
dtype (np.dtype|VarType|str, optional): The type of the data. Supported
dtype: bool, float16, float32, float64, int8, int16, int32, int64,
uint8. Default: float32
......@@ -141,20 +149,26 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
if shape[i] is None:
shape[i] = -1
return helper.create_global_variable(name=name,
shape=shape,
dtype=dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
stop_gradient=True,
lod_level=lod_level,
is_data=True,
need_check_feed=False)
return helper.create_global_variable(
name=name,
shape=shape,
dtype=dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
stop_gradient=True,
lod_level=lod_level,
is_data=True,
need_check_feed=False,
)
def create_undefined_variable():
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
var = data_layer_not_check(unique_name.generate("undefined_var"), [1],
"float64")
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import (
RETURN_NO_VALUE_MAGIC_NUM,
)
var = data_layer_not_check(
unique_name.generate("undefined_var"), [1], "float64"
)
var.stop_gradient = False
# the variable is created in block(0), we append assign in block(0) either.
helper = LayerHelper('create_undefined_variable', **locals())
......@@ -166,17 +180,16 @@ def create_undefined_variable():
class UndefinedVar:
def __init__(self, name):
self.name = name
def check(self):
raise UnboundLocalError(
"local variable '{}' should be created before using it.")
"local variable '{}' should be created before using it."
)
class Dygraph2StaticException(Exception):
def __init__(self, message):
super().__init__(message)
......@@ -193,13 +206,15 @@ def getfullargspec(target):
return inspect.getfullargspec(target)
else:
argspec = inspect.getargspec(target)
return FullArgSpec(args=argspec.args,
varargs=argspec.varargs,
varkw=argspec.keywords,
defaults=argspec.defaults,
kwonlyargs=[],
kwonlydefaults=None,
annotations={})
return FullArgSpec(
args=argspec.args,
varargs=argspec.varargs,
varkw=argspec.keywords,
defaults=argspec.defaults,
kwonlyargs=[],
kwonlydefaults=None,
annotations={},
)
def parse_arg_and_kwargs(function):
......@@ -216,7 +231,7 @@ def parse_arg_and_kwargs(function):
default_values = fullargspec.defaults
if default_values:
assert len(default_values) <= len(arg_names)
default_kwarg_names = arg_names[-len(default_values):]
default_kwarg_names = arg_names[-len(default_values) :]
default_kwargs = dict(zip(default_kwarg_names, default_values))
return arg_names, default_kwargs
......@@ -290,8 +305,9 @@ def is_api_in_module(node, module_prefix):
from paddle.fluid.dygraph import to_variable
from paddle import to_tensor
return eval("_is_api_in_module_helper({}, '{}')".format(
func_str, module_prefix))
return eval(
"_is_api_in_module_helper({}, '{}')".format(func_str, module_prefix)
)
except Exception:
return False
......@@ -322,8 +338,10 @@ def is_numpy_api(node):
func_str = astor.to_source(gast.gast_to_ast(node.func))
try:
import numpy as np
module_result = eval("_is_api_in_module_helper({}, '{}')".format(
func_str, "numpy"))
module_result = eval(
"_is_api_in_module_helper({}, '{}')".format(func_str, "numpy")
)
# BUG: np.random.uniform doesn't have module and cannot be analyzed
# TODO: find a better way
if not module_result:
......@@ -332,18 +350,19 @@ def is_numpy_api(node):
return False
def is_control_flow_to_transform(node,
static_analysis_visitor=None,
var_name_to_type=None):
def is_control_flow_to_transform(
node, static_analysis_visitor=None, var_name_to_type=None
):
"""
Determines whether the node is a PaddlePaddle control flow statement which needs to
be transformed into a static graph control flow statement.
"""
assert isinstance(node, gast.AST), \
"The type of input node must be gast.AST, but received %s." % type(node)
visitor = IsControlFlowVisitor(node,
static_analysis_visitor,
node_var_type_map=var_name_to_type)
assert isinstance(
node, gast.AST
), "The type of input node must be gast.AST, but received %s." % type(node)
visitor = IsControlFlowVisitor(
node, static_analysis_visitor, node_var_type_map=var_name_to_type
)
need_to_transform = visitor.transform()
return need_to_transform
......@@ -352,6 +371,7 @@ def _delete_keywords_from(node):
assert isinstance(node, gast.Call)
func_src = astor.to_source(gast.gast_to_ast(node.func))
import paddle.fluid as fluid
full_args = eval("inspect.getargspec({})".format(func_src))
full_args_name = full_args[0]
......@@ -365,7 +385,8 @@ def to_static_api(dygraph_class):
else:
raise NotImplementedError(
"Paddle dygraph API {} cannot be converted "
"to static graph at present.".format(dygraph_class))
"to static graph at present.".format(dygraph_class)
)
def _add_keywords_to(node, dygraph_api_name):
......@@ -376,8 +397,10 @@ def _add_keywords_to(node, dygraph_api_name):
ast_keyword.arg = "size"
node.keywords.append(
gast.keyword(arg="num_flatten_dims",
value=gast.Constant(value=-1, kind=None)))
gast.keyword(
arg="num_flatten_dims", value=gast.Constant(value=-1, kind=None)
)
)
if dygraph_api_name == "BilinearTensorProduct":
for ast_keyword in node.keywords:
......@@ -396,15 +419,17 @@ def to_static_ast(node, class_node):
assert isinstance(class_node, gast.Call)
static_api = to_static_api(class_node.func.attr)
node.func = gast.Attribute(attr=static_api,
ctx=gast.Load(),
value=gast.Attribute(attr='layers',
ctx=gast.Load(),
value=gast.Name(
ctx=gast.Load(),
id='fluid',
annotation=None,
type_comment=None)))
node.func = gast.Attribute(
attr=static_api,
ctx=gast.Load(),
value=gast.Attribute(
attr='layers',
ctx=gast.Load(),
value=gast.Name(
ctx=gast.Load(), id='fluid', annotation=None, type_comment=None
),
),
)
update_args_of_func(node, class_node, 'forward')
......@@ -427,10 +452,13 @@ def update_args_of_func(node, dygraph_node, method_name):
class_src = astor.to_source(gast.gast_to_ast(dygraph_node.func))
import paddle.fluid as fluid
if method_name == "__init__" or eval(
"issubclass({}, fluid.dygraph.Layer)".format(class_src)):
full_args = eval("inspect.getargspec({}.{})".format(
class_src, method_name))
"issubclass({}, fluid.dygraph.Layer)".format(class_src)
):
full_args = eval(
"inspect.getargspec({}.{})".format(class_src, method_name)
)
full_args_name = [
arg_name for arg_name in full_args[0] if arg_name != "self"
]
......@@ -445,21 +473,24 @@ def update_args_of_func(node, dygraph_node, method_name):
def create_api_shape_node(tensor_shape_node):
assert isinstance(tensor_shape_node,
(gast.Name, gast.Attribute, gast.Subscript))
assert isinstance(
tensor_shape_node, (gast.Name, gast.Attribute, gast.Subscript)
)
if isinstance(tensor_shape_node, gast.Name):
api_shape_node = gast.Call(
func=gast.parse('paddle.shape').body[0].value,
args=[tensor_shape_node],
keywords=[])
keywords=[],
)
return api_shape_node
if isinstance(tensor_shape_node, gast.Attribute):
api_shape_node = gast.Call(
func=gast.parse('paddle.shape').body[0].value,
args=[tensor_shape_node.value],
keywords=[])
keywords=[],
)
return api_shape_node
if isinstance(tensor_shape_node, gast.Subscript):
......@@ -469,14 +500,15 @@ def create_api_shape_node(tensor_shape_node):
def get_constant_variable_node(name, value, shape=[1], dtype='int64'):
return gast.parse('%s = paddle.full(%s, "%s", %s)' %
(name, str(shape), str(value), dtype))
return gast.parse(
'%s = paddle.full(%s, "%s", %s)' % (name, str(shape), str(value), dtype)
)
def get_attribute_full_name(node):
assert isinstance(
node,
gast.Attribute), "Input non-Attribute node to get attribute full name"
node, gast.Attribute
), "Input non-Attribute node to get attribute full name"
return astor.to_source(gast.gast_to_ast(node)).strip()
......@@ -494,15 +526,15 @@ def generate_name_node(name_ids, ctx=gast.Load(), gen_tuple_if_single=False):
name_ids = [name_ids]
if not isinstance(name_ids, (list, tuple, set)):
raise TypeError(
'name_ids must be list or tuple or set, but received %s' %
type(type(name_ids)))
'name_ids must be list or tuple or set, but received %s'
% type(type(name_ids))
)
def create_node_for_name(name):
if '.' not in name:
return gast.Name(id=name,
ctx=ctx,
annotation=None,
type_comment=None)
return gast.Name(
id=name, ctx=ctx, annotation=None, type_comment=None
)
return gast.parse(name).body[0].value
gast_names = [create_node_for_name(name_id) for name_id in name_ids]
......@@ -524,12 +556,14 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids):
nodes.append(gast.Return(value=generate_name_node(return_name_ids)))
else:
nodes.append(gast.Return(value=None))
func_def_node = gast.FunctionDef(name=name,
args=input_args,
body=nodes,
decorator_list=[],
returns=None,
type_comment=None)
func_def_node = gast.FunctionDef(
name=name,
args=input_args,
body=nodes,
decorator_list=[],
returns=None,
type_comment=None,
)
return func_def_node
......@@ -554,8 +588,9 @@ def get_temp_dir():
"""
Return @to_static temp directory.
"""
dir_name = "paddle/to_static_tmp"
dir_name = "paddle/to_static_tmp/{pid}".format(pid=os.getpid())
temp_dir = os.path.join(os.path.expanduser('~/.cache'), dir_name)
is_windows = sys.platform.startswith('win')
if is_windows:
temp_dir = os.path.normpath(temp_dir)
......@@ -589,12 +624,14 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
source = ast_to_source_code(ast_root)
source = _inject_import_statements() + source
temp_dir = get_temp_dir()
f = tempfile.NamedTemporaryFile(mode='w',
prefix=func_prefix(dyfunc),
suffix='.py',
delete=False,
dir=temp_dir,
encoding='utf-8')
f = tempfile.NamedTemporaryFile(
mode='w',
prefix=func_prefix(dyfunc),
suffix='.py',
delete=False,
dir=temp_dir,
encoding='utf-8',
)
with f:
module_name = os.path.basename(f.name[:-3])
f.write(source)
......@@ -616,8 +653,9 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
callable_func = getattr(module, func_name)
else:
raise ValueError(
'Function: %s doesn\'t exist in the Module transformed from AST.' %
func_name)
'Function: %s doesn\'t exist in the Module transformed from AST.'
% func_name
)
# After transform dygraph function into callable_func saved in tmp file,
# it lost the global variables from imported statements or defined in source file.
# Recovers the necessary variables by `__globals__`.
......@@ -628,10 +666,14 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
def _inject_import_statements():
import_statements = [
"import paddle", "from paddle import Tensor",
"import paddle.fluid as fluid", "import paddle.jit.dy2static as _jst",
"from typing import *", "import numpy as np", "import warnings",
"warnings.filterwarnings('ignore', category=DeprecationWarning)"
"import paddle",
"from paddle import Tensor",
"import paddle.fluid as fluid",
"import paddle.jit.dy2static as _jst",
"from typing import *",
"import numpy as np",
"import warnings",
"warnings.filterwarnings('ignore', category=DeprecationWarning)",
]
return '\n'.join(import_statements) + '\n'
......@@ -654,8 +696,10 @@ def func_to_source_code(function, dedent=True):
"""
if not (inspect.isfunction(function) or inspect.ismethod(function)):
raise TypeError(
"The type of 'function' should be a function or method, but received {}."
.format(type(function).__name__))
"The type of 'function' should be a function or method, but received {}.".format(
type(function).__name__
)
)
source_code_list, _ = inspect.getsourcelines(function)
# Replace comments with blank lines so that error messages are not misplaced
source_code_list = [
......@@ -675,8 +719,9 @@ def ast_to_source_code(ast_node):
"""
if not isinstance(ast_node, (gast.AST, ast.AST)):
raise TypeError(
"Type of ast_root should be gast.AST or ast.AST, but received %s." %
type(ast_node))
"Type of ast_root should be gast.AST or ast.AST, but received %s."
% type(ast_node)
)
if isinstance(ast_node, gast.AST):
ast_node = gast.gast_to_ast(ast_node)
......@@ -692,8 +737,17 @@ def is_candidate_node(node):
"""
Nodes with specified type will be dependent on tensor.
"""
is_compare_node = isinstance(node, (gast.Compare, gast.BoolOp, gast.UnaryOp,
gast.For, gast.If, gast.While))
is_compare_node = isinstance(
node,
(
gast.Compare,
gast.BoolOp,
gast.UnaryOp,
gast.For,
gast.If,
gast.While,
),
)
# TODO(Aurelius84): `.numpy()` may be an customized function,
# and should consider a more elegant way to solve this problem.
has_numpy_attr = ".numpy()" in ast_to_source_code(node)
......@@ -709,9 +763,9 @@ def compare_with_none(node):
# node.comparators is a list.
if isinstance(child, list):
child = child[0]
if (isinstance(child, gast.Constant)
and child.value is None) or (isinstance(child, gast.Name)
and child.id == 'None'):
if (isinstance(child, gast.Constant) and child.value is None) or (
isinstance(child, gast.Name) and child.id == 'None'
):
return True
return False
......@@ -746,20 +800,22 @@ class IsControlFlowVisitor(gast.NodeVisitor):
because reshape_op may be called before this statement.
"""
def __init__(self,
ast_node,
static_analysis_visitor=None,
node_var_type_map=None):
def __init__(
self, ast_node, static_analysis_visitor=None, node_var_type_map=None
):
assert isinstance(
ast_node, gast.AST
), "Type of input node should be gast.AST, but received %s." % type(
ast_node)
ast_node
)
self.ast_root = ast_node
if static_analysis_visitor is None:
from .static_analysis import StaticAnalysisVisitor
static_analysis_visitor = StaticAnalysisVisitor(ast_node)
self.static_analysis_visitor = static_analysis_visitor
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
self.node_to_wrapper_map = (
self.static_analysis_visitor.get_node_to_wrapper_map()
)
self.node_var_type_map = node_var_type_map
......@@ -788,7 +844,10 @@ class IsControlFlowVisitor(gast.NodeVisitor):
if isinstance(node.iter, gast.Call):
# for in range(var[0]|var.numpy()[0]) or for in enumerate(var|var.numpy())
if isinstance(node.iter.func, gast.Name):
if node.iter.func.id == "range" or node.iter.func.id == "enumerate":
if (
node.iter.func.id == "range"
or node.iter.func.id == "enumerate"
):
for arg in node.iter.args:
self.visit(arg)
else:
......@@ -887,7 +946,9 @@ class IsControlFlowVisitor(gast.NodeVisitor):
return node
def _is_node_with_tensor(self, node, name_id):
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import (
NodeVarType,
)
# Look up the node_var_type_map by name_id.
if self.node_var_type_map:
......@@ -917,7 +978,7 @@ def unwrap(func):
return hasattr(f, '__wrapped__')
unwrapped_f = func
while (_is_wrapped(unwrapped_f)):
while _is_wrapped(unwrapped_f):
unwrapped_f = unwrapped_f.__wrapped__
return unwrapped_f
......@@ -941,10 +1002,12 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
if spec not in desired_input_specs:
return False
else:
for (src_spec, desired_spec) in zip(src_input_specs,
desired_input_specs):
for (src_spec, desired_spec) in zip(
src_input_specs, desired_input_specs
):
if isinstance(src_spec, paddle.static.InputSpec) or isinstance(
desired_spec, paddle.static.InputSpec):
desired_spec, paddle.static.InputSpec
):
if not _compatible_tensor_spec(src_spec, desired_spec):
return False
else:
......@@ -1029,15 +1092,14 @@ def slice_is_num(slice_node):
class NameScope:
def __init__(self):
"""
A NameScope is a object which manager all the variable names.
only FunctionDef and Controlflow node will have a namescope property.
"""
A NameScope is a object which manager all the variable names.
only FunctionDef and Controlflow node will have a namescope property.
type can be "function" and "controlflow"
type can be "function" and "controlflow"
we don't analyze the read only variable because they don't affect the analysis.
we don't analyze the read only variable because they don't affect the analysis.
"""
self.globals = set()
self.nonlocals = set()
......@@ -1053,8 +1115,8 @@ class NameScope:
self.father = father
def existed_vars(self):
""" vars existing in current scope.
they must not contain qualified names.
"""vars existing in current scope.
they must not contain qualified names.
"""
local_vars = self.w_vars - self.globals - self.nonlocals - self.args
return set(filter(lambda x: '.' not in x, local_vars))
......@@ -1067,9 +1129,9 @@ class NameScope:
return self.w_vars
def variadic_length_vars(self):
"""
"""
At present, we do not support global append, such as
import numpy as np
a = []
def func():
......@@ -1083,33 +1145,38 @@ class NameScope:
f"Find variable `{var}` defined in global scope"
f" and call `{var}.append() or {var}.pop()`"
f", which will be ignored and never be transfered into"
f" tensor array.")
f" tensor array."
)
else:
non_global_push_pop_names.append(var)
return set(non_global_push_pop_names)
def control_flow_vars(self):
valid_names = self.w_vars
tmp = self.father.global_vars & valid_names,
tmp = (self.father.global_vars & valid_names,)
return {"global": tmp, "nonlocal": self.w_vars - tmp}
def _is_simple_name(self, name):
if '.' in name or '[' in name: return False
if '.' in name or '[' in name:
return False
return True
def is_global_var(self, name):
"""
"""
Return whether the name is a var created in global scope.
Search from bottom to top. If it is not created or modified,
Search from bottom to top. If it is not created or modified,
it means global vars; otherwise, it means local vars.
Only valid after FunctionNameLivenessAnalysis visitor.
"""
assert self._is_simple_name(
name), "is_global_var accept a simple name, but get `{name}`."
name
), "is_global_var accept a simple name, but get `{name}`."
ancestor = self
while ancestor is not None:
if name in ancestor.globals: return True
if name in (ancestor.nonlocals | ancestor.w_vars): return False
if name in ancestor.globals:
return True
if name in (ancestor.nonlocals | ancestor.w_vars):
return False
ancestor = ancestor.father
return True
......@@ -1125,46 +1192,46 @@ class NameScope:
class FunctionNameLivenessAnalysis(gast.NodeVisitor):
""" analyze the liveness of a function.
every variables stored in this scope will be collected,
in addition with global/nonlocal information and
push_pop information.
1. global variable is stored in node.var_globals.
2. nonlocal variable is stored in node.var_nonlocals.
3. arguments is stored in node.var_args.
4. if a variable's push and pop attribute is called,
it will be collected in push_pop_vars. They are
used for transformation to tensor_array.
NOTE: push_pop_vars **may not** in w_vars.
a.push(0) don't modify the variable a, but the content
of a.
For example:
def func(*args, **kargs):
a = 12
global i,j
nonlocal x,y
print(a)
i = k
b = []
c = [1,2,3]
for m in range(10):
q = 12
b.push(1)
c.pop()
After this visitor we have:
# node is the FunctionDef node with name: "func"
node.pd_scope = NameScope(
globals = ['i', 'j'],
nonlocals = ['x', 'y'],
args = ['args', 'kargs'],
wr_vars = ['a', 'i', 'q', 'm', 'c', 'b']
push_pop_vars = ['b', 'c']
)
"""analyze the liveness of a function.
every variables stored in this scope will be collected,
in addition with global/nonlocal information and
push_pop information.
1. global variable is stored in node.var_globals.
2. nonlocal variable is stored in node.var_nonlocals.
3. arguments is stored in node.var_args.
4. if a variable's push and pop attribute is called,
it will be collected in push_pop_vars. They are
used for transformation to tensor_array.
NOTE: push_pop_vars **may not** in w_vars.
a.push(0) don't modify the variable a, but the content
of a.
For example:
def func(*args, **kargs):
a = 12
global i,j
nonlocal x,y
print(a)
i = k
b = []
c = [1,2,3]
for m in range(10):
q = 12
b.push(1)
c.pop()
After this visitor we have:
# node is the FunctionDef node with name: "func"
node.pd_scope = NameScope(
globals = ['i', 'j'],
nonlocals = ['x', 'y'],
args = ['args', 'kargs'],
wr_vars = ['a', 'i', 'q', 'm', 'c', 'b']
push_pop_vars = ['b', 'c']
)
"""
def __init__(self, root_node):
......@@ -1184,25 +1251,26 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
return self._get_name_scope(self.scope_node_stack[-1])
def _father_name_scope(self):
if len(self.scope_node_stack) == 1: return None
if len(self.scope_node_stack) == 1:
return None
return self._get_name_scope(self.scope_node_stack[-2])
def _nearest_function_scope(self):
if len(self.scope_node_stack) == 1: return None
if len(self.scope_node_stack) == 1:
return None
for node in self.scope_node_stack[-2::-1]:
if isinstance(node, gast.FunctionDef):
return self._get_name_scope(node)
def visit_ListComp(self, node):
""" [ i for i in range(10) ]
In this case, `i` will not created in FunctionScope.
We don't collect `i` by not calling generic_visit.
"""[ i for i in range(10) ]
In this case, `i` will not created in FunctionScope.
We don't collect `i` by not calling generic_visit.
"""
pass
def visit_DictComp(self, node):
""" the same as ListComp.
"""
"""the same as ListComp."""
pass
def visit_Name(self, node):
......@@ -1212,62 +1280,86 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
self._current_name_scope().w_vars.add(node.id)
def visit_FunctionDef(self, node):
def pre_func():
self._current_name_scope().args |= set(
self._get_argument_names(node))
self._get_argument_names(node)
)
def post_func():
""" NOTE: why we need merge w_vars and push_pop_vars here ?
because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope.
"""NOTE: why we need merge w_vars and push_pop_vars here ?
because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope.
"""
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import WHILE_CONDITION_PREFIX, WHILE_BODY_PREFIX, FOR_CONDITION_PREFIX, FOR_BODY_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import TRUE_FUNC_PREFIX, FALSE_FUNC_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import (
WHILE_CONDITION_PREFIX,
WHILE_BODY_PREFIX,
FOR_CONDITION_PREFIX,
FOR_BODY_PREFIX,
)
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import (
TRUE_FUNC_PREFIX,
FALSE_FUNC_PREFIX,
)
control_flow_function_def = [
WHILE_BODY_PREFIX, WHILE_BODY_PREFIX, FOR_CONDITION_PREFIX,
FOR_BODY_PREFIX, TRUE_FUNC_PREFIX, FALSE_FUNC_PREFIX
WHILE_BODY_PREFIX,
WHILE_BODY_PREFIX,
FOR_CONDITION_PREFIX,
FOR_BODY_PREFIX,
TRUE_FUNC_PREFIX,
FALSE_FUNC_PREFIX,
]
def is_control_flow_def_node():
for prefix in control_flow_function_def:
if node.name.startswith(prefix): return True
if node.name.startswith(prefix):
return True
return False
if self._father_name_scope() and is_control_flow_def_node():
self._father_name_scope().w_vars |= self._current_name_scope(
).w_vars
self._father_name_scope(
).push_pop_vars |= self._current_name_scope().push_pop_vars
self._father_name_scope().w_vars |= (
self._current_name_scope().w_vars
)
self._father_name_scope().push_pop_vars |= (
self._current_name_scope().push_pop_vars
)
self._visit_scope_node(node, pre_func, post_func)
def _visit_scope_node(self, node, pre_func, post_func):
""" scope node main visit logic.
pre_func and post_func is callbacks
"""scope node main visit logic.
pre_func and post_func is callbacks
"""
self._reset_name_scope(node)
self.scope_node_stack.append(node)
self._current_name_scope().set_father(self._nearest_function_scope())
if pre_func: pre_func()
if pre_func:
pre_func()
self.generic_visit(node)
if post_func: post_func()
if post_func:
post_func()
self.scope_node_stack.pop()
def _visit_controlflow_node(self, node):
def post_func():
self._father_name_scope().merge_from(self._current_name_scope())
self._nearest_function_scope().merge_from(
self._current_name_scope())
self._current_name_scope().created = self._nearest_function_scope(
).existed_vars() - node.before_created
self._current_name_scope()
)
self._current_name_scope().created = (
self._nearest_function_scope().existed_vars()
- node.before_created
)
# gather created vars into father and used in CreateUndefinedVarTransform
self._nearest_function_scope().created |= self._current_name_scope(
).created
self._nearest_function_scope().created |= (
self._current_name_scope().created
)
def pre_func():
setattr(node, "before_created",
self._nearest_function_scope().existed_vars())
setattr(
node,
"before_created",
self._nearest_function_scope().existed_vars(),
)
self._visit_scope_node(node, pre_func, post_func)
......@@ -1305,12 +1397,13 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
self._current_name_scope().push_pop_vars.add(name)
def _get_argument_names(self, node):
""" get all arguments name in the functiondef node.
this node is local to the function and shouldn't
be created.
"""get all arguments name in the functiondef node.
this node is local to the function and shouldn't
be created.
"""
assert isinstance(
node, gast.FunctionDef), "Input node is not function define node"
node, gast.FunctionDef
), "Input node is not function define node"
names = [a for a in node.args.args]
names.append(node.args.vararg)
names.append(node.args.kwarg)
......@@ -1331,7 +1424,9 @@ def create_get_args_node(names):
func_def = """
def {func_name}():
return
""".format(func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX))
""".format(
func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX)
)
return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple))
......@@ -1350,7 +1445,8 @@ def create_get_args_node(names):
func_def = template.format(
func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX),
nonlocal_vars=nonlocal_vars,
vars=",".join(names))
vars=",".join(names),
)
return gast.parse(textwrap.dedent(func_def)).body[0]
......@@ -1367,8 +1463,9 @@ def create_set_args_node(names):
func_def = """
def {func_name}({args}):
pass
""".format(func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
args=ARGS_NAME)
""".format(
func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), args=ARGS_NAME
)
return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple))
......@@ -1388,7 +1485,8 @@ def create_set_args_node(names):
func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
args=ARGS_NAME,
nonlocal_vars=nonlocal_vars,
vars=",".join(names))
vars=",".join(names),
)
return gast.parse(textwrap.dedent(func_def)).body[0]
......@@ -1398,8 +1496,8 @@ def create_nonlocal_stmt_nodes(names):
mapped = list(filter(lambda n: '.' not in n, names))
mapped = list(filter(lambda n: '[' not in n, mapped))
names = sorted(
mapped,
key=mapped.index) # to keep the order, we can't use set() to unique
mapped, key=mapped.index
) # to keep the order, we can't use set() to unique
if not names:
return []
func_code = "nonlocal {}".format(','.join(names))
......@@ -1407,10 +1505,10 @@ def create_nonlocal_stmt_nodes(names):
class GetterSetterHelper:
""" we have two classes of names in setter and getter function:
w_vars(loop_vars) + push_pop_vars
To simplify the setter logic in convert_while and convert_cond,
we extract the helper class here.
"""we have two classes of names in setter and getter function:
w_vars(loop_vars) + push_pop_vars
To simplify the setter logic in convert_while and convert_cond,
we extract the helper class here.
"""
def __init__(self, getter_func, setter_func, *name_lists):
......@@ -1426,22 +1524,33 @@ class GetterSetterHelper:
return self._union
def get(self, names):
if names is None: names = []
if names is None:
names = []
vars = self.getter()
if vars is None: return tuple()
if vars is None:
return tuple()
for n in names:
assert n in self.name2id, "the name `{}` not in name union set`{}`.".format(
n, self.name2id.keys())
assert (
n in self.name2id
), "the name `{}` not in name union set`{}`.".format(
n, self.name2id.keys()
)
return tuple(map(lambda n: vars[self.name2id[n]], names))
def set(self, names, values):
if names is None: names = []
if values is None: values = []
if names is None:
names = []
if values is None:
values = []
vars = self.getter()
if vars is None: return
if vars is None:
return
for n in names:
assert n in self.name2id, "the name `{}` not in name union set`{}`.".format(
n, self.name2id.keys())
assert (
n in self.name2id
), "the name `{}` not in name union set`{}`.".format(
n, self.name2id.keys()
)
vars = list(vars)
indices = list(map(lambda n: self.name2id[n], names))
for i, v in zip(indices, values):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册