diff --git a/mace/core/operator.cc b/mace/core/operator.cc index e5355b2d842059f23ba33f000b164ada1804f01f..4bc3cfade4fcd129b68687b0ac03b08961a11d6d 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -100,8 +100,11 @@ extern void Register_Quantize(OperatorRegistry *op_registry); extern void Register_Requantize(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry); +extern void Register_Shape(OperatorRegistry *op_registry); extern void Register_Slice(OperatorRegistry *op_registry); extern void Register_Softmax(OperatorRegistry *op_registry); +extern void Register_Stack(OperatorRegistry *op_registry); +extern void Register_StridedSlice(OperatorRegistry *op_registry); extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); extern void Register_SpaceToDepth(OperatorRegistry *op_registry); extern void Register_Transpose(OperatorRegistry *op_registry); @@ -142,8 +145,11 @@ OperatorRegistry::OperatorRegistry() { ops::Register_Requantize(this); ops::Register_Reshape(this); ops::Register_ResizeBilinear(this); + ops::Register_Shape(this); ops::Register_Slice(this); ops::Register_Softmax(this); + ops::Register_Stack(this); + ops::Register_StridedSlice(this); ops::Register_SpaceToBatchND(this); ops::Register_SpaceToDepth(this); ops::Register_Transpose(this); diff --git a/mace/kernels/stack.h b/mace/kernels/stack.h new file mode 100644 index 0000000000000000000000000000000000000000..873b84ad5252adfbe14c734a5521972e92fd047b --- /dev/null +++ b/mace/kernels/stack.h @@ -0,0 +1,81 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_STACK_H_ +#define MACE_KERNELS_STACK_H_ + +#include +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/public/mace.h" + +namespace mace { +namespace kernels { + +template +struct StackFunctor { + explicit StackFunctor(int axis) : axis_(axis) {} + + MaceStatus operator()(const std::vector &inputs, + Tensor *output, + StatsFuture *future) { + MACE_UNUSED(future); + + MACE_CHECK(!inputs.empty(), "stack inputs are empty."); + std::vector input_shape = inputs[0]->shape(); + MACE_CHECK(axis_ >= -(inputs[0]->dim_size() + 1) && + axis_ < inputs[0]->dim_size() + 1, + "axis out of bound."); + if (axis_ < 0) { + axis_ += inputs[0]->dim_size() + 1; + } + std::vector output_shape = input_shape; + output_shape.insert(output_shape.begin() + axis_, inputs.size()); + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + + // On host, no need to map data + T *output_data = output->mutable_data(); + std::vector input_data(inputs.size()); + for (int i = 0; i < inputs.size(); ++i) { + input_data[i] = inputs[i]->data(); + } + + index_t high_dim_elem_size = + std::accumulate(input_shape.begin(), input_shape.begin() + axis_, 1, + std::multiplies()); + index_t low_dim_elem_size = + std::accumulate(input_shape.begin() + axis_, input_shape.end(), 1, + std::multiplies()); + for (index_t h = 0; h < high_dim_elem_size; ++h) { + for (index_t i = 0; i < inputs.size(); ++i) { + memcpy(output_data, input_data[i] + h * low_dim_elem_size, + sizeof(T) * low_dim_elem_size); + output_data += low_dim_elem_size; + } + } + + return MACE_SUCCESS; + } + + int axis_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_STACK_H_ diff --git a/mace/kernels/strided_slice.h b/mace/kernels/strided_slice.h new file mode 100644 index 0000000000000000000000000000000000000000..9486e2fdbfd8d620eba743220838875c3a077234 --- /dev/null +++ b/mace/kernels/strided_slice.h @@ -0,0 +1,160 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_STRIDED_SLICE_H_ +#define MACE_KERNELS_STRIDED_SLICE_H_ + +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/public/mace.h" + +namespace mace { +namespace kernels { + +template +struct StridedSliceFunctor { + StridedSliceFunctor(int begin_mask, + int end_mask, + int ellipsis_mask, + int new_axis_mask, + int shrink_axis_mask) + : begin_mask_(begin_mask), + end_mask_(end_mask), + ellipsis_mask_(ellipsis_mask), + new_axis_mask_(new_axis_mask), + shrink_axis_mask_(shrink_axis_mask) {} + + MaceStatus operator()(const Tensor *input, + const Tensor *begin_indices, + const Tensor *end_indices, + const Tensor *strides, + Tensor *output, + StatsFuture *future) { + MACE_UNUSED(future); + MACE_CHECK(ellipsis_mask_ == 0 && new_axis_mask_ == 0, + "ellipsis_mask and new_axis_mask are not supported yet."); + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard begin_indices_guard(begin_indices); + Tensor::MappingGuard end_indices_guard(end_indices); + Tensor::MappingGuard strides_guard(strides); + const T *input_data = input->data(); + const int32_t *begin_indices_data = begin_indices->data(); + const int32_t *end_indices_data = end_indices->data(); + const int32_t *strides_data = strides->data(); + + std::vector output_shape; + std::vector real_begin_indices(input->dim_size(), 0); + std::vector real_end_indices(input->dim_size(), 0); + for (index_t d = 0; d < input->dim_size(); ++d) { + index_t dim_len = input->dim(d); + if (begin_mask_ & (1 << d)) { + real_begin_indices[d] = strides_data[d] > 0 ? 0 : dim_len - 1; + } else { + real_begin_indices[d] = (begin_indices_data[d] + dim_len) % dim_len; + } + if (end_mask_ & (1 << d)) { + real_end_indices[d] = strides_data[d] > 0 ? dim_len : -1; + } else { + real_end_indices[d] = + end_indices_data[d] < -dim_len + ? -1 + : (end_indices_data[d] < 0 + ? (end_indices_data[d] + dim_len) + : std::min(static_cast(end_indices_data[d]), + dim_len)); + } + + int32_t out_dim_len = std::max( + 0.f, std::ceil((real_end_indices[d] - real_begin_indices[d]) / + static_cast(strides_data[d]))); + if (!(shrink_axis_mask_ & (1 << d))) { + output_shape.push_back(out_dim_len); + } else { + MACE_CHECK(out_dim_len == 1, + "cannot shrink axis that has len > 1, dim(", d, "): [", + real_begin_indices[d], ", ", real_end_indices[d], "]"); + } + } + + std::vector dim_stride(input->dim_size(), 1); + for (index_t d = input->dim_size() - 2; d >= 0; --d) { + dim_stride[d] = dim_stride[d + 1] * input->dim(d + 1); + } + + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + Tensor::MappingGuard output_guard(output); + T *output_data = output->mutable_data(); + + bool slice_by_first_axis = true; + if (strides_data[0] != 1) { + slice_by_first_axis = false; + } else { + for (index_t d = 1; d < input->dim_size(); ++d) { + if (strides_data[d] != 1 || real_begin_indices[d] != 0 || + real_end_indices[d] != input->dim(d)) { + slice_by_first_axis = false; + break; + } + } + } + + if (slice_by_first_axis) { + memcpy(output_data, input_data + real_begin_indices[0] * dim_stride[0], + sizeof(T) * (real_end_indices[0] - real_begin_indices[0]) * + dim_stride[0]); + } else { + if (input->dim_size() == 1) { + for (index_t i = real_begin_indices[0]; + strides_data[0] > 0 ? i < real_end_indices[0] + : i > real_end_indices[0]; + i += strides_data[0]) { + *output_data++ = input_data[i]; + } + + } else if (input->dim_size() == 2) { + for (index_t i = real_begin_indices[0]; + strides_data[0] > 0 ? i < real_end_indices[0] + : i > real_end_indices[0]; + i += strides_data[0]) { + for (index_t j = real_begin_indices[1]; + strides_data[1] > 0 ? j < real_end_indices[1] + : j > real_end_indices[1]; + j += strides_data[1]) { + *output_data++ = input_data[i * dim_stride[0] + j]; + } + } + } else { + MACE_NOT_IMPLEMENTED; + } + } + + return MACE_SUCCESS; + } + + int begin_mask_; + int end_mask_; + int ellipsis_mask_; + int new_axis_mask_; + int shrink_axis_mask_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_STRIDED_SLICE_H_ diff --git a/mace/ops/shape.cc b/mace/ops/shape.cc new file mode 100644 index 0000000000000000000000000000000000000000..c65586e6e366197c4bd3d154bdfa73e66ca728a8 --- /dev/null +++ b/mace/ops/shape.cc @@ -0,0 +1,29 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/shape.h" + +namespace mace { +namespace ops { + +void Register_Shape(OperatorRegistry *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Shape") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ShapeOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/shape.h b/mace/ops/shape.h new file mode 100644 index 0000000000000000000000000000000000000000..aaac1b39447a6ee53bcd326733d3b9b61526d16a --- /dev/null +++ b/mace/ops/shape.h @@ -0,0 +1,57 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_SHAPE_H_ +#define MACE_OPS_SHAPE_H_ + +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { + +template +class ShapeOp : public Operator { + public: + ShapeOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + Tensor *output = this->Output(OUTPUT); + if (input->dim_size() > 0) { + MACE_RETURN_IF_ERROR(output->Resize({input->dim_size()})); + } else { + output->Resize({}); + } + Tensor::MappingGuard output_guard(output); + int32_t *output_data = output->mutable_data(); + + for (index_t i = 0; i < input->dim_size(); ++i) { + output_data[i] = input->dim(i); + } + + return MACE_SUCCESS; + } + + private: + MACE_OP_INPUT_TAGS(INPUT); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_SHAPE_H_ diff --git a/mace/ops/shape_test.cc b/mace/ops/shape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5798be7f8309970445cb3c8bf10e6327c2f52144 --- /dev/null +++ b/mace/ops/shape_test.cc @@ -0,0 +1,62 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class ShapeOpTest : public OpsTestBase {}; + +namespace { + +void TestShapeOp(const std::vector &input_shape) { + OpsTestNet net; + net.AddRandomInput("Input", input_shape); + OpDefBuilder("Shape", "ShapeOpTest") + .Input("Input") + .Output("Output") + .OutputType({DataTypeToEnum::v()}) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + // we need to convert vector to vector + std::vector expected_input_shape(input_shape.begin(), + input_shape.end()); + if (!expected_input_shape.empty()) { + net.AddInputFromArray("ExpectedOutput", {input_shape.size()}, + expected_input_shape); + } else { + net.AddInputFromArray("ExpectedOutput", {}, {0}); + } + + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} + +} // namespace + +TEST_F(ShapeOpTest, TestShape) { + TestShapeOp({1, 2, 3}); + TestShapeOp({2, 3}); + TestShapeOp({3}); + TestShapeOp({}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/stack.cc b/mace/ops/stack.cc new file mode 100644 index 0000000000000000000000000000000000000000..f951460a665404c7e34a17a4c71ee38b79c428a0 --- /dev/null +++ b/mace/ops/stack.cc @@ -0,0 +1,29 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/stack.h" + +namespace mace { +namespace ops { + +void Register_Stack(OperatorRegistry *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Stack") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + StackOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/stack.h b/mace/ops/stack.h new file mode 100644 index 0000000000000000000000000000000000000000..27a90fc32ae6870dcac2c0c52fc42d17cb769f93 --- /dev/null +++ b/mace/ops/stack.h @@ -0,0 +1,50 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_STACK_H_ +#define MACE_OPS_STACK_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/stack.h" + +namespace mace { +namespace ops { + +template +class StackOp : public Operator { + public: + StackOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_(OperatorBase::GetOptionalArg("axis", 0)) {} + + MaceStatus Run(StatsFuture *future) override { + const std::vector &inputs = this->Inputs(); + Tensor *output = this->Output(OUTPUT); + + return functor_(inputs, output, future); + } + + private: + kernels::StackFunctor functor_; + + protected: + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_STACK_H_ diff --git a/mace/ops/stack_test.cc b/mace/ops/stack_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d91c4a022f38130b259d64629f8d2eac2a6a2d9f --- /dev/null +++ b/mace/ops/stack_test.cc @@ -0,0 +1,76 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class StackOpTest : public OpsTestBase {}; + +namespace { + +void TestStack(const std::vector &input_shape, + const std::vector> &inputs, + int axis, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + for (int i = 0; i < inputs.size(); ++i) { + net.AddInputFromArray(MakeString("Input", i), input_shape, + inputs[i]); + } + + auto op_builder = OpDefBuilder("Stack", "StackOpTest") + .Output("Output") + .AddIntArg("axis", axis); + + for (int i = 0; i < inputs.size(); ++i) { + op_builder.Input(MakeString("Input", i)); + } + op_builder.Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} + +} // namespace + +TEST_F(StackOpTest, TestStackScalar) { + TestStack({}, {{1}, {2}, {3}}, 0, {3}, {1, 2, 3}); +} + +TEST_F(StackOpTest, TestStackVector) { + TestStack({2}, {{1, 4}, {2, 5}, {3, 6}}, 0, {3, 2}, {1, 4, 2, 5, 3, 6}); + TestStack({2}, {{1, 4}, {2, 5}, {3, 6}}, -2, {3, 2}, {1, 4, 2, 5, 3, 6}); + TestStack({2}, {{1, 4}, {2, 5}, {3, 6}}, 1, {2, 3}, {1, 2, 3, 4, 5, 6}); +} + +TEST_F(StackOpTest, TestStackHighRank) { + TestStack({2, 3}, {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}}, -3, {2, 2, 3}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + TestStack({2, 3}, {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}}, 1, {2, 2, 3}, + {1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}); + TestStack({2, 3}, {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}}, 2, {2, 3, 2}, + {1, 7, 2, 8, 3, 9, 4, 10, 5, 11, 6, 12}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/strided_slice.cc b/mace/ops/strided_slice.cc new file mode 100644 index 0000000000000000000000000000000000000000..674c766f75a6bfa967266a428cb816557eef86fa --- /dev/null +++ b/mace/ops/strided_slice.cc @@ -0,0 +1,29 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/strided_slice.h" + +namespace mace { +namespace ops { + +void Register_StridedSlice(OperatorRegistry *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("StridedSlice") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + StridedSliceOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/strided_slice.h b/mace/ops/strided_slice.h new file mode 100644 index 0000000000000000000000000000000000000000..e3e25db543f32e3b3361422a762159d50aeec69e --- /dev/null +++ b/mace/ops/strided_slice.h @@ -0,0 +1,56 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_STRIDED_SLICE_H_ +#define MACE_OPS_STRIDED_SLICE_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/strided_slice.h" + +namespace mace { +namespace ops { + +template +class StridedSliceOp : public Operator { + public: + StridedSliceOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_(OperatorBase::GetOptionalArg("begin_mask", 0), + OperatorBase::GetOptionalArg("end_mask", 0), + OperatorBase::GetOptionalArg("ellipsis_mask", 0), + OperatorBase::GetOptionalArg("new_axis_mask", 0), + OperatorBase::GetOptionalArg("shrink_axis_mask", 0)) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const Tensor *begin_indices = this->Input(BEGIN); + const Tensor *end_indices = this->Input(END); + const Tensor *strides = this->Input(STRIDES); + Tensor *output = this->Output(OUTPUT); + + return functor_(input, begin_indices, end_indices, strides, output, future); + } + + private: + kernels::StridedSliceFunctor functor_; + + protected: + MACE_OP_INPUT_TAGS(INPUT, BEGIN, END, STRIDES); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_STRIDED_SLICE_H_ diff --git a/mace/ops/strided_slice_test.cc b/mace/ops/strided_slice_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2aa4af2820488a7ea7fb0a293f05e7b7ad1802bf --- /dev/null +++ b/mace/ops/strided_slice_test.cc @@ -0,0 +1,111 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class StridedSliceOpTest : public OpsTestBase {}; + +namespace { + +void TestSlice(const std::vector &input_shape, + const std::vector &input, + const std::vector &begin_indices, + const std::vector &end_indices, + const std::vector &strides, + const int begin_mask, + const int end_mask, + const int ellipsis_mask, + const int new_axis_mask, + const int shrink_axis_mask, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray("Input", input_shape, input); + net.AddInputFromArray("BeginIndices", {input_shape.size()}, + begin_indices); + net.AddInputFromArray("EndIndices", {input_shape.size()}, + end_indices); + net.AddInputFromArray("Strides", {input_shape.size()}, strides); + + OpDefBuilder("StridedSlice", "StridedSliceOpTest") + .Input("Input") + .Input("BeginIndices") + .Input("EndIndices") + .Input("Strides") + .Output("Output") + .AddIntArg("begin_mask", begin_mask) + .AddIntArg("end_mask", end_mask) + .AddIntArg("ellipsis_mask", ellipsis_mask) + .AddIntArg("new_axis_mask", new_axis_mask) + .AddIntArg("shrink_axis_mask", shrink_axis_mask) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} + +} // namespace + +TEST_F(StridedSliceOpTest, TestSliceByFirstAxis) { + TestSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 0, 0}, + {2, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 0, {1, 3, 2}, + {7, 8, 9, 10, 11, 12}); + TestSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 0, 0}, + {2, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 1, {3, 2}, {7, 8, 9, 10, 11, 12}); + TestSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2}, + {2, 3, 2}, {1, 1, 1}, 6, 6, 0, 0, 0, {1, 3, 2}, + {7, 8, 9, 10, 11, 12}); +} + +TEST_F(StridedSliceOpTest, TestSliceRank1) { + TestSlice({4}, {1, 2, 3, 4}, {1}, {3}, {1}, 0, 0, 0, 0, 0, {2}, {2, 3}); + TestSlice({4}, {1, 2, 3, 4}, {-3}, {3}, {1}, 0, 0, 0, 0, 0, {2}, {2, 3}); + TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 0, 0, 0, 0, 0, {2}, {3, 2}); + TestSlice({4}, {1, 2, 3, 4}, {-1}, {-4}, {-2}, 0, 0, 0, 0, 0, {2}, {4, 2}); + TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 1, 0, 0, 0, 0, {3}, {4, 3, 2}); + TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 0, 1, 0, 0, 0, {3}, {3, 2, 1}); + TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 1, 1, 0, 0, 0, {4}, + {4, 3, 2, 1}); + TestSlice({4}, {1, 2, 3, 4}, {2}, {4}, {2}, 1, 1, 0, 0, 0, {2}, {1, 3}); + TestSlice({4}, {1, 2, 3, 4}, {2}, {3}, {1}, 0, 0, 0, 0, 1, {}, {3}); +} + +TEST_F(StridedSliceOpTest, TestSliceRank2) { + TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 1}, 0, 0, 0, 0, 0, + {2, 3}, {1, 2, 3, 4, 5, 6}); + TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1}, {2, 3}, {1, 1}, 0, 0, 0, 0, 0, + {1, 2}, {5, 6}); + TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 2}, 0, 0, 0, 0, 0, + {2, 2}, {1, 3, 4, 6}); + TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {0, 0}, {-1, -1}, 0, 0, 0, 0, 0, + {1, 2}, {6, 5}); + TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {0, 0}, {-1, -1}, 3, 3, 0, 0, 0, + {2, 3}, {6, 5, 4, 3, 2, 1}); + TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 0}, {2, 3}, {1, 1}, 0, 0, 0, 0, 1, + {3}, {4, 5, 6}); + TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {2, 3}, {1, 1}, 0, 0, 0, 0, 3, + {}, {6}); +} + +} // namespace test +} // namespace ops +} // namespace mace