未验证 提交 f9198372 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Support nonlocal mechanism in IF ast transformer (#43666)

* [Dy2Stat]Support nonlocal mechanism in IF ast transformer

* support prune return vars in cond

* fix unittest

* fix unittest

* fix static check
上级 295f289a
......@@ -21,6 +21,7 @@ from paddle.fluid.layers import assign, fill_constant, slice, reduce_all, reduce
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
def convert_while_loop(cond, body, loop_vars):
......@@ -188,7 +189,8 @@ def _run_py_logical_not(x):
return not x
def convert_ifelse(pred, true_fn, false_fn, true_args, false_args):
def convert_ifelse(pred, true_fn, false_fn, get_args, set_args,
return_name_ids):
"""
A function representation of a Python ``if/else`` statement.
......@@ -196,17 +198,18 @@ def convert_ifelse(pred, true_fn, false_fn, true_args, false_args):
pred(bool|Tensor): A boolean Tensor which determines whether to return the result of ``true_fn`` or ``false_fn`` .
true_fn(callable): A callable to be performed if ``pred`` is true.
false_fn(callable): A callable to be performed if ``pred`` is false.
true_args(tuple): Parameters of ``true_fn``.
false_args(tuple): Parameters of ``false_fn``.
get_args(callable): Get all arguments that needed in true_fn and false_fn.
set_args(callable): Update arguments that modified in trure_fn and false_fn.
Returns:
``true_fn(true_args)`` if the predicate ``pred`` is true else ``false_fn(false_args)`` .
``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` .
"""
if isinstance(pred, Variable):
out = _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args)
out = _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
return_name_ids)
else:
out = _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args)
out = _run_py_ifelse(pred, true_fn, false_fn)
return _remove_no_value_return_var(out)
......@@ -244,14 +247,59 @@ def _remove_no_value_return_var(out):
return out
def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args):
def _check_no_undefined_var(outs, names, branch_name):
if names is None: return
if not isinstance(outs, (list, tuple)):
outs = [outs]
for var, name in zip(list(outs), names):
if isinstance(var, UndefinedVar):
raise ValueError(
"Required '{}' must be initialized both in if-else branch, but found it not initialized in '{}'."
.format(name, branch_name))
def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
return_name_ids):
"""
Paddle cond API will evaluate both ture_fn and false_fn codes.
"""
pred = cast_bool_if_necessary(pred)
return control_flow.cond(pred, lambda: true_fn(*true_args),
lambda: false_fn(*false_args))
init_args = get_args()
def new_true_fn():
set_args(init_args)
outs = true_fn()
_check_no_undefined_var(outs, return_name_ids, 'if_body')
return outs
def new_false_fn():
set_args(init_args)
outs = false_fn()
_check_no_undefined_var(outs, return_name_ids, 'else_body')
return outs
cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn)
# IfExpr's return_name_ids maybe None
if return_name_ids is None:
return cond_outs
# recover args state
num_outs = len(return_name_ids)
num_args = 1 if not isinstance(init_args, tuple) else len(init_args)
assert num_outs <= num_args
if num_args == 1:
final_outs = cond_outs
else:
cond_outs = (cond_outs, ) if num_outs == 1 else cond_outs
final_outs = cond_outs + init_args[num_outs:]
set_args(final_outs)
return final_outs
def _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args):
return true_fn(*true_args) if pred else false_fn(*false_args)
def _run_py_ifelse(pred, true_fn, false_fn):
return true_fn() if pred else false_fn()
def convert_len(var):
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import six
import copy
import textwrap
from collections import defaultdict
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
......@@ -29,10 +30,14 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, as
from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_nonlocal_stmt_node
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
GET_ARGS_FUNC_PREFIX = 'get_args'
SET_ARGS_FUNC_PREFIX = 'set_args'
ARGS_NAME = '__args'
class IfElseTransformer(gast.NodeTransformer):
......@@ -56,13 +61,16 @@ class IfElseTransformer(gast.NodeTransformer):
def visit_If(self, node):
self.generic_visit(node)
new_vars_stmts, true_func_node, false_func_node, return_name_ids = transform_if_else(
new_vars_stmts, true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids = transform_if_else(
node, self.root)
new_node = create_convert_ifelse_node(return_name_ids, node.test,
true_func_node, false_func_node)
true_func_node, false_func_node,
get_args_node, set_args_node)
return new_vars_stmts + [true_func_node, false_func_node] + [new_node]
return new_vars_stmts + [
get_args_node, set_args_node, true_func_node, false_func_node
] + [new_node]
def visit_Call(self, node):
# Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
......@@ -80,7 +88,7 @@ class IfElseTransformer(gast.NodeTransformer):
self.generic_visit(node)
new_node = create_convert_ifelse_node(None, node.test, node.body,
node.orelse, True)
node.orelse, None, None, True)
# Note: A blank line will be added separately if transform gast.Expr
# into source code. Using gast.Expr.value instead to avoid syntax error
# in python.
......@@ -192,6 +200,12 @@ class NameVisitor(gast.NodeVisitor):
self.generic_visit(node)
def visit_FunctionDef(self, node):
# NOTE: We skip to visit names of get_args and set_args, because they contains
# nonlocal statement such as 'nonlocal x, self' where 'self' should not be
# parsed as returned value in contron flow.
if GET_ARGS_FUNC_PREFIX in node.name or SET_ARGS_FUNC_PREFIX in node.name:
return
if not self._in_range:
self.generic_visit(node)
return
......@@ -269,7 +283,7 @@ def get_name_ids(nodes, after_node=None, end_node=None):
return name_visitor.name_ids
def parse_cond_args(parent_ids_dict,
def parse_cond_args(parent_ids,
var_ids_dict,
modified_ids_dict=None,
ctx=gast.Load):
......@@ -307,24 +321,9 @@ def parse_cond_args(parent_ids_dict,
# ```
#
# In the above case, `v` should not be in the args of cond()
arg_name_ids = list(set(arg_name_ids) & set(parent_ids_dict))
arg_name_ids.sort()
args = [
gast.Name(id=name_id,
ctx=gast.Load(),
annotation=None,
type_comment=None) for name_id in arg_name_ids
]
arguments = gast.arguments(args=args,
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[])
arg_name_ids = set(arg_name_ids) & set(parent_ids)
return arguments
return arg_name_ids
def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
......@@ -454,10 +453,35 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
return return_ids, modified_vars_from_parent, new_vars_to_create
def _valid_nonlocal_names(return_name_ids, nonlocal_names):
"""
All var in return_name_ids should be in nonlocal_names.
Moreover, we will always put return_name_ids in front of nonlocal_names.
For Example:
return_name_ids: [x, y]
nonlocal_names : [a, y, b, x]
Return:
nonlocal_names : [x, y, a, b]
"""
assert isinstance(return_name_ids, list)
for name in return_name_ids:
if name not in nonlocal_names:
raise ValueError(
"Required returned var '{}' must be in 'nonlocal' statement '', but not found."
.format(name))
nonlocal_names.remove(name)
return return_name_ids + nonlocal_names
def transform_if_else(node, root):
"""
Transform ast.If into control flow statement of Paddle static graph.
"""
# TODO(liym27): Consider variable like `self.a` modified in if/else node.
parent_name_ids = get_name_ids([root], end_node=node)
body_name_ids = get_name_ids(node.body)
......@@ -480,73 +504,134 @@ def transform_if_else(node, root):
for name in new_vars_to_create:
# NOTE: Consider variable like `self.a` modified in if/else node.
if "." not in name:
create_new_vars_in_parent_stmts.append(
create_static_variable_gast_node(name))
modified_name_ids = modified_name_ids_from_parent | new_vars_to_create
create_new_vars_in_parent_stmts.append(create_undefined_var(name))
parent_ids_set = set()
for k, ctxs in parent_name_ids.items():
if any([not isinstance(ctx, gast.Load) for ctx in ctxs]):
parent_ids_set.add(k)
trun_args = parse_cond_args(parent_ids_set, body_name_ids,
modified_name_ids_from_parent)
false_args = parse_cond_args(parent_ids_set, orelse_name_ids,
modified_name_ids_from_parent)
nonlocal_names = list(trun_args | false_args | new_vars_to_create)
nonlocal_names.sort()
# NOTE: All var in return_name_ids should be in nonlocal_names.
nonlocal_names = _valid_nonlocal_names(return_name_ids, nonlocal_names)
# TODO(dev): Need a better way to deal this.
if ARGS_NAME in nonlocal_names:
nonlocal_names.remove(ARGS_NAME)
nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)]
empty_arg_node = gast.arguments(args=[],
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[])
true_func_node = create_funcDef_node(
node.body,
nonlocal_stmt_node + node.body,
name=unique_name.generate(TRUE_FUNC_PREFIX),
input_args=parse_cond_args(parent_name_ids, body_name_ids,
modified_name_ids),
input_args=empty_arg_node,
return_name_ids=return_name_ids)
false_func_node = create_funcDef_node(
node.orelse,
nonlocal_stmt_node + node.orelse,
name=unique_name.generate(FALSE_FUNC_PREFIX),
input_args=parse_cond_args(parent_name_ids, orelse_name_ids,
modified_name_ids),
input_args=empty_arg_node,
return_name_ids=return_name_ids)
return create_new_vars_in_parent_stmts, true_func_node, false_func_node, return_name_ids
get_args_node = create_get_args_node(nonlocal_names)
set_args_node = create_set_args_node(nonlocal_names)
return create_new_vars_in_parent_stmts, true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids
def create_get_args_node(names):
"""
Create get_args function as follows:
def get_args_0():
nonlocal x, y
"""
assert isinstance(names, (list, tuple))
template = """
def {func_name}():
nonlocal {vars}
return {vars}
"""
func_def = template.format(
func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX),
vars=",".join(names))
return gast.parse(textwrap.dedent(func_def)).body[0]
def create_set_args_node(names):
"""
Create set_args function as follows:
def set_args_0(__args):
nonlocal x, y
x, y = __args
"""
assert isinstance(names, (list, tuple))
template = """
def {func_name}({args}):
nonlocal {vars}
{vars} = {args}
"""
func_def = template.format(
func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
args=ARGS_NAME,
vars=",".join(names))
return gast.parse(textwrap.dedent(func_def)).body[0]
def create_convert_ifelse_node(return_name_ids,
pred,
true_func,
false_func,
get_args_func,
set_args_func,
is_if_expr=False):
"""
Create `paddle.jit.dy2static.convert_ifelse(
pred, true_fn, false_fn, true_args, false_args)`
pred, true_fn, false_fn, get_args, set_args, return_name_ids)`
to replace original `python if/else` statement.
"""
def create_name_nodes(name_ids):
def create_name_str(name_ids):
"""
Return "('x', 'y')" for [x, y]
"""
if not name_ids:
return gast.Tuple(elts=[], ctx=gast.Load())
return 'None'
gast_names = [
gast.Name(id=name_id,
ctx=gast.Load(),
annotation=None,
type_comment=None) for name_id in name_ids
]
name_node = gast.Tuple(elts=gast_names, ctx=gast.Load())
return name_node
names_str = ["'%s'" % name for name in name_ids]
return "(%s, )" % ','.join(names_str)
if is_if_expr:
true_args = gast.Tuple(elts=[], ctx=gast.Load())
false_args = gast.Tuple(elts=[], ctx=gast.Load())
true_func_source = "lambda : {}".format(ast_to_source_code(true_func))
false_func_source = "lambda : {}".format(ast_to_source_code(false_func))
else:
true_args = gast.Tuple(elts=true_func.args.args, ctx=gast.Load())
false_args = gast.Tuple(elts=false_func.args.args, ctx=gast.Load())
true_func_source = true_func.name
false_func_source = false_func.name
convert_ifelse_layer = gast.parse(
'_jst.convert_ifelse('
'{pred}, {true_fn}, {false_fn}, {true_args}, {false_args})'.format(
'{pred}, {true_fn}, {false_fn}, {get_args}, {set_args}, {return_name_ids})'
.format(
pred=ast_to_source_code(pred),
true_fn=true_func_source,
false_fn=false_func_source,
true_args=ast_to_source_code(true_args),
false_args=ast_to_source_code(false_args))).body[0].value
if return_name_ids:
_, cond_node = create_assign_node(return_name_ids, convert_ifelse_layer)
else: # No variables can be returned if no assign statement in if.body.
cond_node = gast.Expr(value=convert_ifelse_layer)
get_args=get_args_func.name if not is_if_expr else
'lambda: None', #TODO: better way to deal with this
set_args=set_args_func.name
if not is_if_expr else 'lambda args: None',
return_name_ids=create_name_str(return_name_ids))).body[0]
return cond_node
return convert_ifelse_layer
......@@ -87,6 +87,23 @@ FullArgSpec = collections.namedtuple('FullArgSpec', [
])
class UndefinedVar:
def __init__(self, name):
self.name = name
def check(self):
raise UnboundLocalError(
"local variable '{}' should be created before using it.")
def saw(x):
if isinstance(x, UndefinedVar):
return x.check()
else:
return x
def getfullargspec(target):
if hasattr(inspect, "getfullargspec"):
return inspect.getfullargspec(target)
......
......@@ -25,7 +25,7 @@ from paddle.fluid.layer_helper import LayerHelper
__all__ = [
'create_bool_as_type', 'create_fill_constant_node',
'create_static_variable_gast_node', 'data_layer_not_check',
'to_static_variable', 'to_static_variable_gast_node'
'to_static_variable', 'to_static_variable_gast_node', 'create_undefined_var'
]
......@@ -74,6 +74,17 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
need_check_feed=False)
def create_undefined_var(name):
func_code = "{} = _jst.UndefinedVar('{}')".format(name, name)
return gast.parse(func_code).body[0]
def create_nonlocal_stmt_node(names):
assert isinstance(names, (list, tuple))
func_code = "nonlocal {}".format(','.join(names))
return gast.parse(func_code).body[0]
def to_static_variable_gast_node(name):
func_code = "{} = _jst.to_static_variable({})".format(name, name)
return gast.parse(func_code).body[0]
......
......@@ -72,33 +72,51 @@ class StaticCode1():
name='__return_value_init_0')
__return_value_0 = __return_value_init_0
def true_fn_0(x_v):
def get_args_0():
nonlocal x_v
return x_v
def set_args_0(__args):
nonlocal x_v
x_v = __args
def true_fn_0():
nonlocal x_v
x_v = x_v - 1
return x_v
def false_fn_0(x_v):
def false_fn_0():
nonlocal x_v
x_v = x_v + 1
return x_v
x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ),
(x_v, ))
_jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, get_args_0,
set_args_0, ('x_v', ))
def get_args_1():
nonlocal __return_value_0, label, x_v
return __return_value_0, label, x_v
def set_args_1(__args):
nonlocal __return_value_0, label, x_v
__return_value_0, label, x_v = __args
def true_fn_1(__return_value_0, label, x_v):
def true_fn_1():
nonlocal __return_value_0, label, x_v
loss = fluid.layers.cross_entropy(x_v, label)
__return_0 = _jst.create_bool_as_type(label is not None, True)
__return_value_0 = loss
return __return_value_0
def false_fn_1(__return_value_0, label, x_v):
def false_fn_1():
nonlocal __return_value_0, label, x_v
__return_1 = _jst.create_bool_as_type(label is not None, True)
__return_value_0 = x_v
return __return_value_0
__return_value_0 = _jst.convert_ifelse(label is not None, true_fn_1,
false_fn_1,
(__return_value_0, label, x_v),
(__return_value_0, label, x_v))
_jst.convert_ifelse(label is not None, true_fn_1, false_fn_1,
get_args_1, set_args_1, ('__return_value_0', ))
return __return_value_0
......@@ -111,33 +129,51 @@ class StaticCode2():
name='__return_value_init_1')
__return_value_1 = __return_value_init_1
def true_fn_2(x_v):
def get_args_2():
nonlocal x_v
return x_v
def set_args_2(__args):
nonlocal x_v
x_v = __args
def true_fn_2():
nonlocal x_v
x_v = x_v - 1
return x_v
def false_fn_2(x_v):
def false_fn_2():
nonlocal x_v
x_v = x_v + 1
return x_v
x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, (x_v, ),
(x_v, ))
_jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, get_args_2,
set_args_2, ('x_v', ))
def get_args_3():
nonlocal __return_value_1, label, x_v
return __return_value_1, label, x_v
def set_args_3(__args):
nonlocal __return_value_1, label, x_v
__return_value_1, label, x_v = __args
def true_fn_3(__return_value_1, label, x_v):
def true_fn_3():
nonlocal __return_value_1, label, x_v
loss = fluid.layers.cross_entropy(x_v, label)
__return_2 = _jst.create_bool_as_type(label is not None, True)
__return_value_1 = loss
return __return_value_1
def false_fn_3(__return_value_1, label, x_v):
def false_fn_3():
nonlocal __return_value_1, label, x_v
__return_3 = _jst.create_bool_as_type(label is not None, True)
__return_value_1 = x_v
return __return_value_1
__return_value_1 = _jst.convert_ifelse(label is not None, true_fn_3,
false_fn_3,
(__return_value_1, label, x_v),
(__return_value_1, label, x_v))
_jst.convert_ifelse(label is not None, true_fn_3, false_fn_3,
get_args_3, set_args_3, ('__return_value_1', ))
return __return_value_1
......@@ -166,6 +202,7 @@ class TestDygraphToStaticCode(unittest.TestCase):
answer = get_source_code(StaticCode2.dyfunc_with_if_else)
program_translator = ProgramTranslator()
code = program_translator.get_code(dyfunc_with_if_else)
# print(code)
self.assertEqual(answer, code)
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import saw
from .base import UndefinedVar
from .convert_call_func import convert_call # noqa: F401
from .convert_operators import cast_bool_if_necessary # noqa: F401
from .convert_operators import convert_assert # noqa: F401
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from ...fluid.dygraph.dygraph_to_static.utils import saw # noqa: F401
from ...fluid.dygraph.dygraph_to_static.utils import UndefinedVar # noqa: F401
__all__ = []
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册