未验证 提交 08b09f64 编写于 作者: A Aurelius84 提交者: GitHub

Support if/else in dygraph_to_static (#22540)

* support nested if/else

* support to derivate returns the parameter list automatically

* polish tranform function of slice

* fix modify x.numpy()[i] slice function

* support to transform ast.node into callable function

* fix get_name_ids bug and add more unittest test=develop

* fix requirements.txt test=develop

* remove useless import statement test=develop

* Fixed version compatibility issues in param of function test=develop

* use decorater to test ast_to_func test=develop

* add textwrap.dedent for source_code test=develop

* polish code comment

* fix compatibility with python2 and python3 test=develop

* fix gast version error test=develop

* fix gast repo test=develop

* polish transfer_from_node_type code test=develop

* add nested_if_else unittest test=develop

* split IfElseTransformer test=develop

* specify gast version test=develop

* fix ast_to_func root type test=develop
上级 7a4c29e0
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,11 +15,78 @@ ...@@ -15,11 +15,78 @@
from __future__ import print_function from __future__ import print_function
import gast import gast
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else
from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor
__all__ = ['DygraphToStaticAst'] __all__ = ['DygraphToStaticAst']
DECORATOR_NAME = 'dygraph_to_static_output'
class IfElseTransformer(gast.NodeTransformer):
"""
Transform if/else statement of Dygraph into Static Graph.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Type of input node should be AstNodeWrapper, but received %s ." % type(
wrapper_root)
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.new_func_nodes = []
def ast_visit(self):
"""
Main function to transform AST.
"""
self.visit(self.root)
self.after_visit(self.root)
def visit_If(self, node):
assert isinstance(node, gast.If)
self.generic_visit(node)
if is_control_flow_if(node.test):
pred_node = node.test
true_func_node, false_func_node, return_name_ids = transform_if_else(
node, self.root)
self.new_func_nodes += [true_func_node, false_func_node]
# create layers.cond
new_node = create_cond_node(return_name_ids, pred_node,
true_func_node, false_func_node)
return new_node
else:
return node
def visit_Call(self, node):
# Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
# Todo: should be removed. it may be considered as basic api transformation.
if isinstance(node.func, gast.Attribute):
attribute = node.func
if attribute.attr == 'numpy':
node = attribute.value
return node
def after_visit(self, node):
"""
This function will add some postprocessing operations with node.
It can be used to add the created `true_fn/false_fn` in front of
the node.body before they are called in cond layer.
"""
assert hasattr(node, 'body')
# add new ast.funcDef of `if/else`
if self.new_func_nodes:
node.body = self.new_func_nodes + node.body
def get_new_func_nodes(self):
return self.new_func_nodes
class DygraphToStaticAst(gast.NodeTransformer): class DygraphToStaticAst(gast.NodeTransformer):
""" """
...@@ -31,8 +98,33 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -31,8 +98,33 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.root = root self.root = root
self.static_analysis_root = StaticAnalysisVisitor( self.static_analysis_root = StaticAnalysisVisitor(
root).get_node_wrapper_root() root).get_node_wrapper_root()
self.decorate_func_name = None
self.transfer_from_node_type(self.static_analysis_root) self.transfer_from_node_type(self.static_analysis_root)
return self.static_analysis_root return self.static_analysis_root
def transfer_from_node_type(self, node): def transfer_from_node_type(self, node):
print("Not implemented") # Generic transformation
self.visit(node.node)
# Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node).ast_visit()
def visit_FunctionDef(self, node):
if self.decorate_func_name is None:
self.decorate_func_name = node.name
self.generic_visit(node)
# Remove the decorated name of dygraph_to_static
if hasattr(node, 'decorator_list'):
decorator_list = [
d for d in node.decorator_list if d.id != DECORATOR_NAME
]
node.decorator_list = decorator_list
return node
def get_module_name(self):
"""
Return the main function name which will be used as module name
in ast_to_func.
"""
# Should consider BaseAPITransformer which add new module name in Yamei's PR.
assert self.decorate_func_name, "decorate_func_name shall not be None."
return self.decorate_func_name
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import astor
import ast
import gast
import six
import copy
import tempfile
import imp
import os
import atexit
from collections import defaultdict
from paddle.fluid import unique_name
TRUE_FUNC_PRFIX = 'true_fn'
FALSE_FUNC_PRFIX = 'false_fn'
def is_control_flow_if(node):
"""
Determine whether the node is a plain python `if statement` or
control flow in Paddle.
"""
return True
def get_name_ids(nodes, not_name_set=None, node_black_list=None):
"""
Return all ast.Name.id of python variable in nodes.
"""
if not isinstance(nodes, (list, tuple, set)):
raise ValueError(
"nodes must be one of list, tuple, set, but received %s" %
type(nodes))
if not_name_set is None:
not_name_set = set()
def update(old_dict, new_dict):
for k, v in new_dict.items():
old_dict[k].extend(v)
name_ids = defaultdict(list)
for node in nodes:
if node_black_list and node in node_black_list: continue
if isinstance(node, gast.AST):
# In two case, the ast.Name should be filtered.
# 1. Function name like `my_func` of my_func(x)
# 2. api prefix like `fluid` of `fluid.layers.mean`
if isinstance(node, gast.Return):
continue
elif isinstance(node, gast.Call) and isinstance(node.func,
gast.Name):
not_name_set.add(node.func.id)
elif isinstance(node, gast.Attribute) and isinstance(node.value,
gast.Name):
not_name_set.add(node.value.id)
if isinstance(
node, gast.Name
) and node.id not in name_ids and node.id not in not_name_set:
if isinstance(node.ctx, (gast.Store, gast.Load, gast.Param)):
name_ids[node.id].append(node.ctx)
else:
if isinstance(node, gast.Assign):
node = copy.copy(node)
node._fields = ('value', 'targets')
for field, value in gast.iter_fields(node):
value = value if isinstance(value, list) else [value]
update(name_ids,
get_name_ids(value, not_name_set, node_black_list))
return name_ids
def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load):
"""
Find out the ast.Name.id list of input by analyzing node's AST information.
"""
name_ids = [
var_id for var_id, var_ctx in var_ids_dict.items()
if isinstance(var_ctx[0], ctx)
]
if return_ids:
new_args = set(return_ids) - set(name_ids)
name_ids.extend(list(new_args))
name_ids.sort()
args = [
gast.Name(
id=name_id, ctx=gast.Load(), annotation=None, type_comment=None)
for name_id in name_ids
]
arguments = gast.arguments(
args=args,
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[])
return arguments
def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict):
"""
Find out the ast.Name list of output by analyzing node's AST information.
Following conditions should be satisfied while determining whether a variable is a return value:
1. the var in parent scope is modified in if/else node.
2. new var is both created in if and else node.
If different var is modified in if and else node, it should add the var in return_ids
of different node.
For example:
x, y = 5, 10
if x > 4:
x = x+1
z = x*x
else:
y = y - 1
z = y*y
The return_ids should be (x, y, z) for `if` and `else`node.
"""
def _is_return_var(ctxs):
for ctx in ctxs:
if isinstance(ctx, (gast.Store, gast.Param)):
return True
return False
def _vars_with_store(ids_dict):
vars = []
for k, ctxs in ids_dict.items():
if _is_return_var(ctxs):
vars.append(k)
return vars
def _candidate_vars(child_dict, parent_dict):
return set([
var for var in _vars_with_store(child_dict) if var in parent_dict
])
# 1. the var in parent_ids is modified in if/else node.
if_candidate_vars = _candidate_vars(if_vars_dict, parent_vars_dict)
else_candidate_vars = _candidate_vars(else_vars_dict, parent_vars_dict)
# 2. new var is both created in if and else node.
if_new_vars = set([
var for var in _vars_with_store(if_vars_dict)
if var not in parent_vars_dict
])
else_new_vars = set([
var for var in _vars_with_store(else_vars_dict)
if var not in parent_vars_dict
])
new_vars = if_new_vars & else_new_vars
# generate return_ids of if/else node.
modified_vars = if_candidate_vars | else_candidate_vars
return_ids = list(modified_vars | new_vars)
return_ids.sort()
return return_ids, list(modified_vars - new_vars)
def generate_name_node(name_ids, ctx=gast.Load()):
"""
Generate list or gast.Tuple of ast.Name for Return statement.
"""
if isinstance(name_ids, six.string_types):
name_ids = [name_ids]
if not isinstance(name_ids, (list, tuple, set)):
raise TypeError('name_ids must be list or tuple or set, but received %s'
% type(type(name_ids)))
gast_names = [
gast.Name(
id=name_id, ctx=ctx, annotation=None, type_comment=None)
for name_id in name_ids
]
if len(gast_names) == 1:
name_node = gast_names[0]
else:
name_node = gast.Tuple(elts=gast_names, ctx=ctx)
return name_node
def create_funcDef_node(nodes, name, input_args, return_name_ids):
"""
Wrapper all statements of nodes into one ast.FunctionDef, which can be
called by ast.Call.
"""
nodes = copy.copy(nodes)
# add return statement
nodes.append(gast.Return(value=generate_name_node(return_name_ids)))
func_def_node = gast.FunctionDef(
name=name,
args=input_args,
body=nodes,
decorator_list=[],
returns=None,
type_comment=None)
return func_def_node
def transform_if_else(node, root):
"""
Transform ast.If into control flow statement of Paddle static graph.
"""
parent_name_ids = get_name_ids([root], node_black_list=[node])
if_name_ids = get_name_ids(node.body)
else_name_ids = get_name_ids(node.orelse)
return_name_ids, modified_name_ids = parse_cond_return(
parent_name_ids, if_name_ids, else_name_ids)
true_func_node = create_funcDef_node(
node.body,
name=unique_name.generate(TRUE_FUNC_PRFIX),
input_args=parse_cond_args(if_name_ids, modified_name_ids),
return_name_ids=return_name_ids)
false_func_node = create_funcDef_node(
node.orelse,
name=unique_name.generate(FALSE_FUNC_PRFIX),
input_args=parse_cond_args(else_name_ids, modified_name_ids),
return_name_ids=return_name_ids)
return true_func_node, false_func_node, return_name_ids
def create_cond_node(return_name_ids, pred, true_func, false_func):
"""
Create `fluid.layers.cond(pred, true_fn, false_fn)` to replace
original `python if/else` statement.
"""
# TODO(Aurelius84): should replace the api hard code.
cond_api = gast.parse('fluid.layers.cond').body[0].value
true_func_lambda = gast.Lambda(
args=gast.arguments(
args=[],
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[]),
body=gast.Call(
func=gast.Name(
id=true_func.name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
args=[true_func.args],
keywords=[]))
false_func_lambda = gast.Lambda(
args=gast.arguments(
args=[],
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[]),
body=gast.Call(
func=gast.Name(
id=false_func.name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
args=[false_func.args],
keywords=[]))
cond_layer = gast.Call(
func=cond_api,
args=[pred, true_func_lambda, false_func_lambda],
keywords=[])
targets = [generate_name_node(return_name_ids, ctx=gast.Store())]
assign_node = gast.Assign(targets=targets, value=cond_layer)
return assign_node
def ast_to_func(ast_root, func_name, delete_on_exit=True):
"""
Transform modified AST of decorated function into python callable object.
"""
if not isinstance(ast_root, (gast.AST, ast.AST)):
raise TypeError(
"Type of ast_root should be gast.AST or ast.AST, but received %s." %
type(ast_root))
if isinstance(ast_root, gast.AST):
ast_root = gast.gast_to_ast(ast_root)
source = astor.to_source(ast_root)
if six.PY2:
source = source.encode('utf-8')
f = tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False)
else:
f = tempfile.NamedTemporaryFile(
mode='w', suffix='.py', delete=False, encoding='utf-8')
# TODO(Aurelius84): more elegent way to transform ast into callable object
import_str = "import paddle\n" \
"import paddle.fluid as fluid\n" \
"import paddle.fluid.layers as layers\n"
with f:
module_name = os.path.basename(f.name[:-3])
f.write(import_str)
f.write(source)
if delete_on_exit:
atexit.register(lambda: os.remove(f.name))
module = imp.load_source(module_name, f.name)
if not hasattr(module, func_name):
raise ValueError(
'Function: %s doesn\'t exist in the Module transformed from AST.' %
func_name)
return getattr(module, func_name), f.name
...@@ -16,10 +16,12 @@ __all__ = ['TracedLayer', 'dygraph_to_static_output'] ...@@ -16,10 +16,12 @@ __all__ = ['TracedLayer', 'dygraph_to_static_output']
import gast import gast
import inspect import inspect
import textwrap
from ..wrapped_decorator import wrap_decorator from ..wrapped_decorator import wrap_decorator
from .base import program_desc_tracing_guard, switch_to_static_graph from .base import program_desc_tracing_guard, switch_to_static_graph
from .dygraph_to_static import DygraphToStaticAst from .dygraph_to_static import DygraphToStaticAst
from .dygraph_to_static.ast_utils import ast_to_func
from .layers import Layer from .layers import Layer
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
...@@ -54,14 +56,15 @@ def _dygraph_to_static_output_(dygraph_func): ...@@ -54,14 +56,15 @@ def _dygraph_to_static_output_(dygraph_func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
# Get AST from dygraph function # Get AST from dygraph function
dygraph_code = inspect.getsource(dygraph_func) dygraph_code = inspect.getsource(dygraph_func)
dygraph_code = textwrap.dedent(dygraph_code)
root = gast.parse(dygraph_code) root = gast.parse(dygraph_code)
# Transform AST
dygraph_to_static = DygraphToStaticAst()
root_wrapper = dygraph_to_static.get_static_ast(root)
func_name = dygraph_to_static.get_module_name()
root = DygraphToStaticAst().get_static_ast(root) static_func, file_name = ast_to_func(root_wrapper.node, func_name)
# TODO static_func should a callable from AST, like
# static_func = ast_to_func(root)
# currently just use dygraph_func
static_func = dygraph_func
return static_func(*args, **kwargs) return static_func(*args, **kwargs)
return __impl__ return __impl__
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import textwrap
import gast
import inspect
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func
class TestGetNameIds(unittest.TestCase):
"""
Test for parsing the ast.Name list from the ast.Nodes
"""
def setUp(self):
self.source = """
def test_fn(x):
return x+1
"""
self.all_name_ids = {'x': [gast.Param()]}
def test_get_name_ids(self):
source = textwrap.dedent(self.source)
root = gast.parse(source)
all_name_ids = get_name_ids([root])
self.assertDictEqual(
self.transfer_dict(self.all_name_ids),
self.transfer_dict(all_name_ids))
def transfer_dict(self, name_ids_dict):
new_dict = {}
for name, ctxs in name_ids_dict.items():
new_dict[name] = [type(ctx) for ctx in ctxs]
return new_dict
class TestGetNameIds2(TestGetNameIds):
def setUp(self):
self.source = """
def test_fn(x, y):
a = 1
x = y + a
if x > y:
z = x * x
z = z + a
else:
z = y * y
return z
"""
self.all_name_ids = {
'x': [
gast.Param(), gast.Store(), gast.Load(), gast.Load(),
gast.Load()
],
'a': [gast.Store(), gast.Load(), gast.Load()],
'y':
[gast.Param(), gast.Load(), gast.Load(), gast.Load(), gast.Load()],
'z': [gast.Store(), gast.Load(), gast.Store(), gast.Store()]
}
class TestGetNameIds3(TestGetNameIds):
def setUp(self):
self.source = """
def test_fn(x, y):
z = 1
if x > y:
z = x * x
z = z + y
return z
"""
self.all_name_ids = {
'x': [gast.Param(), gast.Load(), gast.Load(), gast.Load()],
'y': [gast.Param(), gast.Load(), gast.Load()],
'z': [gast.Store(), gast.Store(), gast.Load(), gast.Store()]
}
def dyfunc_with_if_else(x_v):
if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1
else:
x_v = x_v + 1
return x_v
def dyfunc_with_if_else2(x):
i, j = 0, 0
if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[i][j]:
y = fluid.layers.relu(x)
else:
x_pow = fluid.layers.pow(x, 2)
y = fluid.layers.tanh(x_pow)
return y
class TestAST2Func(unittest.TestCase):
"""
TestCase for the transformation from ast.AST into python callable function.
"""
def _ast2func(self, func):
source = inspect.getsource(func)
source = textwrap.dedent(source)
ast_root = gast.parse(source)
transformed_func, _ = ast_to_func(ast_root, func.__name__)
return transformed_func
def test_ast2func(self):
def func(x, y):
return x + y
x, y = 10, 20
self.assertEqual(func(x, y), self._ast2func(func)(x, y))
def test_ast2func_dygraph(self):
func = dyfunc_with_if_else
x_data = np.random.random([10, 16]).astype('float32')
with fluid.dygraph.guard():
x_v = fluid.dygraph.to_variable(x_data)
true_ret = func(x_v).numpy()
test_ret = self._ast2func(func)(x_v).numpy()
self.assertTrue((true_ret == test_ret).all())
def test_ast2func_static(self):
def func(x):
y = fluid.layers.relu(x)
loss = fluid.layers.mean(y)
return loss
x_data = np.random.random([10, 16]).astype('float32')
main_program = fluid.Program()
with fluid.program_guard(main_program):
x_v = fluid.layers.assign(x_data)
true_ret = func(x_v)
test_ret = self._ast2func(func)(x_v)
exe = fluid.Executor(fluid.CPUPlace())
ret = exe.run(main_program, fetch_list=[true_ret, test_ret])
self.assertTrue((ret[0] == ret[1]).all())
def test_ast2func_error(self):
with self.assertRaises(Exception) as e:
self.assertRaises(TypeError, ast_to_func("x = a + b", 'foo'))
self.assertTrue("Type of ast_root should be gast.AST or ast.AST" in
str(e.exception))
if __name__ == '__main__':
unittest.main()
...@@ -16,8 +16,6 @@ from __future__ import print_function ...@@ -16,8 +16,6 @@ from __future__ import print_function
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
import unittest import unittest
from paddle.fluid.dygraph.jit import dygraph_to_static_output from paddle.fluid.dygraph.jit import dygraph_to_static_output
...@@ -25,37 +23,85 @@ from paddle.fluid.dygraph.jit import dygraph_to_static_output ...@@ -25,37 +23,85 @@ from paddle.fluid.dygraph.jit import dygraph_to_static_output
np.random.seed(1) np.random.seed(1)
def dyfunc(a, b): def dyfunc_with_if_else(x_v):
if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1
else:
x_v = x_v + 1
return x_v
def dyfunc_with_if_else2(x):
i, j = 0, 0
if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[i][j]:
y = fluid.layers.relu(x)
else:
x_pow = fluid.layers.pow(x, 2)
y = fluid.layers.tanh(x_pow)
return y
def nested_if_else(x_v):
batch_size = x_v.shape[0]
feat_size = x_v.shape[-1]
bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1)
if fluid.layers.mean(x_v).numpy()[0] < 0:
y = x_v + bias
w = fluid.layers.fill_constant([feat_size], dtype='float32', value=10)
if y.numpy()[0] < 10:
tmp = y * w
y = fluid.layers.relu(tmp)
if fluid.layers.mean(y).numpy()[0] < batch_size:
y = fluid.layers.abs(y)
else:
tmp = fluid.layers.fill_constant(
[feat_size], dtype='float32', value=-1)
y = y - tmp
else:
y = x_v - bias
return y
class TestDygraphIfElse(unittest.TestCase):
"""
TestCase for the transformation from control flow `if/else`
dependent on tensor in Dygraph into Static `fluid.layers.cond`.
"""
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = dyfunc_with_if_else
def _run_static(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
x_v = fluid.layers.assign(self.x)
# Transform into static graph
out = dygraph_to_static_output(self.dyfunc)(x_v)
exe = fluid.Executor(fluid.CPUPlace())
ret = exe.run(main_program, fetch_list=out)
return ret
def _run_dygraph(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(a) x_v = fluid.dygraph.to_variable(self.x)
y = fluid.dygraph.to_variable(b) ret = self.dyfunc(x_v)
x.stop_gradient = False return ret.numpy()
y.stop_gradient = False
def test_ast_to_func(self):
inputs = {'X': [x], 'Y': [y]} self.assertTrue((self._run_dygraph() == self._run_static()).all())
loss = core.ops.elementwise_mul(inputs)['Out'][0]
loss.backward() class TestDygraphIfElse2(TestDygraphIfElse):
x_grad = x.gradient() def setUp(self):
y_grad = y.gradient() self.x = np.random.random([10, 16]).astype('float32')
return x_grad, y_grad self.dyfunc = dyfunc_with_if_else2
@dygraph_to_static_output class TestDygraphIfElse3(TestDygraphIfElse):
def dyfunc_to_static(a, b): def setUp(self):
return dyfunc(a, b) self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = nested_if_else
class TestBasicModel(unittest.TestCase):
def test_dygraph_static_same_output(self):
a = np.random.uniform(
low=0.1, high=1, size=(3, 4, 5)).astype(np.float32)
b = np.random.uniform(
low=0.1, high=1, size=(3, 4, 5)).astype(np.float32)
dy_output = dyfunc(a, b)
static_output = dyfunc_to_static(a, b)
self.assertTrue(np.array_equal(dy_output[0], static_output[0]))
self.assertTrue(np.array_equal(dy_output[1], static_output[1]))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -2,6 +2,7 @@ requests>=2.20.0 ...@@ -2,6 +2,7 @@ requests>=2.20.0
numpy>=1.12, <=1.16.4 ; python_version<"3.5" numpy>=1.12, <=1.16.4 ; python_version<"3.5"
numpy>=1.12 ; python_version>="3.5" numpy>=1.12 ; python_version>="3.5"
protobuf>=3.1.0 protobuf>=3.1.0
gast>=0.3.3
matplotlib<=2.2.4 ; python_version<"3.6" matplotlib<=2.2.4 ; python_version<"3.6"
scipy>=0.19.0, <=1.2.1 ; python_version<"3.5" scipy>=0.19.0, <=1.2.1 ; python_version<"3.5"
nltk>=3.2.2, <=3.4 ; python_version<"3.5" nltk>=3.2.2, <=3.4 ; python_version<"3.5"
...@@ -17,5 +18,4 @@ pyyaml ...@@ -17,5 +18,4 @@ pyyaml
decorator decorator
prettytable prettytable
objgraph objgraph
gast
astor astor
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册