未验证 提交 5e8e6dad 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Support Various-Length Return Grammar in Dy2stat (#25249)

Support Various-Length Return Grammar in Dy2stat. This PR is a follow-up of https://github.com/PaddlePaddle/Paddle/pull/25176 .

The basic idea is putting no-value placeholder variables at `return` statement to make all `return` statement have same length, after that the static graph can have fixed fetch output (code at return_transformer.py). Then remove those no-value placeholder when we finally return dygraph result (code at partial_program.py).

However, various length return in Bert model is still not supported. The dy2stat can change the code as I wish but some ops which check shape at compile time (e.g. Reshape, MatMul) will throw error because of the no-value-placeholder may not have the required shape. Is this a matter? To me, those no-value placeholder will be replaced as really values meeting shape requirements at run time, so I think the solution should be some way to do the compile-time checking. By the way, every time when we have dynamic shape, it often causes problem in dy2stat. We should find a way to handle it in the future.

Fixing various return in Bert is my TODO thing and I will also find some other existing models for verification.
上级 de27569e
...@@ -19,9 +19,10 @@ import logging ...@@ -19,9 +19,10 @@ import logging
from paddle.fluid import log_helper from paddle.fluid import log_helper
from paddle.fluid import framework, backward, core from paddle.fluid import framework, backward, core
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph
import paddle.compat as cpt import paddle.compat as cpt
_logger = log_helper.get_logger( _logger = log_helper.get_logger(
...@@ -184,7 +185,8 @@ class PartialProgramLayer(layers.Layer): ...@@ -184,7 +185,8 @@ class PartialProgramLayer(layers.Layer):
'is_test': not self.training 'is_test': not self.training
}) })
return self._restore_out(out_vars) restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
def _prepare(self, inputs): def _prepare(self, inputs):
""" """
...@@ -239,11 +241,44 @@ class PartialProgramLayer(layers.Layer): ...@@ -239,11 +241,44 @@ class PartialProgramLayer(layers.Layer):
for i, idx in enumerate(self._outputs.var_ids): for i, idx in enumerate(self._outputs.var_ids):
flatten_outputs[idx] = out_vars[i] flatten_outputs[idx] = out_vars[i]
outs = self._outputs.restore(flatten_outputs) outs = self._outputs.restore(flatten_outputs)
if len(outs) == 1: if outs is not None and len(outs) == 1:
outs = outs[0] outs = outs[0]
return outs return outs
def _is_no_value(self, var):
if isinstance(var, core.VarBase):
if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
return True
return False
def _remove_no_value(self, out_vars):
"""
Removes invalid value for various-length return statement
"""
if isinstance(out_vars, core.VarBase):
if self._is_no_value(out_vars):
return None
return out_vars
elif isinstance(out_vars, (tuple, list)):
if isinstance(out_vars, tuple):
res = tuple(
var for var in out_vars if not self._is_no_value(var))
else:
# isinstance(out_vars, list)
res = [var for var in out_vars if not self._is_no_value(var)]
has_removed = (len(out_vars) > len(res))
# len(out_vars) > len(res) means we have removed var. This is
# preventing out_vars is empty or just one element at the beginning
if len(res) == 0 and has_removed:
return None
elif len(res) == 1 and has_removed:
return res[0]
return res
return out_vars
def _set_grad_type(self, params): def _set_grad_type(self, params):
# NOTE: if user set sparse gradient mode, the param's gradient # NOTE: if user set sparse gradient mode, the param's gradient
# will be SelectedRows, not LoDTensor. But tracer will just # will be SelectedRows, not LoDTensor. But tracer will just
......
...@@ -278,8 +278,9 @@ class ConcreteProgram(object): ...@@ -278,8 +278,9 @@ class ConcreteProgram(object):
with param_guard(func_spec.parameters(False)), param_guard( with param_guard(func_spec.parameters(False)), param_guard(
func_spec.buffers(False)): func_spec.buffers(False)):
outputs = static_func(*inputs) outputs = static_func(*inputs)
if not isinstance(outputs, (tuple, list)): if not isinstance(outputs,
outputs = [outputs] if outputs else [] (tuple, list)) and outputs is not None:
outputs = [outputs]
return ConcreteProgram( return ConcreteProgram(
inputs=inputs, inputs=inputs,
......
...@@ -21,7 +21,9 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list ...@@ -21,7 +21,9 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ForToWhileTransformer from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ForToWhileTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
__all__ = ['ReturnTransformer'] __all__ = [
'RETURN_NO_VALUE_MAGIC_NUM', 'RETURN_NO_VALUE_VAR_NAME', 'ReturnTransformer'
]
# Constant for the name of the variable which stores the boolean state that we # Constant for the name of the variable which stores the boolean state that we
# should return # should return
...@@ -30,10 +32,56 @@ RETURN_PREFIX = '__return' ...@@ -30,10 +32,56 @@ RETURN_PREFIX = '__return'
# Constant for the name of the variable which stores the final return value # Constant for the name of the variable which stores the final return value
RETURN_VALUE_PREFIX = '__return_value' RETURN_VALUE_PREFIX = '__return_value'
# Constant for the name of variables to initialize the __return_value
RETURN_VALUE_INIT_NAME = '__return_value_init'
class ReturnPreAnalysisVisitor(gast.NodeVisitor): # Constant magic number representing returning no value. This constant amis to
# support returning various lengths of variables. Static graph must have fixed
# size of fetched output while dygraph can have flexible lengths of output, to
# solve it in dy2stat, we put float64 value with this magic number at Static
# graph as a place holder to indicate the returning placeholder means no value
# should return.
RETURN_NO_VALUE_MAGIC_NUM = 1.77113e+279
RETURN_NO_VALUE_VAR_NAME = "__no_value_return_var"
def get_return_size(return_node):
assert isinstance(return_node, gast.Return), "Input is not gast.Return node"
return_length = 0
if return_node.value is not None:
if isinstance(return_node.value, gast.Tuple):
return_length = len(return_node.value.elts)
else:
return_length = 1
return return_length
class ReplaceReturnNoneTransformer(gast.NodeTransformer):
"""
Replace 'return None' to 'return' because 'None' cannot be a valid input
in control flow. In ReturnTransformer single 'Return' will be appended no
value placeholder
"""
def __init__(self, root_node):
self.root = root_node
def transform(self):
self.visit(self.root)
def visit_Return(self, node):
if isinstance(node.value, gast.Name) and node.value.id == 'None':
node.value = None
return node
if isinstance(node.value, gast.Constant) and node.value.value == None:
node.value = None
return node
return node
class ReturnAnalysisVisitor(gast.NodeVisitor):
""" """
Visits gast Tree and pre-analyze the information about 'return'. Visits gast Tree and analyze the information about 'return'.
""" """
def __init__(self, root_node): def __init__(self, root_node):
...@@ -45,11 +93,16 @@ class ReturnPreAnalysisVisitor(gast.NodeVisitor): ...@@ -45,11 +93,16 @@ class ReturnPreAnalysisVisitor(gast.NodeVisitor):
# Mapping from gast.FunctionDef node to the number of return statements # Mapping from gast.FunctionDef node to the number of return statements
# Python allows define function inside function so we have to handle it # Python allows define function inside function so we have to handle it
self.count_return = {} self.count_return = {}
# Mapping from gast.FunctionDef node to the maximum number of variables
# returned by the function's return statement
self.max_return_length = {}
self.visit(self.root) self.visit(self.root)
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
self.function_def.append(node) self.function_def.append(node)
self.count_return[node] = 0 self.count_return[node] = 0
self.max_return_length[node] = 0
self.generic_visit(node) self.generic_visit(node)
self.function_def.pop() self.function_def.pop()
return node return node
...@@ -62,13 +115,21 @@ class ReturnPreAnalysisVisitor(gast.NodeVisitor): ...@@ -62,13 +115,21 @@ class ReturnPreAnalysisVisitor(gast.NodeVisitor):
self.count_return[cur_func] += 1 self.count_return[cur_func] += 1
else: else:
self.count_return[cur_func] = 1 self.count_return[cur_func] = 1
return_length = get_return_size(node)
if cur_func in self.max_return_length:
self.max_return_length[cur_func] = max(
self.max_return_length[cur_func], return_length)
else:
self.max_return_length[cur_func] = return_length
self.generic_visit(node) self.generic_visit(node)
def get_func_return_count(self, func_node): def get_func_return_count(self, func_node):
return self.count_return[func_node] return self.count_return[func_node]
def set_func_return_count(self, func_node, count): def get_func_max_return_length(self, func_node):
self.count_return[func_node] = count return self.max_return_length[func_node]
class ReturnTransformer(gast.NodeTransformer): class ReturnTransformer(gast.NodeTransformer):
...@@ -83,17 +144,25 @@ class ReturnTransformer(gast.NodeTransformer): ...@@ -83,17 +144,25 @@ class ReturnTransformer(gast.NodeTransformer):
def __init__(self, wrapper_root): def __init__(self, wrapper_root):
self.wrapper_root = wrapper_root self.wrapper_root = wrapper_root
self.root = wrapper_root.node self.root = wrapper_root.node
self.ancestor_nodes = []
pre_transformer = ReplaceReturnNoneTransformer(self.root)
pre_transformer.transform()
self.ancestor_nodes = []
# The name of the variable which stores the final return value # The name of the variable which stores the final return value
# Mapping from FunctionDef node to string # Mapping from FunctionDef node to string
self.return_value_name = {} self.return_value_name = {}
# The names of the variable which stores the boolean state that skip # The names of the variable which stores the boolean state that skip
# statments. Mapping from FunctionDef node to list # statments. Mapping from FunctionDef node to list
self.return_name = {} self.return_name = {}
# The names of the variable which is placeholder to handle various-
# length return. Mapping from FunctionDef node to list
self.return_no_value_name = {}
# A list of FunctionDef to store where the current function is. # A list of FunctionDef to store where the current function is.
self.function_def = [] self.function_def = []
self.pre_analysis = None
def transform(self): def transform(self):
self.visit(self.root) self.visit(self.root)
...@@ -125,13 +194,19 @@ class ReturnTransformer(gast.NodeTransformer): ...@@ -125,13 +194,19 @@ class ReturnTransformer(gast.NodeTransformer):
self.function_def.append(node) self.function_def.append(node)
self.return_value_name[node] = None self.return_value_name[node] = None
self.return_name[node] = [] self.return_name[node] = []
self.return_no_value_name[node] = []
pre_analysis = ReturnPreAnalysisVisitor(node) self.pre_analysis = ReturnAnalysisVisitor(node)
while pre_analysis.get_func_return_count(node) > 1: max_return_length = self.pre_analysis.get_func_max_return_length(node)
while self.pre_analysis.get_func_return_count(node) > 1:
self.generic_visit(node) self.generic_visit(node)
pre_analysis = ReturnPreAnalysisVisitor(node) self.pre_analysis = ReturnAnalysisVisitor(node)
# prepend initialization of final return and append final return statement if max_return_length == 0:
self.function_def.pop()
return node
# Prepend initialization of final return and append final return statement
value_name = self.return_value_name[node] value_name = self.return_value_name[node]
if value_name is not None: if value_name is not None:
node.body.append( node.body.append(
...@@ -140,12 +215,51 @@ class ReturnTransformer(gast.NodeTransformer): ...@@ -140,12 +215,51 @@ class ReturnTransformer(gast.NodeTransformer):
ctx=gast.Load(), ctx=gast.Load(),
annotation=None, annotation=None,
type_comment=None))) type_comment=None)))
assign_zero_node = create_fill_constant_node(value_name, 0.0) init_names = [
node.body.insert(0, assign_zero_node) unique_name.generate(RETURN_VALUE_INIT_NAME)
for i in range(max_return_length)
]
assign_zero_nodes = [
create_fill_constant_node(iname, 0.0) for iname in init_names
]
if len(init_names) == 1:
return_value_nodes = gast.Name(
id=init_names[0],
ctx=gast.Load(),
annotation=None,
type_comment=None)
else:
# We need to initialize return value as a tuple because control
# flow requires some inputs or outputs have same structure
return_value_nodes = gast.Tuple(
elts=[
gast.Name(
id=iname,
ctx=gast.Load(),
annotation=None,
type_comment=None) for iname in init_names
],
ctx=gast.Load())
assign_return_value_node = gast.Assign(
targets=[
gast.Name(
id=value_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=return_value_nodes)
node.body.insert(0, assign_return_value_node)
node.body[:0] = assign_zero_nodes
# Prepend control flow boolean nodes such as '__return@1 = False' # Prepend control flow boolean nodes such as '__return@1 = False'
for name in self.return_name[node]: for name in self.return_name[node]:
assign_false_node = create_fill_constant_node(name, False) assign_false_node = create_fill_constant_node(name, False)
node.body.insert(0, assign_false_node) node.body.insert(0, assign_false_node)
# Prepend no value placeholders
for name in self.return_no_value_name[node]:
assign_no_value_node = create_fill_constant_node(
name, RETURN_NO_VALUE_MAGIC_NUM)
node.body.insert(0, assign_no_value_node)
self.function_def.pop() self.function_def.pop()
return node return node
...@@ -154,21 +268,24 @@ class ReturnTransformer(gast.NodeTransformer): ...@@ -154,21 +268,24 @@ class ReturnTransformer(gast.NodeTransformer):
cur_func_node = self.function_def[-1] cur_func_node = self.function_def[-1]
return_name = unique_name.generate(RETURN_PREFIX) return_name = unique_name.generate(RETURN_PREFIX)
self.return_name[cur_func_node].append(return_name) self.return_name[cur_func_node].append(return_name)
max_return_length = self.pre_analysis.get_func_max_return_length(
cur_func_node)
for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)): for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)):
ancestor = self.ancestor_nodes[ancestor_index] ancestor = self.ancestor_nodes[ancestor_index]
cur_node = self.ancestor_nodes[ancestor_index + 1] cur_node = self.ancestor_nodes[ancestor_index + 1]
if hasattr(ancestor, if hasattr(ancestor,
"body") and index_in_list(ancestor.body, cur_node) != -1: "body") and index_in_list(ancestor.body, cur_node) != -1:
if cur_node == node: if cur_node == node:
self._replace_return_in_stmt_list(ancestor.body, cur_node, self._replace_return_in_stmt_list(
return_name) ancestor.body, cur_node, return_name, max_return_length)
self._replace_after_node_to_if_in_stmt_list( self._replace_after_node_to_if_in_stmt_list(
ancestor.body, cur_node, return_name) ancestor.body, cur_node, return_name)
elif hasattr(ancestor, "orelse") and index_in_list(ancestor.orelse, elif hasattr(ancestor, "orelse") and index_in_list(ancestor.orelse,
cur_node) != -1: cur_node) != -1:
if cur_node == node: if cur_node == node:
self._replace_return_in_stmt_list(ancestor.orelse, cur_node, self._replace_return_in_stmt_list(ancestor.orelse, cur_node,
return_name) return_name,
max_return_length)
self._replace_after_node_to_if_in_stmt_list( self._replace_after_node_to_if_in_stmt_list(
ancestor.orelse, cur_node, return_name) ancestor.orelse, cur_node, return_name)
...@@ -203,16 +320,81 @@ class ReturnTransformer(gast.NodeTransformer): ...@@ -203,16 +320,81 @@ class ReturnTransformer(gast.NodeTransformer):
break break
# return_node is replaced so we shouldn't return here # return_node is replaced so we shouldn't return here
def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name): def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name,
max_return_length):
assert max_return_length >= 0, "Input illegal max_return_length"
i = index_in_list(stmt_list, return_node) i = index_in_list(stmt_list, return_node)
if i == -1: if i == -1:
return False return False
assign_nodes = [create_fill_constant_node(return_name, True)] assign_nodes = [create_fill_constant_node(return_name, True)]
if return_node.value is not None:
cur_func_node = self.function_def[-1] cur_func_node = self.function_def[-1]
return_length = get_return_size(return_node)
if return_length < max_return_length:
# In this case we should append RETURN_NO_VALUE placeholder
#
# max_return_length must be >= 1 here because return_length will be
# 0 at least.
if self.return_value_name[cur_func_node] is None: if self.return_value_name[cur_func_node] is None:
self.return_value_name[cur_func_node] = unique_name.generate( self.return_value_name[cur_func_node] = unique_name.generate(
RETURN_VALUE_PREFIX) RETURN_VALUE_PREFIX)
no_value_names = [
unique_name.generate(RETURN_NO_VALUE_VAR_NAME)
for j in range(max_return_length - return_length)
]
self.return_no_value_name[cur_func_node].extend(no_value_names)
# Handle tuple/non-tuple case
if max_return_length == 1:
assign_nodes.append(
gast.Assign(
targets=[
gast.Name(
id=self.return_value_name[cur_func_node],
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=gast.Name(
id=no_value_names[0],
ctx=gast.Load(),
annotation=None,
type_comment=None)))
else:
# max_return_length > 1 which means we should assign tuple
fill_tuple = [
gast.Name(
id=n,
ctx=gast.Load(),
annotation=None,
type_comment=None) for n in no_value_names
]
if return_node.value is not None:
if isinstance(return_node.value, gast.Tuple):
fill_tuple[:0] = return_node.value.elts
else:
fill_tuple.insert(0, return_node.value)
assign_nodes.append(
gast.Assign(
targets=[
gast.Name(
id=self.return_value_name[cur_func_node],
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=gast.Tuple(
elts=fill_tuple, ctx=gast.Load())))
else:
# In this case we should NOT append RETURN_NO_VALUE placeholder
if return_node.value is not None:
cur_func_node = self.function_def[-1]
if self.return_value_name[cur_func_node] is None:
self.return_value_name[
cur_func_node] = unique_name.generate(
RETURN_VALUE_PREFIX)
assign_nodes.append( assign_nodes.append(
gast.Assign( gast.Assign(
targets=[ targets=[
...@@ -223,6 +405,7 @@ class ReturnTransformer(gast.NodeTransformer): ...@@ -223,6 +405,7 @@ class ReturnTransformer(gast.NodeTransformer):
type_comment=None) type_comment=None)
], ],
value=return_node.value)) value=return_node.value))
stmt_list[i:] = assign_nodes stmt_list[i:] = assign_nodes
return True return True
......
...@@ -67,8 +67,9 @@ class StaticCode1(): ...@@ -67,8 +67,9 @@ class StaticCode1():
shape=[1], dtype='bool', value=False) shape=[1], dtype='bool', value=False)
__return_0 = fluid.layers.fill_constant( __return_0 = fluid.layers.fill_constant(
shape=[1], dtype='bool', value=False) shape=[1], dtype='bool', value=False)
__return_value_0 = fluid.layers.fill_constant( __return_value_init_0 = fluid.layers.fill_constant(
shape=[1], dtype='float64', value=0.0) shape=[1], dtype='float64', value=0.0)
__return_value_0 = __return_value_init_0
def true_fn_0(x_v): def true_fn_0(x_v):
x_v = x_v - 1 x_v = x_v - 1
...@@ -123,8 +124,9 @@ class StaticCode2(): ...@@ -123,8 +124,9 @@ class StaticCode2():
shape=[1], dtype='bool', value=False) shape=[1], dtype='bool', value=False)
__return_2 = fluid.layers.fill_constant( __return_2 = fluid.layers.fill_constant(
shape=[1], dtype='bool', value=False) shape=[1], dtype='bool', value=False)
__return_value_1 = fluid.layers.fill_constant( __return_value_init_1 = fluid.layers.fill_constant(
shape=[1], dtype='float64', value=0.0) shape=[1], dtype='float64', value=0.0)
__return_value_1 = __return_value_init_1
def true_fn_3(x_v): def true_fn_3(x_v):
x_v = x_v - 1 x_v = x_v - 1
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.dygraph import declarative from paddle.fluid.dygraph import declarative
from paddle.fluid.dygraph import ProgramTranslator from paddle.fluid.dygraph import ProgramTranslator
...@@ -96,6 +97,56 @@ def test_recursive_return(x): ...@@ -96,6 +97,56 @@ def test_recursive_return(x):
return dyfunc_with_if_else(x) return dyfunc_with_if_else(x)
@declarative
def test_return_different_length_if_body(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
if x > 0:
# x = to_variable(np.ones(1)) so it will return here
return x, y
else:
return x
@declarative
def test_return_different_length_else(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
if x < 0:
return x, y
else:
# x = to_variable(np.ones(1)) so it will return here
return x
@declarative
def test_no_return(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
@declarative
def test_return_none(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
if x > 0:
# x = to_variable(np.ones(1)) so it will return here
return None
else:
return x, y
@declarative
def test_return_no_variable(x):
x = fluid.dygraph.to_variable(x)
y = x + 1
if x < 0:
return x, y
else:
# x = to_variable(np.ones(1)) so it will return here
return
class TestReturnBase(unittest.TestCase): class TestReturnBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.input = np.ones((1)).astype('int32') self.input = np.ones((1)).astype('int32')
...@@ -111,21 +162,41 @@ class TestReturnBase(unittest.TestCase): ...@@ -111,21 +162,41 @@ class TestReturnBase(unittest.TestCase):
self.program_translator.enable(False) self.program_translator.enable(False)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
res = self.dygraph_func(self.input) res = self.dygraph_func(self.input)
if isinstance(res, (tuple)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.VarBase):
return res.numpy() return res.numpy()
return res
def run_static_mode(self): def run_static_mode(self):
self.program_translator.enable(True) self.program_translator.enable(True)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
res = self.dygraph_func(self.input) res = self.dygraph_func(self.input)
if isinstance(res, tuple):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.VarBase):
return res.numpy() return res.numpy()
return res
def test_transformed_static_result(self): def test_transformed_static_result(self):
static_res = self.run_static_mode()
dygraph_res = self.run_dygraph_mode() dygraph_res = self.run_dygraph_mode()
static_res = self.run_static_mode()
if isinstance(dygraph_res, tuple):
self.assertTrue(isinstance(static_res, tuple))
self.assertEqual(len(dygraph_res), len(static_res))
for i in range(len(dygraph_res)):
self.assertTrue(
np.allclose(dygraph_res[i], static_res[i]),
msg='dygraph res is {}\nstatic_res is {}'.format(
dygraph_res[i], static_res[i]))
elif isinstance(dygraph_res, np.ndarray):
self.assertTrue( self.assertTrue(
np.allclose(dygraph_res, static_res), np.allclose(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res, msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res)) static_res))
else:
self.assertEqual(dygraph_res, static_res)
class TestInsideFuncBase(TestReturnBase): class TestInsideFuncBase(TestReturnBase):
...@@ -159,5 +230,30 @@ class TestRecursiveReturn(TestReturnBase): ...@@ -159,5 +230,30 @@ class TestRecursiveReturn(TestReturnBase):
self.dygraph_func = test_recursive_return self.dygraph_func = test_recursive_return
class TestReturnDifferentLengthIfBody(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_different_length_if_body
class TestReturnDifferentLengthElse(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_different_length_else
class TestNoReturn(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_no_return
class TestReturnNone(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_none
class TestReturnNoVariable(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_no_variable
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册