From 08b09f64475c3a4876f1e53eb6574e1a9eff94da Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 20 Feb 2020 16:04:49 +0800 Subject: [PATCH] Support if/else in dygraph_to_static (#22540) * support nested if/else * support to derivate returns the parameter list automatically * polish tranform function of slice * fix modify x.numpy()[i] slice function * support to transform ast.node into callable function * fix get_name_ids bug and add more unittest test=develop * fix requirements.txt test=develop * remove useless import statement test=develop * Fixed version compatibility issues in param of function test=develop * use decorater to test ast_to_func test=develop * add textwrap.dedent for source_code test=develop * polish code comment * fix compatibility with python2 and python3 test=develop * fix gast version error test=develop * fix gast repo test=develop * polish transfer_from_node_type code test=develop * add nested_if_else unittest test=develop * split IfElseTransformer test=develop * specify gast version test=develop * fix ast_to_func root type test=develop --- .../dygraph_to_static/ast_transformer.py | 98 +++++- .../dygraph/dygraph_to_static/ast_utils.py | 329 ++++++++++++++++++ python/paddle/fluid/dygraph/jit.py | 13 +- .../fluid/tests/unittests/test_ast_util.py | 165 +++++++++ .../unittests/test_dygraph_to_static_basic.py | 100 ++++-- python/requirements.txt | 2 +- 6 files changed, 671 insertions(+), 36 deletions(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py create mode 100644 python/paddle/fluid/tests/unittests/test_ast_util.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 258bca47ae9..d793dcecb8b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,11 +15,78 @@ from __future__ import print_function import gast +# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST). +# It provides a compatibility layer between the AST of various Python versions, +# as produced by ast.parse from the standard ast module. +# See details in https://github.com/serge-sans-paille/gast/ +from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor __all__ = ['DygraphToStaticAst'] +DECORATOR_NAME = 'dygraph_to_static_output' + + +class IfElseTransformer(gast.NodeTransformer): + """ + Transform if/else statement of Dygraph into Static Graph. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Type of input node should be AstNodeWrapper, but received %s ." % type( + wrapper_root) + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + self.new_func_nodes = [] + + def ast_visit(self): + """ + Main function to transform AST. + """ + self.visit(self.root) + self.after_visit(self.root) + + def visit_If(self, node): + assert isinstance(node, gast.If) + self.generic_visit(node) + if is_control_flow_if(node.test): + pred_node = node.test + true_func_node, false_func_node, return_name_ids = transform_if_else( + node, self.root) + self.new_func_nodes += [true_func_node, false_func_node] + # create layers.cond + new_node = create_cond_node(return_name_ids, pred_node, + true_func_node, false_func_node) + return new_node + else: + return node + + def visit_Call(self, node): + # Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]` + # Todo: should be removed. it may be considered as basic api transformation. + if isinstance(node.func, gast.Attribute): + attribute = node.func + if attribute.attr == 'numpy': + node = attribute.value + return node + + def after_visit(self, node): + """ + This function will add some postprocessing operations with node. + It can be used to add the created `true_fn/false_fn` in front of + the node.body before they are called in cond layer. + """ + assert hasattr(node, 'body') + # add new ast.funcDef of `if/else` + if self.new_func_nodes: + node.body = self.new_func_nodes + node.body + + def get_new_func_nodes(self): + return self.new_func_nodes + class DygraphToStaticAst(gast.NodeTransformer): """ @@ -27,12 +94,37 @@ class DygraphToStaticAst(gast.NodeTransformer): """ def get_static_ast(self, root): - # save root for some analysis may need global AST + # save root for some analysis may need global AST self.root = root self.static_analysis_root = StaticAnalysisVisitor( root).get_node_wrapper_root() + self.decorate_func_name = None self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root def transfer_from_node_type(self, node): - print("Not implemented") + # Generic transformation + self.visit(node.node) + # Transform all if/else statement of Dygraph into Static Graph. + IfElseTransformer(node).ast_visit() + + def visit_FunctionDef(self, node): + if self.decorate_func_name is None: + self.decorate_func_name = node.name + self.generic_visit(node) + # Remove the decorated name of dygraph_to_static + if hasattr(node, 'decorator_list'): + decorator_list = [ + d for d in node.decorator_list if d.id != DECORATOR_NAME + ] + node.decorator_list = decorator_list + return node + + def get_module_name(self): + """ + Return the main function name which will be used as module name + in ast_to_func. + """ + # Should consider BaseAPITransformer which add new module name in Yamei's PR. + assert self.decorate_func_name, "decorate_func_name shall not be None." + return self.decorate_func_name diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py new file mode 100644 index 00000000000..d472f64e6f2 --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py @@ -0,0 +1,329 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import astor +import ast +import gast +import six +import copy +import tempfile +import imp +import os +import atexit +from collections import defaultdict + +from paddle.fluid import unique_name + +TRUE_FUNC_PRFIX = 'true_fn' +FALSE_FUNC_PRFIX = 'false_fn' + + +def is_control_flow_if(node): + """ + Determine whether the node is a plain python `if statement` or + control flow in Paddle. + """ + return True + + +def get_name_ids(nodes, not_name_set=None, node_black_list=None): + """ + Return all ast.Name.id of python variable in nodes. + """ + if not isinstance(nodes, (list, tuple, set)): + raise ValueError( + "nodes must be one of list, tuple, set, but received %s" % + type(nodes)) + if not_name_set is None: + not_name_set = set() + + def update(old_dict, new_dict): + for k, v in new_dict.items(): + old_dict[k].extend(v) + + name_ids = defaultdict(list) + for node in nodes: + if node_black_list and node in node_black_list: continue + if isinstance(node, gast.AST): + # In two case, the ast.Name should be filtered. + # 1. Function name like `my_func` of my_func(x) + # 2. api prefix like `fluid` of `fluid.layers.mean` + if isinstance(node, gast.Return): + continue + elif isinstance(node, gast.Call) and isinstance(node.func, + gast.Name): + not_name_set.add(node.func.id) + elif isinstance(node, gast.Attribute) and isinstance(node.value, + gast.Name): + not_name_set.add(node.value.id) + if isinstance( + node, gast.Name + ) and node.id not in name_ids and node.id not in not_name_set: + if isinstance(node.ctx, (gast.Store, gast.Load, gast.Param)): + name_ids[node.id].append(node.ctx) + else: + if isinstance(node, gast.Assign): + node = copy.copy(node) + node._fields = ('value', 'targets') + for field, value in gast.iter_fields(node): + value = value if isinstance(value, list) else [value] + update(name_ids, + get_name_ids(value, not_name_set, node_black_list)) + return name_ids + + +def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load): + """ + Find out the ast.Name.id list of input by analyzing node's AST information. + """ + + name_ids = [ + var_id for var_id, var_ctx in var_ids_dict.items() + if isinstance(var_ctx[0], ctx) + ] + if return_ids: + new_args = set(return_ids) - set(name_ids) + name_ids.extend(list(new_args)) + name_ids.sort() + args = [ + gast.Name( + id=name_id, ctx=gast.Load(), annotation=None, type_comment=None) + for name_id in name_ids + ] + arguments = gast.arguments( + args=args, + posonlyargs=[], + vararg=None, + kwonlyargs=[], + kw_defaults=None, + kwarg=None, + defaults=[]) + return arguments + + +def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict): + """ + Find out the ast.Name list of output by analyzing node's AST information. + Following conditions should be satisfied while determining whether a variable is a return value: + 1. the var in parent scope is modified in if/else node. + 2. new var is both created in if and else node. + + If different var is modified in if and else node, it should add the var in return_ids + of different node. + For example: + x, y = 5, 10 + if x > 4: + x = x+1 + z = x*x + else: + y = y - 1 + z = y*y + + The return_ids should be (x, y, z) for `if` and `else`node. + """ + + def _is_return_var(ctxs): + for ctx in ctxs: + if isinstance(ctx, (gast.Store, gast.Param)): + return True + return False + + def _vars_with_store(ids_dict): + vars = [] + for k, ctxs in ids_dict.items(): + if _is_return_var(ctxs): + vars.append(k) + return vars + + def _candidate_vars(child_dict, parent_dict): + return set([ + var for var in _vars_with_store(child_dict) if var in parent_dict + ]) + + # 1. the var in parent_ids is modified in if/else node. + if_candidate_vars = _candidate_vars(if_vars_dict, parent_vars_dict) + else_candidate_vars = _candidate_vars(else_vars_dict, parent_vars_dict) + + # 2. new var is both created in if and else node. + if_new_vars = set([ + var for var in _vars_with_store(if_vars_dict) + if var not in parent_vars_dict + ]) + else_new_vars = set([ + var for var in _vars_with_store(else_vars_dict) + if var not in parent_vars_dict + ]) + new_vars = if_new_vars & else_new_vars + + # generate return_ids of if/else node. + modified_vars = if_candidate_vars | else_candidate_vars + return_ids = list(modified_vars | new_vars) + return_ids.sort() + + return return_ids, list(modified_vars - new_vars) + + +def generate_name_node(name_ids, ctx=gast.Load()): + """ + Generate list or gast.Tuple of ast.Name for Return statement. + """ + if isinstance(name_ids, six.string_types): + name_ids = [name_ids] + if not isinstance(name_ids, (list, tuple, set)): + raise TypeError('name_ids must be list or tuple or set, but received %s' + % type(type(name_ids))) + gast_names = [ + gast.Name( + id=name_id, ctx=ctx, annotation=None, type_comment=None) + for name_id in name_ids + ] + if len(gast_names) == 1: + name_node = gast_names[0] + else: + name_node = gast.Tuple(elts=gast_names, ctx=ctx) + return name_node + + +def create_funcDef_node(nodes, name, input_args, return_name_ids): + """ + Wrapper all statements of nodes into one ast.FunctionDef, which can be + called by ast.Call. + """ + nodes = copy.copy(nodes) + # add return statement + nodes.append(gast.Return(value=generate_name_node(return_name_ids))) + func_def_node = gast.FunctionDef( + name=name, + args=input_args, + body=nodes, + decorator_list=[], + returns=None, + type_comment=None) + return func_def_node + + +def transform_if_else(node, root): + """ + Transform ast.If into control flow statement of Paddle static graph. + """ + parent_name_ids = get_name_ids([root], node_black_list=[node]) + if_name_ids = get_name_ids(node.body) + else_name_ids = get_name_ids(node.orelse) + + return_name_ids, modified_name_ids = parse_cond_return( + parent_name_ids, if_name_ids, else_name_ids) + + true_func_node = create_funcDef_node( + node.body, + name=unique_name.generate(TRUE_FUNC_PRFIX), + input_args=parse_cond_args(if_name_ids, modified_name_ids), + return_name_ids=return_name_ids) + false_func_node = create_funcDef_node( + node.orelse, + name=unique_name.generate(FALSE_FUNC_PRFIX), + input_args=parse_cond_args(else_name_ids, modified_name_ids), + return_name_ids=return_name_ids) + + return true_func_node, false_func_node, return_name_ids + + +def create_cond_node(return_name_ids, pred, true_func, false_func): + """ + Create `fluid.layers.cond(pred, true_fn, false_fn)` to replace + original `python if/else` statement. + """ + # TODO(Aurelius84): should replace the api hard code. + cond_api = gast.parse('fluid.layers.cond').body[0].value + true_func_lambda = gast.Lambda( + args=gast.arguments( + args=[], + posonlyargs=[], + vararg=None, + kwonlyargs=[], + kw_defaults=None, + kwarg=None, + defaults=[]), + body=gast.Call( + func=gast.Name( + id=true_func.name, + ctx=gast.Load(), + annotation=None, + type_comment=None), + args=[true_func.args], + keywords=[])) + false_func_lambda = gast.Lambda( + args=gast.arguments( + args=[], + posonlyargs=[], + vararg=None, + kwonlyargs=[], + kw_defaults=None, + kwarg=None, + defaults=[]), + body=gast.Call( + func=gast.Name( + id=false_func.name, + ctx=gast.Load(), + annotation=None, + type_comment=None), + args=[false_func.args], + keywords=[])) + cond_layer = gast.Call( + func=cond_api, + args=[pred, true_func_lambda, false_func_lambda], + keywords=[]) + targets = [generate_name_node(return_name_ids, ctx=gast.Store())] + assign_node = gast.Assign(targets=targets, value=cond_layer) + + return assign_node + + +def ast_to_func(ast_root, func_name, delete_on_exit=True): + """ + Transform modified AST of decorated function into python callable object. + """ + if not isinstance(ast_root, (gast.AST, ast.AST)): + raise TypeError( + "Type of ast_root should be gast.AST or ast.AST, but received %s." % + type(ast_root)) + if isinstance(ast_root, gast.AST): + ast_root = gast.gast_to_ast(ast_root) + source = astor.to_source(ast_root) + if six.PY2: + source = source.encode('utf-8') + f = tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) + else: + f = tempfile.NamedTemporaryFile( + mode='w', suffix='.py', delete=False, encoding='utf-8') + + # TODO(Aurelius84): more elegent way to transform ast into callable object + import_str = "import paddle\n" \ + "import paddle.fluid as fluid\n" \ + "import paddle.fluid.layers as layers\n" + with f: + module_name = os.path.basename(f.name[:-3]) + f.write(import_str) + f.write(source) + + if delete_on_exit: + atexit.register(lambda: os.remove(f.name)) + module = imp.load_source(module_name, f.name) + if not hasattr(module, func_name): + raise ValueError( + 'Function: %s doesn\'t exist in the Module transformed from AST.' % + func_name) + + return getattr(module, func_name), f.name diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 223f9d3c557..43012c08af9 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -16,10 +16,12 @@ __all__ = ['TracedLayer', 'dygraph_to_static_output'] import gast import inspect +import textwrap from ..wrapped_decorator import wrap_decorator from .base import program_desc_tracing_guard, switch_to_static_graph from .dygraph_to_static import DygraphToStaticAst +from .dygraph_to_static.ast_utils import ast_to_func from .layers import Layer from paddle.fluid import core from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode @@ -54,14 +56,15 @@ def _dygraph_to_static_output_(dygraph_func): def __impl__(*args, **kwargs): # Get AST from dygraph function dygraph_code = inspect.getsource(dygraph_func) + dygraph_code = textwrap.dedent(dygraph_code) root = gast.parse(dygraph_code) + # Transform AST + dygraph_to_static = DygraphToStaticAst() + root_wrapper = dygraph_to_static.get_static_ast(root) + func_name = dygraph_to_static.get_module_name() - root = DygraphToStaticAst().get_static_ast(root) + static_func, file_name = ast_to_func(root_wrapper.node, func_name) - # TODO static_func should a callable from AST, like - # static_func = ast_to_func(root) - # currently just use dygraph_func - static_func = dygraph_func return static_func(*args, **kwargs) return __impl__ diff --git a/python/paddle/fluid/tests/unittests/test_ast_util.py b/python/paddle/fluid/tests/unittests/test_ast_util.py new file mode 100644 index 00000000000..9276d663b2f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_ast_util.py @@ -0,0 +1,165 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import textwrap +import gast +import inspect +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func + + +class TestGetNameIds(unittest.TestCase): + """ + Test for parsing the ast.Name list from the ast.Nodes + """ + + def setUp(self): + self.source = """ + def test_fn(x): + return x+1 + """ + self.all_name_ids = {'x': [gast.Param()]} + + def test_get_name_ids(self): + source = textwrap.dedent(self.source) + root = gast.parse(source) + all_name_ids = get_name_ids([root]) + self.assertDictEqual( + self.transfer_dict(self.all_name_ids), + self.transfer_dict(all_name_ids)) + + def transfer_dict(self, name_ids_dict): + new_dict = {} + for name, ctxs in name_ids_dict.items(): + new_dict[name] = [type(ctx) for ctx in ctxs] + return new_dict + + +class TestGetNameIds2(TestGetNameIds): + def setUp(self): + self.source = """ + def test_fn(x, y): + a = 1 + x = y + a + if x > y: + z = x * x + z = z + a + else: + z = y * y + return z + """ + self.all_name_ids = { + 'x': [ + gast.Param(), gast.Store(), gast.Load(), gast.Load(), + gast.Load() + ], + 'a': [gast.Store(), gast.Load(), gast.Load()], + 'y': + [gast.Param(), gast.Load(), gast.Load(), gast.Load(), gast.Load()], + 'z': [gast.Store(), gast.Load(), gast.Store(), gast.Store()] + } + + +class TestGetNameIds3(TestGetNameIds): + def setUp(self): + self.source = """ + def test_fn(x, y): + z = 1 + if x > y: + z = x * x + z = z + y + return z + """ + self.all_name_ids = { + 'x': [gast.Param(), gast.Load(), gast.Load(), gast.Load()], + 'y': [gast.Param(), gast.Load(), gast.Load()], + 'z': [gast.Store(), gast.Store(), gast.Load(), gast.Store()] + } + + +def dyfunc_with_if_else(x_v): + if fluid.layers.mean(x_v).numpy()[0] > 5: + x_v = x_v - 1 + else: + x_v = x_v + 1 + return x_v + + +def dyfunc_with_if_else2(x): + i, j = 0, 0 + if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[i][j]: + y = fluid.layers.relu(x) + else: + x_pow = fluid.layers.pow(x, 2) + y = fluid.layers.tanh(x_pow) + return y + + +class TestAST2Func(unittest.TestCase): + """ + TestCase for the transformation from ast.AST into python callable function. + """ + + def _ast2func(self, func): + source = inspect.getsource(func) + source = textwrap.dedent(source) + ast_root = gast.parse(source) + transformed_func, _ = ast_to_func(ast_root, func.__name__) + return transformed_func + + def test_ast2func(self): + def func(x, y): + return x + y + + x, y = 10, 20 + self.assertEqual(func(x, y), self._ast2func(func)(x, y)) + + def test_ast2func_dygraph(self): + func = dyfunc_with_if_else + x_data = np.random.random([10, 16]).astype('float32') + with fluid.dygraph.guard(): + x_v = fluid.dygraph.to_variable(x_data) + true_ret = func(x_v).numpy() + test_ret = self._ast2func(func)(x_v).numpy() + self.assertTrue((true_ret == test_ret).all()) + + def test_ast2func_static(self): + def func(x): + y = fluid.layers.relu(x) + loss = fluid.layers.mean(y) + return loss + + x_data = np.random.random([10, 16]).astype('float32') + main_program = fluid.Program() + with fluid.program_guard(main_program): + x_v = fluid.layers.assign(x_data) + true_ret = func(x_v) + test_ret = self._ast2func(func)(x_v) + exe = fluid.Executor(fluid.CPUPlace()) + ret = exe.run(main_program, fetch_list=[true_ret, test_ret]) + self.assertTrue((ret[0] == ret[1]).all()) + + def test_ast2func_error(self): + with self.assertRaises(Exception) as e: + self.assertRaises(TypeError, ast_to_func("x = a + b", 'foo')) + self.assertTrue("Type of ast_root should be gast.AST or ast.AST" in + str(e.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py index 39c8fbe6fc5..1c5e298552d 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py @@ -16,8 +16,6 @@ from __future__ import print_function import numpy as np import paddle.fluid as fluid -import paddle.fluid.layers as layers -import paddle.fluid.core as core import unittest from paddle.fluid.dygraph.jit import dygraph_to_static_output @@ -25,37 +23,85 @@ from paddle.fluid.dygraph.jit import dygraph_to_static_output np.random.seed(1) -def dyfunc(a, b): - with fluid.dygraph.guard(): - x = fluid.dygraph.to_variable(a) - y = fluid.dygraph.to_variable(b) - x.stop_gradient = False - y.stop_gradient = False +def dyfunc_with_if_else(x_v): + if fluid.layers.mean(x_v).numpy()[0] > 5: + x_v = x_v - 1 + else: + x_v = x_v + 1 + return x_v - inputs = {'X': [x], 'Y': [y]} - loss = core.ops.elementwise_mul(inputs)['Out'][0] - loss.backward() - x_grad = x.gradient() - y_grad = y.gradient() - return x_grad, y_grad +def dyfunc_with_if_else2(x): + i, j = 0, 0 + if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[i][j]: + y = fluid.layers.relu(x) + else: + x_pow = fluid.layers.pow(x, 2) + y = fluid.layers.tanh(x_pow) + return y -@dygraph_to_static_output -def dyfunc_to_static(a, b): - return dyfunc(a, b) +def nested_if_else(x_v): + batch_size = x_v.shape[0] + feat_size = x_v.shape[-1] + bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1) + if fluid.layers.mean(x_v).numpy()[0] < 0: + y = x_v + bias + w = fluid.layers.fill_constant([feat_size], dtype='float32', value=10) + if y.numpy()[0] < 10: + tmp = y * w + y = fluid.layers.relu(tmp) + if fluid.layers.mean(y).numpy()[0] < batch_size: + y = fluid.layers.abs(y) + else: + tmp = fluid.layers.fill_constant( + [feat_size], dtype='float32', value=-1) + y = y - tmp + else: + y = x_v - bias + return y -class TestBasicModel(unittest.TestCase): - def test_dygraph_static_same_output(self): - a = np.random.uniform( - low=0.1, high=1, size=(3, 4, 5)).astype(np.float32) - b = np.random.uniform( - low=0.1, high=1, size=(3, 4, 5)).astype(np.float32) - dy_output = dyfunc(a, b) - static_output = dyfunc_to_static(a, b) - self.assertTrue(np.array_equal(dy_output[0], static_output[0])) - self.assertTrue(np.array_equal(dy_output[1], static_output[1])) +class TestDygraphIfElse(unittest.TestCase): + """ + TestCase for the transformation from control flow `if/else` + dependent on tensor in Dygraph into Static `fluid.layers.cond`. + """ + + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = dyfunc_with_if_else + + def _run_static(self): + main_program = fluid.Program() + with fluid.program_guard(main_program): + x_v = fluid.layers.assign(self.x) + # Transform into static graph + out = dygraph_to_static_output(self.dyfunc)(x_v) + exe = fluid.Executor(fluid.CPUPlace()) + ret = exe.run(main_program, fetch_list=out) + return ret + + def _run_dygraph(self): + with fluid.dygraph.guard(): + x_v = fluid.dygraph.to_variable(self.x) + ret = self.dyfunc(x_v) + return ret.numpy() + + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + + +class TestDygraphIfElse2(TestDygraphIfElse): + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = dyfunc_with_if_else2 + + +class TestDygraphIfElse3(TestDygraphIfElse): + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = nested_if_else if __name__ == '__main__': diff --git a/python/requirements.txt b/python/requirements.txt index 6c82f5a1a62..29bf92607d5 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -2,6 +2,7 @@ requests>=2.20.0 numpy>=1.12, <=1.16.4 ; python_version<"3.5" numpy>=1.12 ; python_version>="3.5" protobuf>=3.1.0 +gast>=0.3.3 matplotlib<=2.2.4 ; python_version<"3.6" scipy>=0.19.0, <=1.2.1 ; python_version<"3.5" nltk>=3.2.2, <=3.4 ; python_version<"3.5" @@ -17,5 +18,4 @@ pyyaml decorator prettytable objgraph -gast astor -- GitLab