提交 e92d1135 编写于 作者: Z zhouwei25 提交者: Tao Luo

fix bug that tuple(Variable) is converted to list(Variable) uncorrectly (#21687)

上级 a5159d84
...@@ -1814,8 +1814,8 @@ class Operator(object): ...@@ -1814,8 +1814,8 @@ class Operator(object):
"The type of '%s' in operator %s should be " "The type of '%s' in operator %s should be "
"one of [basestring(), str, Varibale] in python2, " "one of [basestring(), str, Varibale] in python2, "
"or one of [str, bytes, Variable] in python3." "or one of [str, bytes, Variable] in python3."
"but received : " % (in_proto.name, type), "but received : %s" %
arg) (in_proto.name, type, arg))
self.desc.set_input(in_proto.name, in_arg_names) self.desc.set_input(in_proto.name, in_arg_names)
else: else:
self.desc.set_input(in_proto.name, []) self.desc.set_input(in_proto.name, [])
......
...@@ -12280,16 +12280,18 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): ...@@ -12280,16 +12280,18 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
x = [] x = []
elif isinstance(x, Variable): elif isinstance(x, Variable):
x = [x] x = [x]
elif not isinstance(x, (list, tuple)): elif isinstance(x, tuple):
x = list(x)
elif not isinstance(x, (list, tuple, Variable)):
raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)') raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)')
if out is None: if out is None:
out_list = [] out_list = []
elif isinstance(out, Variable): elif isinstance(out, Variable):
out_list = [out] out_list = [out]
elif isinstance(out, (list, tuple)): elif isinstance(out, tuple):
out_list = out out_list = list(out)
else: elif not isinstance(x, (list, tuple, Variable)):
raise TypeError( raise TypeError(
'Output must be Variable/list(Variable)/tuple(Variable)') 'Output must be Variable/list(Variable)/tuple(Variable)')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册