提交 2b0ecfd2 编写于 作者: L liuxiao93

Add TBE op UnsortedSegmentProd for VM.

上级 7304f024
......@@ -84,6 +84,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"transpose", "transpose_d"},
{"fill", "fill_d"},
{"unsorted_segment_sum", "unsorted_segment_sum_d"},
{"unsorted_segment_prod", "unsorted_segment_prod_d"},
{"concat", "concat_d"},
{"slice", "slice_d"},
{"reduce_sum", "reduce_sum_d"},
......
......@@ -625,6 +625,36 @@ def get_bprop_unsorted_segment_min(self):
return bprop
@bprop_getters.register(P.UnsortedSegmentProd)
def get_bprop_unsorted_segment_prod(self):
"""Generate bprop for UnsortedSegmentProd"""
equal = P.Equal()
cast = P.Cast()
select = P.Select()
gather = P.GatherV2()
greater = P.Greater()
ones_like = P.OnesLike()
maximum = P.Maximum()
unsorted_segment_prod = P.UnsortedSegmentProd()
def bprop(x, segment_ids, num_segments, out, dout):
is_zero = equal(x, 0)
num_zero = unsorted_segment_sum(cast(is_zero, mstype.int32), segment_ids, num_segments)
grad = select(greater(num_zero, 1), zeros_like(dout), dout)
non_zero_data = select(is_zero, ones_like(x), x)
non_zero_prod = unsorted_segment_prod(non_zero_data, segment_ids, num_segments)
zero_clipped_indices = maximum(segment_ids, zeros_like(segment_ids))
gathered_prod = gather(out, zero_clipped_indices, 0)
gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0)
prod_divided_by_x = gathered_prod / x
partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x)
gathered_grad, _, _ = _GatherDropNegatives(grad, segment_ids, zero_clipped_indices)
dx = gathered_grad * partial_derivative
return dx, zeros_like(segment_ids), zeros_like(num_segments)
return bprop
@bprop_getters.register(P.SpaceToBatch)
def get_bprop_space_to_batch(self):
"""Generate bprop for SpaceToBatch"""
......
......@@ -133,6 +133,7 @@ from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad
from .apply_proximal_adagrad import _apply_proximal_adagrad
from .transpose_d import _transpose_d_tbe
from .unsorted_segment_sum import _unsorted_segment_sum_tbe
from .unsorted_segment_prod import _unsorted_segment_prod_tbe
from .logsoftmax_grad import _logsoftmax_grad_tbe
from .logsoftmax import _logsoftmax_tbe
from .select import _select_tbe
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""UnsortedSegmentProdD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
unsorted_segment_prod_d_op_info = TBERegOp("UnsortedSegmentProd") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("unsorted_segment_prod_d.so") \
.compute_cost(10) \
.kernel_name("unsorted_segment_prod_d") \
.partial_flag(True) \
.attr("num_segments", "required", "int", "all") \
.input(0, "data", False, "required", "all") \
.input(1, "segment_ids", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.I32_Default, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.I32_Default, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.I32_Default, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.I32_Default, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.I32_Default, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.I32_Default, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_Default, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_Default, DataType.I32_FracZ) \
.dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_Default, DataType.I32_C1HWNCoC0) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(unsorted_segment_prod_d_op_info)
def _unsorted_segment_prod_tbe():
"""UnsortedSegmentProdD TBE register"""
return
......@@ -29,7 +29,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, TransShape,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
......@@ -249,6 +249,7 @@ __all__ = [
'DepthwiseConv2dNative',
'UnsortedSegmentSum',
'UnsortedSegmentMin',
'UnsortedSegmentProd',
"AllGather",
"AllReduce",
"ReduceScatter",
......
......@@ -1412,6 +1412,58 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
return out
class UnsortedSegmentProd(PrimitiveWithInfer):
"""
Computes the product along segments of a tensor.
Inputs:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
With float16, float32 or int32 data type.
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`. Data type must be int32.
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`,
should be greater than 0.
Outputs:
Tensor, Set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
Examples:
>>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
>>> segment_ids = Tensor(np.array([0, 1, 0]).astype(np.int32))
>>> num_segments = 2
>>> unsorted_segment_prod = P.UnsortedSegmentProd()
>>> unsorted_segment_prod(input_x, segment_ids, num_segments)
[[4., 4., 3.], [4., 5., 6.]]
"""
@prim_attr_register
def __init__(self):
"""init UnsortedSegmentProd"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
def __infer__(self, x, segment_ids, num_segments):
x_type = x['dtype']
x_shape = x['shape']
segment_ids_shape = segment_ids['shape']
validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
validator.check_value_type("x_shape", x_shape, [list], self.name)
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name)
segment_ids_shape_len = len(segment_ids_shape)
out_shape = [num_segments_v]
out_shape += x_shape[segment_ids_shape_len:]
out = {'shape': out_shape,
'dtype': mstype.tensor_type(x_type.element_type()),
'value': None}
return out
class Concat(PrimitiveWithInfer):
r"""
Concat tensor in specified axis.
......
......@@ -1318,6 +1318,11 @@ test_case_nn_ops = [
'desc_const': [4],
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))],
'desc_bprop': [[4, 2, 1, 3]]}),
('UnsortedSegmentProd', {
'block': P.UnsortedSegmentProd(),
'desc_const': [4],
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([0, 1, 0]).astype(np.int32))],
'desc_bprop': [[4, 2, 1, 3]]}),
('DropoutGenMask', {
'block': P.DropoutGenMask(),
'desc_const': [(2, 2), Tensor(0.5, mstype.float32)],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册