diff --git a/mace/kernels/strided_slice.h b/mace/kernels/strided_slice.h index 9486e2fdbfd8d620eba743220838875c3a077234..efdd9141663eea93bd0ac554118858541631b8fa 100644 --- a/mace/kernels/strided_slice.h +++ b/mace/kernels/strided_slice.h @@ -32,12 +32,16 @@ struct StridedSliceFunctor { int end_mask, int ellipsis_mask, int new_axis_mask, - int shrink_axis_mask) + int shrink_axis_mask, + bool is_slice = false) : begin_mask_(begin_mask), end_mask_(end_mask), ellipsis_mask_(ellipsis_mask), new_axis_mask_(new_axis_mask), - shrink_axis_mask_(shrink_axis_mask) {} + shrink_axis_mask_(shrink_axis_mask), + is_slice_(is_slice), + tmp_strides_tensor_(GetDeviceAllocator(D), + DataTypeToEnum::v()) {} MaceStatus operator()(const Tensor *input, const Tensor *begin_indices, @@ -49,6 +53,14 @@ struct StridedSliceFunctor { MACE_CHECK(ellipsis_mask_ == 0 && new_axis_mask_ == 0, "ellipsis_mask and new_axis_mask are not supported yet."); + if (strides == nullptr) { + tmp_strides_tensor_.Resize({begin_indices->size()}); + Tensor::MappingGuard strides_guard(&tmp_strides_tensor_); + int32_t *strides_data = tmp_strides_tensor_.mutable_data(); + std::fill(strides_data, strides_data + tmp_strides_tensor_.size(), 1); + strides = &tmp_strides_tensor_; + } + Tensor::MappingGuard input_guard(input); Tensor::MappingGuard begin_indices_guard(begin_indices); Tensor::MappingGuard end_indices_guard(end_indices); @@ -56,6 +68,19 @@ struct StridedSliceFunctor { const T *input_data = input->data(); const int32_t *begin_indices_data = begin_indices->data(); const int32_t *end_indices_data = end_indices->data(); + std::vector slice_end_data; + if (is_slice_) { + // if this op is slice, the end_indices_data is size actually + slice_end_data.resize(end_indices->size()); + for (int i = 0; i < slice_end_data.size(); ++i) { + if (end_indices_data[i] == -1) { + slice_end_data[i] = input->dim(i); + } else { + slice_end_data[i] = begin_indices_data[i] + end_indices_data[i]; + } + } + end_indices_data = slice_end_data.data(); + } const int32_t *strides_data = strides->data(); std::vector output_shape; @@ -152,6 +177,8 @@ struct StridedSliceFunctor { int ellipsis_mask_; int new_axis_mask_; int shrink_axis_mask_; + bool is_slice_; + Tensor tmp_strides_tensor_; }; } // namespace kernels diff --git a/mace/ops/strided_slice.h b/mace/ops/strided_slice.h index e3e25db543f32e3b3361422a762159d50aeec69e..57653359c2b0d4333ed8e04517c699e60b7439b3 100644 --- a/mace/ops/strided_slice.h +++ b/mace/ops/strided_slice.h @@ -30,13 +30,18 @@ class StridedSliceOp : public Operator { OperatorBase::GetOptionalArg("end_mask", 0), OperatorBase::GetOptionalArg("ellipsis_mask", 0), OperatorBase::GetOptionalArg("new_axis_mask", 0), - OperatorBase::GetOptionalArg("shrink_axis_mask", 0)) {} + OperatorBase::GetOptionalArg("shrink_axis_mask", 0), + OperatorBase::GetOptionalArg("slice", + false)) {} MaceStatus Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); const Tensor *begin_indices = this->Input(BEGIN); const Tensor *end_indices = this->Input(END); - const Tensor *strides = this->Input(STRIDES); + const Tensor *strides = nullptr; + if (this->InputSize() > 3) { + strides = this->Input(STRIDES); + } Tensor *output = this->Output(OUTPUT); return functor_(input, begin_indices, end_indices, strides, output, future); diff --git a/mace/ops/strided_slice_test.cc b/mace/ops/strided_slice_test.cc index 6cd46f4e110e0cb001932a18db6db2a5c69d866b..8662367eddc775a2142878430c39cc93c364ba15 100644 --- a/mace/ops/strided_slice_test.cc +++ b/mace/ops/strided_slice_test.cc @@ -23,32 +23,27 @@ class StridedSliceOpTest : public OpsTestBase {}; namespace { -void TestSlice(const std::vector &input_shape, - const std::vector &input, - const std::vector &begin_indices, - const std::vector &end_indices, - const std::vector &strides, - const int begin_mask, - const int end_mask, - const int ellipsis_mask, - const int new_axis_mask, - const int shrink_axis_mask, - const std::vector &output_shape, - const std::vector &output) { +void TestStridedSlice(const std::vector &input_shape, + const std::vector &input, + const std::vector &begin_indices, + const std::vector &end_indices, + const std::vector &strides, + const int begin_mask, + const int end_mask, + const int ellipsis_mask, + const int new_axis_mask, + const int shrink_axis_mask, + const std::vector &output_shape, + const std::vector &output) { OpsTestNet net; net.AddInputFromArray("Input", input_shape, input); - net.AddInputFromArray("BeginIndices", - {static_cast( - input_shape.size())}, - begin_indices); - net.AddInputFromArray("EndIndices", - {static_cast( - input_shape.size())}, - end_indices); - net.AddInputFromArray("Strides", - {static_cast( - input_shape.size())}, - strides); + net.AddInputFromArray( + "BeginIndices", {static_cast(input_shape.size())}, + begin_indices); + net.AddInputFromArray( + "EndIndices", {static_cast(input_shape.size())}, end_indices); + net.AddInputFromArray( + "Strides", {static_cast(input_shape.size())}, strides); OpDefBuilder("StridedSlice", "StridedSliceOpTest") .Input("Input") @@ -70,47 +65,92 @@ void TestSlice(const std::vector &input_shape, *net.GetOutput("Output")); } +void TestSlice(const std::vector &input_shape, + const std::vector &input, + const std::vector &begin_indices, + const std::vector &indices_size, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray("Input", input_shape, input); + net.AddInputFromArray( + "BeginIndices", {static_cast(input_shape.size())}, + begin_indices); + net.AddInputFromArray( + "IndicesSize", {static_cast(indices_size.size())}, indices_size); + + OpDefBuilder("StridedSlice", "StridedSliceOpTest") + .Input("Input") + .Input("BeginIndices") + .Input("IndicesSize") + .Output("Output") + .AddIntArg("slice", 1) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} + } // namespace -TEST_F(StridedSliceOpTest, TestSliceByFirstAxis) { - TestSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 0, 0}, - {2, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 0, {1, 3, 2}, - {7, 8, 9, 10, 11, 12}); - TestSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 0, 0}, - {2, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 1, {3, 2}, {7, 8, 9, 10, 11, 12}); - TestSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2}, - {2, 3, 2}, {1, 1, 1}, 6, 6, 0, 0, 0, {1, 3, 2}, - {7, 8, 9, 10, 11, 12}); +TEST_F(StridedSliceOpTest, TestStridedSliceByFirstAxis) { + TestStridedSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + {1, 0, 0}, {2, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 0, {1, 3, 2}, + {7, 8, 9, 10, 11, 12}); + TestStridedSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + {1, 0, 0}, {2, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 1, {3, 2}, + {7, 8, 9, 10, 11, 12}); + TestStridedSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + {1, 1, 2}, {2, 3, 2}, {1, 1, 1}, 6, 6, 0, 0, 0, {1, 3, 2}, + {7, 8, 9, 10, 11, 12}); +} + +TEST_F(StridedSliceOpTest, TestStridedSliceRank1) { + TestStridedSlice({4}, {1, 2, 3, 4}, {1}, {3}, {1}, 0, 0, 0, 0, 0, {2}, + {2, 3}); + TestStridedSlice({4}, {1, 2, 3, 4}, {-3}, {3}, {1}, 0, 0, 0, 0, 0, {2}, + {2, 3}); + TestStridedSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 0, 0, 0, 0, 0, {2}, + {3, 2}); + TestStridedSlice({4}, {1, 2, 3, 4}, {-1}, {-4}, {-2}, 0, 0, 0, 0, 0, {2}, + {4, 2}); + TestStridedSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 1, 0, 0, 0, 0, {3}, + {4, 3, 2}); + TestStridedSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 0, 1, 0, 0, 0, {3}, + {3, 2, 1}); + TestStridedSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 1, 1, 0, 0, 0, {4}, + {4, 3, 2, 1}); + TestStridedSlice({4}, {1, 2, 3, 4}, {2}, {4}, {2}, 1, 1, 0, 0, 0, {2}, + {1, 3}); + TestStridedSlice({4}, {1, 2, 3, 4}, {2}, {3}, {1}, 0, 0, 0, 0, 1, {}, {3}); } -TEST_F(StridedSliceOpTest, TestSliceRank1) { - TestSlice({4}, {1, 2, 3, 4}, {1}, {3}, {1}, 0, 0, 0, 0, 0, {2}, {2, 3}); - TestSlice({4}, {1, 2, 3, 4}, {-3}, {3}, {1}, 0, 0, 0, 0, 0, {2}, {2, 3}); - TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 0, 0, 0, 0, 0, {2}, {3, 2}); - TestSlice({4}, {1, 2, 3, 4}, {-1}, {-4}, {-2}, 0, 0, 0, 0, 0, {2}, {4, 2}); - TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 1, 0, 0, 0, 0, {3}, {4, 3, 2}); - TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 0, 1, 0, 0, 0, {3}, {3, 2, 1}); - TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 1, 1, 0, 0, 0, {4}, - {4, 3, 2, 1}); - TestSlice({4}, {1, 2, 3, 4}, {2}, {4}, {2}, 1, 1, 0, 0, 0, {2}, {1, 3}); - TestSlice({4}, {1, 2, 3, 4}, {2}, {3}, {1}, 0, 0, 0, 0, 1, {}, {3}); +TEST_F(StridedSliceOpTest, TestStridedSliceRank2) { + TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 1}, 0, 0, 0, + 0, 0, {2, 3}, {1, 2, 3, 4, 5, 6}); + TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1}, {2, 3}, {1, 1}, 0, 0, 0, + 0, 0, {1, 2}, {5, 6}); + TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 2}, 0, 0, 0, + 0, 0, {2, 2}, {1, 3, 4, 6}); + TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {0, 0}, {-1, -1}, 0, 0, + 0, 0, 0, {1, 2}, {6, 5}); + TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {0, 0}, {-1, -1}, 3, 3, + 0, 0, 0, {2, 3}, {6, 5, 4, 3, 2, 1}); + TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 0}, {2, 3}, {1, 1}, 0, 0, 0, + 0, 1, {3}, {4, 5, 6}); + TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {2, 3}, {1, 1}, 0, 0, 0, + 0, 3, {}, {6}); } -TEST_F(StridedSliceOpTest, TestSliceRank2) { - TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 1}, 0, 0, 0, 0, 0, - {2, 3}, {1, 2, 3, 4, 5, 6}); - TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1}, {2, 3}, {1, 1}, 0, 0, 0, 0, 0, - {1, 2}, {5, 6}); - TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 2}, 0, 0, 0, 0, 0, - {2, 2}, {1, 3, 4, 6}); - TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {0, 0}, {-1, -1}, 0, 0, 0, 0, 0, - {1, 2}, {6, 5}); - TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {0, 0}, {-1, -1}, 3, 3, 0, 0, 0, - {2, 3}, {6, 5, 4, 3, 2, 1}); - TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 0}, {2, 3}, {1, 1}, 0, 0, 0, 0, 1, - {3}, {4, 5, 6}); - TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {2, 3}, {1, 1}, 0, 0, 0, 0, 3, - {}, {6}); +TEST_F(StridedSliceOpTest, TestSlice) { + TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {2, 3}, + {1, 2, 3, 4, 5, 6}); + TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 0}, {1, 2}, {1, 2}, {4, 5}); + TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, -1}, {2, 3}, + {1, 2, 3, 4, 5, 6}); + TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 1}, {2, -1}, {2, 2}, {2, 3, 5, 6}); } } // namespace test diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 8c8987bf9d758944dcf8f2656672a969dd9558ac..02207cf1afe70341f2ba1bf76b0c28b2e5e2f80e 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -1155,5 +1155,6 @@ class Transformer(base_converter.ConverterInterface): print("Final ops:") for op in net.op: - print("%s (%s)" % (op.name, op.type)) + print("%s (%s): %s" % (op.name, op.type, [ + out_shape.dims for out_shape in op.output_shape])) return False