提交 18c94950 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1228 Adapt tbe op UnsortedSegmentMin for GE.

Merge pull request !1228 from liuxiao/UnsortedSegmentMin
...@@ -138,6 +138,7 @@ const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size"); ...@@ -138,6 +138,7 @@ const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax"); const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax");
const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack"); const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack");
const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum"); const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum");
const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset"); const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
......
...@@ -143,6 +143,7 @@ extern const PrimitivePtr kPrimSize; ...@@ -143,6 +143,7 @@ extern const PrimitivePtr kPrimSize;
extern const PrimitivePtr kPrimArgMax; extern const PrimitivePtr kPrimArgMax;
extern const PrimitivePtr kPrimPack; extern const PrimitivePtr kPrimPack;
extern const PrimitivePtr kPrimUnpack; extern const PrimitivePtr kPrimUnpack;
extern const PrimitivePtr kPrimUnsortedSegmentMin;
extern const PrimitivePtr kPrimUnsortedSegmentSum; extern const PrimitivePtr kPrimUnsortedSegmentSum;
extern const PrimitivePtr kPrimConcatOffset; extern const PrimitivePtr kPrimConcatOffset;
extern const PrimitivePtr kPrimReshape; extern const PrimitivePtr kPrimReshape;
......
...@@ -341,6 +341,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma ...@@ -341,6 +341,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{prim::kPrimGelu->name(), ADPT_DESC(Gelu)}, {prim::kPrimGelu->name(), ADPT_DESC(Gelu)},
{prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)}, {prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)},
{string(kNameStridedSlice), ADPT_DESC(StridedSlice)}, {string(kNameStridedSlice), ADPT_DESC(StridedSlice)},
{prim::kPrimUnsortedSegmentMin->name(), ADPT_DESC(UnsortedSegmentMinD)},
{prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD)}, {prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD)},
{string(kNameExpandDims), ADPT_DESC(ExpandDims)}, {string(kNameExpandDims), ADPT_DESC(ExpandDims)},
{prim::kPrimSqueeze->name(), ADPT_DESC(Squeeze)}, {prim::kPrimSqueeze->name(), ADPT_DESC(Squeeze)},
......
...@@ -1053,6 +1053,12 @@ INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits<int ...@@ -1053,6 +1053,12 @@ INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits<int
ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP; ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}};
// UnsortedSegmentMin
INPUT_MAP(UnsortedSegmentMinD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}};
INPUT_ATTR_MAP(UnsortedSegmentMinD) = {{3, ATTR_DESC(num_segments, AnyTraits<int64_t>())}};
ATTR_MAP(UnsortedSegmentMinD) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UnsortedSegmentMinD) = {{0, OUTPUT_DESC(y)}};
// ExpandDims // ExpandDims
INPUT_MAP(ExpandDims) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}}; INPUT_MAP(ExpandDims) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}};
ATTR_MAP(ExpandDims) = EMPTY_ATTR_MAP; ATTR_MAP(ExpandDims) = EMPTY_ATTR_MAP;
......
...@@ -283,6 +283,9 @@ DECLARE_OP_USE_OUTPUT(StridedSlice) ...@@ -283,6 +283,9 @@ DECLARE_OP_USE_OUTPUT(StridedSlice)
DECLARE_OP_ADAPTER(UnsortedSegmentSumD) DECLARE_OP_ADAPTER(UnsortedSegmentSumD)
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD) DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD)
DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD) DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD)
DECLARE_OP_ADAPTER(UnsortedSegmentMinD)
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentMinD)
DECLARE_OP_USE_OUTPUT(UnsortedSegmentMinD)
DECLARE_OP_ADAPTER(ExpandDims) DECLARE_OP_ADAPTER(ExpandDims)
DECLARE_OP_USE_OUTPUT(ExpandDims) DECLARE_OP_USE_OUTPUT(ExpandDims)
DECLARE_OP_ADAPTER(Squeeze) DECLARE_OP_ADAPTER(Squeeze)
......
...@@ -22,6 +22,7 @@ from .. import functional as F ...@@ -22,6 +22,7 @@ from .. import functional as F
from .grad_base import bprop_getters from .grad_base import bprop_getters
from ..primitive import constexpr from ..primitive import constexpr
from ... import context from ... import context
from ...common import dtype as mstype
reduce_sum = P.ReduceSum() reduce_sum = P.ReduceSum()
unsorted_segment_sum = P.UnsortedSegmentSum() unsorted_segment_sum = P.UnsortedSegmentSum()
...@@ -29,6 +30,7 @@ transpose = P.Transpose() ...@@ -29,6 +30,7 @@ transpose = P.Transpose()
shape_op = P.Shape() shape_op = P.Shape()
reshape = P.Reshape() reshape = P.Reshape()
invert_permutation = P.InvertPermutation() invert_permutation = P.InvertPermutation()
logical_and = P.LogicalAnd()
@bprop_getters.register(P.Fill) @bprop_getters.register(P.Fill)
...@@ -456,6 +458,57 @@ def get_bprop_diag_part(self): ...@@ -456,6 +458,57 @@ def get_bprop_diag_part(self):
return bprop return bprop
def _GatherDropNegatives(params,
ids,
zero_clipped_indices=None,
is_positive=None):
"""Helper function for unsorted segment ops."""
maximum = P.Maximum()
gather = P.GatherV2()
greater_equal = P.GreaterEqual()
rank = P.Rank()
fill = P.Fill()
select = P.Select()
if zero_clipped_indices is None:
zero_clipped_indices = maximum(ids, zeros_like(ids))
gathered = gather(params, zero_clipped_indices, 0)
if is_positive is None:
is_positive = greater_equal(ids, 0)
is_positive_shape = shape_op(is_positive)
broadcastable_shape = is_positive_shape
for _ in range(rank(gathered) - rank(is_positive)):
broadcastable_shape += (1,)
is_positive = reshape(is_positive, broadcastable_shape)
gathered_shape = shape_op(gathered)
is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
zero_slice = zeros_like(gathered)
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
@bprop_getters.register(P.UnsortedSegmentMin)
def get_bprop_unsorted_segment_min(self):
"""Generate bprop for UnsortedSegmentMin"""
equal = P.Equal()
cast = P.Cast()
divide = P.RealDiv()
get_dtype = P.DType()
select = P.Select()
def bprop(x, segment_ids, num_segments, out, dout):
gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids)
is_selected = equal(x, gathered_outputs)
is_selected = logical_and(is_selected, is_positive)
num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
segment_ids, num_segments)
weighted_grads = divide(dout, num_selected)
gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
zero_clipped_indices, is_positive)
zeros = zeros_like(gathered_grads)
return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
return bprop
@bprop_getters.register(P.SpaceToBatch) @bprop_getters.register(P.SpaceToBatch)
def get_bprop_space_to_batch(self): def get_bprop_space_to_batch(self):
"""Generate bprop for SpaceToBatch""" """Generate bprop for SpaceToBatch"""
......
...@@ -28,7 +28,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, ...@@ -28,7 +28,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, Shape, Size, Slice, Split,
Squeeze, StridedSlice, Tile, Squeeze, StridedSlice, Tile,
Transpose, TruncatedNormal, TupleToArray, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace) UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset, _MirrorOperator, ReduceOp, _VirtualDataset,
...@@ -96,6 +96,7 @@ __all__ = [ ...@@ -96,6 +96,7 @@ __all__ = [
'MaxPool', 'MaxPool',
'TopK', 'TopK',
'Adam', 'Adam',
'Softplus',
'Softmax', 'Softmax',
'LogSoftmax', 'LogSoftmax',
'SoftmaxCrossEntropyWithLogits', 'SoftmaxCrossEntropyWithLogits',
...@@ -210,6 +211,7 @@ __all__ = [ ...@@ -210,6 +211,7 @@ __all__ = [
'Size', 'Size',
'DepthwiseConv2dNative', 'DepthwiseConv2dNative',
'UnsortedSegmentSum', 'UnsortedSegmentSum',
'UnsortedSegmentMin',
"AllGather", "AllGather",
"AllReduce", "AllReduce",
"ReduceScatter", "ReduceScatter",
......
...@@ -1253,6 +1253,54 @@ class UnsortedSegmentSum(PrimitiveWithInfer): ...@@ -1253,6 +1253,54 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
return out return out
class UnsortedSegmentMin(PrimitiveWithInfer):
"""
Computes the minimum along segments of a tensor.
If the given segment_ids is negative, the value will be ignored.
Inputs:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is a prefix of `x_shape`.
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`.
Outputs:
Tensor, Set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
Examples:
>>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
>>> segment_ids = Tensor(np.array([0, 1, 1]).np.int32)
>>> num_segments = 2
>>> unsorted_segment_min = P.UnsortedSegmentMin()
>>> unsorted_segment_min(input_x, segment_ids, num_segments)
[[1., 2., 3.], [4., 2., 1.]]
"""
@prim_attr_register
def __init__(self):
"""init UnsortedSegmentMin"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
def __infer__(self, x, segment_ids, num_segments):
x_type = x['dtype']
x_shape = x['shape']
segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name)
num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name)
segment_ids_shape_len = len(segment_ids_shape)
out_shape = [num_segments_v]
out_shape += x_shape[segment_ids_shape_len:]
out = {'shape': out_shape,
'dtype': x_type,
'value': None}
return out
class Concat(PrimitiveWithInfer): class Concat(PrimitiveWithInfer):
r""" r"""
Concat tensor in specified axis. Concat tensor in specified axis.
......
...@@ -778,6 +778,11 @@ test_case_nn_ops = [ ...@@ -778,6 +778,11 @@ test_case_nn_ops = [
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))], 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))],
'desc_bprop': [[4, 1, 3]], 'desc_bprop': [[4, 1, 3]],
'skip': ['backward']}), 'skip': ['backward']}),
('UnsortedSegmentMin', {
'block': P.UnsortedSegmentMin(),
'desc_const': [4],
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))],
'desc_bprop': [[4, 2, 1, 3]]}),
('DropoutGenMask', { ('DropoutGenMask', {
'block': P.DropoutGenMask(), 'block': P.DropoutGenMask(),
'desc_const': [(2, 2), Tensor(0.5, mstype.float32)], 'desc_const': [(2, 2), Tensor(0.5, mstype.float32)],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册