提交 e92d1135 编写于 作者: Z zhouwei25 提交者: Tao Luo

fix bug that tuple(Variable) is converted to list(Variable) uncorrectly (#21687)

上级 a5159d84
......@@ -1814,8 +1814,8 @@ class Operator(object):
"The type of '%s' in operator %s should be "
"one of [basestring(), str, Varibale] in python2, "
"or one of [str, bytes, Variable] in python3."
"but received : " % (in_proto.name, type),
arg)
"but received : %s" %
(in_proto.name, type, arg))
self.desc.set_input(in_proto.name, in_arg_names)
else:
self.desc.set_input(in_proto.name, [])
......
......@@ -12280,16 +12280,18 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
x = []
elif isinstance(x, Variable):
x = [x]
elif not isinstance(x, (list, tuple)):
elif isinstance(x, tuple):
x = list(x)
elif not isinstance(x, (list, tuple, Variable)):
raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)')
if out is None:
out_list = []
elif isinstance(out, Variable):
out_list = [out]
elif isinstance(out, (list, tuple)):
out_list = out
else:
elif isinstance(out, tuple):
out_list = list(out)
elif not isinstance(x, (list, tuple, Variable)):
raise TypeError(
'Output must be Variable/list(Variable)/tuple(Variable)')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册