提交 295038d3 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3324 add reduce_any op for vm

Merge pull request !3324 from fangzehua/reduce_any
......@@ -127,6 +127,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"apply_rms_prop", "apply_rms_prop_d"},
{"cum_prod", "cumprod_d"},
{"reduce_all", "reduce_all_d"},
{"reduce_any", "reduce_any_d"},
{"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
{"unsorted_segment_min", "unsorted_segment_min_d"},
{"reduce_prod", "reduce_prod_d"},
......
......@@ -46,6 +46,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimCumSum->name(), {1});
Register(prim::kPrimCumProd->name(), {1});
Register(prim::kPrimReduceAll->name(), {1});
Register(prim::kPrimReduceAny->name(), {1});
Register(prim::kPrimUnsortedSegmentMin->name(), {2});
Register(kSparseGatherV2, {2});
Register(kUnsortedSegmentProdOpName, {2});
......
......@@ -34,6 +34,7 @@ inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("Minimu
inline const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean");
inline const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum");
inline const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll");
inline const PrimitivePtr kPrimReduceAny = std::make_shared<Primitive>("ReduceAny");
inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax");
inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin");
inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg");
......
......@@ -641,6 +641,16 @@ def get_bprop_reduceall(self):
return bprop
@bprop_getters.register(P.ReduceAny)
def get_bprop_reduceany(self):
"""Grad definition for `ReduceAny` operation."""
def bprop(x, axis, out, dout):
return zeros_like(x), zeros_like(axis)
return bprop
@bprop_getters.register(P.ReduceMax)
def get_bprop_reducemax(self):
"""Grad definition for `Max` operation."""
......
......@@ -246,6 +246,7 @@ from .bitwise_and import _bitwise_and_tbe
from .bitwise_or import _bitwise_or_tbe
from .bitwise_xor import _bitwise_xor_tbe
from .reduce_all import _reduce_all_tbe
from .reduce_any import _reduce_any_tbe
from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe
from .unsorted_segment_min import _unsorted_segment_min_tbe
from .asin import _asin_tbe
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceAny op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
reduce_any_op_info = TBERegOp("ReduceAny") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reduce_any_d.so") \
.compute_cost(10) \
.kernel_name("reduce_any_d") \
.partial_flag(True) \
.attr("axis", "required", "listInt", "all") \
.attr("keep_dims", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("reduce") \
.dtype_format(DataType.BOOL_None, DataType.BOOL_None) \
.get_op_info()
@op_info_register(reduce_any_op_info)
def _reduce_any_tbe():
"""ReduceAny TBE register"""
return
......@@ -44,7 +44,7 @@ from .inner_ops import ScalarCast
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr,
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceAny,
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
LogicalNot, LogicalOr, MatMul, Maximum,
......@@ -215,6 +215,7 @@ __all__ = [
'CTCLoss',
'RNNTLoss',
'ReduceAll',
'ReduceAny',
'ScalarToArray',
'ScalarToTensor',
'TupleToArray',
......
......@@ -405,6 +405,42 @@ class ReduceAll(_Reduce):
return self.do_infer(input_x, axis, (mstype.bool_,))
class ReduceAny(_Reduce):
"""
Reduce a dimension of a tensor by the "logical or" of all elements in the dimension.
The dtype of the tensor to be reduced is bool.
Args:
keep_dims (bool): If True, keep these reduced dimensions and the length is 1.
If False, don't keep these dimensions.
Default : False, don't keep these reduced dimensions.
Inputs:
- **input_x** (Tensor[bool]) - The input tensor.
- **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
Only constant value is allowed.
Outputs:
Tensor, the dtype is bool.
- If axis is (), and keep_dims is false,
the output is a 0-D tensor representing the "logical or" of of all elements in the input tensor.
- If axis is int, set as 2, and keep_dims is false,
and keep_dims is false, the shape of output is :math:`(x_1, x_3, ..., x_R)`.
- If axis is tuple(int), set as (2, 3), and keep_dims is false,
the shape of output is :math:`(x_1, x_4, ..., x_R)`.
Examples:
>>> input_x = Tensor(np.array([[True, False], [True, True]]))
>>> op = P.ReduceAny(keep_dims=True)
>>> output = op(input_x, 1)
"""
def __infer__(self, input_x, axis):
return self.do_infer(input_x, axis, (mstype.bool_,))
class ReduceMax(_Reduce):
"""
Reduce a dimension of a tensor by the maximum value in this dimension.
......
......@@ -1186,6 +1186,11 @@ test_case_math_ops = [
'desc_const': [1],
'desc_inputs': [Tensor(np.array([[True, False], [True, True]]))],
'desc_bprop': []}),
('ReduceAny', {
'block': P.ReduceAny(),
'desc_const': [1],
'desc_inputs': [Tensor(np.array([[True, False], [True, True]]))],
'desc_bprop': []}),
('BesselI0e', {
'block': P.BesselI0e(),
'desc_inputs': [[2, 3]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册