diff --git a/mindspore/ccsrc/pynative/base.h b/mindspore/ccsrc/pynative/base.h index 7405f621cb080c1f8cf48606400889ad85b41232..d8675adc9ce9d6cc0eca870da6f67b9fa509e693 100644 --- a/mindspore/ccsrc/pynative/base.h +++ b/mindspore/ccsrc/pynative/base.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -59,7 +60,7 @@ struct OpExecInfo { using OpExecInfoPtr = std::shared_ptr; OpExecInfoPtr GenerateOpExecInfo(const py::args& args); -const std::unordered_set ignore_infer_prim = {"partial"}; +const std::set ignore_infer_prim = {"partial", "make_ref"}; } // namespace pynative } // namespace mindspore diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 98665dd27aa5855eebe06089a8ac6fe2a41b1099..a3df6b7fbab8ad29e2e00cbd643352b704de71b7 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -24,7 +24,7 @@ from ..._checkparam import Rel from ...common import dtype as mstype from ...common.tensor import Tensor from .._utils import _get_broadcast_shape -from ..primitive import PrimitiveWithInfer, prim_attr_register +from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op def _infer_shape_reduce(x, axis, keep_dims, prim_name): @@ -225,6 +225,11 @@ class _Reduce(PrimitiveWithInfer): validator.check_value_type('keep_dims', keep_dims, [bool], self.name) self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) + def __call__(self, x, axis=()): + args = [x, axis] + output = _run_op(self, self.name, args) + return output + def do_infer(self, input_x, axis, valid_dtype=mstype.number_type): axis_v = axis['value'] input_shp = input_x['shape']