diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index 446132fc0b4c52a212590ed7db77a956087ffb2e..af10c65400ee2c90c8281faf5b06cc2d8a367626 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -19,9 +19,10 @@ import logging from paddle.fluid import log_helper from paddle.fluid import framework, backward, core from paddle.fluid.dygraph import layers +from paddle.fluid.dygraph.base import switch_to_static_graph +from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import pack_sequence_as -from paddle.fluid.dygraph.base import switch_to_static_graph import paddle.compat as cpt _logger = log_helper.get_logger( @@ -184,7 +185,8 @@ class PartialProgramLayer(layers.Layer): 'is_test': not self.training }) - return self._restore_out(out_vars) + restored_nest_out = self._restore_out(out_vars) + return self._remove_no_value(restored_nest_out) def _prepare(self, inputs): """ @@ -239,11 +241,44 @@ class PartialProgramLayer(layers.Layer): for i, idx in enumerate(self._outputs.var_ids): flatten_outputs[idx] = out_vars[i] outs = self._outputs.restore(flatten_outputs) - if len(outs) == 1: + if outs is not None and len(outs) == 1: outs = outs[0] return outs + def _is_no_value(self, var): + if isinstance(var, core.VarBase): + if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM: + return True + return False + + def _remove_no_value(self, out_vars): + """ + Removes invalid value for various-length return statement + """ + if isinstance(out_vars, core.VarBase): + if self._is_no_value(out_vars): + return None + return out_vars + elif isinstance(out_vars, (tuple, list)): + if isinstance(out_vars, tuple): + res = tuple( + var for var in out_vars if not self._is_no_value(var)) + else: + # isinstance(out_vars, list) + res = [var for var in out_vars if not self._is_no_value(var)] + + has_removed = (len(out_vars) > len(res)) + # len(out_vars) > len(res) means we have removed var. This is + # preventing out_vars is empty or just one element at the beginning + if len(res) == 0 and has_removed: + return None + elif len(res) == 1 and has_removed: + return res[0] + return res + + return out_vars + def _set_grad_type(self, params): # NOTE: if user set sparse gradient mode, the param's gradient # will be SelectedRows, not LoDTensor. But tracer will just diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 57358818eae151ae8dce4eab76caf05228d7b493..5f1ed75735197606bc01ca31d27de5a128feca93 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -278,8 +278,9 @@ class ConcreteProgram(object): with param_guard(func_spec.parameters(False)), param_guard( func_spec.buffers(False)): outputs = static_func(*inputs) - if not isinstance(outputs, (tuple, list)): - outputs = [outputs] if outputs else [] + if not isinstance(outputs, + (tuple, list)) and outputs is not None: + outputs = [outputs] return ConcreteProgram( inputs=inputs, diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py index 5125a190fb9341d130d8ebf92a87e18de2e300ea..ef03e63dbbbb6253e6a4a337f5d1375476165a38 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/return_transformer.py @@ -21,7 +21,9 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ForToWhileTransformer from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node -__all__ = ['ReturnTransformer'] +__all__ = [ + 'RETURN_NO_VALUE_MAGIC_NUM', 'RETURN_NO_VALUE_VAR_NAME', 'ReturnTransformer' +] # Constant for the name of the variable which stores the boolean state that we # should return @@ -30,10 +32,56 @@ RETURN_PREFIX = '__return' # Constant for the name of the variable which stores the final return value RETURN_VALUE_PREFIX = '__return_value' +# Constant for the name of variables to initialize the __return_value +RETURN_VALUE_INIT_NAME = '__return_value_init' + +# Constant magic number representing returning no value. This constant amis to +# support returning various lengths of variables. Static graph must have fixed +# size of fetched output while dygraph can have flexible lengths of output, to +# solve it in dy2stat, we put float64 value with this magic number at Static +# graph as a place holder to indicate the returning placeholder means no value +# should return. +RETURN_NO_VALUE_MAGIC_NUM = 1.77113e+279 +RETURN_NO_VALUE_VAR_NAME = "__no_value_return_var" + + +def get_return_size(return_node): + assert isinstance(return_node, gast.Return), "Input is not gast.Return node" + return_length = 0 + if return_node.value is not None: + if isinstance(return_node.value, gast.Tuple): + return_length = len(return_node.value.elts) + else: + return_length = 1 + return return_length + -class ReturnPreAnalysisVisitor(gast.NodeVisitor): +class ReplaceReturnNoneTransformer(gast.NodeTransformer): """ - Visits gast Tree and pre-analyze the information about 'return'. + Replace 'return None' to 'return' because 'None' cannot be a valid input + in control flow. In ReturnTransformer single 'Return' will be appended no + value placeholder + """ + + def __init__(self, root_node): + self.root = root_node + + def transform(self): + self.visit(self.root) + + def visit_Return(self, node): + if isinstance(node.value, gast.Name) and node.value.id == 'None': + node.value = None + return node + if isinstance(node.value, gast.Constant) and node.value.value == None: + node.value = None + return node + return node + + +class ReturnAnalysisVisitor(gast.NodeVisitor): + """ + Visits gast Tree and analyze the information about 'return'. """ def __init__(self, root_node): @@ -45,11 +93,16 @@ class ReturnPreAnalysisVisitor(gast.NodeVisitor): # Mapping from gast.FunctionDef node to the number of return statements # Python allows define function inside function so we have to handle it self.count_return = {} + + # Mapping from gast.FunctionDef node to the maximum number of variables + # returned by the function's return statement + self.max_return_length = {} self.visit(self.root) def visit_FunctionDef(self, node): self.function_def.append(node) self.count_return[node] = 0 + self.max_return_length[node] = 0 self.generic_visit(node) self.function_def.pop() return node @@ -62,13 +115,21 @@ class ReturnPreAnalysisVisitor(gast.NodeVisitor): self.count_return[cur_func] += 1 else: self.count_return[cur_func] = 1 + + return_length = get_return_size(node) + if cur_func in self.max_return_length: + self.max_return_length[cur_func] = max( + self.max_return_length[cur_func], return_length) + else: + self.max_return_length[cur_func] = return_length + self.generic_visit(node) def get_func_return_count(self, func_node): return self.count_return[func_node] - def set_func_return_count(self, func_node, count): - self.count_return[func_node] = count + def get_func_max_return_length(self, func_node): + return self.max_return_length[func_node] class ReturnTransformer(gast.NodeTransformer): @@ -83,17 +144,25 @@ class ReturnTransformer(gast.NodeTransformer): def __init__(self, wrapper_root): self.wrapper_root = wrapper_root self.root = wrapper_root.node - self.ancestor_nodes = [] + pre_transformer = ReplaceReturnNoneTransformer(self.root) + pre_transformer.transform() + + self.ancestor_nodes = [] # The name of the variable which stores the final return value # Mapping from FunctionDef node to string self.return_value_name = {} # The names of the variable which stores the boolean state that skip # statments. Mapping from FunctionDef node to list self.return_name = {} + # The names of the variable which is placeholder to handle various- + # length return. Mapping from FunctionDef node to list + self.return_no_value_name = {} # A list of FunctionDef to store where the current function is. self.function_def = [] + self.pre_analysis = None + def transform(self): self.visit(self.root) @@ -125,13 +194,19 @@ class ReturnTransformer(gast.NodeTransformer): self.function_def.append(node) self.return_value_name[node] = None self.return_name[node] = [] + self.return_no_value_name[node] = [] - pre_analysis = ReturnPreAnalysisVisitor(node) - while pre_analysis.get_func_return_count(node) > 1: + self.pre_analysis = ReturnAnalysisVisitor(node) + max_return_length = self.pre_analysis.get_func_max_return_length(node) + while self.pre_analysis.get_func_return_count(node) > 1: self.generic_visit(node) - pre_analysis = ReturnPreAnalysisVisitor(node) + self.pre_analysis = ReturnAnalysisVisitor(node) + + if max_return_length == 0: + self.function_def.pop() + return node - # prepend initialization of final return and append final return statement + # Prepend initialization of final return and append final return statement value_name = self.return_value_name[node] if value_name is not None: node.body.append( @@ -140,12 +215,51 @@ class ReturnTransformer(gast.NodeTransformer): ctx=gast.Load(), annotation=None, type_comment=None))) - assign_zero_node = create_fill_constant_node(value_name, 0.0) - node.body.insert(0, assign_zero_node) + init_names = [ + unique_name.generate(RETURN_VALUE_INIT_NAME) + for i in range(max_return_length) + ] + assign_zero_nodes = [ + create_fill_constant_node(iname, 0.0) for iname in init_names + ] + if len(init_names) == 1: + return_value_nodes = gast.Name( + id=init_names[0], + ctx=gast.Load(), + annotation=None, + type_comment=None) + else: + # We need to initialize return value as a tuple because control + # flow requires some inputs or outputs have same structure + return_value_nodes = gast.Tuple( + elts=[ + gast.Name( + id=iname, + ctx=gast.Load(), + annotation=None, + type_comment=None) for iname in init_names + ], + ctx=gast.Load()) + assign_return_value_node = gast.Assign( + targets=[ + gast.Name( + id=value_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + value=return_value_nodes) + node.body.insert(0, assign_return_value_node) + node.body[:0] = assign_zero_nodes # Prepend control flow boolean nodes such as '__return@1 = False' for name in self.return_name[node]: assign_false_node = create_fill_constant_node(name, False) node.body.insert(0, assign_false_node) + # Prepend no value placeholders + for name in self.return_no_value_name[node]: + assign_no_value_node = create_fill_constant_node( + name, RETURN_NO_VALUE_MAGIC_NUM) + node.body.insert(0, assign_no_value_node) self.function_def.pop() return node @@ -154,21 +268,24 @@ class ReturnTransformer(gast.NodeTransformer): cur_func_node = self.function_def[-1] return_name = unique_name.generate(RETURN_PREFIX) self.return_name[cur_func_node].append(return_name) + max_return_length = self.pre_analysis.get_func_max_return_length( + cur_func_node) for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)): ancestor = self.ancestor_nodes[ancestor_index] cur_node = self.ancestor_nodes[ancestor_index + 1] if hasattr(ancestor, "body") and index_in_list(ancestor.body, cur_node) != -1: if cur_node == node: - self._replace_return_in_stmt_list(ancestor.body, cur_node, - return_name) + self._replace_return_in_stmt_list( + ancestor.body, cur_node, return_name, max_return_length) self._replace_after_node_to_if_in_stmt_list( ancestor.body, cur_node, return_name) elif hasattr(ancestor, "orelse") and index_in_list(ancestor.orelse, cur_node) != -1: if cur_node == node: self._replace_return_in_stmt_list(ancestor.orelse, cur_node, - return_name) + return_name, + max_return_length) self._replace_after_node_to_if_in_stmt_list( ancestor.orelse, cur_node, return_name) @@ -203,26 +320,92 @@ class ReturnTransformer(gast.NodeTransformer): break # return_node is replaced so we shouldn't return here - def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name): + def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name, + max_return_length): + assert max_return_length >= 0, "Input illegal max_return_length" i = index_in_list(stmt_list, return_node) if i == -1: return False assign_nodes = [create_fill_constant_node(return_name, True)] - if return_node.value is not None: - cur_func_node = self.function_def[-1] + cur_func_node = self.function_def[-1] + return_length = get_return_size(return_node) + if return_length < max_return_length: + # In this case we should append RETURN_NO_VALUE placeholder + # + # max_return_length must be >= 1 here because return_length will be + # 0 at least. if self.return_value_name[cur_func_node] is None: self.return_value_name[cur_func_node] = unique_name.generate( RETURN_VALUE_PREFIX) - assign_nodes.append( - gast.Assign( - targets=[ - gast.Name( - id=self.return_value_name[cur_func_node], - ctx=gast.Store(), + + no_value_names = [ + unique_name.generate(RETURN_NO_VALUE_VAR_NAME) + for j in range(max_return_length - return_length) + ] + self.return_no_value_name[cur_func_node].extend(no_value_names) + + # Handle tuple/non-tuple case + if max_return_length == 1: + assign_nodes.append( + gast.Assign( + targets=[ + gast.Name( + id=self.return_value_name[cur_func_node], + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + value=gast.Name( + id=no_value_names[0], + ctx=gast.Load(), annotation=None, - type_comment=None) - ], - value=return_node.value)) + type_comment=None))) + else: + # max_return_length > 1 which means we should assign tuple + fill_tuple = [ + gast.Name( + id=n, + ctx=gast.Load(), + annotation=None, + type_comment=None) for n in no_value_names + ] + if return_node.value is not None: + if isinstance(return_node.value, gast.Tuple): + fill_tuple[:0] = return_node.value.elts + else: + fill_tuple.insert(0, return_node.value) + + assign_nodes.append( + gast.Assign( + targets=[ + gast.Name( + id=self.return_value_name[cur_func_node], + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + value=gast.Tuple( + elts=fill_tuple, ctx=gast.Load()))) + else: + # In this case we should NOT append RETURN_NO_VALUE placeholder + if return_node.value is not None: + cur_func_node = self.function_def[-1] + if self.return_value_name[cur_func_node] is None: + self.return_value_name[ + cur_func_node] = unique_name.generate( + RETURN_VALUE_PREFIX) + + assign_nodes.append( + gast.Assign( + targets=[ + gast.Name( + id=self.return_value_name[cur_func_node], + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + value=return_node.value)) + stmt_list[i:] = assign_nodes return True diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index 1cba65bc8e48fd44b41a72486b78e17a456d7f1a..873d9ecb53549e9d6a3982ca4528e63526bd3a0d 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -67,8 +67,9 @@ class StaticCode1(): shape=[1], dtype='bool', value=False) __return_0 = fluid.layers.fill_constant( shape=[1], dtype='bool', value=False) - __return_value_0 = fluid.layers.fill_constant( + __return_value_init_0 = fluid.layers.fill_constant( shape=[1], dtype='float64', value=0.0) + __return_value_0 = __return_value_init_0 def true_fn_0(x_v): x_v = x_v - 1 @@ -123,8 +124,9 @@ class StaticCode2(): shape=[1], dtype='bool', value=False) __return_2 = fluid.layers.fill_constant( shape=[1], dtype='bool', value=False) - __return_value_1 = fluid.layers.fill_constant( + __return_value_init_1 = fluid.layers.fill_constant( shape=[1], dtype='float64', value=0.0) + __return_value_1 = __return_value_init_1 def true_fn_3(x_v): x_v = x_v - 1 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py index 59ea532832e8637b469841b823302cd76f037d5a..1f4f82146645ded6f345abc7d17b1724d9c3a8b9 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np import paddle.fluid as fluid +import paddle.fluid.core as core from paddle.fluid.dygraph import declarative from paddle.fluid.dygraph import ProgramTranslator @@ -96,6 +97,56 @@ def test_recursive_return(x): return dyfunc_with_if_else(x) +@declarative +def test_return_different_length_if_body(x): + x = fluid.dygraph.to_variable(x) + y = x + 1 + if x > 0: + # x = to_variable(np.ones(1)) so it will return here + return x, y + else: + return x + + +@declarative +def test_return_different_length_else(x): + x = fluid.dygraph.to_variable(x) + y = x + 1 + if x < 0: + return x, y + else: + # x = to_variable(np.ones(1)) so it will return here + return x + + +@declarative +def test_no_return(x): + x = fluid.dygraph.to_variable(x) + y = x + 1 + + +@declarative +def test_return_none(x): + x = fluid.dygraph.to_variable(x) + y = x + 1 + if x > 0: + # x = to_variable(np.ones(1)) so it will return here + return None + else: + return x, y + + +@declarative +def test_return_no_variable(x): + x = fluid.dygraph.to_variable(x) + y = x + 1 + if x < 0: + return x, y + else: + # x = to_variable(np.ones(1)) so it will return here + return + + class TestReturnBase(unittest.TestCase): def setUp(self): self.input = np.ones((1)).astype('int32') @@ -111,21 +162,41 @@ class TestReturnBase(unittest.TestCase): self.program_translator.enable(False) with fluid.dygraph.guard(): res = self.dygraph_func(self.input) - return res.numpy() + if isinstance(res, (tuple)): + return tuple(r.numpy() for r in res) + elif isinstance(res, core.VarBase): + return res.numpy() + return res def run_static_mode(self): self.program_translator.enable(True) with fluid.dygraph.guard(): res = self.dygraph_func(self.input) - return res.numpy() + if isinstance(res, tuple): + return tuple(r.numpy() for r in res) + elif isinstance(res, core.VarBase): + return res.numpy() + return res def test_transformed_static_result(self): - static_res = self.run_static_mode() dygraph_res = self.run_dygraph_mode() - self.assertTrue( - np.allclose(dygraph_res, static_res), - msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res, - static_res)) + static_res = self.run_static_mode() + if isinstance(dygraph_res, tuple): + self.assertTrue(isinstance(static_res, tuple)) + self.assertEqual(len(dygraph_res), len(static_res)) + for i in range(len(dygraph_res)): + self.assertTrue( + np.allclose(dygraph_res[i], static_res[i]), + msg='dygraph res is {}\nstatic_res is {}'.format( + dygraph_res[i], static_res[i])) + + elif isinstance(dygraph_res, np.ndarray): + self.assertTrue( + np.allclose(dygraph_res, static_res), + msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res, + static_res)) + else: + self.assertEqual(dygraph_res, static_res) class TestInsideFuncBase(TestReturnBase): @@ -159,5 +230,30 @@ class TestRecursiveReturn(TestReturnBase): self.dygraph_func = test_recursive_return +class TestReturnDifferentLengthIfBody(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_return_different_length_if_body + + +class TestReturnDifferentLengthElse(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_return_different_length_else + + +class TestNoReturn(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_no_return + + +class TestReturnNone(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_return_none + + +class TestReturnNoVariable(TestReturnBase): + def init_dygraph_func(self): + self.dygraph_func = test_return_no_variable + + if __name__ == '__main__': unittest.main()