提交 c14048d0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1338 Chage op UnsortedSegmentMinD to UnsortedSegmentMin for GE.

Merge pull request !1338 from liuxiao/UnsortedSegmentMinD-UnsortedSegmentMin
......@@ -343,7 +343,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{prim::kPrimGelu->name(), ADPT_DESC(Gelu)},
{prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)},
{string(kNameStridedSlice), ADPT_DESC(StridedSlice)},
{prim::kPrimUnsortedSegmentMin->name(), ADPT_DESC(UnsortedSegmentMinD)},
{prim::kPrimUnsortedSegmentMin->name(), ADPT_DESC(UnsortedSegmentMin)},
{prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD)},
{string(kNameExpandDims), ADPT_DESC(ExpandDims)},
{prim::kPrimSqueeze->name(), ADPT_DESC(Squeeze)},
......
......@@ -1059,10 +1059,9 @@ ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}};
// UnsortedSegmentMin
INPUT_MAP(UnsortedSegmentMinD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}};
INPUT_ATTR_MAP(UnsortedSegmentMinD) = {{3, ATTR_DESC(num_segments, AnyTraits<int64_t>())}};
ATTR_MAP(UnsortedSegmentMinD) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UnsortedSegmentMinD) = {{0, OUTPUT_DESC(y)}};
INPUT_MAP(UnsortedSegmentMin) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}, {3, INPUT_DESC(num_segments)}};
ATTR_MAP(UnsortedSegmentMin) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UnsortedSegmentMin) = {{0, OUTPUT_DESC(y)}};
// ExpandDims
INPUT_MAP(ExpandDims) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}};
......
......@@ -285,9 +285,8 @@ DECLARE_OP_USE_OUTPUT(StridedSlice)
DECLARE_OP_ADAPTER(UnsortedSegmentSumD)
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD)
DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD)
DECLARE_OP_ADAPTER(UnsortedSegmentMinD)
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentMinD)
DECLARE_OP_USE_OUTPUT(UnsortedSegmentMinD)
DECLARE_OP_ADAPTER(UnsortedSegmentMin)
DECLARE_OP_USE_OUTPUT(UnsortedSegmentMin)
DECLARE_OP_ADAPTER(ExpandDims)
DECLARE_OP_USE_OUTPUT(ExpandDims)
DECLARE_OP_ADAPTER(Squeeze)
......
......@@ -1271,7 +1271,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
Inputs:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is a prefix of `x_shape`.
- **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`.
- **num_segments** (int) - The value spcifies the number of distinct `segment_ids`.
Outputs:
......@@ -1279,7 +1279,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
Examples:
>>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
>>> segment_ids = Tensor(np.array([0, 1, 1]).np.int32)
>>> segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32))
>>> num_segments = 2
>>> unsorted_segment_min = P.UnsortedSegmentMin()
>>> unsorted_segment_min(input_x, segment_ids, num_segments)
......@@ -1299,6 +1299,8 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
num_segments_v = num_segments['value']
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册