未验证 提交 5e8e6dad 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Support Various-Length Return Grammar in Dy2stat (#25249)

Support Various-Length Return Grammar in Dy2stat. This PR is a follow-up of https://github.com/PaddlePaddle/Paddle/pull/25176 .

The basic idea is putting no-value placeholder variables at `return` statement to make all `return` statement have same length, after that the static graph can have fixed fetch output (code at return_transformer.py). Then remove those no-value placeholder when we finally return dygraph result (code at partial_program.py).

However, various length return in Bert model is still not supported. The dy2stat can change the code as I wish but some ops which check shape at compile time (e.g. Reshape, MatMul) will throw error because of the no-value-placeholder may not have the required shape. Is this a matter? To me, those no-value placeholder will be replaced as really values meeting shape requirements at run time, so I think the solution should be some way to do the compile-time checking. By the way, every time when we have dynamic shape, it often causes problem in dy2stat. We should find a way to handle it in the future.

Fixing various return in Bert is my TODO thing and I will also find some other existing models for verification.
上级 de27569e
......@@ -19,9 +19,10 @@ import logging
from paddle.fluid import log_helper
from paddle.fluid import framework, backward, core
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph
import paddle.compat as cpt
_logger = log_helper.get_logger(
......@@ -184,7 +185,8 @@ class PartialProgramLayer(layers.Layer):
'is_test': not self.training
})
return self._restore_out(out_vars)
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
def _prepare(self, inputs):
"""
......@@ -239,11 +241,44 @@ class PartialProgramLayer(layers.Layer):
for i, idx in enumerate(self._outputs.var_ids):
flatten_outputs[idx] = out_vars[i]
outs = self._outputs.restore(flatten_outputs)
if len(outs) == 1:
if outs is not None and len(outs) == 1:
outs = outs[0]
return outs
def _is_no_value(self, var):
if isinstance(var, core.VarBase):
if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
return True
return False
def _remove_no_value(self, out_vars):
"""
Removes invalid value for various-length return statement
"""
if isinstance(out_vars, core.VarBase):
if self._is_no_value(out_vars):
return None
return out_vars
elif isinstance(out_vars, (tuple, list)):
if isinstance(out_vars, tuple):
res = tuple(
var for var in out_vars if not self._is_no_value(var))
else:
# isinstance(out_vars, list)
res = [var for var in out_vars if not self._is_no_value(var)]
has_removed = (len(out_vars) > len(res))
# len(out_vars) > len(res) means we have removed var. This is
# preventing out_vars is empty or just one element at the beginning
if len(res) == 0 and has_removed:
return None
elif len(res) == 1 and has_removed:
return res[0]
return res
return out_vars
def _set_grad_type(self, params):
# NOTE: if user set sparse gradient mode, the param's gradient
# will be SelectedRows, not LoDTensor. But tracer will just
......
......@@ -278,8 +278,9 @@ class ConcreteProgram(object):
with param_guard(func_spec.parameters(False)), param_guard(
func_spec.buffers(False)):
outputs = static_func(*inputs)
if not isinstance(outputs, (tuple, list)):
outputs = [outputs] if outputs else []
if not isinstance(outputs,
(tuple, list)) and outputs is not None:
outputs = [outputs]
return ConcreteProgram(
inputs=inputs,
......
......@@ -21,7 +21,9 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ForToWhileTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
__all__ = ['ReturnTransformer']
__all__ = [
'RETURN_NO_VALUE_MAGIC_NUM', 'RETURN_NO_VALUE_VAR_NAME', 'ReturnTransformer'
]
# Constant for the name of the variable which stores the boolean state that we
# should return
......@@ -30,10 +32,56 @@ RETURN_PREFIX = '__return'
# Constant for the name of the variable which stores the final return value
RETURN_VALUE_PREFIX = '__return_value'
# Constant for the name of variables to initialize the __return_value
RETURN_VALUE_INIT_NAME = '__return_value_init'
class ReturnPreAnalysisVisitor(gast.NodeVisitor):
# Constant magic number representing returning no value. This constant amis to
# support returning various lengths of variables. Static graph must have fixed
# size of fetched output while dygraph can have flexible lengths of output, to
# solve it in dy2stat, we put float64 value with this magic number at Static
# graph as a place holder to indicate the returning placeholder means no value
# should return.
RETURN_NO_VALUE_MAGIC_NUM = 1.77113e+279
RETURN_NO_VALUE_VAR_NAME = "__no_value_return_var"
def get_return_size(return_node):
assert isinstance(return_node, gast.Return), "Input is not gast.Return node"
return_length = 0
if return_node.value is not None:
if isinstance(return_node.value, gast.Tuple):
return_length = len(return_node.value.elts)
else:
return_length = 1
return return_length
class ReplaceReturnNoneTransformer(gast.NodeTransformer):
"""
Replace 'return None' to 'return' because 'None' cannot be a valid input
in control flow. In ReturnTransformer single 'Return' will be appended no
value placeholder
"""
def __init__(self, root_node):
self.root = root_node
def transform(self):
self.visit(self.root)
def visit_Return(self, node):
if isinstance(node.value, gast.Name) and node.value.id == 'None':
node.value = None
return node
if isinstance(node.value, gast.Constant) and node.value.value == None:
node.value = None
return node
return node
class ReturnAnalysisVisitor(gast.NodeVisitor):
"""
Visits gast Tree and pre-analyze the information about 'return'.
Visits gast Tree and analyze the information about 'return'.
"""
def __init__(self, root_node):
......@@ -45,11 +93,16 @@ class ReturnPreAnalysisVisitor(gast.NodeVisitor):
# Mapping from gast.FunctionDef node to the number of return statements
# Python allows define function inside function so we have to handle it
self.count_return = {}
# Mapping from gast.FunctionDef node to the maximum number of variables
# returned by the function's return statement
self.max_return_length = {}
self.visit(self.root)
def visit_FunctionDef(self, node):
self.function_def.append(node)
self.count_return[node] = 0
self.max_return_length[node] = 0
self.generic_visit(node)
self.function_def.pop()
return node
......@@ -62,13 +115,21 @@ class ReturnPreAnalysisVisitor(gast.NodeVisitor):
self.count_return[cur_func] += 1
else:
self.count_return[cur_func] = 1
return_length = get_return_size(node)
if cur_func in self.max_return_length:
self.max_return_length[cur_func] = max(
self.max_return_length[cur_func], return_length)
else:
self.max_return_length[cur_func] = return_length
self.generic_visit(node)
def get_func_return_count(self, func_node):
return self.count_return[func_node]
def set_func_return_count(self, func_node, count):
self.count_return[func_node] = count
def get_func_max_return_length(self, func_node):
return self.max_return_length[func_node]
class ReturnTransformer(gast.NodeTransformer):
......@@ -83,17 +144,25 @@ class ReturnTransformer(gast.NodeTransformer):
def __init__(self, wrapper_root):
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.ancestor_nodes = []
pre_transformer = ReplaceReturnNoneTransformer(self.root)
pre_transformer.transform()
self.ancestor_nodes = []
# The name of the variable which stores the final return value
# Mapping from FunctionDef node to string
self.return_value_name = {}
# The names of the variable which stores the boolean state that skip
# statments. Mapping from FunctionDef node to list
self.return_name = {}
# The names of the variable which is placeholder to handle various-
# length return. Mapping from FunctionDef node to list
self.return_no_value_name = {}
# A list of FunctionDef to store where the current function is.
self.function_def = []
self.pre_analysis = None
def transform(self):
self.visit(self.root)
......@@ -125,13 +194,19 @@ class ReturnTransformer(gast.NodeTransformer):
self.function_def.append(node)
self.return_value_name[node] = None
self.return_name[node] = []
self.return_no_value_name[node] = []
pre_analysis = ReturnPreAnalysisVisitor(node)
while pre_analysis.get_func_return_count(node) > 1:
self.pre_analysis = ReturnAnalysisVisitor(node)
max_return_length = self.pre_analysis.get_func_max_return_length(node)
while self.pre_analysis.get_func_return_count(node) > 1:
self.generic_visit(node)
pre_analysis = ReturnPreAnalysisVisitor(node)
self.pre_analysis = ReturnAnalysisVisitor(node)
# prepend initialization of final return and append final return statement
if max_return_length == 0:
self.function_def.pop()
return node
# Prepend initialization of final return and append final return statement
value_name = self.return_value_name[node]
if value_name is not None:
node.body.append(
......@@ -140,12 +215,51 @@ class ReturnTransformer(gast.NodeTransformer):
ctx=gast.Load(),
annotation=None,
type_comment=None)))
assign_zero_node = create_fill_constant_node(value_name, 0.0)
node.body.insert(0, assign_zero_node)
init_names = [
unique_name.generate(RETURN_VALUE_INIT_NAME)
for i in range(max_return_length)
]
assign_zero_nodes = [
create_fill_constant_node(iname, 0.0) for iname in init_names
]
if len(init_names) == 1:
return_value_nodes = gast.Name(
id=init_names[0],
ctx=gast.Load(),
annotation=None,
type_comment=None)
else:
# We need to initialize return value as a tuple because control
# flow requires some inputs or outputs have same structure
return_value_nodes = gast.Tuple(
elts=[
gast.Name(
id=iname,
ctx=gast.Load(),
annotation=None,
type_comment=None) for iname in init_names
],
ctx=gast.Load())
assign_return_value_node = gast.Assign(
targets=[
gast.Name(
id=value_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=return_value_nodes)
node.body.insert(0, assign_return_value_node)
node.body[:0] = assign_zero_nodes
# Prepend control flow boolean nodes such as '__return@1 = False'
for name in self.return_name[node]:
assign_false_node = create_fill_constant_node(name, False)
node.body.insert(0, assign_false_node)
# Prepend no value placeholders
for name in self.return_no_value_name[node]:
assign_no_value_node = create_fill_constant_node(
name, RETURN_NO_VALUE_MAGIC_NUM)
node.body.insert(0, assign_no_value_node)
self.function_def.pop()
return node
......@@ -154,21 +268,24 @@ class ReturnTransformer(gast.NodeTransformer):
cur_func_node = self.function_def[-1]
return_name = unique_name.generate(RETURN_PREFIX)
self.return_name[cur_func_node].append(return_name)
max_return_length = self.pre_analysis.get_func_max_return_length(
cur_func_node)
for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)):
ancestor = self.ancestor_nodes[ancestor_index]
cur_node = self.ancestor_nodes[ancestor_index + 1]
if hasattr(ancestor,
"body") and index_in_list(ancestor.body, cur_node) != -1:
if cur_node == node:
self._replace_return_in_stmt_list(ancestor.body, cur_node,
return_name)
self._replace_return_in_stmt_list(
ancestor.body, cur_node, return_name, max_return_length)
self._replace_after_node_to_if_in_stmt_list(
ancestor.body, cur_node, return_name)
elif hasattr(ancestor, "orelse") and index_in_list(ancestor.orelse,
cur_node) != -1:
if cur_node == node:
self._replace_return_in_stmt_list(ancestor.orelse, cur_node,
return_name)
return_name,
max_return_length)
self._replace_after_node_to_if_in_stmt_list(
ancestor.orelse, cur_node, return_name)
......@@ -203,16 +320,81 @@ class ReturnTransformer(gast.NodeTransformer):
break
# return_node is replaced so we shouldn't return here
def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name):
def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name,
max_return_length):
assert max_return_length >= 0, "Input illegal max_return_length"
i = index_in_list(stmt_list, return_node)
if i == -1:
return False
assign_nodes = [create_fill_constant_node(return_name, True)]
if return_node.value is not None:
cur_func_node = self.function_def[-1]
return_length = get_return_size(return_node)
if return_length < max_return_length:
# In this case we should append RETURN_NO_VALUE placeholder
#
# max_return_length must be >= 1 here because return_length will be
# 0 at least.
if self.return_value_name[cur_func_node] is None:
self.return_value_name[cur_func_node] = unique_name.generate(
RETURN_VALUE_PREFIX)
no_value_names = [
unique_name.generate(RETURN_NO_VALUE_VAR_NAME)
for j in range(max_return_length - return_length)
]
self.return_no_value_name[cur_func_node].extend(no_value_names)
# Handle tuple/non-tuple case
if max_return_length == 1:
assign_nodes.append(
gast.Assign(
targets=[
gast.Name(
id=self.return_value_name[cur_func_node],
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=gast.Name(
id=no_value_names[0],
ctx=gast.Load(),
annotation=None,
type_comment=None)))
else:
# max_return_length > 1 which means we should assign tuple
fill_tuple = [
gast.Name(
id=n,
ctx=gast.Load(),
annotation=None,
type_comment=None) for n in no_value_names
]
if return_node.value is not None:
if isinstance(return_node.value, gast.Tuple):
fill_tuple[:0] = return_node.value.elts
else:
fill_tuple.insert(0, return_node.value)
assign_nodes.append(
gast.Assign(
targets=[
gast.Name(
id=self.return_value_name[cur_func_node],
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=gast.Tuple(
elts=fill_tuple, ctx=gast.Load())))
else:
# In this case we should NOT append RETURN_NO_VALUE placeholder
if return_node.value is not None:
cur_func_node = self.function_def[-1]
if self.return_value_name[cur_func_node] is None:
self.return_value_name[
cur_func_node] = unique_name.generate(
RETURN_VALUE_PREFIX)
assign_nodes.append(
gast.Assign(
targets=[
......@@ -223,6 +405,7 @@ class ReturnTransformer(gast.NodeTransformer):
type_comment=None)
],
value=return_node.value))
stmt_list[i:] = assign_nodes
return True
......
......@@ -67,8 +67,9 @@ class StaticCode1():
shape=[1], dtype='bool', value=False)
__return_0 = fluid.layers.fill_constant(
shape=[1], dtype='bool', value=False)
__return_value_0 = fluid.layers.fill_constant(
__return_value_init_0 = fluid.layers.fill_constant(
shape=[1], dtype='float64', value=0.0)
__return_value_0 = __return_value_init_0
def true_fn_0(x_v):
x_v = x_v - 1
......@@ -123,8 +124,9 @@ class StaticCode2():
shape=[1], dtype='bool', value=False)
__return_2 = fluid.layers.fill_constant(
shape=[1], dtype='bool', value=False)
__return_value_1 = fluid.layers.fill_constant(
__return_value_init_1 = fluid.layers.fill_constant(
shape=[1], dtype='float64', value=0.0)
__return_value_1 = __return_value_init_1
def true_fn_3(x_v):
x_v = x_v - 1
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.dygraph import declarative
from paddle.fluid.dygraph import ProgramTranslator
......@@ -96,6 +97,56 @@ def test_recursive_return(x):
return dyfunc_with_if_else(x)
@declarative
def test_return_different_length_if_body(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
if x > 0:
# x = to_variable(np.ones(1)) so it will return here
return x, y
else:
return x
@declarative
def test_return_different_length_else(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
if x < 0:
return x, y
else:
# x = to_variable(np.ones(1)) so it will return here
return x
@declarative
def test_no_return(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
@declarative
def test_return_none(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
if x > 0:
# x = to_variable(np.ones(1)) so it will return here
return None
else:
return x, y
@declarative
def test_return_no_variable(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
if x < 0:
return x, y
else:
# x = to_variable(np.ones(1)) so it will return here
return
class TestReturnBase(unittest.TestCase):
def setUp(self):
self.input = np.ones((1)).astype('int32')
......@@ -111,21 +162,41 @@ class TestReturnBase(unittest.TestCase):
self.program_translator.enable(False)
with fluid.dygraph.guard():
res = self.dygraph_func(self.input)
if isinstance(res, (tuple)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.VarBase):
return res.numpy()
return res
def run_static_mode(self):
self.program_translator.enable(True)
with fluid.dygraph.guard():
res = self.dygraph_func(self.input)
if isinstance(res, tuple):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.VarBase):
return res.numpy()
return res
def test_transformed_static_result(self):
static_res = self.run_static_mode()
dygraph_res = self.run_dygraph_mode()
static_res = self.run_static_mode()
if isinstance(dygraph_res, tuple):
self.assertTrue(isinstance(static_res, tuple))
self.assertEqual(len(dygraph_res), len(static_res))
for i in range(len(dygraph_res)):
self.assertTrue(
np.allclose(dygraph_res[i], static_res[i]),
msg='dygraph res is {}\nstatic_res is {}'.format(
dygraph_res[i], static_res[i]))
elif isinstance(dygraph_res, np.ndarray):
self.assertTrue(
np.allclose(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
else:
self.assertEqual(dygraph_res, static_res)
class TestInsideFuncBase(TestReturnBase):
......@@ -159,5 +230,30 @@ class TestRecursiveReturn(TestReturnBase):
self.dygraph_func = test_recursive_return
class TestReturnDifferentLengthIfBody(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_different_length_if_body
class TestReturnDifferentLengthElse(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_different_length_else
class TestNoReturn(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_no_return
class TestReturnNone(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_none
class TestReturnNoVariable(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_no_variable
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册