未验证 提交 8279dfea 编写于 作者: X xiongkun 提交者: GitHub

[Dy2Static] Add non-local for while and for. (#43864)

* merge and add base support for non-local for

* for and while non-local support

* fix ci errors: v1

* fix bug

* fix

* fix code

* fix

* fix

* fix
上级 f720e231
...@@ -24,7 +24,7 @@ from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_ ...@@ -24,7 +24,7 @@ from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
def convert_while_loop(cond, body, loop_vars): def convert_while_loop(cond, body, getter, setter):
""" """
A function representation of a Python ``while`` statement. A function representation of a Python ``while`` statement.
...@@ -39,25 +39,36 @@ def convert_while_loop(cond, body, loop_vars): ...@@ -39,25 +39,36 @@ def convert_while_loop(cond, body, loop_vars):
# NOTE: It may be slower if cond is very expensive, but usually cond is just O(1). # NOTE: It may be slower if cond is very expensive, but usually cond is just O(1).
# If loop_vars is changed during cond callable, then it causes bug, but current logical_and/logical_not/... doesn't change the loop_vars. # If loop_vars is changed during cond callable, then it causes bug, but current logical_and/logical_not/... doesn't change the loop_vars.
pred = cond(*loop_vars) pred = cond()
if isinstance(pred, Variable): if isinstance(pred, Variable):
loop_vars = _run_paddle_while_loop(cond, body, loop_vars) loop_vars = _run_paddle_while(cond, body, getter, setter)
else: else:
loop_vars = _run_py_while(cond, body, loop_vars) loop_vars = _run_py_while(cond, body, getter, setter)
return loop_vars return loop_vars
def _run_paddle_while_loop(cond, body, loop_vars): def _run_paddle_while(cond, body, getter, setter):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors. # NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors.
loop_vars = [to_static_variable(var) for var in loop_vars] def to_list(x):
if isinstance(x, (tuple, list)): return x
return [x]
# UndefinedVar will become data layer not check.
loop_vars = [to_static_variable(var) for var in to_list(getter())]
setter(loop_vars if len(loop_vars) > 1 else
loop_vars[0]) # change the non-local var to variable
# variable maybe modified to inner var. change it into
loop_vars = control_flow.while_loop(cond, body, loop_vars) loop_vars = control_flow.while_loop(cond, body, loop_vars)
setter(loop_vars if len(loop_vars) > 1 else
loop_vars[0]) # change the non-local var to variable
return loop_vars return loop_vars
def _run_py_while(cond, body, loop_vars): def _run_py_while(cond, body, getter, setter):
while cond(*loop_vars): loop_vars = getter()
loop_vars = body(*loop_vars) while cond():
loop_vars = body()
return loop_vars return loop_vars
......
...@@ -31,7 +31,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node ...@@ -31,7 +31,8 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_nonlocal_stmt_node from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node
TRUE_FUNC_PREFIX = 'true_fn' TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn' FALSE_FUNC_PREFIX = 'false_fn'
...@@ -415,17 +416,22 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, ...@@ -415,17 +416,22 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
# modified vars # modified vars
body_modified_vars = _modified_vars(if_vars_dict, parent_vars_dict) body_modified_vars = _modified_vars(if_vars_dict, parent_vars_dict)
body_modified_vars = set(
filter(lambda x: x != ARGS_NAME, body_modified_vars))
orelse_modified_vars = _modified_vars(else_vars_dict, parent_vars_dict) orelse_modified_vars = _modified_vars(else_vars_dict, parent_vars_dict)
orelse_modified_vars = set(
filter(lambda x: x != ARGS_NAME, orelse_modified_vars))
modified_vars = body_modified_vars | orelse_modified_vars modified_vars = body_modified_vars | orelse_modified_vars
# new vars # new vars
# TODO(remove __args when new FunctionScopeAnalysis has been used.)
body_new_vars = set([ body_new_vars = set([
var for var in _vars_with_store(if_vars_dict) var for var in _vars_with_store(if_vars_dict)
if var not in parent_vars_dict if var not in parent_vars_dict and var != ARGS_NAME
]) ])
orelse_new_vars = set([ orelse_new_vars = set([
var for var in _vars_with_store(else_vars_dict) var for var in _vars_with_store(else_vars_dict)
if var not in parent_vars_dict if var not in parent_vars_dict and var != ARGS_NAME
]) ])
new_vars_in_body_or_orelse = body_new_vars | orelse_new_vars new_vars_in_body_or_orelse = body_new_vars | orelse_new_vars
new_vars_in_one_of_body_or_orelse = body_new_vars ^ orelse_new_vars new_vars_in_one_of_body_or_orelse = body_new_vars ^ orelse_new_vars
...@@ -511,11 +517,11 @@ def transform_if_else(node, root): ...@@ -511,11 +517,11 @@ def transform_if_else(node, root):
if any([not isinstance(ctx, gast.Load) for ctx in ctxs]): if any([not isinstance(ctx, gast.Load) for ctx in ctxs]):
parent_ids_set.add(k) parent_ids_set.add(k)
trun_args = parse_cond_args(parent_ids_set, body_name_ids, true_args = parse_cond_args(parent_ids_set, body_name_ids,
modified_name_ids_from_parent) modified_name_ids_from_parent)
false_args = parse_cond_args(parent_ids_set, orelse_name_ids, false_args = parse_cond_args(parent_ids_set, orelse_name_ids,
modified_name_ids_from_parent) modified_name_ids_from_parent)
nonlocal_names = list(trun_args | false_args | new_vars_to_create) nonlocal_names = list(true_args | false_args | new_vars_to_create)
nonlocal_names.sort() nonlocal_names.sort()
# NOTE: All var in return_name_ids should be in nonlocal_names. # NOTE: All var in return_name_ids should be in nonlocal_names.
nonlocal_names = _valid_nonlocal_names(return_name_ids, nonlocal_names) nonlocal_names = _valid_nonlocal_names(return_name_ids, nonlocal_names)
...@@ -552,70 +558,6 @@ def transform_if_else(node, root): ...@@ -552,70 +558,6 @@ def transform_if_else(node, root):
return create_new_vars_in_parent_stmts, true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids return create_new_vars_in_parent_stmts, true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids
def create_get_args_node(names):
"""
Create get_args function as follows:
def get_args_0():
nonlocal x, y
return x, y
"""
def empty_node():
func_def = """
def {func_name}():
return
""".format(func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX))
return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple))
if not names:
return empty_node()
template = """
def {func_name}():
nonlocal {vars}
return {vars}
"""
func_def = template.format(
func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX),
vars=",".join(names))
return gast.parse(textwrap.dedent(func_def)).body[0]
def create_set_args_node(names):
"""
Create set_args function as follows:
def set_args_0(__args):
nonlocal x, y
x, y = __args
"""
def empty_node():
func_def = """
def {func_name}({args}):
pass
""".format(func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
args=ARGS_NAME)
return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple))
if not names:
return empty_node()
template = """
def {func_name}({args}):
nonlocal {vars}
{vars} = {args}
"""
func_def = template.format(
func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
args=ARGS_NAME,
vars=",".join(names))
return gast.parse(textwrap.dedent(func_def)).body[0]
def create_convert_ifelse_node(return_name_ids, def create_convert_ifelse_node(return_name_ids,
pred, pred,
true_func, true_func,
......
...@@ -28,7 +28,10 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name ...@@ -28,7 +28,10 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ForLoopTuplePreTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import ForLoopTuplePreTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_node, create_get_args_node, create_set_args_node
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import ARGS_NAME
__all__ = ['LoopTransformer', 'NameVisitor'] __all__ = ['LoopTransformer', 'NameVisitor']
...@@ -37,12 +40,10 @@ WHILE_BODY_PREFIX = 'while_body' ...@@ -37,12 +40,10 @@ WHILE_BODY_PREFIX = 'while_body'
FOR_CONDITION_PREFIX = 'for_loop_condition' FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = 'for_loop_body' FOR_BODY_PREFIX = 'for_loop_body'
GENERATE_VARIABLE_PREFIX = 'generate_variable'
ATTRIBUTE_VARIABLE_PREFIX = '__attribute_variable'
def create_while_nodes(condition_name, body_name, loop_var_names, getter_name,
def create_while_nodes(condition_name, body_name, loop_var_names): setter_name):
""" """
Returns a list of gast.Node which represents the calling of Paddle Returns a list of gast.Node which represents the calling of Paddle
controlflow while_loop. controlflow while_loop.
...@@ -74,37 +75,20 @@ def create_while_nodes(condition_name, body_name, loop_var_names): ...@@ -74,37 +75,20 @@ def create_while_nodes(condition_name, body_name, loop_var_names):
# #
# For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name, # For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name,
# but the type of `foo.x` gast.Attribute. # but the type of `foo.x` gast.Attribute.
unique_name_to_origin = {}
# We have to make loop_var_names and assign_loop_var_names with same order # We have to make loop_var_names and assign_loop_var_names with same order
# set doesn't have order so we convert it to list # set doesn't have order so we convert it to list
loop_var_names = list(loop_var_names) loop_var_names = list(loop_var_names)
assign_loop_var_names = [] assign_loop_var_names = []
for name in (loop_var_names): for name in (loop_var_names):
if "." in name: assign_loop_var_names.append(name)
# name is an attribute variable such as foo.x
tmp_attr_name = unique_name.generate(ATTRIBUTE_VARIABLE_PREFIX)
unique_name_to_origin[tmp_attr_name] = name
assign_loop_var_names.append(tmp_attr_name)
else:
assign_loop_var_names.append(name)
while_func_name = "_jst.While" while_func_name = "_jst.While"
while_node_str = "[{}] = {}({}, {}, [{}])".format( while_node_str = "{}({}, {}, {}, {})".format(while_func_name,
",".join(assign_loop_var_names), while_func_name, condition_name, condition_name, body_name,
body_name, ",".join(loop_var_names)) getter_name, setter_name)
while_node = gast.parse(while_node_str).body[0] while_node = gast.parse(while_node_str).body[0]
ret = [while_node] ret = [while_node]
for tmp_attr_name in unique_name_to_origin:
origin_attr_var = unique_name_to_origin[tmp_attr_name]
dot_pos = origin_attr_var.rindex(".")
obj_name = origin_attr_var[0:dot_pos]
attr_name = origin_attr_var[dot_pos + 1:]
assign_if_not_prop_str = "if not isinstance(getattr(type({}), '{}', None), property): {} = {}".format(
obj_name, attr_name, origin_attr_var, tmp_attr_name)
assign_if_not_prop_node = gast.parse(assign_if_not_prop_str).body[0]
ret.append(assign_if_not_prop_node)
return ret return ret
...@@ -117,8 +101,10 @@ class NameScope: ...@@ -117,8 +101,10 @@ class NameScope:
self.globals = set() self.globals = set()
self.nonlocals = set() self.nonlocals = set()
self.args = set() self.args = set()
self.w_vars = set() # all vars been stored, # all vars been stored,
# may be globals or non-locals # may be globals or non-locals
self.w_vars = set()
def created_vars(self): def created_vars(self):
return self.w_vars - self.globals - self.nonlocals - self.args return self.w_vars - self.globals - self.nonlocals - self.args
...@@ -282,9 +268,7 @@ class NameVisitor(gast.NodeVisitor): ...@@ -282,9 +268,7 @@ class NameVisitor(gast.NodeVisitor):
# If this var is a basic variable and read-only and not # If this var is a basic variable and read-only and not
# condition var, it may not be loop_var else it should # condition var, it may not be loop_var else it should
# be in loop_var as input # be in loop_var as input
if (not name in condition_names) and ( if (not name in condition_names) and (not name in write_names):
not name in write_names
) and self._node_var_type_is_basic(name_to_type[name]):
continue continue
loop_var_names.add(name) loop_var_names.add(name)
...@@ -645,7 +629,6 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -645,7 +629,6 @@ class LoopTransformer(gast.NodeTransformer):
if stmts_tuple is None: if stmts_tuple is None:
return [node] return [node]
init_stmts, cond_stmt, body_stmts = stmts_tuple init_stmts, cond_stmt, body_stmts = stmts_tuple
# 2. get original loop vars # 2. get original loop vars
loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
node) node)
...@@ -672,7 +655,16 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -672,7 +655,16 @@ class LoopTransformer(gast.NodeTransformer):
# We need to create static variable for those variables # We need to create static variable for those variables
for name in create_var_names: for name in create_var_names:
if "." not in name: if "." not in name:
new_stmts.append(create_fill_constant_node(name)) new_stmts.append(create_undefined_var(name))
# create non-local statement for body and cond.
nonlocal_names = list(loop_var_names | create_var_names)
nonlocal_names.sort()
# TODO(dev): Need a better way to deal this.
if ARGS_NAME in nonlocal_names:
nonlocal_names.remove(ARGS_NAME)
nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)]
# 4. append init statements # 4. append init statements
new_stmts.extend(init_stmts) new_stmts.extend(init_stmts)
...@@ -680,63 +672,54 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -680,63 +672,54 @@ class LoopTransformer(gast.NodeTransformer):
# 5. create & append condition function node # 5. create & append condition function node
condition_func_node = gast.FunctionDef( condition_func_node = gast.FunctionDef(
name=unique_name.generate(FOR_CONDITION_PREFIX), name=unique_name.generate(FOR_CONDITION_PREFIX),
args=gast.arguments(args=[ args=gast.arguments(args=[],
gast.Name(id=name,
ctx=gast.Param(),
annotation=None,
type_comment=None) for name in loop_var_names
],
posonlyargs=[], posonlyargs=[],
vararg=None, vararg=gast.Name(id=ARGS_NAME,
ctx=gast.Param(),
annotation=None,
type_comment=None),
kwonlyargs=[], kwonlyargs=[],
kw_defaults=None, kw_defaults=None,
kwarg=None, kwarg=None,
defaults=[]), defaults=[]),
body=[gast.Return(value=cond_stmt)], body=nonlocal_stmt_node + [gast.Return(value=cond_stmt)],
decorator_list=[], decorator_list=[],
returns=None, returns=None,
type_comment=None) type_comment=None)
for name in loop_var_names:
if "." in name:
rename_transformer = RenameTransformer(condition_func_node)
rename_transformer.rename(
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(condition_func_node) new_stmts.append(condition_func_node)
# 6. create & append loop body function node # 6. create & append loop body function node
# append return values for loop body # append return values for loop body
body_stmts.append( body_stmts.append(
gast.Return(value=generate_name_node( gast.Return(value=generate_name_node(
loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True))) nonlocal_names, ctx=gast.Load(), gen_tuple_if_single=True)))
body_func_node = gast.FunctionDef( body_func_node = gast.FunctionDef(
name=unique_name.generate(FOR_BODY_PREFIX), name=unique_name.generate(FOR_BODY_PREFIX),
args=gast.arguments(args=[ args=gast.arguments(args=[],
gast.Name(id=name,
ctx=gast.Param(),
annotation=None,
type_comment=None) for name in loop_var_names
],
posonlyargs=[], posonlyargs=[],
vararg=None, vararg=gast.Name(id=ARGS_NAME,
ctx=gast.Param(),
annotation=None,
type_comment=None),
kwonlyargs=[], kwonlyargs=[],
kw_defaults=None, kw_defaults=None,
kwarg=None, kwarg=None,
defaults=[]), defaults=[]),
body=body_stmts, body=nonlocal_stmt_node + body_stmts,
decorator_list=[], decorator_list=[],
returns=None, returns=None,
type_comment=None) type_comment=None)
for name in loop_var_names:
if "." in name:
rename_transformer = RenameTransformer(body_func_node)
rename_transformer.rename(
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(body_func_node) new_stmts.append(body_func_node)
get_args_node = create_get_args_node(nonlocal_names)
set_args_node = create_set_args_node(nonlocal_names)
# 7. create & append while loop node # 7. create & append while loop node
while_loop_nodes = create_while_nodes(condition_func_node.name, while_loop_nodes = create_while_nodes(condition_func_node.name,
body_func_node.name, body_func_node.name,
loop_var_names) nonlocal_names,
get_args_node.name,
set_args_node.name)
new_stmts.extend([get_args_node, set_args_node])
new_stmts.extend(while_loop_nodes) new_stmts.extend(while_loop_nodes)
return new_stmts return new_stmts
...@@ -746,6 +729,15 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -746,6 +729,15 @@ class LoopTransformer(gast.NodeTransformer):
node) node)
new_stmts = [] new_stmts = []
# create non-local statement for body and cond.
nonlocal_names = list(loop_var_names | create_var_names)
nonlocal_names.sort()
# TODO(dev): Need a better way to deal this.
if ARGS_NAME in nonlocal_names:
nonlocal_names.remove(ARGS_NAME)
nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)]
# Python can create variable in loop and use it out of loop, E.g. # Python can create variable in loop and use it out of loop, E.g.
# #
# while x < 10: # while x < 10:
...@@ -760,61 +752,52 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -760,61 +752,52 @@ class LoopTransformer(gast.NodeTransformer):
condition_func_node = gast.FunctionDef( condition_func_node = gast.FunctionDef(
name=unique_name.generate(WHILE_CONDITION_PREFIX), name=unique_name.generate(WHILE_CONDITION_PREFIX),
args=gast.arguments(args=[ args=gast.arguments(args=[],
gast.Name(id=name,
ctx=gast.Param(),
annotation=None,
type_comment=None) for name in loop_var_names
],
posonlyargs=[], posonlyargs=[],
vararg=None, vararg=gast.Name(id=ARGS_NAME,
ctx=gast.Param(),
annotation=None,
type_comment=None),
kwonlyargs=[], kwonlyargs=[],
kw_defaults=None, kw_defaults=None,
kwarg=None, kwarg=None,
defaults=[]), defaults=[]),
body=[gast.Return(value=node.test)], body=nonlocal_stmt_node + [gast.Return(value=node.test)],
decorator_list=[], decorator_list=[],
returns=None, returns=None,
type_comment=None) type_comment=None)
for name in loop_var_names:
if "." in name:
rename_transformer = RenameTransformer(condition_func_node)
rename_transformer.rename(
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(condition_func_node) new_stmts.append(condition_func_node)
new_body = node.body new_body = node.body
new_body.append( new_body.append(
gast.Return(value=generate_name_node( gast.Return(value=generate_name_node(
loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True))) nonlocal_names, ctx=gast.Load(), gen_tuple_if_single=True)))
body_func_node = gast.FunctionDef( body_func_node = gast.FunctionDef(
name=unique_name.generate(WHILE_BODY_PREFIX), name=unique_name.generate(WHILE_BODY_PREFIX),
args=gast.arguments(args=[ args=gast.arguments(args=[],
gast.Name(id=name,
ctx=gast.Param(),
annotation=None,
type_comment=None) for name in loop_var_names
],
posonlyargs=[], posonlyargs=[],
vararg=None, vararg=gast.Name(id=ARGS_NAME,
ctx=gast.Param(),
annotation=None,
type_comment=None),
kwonlyargs=[], kwonlyargs=[],
kw_defaults=None, kw_defaults=None,
kwarg=None, kwarg=None,
defaults=[]), defaults=[]),
body=new_body, body=nonlocal_stmt_node + new_body,
decorator_list=[], decorator_list=[],
returns=None, returns=None,
type_comment=None) type_comment=None)
for name in loop_var_names:
if "." in name:
rename_transformer = RenameTransformer(body_func_node)
rename_transformer.rename(
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(body_func_node) new_stmts.append(body_func_node)
get_args_node = create_get_args_node(nonlocal_names)
set_args_node = create_set_args_node(nonlocal_names)
while_loop_nodes = create_while_nodes(condition_func_node.name, while_loop_nodes = create_while_nodes(condition_func_node.name,
body_func_node.name, body_func_node.name,
loop_var_names) nonlocal_names,
get_args_node.name,
set_args_node.name)
new_stmts.extend([get_args_node, set_args_node])
new_stmts.extend(while_loop_nodes) new_stmts.extend(while_loop_nodes)
return new_stmts return new_stmts
...@@ -30,6 +30,8 @@ import numpy as np ...@@ -30,6 +30,8 @@ import numpy as np
import paddle import paddle
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid import core
# Note(Aurelius): Do not forget the dot `.` to distinguish other # Note(Aurelius): Do not forget the dot `.` to distinguish other
# module such as paddlenlp. # module such as paddlenlp.
...@@ -59,6 +61,51 @@ class BaseNodeVisitor(gast.NodeVisitor): ...@@ -59,6 +61,51 @@ class BaseNodeVisitor(gast.NodeVisitor):
return ret return ret
def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
"""
This function creates a Tensor on the global block. The created Tensor
doesn't check the dtype and the shape of feed data because dygraph input
data can be various-length. This API is used in translating dygraph into
static graph.
Note:
The default :code:`stop_gradient` attribute of the Tensor created by
this API is true, which means the gradient won't be passed backward
through the data Tensor. Set :code:`var.stop_gradient = False` If
user would like to pass backward gradient.
Args:
name (str): The name/alias of the Tensor, see :ref:`api_guide_Name`
for more details.
shape (list|tuple): List|Tuple of integers declaring the shape. You can
set "None" at a dimension to indicate the dimension can be of any
size. For example, it is useful to set changeable batch size as "None"
dtype (np.dtype|VarType|str, optional): The type of the data. Supported
dtype: bool, float16, float32, float64, int8, int16, int32, int64,
uint8. Default: float32
lod_level (int, optional): The LoD level of the LoDTensor. Usually users
don't have to set this value. For more details about when and how to
use LoD level, see :ref:`user_guide_lod_tensor` . Default: 0
Returns:
Tensor: The global Tensor that gives access to the data.
"""
helper = LayerHelper('data', **locals())
shape = list(shape)
for i in six.moves.range(len(shape)):
if shape[i] is None:
shape[i] = -1
return helper.create_variable(name=name,
shape=shape,
dtype=dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
stop_gradient=True,
lod_level=lod_level,
is_data=True,
need_check_feed=False)
# imp is deprecated in python3 # imp is deprecated in python3
from importlib.machinery import SourceFileLoader from importlib.machinery import SourceFileLoader
...@@ -412,10 +459,16 @@ def generate_name_node(name_ids, ctx=gast.Load(), gen_tuple_if_single=False): ...@@ -412,10 +459,16 @@ def generate_name_node(name_ids, ctx=gast.Load(), gen_tuple_if_single=False):
raise TypeError( raise TypeError(
'name_ids must be list or tuple or set, but received %s' % 'name_ids must be list or tuple or set, but received %s' %
type(type(name_ids))) type(type(name_ids)))
gast_names = [
gast.Name(id=name_id, ctx=ctx, annotation=None, type_comment=None) def create_node_for_name(name):
for name_id in name_ids if '.' not in name:
] return gast.Name(id=name,
ctx=ctx,
annotation=None,
type_comment=None)
return gast.parse(name).body[0].value
gast_names = [create_node_for_name(name_id) for name_id in name_ids]
if len(gast_names) == 1 and not gen_tuple_if_single: if len(gast_names) == 1 and not gen_tuple_if_single:
name_node = gast_names[0] name_node = gast_names[0]
else: else:
...@@ -842,6 +895,16 @@ class NameNodeReplaceTransformer(gast.NodeTransformer): ...@@ -842,6 +895,16 @@ class NameNodeReplaceTransformer(gast.NodeTransformer):
return self.replace_node return self.replace_node
return node return node
def visit_Nonlocal(self, node):
names = node.names
def replace(s):
if s == self.target_name: return self.replace_node.id
return s
node.names = list(map(replace, names))
return node
class ForLoopTuplePreTransformer(gast.NodeTransformer): class ForLoopTuplePreTransformer(gast.NodeTransformer):
""" """
...@@ -1527,3 +1590,93 @@ def slice_is_num(slice_node): ...@@ -1527,3 +1590,93 @@ def slice_is_num(slice_node):
return True return True
return False return False
def create_get_args_node(names):
"""
Create get_args function as follows:
def get_args_0():
nonlocal x, y
return x, y
"""
def empty_node():
func_def = """
def {func_name}():
return
""".format(func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX))
return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple))
if not names:
return empty_node()
mapped = list(filter(lambda n: '.' not in n, names))
nonlocal_names = sorted(
mapped,
key=mapped.index) # to keep the order, we can't use set() to unique
template = """
def {func_name}():
nonlocal {nonlocal_vars}
return {vars}
"""
func_def = template.format(
func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX),
nonlocal_vars=','.join(nonlocal_names),
vars=",".join(names))
return gast.parse(textwrap.dedent(func_def)).body[0]
GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args'
def create_set_args_node(names):
"""
Create set_args function as follows:
def set_args_0(__args):
nonlocal x, y
x, y = __args
"""
def empty_node():
func_def = """
def {func_name}({args}):
pass
""".format(func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
args=ARGS_NAME)
return gast.parse(textwrap.dedent(func_def)).body[0]
assert isinstance(names, (list, tuple))
if not names:
return empty_node()
mapped = list(filter(lambda n: '.' not in n, names))
nonlocal_names = sorted(
mapped,
key=mapped.index) # to keep the order, we can't use set() to unique
template = """
def {func_name}({args}):
nonlocal {nonlocal_vars}
{vars} = {args}
"""
func_def = template.format(
func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
args=ARGS_NAME,
nonlocal_vars=','.join(nonlocal_names),
vars=",".join(names))
return gast.parse(textwrap.dedent(func_def)).body[0]
def create_nonlocal_stmt_node(names):
assert isinstance(names, (list, tuple))
mapped = list(filter(lambda n: '.' not in n, names))
names = sorted(
mapped,
key=mapped.index) # to keep the order, we can't use set() to unique
func_code = "nonlocal {}".format(','.join(names))
return gast.parse(func_code).body[0]
...@@ -16,15 +16,17 @@ from __future__ import print_function ...@@ -16,15 +16,17 @@ from __future__ import print_function
import six import six
import paddle import paddle
import textwrap
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid import core
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, data_layer_not_check
__all__ = [ __all__ = [
'create_bool_as_type', 'create_fill_constant_node', 'to_static_variable', 'create_bool_as_type',
'create_undefined_var' 'create_fill_constant_node',
'to_static_variable',
'create_undefined_var',
] ]
...@@ -33,12 +35,6 @@ def create_undefined_var(name): ...@@ -33,12 +35,6 @@ def create_undefined_var(name):
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
def create_nonlocal_stmt_node(names):
assert isinstance(names, (list, tuple))
func_code = "nonlocal {}".format(','.join(names))
return gast.parse(func_code).body[0]
def create_fill_constant_node(name, value=0): def create_fill_constant_node(name, value=0):
func_code = "{} = paddle.full(shape=[1], ".format(name) func_code = "{} = paddle.full(shape=[1], ".format(name)
if isinstance(value, bool): if isinstance(value, bool):
...@@ -66,7 +62,9 @@ def to_static_variable(x): ...@@ -66,7 +62,9 @@ def to_static_variable(x):
return paddle.full(shape=[1], dtype='float64', fill_value=x) return paddle.full(shape=[1], dtype='float64', fill_value=x)
if isinstance(x, six.integer_types): if isinstance(x, six.integer_types):
return paddle.full(shape=[1], dtype='int64', fill_value=x) return paddle.full(shape=[1], dtype='int64', fill_value=x)
if isinstance(x, UndefinedVar):
return data_layer_not_check(unique_name.generator("loop_undefined_var"),
[-1])
return x return x
......
...@@ -177,7 +177,6 @@ def test_list_pop_in_for_loop(x, iter_num): ...@@ -177,7 +177,6 @@ def test_list_pop_in_for_loop(x, iter_num):
one = fluid.layers.ones(shape=[1], dtype="int32") one = fluid.layers.ones(shape=[1], dtype="int32")
for i in range(one.numpy()[0]): for i in range(one.numpy()[0]):
item = a.pop() item = a.pop()
return a[0], item, b[1] return a[0], item, b[1]
......
...@@ -270,7 +270,7 @@ class TestNameVisitor(unittest.TestCase): ...@@ -270,7 +270,7 @@ class TestNameVisitor(unittest.TestCase):
self.loop_var_names = [ self.loop_var_names = [
set(["j", "two"]), set(["j", "two"]),
set(["i", "three", "b"]), set(["i", "three", "b"]),
set(["i", "j"]) set(["i"])
] ]
self.create_var_names = [set(), set(["b"]), set()] self.create_var_names = [set(), set(["b"]), set()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册