未验证 提交 12e6dfcf 编写于 作者: A Aurelius84 提交者: GitHub

[Cherry-Pick][Dy2Stat]Fix module loading OSError in multiprocess (#47302)

[Dy2Stat]Fix module loading OSError in multiprocess
上级 7c6550a6
...@@ -100,10 +100,18 @@ RE_PYMODULE = r'[a-zA-Z0-9_]+\.' ...@@ -100,10 +100,18 @@ RE_PYMODULE = r'[a-zA-Z0-9_]+\.'
# FullArgSpec is valid from Python3. Defined a Namedtuple to # FullArgSpec is valid from Python3. Defined a Namedtuple to
# to make it available in Python2. # to make it available in Python2.
FullArgSpec = collections.namedtuple('FullArgSpec', [ FullArgSpec = collections.namedtuple(
'args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', 'kwonlydefaults', 'FullArgSpec',
'annotations' [
]) 'args',
'varargs',
'varkw',
'defaults',
'kwonlyargs',
'kwonlydefaults',
'annotations',
],
)
def data_layer_not_check(name, shape, dtype='float32', lod_level=0): def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
...@@ -141,20 +149,26 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): ...@@ -141,20 +149,26 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
if shape[i] is None: if shape[i] is None:
shape[i] = -1 shape[i] = -1
return helper.create_global_variable(name=name, return helper.create_global_variable(
name=name,
shape=shape, shape=shape,
dtype=dtype, dtype=dtype,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
stop_gradient=True, stop_gradient=True,
lod_level=lod_level, lod_level=lod_level,
is_data=True, is_data=True,
need_check_feed=False) need_check_feed=False,
)
def create_undefined_variable(): def create_undefined_variable():
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM from paddle.fluid.dygraph.dygraph_to_static.return_transformer import (
var = data_layer_not_check(unique_name.generate("undefined_var"), [1], RETURN_NO_VALUE_MAGIC_NUM,
"float64") )
var = data_layer_not_check(
unique_name.generate("undefined_var"), [1], "float64"
)
var.stop_gradient = False var.stop_gradient = False
# the variable is created in block(0), we append assign in block(0) either. # the variable is created in block(0), we append assign in block(0) either.
helper = LayerHelper('create_undefined_variable', **locals()) helper = LayerHelper('create_undefined_variable', **locals())
...@@ -166,17 +180,16 @@ def create_undefined_variable(): ...@@ -166,17 +180,16 @@ def create_undefined_variable():
class UndefinedVar: class UndefinedVar:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
def check(self): def check(self):
raise UnboundLocalError( raise UnboundLocalError(
"local variable '{}' should be created before using it.") "local variable '{}' should be created before using it."
)
class Dygraph2StaticException(Exception): class Dygraph2StaticException(Exception):
def __init__(self, message): def __init__(self, message):
super().__init__(message) super().__init__(message)
...@@ -193,13 +206,15 @@ def getfullargspec(target): ...@@ -193,13 +206,15 @@ def getfullargspec(target):
return inspect.getfullargspec(target) return inspect.getfullargspec(target)
else: else:
argspec = inspect.getargspec(target) argspec = inspect.getargspec(target)
return FullArgSpec(args=argspec.args, return FullArgSpec(
args=argspec.args,
varargs=argspec.varargs, varargs=argspec.varargs,
varkw=argspec.keywords, varkw=argspec.keywords,
defaults=argspec.defaults, defaults=argspec.defaults,
kwonlyargs=[], kwonlyargs=[],
kwonlydefaults=None, kwonlydefaults=None,
annotations={}) annotations={},
)
def parse_arg_and_kwargs(function): def parse_arg_and_kwargs(function):
...@@ -216,7 +231,7 @@ def parse_arg_and_kwargs(function): ...@@ -216,7 +231,7 @@ def parse_arg_and_kwargs(function):
default_values = fullargspec.defaults default_values = fullargspec.defaults
if default_values: if default_values:
assert len(default_values) <= len(arg_names) assert len(default_values) <= len(arg_names)
default_kwarg_names = arg_names[-len(default_values):] default_kwarg_names = arg_names[-len(default_values) :]
default_kwargs = dict(zip(default_kwarg_names, default_values)) default_kwargs = dict(zip(default_kwarg_names, default_values))
return arg_names, default_kwargs return arg_names, default_kwargs
...@@ -290,8 +305,9 @@ def is_api_in_module(node, module_prefix): ...@@ -290,8 +305,9 @@ def is_api_in_module(node, module_prefix):
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle import to_tensor from paddle import to_tensor
return eval("_is_api_in_module_helper({}, '{}')".format( return eval(
func_str, module_prefix)) "_is_api_in_module_helper({}, '{}')".format(func_str, module_prefix)
)
except Exception: except Exception:
return False return False
...@@ -322,8 +338,10 @@ def is_numpy_api(node): ...@@ -322,8 +338,10 @@ def is_numpy_api(node):
func_str = astor.to_source(gast.gast_to_ast(node.func)) func_str = astor.to_source(gast.gast_to_ast(node.func))
try: try:
import numpy as np import numpy as np
module_result = eval("_is_api_in_module_helper({}, '{}')".format(
func_str, "numpy")) module_result = eval(
"_is_api_in_module_helper({}, '{}')".format(func_str, "numpy")
)
# BUG: np.random.uniform doesn't have module and cannot be analyzed # BUG: np.random.uniform doesn't have module and cannot be analyzed
# TODO: find a better way # TODO: find a better way
if not module_result: if not module_result:
...@@ -332,18 +350,19 @@ def is_numpy_api(node): ...@@ -332,18 +350,19 @@ def is_numpy_api(node):
return False return False
def is_control_flow_to_transform(node, def is_control_flow_to_transform(
static_analysis_visitor=None, node, static_analysis_visitor=None, var_name_to_type=None
var_name_to_type=None): ):
""" """
Determines whether the node is a PaddlePaddle control flow statement which needs to Determines whether the node is a PaddlePaddle control flow statement which needs to
be transformed into a static graph control flow statement. be transformed into a static graph control flow statement.
""" """
assert isinstance(node, gast.AST), \ assert isinstance(
"The type of input node must be gast.AST, but received %s." % type(node) node, gast.AST
visitor = IsControlFlowVisitor(node, ), "The type of input node must be gast.AST, but received %s." % type(node)
static_analysis_visitor, visitor = IsControlFlowVisitor(
node_var_type_map=var_name_to_type) node, static_analysis_visitor, node_var_type_map=var_name_to_type
)
need_to_transform = visitor.transform() need_to_transform = visitor.transform()
return need_to_transform return need_to_transform
...@@ -352,6 +371,7 @@ def _delete_keywords_from(node): ...@@ -352,6 +371,7 @@ def _delete_keywords_from(node):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
func_src = astor.to_source(gast.gast_to_ast(node.func)) func_src = astor.to_source(gast.gast_to_ast(node.func))
import paddle.fluid as fluid import paddle.fluid as fluid
full_args = eval("inspect.getargspec({})".format(func_src)) full_args = eval("inspect.getargspec({})".format(func_src))
full_args_name = full_args[0] full_args_name = full_args[0]
...@@ -365,7 +385,8 @@ def to_static_api(dygraph_class): ...@@ -365,7 +385,8 @@ def to_static_api(dygraph_class):
else: else:
raise NotImplementedError( raise NotImplementedError(
"Paddle dygraph API {} cannot be converted " "Paddle dygraph API {} cannot be converted "
"to static graph at present.".format(dygraph_class)) "to static graph at present.".format(dygraph_class)
)
def _add_keywords_to(node, dygraph_api_name): def _add_keywords_to(node, dygraph_api_name):
...@@ -376,8 +397,10 @@ def _add_keywords_to(node, dygraph_api_name): ...@@ -376,8 +397,10 @@ def _add_keywords_to(node, dygraph_api_name):
ast_keyword.arg = "size" ast_keyword.arg = "size"
node.keywords.append( node.keywords.append(
gast.keyword(arg="num_flatten_dims", gast.keyword(
value=gast.Constant(value=-1, kind=None))) arg="num_flatten_dims", value=gast.Constant(value=-1, kind=None)
)
)
if dygraph_api_name == "BilinearTensorProduct": if dygraph_api_name == "BilinearTensorProduct":
for ast_keyword in node.keywords: for ast_keyword in node.keywords:
...@@ -396,15 +419,17 @@ def to_static_ast(node, class_node): ...@@ -396,15 +419,17 @@ def to_static_ast(node, class_node):
assert isinstance(class_node, gast.Call) assert isinstance(class_node, gast.Call)
static_api = to_static_api(class_node.func.attr) static_api = to_static_api(class_node.func.attr)
node.func = gast.Attribute(attr=static_api, node.func = gast.Attribute(
attr=static_api,
ctx=gast.Load(), ctx=gast.Load(),
value=gast.Attribute(attr='layers', value=gast.Attribute(
attr='layers',
ctx=gast.Load(), ctx=gast.Load(),
value=gast.Name( value=gast.Name(
ctx=gast.Load(), ctx=gast.Load(), id='fluid', annotation=None, type_comment=None
id='fluid', ),
annotation=None, ),
type_comment=None))) )
update_args_of_func(node, class_node, 'forward') update_args_of_func(node, class_node, 'forward')
...@@ -427,10 +452,13 @@ def update_args_of_func(node, dygraph_node, method_name): ...@@ -427,10 +452,13 @@ def update_args_of_func(node, dygraph_node, method_name):
class_src = astor.to_source(gast.gast_to_ast(dygraph_node.func)) class_src = astor.to_source(gast.gast_to_ast(dygraph_node.func))
import paddle.fluid as fluid import paddle.fluid as fluid
if method_name == "__init__" or eval( if method_name == "__init__" or eval(
"issubclass({}, fluid.dygraph.Layer)".format(class_src)): "issubclass({}, fluid.dygraph.Layer)".format(class_src)
full_args = eval("inspect.getargspec({}.{})".format( ):
class_src, method_name)) full_args = eval(
"inspect.getargspec({}.{})".format(class_src, method_name)
)
full_args_name = [ full_args_name = [
arg_name for arg_name in full_args[0] if arg_name != "self" arg_name for arg_name in full_args[0] if arg_name != "self"
] ]
...@@ -445,21 +473,24 @@ def update_args_of_func(node, dygraph_node, method_name): ...@@ -445,21 +473,24 @@ def update_args_of_func(node, dygraph_node, method_name):
def create_api_shape_node(tensor_shape_node): def create_api_shape_node(tensor_shape_node):
assert isinstance(tensor_shape_node, assert isinstance(
(gast.Name, gast.Attribute, gast.Subscript)) tensor_shape_node, (gast.Name, gast.Attribute, gast.Subscript)
)
if isinstance(tensor_shape_node, gast.Name): if isinstance(tensor_shape_node, gast.Name):
api_shape_node = gast.Call( api_shape_node = gast.Call(
func=gast.parse('paddle.shape').body[0].value, func=gast.parse('paddle.shape').body[0].value,
args=[tensor_shape_node], args=[tensor_shape_node],
keywords=[]) keywords=[],
)
return api_shape_node return api_shape_node
if isinstance(tensor_shape_node, gast.Attribute): if isinstance(tensor_shape_node, gast.Attribute):
api_shape_node = gast.Call( api_shape_node = gast.Call(
func=gast.parse('paddle.shape').body[0].value, func=gast.parse('paddle.shape').body[0].value,
args=[tensor_shape_node.value], args=[tensor_shape_node.value],
keywords=[]) keywords=[],
)
return api_shape_node return api_shape_node
if isinstance(tensor_shape_node, gast.Subscript): if isinstance(tensor_shape_node, gast.Subscript):
...@@ -469,14 +500,15 @@ def create_api_shape_node(tensor_shape_node): ...@@ -469,14 +500,15 @@ def create_api_shape_node(tensor_shape_node):
def get_constant_variable_node(name, value, shape=[1], dtype='int64'): def get_constant_variable_node(name, value, shape=[1], dtype='int64'):
return gast.parse('%s = paddle.full(%s, "%s", %s)' % return gast.parse(
(name, str(shape), str(value), dtype)) '%s = paddle.full(%s, "%s", %s)' % (name, str(shape), str(value), dtype)
)
def get_attribute_full_name(node): def get_attribute_full_name(node):
assert isinstance( assert isinstance(
node, node, gast.Attribute
gast.Attribute), "Input non-Attribute node to get attribute full name" ), "Input non-Attribute node to get attribute full name"
return astor.to_source(gast.gast_to_ast(node)).strip() return astor.to_source(gast.gast_to_ast(node)).strip()
...@@ -494,15 +526,15 @@ def generate_name_node(name_ids, ctx=gast.Load(), gen_tuple_if_single=False): ...@@ -494,15 +526,15 @@ def generate_name_node(name_ids, ctx=gast.Load(), gen_tuple_if_single=False):
name_ids = [name_ids] name_ids = [name_ids]
if not isinstance(name_ids, (list, tuple, set)): if not isinstance(name_ids, (list, tuple, set)):
raise TypeError( raise TypeError(
'name_ids must be list or tuple or set, but received %s' % 'name_ids must be list or tuple or set, but received %s'
type(type(name_ids))) % type(type(name_ids))
)
def create_node_for_name(name): def create_node_for_name(name):
if '.' not in name: if '.' not in name:
return gast.Name(id=name, return gast.Name(
ctx=ctx, id=name, ctx=ctx, annotation=None, type_comment=None
annotation=None, )
type_comment=None)
return gast.parse(name).body[0].value return gast.parse(name).body[0].value
gast_names = [create_node_for_name(name_id) for name_id in name_ids] gast_names = [create_node_for_name(name_id) for name_id in name_ids]
...@@ -524,12 +556,14 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids): ...@@ -524,12 +556,14 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids):
nodes.append(gast.Return(value=generate_name_node(return_name_ids))) nodes.append(gast.Return(value=generate_name_node(return_name_ids)))
else: else:
nodes.append(gast.Return(value=None)) nodes.append(gast.Return(value=None))
func_def_node = gast.FunctionDef(name=name, func_def_node = gast.FunctionDef(
name=name,
args=input_args, args=input_args,
body=nodes, body=nodes,
decorator_list=[], decorator_list=[],
returns=None, returns=None,
type_comment=None) type_comment=None,
)
return func_def_node return func_def_node
...@@ -554,8 +588,9 @@ def get_temp_dir(): ...@@ -554,8 +588,9 @@ def get_temp_dir():
""" """
Return @to_static temp directory. Return @to_static temp directory.
""" """
dir_name = "paddle/to_static_tmp" dir_name = "paddle/to_static_tmp/{pid}".format(pid=os.getpid())
temp_dir = os.path.join(os.path.expanduser('~/.cache'), dir_name) temp_dir = os.path.join(os.path.expanduser('~/.cache'), dir_name)
is_windows = sys.platform.startswith('win') is_windows = sys.platform.startswith('win')
if is_windows: if is_windows:
temp_dir = os.path.normpath(temp_dir) temp_dir = os.path.normpath(temp_dir)
...@@ -589,12 +624,14 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): ...@@ -589,12 +624,14 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
source = ast_to_source_code(ast_root) source = ast_to_source_code(ast_root)
source = _inject_import_statements() + source source = _inject_import_statements() + source
temp_dir = get_temp_dir() temp_dir = get_temp_dir()
f = tempfile.NamedTemporaryFile(mode='w', f = tempfile.NamedTemporaryFile(
mode='w',
prefix=func_prefix(dyfunc), prefix=func_prefix(dyfunc),
suffix='.py', suffix='.py',
delete=False, delete=False,
dir=temp_dir, dir=temp_dir,
encoding='utf-8') encoding='utf-8',
)
with f: with f:
module_name = os.path.basename(f.name[:-3]) module_name = os.path.basename(f.name[:-3])
f.write(source) f.write(source)
...@@ -616,8 +653,9 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): ...@@ -616,8 +653,9 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
callable_func = getattr(module, func_name) callable_func = getattr(module, func_name)
else: else:
raise ValueError( raise ValueError(
'Function: %s doesn\'t exist in the Module transformed from AST.' % 'Function: %s doesn\'t exist in the Module transformed from AST.'
func_name) % func_name
)
# After transform dygraph function into callable_func saved in tmp file, # After transform dygraph function into callable_func saved in tmp file,
# it lost the global variables from imported statements or defined in source file. # it lost the global variables from imported statements or defined in source file.
# Recovers the necessary variables by `__globals__`. # Recovers the necessary variables by `__globals__`.
...@@ -628,10 +666,14 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): ...@@ -628,10 +666,14 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
def _inject_import_statements(): def _inject_import_statements():
import_statements = [ import_statements = [
"import paddle", "from paddle import Tensor", "import paddle",
"import paddle.fluid as fluid", "import paddle.jit.dy2static as _jst", "from paddle import Tensor",
"from typing import *", "import numpy as np", "import warnings", "import paddle.fluid as fluid",
"warnings.filterwarnings('ignore', category=DeprecationWarning)" "import paddle.jit.dy2static as _jst",
"from typing import *",
"import numpy as np",
"import warnings",
"warnings.filterwarnings('ignore', category=DeprecationWarning)",
] ]
return '\n'.join(import_statements) + '\n' return '\n'.join(import_statements) + '\n'
...@@ -654,8 +696,10 @@ def func_to_source_code(function, dedent=True): ...@@ -654,8 +696,10 @@ def func_to_source_code(function, dedent=True):
""" """
if not (inspect.isfunction(function) or inspect.ismethod(function)): if not (inspect.isfunction(function) or inspect.ismethod(function)):
raise TypeError( raise TypeError(
"The type of 'function' should be a function or method, but received {}." "The type of 'function' should be a function or method, but received {}.".format(
.format(type(function).__name__)) type(function).__name__
)
)
source_code_list, _ = inspect.getsourcelines(function) source_code_list, _ = inspect.getsourcelines(function)
# Replace comments with blank lines so that error messages are not misplaced # Replace comments with blank lines so that error messages are not misplaced
source_code_list = [ source_code_list = [
...@@ -675,8 +719,9 @@ def ast_to_source_code(ast_node): ...@@ -675,8 +719,9 @@ def ast_to_source_code(ast_node):
""" """
if not isinstance(ast_node, (gast.AST, ast.AST)): if not isinstance(ast_node, (gast.AST, ast.AST)):
raise TypeError( raise TypeError(
"Type of ast_root should be gast.AST or ast.AST, but received %s." % "Type of ast_root should be gast.AST or ast.AST, but received %s."
type(ast_node)) % type(ast_node)
)
if isinstance(ast_node, gast.AST): if isinstance(ast_node, gast.AST):
ast_node = gast.gast_to_ast(ast_node) ast_node = gast.gast_to_ast(ast_node)
...@@ -692,8 +737,17 @@ def is_candidate_node(node): ...@@ -692,8 +737,17 @@ def is_candidate_node(node):
""" """
Nodes with specified type will be dependent on tensor. Nodes with specified type will be dependent on tensor.
""" """
is_compare_node = isinstance(node, (gast.Compare, gast.BoolOp, gast.UnaryOp, is_compare_node = isinstance(
gast.For, gast.If, gast.While)) node,
(
gast.Compare,
gast.BoolOp,
gast.UnaryOp,
gast.For,
gast.If,
gast.While,
),
)
# TODO(Aurelius84): `.numpy()` may be an customized function, # TODO(Aurelius84): `.numpy()` may be an customized function,
# and should consider a more elegant way to solve this problem. # and should consider a more elegant way to solve this problem.
has_numpy_attr = ".numpy()" in ast_to_source_code(node) has_numpy_attr = ".numpy()" in ast_to_source_code(node)
...@@ -709,9 +763,9 @@ def compare_with_none(node): ...@@ -709,9 +763,9 @@ def compare_with_none(node):
# node.comparators is a list. # node.comparators is a list.
if isinstance(child, list): if isinstance(child, list):
child = child[0] child = child[0]
if (isinstance(child, gast.Constant) if (isinstance(child, gast.Constant) and child.value is None) or (
and child.value is None) or (isinstance(child, gast.Name) isinstance(child, gast.Name) and child.id == 'None'
and child.id == 'None'): ):
return True return True
return False return False
...@@ -746,20 +800,22 @@ class IsControlFlowVisitor(gast.NodeVisitor): ...@@ -746,20 +800,22 @@ class IsControlFlowVisitor(gast.NodeVisitor):
because reshape_op may be called before this statement. because reshape_op may be called before this statement.
""" """
def __init__(self, def __init__(
ast_node, self, ast_node, static_analysis_visitor=None, node_var_type_map=None
static_analysis_visitor=None, ):
node_var_type_map=None):
assert isinstance( assert isinstance(
ast_node, gast.AST ast_node, gast.AST
), "Type of input node should be gast.AST, but received %s." % type( ), "Type of input node should be gast.AST, but received %s." % type(
ast_node) ast_node
)
self.ast_root = ast_node self.ast_root = ast_node
if static_analysis_visitor is None: if static_analysis_visitor is None:
from .static_analysis import StaticAnalysisVisitor from .static_analysis import StaticAnalysisVisitor
static_analysis_visitor = StaticAnalysisVisitor(ast_node) static_analysis_visitor = StaticAnalysisVisitor(ast_node)
self.static_analysis_visitor = static_analysis_visitor self.static_analysis_visitor = static_analysis_visitor
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( self.node_to_wrapper_map = (
self.static_analysis_visitor.get_node_to_wrapper_map()
) )
self.node_var_type_map = node_var_type_map self.node_var_type_map = node_var_type_map
...@@ -788,7 +844,10 @@ class IsControlFlowVisitor(gast.NodeVisitor): ...@@ -788,7 +844,10 @@ class IsControlFlowVisitor(gast.NodeVisitor):
if isinstance(node.iter, gast.Call): if isinstance(node.iter, gast.Call):
# for in range(var[0]|var.numpy()[0]) or for in enumerate(var|var.numpy()) # for in range(var[0]|var.numpy()[0]) or for in enumerate(var|var.numpy())
if isinstance(node.iter.func, gast.Name): if isinstance(node.iter.func, gast.Name):
if node.iter.func.id == "range" or node.iter.func.id == "enumerate": if (
node.iter.func.id == "range"
or node.iter.func.id == "enumerate"
):
for arg in node.iter.args: for arg in node.iter.args:
self.visit(arg) self.visit(arg)
else: else:
...@@ -887,7 +946,9 @@ class IsControlFlowVisitor(gast.NodeVisitor): ...@@ -887,7 +946,9 @@ class IsControlFlowVisitor(gast.NodeVisitor):
return node return node
def _is_node_with_tensor(self, node, name_id): def _is_node_with_tensor(self, node, name_id):
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import (
NodeVarType,
)
# Look up the node_var_type_map by name_id. # Look up the node_var_type_map by name_id.
if self.node_var_type_map: if self.node_var_type_map:
...@@ -917,7 +978,7 @@ def unwrap(func): ...@@ -917,7 +978,7 @@ def unwrap(func):
return hasattr(f, '__wrapped__') return hasattr(f, '__wrapped__')
unwrapped_f = func unwrapped_f = func
while (_is_wrapped(unwrapped_f)): while _is_wrapped(unwrapped_f):
unwrapped_f = unwrapped_f.__wrapped__ unwrapped_f = unwrapped_f.__wrapped__
return unwrapped_f return unwrapped_f
...@@ -941,10 +1002,12 @@ def input_specs_compatible(src_input_specs, desired_input_specs): ...@@ -941,10 +1002,12 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
if spec not in desired_input_specs: if spec not in desired_input_specs:
return False return False
else: else:
for (src_spec, desired_spec) in zip(src_input_specs, for (src_spec, desired_spec) in zip(
desired_input_specs): src_input_specs, desired_input_specs
):
if isinstance(src_spec, paddle.static.InputSpec) or isinstance( if isinstance(src_spec, paddle.static.InputSpec) or isinstance(
desired_spec, paddle.static.InputSpec): desired_spec, paddle.static.InputSpec
):
if not _compatible_tensor_spec(src_spec, desired_spec): if not _compatible_tensor_spec(src_spec, desired_spec):
return False return False
else: else:
...@@ -1029,7 +1092,6 @@ def slice_is_num(slice_node): ...@@ -1029,7 +1092,6 @@ def slice_is_num(slice_node):
class NameScope: class NameScope:
def __init__(self): def __init__(self):
""" """
A NameScope is a object which manager all the variable names. A NameScope is a object which manager all the variable names.
...@@ -1053,7 +1115,7 @@ class NameScope: ...@@ -1053,7 +1115,7 @@ class NameScope:
self.father = father self.father = father
def existed_vars(self): def existed_vars(self):
""" vars existing in current scope. """vars existing in current scope.
they must not contain qualified names. they must not contain qualified names.
""" """
local_vars = self.w_vars - self.globals - self.nonlocals - self.args local_vars = self.w_vars - self.globals - self.nonlocals - self.args
...@@ -1083,18 +1145,20 @@ class NameScope: ...@@ -1083,18 +1145,20 @@ class NameScope:
f"Find variable `{var}` defined in global scope" f"Find variable `{var}` defined in global scope"
f" and call `{var}.append() or {var}.pop()`" f" and call `{var}.append() or {var}.pop()`"
f", which will be ignored and never be transfered into" f", which will be ignored and never be transfered into"
f" tensor array.") f" tensor array."
)
else: else:
non_global_push_pop_names.append(var) non_global_push_pop_names.append(var)
return set(non_global_push_pop_names) return set(non_global_push_pop_names)
def control_flow_vars(self): def control_flow_vars(self):
valid_names = self.w_vars valid_names = self.w_vars
tmp = self.father.global_vars & valid_names, tmp = (self.father.global_vars & valid_names,)
return {"global": tmp, "nonlocal": self.w_vars - tmp} return {"global": tmp, "nonlocal": self.w_vars - tmp}
def _is_simple_name(self, name): def _is_simple_name(self, name):
if '.' in name or '[' in name: return False if '.' in name or '[' in name:
return False
return True return True
def is_global_var(self, name): def is_global_var(self, name):
...@@ -1105,11 +1169,14 @@ class NameScope: ...@@ -1105,11 +1169,14 @@ class NameScope:
Only valid after FunctionNameLivenessAnalysis visitor. Only valid after FunctionNameLivenessAnalysis visitor.
""" """
assert self._is_simple_name( assert self._is_simple_name(
name), "is_global_var accept a simple name, but get `{name}`." name
), "is_global_var accept a simple name, but get `{name}`."
ancestor = self ancestor = self
while ancestor is not None: while ancestor is not None:
if name in ancestor.globals: return True if name in ancestor.globals:
if name in (ancestor.nonlocals | ancestor.w_vars): return False return True
if name in (ancestor.nonlocals | ancestor.w_vars):
return False
ancestor = ancestor.father ancestor = ancestor.father
return True return True
...@@ -1125,7 +1192,7 @@ class NameScope: ...@@ -1125,7 +1192,7 @@ class NameScope:
class FunctionNameLivenessAnalysis(gast.NodeVisitor): class FunctionNameLivenessAnalysis(gast.NodeVisitor):
""" analyze the liveness of a function. """analyze the liveness of a function.
every variables stored in this scope will be collected, every variables stored in this scope will be collected,
in addition with global/nonlocal information and in addition with global/nonlocal information and
...@@ -1184,25 +1251,26 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1184,25 +1251,26 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
return self._get_name_scope(self.scope_node_stack[-1]) return self._get_name_scope(self.scope_node_stack[-1])
def _father_name_scope(self): def _father_name_scope(self):
if len(self.scope_node_stack) == 1: return None if len(self.scope_node_stack) == 1:
return None
return self._get_name_scope(self.scope_node_stack[-2]) return self._get_name_scope(self.scope_node_stack[-2])
def _nearest_function_scope(self): def _nearest_function_scope(self):
if len(self.scope_node_stack) == 1: return None if len(self.scope_node_stack) == 1:
return None
for node in self.scope_node_stack[-2::-1]: for node in self.scope_node_stack[-2::-1]:
if isinstance(node, gast.FunctionDef): if isinstance(node, gast.FunctionDef):
return self._get_name_scope(node) return self._get_name_scope(node)
def visit_ListComp(self, node): def visit_ListComp(self, node):
""" [ i for i in range(10) ] """[ i for i in range(10) ]
In this case, `i` will not created in FunctionScope. In this case, `i` will not created in FunctionScope.
We don't collect `i` by not calling generic_visit. We don't collect `i` by not calling generic_visit.
""" """
pass pass
def visit_DictComp(self, node): def visit_DictComp(self, node):
""" the same as ListComp. """the same as ListComp."""
"""
pass pass
def visit_Name(self, node): def visit_Name(self, node):
...@@ -1212,62 +1280,86 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1212,62 +1280,86 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
self._current_name_scope().w_vars.add(node.id) self._current_name_scope().w_vars.add(node.id)
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
def pre_func(): def pre_func():
self._current_name_scope().args |= set( self._current_name_scope().args |= set(
self._get_argument_names(node)) self._get_argument_names(node)
)
def post_func(): def post_func():
""" NOTE: why we need merge w_vars and push_pop_vars here ? """NOTE: why we need merge w_vars and push_pop_vars here ?
because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope. because we do ifelse_transformer after loop_transformer. Loops will changed into functioons. but we know this function will be called in if. so we add w_vars to father function scope.
""" """
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import WHILE_CONDITION_PREFIX, WHILE_BODY_PREFIX, FOR_CONDITION_PREFIX, FOR_BODY_PREFIX from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import (
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import TRUE_FUNC_PREFIX, FALSE_FUNC_PREFIX WHILE_CONDITION_PREFIX,
WHILE_BODY_PREFIX,
FOR_CONDITION_PREFIX,
FOR_BODY_PREFIX,
)
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import (
TRUE_FUNC_PREFIX,
FALSE_FUNC_PREFIX,
)
control_flow_function_def = [ control_flow_function_def = [
WHILE_BODY_PREFIX, WHILE_BODY_PREFIX, FOR_CONDITION_PREFIX, WHILE_BODY_PREFIX,
FOR_BODY_PREFIX, TRUE_FUNC_PREFIX, FALSE_FUNC_PREFIX WHILE_BODY_PREFIX,
FOR_CONDITION_PREFIX,
FOR_BODY_PREFIX,
TRUE_FUNC_PREFIX,
FALSE_FUNC_PREFIX,
] ]
def is_control_flow_def_node(): def is_control_flow_def_node():
for prefix in control_flow_function_def: for prefix in control_flow_function_def:
if node.name.startswith(prefix): return True if node.name.startswith(prefix):
return True
return False return False
if self._father_name_scope() and is_control_flow_def_node(): if self._father_name_scope() and is_control_flow_def_node():
self._father_name_scope().w_vars |= self._current_name_scope( self._father_name_scope().w_vars |= (
).w_vars self._current_name_scope().w_vars
self._father_name_scope( )
).push_pop_vars |= self._current_name_scope().push_pop_vars self._father_name_scope().push_pop_vars |= (
self._current_name_scope().push_pop_vars
)
self._visit_scope_node(node, pre_func, post_func) self._visit_scope_node(node, pre_func, post_func)
def _visit_scope_node(self, node, pre_func, post_func): def _visit_scope_node(self, node, pre_func, post_func):
""" scope node main visit logic. """scope node main visit logic.
pre_func and post_func is callbacks pre_func and post_func is callbacks
""" """
self._reset_name_scope(node) self._reset_name_scope(node)
self.scope_node_stack.append(node) self.scope_node_stack.append(node)
self._current_name_scope().set_father(self._nearest_function_scope()) self._current_name_scope().set_father(self._nearest_function_scope())
if pre_func: pre_func() if pre_func:
pre_func()
self.generic_visit(node) self.generic_visit(node)
if post_func: post_func() if post_func:
post_func()
self.scope_node_stack.pop() self.scope_node_stack.pop()
def _visit_controlflow_node(self, node): def _visit_controlflow_node(self, node):
def post_func(): def post_func():
self._father_name_scope().merge_from(self._current_name_scope()) self._father_name_scope().merge_from(self._current_name_scope())
self._nearest_function_scope().merge_from( self._nearest_function_scope().merge_from(
self._current_name_scope()) self._current_name_scope()
self._current_name_scope().created = self._nearest_function_scope( )
).existed_vars() - node.before_created self._current_name_scope().created = (
self._nearest_function_scope().existed_vars()
- node.before_created
)
# gather created vars into father and used in CreateUndefinedVarTransform # gather created vars into father and used in CreateUndefinedVarTransform
self._nearest_function_scope().created |= self._current_name_scope( self._nearest_function_scope().created |= (
).created self._current_name_scope().created
)
def pre_func(): def pre_func():
setattr(node, "before_created", setattr(
self._nearest_function_scope().existed_vars()) node,
"before_created",
self._nearest_function_scope().existed_vars(),
)
self._visit_scope_node(node, pre_func, post_func) self._visit_scope_node(node, pre_func, post_func)
...@@ -1305,12 +1397,13 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1305,12 +1397,13 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
self._current_name_scope().push_pop_vars.add(name) self._current_name_scope().push_pop_vars.add(name)
def _get_argument_names(self, node): def _get_argument_names(self, node):
""" get all arguments name in the functiondef node. """get all arguments name in the functiondef node.
this node is local to the function and shouldn't this node is local to the function and shouldn't
be created. be created.
""" """
assert isinstance( assert isinstance(
node, gast.FunctionDef), "Input node is not function define node" node, gast.FunctionDef
), "Input node is not function define node"
names = [a for a in node.args.args] names = [a for a in node.args.args]
names.append(node.args.vararg) names.append(node.args.vararg)
names.append(node.args.kwarg) names.append(node.args.kwarg)
...@@ -1331,7 +1424,9 @@ def create_get_args_node(names): ...@@ -1331,7 +1424,9 @@ def create_get_args_node(names):
func_def = """ func_def = """
def {func_name}(): def {func_name}():
return return
""".format(func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX)) """.format(
func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX)
)
return gast.parse(textwrap.dedent(func_def)).body[0] return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple)) assert isinstance(names, (list, tuple))
...@@ -1350,7 +1445,8 @@ def create_get_args_node(names): ...@@ -1350,7 +1445,8 @@ def create_get_args_node(names):
func_def = template.format( func_def = template.format(
func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX), func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX),
nonlocal_vars=nonlocal_vars, nonlocal_vars=nonlocal_vars,
vars=",".join(names)) vars=",".join(names),
)
return gast.parse(textwrap.dedent(func_def)).body[0] return gast.parse(textwrap.dedent(func_def)).body[0]
...@@ -1367,8 +1463,9 @@ def create_set_args_node(names): ...@@ -1367,8 +1463,9 @@ def create_set_args_node(names):
func_def = """ func_def = """
def {func_name}({args}): def {func_name}({args}):
pass pass
""".format(func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), """.format(
args=ARGS_NAME) func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), args=ARGS_NAME
)
return gast.parse(textwrap.dedent(func_def)).body[0] return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple)) assert isinstance(names, (list, tuple))
...@@ -1388,7 +1485,8 @@ def create_set_args_node(names): ...@@ -1388,7 +1485,8 @@ def create_set_args_node(names):
func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
args=ARGS_NAME, args=ARGS_NAME,
nonlocal_vars=nonlocal_vars, nonlocal_vars=nonlocal_vars,
vars=",".join(names)) vars=",".join(names),
)
return gast.parse(textwrap.dedent(func_def)).body[0] return gast.parse(textwrap.dedent(func_def)).body[0]
...@@ -1398,8 +1496,8 @@ def create_nonlocal_stmt_nodes(names): ...@@ -1398,8 +1496,8 @@ def create_nonlocal_stmt_nodes(names):
mapped = list(filter(lambda n: '.' not in n, names)) mapped = list(filter(lambda n: '.' not in n, names))
mapped = list(filter(lambda n: '[' not in n, mapped)) mapped = list(filter(lambda n: '[' not in n, mapped))
names = sorted( names = sorted(
mapped, mapped, key=mapped.index
key=mapped.index) # to keep the order, we can't use set() to unique ) # to keep the order, we can't use set() to unique
if not names: if not names:
return [] return []
func_code = "nonlocal {}".format(','.join(names)) func_code = "nonlocal {}".format(','.join(names))
...@@ -1407,7 +1505,7 @@ def create_nonlocal_stmt_nodes(names): ...@@ -1407,7 +1505,7 @@ def create_nonlocal_stmt_nodes(names):
class GetterSetterHelper: class GetterSetterHelper:
""" we have two classes of names in setter and getter function: """we have two classes of names in setter and getter function:
w_vars(loop_vars) + push_pop_vars w_vars(loop_vars) + push_pop_vars
To simplify the setter logic in convert_while and convert_cond, To simplify the setter logic in convert_while and convert_cond,
we extract the helper class here. we extract the helper class here.
...@@ -1426,22 +1524,33 @@ class GetterSetterHelper: ...@@ -1426,22 +1524,33 @@ class GetterSetterHelper:
return self._union return self._union
def get(self, names): def get(self, names):
if names is None: names = [] if names is None:
names = []
vars = self.getter() vars = self.getter()
if vars is None: return tuple() if vars is None:
return tuple()
for n in names: for n in names:
assert n in self.name2id, "the name `{}` not in name union set`{}`.".format( assert (
n, self.name2id.keys()) n in self.name2id
), "the name `{}` not in name union set`{}`.".format(
n, self.name2id.keys()
)
return tuple(map(lambda n: vars[self.name2id[n]], names)) return tuple(map(lambda n: vars[self.name2id[n]], names))
def set(self, names, values): def set(self, names, values):
if names is None: names = [] if names is None:
if values is None: values = [] names = []
if values is None:
values = []
vars = self.getter() vars = self.getter()
if vars is None: return if vars is None:
return
for n in names: for n in names:
assert n in self.name2id, "the name `{}` not in name union set`{}`.".format( assert (
n, self.name2id.keys()) n in self.name2id
), "the name `{}` not in name union set`{}`.".format(
n, self.name2id.keys()
)
vars = list(vars) vars = list(vars)
indices = list(map(lambda n: self.name2id[n], names)) indices = list(map(lambda n: self.name2id[n], names))
for i, v in zip(indices, values): for i, v in zip(indices, values):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册