diff --git a/mace/ops/cumsum.cc b/mace/ops/cumsum.cc index cb9ecb83bb1132d081aded153a0276c0339f3fa3..f0117270c80ce25bda50ab8e8461302b521c484e 100644 --- a/mace/ops/cumsum.cc +++ b/mace/ops/cumsum.cc @@ -19,26 +19,6 @@ namespace mace { namespace ops { -namespace { -void PlusOne(int* val) { - ++(*val); -} - -void SubOne(int* val) { - --(*val); -} - -bool LessThan(const int& val, const int& boundary) { - return val < boundary; -} - -bool NotLessThanZero(const int& val, const int& boundary) { - MACE_UNUSED(boundary); - return val >= 0; -} - -} // namespace - template class CumsumOp; @@ -47,9 +27,10 @@ class CumsumOp : public Operation { public: explicit CumsumOp(OpConstructContext *context) : Operation(context), - axis_(Operation::GetOptionalArg("axis", 3)), + axis_(Operation::GetOptionalArg("axis", 0)), exclusive_(Operation::GetOptionalArg("exclusive", false)), - reverse_(Operation::GetOptionalArg("reverse", false)) {} + reverse_(Operation::GetOptionalArg("reverse", false)), + checked_(false) {} void Validate() { const int32_t input_dims = this->Input(0)->dim_size(); @@ -64,9 +45,9 @@ class CumsumOp : public Operation { MACE_UNUSED(context); if (!checked_) { Validate(); - auto df = static_cast(Operation::GetOptionalArg( - "data_format", DataFormat::DF_NONE)); - if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) { + bool has_data_format = Operation::GetOptionalArg( + "has_data_format", 0); + if (has_data_format && this->Input(0)->dim_size() == 4) { if (axis_ == 3) axis_ = 1; else if (axis_ == 2) axis_ = 3; else if (axis_ == 1) axis_ = 2; @@ -75,6 +56,7 @@ class CumsumOp : public Operation { } const Tensor *input = this->Input(0); + const std::vector input_shape = input->shape(); Tensor *output = this->Output(0); MACE_RETURN_IF_ERROR(output->ResizeLike(input)); @@ -85,66 +67,70 @@ class CumsumOp : public Operation { const float *input_ptr = input->data(); float *output_ptr = output->mutable_data(); - std::function next = reverse_ ? SubOne : PlusOne; - std::function previous = reverse_ ? PlusOne : SubOne; - std::function boundary = - reverse_ ? NotLessThanZero : LessThan; - - if (input->dim_size() == 4) { - const int batch = input->dim(0); - const int channel = input->dim(1); - const int height = input->dim(2); - const int width = input->dim(3); - - const int axis_dim_size = input->dim(axis_); - - for (int n = reverse_ ? batch - 1 : 0; boundary(n, batch); next(&n)) { - for (int c = reverse_ ? channel - 1 : 0; boundary(c, channel); - next(&c)) { - for (int h = reverse_ ? height - 1 : 0; boundary(h, height); - next(&h)) { - for (int w = reverse_ ? width - 1 : 0; boundary(w, width); - next(&w)) { - int dims[4] = {n, c, h, w}; - if (!reverse_ && dims[axis_] == 0) { - if (exclusive_) { - output_ptr[((n * channel + c) * height + h) * width + w] = 0; - } else { - continue; - } - } else if (reverse_ && dims[axis_] == axis_dim_size - 1) { - if (exclusive_) { - output_ptr[((n * channel + c) * height + h) * width + w] = 0; - } else { - continue; - } - } else { - previous(&dims[axis_]); - if (exclusive_) { - output_ptr[((n * channel + c) * height + h) * width + w] = - input_ptr[((dims[0] * channel + dims[1]) * height + - dims[2]) * - width + - dims[3]] + - output_ptr[((dims[0] * channel + dims[1]) * height + - dims[2]) * - width + - dims[3]]; - } else { - output_ptr[((n * channel + c) * height + h) * width + w] = - input_ptr[((n * channel + c) * height + h) * width + w] + - output_ptr[((dims[0] * channel + dims[1]) * height + - dims[2]) * - width + - dims[3]]; - } - } + const index_t outer_size = std::accumulate(input_shape.begin(), + input_shape.begin() + axis_, + 1, + std::multiplies()); + const index_t inner_size = std::accumulate(input_shape.begin() + axis_ + 1, + input_shape.end(), + 1, + std::multiplies()); + const index_t cum_size = input_shape[axis_]; + + if (!reverse_) { +#pragma omp parallel for + for (index_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + index_t start_idx = outer_idx * cum_size * inner_size; + for (index_t cum_idx = 0; cum_idx < cum_size; ++cum_idx) { + if (cum_idx == 0) { + if (exclusive_) { + std::memset(output_ptr + start_idx, + 0, + sizeof(T) * inner_size); + } else { + std::memcpy(output_ptr + start_idx, + input_ptr + start_idx, + sizeof(T) * inner_size); + } + } else { + index_t cur_idx = start_idx + cum_idx * inner_size; + index_t pre_idx = start_idx + (cum_idx - 1) * inner_size; + index_t input_idx = exclusive_ ? pre_idx : cur_idx; + for (index_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) { + output_ptr[cur_idx + inner_idx] = + output_ptr[pre_idx + inner_idx] + + input_ptr[input_idx + inner_idx]; } } } } } else { - MACE_NOT_IMPLEMENTED; +#pragma omp parallel for + for (index_t outer_idx = outer_size - 1; outer_idx >= 0; --outer_idx) { + index_t start_idx = outer_idx * cum_size * inner_size; + for (index_t cum_idx = cum_size - 1; cum_idx >= 0; --cum_idx) { + index_t cur_idx = start_idx + cum_idx * inner_size; + if (cum_idx == cum_size - 1) { + if (exclusive_) { + std::memset(output_ptr + cur_idx, + 0, + sizeof(T) * inner_size); + } else { + std::memcpy(output_ptr + cur_idx, + input_ptr + cur_idx, + sizeof(T) * inner_size); + } + } else { + index_t pre_idx = start_idx + (cum_idx + 1) * inner_size; + index_t input_idx = exclusive_ ? pre_idx : cur_idx; + for (index_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) { + output_ptr[cur_idx + inner_idx] = + output_ptr[pre_idx + inner_idx] + + input_ptr[input_idx + inner_idx]; + } + } + } + } } return MaceStatus::MACE_SUCCESS; diff --git a/mace/ops/cumsum_benchmark.cc b/mace/ops/cumsum_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..8ca59fa0501fe92a35fbe0a02141cdd23a7c1198 --- /dev/null +++ b/mace/ops/cumsum_benchmark.cc @@ -0,0 +1,90 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class CumsumOpTest : public OpsTestBase {}; + +namespace { +template +void Cumsum(int iters, int batch, int channels, int height, int width) { + mace::testing::StopTiming(); + + // Construct graph + OpsTestNet net; + + // Add input data + if (D == DeviceType::CPU) { + net.AddRandomInput("Input", {batch, channels, height, width}); + } else { + MACE_NOT_IMPLEMENTED; + } + + OpDefBuilder("Cumsum", "CumsumTest") + .Input("Input") + .Output("Output") + .AddIntArg("axis", 0) + .AddIntArg("exclusive", 0) + .AddIntArg("reverse", 0) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} +} // namespace + +#define MACE_BM_CUMSUM_MACRO(N, C, H, W, TYPE, DEVICE) \ + static void MACE_BM_CUMSUM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + Cumsum(iters, N, C, H, W); \ + } \ + MACE_BENCHMARK(MACE_BM_CUMSUM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) + +#define MACE_BM_CUMSUM(N, C, H, W) \ + MACE_BM_CUMSUM_MACRO(N, C, H, W, float, CPU); + +MACE_BM_CUMSUM(1, 1, 512, 512); +MACE_BM_CUMSUM(1, 3, 128, 128); +MACE_BM_CUMSUM(1, 3, 512, 512); +MACE_BM_CUMSUM(1, 32, 112, 112); +MACE_BM_CUMSUM(1, 64, 256, 256); +MACE_BM_CUMSUM(1, 64, 512, 512); +MACE_BM_CUMSUM(1, 128, 56, 56); +MACE_BM_CUMSUM(1, 128, 256, 256); +MACE_BM_CUMSUM(1, 256, 14, 14); +MACE_BM_CUMSUM(1, 512, 14, 14); +MACE_BM_CUMSUM(1, 1024, 7, 7); +MACE_BM_CUMSUM(32, 1, 256, 256); +MACE_BM_CUMSUM(32, 3, 256, 256); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/cumsum_test.cc b/mace/ops/cumsum_test.cc index 473b8802a67f962a0bec44d391b7584165218c9a..8b111540c9040a391ae419d86e3c042b23954b5e 100644 --- a/mace/ops/cumsum_test.cc +++ b/mace/ops/cumsum_test.cc @@ -12,9 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include - #include "mace/ops/ops_test_util.h" namespace mace { @@ -24,33 +21,69 @@ namespace test { class CumsumOpTest : public OpsTestBase {}; namespace { -void SimpleTest() { +template +void SimpleTestWithDataFormat(const std::vector &shape, + const std::vector &input, + const int axis, + const int exclusive, + const int reverse, + const std::vector &output) { // Construct graph OpsTestNet net; - net.AddInputFromArray("Input", {2, 2, 2, 2}, - {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + net.AddInputFromArray("Input", shape, input); + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); OpDefBuilder("Cumsum", "CumsumTest") - .Input("Input") - .Output("Output") - .AddIntArg("axis", 1) - .AddIntArg("exclusive", 1) - .AddIntArg("reverse", 1) - .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Input("InputNCHW") + .Output("OutputNCHW") + .AddIntArg("axis", axis) + .AddIntArg("exclusive", exclusive) + .AddIntArg("reverse", reverse) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .AddIntArg("has_data_format", 1) .Finalize(net.NewOperatorDef()); // Run net.RunOp(DeviceType::CPU); - auto expected = net.CreateTensor({2, 2, 2, 2}, - {4., 5., 6., 7., 0., 0., 0., 0., 12., 13., 14., 15., 0., 0., 0., 0.}); - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); + net.TransformDataFormat("OutputNCHW", NCHW, "Output", + NHWC); + + net.AddInputFromArray("ExpectedOutput", shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); } } // namespace -TEST_F(CumsumOpTest, CPU) { - SimpleTest(); +TEST_F(CumsumOpTest, HasDataFormatCPU) { + SimpleTestWithDataFormat( + {2, 2, 2, 2}, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}, + 0, 0, 0, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 10., 12., 14., 16., 18., 20., 22.}); + SimpleTestWithDataFormat( + {2, 2, 2, 2}, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}, + 1, 0, 0, + {0., 1., 2., 3., 4., 6., 8., 10., 8., 9., 10., 11., 20., 22., 24., 26.}); + SimpleTestWithDataFormat( + {2, 2, 2, 2}, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}, + 0, 1, 0, + {0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 5., 6., 7.}); + SimpleTestWithDataFormat( + {2, 2, 2, 2}, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}, + 0, 0, 1, + {8., 10., 12., 14., 16., 18., 20., 22., 8., 9., 10., 11., 12., 13., 14., + 15.}); + SimpleTestWithDataFormat( + {2, 2, 2, 2}, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}, + 1, 1, 1, + {4., 5., 6., 7., 0., 0., 0., 0., 12., 13., 14., 15., 0., 0., 0., 0.}); } } // namespace test diff --git a/mace/ops/strided_slice.cc b/mace/ops/strided_slice.cc index 6087ae50478a70dc7b7b74a0457264d629991074..c10914f27fb87e7e1159749eb990a66bb6506f42 100644 --- a/mace/ops/strided_slice.cc +++ b/mace/ops/strided_slice.cc @@ -17,6 +17,7 @@ #include #include "mace/core/operator.h" +#include "mace/utils/math.h" namespace mace { namespace ops { @@ -33,6 +34,7 @@ class StridedSliceOp : public Operation { shrink_axis_mask_( Operation::GetOptionalArg("shrink_axis_mask", 0)), is_slice_(Operation::GetOptionalArg("slice", false)), + has_data_format_(Operation::GetOptionalArg("has_data_format", 0)), checked_(false) { MACE_CHECK(ellipsis_mask_ == 0 && new_axis_mask_ == 0, "ellipsis_mask and new_axis_mask are not supported yet."); @@ -62,14 +64,21 @@ class StridedSliceOp : public Operation { (*dims)[3] = w; } + void TransposeDimsFromNCHWToNHWC(std::vector* dims) { + int32_t c = (*dims)[1]; + int32_t h = (*dims)[2]; + int32_t w = (*dims)[3]; + + (*dims)[1] = h; + (*dims)[2] = w; + (*dims)[3] = c; + } + MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); - auto df = static_cast(Operation::GetOptionalArg( - "data_format", DataFormat::DF_NONE)); - if (!checked_) { - if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) { + if (has_data_format_ && this->Input(0)->dim_size() == 4) { TransposeMaskValueFromNHWCToNCHW(&begin_mask_); TransposeMaskValueFromNHWCToNCHW(&end_mask_); TransposeMaskValueFromNHWCToNCHW(&ellipsis_mask_); @@ -78,14 +87,15 @@ class StridedSliceOp : public Operation { } checked_ = true; } + const Tensor *input = this->Input(INPUT); const Tensor *begin_indices = this->Input(BEGIN); const Tensor *end_indices = this->Input(END); const Tensor *strides = nullptr; + if (this->InputSize() > 3) { strides = this->Input(STRIDES); } - Tensor *output = this->Output(OUTPUT); if (strides == nullptr) { tmp_strides_tensor_.Resize({begin_indices->size()}); Tensor::MappingGuard strides_guard(&tmp_strides_tensor_); @@ -94,6 +104,11 @@ class StridedSliceOp : public Operation { strides = &tmp_strides_tensor_; } + MACE_CHECK(begin_indices->dim_size() == 1 && + end_indices->dim_size() == 1 && + strides->dim_size() == 1, + "Expected begin, end, and strides to be 1D tensor"); + Tensor::MappingGuard input_guard(input); Tensor::MappingGuard begin_indices_guard(begin_indices); Tensor::MappingGuard end_indices_guard(end_indices); @@ -102,107 +117,145 @@ class StridedSliceOp : public Operation { const int32_t *begin_indices_data = begin_indices->data(); const int32_t *end_indices_data = end_indices->data(); const int32_t *strides_data = strides->data(); - std::vector pad_begin_indices(input->dim_size(), 0); - std::vector pad_end_indices(input->dim_size(), 0); - std::vector pad_strides_indices(input->dim_size(), 1); - - if (begin_indices->size() < input->dim_size()) { - for (index_t i = 0; i < begin_indices->size(); ++i) { - pad_begin_indices[i] = begin_indices_data[i]; - pad_end_indices[i] = end_indices_data[i]; - pad_strides_indices[i] = strides_data[i]; - } - for (index_t i = begin_indices->size(); i < input->dim_size(); ++i) { - pad_end_indices[i] = input->dim(i); - } - begin_indices_data = pad_begin_indices.data(); - end_indices_data = pad_end_indices.data(); - strides_data = pad_strides_indices.data(); - } + std::vector begin_indices_vec( + begin_indices_data, begin_indices_data + begin_indices->size()); + std::vector end_indices_vec( + end_indices_data, end_indices_data + end_indices->size()); + std::vector strides_indices_vec( + strides_data, strides_data + strides->size()); - std::vector transpose_begin_indices(input->dim_size(), 0); - std::vector transpose_end_indices(input->dim_size(), 0); - std::vector transpose_strides_indices(input->dim_size(), 1); - if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) { - for (index_t i = 0; i < begin_indices->size(); ++i) { - transpose_begin_indices[i] = begin_indices_data[i]; - transpose_end_indices[i] = end_indices_data[i]; - transpose_strides_indices[i] = strides_data[i]; - } - TransposeDimsFromNHWCToNCHW(&transpose_begin_indices); - TransposeDimsFromNHWCToNCHW(&transpose_end_indices); - TransposeDimsFromNHWCToNCHW(&transpose_strides_indices); + MACE_CHECK(input->size() > 0 && input->dim_size() > 0 && + input->dim_size() <= 4, + "The input size should larger than 0." + " And input dims should be an integer in (0, 4]."); - begin_indices_data = transpose_begin_indices.data(); - end_indices_data = transpose_end_indices.data(); - strides_data = transpose_strides_indices.data(); - } + std::vector output_shape = {}; - std::vector slice_end_data; + const size_t input_dims = input->dim_size(); if (is_slice_) { - // if this op is slice, the end_indices_data is size actually - slice_end_data.resize(end_indices->size()); - for (size_t i = 0; i < slice_end_data.size(); ++i) { - if (end_indices_data[i] == -1) { - slice_end_data[i] = input->dim(i); - } else { - slice_end_data[i] = begin_indices_data[i] + end_indices_data[i]; + MACE_CHECK(begin_indices_vec.size() == input_dims && + end_indices_vec.size() == input_dims, + "In slice, begin and size elements num should be equal"); + + // transpose + if (has_data_format_ && this->Input(0)->dim_size() == 4) { + TransposeDimsFromNHWCToNCHW(&begin_indices_vec); + TransposeDimsFromNHWCToNCHW(&end_indices_vec); + TransposeDimsFromNHWCToNCHW(&strides_indices_vec); + } + + for (size_t i = 0; i < input_dims; ++i) { + if (end_indices_vec[i] == -1) { + end_indices_vec[i] = input->dim(i) - begin_indices_vec[i]; } } - end_indices_data = slice_end_data.data(); - } - std::vector output_shape; - std::vector real_begin_indices(input->dim_size(), 0); - std::vector real_end_indices(input->dim_size(), 0); - for (index_t d = 0; d < input->dim_size(); ++d) { - index_t dim_len = input->dim(d); - if (begin_mask_ & (1 << d)) { - real_begin_indices[d] = strides_data[d] > 0 ? 0 : dim_len - 1; - } else { - real_begin_indices[d] = (begin_indices_data[d] + dim_len) % dim_len; + for (size_t i = 0; i < input_dims; ++i) { + int32_t b = begin_indices_vec[i]; + int32_t s = end_indices_vec[i]; + int32_t input_i = input->dim(i); + MACE_CHECK(0 <= b && b <= input_i, + "In Slice, expected begin[", i, "] in [0, ", input_i, + "], but got ", b); + MACE_CHECK(0 <= s && b + s <= input_i, + "In Slice, expected size[", i, "] in [0, ", + input_i - b, "], but got", s); + end_indices_vec[i] = b + s; + output_shape.push_back(s); } - if (end_mask_ & (1 << d)) { - real_end_indices[d] = strides_data[d] > 0 ? dim_len : -1; - } else { - real_end_indices[d] = - end_indices_data[d] < -dim_len - ? -1 - : (end_indices_data[d] < 0 - ? (end_indices_data[d] + dim_len) - : std::min(static_cast(end_indices_data[d]), - dim_len)); + } else { + MACE_CHECK(begin_indices_vec.size() == end_indices_vec.size() && + end_indices_vec.size() == strides_indices_vec.size(), + "In strided_slice, expected begin, end, and strides to be", + " equal size tensors"); + for (index_t i = 0; i < strides->size(); ++i) { + MACE_CHECK(strides_indices_vec[i] != 0, "strides data cannot be 0!"); } - int32_t out_dim_len = std::max( - 0.f, std::ceil((real_end_indices[d] - real_begin_indices[d]) / - static_cast(strides_data[d]))); - if (!(shrink_axis_mask_ & (1 << d))) { - output_shape.push_back(out_dim_len); - } else { - MACE_CHECK(out_dim_len == 1, - "cannot shrink axis that has len > 1, dim(", d, "): [", - real_begin_indices[d], ", ", real_end_indices[d], "]"); + // pad + begin_indices_vec.resize(input_dims, 0); + strides_indices_vec.resize(input_dims, 1); + std::vector tmp_input_dims(input->shape().begin(), + input->shape().end()); + if (has_data_format_ && input_dims == 4) { + TransposeDimsFromNCHWToNHWC(&tmp_input_dims); + } + for (size_t i = end_indices_vec.size(); i < input_dims; ++i) { + end_indices_vec.push_back(tmp_input_dims[i]); + } + + // transpose + if (has_data_format_ && this->Input(0)->dim_size() == 4) { + TransposeDimsFromNHWCToNCHW(&begin_indices_vec); + TransposeDimsFromNHWCToNCHW(&end_indices_vec); + TransposeDimsFromNHWCToNCHW(&strides_indices_vec); + } + + // mask and shrink + for (index_t d = 0; d < input->dim_size(); ++d) { + index_t dim_len = input->dim(d); + const std::vector valid_range = { + strides_indices_vec[d] > 0 ? 0 : -1, + strides_indices_vec[d] > 0 ? dim_len : dim_len - 1}; + + auto format_indices = [valid_range, d, dim_len](index_t indice) { + index_t forward = indice < 0 ? indice + dim_len : indice; + return Clamp(forward, valid_range[0], valid_range[1]); + }; + + if (!(shrink_axis_mask_ & (1 << d))) { + if (begin_mask_ & (1 << d)) { + begin_indices_vec[d] = strides_indices_vec[d] > 0 ? 0 : dim_len - 1; + } else { + begin_indices_vec[d] = format_indices(begin_indices_vec[d]); + } + if (end_mask_ & (1 << d)) { + end_indices_vec[d] = strides_indices_vec[d] > 0 ? dim_len : -1; + } else { + end_indices_vec[d] = format_indices(end_indices_vec[d]); + } + + int32_t out_dim_len = std::max( + 0.f, std::ceil((end_indices_vec[d] - begin_indices_vec[d]) / + static_cast(strides_indices_vec[d]))); + output_shape.push_back(out_dim_len); + } else { + begin_indices_vec[d] = begin_indices_vec[d] < 0 + ? begin_indices_vec[d] + dim_len + : begin_indices_vec[d]; + end_indices_vec[d] = begin_indices_vec[d] + 1; + MACE_CHECK( + begin_indices_vec[d] >= 0 && begin_indices_vec[d] < dim_len, + "slice begin indice of dimension '", d, "': ", + begin_indices_vec[d], ", is out of bound"); + } } } + for (size_t i = 0; i < output_shape.size(); ++i) { + MACE_CHECK(output_shape[i] > 0, + "Expected output_shape[", i, "] larger than 0, but got ", + output_shape[i]); + } + std::vector dim_stride(input->dim_size(), 1); for (index_t d = input->dim_size() - 2; d >= 0; --d) { dim_stride[d] = dim_stride[d + 1] * input->dim(d + 1); } + Tensor *output = this->Output(OUTPUT); MACE_RETURN_IF_ERROR(output->Resize(output_shape)); Tensor::MappingGuard output_guard(output); T *output_data = output->mutable_data(); bool slice_by_first_axis = true; - if (strides_data[0] != 1) { + if (strides_indices_vec[0] != 1) { slice_by_first_axis = false; } else { for (index_t d = 1; d < input->dim_size(); ++d) { - if (strides_data[d] != 1 || real_begin_indices[d] != 0 || - real_end_indices[d] != input->dim(d)) { + if (strides_indices_vec[d] != 1 || begin_indices_vec[d] != 0 || + end_indices_vec[d] != input->dim(d)) { slice_by_first_axis = false; break; } @@ -210,64 +263,64 @@ class StridedSliceOp : public Operation { } if (slice_by_first_axis) { - memcpy(output_data, input_data + real_begin_indices[0] * dim_stride[0], - sizeof(T) * (real_end_indices[0] - real_begin_indices[0]) * + memcpy(output_data, input_data + begin_indices_vec[0] * dim_stride[0], + sizeof(T) * (end_indices_vec[0] - begin_indices_vec[0]) * dim_stride[0]); } else { if (input->dim_size() == 1) { - for (index_t i = real_begin_indices[0]; - strides_data[0] > 0 ? i < real_end_indices[0] - : i > real_end_indices[0]; - i += strides_data[0]) { + for (index_t i = begin_indices_vec[0]; + strides_indices_vec[0] > 0 ? i < end_indices_vec[0] + : i > end_indices_vec[0]; + i += strides_indices_vec[0]) { *output_data++ = input_data[i]; } } else if (input->dim_size() == 2) { - for (index_t i = real_begin_indices[0]; - strides_data[0] > 0 ? i < real_end_indices[0] - : i > real_end_indices[0]; - i += strides_data[0]) { - for (index_t j = real_begin_indices[1]; - strides_data[1] > 0 ? j < real_end_indices[1] - : j > real_end_indices[1]; - j += strides_data[1]) { + for (index_t i = begin_indices_vec[0]; + strides_indices_vec[0] > 0 ? i < end_indices_vec[0] + : i > end_indices_vec[0]; + i += strides_indices_vec[0]) { + for (index_t j = begin_indices_vec[1]; + strides_indices_vec[1] > 0 ? j < end_indices_vec[1] + : j > end_indices_vec[1]; + j += strides_indices_vec[1]) { *output_data++ = input_data[i * input->dim(1) + j]; } } } else if (input->dim_size() == 3) { - for (index_t i = real_begin_indices[0]; - strides_data[0] > 0 ? i < real_end_indices[0] - : i > real_end_indices[0]; - i += strides_data[0]) { - for (index_t j = real_begin_indices[1]; - strides_data[1] > 0 ? j < real_end_indices[1] - : j > real_end_indices[1]; - j += strides_data[1]) { - for (index_t k = real_begin_indices[2]; - strides_data[2] > 0 ? k < real_end_indices[2] - : k > real_end_indices[2]; - k += strides_data[2]) { + for (index_t i = begin_indices_vec[0]; + strides_indices_vec[0] > 0 ? i < end_indices_vec[0] + : i > end_indices_vec[0]; + i += strides_indices_vec[0]) { + for (index_t j = begin_indices_vec[1]; + strides_indices_vec[1] > 0 ? j < end_indices_vec[1] + : j > end_indices_vec[1]; + j += strides_indices_vec[1]) { + for (index_t k = begin_indices_vec[2]; + strides_indices_vec[2] > 0 ? k < end_indices_vec[2] + : k > end_indices_vec[2]; + k += strides_indices_vec[2]) { *output_data++ = input_data[(i * input->dim(1) + j) * input->dim(2) + k]; } } } } else if (input->dim_size() == 4) { - for (index_t i = real_begin_indices[0]; - strides_data[0] > 0 ? i < real_end_indices[0] - : i > real_end_indices[0]; - i += strides_data[0]) { - for (index_t j = real_begin_indices[1]; - strides_data[1] > 0 ? j < real_end_indices[1] - : j > real_end_indices[1]; - j += strides_data[1]) { - for (index_t k = real_begin_indices[2]; - strides_data[2] > 0 ? k < real_end_indices[2] - : k > real_end_indices[2]; - k += strides_data[2]) { - for (index_t l = real_begin_indices[3]; - strides_data[3] > 0 ? l < real_end_indices[3] - : l > real_end_indices[3]; - l += strides_data[3]) { + for (index_t i = begin_indices_vec[0]; + strides_indices_vec[0] > 0 ? i < end_indices_vec[0] + : i > end_indices_vec[0]; + i += strides_indices_vec[0]) { + for (index_t j = begin_indices_vec[1]; + strides_indices_vec[1] > 0 ? j < end_indices_vec[1] + : j > end_indices_vec[1]; + j += strides_indices_vec[1]) { + for (index_t k = begin_indices_vec[2]; + strides_indices_vec[2] > 0 ? k < end_indices_vec[2] + : k > end_indices_vec[2]; + k += strides_indices_vec[2]) { + for (index_t l = begin_indices_vec[3]; + strides_indices_vec[3] > 0 ? l < end_indices_vec[3] + : l > end_indices_vec[3]; + l += strides_indices_vec[3]) { *output_data++ = input_data[((i * input->dim(1) + j) * input->dim(2) + k) * input->dim(3) + l]; @@ -289,6 +342,7 @@ class StridedSliceOp : public Operation { int new_axis_mask_; int shrink_axis_mask_; bool is_slice_; + int has_data_format_; bool checked_; Tensor tmp_strides_tensor_; diff --git a/mace/ops/strided_slice_test.cc b/mace/ops/strided_slice_test.cc index df691ce682f2a0a55db2f93a9077a265f61cbef0..8b085fe532694f7c343e0cfda735d91332aea294 100644 --- a/mace/ops/strided_slice_test.cc +++ b/mace/ops/strided_slice_test.cc @@ -64,6 +64,54 @@ void TestStridedSlice(const std::vector &input_shape, *net.GetOutput("Output")); } +void TestStridedSliceWithDataFormat(const std::vector &input_shape, + const std::vector &input, + const std::vector &begin_indices, + const std::vector &end_indices, + const std::vector &strides, + const int begin_mask, + const int end_mask, + const int ellipsis_mask, + const int new_axis_mask, + const int shrink_axis_mask, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray("Input", input_shape, input); + net.AddInputFromArray( + "BeginIndices", {static_cast(begin_indices.size())}, + begin_indices); + net.AddInputFromArray( + "EndIndices", {static_cast(end_indices.size())}, end_indices); + net.AddInputFromArray( + "Strides", {static_cast(strides.size())}, strides); + + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); + + OpDefBuilder("StridedSlice", "StridedSliceOpTest") + .Input("InputNCHW") + .Input("BeginIndices") + .Input("EndIndices") + .Input("Strides") + .Output("OutputNCHW") + .AddIntArg("begin_mask", begin_mask) + .AddIntArg("end_mask", end_mask) + .AddIntArg("ellipsis_mask", ellipsis_mask) + .AddIntArg("new_axis_mask", new_axis_mask) + .AddIntArg("shrink_axis_mask", shrink_axis_mask) + .AddIntArg("has_data_format", 1) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.TransformDataFormat("OutputNCHW", NCHW, "Output", + NHWC); + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} + void TestSlice(const std::vector &input_shape, const std::vector &input, const std::vector &begin_indices, @@ -92,6 +140,41 @@ void TestSlice(const std::vector &input_shape, *net.GetOutput("Output")); } +void TestSliceWithDataFormat(const std::vector &input_shape, + const std::vector &input, + const std::vector &begin_indices, + const std::vector &indices_size, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray("Input", input_shape, input); + net.AddInputFromArray( + "BeginIndices", {static_cast(input_shape.size())}, + begin_indices); + net.AddInputFromArray( + "IndicesSize", {static_cast(indices_size.size())}, indices_size); + + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); + + OpDefBuilder("StridedSlice", "StridedSliceOpTest") + .Input("InputNCHW") + .Input("BeginIndices") + .Input("IndicesSize") + .Output("OutputNCHW") + .AddIntArg("slice", 1) + .AddIntArg("has_data_format", 1) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.TransformDataFormat("OutputNCHW", NCHW, "Output", + NHWC); + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} + } // namespace TEST_F(StridedSliceOpTest, TestStridedSliceByFirstAxis) { @@ -157,6 +240,66 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank3) { 1, 2}, {1, 1, 3, 3}); } + +TEST_F(StridedSliceOpTest, TestStridedSliceRank4) { + TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 0, 0, 0, 0, 0, {1, 2, 1, 2}, + {15, 16, 21, 22}); + TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 3, 0, 0, 0, 0, {2, 2, 1, 2}, + {3, 4, 9, 10, 15, 16, 21, 22}); + TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 0, {1, 2, 1, 3}, + {15, 16, 17, 21, 22, 23}); + TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 8, {1, 2, 1}, + {15, 21}); + TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 15, {}, {15}); + TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {-1, 2, 1, 3}, + {0, 0, 0, 0}, {-1, -1, -1, -1}, 0, 0, 0, 0, 0, {1, 1, 1, 2}, + {23, 22}); +} + +TEST_F(StridedSliceOpTest, TestStridedSliceWithDataFormat) { + TestStridedSliceWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 0, 0, 0, 0, 0, {1, 2, 1, 2}, + {15, 16, 21, 22}); + TestStridedSliceWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 3, 0, 0, 0, 0, {2, 2, 1, 2}, + {3, 4, 9, 10, 15, 16, 21, 22}); + TestStridedSliceWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 0, {1, 2, 1, 3}, + {15, 16, 17, 21, 22, 23}); + TestStridedSliceWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0}, + {2, 1}, {1, 1}, 0, 8, 0, 0, 0, {1, 1, 2, 3}, + {12, 13, 14, 15, 16, 17}); + TestStridedSliceWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0}, + {2, 1}, {1, 1}, 0, 2, 0, 0, 0, {1, 2, 2, 3}, + {12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + TestStridedSliceWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {-1, 2, 1, 3}, + {0, 0, 0, 0}, {-1, -1, -1, -1}, 0, 0, 0, 0, 0, {1, 1, 1, 2}, + {23, 22}); +} + TEST_F(StridedSliceOpTest, TestSlice) { TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {2, 3}, {1, 2, 3, 4, 5, 6}); @@ -166,6 +309,17 @@ TEST_F(StridedSliceOpTest, TestSlice) { TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 1}, {2, -1}, {2, 2}, {2, 3, 5, 6}); } +TEST_F(StridedSliceOpTest, TestSliceWithDataFormat) { + TestSliceWithDataFormat({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 0, 1, 0}, {1, 2, 1, 2}, {1, 2, 1, 2}, + {15, 16, 21, 22}); + TestSliceWithDataFormat({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 0, 1, 0}, {-1, -1, -1, -1}, {1, 2, 1, 3}, + {15, 16, 17, 21, 22, 23}); +} + } // namespace test } // namespace ops } // namespace mace