diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 24c81003bd252a2d0fd0f4c69d43d71d2a90c9ea..78e8778c52e03645072d80f83c3ebddb22ec65a4 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -329,6 +329,10 @@ def _run_op(obj, op_name, args): if hasattr(arg, '__parameter__'): op_inputs.append(arg.default_input) op_mask[i] = 1 + elif isinstance(arg, tuple): + convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x + args_ = tuple(convert(x) for x in arg) + op_inputs.append(args_) else: op_inputs.append(arg) output = real_run_op(obj, op_name, tuple(op_inputs), tuple(op_mask)) diff --git a/tests/ut/python/pynative_mode/ops/test_multitype.py b/tests/ut/python/pynative_mode/ops/test_multitype.py index 0073041b96b033a01a99945f52299aef45077472..58fd31256d0b098fd88dd7598c4e4da1ec98cb7d 100644 --- a/tests/ut/python/pynative_mode/ops/test_multitype.py +++ b/tests/ut/python/pynative_mode/ops/test_multitype.py @@ -16,6 +16,7 @@ import numpy as np from mindspore.common.api import ms_function +from mindspore.common.parameter import Parameter from mindspore.ops import Primitive from mindspore.ops import composite as C from mindspore.ops import operations as P @@ -24,6 +25,7 @@ from ...ut_filter import non_graph_engine tensor_add = P.TensorAdd() +op_add = P.AddN() scala_add = Primitive('scalar_add') add = C.MultitypeFuncGraph('add') @@ -50,5 +52,14 @@ def test_multitype_tensor(): mainf(tensor1, tensor2) +@non_graph_engine +def test_multitype_tuple(): + tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) + params1 = Parameter(tensor1, name="params1") + tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) + output = op_add((params1, tensor2)) + assert output == Tensor(np.array([[2.4, 4.2], [4.4, 6.4]]).astype('float32')) + + def test_multitype_scalar(): mainf(1, 2)