提交 a14a6cb4 编写于 作者: Y yejianwu

refactor strided_slice and fix bound check, update cumsum impl

上级 a85f052d
......@@ -19,26 +19,6 @@
namespace mace {
namespace ops {
namespace {
void PlusOne(int* val) {
++(*val);
}
void SubOne(int* val) {
--(*val);
}
bool LessThan(const int& val, const int& boundary) {
return val < boundary;
}
bool NotLessThanZero(const int& val, const int& boundary) {
MACE_UNUSED(boundary);
return val >= 0;
}
} // namespace
template <DeviceType D, typename T>
class CumsumOp;
......@@ -47,9 +27,10 @@ class CumsumOp<DeviceType::CPU, T> : public Operation {
public:
explicit CumsumOp(OpConstructContext *context)
: Operation(context),
axis_(Operation::GetOptionalArg<int>("axis", 3)),
axis_(Operation::GetOptionalArg<int>("axis", 0)),
exclusive_(Operation::GetOptionalArg<bool>("exclusive", false)),
reverse_(Operation::GetOptionalArg<bool>("reverse", false)) {}
reverse_(Operation::GetOptionalArg<bool>("reverse", false)),
checked_(false) {}
void Validate() {
const int32_t input_dims = this->Input(0)->dim_size();
......@@ -64,9 +45,9 @@ class CumsumOp<DeviceType::CPU, T> : public Operation {
MACE_UNUSED(context);
if (!checked_) {
Validate();
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE));
if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) {
bool has_data_format = Operation::GetOptionalArg<int>(
"has_data_format", 0);
if (has_data_format && this->Input(0)->dim_size() == 4) {
if (axis_ == 3) axis_ = 1;
else if (axis_ == 2) axis_ = 3;
else if (axis_ == 1) axis_ = 2;
......@@ -75,6 +56,7 @@ class CumsumOp<DeviceType::CPU, T> : public Operation {
}
const Tensor *input = this->Input(0);
const std::vector<index_t> input_shape = input->shape();
Tensor *output = this->Output(0);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
......@@ -85,66 +67,70 @@ class CumsumOp<DeviceType::CPU, T> : public Operation {
const float *input_ptr = input->data<float>();
float *output_ptr = output->mutable_data<float>();
std::function<void(int*)> next = reverse_ ? SubOne : PlusOne;
std::function<void(int*)> previous = reverse_ ? PlusOne : SubOne;
std::function<bool(const int&, const int&)> boundary =
reverse_ ? NotLessThanZero : LessThan;
if (input->dim_size() == 4) {
const int batch = input->dim(0);
const int channel = input->dim(1);
const int height = input->dim(2);
const int width = input->dim(3);
const int axis_dim_size = input->dim(axis_);
for (int n = reverse_ ? batch - 1 : 0; boundary(n, batch); next(&n)) {
for (int c = reverse_ ? channel - 1 : 0; boundary(c, channel);
next(&c)) {
for (int h = reverse_ ? height - 1 : 0; boundary(h, height);
next(&h)) {
for (int w = reverse_ ? width - 1 : 0; boundary(w, width);
next(&w)) {
int dims[4] = {n, c, h, w};
if (!reverse_ && dims[axis_] == 0) {
if (exclusive_) {
output_ptr[((n * channel + c) * height + h) * width + w] = 0;
} else {
continue;
}
} else if (reverse_ && dims[axis_] == axis_dim_size - 1) {
if (exclusive_) {
output_ptr[((n * channel + c) * height + h) * width + w] = 0;
} else {
continue;
}
} else {
previous(&dims[axis_]);
if (exclusive_) {
output_ptr[((n * channel + c) * height + h) * width + w] =
input_ptr[((dims[0] * channel + dims[1]) * height +
dims[2]) *
width +
dims[3]] +
output_ptr[((dims[0] * channel + dims[1]) * height +
dims[2]) *
width +
dims[3]];
} else {
output_ptr[((n * channel + c) * height + h) * width + w] =
input_ptr[((n * channel + c) * height + h) * width + w] +
output_ptr[((dims[0] * channel + dims[1]) * height +
dims[2]) *
width +
dims[3]];
}
}
const index_t outer_size = std::accumulate(input_shape.begin(),
input_shape.begin() + axis_,
1,
std::multiplies<index_t>());
const index_t inner_size = std::accumulate(input_shape.begin() + axis_ + 1,
input_shape.end(),
1,
std::multiplies<index_t>());
const index_t cum_size = input_shape[axis_];
if (!reverse_) {
#pragma omp parallel for
for (index_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
index_t start_idx = outer_idx * cum_size * inner_size;
for (index_t cum_idx = 0; cum_idx < cum_size; ++cum_idx) {
if (cum_idx == 0) {
if (exclusive_) {
std::memset(output_ptr + start_idx,
0,
sizeof(T) * inner_size);
} else {
std::memcpy(output_ptr + start_idx,
input_ptr + start_idx,
sizeof(T) * inner_size);
}
} else {
index_t cur_idx = start_idx + cum_idx * inner_size;
index_t pre_idx = start_idx + (cum_idx - 1) * inner_size;
index_t input_idx = exclusive_ ? pre_idx : cur_idx;
for (index_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) {
output_ptr[cur_idx + inner_idx] =
output_ptr[pre_idx + inner_idx] +
input_ptr[input_idx + inner_idx];
}
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
#pragma omp parallel for
for (index_t outer_idx = outer_size - 1; outer_idx >= 0; --outer_idx) {
index_t start_idx = outer_idx * cum_size * inner_size;
for (index_t cum_idx = cum_size - 1; cum_idx >= 0; --cum_idx) {
index_t cur_idx = start_idx + cum_idx * inner_size;
if (cum_idx == cum_size - 1) {
if (exclusive_) {
std::memset(output_ptr + cur_idx,
0,
sizeof(T) * inner_size);
} else {
std::memcpy(output_ptr + cur_idx,
input_ptr + cur_idx,
sizeof(T) * inner_size);
}
} else {
index_t pre_idx = start_idx + (cum_idx + 1) * inner_size;
index_t input_idx = exclusive_ ? pre_idx : cur_idx;
for (index_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) {
output_ptr[cur_idx + inner_idx] =
output_ptr[pre_idx + inner_idx] +
input_ptr[input_idx + inner_idx];
}
}
}
}
}
return MaceStatus::MACE_SUCCESS;
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class CumsumOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void Cumsum(int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
// Construct graph
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
} else {
MACE_NOT_IMPLEMENTED;
}
OpDefBuilder("Cumsum", "CumsumTest")
.Input("Input")
.Output("Output")
.AddIntArg("axis", 0)
.AddIntArg("exclusive", 0)
.AddIntArg("reverse", 0)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define MACE_BM_CUMSUM_MACRO(N, C, H, W, TYPE, DEVICE) \
static void MACE_BM_CUMSUM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
Cumsum<DEVICE, TYPE>(iters, N, C, H, W); \
} \
MACE_BENCHMARK(MACE_BM_CUMSUM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_CUMSUM(N, C, H, W) \
MACE_BM_CUMSUM_MACRO(N, C, H, W, float, CPU);
MACE_BM_CUMSUM(1, 1, 512, 512);
MACE_BM_CUMSUM(1, 3, 128, 128);
MACE_BM_CUMSUM(1, 3, 512, 512);
MACE_BM_CUMSUM(1, 32, 112, 112);
MACE_BM_CUMSUM(1, 64, 256, 256);
MACE_BM_CUMSUM(1, 64, 512, 512);
MACE_BM_CUMSUM(1, 128, 56, 56);
MACE_BM_CUMSUM(1, 128, 256, 256);
MACE_BM_CUMSUM(1, 256, 14, 14);
MACE_BM_CUMSUM(1, 512, 14, 14);
MACE_BM_CUMSUM(1, 1024, 7, 7);
MACE_BM_CUMSUM(32, 1, 256, 256);
MACE_BM_CUMSUM(32, 3, 256, 256);
} // namespace test
} // namespace ops
} // namespace mace
......@@ -12,9 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <vector>
#include "mace/ops/ops_test_util.h"
namespace mace {
......@@ -24,33 +21,69 @@ namespace test {
class CumsumOpTest : public OpsTestBase {};
namespace {
void SimpleTest() {
template <typename T>
void SimpleTestWithDataFormat(const std::vector<index_t> &shape,
const std::vector<float> &input,
const int axis,
const int exclusive,
const int reverse,
const std::vector<float> &output) {
// Construct graph
OpsTestNet net;
net.AddInputFromArray<DeviceType::CPU, float>("Input", {2, 2, 2, 2},
{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.});
net.AddInputFromArray<CPU, T>("Input", shape, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("Cumsum", "CumsumTest")
.Input("Input")
.Output("Output")
.AddIntArg("axis", 1)
.AddIntArg("exclusive", 1)
.AddIntArg("reverse", 1)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<float>::value))
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntArg("axis", axis)
.AddIntArg("exclusive", exclusive)
.AddIntArg("reverse", reverse)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(DeviceType::CPU);
auto expected = net.CreateTensor<float>({2, 2, 2, 2},
{4., 5., 6., 7., 0., 0., 0., 0., 12., 13., 14., 15., 0., 0., 0., 0.});
ExpectTensorNear<float, float>(*expected, *net.GetOutput("Output"), 1e-5);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output",
NHWC);
net.AddInputFromArray<CPU, T>("ExpectedOutput", shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(CumsumOpTest, CPU) {
SimpleTest();
TEST_F(CumsumOpTest, HasDataFormatCPU) {
SimpleTestWithDataFormat<float>(
{2, 2, 2, 2},
{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.},
0, 0, 0,
{0., 1., 2., 3., 4., 5., 6., 7., 8., 10., 12., 14., 16., 18., 20., 22.});
SimpleTestWithDataFormat<float>(
{2, 2, 2, 2},
{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.},
1, 0, 0,
{0., 1., 2., 3., 4., 6., 8., 10., 8., 9., 10., 11., 20., 22., 24., 26.});
SimpleTestWithDataFormat<float>(
{2, 2, 2, 2},
{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.},
0, 1, 0,
{0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 5., 6., 7.});
SimpleTestWithDataFormat<float>(
{2, 2, 2, 2},
{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.},
0, 0, 1,
{8., 10., 12., 14., 16., 18., 20., 22., 8., 9., 10., 11., 12., 13., 14.,
15.});
SimpleTestWithDataFormat<float>(
{2, 2, 2, 2},
{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.},
1, 1, 1,
{4., 5., 6., 7., 0., 0., 0., 0., 12., 13., 14., 15., 0., 0., 0., 0.});
}
} // namespace test
......
......@@ -17,6 +17,7 @@
#include <vector>
#include "mace/core/operator.h"
#include "mace/utils/math.h"
namespace mace {
namespace ops {
......@@ -33,6 +34,7 @@ class StridedSliceOp : public Operation {
shrink_axis_mask_(
Operation::GetOptionalArg<int>("shrink_axis_mask", 0)),
is_slice_(Operation::GetOptionalArg<bool>("slice", false)),
has_data_format_(Operation::GetOptionalArg<int>("has_data_format", 0)),
checked_(false) {
MACE_CHECK(ellipsis_mask_ == 0 && new_axis_mask_ == 0,
"ellipsis_mask and new_axis_mask are not supported yet.");
......@@ -62,14 +64,21 @@ class StridedSliceOp : public Operation {
(*dims)[3] = w;
}
void TransposeDimsFromNCHWToNHWC(std::vector<int32_t>* dims) {
int32_t c = (*dims)[1];
int32_t h = (*dims)[2];
int32_t w = (*dims)[3];
(*dims)[1] = h;
(*dims)[2] = w;
(*dims)[3] = c;
}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE));
if (!checked_) {
if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) {
if (has_data_format_ && this->Input(0)->dim_size() == 4) {
TransposeMaskValueFromNHWCToNCHW(&begin_mask_);
TransposeMaskValueFromNHWCToNCHW(&end_mask_);
TransposeMaskValueFromNHWCToNCHW(&ellipsis_mask_);
......@@ -78,14 +87,15 @@ class StridedSliceOp : public Operation {
}
checked_ = true;
}
const Tensor *input = this->Input(INPUT);
const Tensor *begin_indices = this->Input(BEGIN);
const Tensor *end_indices = this->Input(END);
const Tensor *strides = nullptr;
if (this->InputSize() > 3) {
strides = this->Input(STRIDES);
}
Tensor *output = this->Output(OUTPUT);
if (strides == nullptr) {
tmp_strides_tensor_.Resize({begin_indices->size()});
Tensor::MappingGuard strides_guard(&tmp_strides_tensor_);
......@@ -94,6 +104,11 @@ class StridedSliceOp : public Operation {
strides = &tmp_strides_tensor_;
}
MACE_CHECK(begin_indices->dim_size() == 1 &&
end_indices->dim_size() == 1 &&
strides->dim_size() == 1,
"Expected begin, end, and strides to be 1D tensor");
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard begin_indices_guard(begin_indices);
Tensor::MappingGuard end_indices_guard(end_indices);
......@@ -102,107 +117,145 @@ class StridedSliceOp : public Operation {
const int32_t *begin_indices_data = begin_indices->data<int32_t>();
const int32_t *end_indices_data = end_indices->data<int32_t>();
const int32_t *strides_data = strides->data<int32_t>();
std::vector<int32_t> pad_begin_indices(input->dim_size(), 0);
std::vector<int32_t> pad_end_indices(input->dim_size(), 0);
std::vector<int32_t> pad_strides_indices(input->dim_size(), 1);
if (begin_indices->size() < input->dim_size()) {
for (index_t i = 0; i < begin_indices->size(); ++i) {
pad_begin_indices[i] = begin_indices_data[i];
pad_end_indices[i] = end_indices_data[i];
pad_strides_indices[i] = strides_data[i];
}
for (index_t i = begin_indices->size(); i < input->dim_size(); ++i) {
pad_end_indices[i] = input->dim(i);
}
begin_indices_data = pad_begin_indices.data();
end_indices_data = pad_end_indices.data();
strides_data = pad_strides_indices.data();
}
std::vector<int32_t> begin_indices_vec(
begin_indices_data, begin_indices_data + begin_indices->size());
std::vector<int32_t> end_indices_vec(
end_indices_data, end_indices_data + end_indices->size());
std::vector<int32_t> strides_indices_vec(
strides_data, strides_data + strides->size());
std::vector<int32_t> transpose_begin_indices(input->dim_size(), 0);
std::vector<int32_t> transpose_end_indices(input->dim_size(), 0);
std::vector<int32_t> transpose_strides_indices(input->dim_size(), 1);
if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) {
for (index_t i = 0; i < begin_indices->size(); ++i) {
transpose_begin_indices[i] = begin_indices_data[i];
transpose_end_indices[i] = end_indices_data[i];
transpose_strides_indices[i] = strides_data[i];
}
TransposeDimsFromNHWCToNCHW(&transpose_begin_indices);
TransposeDimsFromNHWCToNCHW(&transpose_end_indices);
TransposeDimsFromNHWCToNCHW(&transpose_strides_indices);
MACE_CHECK(input->size() > 0 && input->dim_size() > 0 &&
input->dim_size() <= 4,
"The input size should larger than 0."
" And input dims should be an integer in (0, 4].");
begin_indices_data = transpose_begin_indices.data();
end_indices_data = transpose_end_indices.data();
strides_data = transpose_strides_indices.data();
}
std::vector<index_t> output_shape = {};
std::vector<int32_t> slice_end_data;
const size_t input_dims = input->dim_size();
if (is_slice_) {
// if this op is slice, the end_indices_data is size actually
slice_end_data.resize(end_indices->size());
for (size_t i = 0; i < slice_end_data.size(); ++i) {
if (end_indices_data[i] == -1) {
slice_end_data[i] = input->dim(i);
} else {
slice_end_data[i] = begin_indices_data[i] + end_indices_data[i];
MACE_CHECK(begin_indices_vec.size() == input_dims &&
end_indices_vec.size() == input_dims,
"In slice, begin and size elements num should be equal");
// transpose
if (has_data_format_ && this->Input(0)->dim_size() == 4) {
TransposeDimsFromNHWCToNCHW(&begin_indices_vec);
TransposeDimsFromNHWCToNCHW(&end_indices_vec);
TransposeDimsFromNHWCToNCHW(&strides_indices_vec);
}
for (size_t i = 0; i < input_dims; ++i) {
if (end_indices_vec[i] == -1) {
end_indices_vec[i] = input->dim(i) - begin_indices_vec[i];
}
}
end_indices_data = slice_end_data.data();
}
std::vector<index_t> output_shape;
std::vector<index_t> real_begin_indices(input->dim_size(), 0);
std::vector<index_t> real_end_indices(input->dim_size(), 0);
for (index_t d = 0; d < input->dim_size(); ++d) {
index_t dim_len = input->dim(d);
if (begin_mask_ & (1 << d)) {
real_begin_indices[d] = strides_data[d] > 0 ? 0 : dim_len - 1;
} else {
real_begin_indices[d] = (begin_indices_data[d] + dim_len) % dim_len;
for (size_t i = 0; i < input_dims; ++i) {
int32_t b = begin_indices_vec[i];
int32_t s = end_indices_vec[i];
int32_t input_i = input->dim(i);
MACE_CHECK(0 <= b && b <= input_i,
"In Slice, expected begin[", i, "] in [0, ", input_i,
"], but got ", b);
MACE_CHECK(0 <= s && b + s <= input_i,
"In Slice, expected size[", i, "] in [0, ",
input_i - b, "], but got", s);
end_indices_vec[i] = b + s;
output_shape.push_back(s);
}
if (end_mask_ & (1 << d)) {
real_end_indices[d] = strides_data[d] > 0 ? dim_len : -1;
} else {
real_end_indices[d] =
end_indices_data[d] < -dim_len
? -1
: (end_indices_data[d] < 0
? (end_indices_data[d] + dim_len)
: std::min(static_cast<index_t>(end_indices_data[d]),
dim_len));
} else {
MACE_CHECK(begin_indices_vec.size() == end_indices_vec.size() &&
end_indices_vec.size() == strides_indices_vec.size(),
"In strided_slice, expected begin, end, and strides to be",
" equal size tensors");
for (index_t i = 0; i < strides->size(); ++i) {
MACE_CHECK(strides_indices_vec[i] != 0, "strides data cannot be 0!");
}
int32_t out_dim_len = std::max(
0.f, std::ceil((real_end_indices[d] - real_begin_indices[d]) /
static_cast<float>(strides_data[d])));
if (!(shrink_axis_mask_ & (1 << d))) {
output_shape.push_back(out_dim_len);
} else {
MACE_CHECK(out_dim_len == 1,
"cannot shrink axis that has len > 1, dim(", d, "): [",
real_begin_indices[d], ", ", real_end_indices[d], "]");
// pad
begin_indices_vec.resize(input_dims, 0);
strides_indices_vec.resize(input_dims, 1);
std::vector<int32_t> tmp_input_dims(input->shape().begin(),
input->shape().end());
if (has_data_format_ && input_dims == 4) {
TransposeDimsFromNCHWToNHWC(&tmp_input_dims);
}
for (size_t i = end_indices_vec.size(); i < input_dims; ++i) {
end_indices_vec.push_back(tmp_input_dims[i]);
}
// transpose
if (has_data_format_ && this->Input(0)->dim_size() == 4) {
TransposeDimsFromNHWCToNCHW(&begin_indices_vec);
TransposeDimsFromNHWCToNCHW(&end_indices_vec);
TransposeDimsFromNHWCToNCHW(&strides_indices_vec);
}
// mask and shrink
for (index_t d = 0; d < input->dim_size(); ++d) {
index_t dim_len = input->dim(d);
const std::vector<index_t> valid_range = {
strides_indices_vec[d] > 0 ? 0 : -1,
strides_indices_vec[d] > 0 ? dim_len : dim_len - 1};
auto format_indices = [valid_range, d, dim_len](index_t indice) {
index_t forward = indice < 0 ? indice + dim_len : indice;
return Clamp(forward, valid_range[0], valid_range[1]);
};
if (!(shrink_axis_mask_ & (1 << d))) {
if (begin_mask_ & (1 << d)) {
begin_indices_vec[d] = strides_indices_vec[d] > 0 ? 0 : dim_len - 1;
} else {
begin_indices_vec[d] = format_indices(begin_indices_vec[d]);
}
if (end_mask_ & (1 << d)) {
end_indices_vec[d] = strides_indices_vec[d] > 0 ? dim_len : -1;
} else {
end_indices_vec[d] = format_indices(end_indices_vec[d]);
}
int32_t out_dim_len = std::max(
0.f, std::ceil((end_indices_vec[d] - begin_indices_vec[d]) /
static_cast<float>(strides_indices_vec[d])));
output_shape.push_back(out_dim_len);
} else {
begin_indices_vec[d] = begin_indices_vec[d] < 0
? begin_indices_vec[d] + dim_len
: begin_indices_vec[d];
end_indices_vec[d] = begin_indices_vec[d] + 1;
MACE_CHECK(
begin_indices_vec[d] >= 0 && begin_indices_vec[d] < dim_len,
"slice begin indice of dimension '", d, "': ",
begin_indices_vec[d], ", is out of bound");
}
}
}
for (size_t i = 0; i < output_shape.size(); ++i) {
MACE_CHECK(output_shape[i] > 0,
"Expected output_shape[", i, "] larger than 0, but got ",
output_shape[i]);
}
std::vector<index_t> dim_stride(input->dim_size(), 1);
for (index_t d = input->dim_size() - 2; d >= 0; --d) {
dim_stride[d] = dim_stride[d + 1] * input->dim(d + 1);
}
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard output_guard(output);
T *output_data = output->mutable_data<T>();
bool slice_by_first_axis = true;
if (strides_data[0] != 1) {
if (strides_indices_vec[0] != 1) {
slice_by_first_axis = false;
} else {
for (index_t d = 1; d < input->dim_size(); ++d) {
if (strides_data[d] != 1 || real_begin_indices[d] != 0 ||
real_end_indices[d] != input->dim(d)) {
if (strides_indices_vec[d] != 1 || begin_indices_vec[d] != 0 ||
end_indices_vec[d] != input->dim(d)) {
slice_by_first_axis = false;
break;
}
......@@ -210,64 +263,64 @@ class StridedSliceOp : public Operation {
}
if (slice_by_first_axis) {
memcpy(output_data, input_data + real_begin_indices[0] * dim_stride[0],
sizeof(T) * (real_end_indices[0] - real_begin_indices[0]) *
memcpy(output_data, input_data + begin_indices_vec[0] * dim_stride[0],
sizeof(T) * (end_indices_vec[0] - begin_indices_vec[0]) *
dim_stride[0]);
} else {
if (input->dim_size() == 1) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
: i > real_end_indices[0];
i += strides_data[0]) {
for (index_t i = begin_indices_vec[0];
strides_indices_vec[0] > 0 ? i < end_indices_vec[0]
: i > end_indices_vec[0];
i += strides_indices_vec[0]) {
*output_data++ = input_data[i];
}
} else if (input->dim_size() == 2) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
: i > real_end_indices[0];
i += strides_data[0]) {
for (index_t j = real_begin_indices[1];
strides_data[1] > 0 ? j < real_end_indices[1]
: j > real_end_indices[1];
j += strides_data[1]) {
for (index_t i = begin_indices_vec[0];
strides_indices_vec[0] > 0 ? i < end_indices_vec[0]
: i > end_indices_vec[0];
i += strides_indices_vec[0]) {
for (index_t j = begin_indices_vec[1];
strides_indices_vec[1] > 0 ? j < end_indices_vec[1]
: j > end_indices_vec[1];
j += strides_indices_vec[1]) {
*output_data++ = input_data[i * input->dim(1) + j];
}
}
} else if (input->dim_size() == 3) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
: i > real_end_indices[0];
i += strides_data[0]) {
for (index_t j = real_begin_indices[1];
strides_data[1] > 0 ? j < real_end_indices[1]
: j > real_end_indices[1];
j += strides_data[1]) {
for (index_t k = real_begin_indices[2];
strides_data[2] > 0 ? k < real_end_indices[2]
: k > real_end_indices[2];
k += strides_data[2]) {
for (index_t i = begin_indices_vec[0];
strides_indices_vec[0] > 0 ? i < end_indices_vec[0]
: i > end_indices_vec[0];
i += strides_indices_vec[0]) {
for (index_t j = begin_indices_vec[1];
strides_indices_vec[1] > 0 ? j < end_indices_vec[1]
: j > end_indices_vec[1];
j += strides_indices_vec[1]) {
for (index_t k = begin_indices_vec[2];
strides_indices_vec[2] > 0 ? k < end_indices_vec[2]
: k > end_indices_vec[2];
k += strides_indices_vec[2]) {
*output_data++ =
input_data[(i * input->dim(1) + j) * input->dim(2) + k];
}
}
}
} else if (input->dim_size() == 4) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
: i > real_end_indices[0];
i += strides_data[0]) {
for (index_t j = real_begin_indices[1];
strides_data[1] > 0 ? j < real_end_indices[1]
: j > real_end_indices[1];
j += strides_data[1]) {
for (index_t k = real_begin_indices[2];
strides_data[2] > 0 ? k < real_end_indices[2]
: k > real_end_indices[2];
k += strides_data[2]) {
for (index_t l = real_begin_indices[3];
strides_data[3] > 0 ? l < real_end_indices[3]
: l > real_end_indices[3];
l += strides_data[3]) {
for (index_t i = begin_indices_vec[0];
strides_indices_vec[0] > 0 ? i < end_indices_vec[0]
: i > end_indices_vec[0];
i += strides_indices_vec[0]) {
for (index_t j = begin_indices_vec[1];
strides_indices_vec[1] > 0 ? j < end_indices_vec[1]
: j > end_indices_vec[1];
j += strides_indices_vec[1]) {
for (index_t k = begin_indices_vec[2];
strides_indices_vec[2] > 0 ? k < end_indices_vec[2]
: k > end_indices_vec[2];
k += strides_indices_vec[2]) {
for (index_t l = begin_indices_vec[3];
strides_indices_vec[3] > 0 ? l < end_indices_vec[3]
: l > end_indices_vec[3];
l += strides_indices_vec[3]) {
*output_data++ =
input_data[((i * input->dim(1) + j) * input->dim(2) + k)
* input->dim(3) + l];
......@@ -289,6 +342,7 @@ class StridedSliceOp : public Operation {
int new_axis_mask_;
int shrink_axis_mask_;
bool is_slice_;
int has_data_format_;
bool checked_;
Tensor tmp_strides_tensor_;
......
......@@ -64,6 +64,54 @@ void TestStridedSlice(const std::vector<index_t> &input_shape,
*net.GetOutput("Output"));
}
void TestStridedSliceWithDataFormat(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int32_t> &begin_indices,
const std::vector<int32_t> &end_indices,
const std::vector<int32_t> &strides,
const int begin_mask,
const int end_mask,
const int ellipsis_mask,
const int new_axis_mask,
const int shrink_axis_mask,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>(
"BeginIndices", {static_cast<int32_t>(begin_indices.size())},
begin_indices);
net.AddInputFromArray<CPU, int32_t>(
"EndIndices", {static_cast<int32_t>(end_indices.size())}, end_indices);
net.AddInputFromArray<CPU, int32_t>(
"Strides", {static_cast<int32_t>(strides.size())}, strides);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("InputNCHW")
.Input("BeginIndices")
.Input("EndIndices")
.Input("Strides")
.Output("OutputNCHW")
.AddIntArg("begin_mask", begin_mask)
.AddIntArg("end_mask", end_mask)
.AddIntArg("ellipsis_mask", ellipsis_mask)
.AddIntArg("new_axis_mask", new_axis_mask)
.AddIntArg("shrink_axis_mask", shrink_axis_mask)
.AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef());
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output",
NHWC);
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
void TestSlice(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int32_t> &begin_indices,
......@@ -92,6 +140,41 @@ void TestSlice(const std::vector<index_t> &input_shape,
*net.GetOutput("Output"));
}
void TestSliceWithDataFormat(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int32_t> &begin_indices,
const std::vector<int32_t> &indices_size,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>(
"BeginIndices", {static_cast<int32_t>(input_shape.size())},
begin_indices);
net.AddInputFromArray<CPU, int32_t>(
"IndicesSize", {static_cast<int32_t>(indices_size.size())}, indices_size);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("InputNCHW")
.Input("BeginIndices")
.Input("IndicesSize")
.Output("OutputNCHW")
.AddIntArg("slice", 1)
.AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef());
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output",
NHWC);
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(StridedSliceOpTest, TestStridedSliceByFirstAxis) {
......@@ -157,6 +240,66 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank3) {
1, 2}, {1, 1, 3, 3});
}
TEST_F(StridedSliceOpTest, TestStridedSliceRank4) {
TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 0, 0, 0, 0, 0, {1, 2, 1, 2},
{15, 16, 21, 22});
TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 3, 0, 0, 0, 0, {2, 2, 1, 2},
{3, 4, 9, 10, 15, 16, 21, 22});
TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 0, {1, 2, 1, 3},
{15, 16, 17, 21, 22, 23});
TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 8, {1, 2, 1},
{15, 21});
TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 15, {}, {15});
TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {-1, 2, 1, 3},
{0, 0, 0, 0}, {-1, -1, -1, -1}, 0, 0, 0, 0, 0, {1, 1, 1, 2},
{23, 22});
}
TEST_F(StridedSliceOpTest, TestStridedSliceWithDataFormat) {
TestStridedSliceWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 0, 0, 0, 0, 0, {1, 2, 1, 2},
{15, 16, 21, 22});
TestStridedSliceWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 3, 0, 0, 0, 0, {2, 2, 1, 2},
{3, 4, 9, 10, 15, 16, 21, 22});
TestStridedSliceWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 0, {1, 2, 1, 3},
{15, 16, 17, 21, 22, 23});
TestStridedSliceWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0},
{2, 1}, {1, 1}, 0, 8, 0, 0, 0, {1, 1, 2, 3},
{12, 13, 14, 15, 16, 17});
TestStridedSliceWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0},
{2, 1}, {1, 1}, 0, 2, 0, 0, 0, {1, 2, 2, 3},
{12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
TestStridedSliceWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {-1, 2, 1, 3},
{0, 0, 0, 0}, {-1, -1, -1, -1}, 0, 0, 0, 0, 0, {1, 1, 1, 2},
{23, 22});
}
TEST_F(StridedSliceOpTest, TestSlice) {
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {2, 3},
{1, 2, 3, 4, 5, 6});
......@@ -166,6 +309,17 @@ TEST_F(StridedSliceOpTest, TestSlice) {
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 1}, {2, -1}, {2, 2}, {2, 3, 5, 6});
}
TEST_F(StridedSliceOpTest, TestSliceWithDataFormat) {
TestSliceWithDataFormat({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1, 0, 1, 0}, {1, 2, 1, 2}, {1, 2, 1, 2},
{15, 16, 21, 22});
TestSliceWithDataFormat({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1, 0, 1, 0}, {-1, -1, -1, -1}, {1, 2, 1, 3},
{15, 16, 17, 21, 22, 23});
}
} // namespace test
} // namespace ops
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册