From b96df362f8dd00ae7152d9dbf94c4da63c7d0202 Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Wed, 29 Apr 2020 17:55:12 +0800 Subject: [PATCH] add parser of case which parameter in tuple in run_op function --- mindspore/ops/primitive.py | 4 ++++ tests/ut/python/pynative_mode/ops/test_multitype.py | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 24c81003b..78e8778c5 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -329,6 +329,10 @@ def _run_op(obj, op_name, args): if hasattr(arg, '__parameter__'): op_inputs.append(arg.default_input) op_mask[i] = 1 + elif isinstance(arg, tuple): + convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x + args_ = tuple(convert(x) for x in arg) + op_inputs.append(args_) else: op_inputs.append(arg) output = real_run_op(obj, op_name, tuple(op_inputs), tuple(op_mask)) diff --git a/tests/ut/python/pynative_mode/ops/test_multitype.py b/tests/ut/python/pynative_mode/ops/test_multitype.py index 0073041b9..58fd31256 100644 --- a/tests/ut/python/pynative_mode/ops/test_multitype.py +++ b/tests/ut/python/pynative_mode/ops/test_multitype.py @@ -16,6 +16,7 @@ import numpy as np from mindspore.common.api import ms_function +from mindspore.common.parameter import Parameter from mindspore.ops import Primitive from mindspore.ops import composite as C from mindspore.ops import operations as P @@ -24,6 +25,7 @@ from ...ut_filter import non_graph_engine tensor_add = P.TensorAdd() +op_add = P.AddN() scala_add = Primitive('scalar_add') add = C.MultitypeFuncGraph('add') @@ -50,5 +52,14 @@ def test_multitype_tensor(): mainf(tensor1, tensor2) +@non_graph_engine +def test_multitype_tuple(): + tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) + params1 = Parameter(tensor1, name="params1") + tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) + output = op_add((params1, tensor2)) + assert output == Tensor(np.array([[2.4, 4.2], [4.4, 6.4]]).astype('float32')) + + def test_multitype_scalar(): mainf(1, 2) -- GitLab