diff --git a/mace/ops/cumsum.cc b/mace/ops/cumsum.cc new file mode 100644 index 0000000000000000000000000000000000000000..f0117270c80ce25bda50ab8e8461302b521c484e --- /dev/null +++ b/mace/ops/cumsum.cc @@ -0,0 +1,152 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mace/core/operator.h" + +namespace mace { +namespace ops { + +template +class CumsumOp; + +template +class CumsumOp : public Operation { + public: + explicit CumsumOp(OpConstructContext *context) + : Operation(context), + axis_(Operation::GetOptionalArg("axis", 0)), + exclusive_(Operation::GetOptionalArg("exclusive", false)), + reverse_(Operation::GetOptionalArg("reverse", false)), + checked_(false) {} + + void Validate() { + const int32_t input_dims = this->Input(0)->dim_size(); + axis_ = + axis_ < 0 ? axis_ + input_dims : axis_; + MACE_CHECK((0 <= axis_ && axis_ < input_dims), + "Expected concatenating axis in the range [", -input_dims, ", ", + input_dims, "], but got ", axis_); + } + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + if (!checked_) { + Validate(); + bool has_data_format = Operation::GetOptionalArg( + "has_data_format", 0); + if (has_data_format && this->Input(0)->dim_size() == 4) { + if (axis_ == 3) axis_ = 1; + else if (axis_ == 2) axis_ = 3; + else if (axis_ == 1) axis_ = 2; + } + checked_ = true; + } + + const Tensor *input = this->Input(0); + const std::vector input_shape = input->shape(); + + Tensor *output = this->Output(0); + MACE_RETURN_IF_ERROR(output->ResizeLike(input)); + + Tensor::MappingGuard input_mapper(input); + Tensor::MappingGuard output_mapper(output); + + const float *input_ptr = input->data(); + float *output_ptr = output->mutable_data(); + + const index_t outer_size = std::accumulate(input_shape.begin(), + input_shape.begin() + axis_, + 1, + std::multiplies()); + const index_t inner_size = std::accumulate(input_shape.begin() + axis_ + 1, + input_shape.end(), + 1, + std::multiplies()); + const index_t cum_size = input_shape[axis_]; + + if (!reverse_) { +#pragma omp parallel for + for (index_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + index_t start_idx = outer_idx * cum_size * inner_size; + for (index_t cum_idx = 0; cum_idx < cum_size; ++cum_idx) { + if (cum_idx == 0) { + if (exclusive_) { + std::memset(output_ptr + start_idx, + 0, + sizeof(T) * inner_size); + } else { + std::memcpy(output_ptr + start_idx, + input_ptr + start_idx, + sizeof(T) * inner_size); + } + } else { + index_t cur_idx = start_idx + cum_idx * inner_size; + index_t pre_idx = start_idx + (cum_idx - 1) * inner_size; + index_t input_idx = exclusive_ ? pre_idx : cur_idx; + for (index_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) { + output_ptr[cur_idx + inner_idx] = + output_ptr[pre_idx + inner_idx] + + input_ptr[input_idx + inner_idx]; + } + } + } + } + } else { +#pragma omp parallel for + for (index_t outer_idx = outer_size - 1; outer_idx >= 0; --outer_idx) { + index_t start_idx = outer_idx * cum_size * inner_size; + for (index_t cum_idx = cum_size - 1; cum_idx >= 0; --cum_idx) { + index_t cur_idx = start_idx + cum_idx * inner_size; + if (cum_idx == cum_size - 1) { + if (exclusive_) { + std::memset(output_ptr + cur_idx, + 0, + sizeof(T) * inner_size); + } else { + std::memcpy(output_ptr + cur_idx, + input_ptr + cur_idx, + sizeof(T) * inner_size); + } + } else { + index_t pre_idx = start_idx + (cum_idx + 1) * inner_size; + index_t input_idx = exclusive_ ? pre_idx : cur_idx; + for (index_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) { + output_ptr[cur_idx + inner_idx] = + output_ptr[pre_idx + inner_idx] + + input_ptr[input_idx + inner_idx]; + } + } + } + } + } + + return MaceStatus::MACE_SUCCESS; + } + + private: + int32_t axis_; + bool exclusive_; + bool reverse_; + bool checked_; +}; + +void RegisterCumsum(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "Cumsum", CumsumOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/cumsum_benchmark.cc b/mace/ops/cumsum_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..8ca59fa0501fe92a35fbe0a02141cdd23a7c1198 --- /dev/null +++ b/mace/ops/cumsum_benchmark.cc @@ -0,0 +1,90 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class CumsumOpTest : public OpsTestBase {}; + +namespace { +template +void Cumsum(int iters, int batch, int channels, int height, int width) { + mace::testing::StopTiming(); + + // Construct graph + OpsTestNet net; + + // Add input data + if (D == DeviceType::CPU) { + net.AddRandomInput("Input", {batch, channels, height, width}); + } else { + MACE_NOT_IMPLEMENTED; + } + + OpDefBuilder("Cumsum", "CumsumTest") + .Input("Input") + .Output("Output") + .AddIntArg("axis", 0) + .AddIntArg("exclusive", 0) + .AddIntArg("reverse", 0) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} +} // namespace + +#define MACE_BM_CUMSUM_MACRO(N, C, H, W, TYPE, DEVICE) \ + static void MACE_BM_CUMSUM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + Cumsum(iters, N, C, H, W); \ + } \ + MACE_BENCHMARK(MACE_BM_CUMSUM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) + +#define MACE_BM_CUMSUM(N, C, H, W) \ + MACE_BM_CUMSUM_MACRO(N, C, H, W, float, CPU); + +MACE_BM_CUMSUM(1, 1, 512, 512); +MACE_BM_CUMSUM(1, 3, 128, 128); +MACE_BM_CUMSUM(1, 3, 512, 512); +MACE_BM_CUMSUM(1, 32, 112, 112); +MACE_BM_CUMSUM(1, 64, 256, 256); +MACE_BM_CUMSUM(1, 64, 512, 512); +MACE_BM_CUMSUM(1, 128, 56, 56); +MACE_BM_CUMSUM(1, 128, 256, 256); +MACE_BM_CUMSUM(1, 256, 14, 14); +MACE_BM_CUMSUM(1, 512, 14, 14); +MACE_BM_CUMSUM(1, 1024, 7, 7); +MACE_BM_CUMSUM(32, 1, 256, 256); +MACE_BM_CUMSUM(32, 3, 256, 256); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/cumsum_test.cc b/mace/ops/cumsum_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8b111540c9040a391ae419d86e3c042b23954b5e --- /dev/null +++ b/mace/ops/cumsum_test.cc @@ -0,0 +1,91 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class CumsumOpTest : public OpsTestBase {}; + +namespace { +template +void SimpleTestWithDataFormat(const std::vector &shape, + const std::vector &input, + const int axis, + const int exclusive, + const int reverse, + const std::vector &output) { + // Construct graph + OpsTestNet net; + + net.AddInputFromArray("Input", shape, input); + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); + + OpDefBuilder("Cumsum", "CumsumTest") + .Input("InputNCHW") + .Output("OutputNCHW") + .AddIntArg("axis", axis) + .AddIntArg("exclusive", exclusive) + .AddIntArg("reverse", reverse) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .AddIntArg("has_data_format", 1) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(DeviceType::CPU); + + net.TransformDataFormat("OutputNCHW", NCHW, "Output", + NHWC); + + net.AddInputFromArray("ExpectedOutput", shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} +} // namespace + +TEST_F(CumsumOpTest, HasDataFormatCPU) { + SimpleTestWithDataFormat( + {2, 2, 2, 2}, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}, + 0, 0, 0, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 10., 12., 14., 16., 18., 20., 22.}); + SimpleTestWithDataFormat( + {2, 2, 2, 2}, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}, + 1, 0, 0, + {0., 1., 2., 3., 4., 6., 8., 10., 8., 9., 10., 11., 20., 22., 24., 26.}); + SimpleTestWithDataFormat( + {2, 2, 2, 2}, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}, + 0, 1, 0, + {0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 5., 6., 7.}); + SimpleTestWithDataFormat( + {2, 2, 2, 2}, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}, + 0, 0, 1, + {8., 10., 12., 14., 16., 18., 20., 22., 8., 9., 10., 11., 12., 13., 14., + 15.}); + SimpleTestWithDataFormat( + {2, 2, 2, 2}, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}, + 1, 1, 1, + {4., 5., 6., 7., 0., 0., 0., 0., 12., 13., 14., 15., 0., 0., 0., 0.}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/ops_registry.cc b/mace/ops/ops_registry.cc index 52f22e4a36fdacf2eaa4caef066bfe00600bd069..b7f807352ee453ab56ff9454e6408acaa525c066 100644 --- a/mace/ops/ops_registry.cc +++ b/mace/ops/ops_registry.cc @@ -29,6 +29,7 @@ extern void RegisterChannelShuffle(OpRegistryBase *op_registry); extern void RegisterConcat(OpRegistryBase *op_registry); extern void RegisterConv2D(OpRegistryBase *op_registry); extern void RegisterCrop(OpRegistryBase *op_registry); +extern void RegisterCumsum(OpRegistryBase *op_registry); extern void RegisterDeconv2D(OpRegistryBase *op_registry); extern void RegisterDepthToSpace(OpRegistryBase *op_registry); extern void RegisterDepthwiseConv2d(OpRegistryBase *op_registry); @@ -95,6 +96,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ops::RegisterConcat(this); ops::RegisterConv2D(this); ops::RegisterCrop(this); + ops::RegisterCumsum(this); ops::RegisterDeconv2D(this); ops::RegisterDepthToSpace(this); ops::RegisterDepthwiseConv2d(this); diff --git a/mace/ops/strided_slice.cc b/mace/ops/strided_slice.cc index 221a75d46442afd1b3f385350b6ddd943bdb5db9..c10914f27fb87e7e1159749eb990a66bb6506f42 100644 --- a/mace/ops/strided_slice.cc +++ b/mace/ops/strided_slice.cc @@ -17,6 +17,7 @@ #include #include "mace/core/operator.h" +#include "mace/utils/math.h" namespace mace { namespace ops { @@ -32,21 +33,69 @@ class StridedSliceOp : public Operation { new_axis_mask_(Operation::GetOptionalArg("new_axis_mask", 0)), shrink_axis_mask_( Operation::GetOptionalArg("shrink_axis_mask", 0)), - is_slice_(Operation::GetOptionalArg("slice", false)) { + is_slice_(Operation::GetOptionalArg("slice", false)), + has_data_format_(Operation::GetOptionalArg("has_data_format", 0)), + checked_(false) { MACE_CHECK(ellipsis_mask_ == 0 && new_axis_mask_ == 0, "ellipsis_mask and new_axis_mask are not supported yet."); } + void TransposeMaskValueFromNHWCToNCHW(int* mask_value) { + size_t dims[4]; + int count; + for (count = 0; count < 4; ++count) { + dims[count] = *mask_value & 1; + *mask_value >>= 1; + } + size_t new_dims[4] = {dims[0], dims[3], dims[1], dims[2]}; + for (count = 3; count >= 0; --count) { + *mask_value <<= 1; + *mask_value += new_dims[count]; + } + } + + void TransposeDimsFromNHWCToNCHW(std::vector* dims) { + int32_t h = (*dims)[1]; + int32_t w = (*dims)[2]; + int32_t c = (*dims)[3]; + + (*dims)[1] = c; + (*dims)[2] = h; + (*dims)[3] = w; + } + + void TransposeDimsFromNCHWToNHWC(std::vector* dims) { + int32_t c = (*dims)[1]; + int32_t h = (*dims)[2]; + int32_t w = (*dims)[3]; + + (*dims)[1] = h; + (*dims)[2] = w; + (*dims)[3] = c; + } + MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); + + if (!checked_) { + if (has_data_format_ && this->Input(0)->dim_size() == 4) { + TransposeMaskValueFromNHWCToNCHW(&begin_mask_); + TransposeMaskValueFromNHWCToNCHW(&end_mask_); + TransposeMaskValueFromNHWCToNCHW(&ellipsis_mask_); + TransposeMaskValueFromNHWCToNCHW(&new_axis_mask_); + TransposeMaskValueFromNHWCToNCHW(&shrink_axis_mask_); + } + checked_ = true; + } + const Tensor *input = this->Input(INPUT); const Tensor *begin_indices = this->Input(BEGIN); const Tensor *end_indices = this->Input(END); const Tensor *strides = nullptr; + if (this->InputSize() > 3) { strides = this->Input(STRIDES); } - Tensor *output = this->Output(OUTPUT); if (strides == nullptr) { tmp_strides_tensor_.Resize({begin_indices->size()}); Tensor::MappingGuard strides_guard(&tmp_strides_tensor_); @@ -55,6 +104,11 @@ class StridedSliceOp : public Operation { strides = &tmp_strides_tensor_; } + MACE_CHECK(begin_indices->dim_size() == 1 && + end_indices->dim_size() == 1 && + strides->dim_size() == 1, + "Expected begin, end, and strides to be 1D tensor"); + Tensor::MappingGuard input_guard(input); Tensor::MappingGuard begin_indices_guard(begin_indices); Tensor::MappingGuard end_indices_guard(end_indices); @@ -63,88 +117,145 @@ class StridedSliceOp : public Operation { const int32_t *begin_indices_data = begin_indices->data(); const int32_t *end_indices_data = end_indices->data(); const int32_t *strides_data = strides->data(); - std::vector pad_begin_indices(input->dim_size(), 0); - std::vector pad_end_indices(input->dim_size(), 0); - std::vector pad_strides_indices(input->dim_size(), 1); - - if (begin_indices->size() < input->dim_size()) { - for (index_t i = 0; i < begin_indices->size(); ++i) { - pad_begin_indices[i] = begin_indices_data[i]; - pad_end_indices[i] = end_indices_data[i]; - pad_strides_indices[i] = strides_data[i]; - } - for (index_t i = begin_indices->size(); i < input->dim_size(); ++i) { - pad_end_indices[i] = input->dim(i); - } - begin_indices_data = pad_begin_indices.data(); - end_indices_data = pad_end_indices.data(); - strides_data = pad_strides_indices.data(); - } - std::vector slice_end_data; + std::vector begin_indices_vec( + begin_indices_data, begin_indices_data + begin_indices->size()); + std::vector end_indices_vec( + end_indices_data, end_indices_data + end_indices->size()); + std::vector strides_indices_vec( + strides_data, strides_data + strides->size()); + + MACE_CHECK(input->size() > 0 && input->dim_size() > 0 && + input->dim_size() <= 4, + "The input size should larger than 0." + " And input dims should be an integer in (0, 4]."); + + std::vector output_shape = {}; + + const size_t input_dims = input->dim_size(); if (is_slice_) { - // if this op is slice, the end_indices_data is size actually - slice_end_data.resize(end_indices->size()); - for (size_t i = 0; i < slice_end_data.size(); ++i) { - if (end_indices_data[i] == -1) { - slice_end_data[i] = input->dim(i); - } else { - slice_end_data[i] = begin_indices_data[i] + end_indices_data[i]; + MACE_CHECK(begin_indices_vec.size() == input_dims && + end_indices_vec.size() == input_dims, + "In slice, begin and size elements num should be equal"); + + // transpose + if (has_data_format_ && this->Input(0)->dim_size() == 4) { + TransposeDimsFromNHWCToNCHW(&begin_indices_vec); + TransposeDimsFromNHWCToNCHW(&end_indices_vec); + TransposeDimsFromNHWCToNCHW(&strides_indices_vec); + } + + for (size_t i = 0; i < input_dims; ++i) { + if (end_indices_vec[i] == -1) { + end_indices_vec[i] = input->dim(i) - begin_indices_vec[i]; } } - end_indices_data = slice_end_data.data(); - } - std::vector output_shape; - std::vector real_begin_indices(input->dim_size(), 0); - std::vector real_end_indices(input->dim_size(), 0); - for (index_t d = 0; d < input->dim_size(); ++d) { - index_t dim_len = input->dim(d); - if (begin_mask_ & (1 << d)) { - real_begin_indices[d] = strides_data[d] > 0 ? 0 : dim_len - 1; - } else { - real_begin_indices[d] = (begin_indices_data[d] + dim_len) % dim_len; + for (size_t i = 0; i < input_dims; ++i) { + int32_t b = begin_indices_vec[i]; + int32_t s = end_indices_vec[i]; + int32_t input_i = input->dim(i); + MACE_CHECK(0 <= b && b <= input_i, + "In Slice, expected begin[", i, "] in [0, ", input_i, + "], but got ", b); + MACE_CHECK(0 <= s && b + s <= input_i, + "In Slice, expected size[", i, "] in [0, ", + input_i - b, "], but got", s); + end_indices_vec[i] = b + s; + output_shape.push_back(s); } - if (end_mask_ & (1 << d)) { - real_end_indices[d] = strides_data[d] > 0 ? dim_len : -1; - } else { - real_end_indices[d] = - end_indices_data[d] < -dim_len - ? -1 - : (end_indices_data[d] < 0 - ? (end_indices_data[d] + dim_len) - : std::min(static_cast(end_indices_data[d]), - dim_len)); + } else { + MACE_CHECK(begin_indices_vec.size() == end_indices_vec.size() && + end_indices_vec.size() == strides_indices_vec.size(), + "In strided_slice, expected begin, end, and strides to be", + " equal size tensors"); + for (index_t i = 0; i < strides->size(); ++i) { + MACE_CHECK(strides_indices_vec[i] != 0, "strides data cannot be 0!"); } - int32_t out_dim_len = std::max( - 0.f, std::ceil((real_end_indices[d] - real_begin_indices[d]) / - static_cast(strides_data[d]))); - if (!(shrink_axis_mask_ & (1 << d))) { - output_shape.push_back(out_dim_len); - } else { - MACE_CHECK(out_dim_len == 1, - "cannot shrink axis that has len > 1, dim(", d, "): [", - real_begin_indices[d], ", ", real_end_indices[d], "]"); + // pad + begin_indices_vec.resize(input_dims, 0); + strides_indices_vec.resize(input_dims, 1); + std::vector tmp_input_dims(input->shape().begin(), + input->shape().end()); + if (has_data_format_ && input_dims == 4) { + TransposeDimsFromNCHWToNHWC(&tmp_input_dims); + } + for (size_t i = end_indices_vec.size(); i < input_dims; ++i) { + end_indices_vec.push_back(tmp_input_dims[i]); + } + + // transpose + if (has_data_format_ && this->Input(0)->dim_size() == 4) { + TransposeDimsFromNHWCToNCHW(&begin_indices_vec); + TransposeDimsFromNHWCToNCHW(&end_indices_vec); + TransposeDimsFromNHWCToNCHW(&strides_indices_vec); + } + + // mask and shrink + for (index_t d = 0; d < input->dim_size(); ++d) { + index_t dim_len = input->dim(d); + const std::vector valid_range = { + strides_indices_vec[d] > 0 ? 0 : -1, + strides_indices_vec[d] > 0 ? dim_len : dim_len - 1}; + + auto format_indices = [valid_range, d, dim_len](index_t indice) { + index_t forward = indice < 0 ? indice + dim_len : indice; + return Clamp(forward, valid_range[0], valid_range[1]); + }; + + if (!(shrink_axis_mask_ & (1 << d))) { + if (begin_mask_ & (1 << d)) { + begin_indices_vec[d] = strides_indices_vec[d] > 0 ? 0 : dim_len - 1; + } else { + begin_indices_vec[d] = format_indices(begin_indices_vec[d]); + } + if (end_mask_ & (1 << d)) { + end_indices_vec[d] = strides_indices_vec[d] > 0 ? dim_len : -1; + } else { + end_indices_vec[d] = format_indices(end_indices_vec[d]); + } + + int32_t out_dim_len = std::max( + 0.f, std::ceil((end_indices_vec[d] - begin_indices_vec[d]) / + static_cast(strides_indices_vec[d]))); + output_shape.push_back(out_dim_len); + } else { + begin_indices_vec[d] = begin_indices_vec[d] < 0 + ? begin_indices_vec[d] + dim_len + : begin_indices_vec[d]; + end_indices_vec[d] = begin_indices_vec[d] + 1; + MACE_CHECK( + begin_indices_vec[d] >= 0 && begin_indices_vec[d] < dim_len, + "slice begin indice of dimension '", d, "': ", + begin_indices_vec[d], ", is out of bound"); + } } } + for (size_t i = 0; i < output_shape.size(); ++i) { + MACE_CHECK(output_shape[i] > 0, + "Expected output_shape[", i, "] larger than 0, but got ", + output_shape[i]); + } + std::vector dim_stride(input->dim_size(), 1); for (index_t d = input->dim_size() - 2; d >= 0; --d) { dim_stride[d] = dim_stride[d + 1] * input->dim(d + 1); } + Tensor *output = this->Output(OUTPUT); MACE_RETURN_IF_ERROR(output->Resize(output_shape)); Tensor::MappingGuard output_guard(output); T *output_data = output->mutable_data(); bool slice_by_first_axis = true; - if (strides_data[0] != 1) { + if (strides_indices_vec[0] != 1) { slice_by_first_axis = false; } else { for (index_t d = 1; d < input->dim_size(); ++d) { - if (strides_data[d] != 1 || real_begin_indices[d] != 0 || - real_end_indices[d] != input->dim(d)) { + if (strides_indices_vec[d] != 1 || begin_indices_vec[d] != 0 || + end_indices_vec[d] != input->dim(d)) { slice_by_first_axis = false; break; } @@ -152,47 +263,71 @@ class StridedSliceOp : public Operation { } if (slice_by_first_axis) { - memcpy(output_data, input_data + real_begin_indices[0] * dim_stride[0], - sizeof(T) * (real_end_indices[0] - real_begin_indices[0]) * + memcpy(output_data, input_data + begin_indices_vec[0] * dim_stride[0], + sizeof(T) * (end_indices_vec[0] - begin_indices_vec[0]) * dim_stride[0]); } else { if (input->dim_size() == 1) { - for (index_t i = real_begin_indices[0]; - strides_data[0] > 0 ? i < real_end_indices[0] - : i > real_end_indices[0]; - i += strides_data[0]) { + for (index_t i = begin_indices_vec[0]; + strides_indices_vec[0] > 0 ? i < end_indices_vec[0] + : i > end_indices_vec[0]; + i += strides_indices_vec[0]) { *output_data++ = input_data[i]; } } else if (input->dim_size() == 2) { - for (index_t i = real_begin_indices[0]; - strides_data[0] > 0 ? i < real_end_indices[0] - : i > real_end_indices[0]; - i += strides_data[0]) { - for (index_t j = real_begin_indices[1]; - strides_data[1] > 0 ? j < real_end_indices[1] - : j > real_end_indices[1]; - j += strides_data[1]) { + for (index_t i = begin_indices_vec[0]; + strides_indices_vec[0] > 0 ? i < end_indices_vec[0] + : i > end_indices_vec[0]; + i += strides_indices_vec[0]) { + for (index_t j = begin_indices_vec[1]; + strides_indices_vec[1] > 0 ? j < end_indices_vec[1] + : j > end_indices_vec[1]; + j += strides_indices_vec[1]) { *output_data++ = input_data[i * input->dim(1) + j]; } } } else if (input->dim_size() == 3) { - for (index_t i = real_begin_indices[0]; - strides_data[0] > 0 ? i < real_end_indices[0] - : i > real_end_indices[0]; - i += strides_data[0]) { - for (index_t j = real_begin_indices[1]; - strides_data[1] > 0 ? j < real_end_indices[1] - : j > real_end_indices[1]; - j += strides_data[1]) { - for (index_t k = real_begin_indices[2]; - strides_data[2] > 0 ? k < real_end_indices[2] - : k > real_end_indices[2]; - k += strides_data[2]) { + for (index_t i = begin_indices_vec[0]; + strides_indices_vec[0] > 0 ? i < end_indices_vec[0] + : i > end_indices_vec[0]; + i += strides_indices_vec[0]) { + for (index_t j = begin_indices_vec[1]; + strides_indices_vec[1] > 0 ? j < end_indices_vec[1] + : j > end_indices_vec[1]; + j += strides_indices_vec[1]) { + for (index_t k = begin_indices_vec[2]; + strides_indices_vec[2] > 0 ? k < end_indices_vec[2] + : k > end_indices_vec[2]; + k += strides_indices_vec[2]) { *output_data++ = input_data[(i * input->dim(1) + j) * input->dim(2) + k]; } } } + } else if (input->dim_size() == 4) { + for (index_t i = begin_indices_vec[0]; + strides_indices_vec[0] > 0 ? i < end_indices_vec[0] + : i > end_indices_vec[0]; + i += strides_indices_vec[0]) { + for (index_t j = begin_indices_vec[1]; + strides_indices_vec[1] > 0 ? j < end_indices_vec[1] + : j > end_indices_vec[1]; + j += strides_indices_vec[1]) { + for (index_t k = begin_indices_vec[2]; + strides_indices_vec[2] > 0 ? k < end_indices_vec[2] + : k > end_indices_vec[2]; + k += strides_indices_vec[2]) { + for (index_t l = begin_indices_vec[3]; + strides_indices_vec[3] > 0 ? l < end_indices_vec[3] + : l > end_indices_vec[3]; + l += strides_indices_vec[3]) { + *output_data++ = + input_data[((i * input->dim(1) + j) * input->dim(2) + k) + * input->dim(3) + l]; + } + } + } + } } else { MACE_NOT_IMPLEMENTED; } @@ -207,6 +342,8 @@ class StridedSliceOp : public Operation { int new_axis_mask_; int shrink_axis_mask_; bool is_slice_; + int has_data_format_; + bool checked_; Tensor tmp_strides_tensor_; MACE_OP_INPUT_TAGS(INPUT, BEGIN, END, STRIDES); diff --git a/mace/ops/strided_slice_test.cc b/mace/ops/strided_slice_test.cc index df691ce682f2a0a55db2f93a9077a265f61cbef0..8b085fe532694f7c343e0cfda735d91332aea294 100644 --- a/mace/ops/strided_slice_test.cc +++ b/mace/ops/strided_slice_test.cc @@ -64,6 +64,54 @@ void TestStridedSlice(const std::vector &input_shape, *net.GetOutput("Output")); } +void TestStridedSliceWithDataFormat(const std::vector &input_shape, + const std::vector &input, + const std::vector &begin_indices, + const std::vector &end_indices, + const std::vector &strides, + const int begin_mask, + const int end_mask, + const int ellipsis_mask, + const int new_axis_mask, + const int shrink_axis_mask, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray("Input", input_shape, input); + net.AddInputFromArray( + "BeginIndices", {static_cast(begin_indices.size())}, + begin_indices); + net.AddInputFromArray( + "EndIndices", {static_cast(end_indices.size())}, end_indices); + net.AddInputFromArray( + "Strides", {static_cast(strides.size())}, strides); + + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); + + OpDefBuilder("StridedSlice", "StridedSliceOpTest") + .Input("InputNCHW") + .Input("BeginIndices") + .Input("EndIndices") + .Input("Strides") + .Output("OutputNCHW") + .AddIntArg("begin_mask", begin_mask) + .AddIntArg("end_mask", end_mask) + .AddIntArg("ellipsis_mask", ellipsis_mask) + .AddIntArg("new_axis_mask", new_axis_mask) + .AddIntArg("shrink_axis_mask", shrink_axis_mask) + .AddIntArg("has_data_format", 1) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.TransformDataFormat("OutputNCHW", NCHW, "Output", + NHWC); + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} + void TestSlice(const std::vector &input_shape, const std::vector &input, const std::vector &begin_indices, @@ -92,6 +140,41 @@ void TestSlice(const std::vector &input_shape, *net.GetOutput("Output")); } +void TestSliceWithDataFormat(const std::vector &input_shape, + const std::vector &input, + const std::vector &begin_indices, + const std::vector &indices_size, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray("Input", input_shape, input); + net.AddInputFromArray( + "BeginIndices", {static_cast(input_shape.size())}, + begin_indices); + net.AddInputFromArray( + "IndicesSize", {static_cast(indices_size.size())}, indices_size); + + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); + + OpDefBuilder("StridedSlice", "StridedSliceOpTest") + .Input("InputNCHW") + .Input("BeginIndices") + .Input("IndicesSize") + .Output("OutputNCHW") + .AddIntArg("slice", 1) + .AddIntArg("has_data_format", 1) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + net.TransformDataFormat("OutputNCHW", NCHW, "Output", + NHWC); + net.AddInputFromArray("ExpectedOutput", output_shape, output); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} + } // namespace TEST_F(StridedSliceOpTest, TestStridedSliceByFirstAxis) { @@ -157,6 +240,66 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank3) { 1, 2}, {1, 1, 3, 3}); } + +TEST_F(StridedSliceOpTest, TestStridedSliceRank4) { + TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 0, 0, 0, 0, 0, {1, 2, 1, 2}, + {15, 16, 21, 22}); + TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 3, 0, 0, 0, 0, {2, 2, 1, 2}, + {3, 4, 9, 10, 15, 16, 21, 22}); + TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 0, {1, 2, 1, 3}, + {15, 16, 17, 21, 22, 23}); + TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 8, {1, 2, 1}, + {15, 21}); + TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 15, {}, {15}); + TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {-1, 2, 1, 3}, + {0, 0, 0, 0}, {-1, -1, -1, -1}, 0, 0, 0, 0, 0, {1, 1, 1, 2}, + {23, 22}); +} + +TEST_F(StridedSliceOpTest, TestStridedSliceWithDataFormat) { + TestStridedSliceWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 0, 0, 0, 0, 0, {1, 2, 1, 2}, + {15, 16, 21, 22}); + TestStridedSliceWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 3, 0, 0, 0, 0, {2, 2, 1, 2}, + {3, 4, 9, 10, 15, 16, 21, 22}); + TestStridedSliceWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0}, + {2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 0, {1, 2, 1, 3}, + {15, 16, 17, 21, 22, 23}); + TestStridedSliceWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0}, + {2, 1}, {1, 1}, 0, 8, 0, 0, 0, {1, 1, 2, 3}, + {12, 13, 14, 15, 16, 17}); + TestStridedSliceWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0}, + {2, 1}, {1, 1}, 0, 2, 0, 0, 0, {1, 2, 2, 3}, + {12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + TestStridedSliceWithDataFormat( + {2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {-1, 2, 1, 3}, + {0, 0, 0, 0}, {-1, -1, -1, -1}, 0, 0, 0, 0, 0, {1, 1, 1, 2}, + {23, 22}); +} + TEST_F(StridedSliceOpTest, TestSlice) { TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {2, 3}, {1, 2, 3, 4, 5, 6}); @@ -166,6 +309,17 @@ TEST_F(StridedSliceOpTest, TestSlice) { TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 1}, {2, -1}, {2, 2}, {2, 3, 5, 6}); } +TEST_F(StridedSliceOpTest, TestSliceWithDataFormat) { + TestSliceWithDataFormat({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 0, 1, 0}, {1, 2, 1, 2}, {1, 2, 1, 2}, + {15, 16, 21, 22}); + TestSliceWithDataFormat({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 0, 1, 0}, {-1, -1, -1, -1}, {1, 2, 1, 3}, + {15, 16, 17, 21, 22, 23}); +} + } // namespace test } // namespace ops } // namespace mace diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 3c34d09c0aaca04c077c49eeb090dfca180aebf1..93de5ef3ec53f835dfa331978ff1e83908e933a7 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -159,6 +159,7 @@ MaceSupportedOps = [ 'Transpose', 'WinogradInverseTransform', 'WinogradTransform', + 'Cumsum', ] MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str) @@ -237,6 +238,8 @@ class MaceKeyword(object): mace_find_range_every_time = 'find_range_every_time' mace_non_zero = 'non_zero' mace_pad_type_str = 'pad_type' + mace_exclusive_str = 'exclusive' + mace_reverse_str = 'reverse' class TransformerRule(Enum): diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 8e12714edb2e64f29ef552846d3ab53e742651e5..2ef33d30e3c8e5735e43431672a63c28f305f75a 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -116,6 +116,7 @@ TFSupportedOps = [ 'FloorDiv', 'Sqrt', 'MirrorPad', + 'Cumsum', ] TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str) @@ -253,6 +254,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.FloorDiv.name: self.convert_elementwise, TFOpType.Sqrt.name: self.convert_elementwise, TFOpType.MirrorPad.name: self.convert_pad, + TFOpType.Cumsum.name: self.convert_cumsum, } self._option = option self._mace_net_def = mace_pb2.NetDef() @@ -1012,3 +1014,23 @@ class TensorflowConverter(base_converter.ConverterInterface): self._skip_tensor.add(tf_op.inputs[1].name) self._skip_tensor.add(tf_op.inputs[2].name) + + def convert_cumsum(self, tf_op): + op = self.convert_general_op(tf_op) + op.type = MaceOp.Cumsum.name + + axis = tf_op.inputs[1].eval().astype(np.int32) + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.i = axis + del op.input[1] + + exclusive = tf_op.get_attr('exclusive') + exclusive_arg = op.arg.add() + exclusive_arg.name = MaceKeyword.mace_exclusive_str + exclusive_arg.i = int(exclusive) + + reverse = tf_op.get_attr('reverse') + reverse_arg = op.arg.add() + reverse_arg.name = MaceKeyword.mace_reverse_str + reverse_arg.i = int(reverse)