From d86007775f479358765a0437846ac77384e54bae Mon Sep 17 00:00:00 2001 From: zhouwei25 <52485244+zhouwei25@users.noreply.github.com> Date: Mon, 17 Feb 2020 19:06:25 +0800 Subject: [PATCH] fix py_func bug when out is list and add unittest case (#22595) --- python/paddle/fluid/layers/nn.py | 4 +++- .../fluid/tests/unittests/test_py_func_op.py | 22 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1f024fc1099..e68166f4148 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12542,7 +12542,9 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): out_list = [out] elif isinstance(out, tuple): out_list = list(out) - elif not isinstance(x, (list, tuple, Variable)): + elif isinstance(out, list): + out_list = out + else: raise TypeError( 'Output must be Variable/list(Variable)/tuple(Variable)') diff --git a/python/paddle/fluid/tests/unittests/test_py_func_op.py b/python/paddle/fluid/tests/unittests/test_py_func_op.py index b7bff4eae23..6045f2d7136 100644 --- a/python/paddle/fluid/tests/unittests/test_py_func_op.py +++ b/python/paddle/fluid/tests/unittests/test_py_func_op.py @@ -34,6 +34,10 @@ def dummy_func_with_no_output(x): pass +def dummy_func_with_multi_input_output(x, y): + return np.array(x), np.array(y) + + def tanh(x): return np.tanh(x) @@ -109,6 +113,24 @@ def simple_fc_net(img, label, use_py_func_op): loss += dummy_var fluid.layers.py_func(func=dummy_func_with_no_output, x=loss, out=None) + loss_out = fluid.default_main_program().current_block().create_var( + dtype='float32', shape=[-1, 1]) + dummy_var_out = fluid.default_main_program().current_block().create_var( + dtype='float32', shape=[1]) + fluid.layers.py_func( + func=dummy_func_with_multi_input_output, + x=(loss, dummy_var), + out=(loss_out, dummy_var_out)) + assert loss == loss_out and dummy_var == dummy_var_out, \ + "py_func failed with multi input and output" + + fluid.layers.py_func( + func=dummy_func_with_multi_input_output, + x=[loss, dummy_var], + out=[loss_out, dummy_var_out]) + assert loss == loss_out and dummy_var == dummy_var_out, \ + "py_func failed with multi input and output" + loss = fluid.layers.mean(loss) return loss -- GitLab