未验证 提交 d8600777 编写于 作者: Z zhouwei25 提交者: GitHub

fix py_func bug when out is list and add unittest case (#22595)

上级 d9f0c9f5
......@@ -12542,7 +12542,9 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
out_list = [out]
elif isinstance(out, tuple):
out_list = list(out)
elif not isinstance(x, (list, tuple, Variable)):
elif isinstance(out, list):
out_list = out
else:
raise TypeError(
'Output must be Variable/list(Variable)/tuple(Variable)')
......
......@@ -34,6 +34,10 @@ def dummy_func_with_no_output(x):
pass
def dummy_func_with_multi_input_output(x, y):
return np.array(x), np.array(y)
def tanh(x):
return np.tanh(x)
......@@ -109,6 +113,24 @@ def simple_fc_net(img, label, use_py_func_op):
loss += dummy_var
fluid.layers.py_func(func=dummy_func_with_no_output, x=loss, out=None)
loss_out = fluid.default_main_program().current_block().create_var(
dtype='float32', shape=[-1, 1])
dummy_var_out = fluid.default_main_program().current_block().create_var(
dtype='float32', shape=[1])
fluid.layers.py_func(
func=dummy_func_with_multi_input_output,
x=(loss, dummy_var),
out=(loss_out, dummy_var_out))
assert loss == loss_out and dummy_var == dummy_var_out, \
"py_func failed with multi input and output"
fluid.layers.py_func(
func=dummy_func_with_multi_input_output,
x=[loss, dummy_var],
out=[loss_out, dummy_var_out])
assert loss == loss_out and dummy_var == dummy_var_out, \
"py_func failed with multi input and output"
loss = fluid.layers.mean(loss)
return loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册