提交 b96df362 编写于 作者: G guohongzilong

add parser of case which parameter in tuple in run_op function

上级 286c1d77
......@@ -329,6 +329,10 @@ def _run_op(obj, op_name, args):
if hasattr(arg, '__parameter__'):
op_inputs.append(arg.default_input)
op_mask[i] = 1
elif isinstance(arg, tuple):
convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x
args_ = tuple(convert(x) for x in arg)
op_inputs.append(args_)
else:
op_inputs.append(arg)
output = real_run_op(obj, op_name, tuple(op_inputs), tuple(op_mask))
......
......@@ -16,6 +16,7 @@
import numpy as np
from mindspore.common.api import ms_function
from mindspore.common.parameter import Parameter
from mindspore.ops import Primitive
from mindspore.ops import composite as C
from mindspore.ops import operations as P
......@@ -24,6 +25,7 @@ from ...ut_filter import non_graph_engine
tensor_add = P.TensorAdd()
op_add = P.AddN()
scala_add = Primitive('scalar_add')
add = C.MultitypeFuncGraph('add')
......@@ -50,5 +52,14 @@ def test_multitype_tensor():
mainf(tensor1, tensor2)
@non_graph_engine
def test_multitype_tuple():
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
params1 = Parameter(tensor1, name="params1")
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
output = op_add((params1, tensor2))
assert output == Tensor(np.array([[2.4, 4.2], [4.4, 6.4]]).astype('float32'))
def test_multitype_scalar():
mainf(1, 2)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册