提交 a55d8a14 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!470 pynative support reducemean

Merge pull request !470 from JoyLvliang/pynative-support-reducemean
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <utility> #include <utility>
#include <string> #include <string>
#include <memory> #include <memory>
#include <set>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -59,7 +60,7 @@ struct OpExecInfo { ...@@ -59,7 +60,7 @@ struct OpExecInfo {
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>; using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
OpExecInfoPtr GenerateOpExecInfo(const py::args& args); OpExecInfoPtr GenerateOpExecInfo(const py::args& args);
const std::unordered_set<std::string> ignore_infer_prim = {"partial"}; const std::set<std::string> ignore_infer_prim = {"partial", "make_ref"};
} // namespace pynative } // namespace pynative
} // namespace mindspore } // namespace mindspore
......
...@@ -24,7 +24,7 @@ from ..._checkparam import Rel ...@@ -24,7 +24,7 @@ from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.tensor import Tensor from ...common.tensor import Tensor
from .._utils import _get_broadcast_shape from .._utils import _get_broadcast_shape
from ..primitive import PrimitiveWithInfer, prim_attr_register from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op
def _infer_shape_reduce(x, axis, keep_dims, prim_name): def _infer_shape_reduce(x, axis, keep_dims, prim_name):
...@@ -225,6 +225,11 @@ class _Reduce(PrimitiveWithInfer): ...@@ -225,6 +225,11 @@ class _Reduce(PrimitiveWithInfer):
validator.check_value_type('keep_dims', keep_dims, [bool], self.name) validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y'])
def __call__(self, x, axis=()):
args = [x, axis]
output = _run_op(self, self.name, args)
return output
def do_infer(self, input_x, axis, valid_dtype=mstype.number_type): def do_infer(self, input_x, axis, valid_dtype=mstype.number_type):
axis_v = axis['value'] axis_v = axis['value']
input_shp = input_x['shape'] input_shp = input_x['shape']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册