diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1f024fc10993affb81cdea1ec560d8d2ecdf40dc..e68166f4148dfda2d69835cd3888c18734ceea5e 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12542,7 +12542,9 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): out_list = [out] elif isinstance(out, tuple): out_list = list(out) - elif not isinstance(x, (list, tuple, Variable)): + elif isinstance(out, list): + out_list = out + else: raise TypeError( 'Output must be Variable/list(Variable)/tuple(Variable)') diff --git a/python/paddle/fluid/tests/unittests/test_py_func_op.py b/python/paddle/fluid/tests/unittests/test_py_func_op.py index b7bff4eae23e7b7b4e879bf6f25924c107b4ea02..6045f2d713627cedfe169b9e066222904244311a 100644 --- a/python/paddle/fluid/tests/unittests/test_py_func_op.py +++ b/python/paddle/fluid/tests/unittests/test_py_func_op.py @@ -34,6 +34,10 @@ def dummy_func_with_no_output(x): pass +def dummy_func_with_multi_input_output(x, y): + return np.array(x), np.array(y) + + def tanh(x): return np.tanh(x) @@ -109,6 +113,24 @@ def simple_fc_net(img, label, use_py_func_op): loss += dummy_var fluid.layers.py_func(func=dummy_func_with_no_output, x=loss, out=None) + loss_out = fluid.default_main_program().current_block().create_var( + dtype='float32', shape=[-1, 1]) + dummy_var_out = fluid.default_main_program().current_block().create_var( + dtype='float32', shape=[1]) + fluid.layers.py_func( + func=dummy_func_with_multi_input_output, + x=(loss, dummy_var), + out=(loss_out, dummy_var_out)) + assert loss == loss_out and dummy_var == dummy_var_out, \ + "py_func failed with multi input and output" + + fluid.layers.py_func( + func=dummy_func_with_multi_input_output, + x=[loss, dummy_var], + out=[loss_out, dummy_var_out]) + assert loss == loss_out and dummy_var == dummy_var_out, \ + "py_func failed with multi input and output" + loss = fluid.layers.mean(loss) return loss