From 68aca9ab523b13e67e2a7b456102435ea7849112 Mon Sep 17 00:00:00 2001 From: liuxiao Date: Thu, 21 May 2020 21:25:17 +0800 Subject: [PATCH] UnsortedSegmentMinD->UnsortedSegmentMin --- mindspore/ccsrc/transform/convert.cc | 2 +- mindspore/ccsrc/transform/op_declare.cc | 7 +++---- mindspore/ccsrc/transform/op_declare.h | 5 ++--- mindspore/ops/operations/array_ops.py | 6 ++++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index dcb0d4f0b..8c6fcf02f 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -343,7 +343,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {prim::kPrimGelu->name(), ADPT_DESC(Gelu)}, {prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)}, {string(kNameStridedSlice), ADPT_DESC(StridedSlice)}, - {prim::kPrimUnsortedSegmentMin->name(), ADPT_DESC(UnsortedSegmentMinD)}, + {prim::kPrimUnsortedSegmentMin->name(), ADPT_DESC(UnsortedSegmentMin)}, {prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD)}, {string(kNameExpandDims), ADPT_DESC(ExpandDims)}, {prim::kPrimSqueeze->name(), ADPT_DESC(Squeeze)}, diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 92743b9e8..e6eb08120 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -1059,10 +1059,9 @@ ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP; OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}}; // UnsortedSegmentMin -INPUT_MAP(UnsortedSegmentMinD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}}; -INPUT_ATTR_MAP(UnsortedSegmentMinD) = {{3, ATTR_DESC(num_segments, AnyTraits())}}; -ATTR_MAP(UnsortedSegmentMinD) = EMPTY_ATTR_MAP; -OUTPUT_MAP(UnsortedSegmentMinD) = {{0, OUTPUT_DESC(y)}}; +INPUT_MAP(UnsortedSegmentMin) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}, {3, INPUT_DESC(num_segments)}}; +ATTR_MAP(UnsortedSegmentMin) = EMPTY_ATTR_MAP; +OUTPUT_MAP(UnsortedSegmentMin) = {{0, OUTPUT_DESC(y)}}; // ExpandDims INPUT_MAP(ExpandDims) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}}; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index 2d7d0b159..b517a4bb7 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -285,9 +285,8 @@ DECLARE_OP_USE_OUTPUT(StridedSlice) DECLARE_OP_ADAPTER(UnsortedSegmentSumD) DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD) DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD) -DECLARE_OP_ADAPTER(UnsortedSegmentMinD) -DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentMinD) -DECLARE_OP_USE_OUTPUT(UnsortedSegmentMinD) +DECLARE_OP_ADAPTER(UnsortedSegmentMin) +DECLARE_OP_USE_OUTPUT(UnsortedSegmentMin) DECLARE_OP_ADAPTER(ExpandDims) DECLARE_OP_USE_OUTPUT(ExpandDims) DECLARE_OP_ADAPTER(Squeeze) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 45e04b83f..e3c0a4865 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1271,7 +1271,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer): Inputs: - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`. - - **segment_ids** (Tensor) - A `1-D` tensor whose shape is a prefix of `x_shape`. + - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`. - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`. Outputs: @@ -1279,7 +1279,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer): Examples: >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)) - >>> segment_ids = Tensor(np.array([0, 1, 1]).np.int32) + >>> segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32)) >>> num_segments = 2 >>> unsorted_segment_min = P.UnsortedSegmentMin() >>> unsorted_segment_min(input_x, segment_ids, num_segments) @@ -1299,6 +1299,8 @@ class UnsortedSegmentMin(PrimitiveWithInfer): validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name) + validator.check(f'first shape of input_x', x_shape[0], + 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) num_segments_v = num_segments['value'] validator.check_value_type('num_segments', num_segments_v, [int], self.name) validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) -- GitLab