提交 41a6d777 编写于 作者: 李寅

Merge branch 'support_cumsum_and_stridedslice' into 'master'

support tf.cumsum and 4-D strided_slice in cpu

See merge request !999
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class CumsumOp;
template <typename T>
class CumsumOp<DeviceType::CPU, T> : public Operation {
public:
explicit CumsumOp(OpConstructContext *context)
: Operation(context),
axis_(Operation::GetOptionalArg<int>("axis", 0)),
exclusive_(Operation::GetOptionalArg<bool>("exclusive", false)),
reverse_(Operation::GetOptionalArg<bool>("reverse", false)),
checked_(false) {}
void Validate() {
const int32_t input_dims = this->Input(0)->dim_size();
axis_ =
axis_ < 0 ? axis_ + input_dims : axis_;
MACE_CHECK((0 <= axis_ && axis_ < input_dims),
"Expected concatenating axis in the range [", -input_dims, ", ",
input_dims, "], but got ", axis_);
}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
if (!checked_) {
Validate();
bool has_data_format = Operation::GetOptionalArg<int>(
"has_data_format", 0);
if (has_data_format && this->Input(0)->dim_size() == 4) {
if (axis_ == 3) axis_ = 1;
else if (axis_ == 2) axis_ = 3;
else if (axis_ == 1) axis_ = 2;
}
checked_ = true;
}
const Tensor *input = this->Input(0);
const std::vector<index_t> input_shape = input->shape();
Tensor *output = this->Output(0);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard output_mapper(output);
const float *input_ptr = input->data<float>();
float *output_ptr = output->mutable_data<float>();
const index_t outer_size = std::accumulate(input_shape.begin(),
input_shape.begin() + axis_,
1,
std::multiplies<index_t>());
const index_t inner_size = std::accumulate(input_shape.begin() + axis_ + 1,
input_shape.end(),
1,
std::multiplies<index_t>());
const index_t cum_size = input_shape[axis_];
if (!reverse_) {
#pragma omp parallel for
for (index_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
index_t start_idx = outer_idx * cum_size * inner_size;
for (index_t cum_idx = 0; cum_idx < cum_size; ++cum_idx) {
if (cum_idx == 0) {
if (exclusive_) {
std::memset(output_ptr + start_idx,
0,
sizeof(T) * inner_size);
} else {
std::memcpy(output_ptr + start_idx,
input_ptr + start_idx,
sizeof(T) * inner_size);
}
} else {
index_t cur_idx = start_idx + cum_idx * inner_size;
index_t pre_idx = start_idx + (cum_idx - 1) * inner_size;
index_t input_idx = exclusive_ ? pre_idx : cur_idx;
for (index_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) {
output_ptr[cur_idx + inner_idx] =
output_ptr[pre_idx + inner_idx] +
input_ptr[input_idx + inner_idx];
}
}
}
}
} else {
#pragma omp parallel for
for (index_t outer_idx = outer_size - 1; outer_idx >= 0; --outer_idx) {
index_t start_idx = outer_idx * cum_size * inner_size;
for (index_t cum_idx = cum_size - 1; cum_idx >= 0; --cum_idx) {
index_t cur_idx = start_idx + cum_idx * inner_size;
if (cum_idx == cum_size - 1) {
if (exclusive_) {
std::memset(output_ptr + cur_idx,
0,
sizeof(T) * inner_size);
} else {
std::memcpy(output_ptr + cur_idx,
input_ptr + cur_idx,
sizeof(T) * inner_size);
}
} else {
index_t pre_idx = start_idx + (cum_idx + 1) * inner_size;
index_t input_idx = exclusive_ ? pre_idx : cur_idx;
for (index_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) {
output_ptr[cur_idx + inner_idx] =
output_ptr[pre_idx + inner_idx] +
input_ptr[input_idx + inner_idx];
}
}
}
}
}
return MaceStatus::MACE_SUCCESS;
}
private:
int32_t axis_;
bool exclusive_;
bool reverse_;
bool checked_;
};
void RegisterCumsum(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Cumsum", CumsumOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class CumsumOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void Cumsum(int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
// Construct graph
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
} else {
MACE_NOT_IMPLEMENTED;
}
OpDefBuilder("Cumsum", "CumsumTest")
.Input("Input")
.Output("Output")
.AddIntArg("axis", 0)
.AddIntArg("exclusive", 0)
.AddIntArg("reverse", 0)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define MACE_BM_CUMSUM_MACRO(N, C, H, W, TYPE, DEVICE) \
static void MACE_BM_CUMSUM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
Cumsum<DEVICE, TYPE>(iters, N, C, H, W); \
} \
MACE_BENCHMARK(MACE_BM_CUMSUM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_CUMSUM(N, C, H, W) \
MACE_BM_CUMSUM_MACRO(N, C, H, W, float, CPU);
MACE_BM_CUMSUM(1, 1, 512, 512);
MACE_BM_CUMSUM(1, 3, 128, 128);
MACE_BM_CUMSUM(1, 3, 512, 512);
MACE_BM_CUMSUM(1, 32, 112, 112);
MACE_BM_CUMSUM(1, 64, 256, 256);
MACE_BM_CUMSUM(1, 64, 512, 512);
MACE_BM_CUMSUM(1, 128, 56, 56);
MACE_BM_CUMSUM(1, 128, 256, 256);
MACE_BM_CUMSUM(1, 256, 14, 14);
MACE_BM_CUMSUM(1, 512, 14, 14);
MACE_BM_CUMSUM(1, 1024, 7, 7);
MACE_BM_CUMSUM(32, 1, 256, 256);
MACE_BM_CUMSUM(32, 3, 256, 256);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class CumsumOpTest : public OpsTestBase {};
namespace {
template <typename T>
void SimpleTestWithDataFormat(const std::vector<index_t> &shape,
const std::vector<float> &input,
const int axis,
const int exclusive,
const int reverse,
const std::vector<float> &output) {
// Construct graph
OpsTestNet net;
net.AddInputFromArray<CPU, T>("Input", shape, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("Cumsum", "CumsumTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntArg("axis", axis)
.AddIntArg("exclusive", exclusive)
.AddIntArg("reverse", reverse)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output",
NHWC);
net.AddInputFromArray<CPU, T>("ExpectedOutput", shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(CumsumOpTest, HasDataFormatCPU) {
SimpleTestWithDataFormat<float>(
{2, 2, 2, 2},
{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.},
0, 0, 0,
{0., 1., 2., 3., 4., 5., 6., 7., 8., 10., 12., 14., 16., 18., 20., 22.});
SimpleTestWithDataFormat<float>(
{2, 2, 2, 2},
{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.},
1, 0, 0,
{0., 1., 2., 3., 4., 6., 8., 10., 8., 9., 10., 11., 20., 22., 24., 26.});
SimpleTestWithDataFormat<float>(
{2, 2, 2, 2},
{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.},
0, 1, 0,
{0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 5., 6., 7.});
SimpleTestWithDataFormat<float>(
{2, 2, 2, 2},
{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.},
0, 0, 1,
{8., 10., 12., 14., 16., 18., 20., 22., 8., 9., 10., 11., 12., 13., 14.,
15.});
SimpleTestWithDataFormat<float>(
{2, 2, 2, 2},
{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.},
1, 1, 1,
{4., 5., 6., 7., 0., 0., 0., 0., 12., 13., 14., 15., 0., 0., 0., 0.});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -29,6 +29,7 @@ extern void RegisterChannelShuffle(OpRegistryBase *op_registry);
extern void RegisterConcat(OpRegistryBase *op_registry);
extern void RegisterConv2D(OpRegistryBase *op_registry);
extern void RegisterCrop(OpRegistryBase *op_registry);
extern void RegisterCumsum(OpRegistryBase *op_registry);
extern void RegisterDeconv2D(OpRegistryBase *op_registry);
extern void RegisterDepthToSpace(OpRegistryBase *op_registry);
extern void RegisterDepthwiseConv2d(OpRegistryBase *op_registry);
......@@ -95,6 +96,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterConcat(this);
ops::RegisterConv2D(this);
ops::RegisterCrop(this);
ops::RegisterCumsum(this);
ops::RegisterDeconv2D(this);
ops::RegisterDepthToSpace(this);
ops::RegisterDepthwiseConv2d(this);
......
......@@ -17,6 +17,7 @@
#include <vector>
#include "mace/core/operator.h"
#include "mace/utils/math.h"
namespace mace {
namespace ops {
......@@ -32,21 +33,69 @@ class StridedSliceOp : public Operation {
new_axis_mask_(Operation::GetOptionalArg<int>("new_axis_mask", 0)),
shrink_axis_mask_(
Operation::GetOptionalArg<int>("shrink_axis_mask", 0)),
is_slice_(Operation::GetOptionalArg<bool>("slice", false)) {
is_slice_(Operation::GetOptionalArg<bool>("slice", false)),
has_data_format_(Operation::GetOptionalArg<int>("has_data_format", 0)),
checked_(false) {
MACE_CHECK(ellipsis_mask_ == 0 && new_axis_mask_ == 0,
"ellipsis_mask and new_axis_mask are not supported yet.");
}
void TransposeMaskValueFromNHWCToNCHW(int* mask_value) {
size_t dims[4];
int count;
for (count = 0; count < 4; ++count) {
dims[count] = *mask_value & 1;
*mask_value >>= 1;
}
size_t new_dims[4] = {dims[0], dims[3], dims[1], dims[2]};
for (count = 3; count >= 0; --count) {
*mask_value <<= 1;
*mask_value += new_dims[count];
}
}
void TransposeDimsFromNHWCToNCHW(std::vector<int32_t>* dims) {
int32_t h = (*dims)[1];
int32_t w = (*dims)[2];
int32_t c = (*dims)[3];
(*dims)[1] = c;
(*dims)[2] = h;
(*dims)[3] = w;
}
void TransposeDimsFromNCHWToNHWC(std::vector<int32_t>* dims) {
int32_t c = (*dims)[1];
int32_t h = (*dims)[2];
int32_t w = (*dims)[3];
(*dims)[1] = h;
(*dims)[2] = w;
(*dims)[3] = c;
}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
if (!checked_) {
if (has_data_format_ && this->Input(0)->dim_size() == 4) {
TransposeMaskValueFromNHWCToNCHW(&begin_mask_);
TransposeMaskValueFromNHWCToNCHW(&end_mask_);
TransposeMaskValueFromNHWCToNCHW(&ellipsis_mask_);
TransposeMaskValueFromNHWCToNCHW(&new_axis_mask_);
TransposeMaskValueFromNHWCToNCHW(&shrink_axis_mask_);
}
checked_ = true;
}
const Tensor *input = this->Input(INPUT);
const Tensor *begin_indices = this->Input(BEGIN);
const Tensor *end_indices = this->Input(END);
const Tensor *strides = nullptr;
if (this->InputSize() > 3) {
strides = this->Input(STRIDES);
}
Tensor *output = this->Output(OUTPUT);
if (strides == nullptr) {
tmp_strides_tensor_.Resize({begin_indices->size()});
Tensor::MappingGuard strides_guard(&tmp_strides_tensor_);
......@@ -55,6 +104,11 @@ class StridedSliceOp : public Operation {
strides = &tmp_strides_tensor_;
}
MACE_CHECK(begin_indices->dim_size() == 1 &&
end_indices->dim_size() == 1 &&
strides->dim_size() == 1,
"Expected begin, end, and strides to be 1D tensor");
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard begin_indices_guard(begin_indices);
Tensor::MappingGuard end_indices_guard(end_indices);
......@@ -63,88 +117,145 @@ class StridedSliceOp : public Operation {
const int32_t *begin_indices_data = begin_indices->data<int32_t>();
const int32_t *end_indices_data = end_indices->data<int32_t>();
const int32_t *strides_data = strides->data<int32_t>();
std::vector<int32_t> pad_begin_indices(input->dim_size(), 0);
std::vector<int32_t> pad_end_indices(input->dim_size(), 0);
std::vector<int32_t> pad_strides_indices(input->dim_size(), 1);
if (begin_indices->size() < input->dim_size()) {
for (index_t i = 0; i < begin_indices->size(); ++i) {
pad_begin_indices[i] = begin_indices_data[i];
pad_end_indices[i] = end_indices_data[i];
pad_strides_indices[i] = strides_data[i];
std::vector<int32_t> begin_indices_vec(
begin_indices_data, begin_indices_data + begin_indices->size());
std::vector<int32_t> end_indices_vec(
end_indices_data, end_indices_data + end_indices->size());
std::vector<int32_t> strides_indices_vec(
strides_data, strides_data + strides->size());
MACE_CHECK(input->size() > 0 && input->dim_size() > 0 &&
input->dim_size() <= 4,
"The input size should larger than 0."
" And input dims should be an integer in (0, 4].");
std::vector<index_t> output_shape = {};
const size_t input_dims = input->dim_size();
if (is_slice_) {
MACE_CHECK(begin_indices_vec.size() == input_dims &&
end_indices_vec.size() == input_dims,
"In slice, begin and size elements num should be equal");
// transpose
if (has_data_format_ && this->Input(0)->dim_size() == 4) {
TransposeDimsFromNHWCToNCHW(&begin_indices_vec);
TransposeDimsFromNHWCToNCHW(&end_indices_vec);
TransposeDimsFromNHWCToNCHW(&strides_indices_vec);
}
for (index_t i = begin_indices->size(); i < input->dim_size(); ++i) {
pad_end_indices[i] = input->dim(i);
for (size_t i = 0; i < input_dims; ++i) {
if (end_indices_vec[i] == -1) {
end_indices_vec[i] = input->dim(i) - begin_indices_vec[i];
}
begin_indices_data = pad_begin_indices.data();
end_indices_data = pad_end_indices.data();
strides_data = pad_strides_indices.data();
}
std::vector<int32_t> slice_end_data;
if (is_slice_) {
// if this op is slice, the end_indices_data is size actually
slice_end_data.resize(end_indices->size());
for (size_t i = 0; i < slice_end_data.size(); ++i) {
if (end_indices_data[i] == -1) {
slice_end_data[i] = input->dim(i);
for (size_t i = 0; i < input_dims; ++i) {
int32_t b = begin_indices_vec[i];
int32_t s = end_indices_vec[i];
int32_t input_i = input->dim(i);
MACE_CHECK(0 <= b && b <= input_i,
"In Slice, expected begin[", i, "] in [0, ", input_i,
"], but got ", b);
MACE_CHECK(0 <= s && b + s <= input_i,
"In Slice, expected size[", i, "] in [0, ",
input_i - b, "], but got", s);
end_indices_vec[i] = b + s;
output_shape.push_back(s);
}
} else {
slice_end_data[i] = begin_indices_data[i] + end_indices_data[i];
MACE_CHECK(begin_indices_vec.size() == end_indices_vec.size() &&
end_indices_vec.size() == strides_indices_vec.size(),
"In strided_slice, expected begin, end, and strides to be",
" equal size tensors");
for (index_t i = 0; i < strides->size(); ++i) {
MACE_CHECK(strides_indices_vec[i] != 0, "strides data cannot be 0!");
}
// pad
begin_indices_vec.resize(input_dims, 0);
strides_indices_vec.resize(input_dims, 1);
std::vector<int32_t> tmp_input_dims(input->shape().begin(),
input->shape().end());
if (has_data_format_ && input_dims == 4) {
TransposeDimsFromNCHWToNHWC(&tmp_input_dims);
}
for (size_t i = end_indices_vec.size(); i < input_dims; ++i) {
end_indices_vec.push_back(tmp_input_dims[i]);
}
end_indices_data = slice_end_data.data();
// transpose
if (has_data_format_ && this->Input(0)->dim_size() == 4) {
TransposeDimsFromNHWCToNCHW(&begin_indices_vec);
TransposeDimsFromNHWCToNCHW(&end_indices_vec);
TransposeDimsFromNHWCToNCHW(&strides_indices_vec);
}
std::vector<index_t> output_shape;
std::vector<index_t> real_begin_indices(input->dim_size(), 0);
std::vector<index_t> real_end_indices(input->dim_size(), 0);
// mask and shrink
for (index_t d = 0; d < input->dim_size(); ++d) {
index_t dim_len = input->dim(d);
const std::vector<index_t> valid_range = {
strides_indices_vec[d] > 0 ? 0 : -1,
strides_indices_vec[d] > 0 ? dim_len : dim_len - 1};
auto format_indices = [valid_range, d, dim_len](index_t indice) {
index_t forward = indice < 0 ? indice + dim_len : indice;
return Clamp(forward, valid_range[0], valid_range[1]);
};
if (!(shrink_axis_mask_ & (1 << d))) {
if (begin_mask_ & (1 << d)) {
real_begin_indices[d] = strides_data[d] > 0 ? 0 : dim_len - 1;
begin_indices_vec[d] = strides_indices_vec[d] > 0 ? 0 : dim_len - 1;
} else {
real_begin_indices[d] = (begin_indices_data[d] + dim_len) % dim_len;
begin_indices_vec[d] = format_indices(begin_indices_vec[d]);
}
if (end_mask_ & (1 << d)) {
real_end_indices[d] = strides_data[d] > 0 ? dim_len : -1;
end_indices_vec[d] = strides_indices_vec[d] > 0 ? dim_len : -1;
} else {
real_end_indices[d] =
end_indices_data[d] < -dim_len
? -1
: (end_indices_data[d] < 0
? (end_indices_data[d] + dim_len)
: std::min(static_cast<index_t>(end_indices_data[d]),
dim_len));
end_indices_vec[d] = format_indices(end_indices_vec[d]);
}
int32_t out_dim_len = std::max(
0.f, std::ceil((real_end_indices[d] - real_begin_indices[d]) /
static_cast<float>(strides_data[d])));
if (!(shrink_axis_mask_ & (1 << d))) {
0.f, std::ceil((end_indices_vec[d] - begin_indices_vec[d]) /
static_cast<float>(strides_indices_vec[d])));
output_shape.push_back(out_dim_len);
} else {
MACE_CHECK(out_dim_len == 1,
"cannot shrink axis that has len > 1, dim(", d, "): [",
real_begin_indices[d], ", ", real_end_indices[d], "]");
begin_indices_vec[d] = begin_indices_vec[d] < 0
? begin_indices_vec[d] + dim_len
: begin_indices_vec[d];
end_indices_vec[d] = begin_indices_vec[d] + 1;
MACE_CHECK(
begin_indices_vec[d] >= 0 && begin_indices_vec[d] < dim_len,
"slice begin indice of dimension '", d, "': ",
begin_indices_vec[d], ", is out of bound");
}
}
}
for (size_t i = 0; i < output_shape.size(); ++i) {
MACE_CHECK(output_shape[i] > 0,
"Expected output_shape[", i, "] larger than 0, but got ",
output_shape[i]);
}
std::vector<index_t> dim_stride(input->dim_size(), 1);
for (index_t d = input->dim_size() - 2; d >= 0; --d) {
dim_stride[d] = dim_stride[d + 1] * input->dim(d + 1);
}
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard output_guard(output);
T *output_data = output->mutable_data<T>();
bool slice_by_first_axis = true;
if (strides_data[0] != 1) {
if (strides_indices_vec[0] != 1) {
slice_by_first_axis = false;
} else {
for (index_t d = 1; d < input->dim_size(); ++d) {
if (strides_data[d] != 1 || real_begin_indices[d] != 0 ||
real_end_indices[d] != input->dim(d)) {
if (strides_indices_vec[d] != 1 || begin_indices_vec[d] != 0 ||
end_indices_vec[d] != input->dim(d)) {
slice_by_first_axis = false;
break;
}
......@@ -152,47 +263,71 @@ class StridedSliceOp : public Operation {
}
if (slice_by_first_axis) {
memcpy(output_data, input_data + real_begin_indices[0] * dim_stride[0],
sizeof(T) * (real_end_indices[0] - real_begin_indices[0]) *
memcpy(output_data, input_data + begin_indices_vec[0] * dim_stride[0],
sizeof(T) * (end_indices_vec[0] - begin_indices_vec[0]) *
dim_stride[0]);
} else {
if (input->dim_size() == 1) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
: i > real_end_indices[0];
i += strides_data[0]) {
for (index_t i = begin_indices_vec[0];
strides_indices_vec[0] > 0 ? i < end_indices_vec[0]
: i > end_indices_vec[0];
i += strides_indices_vec[0]) {
*output_data++ = input_data[i];
}
} else if (input->dim_size() == 2) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
: i > real_end_indices[0];
i += strides_data[0]) {
for (index_t j = real_begin_indices[1];
strides_data[1] > 0 ? j < real_end_indices[1]
: j > real_end_indices[1];
j += strides_data[1]) {
for (index_t i = begin_indices_vec[0];
strides_indices_vec[0] > 0 ? i < end_indices_vec[0]
: i > end_indices_vec[0];
i += strides_indices_vec[0]) {
for (index_t j = begin_indices_vec[1];
strides_indices_vec[1] > 0 ? j < end_indices_vec[1]
: j > end_indices_vec[1];
j += strides_indices_vec[1]) {
*output_data++ = input_data[i * input->dim(1) + j];
}
}
} else if (input->dim_size() == 3) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
: i > real_end_indices[0];
i += strides_data[0]) {
for (index_t j = real_begin_indices[1];
strides_data[1] > 0 ? j < real_end_indices[1]
: j > real_end_indices[1];
j += strides_data[1]) {
for (index_t k = real_begin_indices[2];
strides_data[2] > 0 ? k < real_end_indices[2]
: k > real_end_indices[2];
k += strides_data[2]) {
for (index_t i = begin_indices_vec[0];
strides_indices_vec[0] > 0 ? i < end_indices_vec[0]
: i > end_indices_vec[0];
i += strides_indices_vec[0]) {
for (index_t j = begin_indices_vec[1];
strides_indices_vec[1] > 0 ? j < end_indices_vec[1]
: j > end_indices_vec[1];
j += strides_indices_vec[1]) {
for (index_t k = begin_indices_vec[2];
strides_indices_vec[2] > 0 ? k < end_indices_vec[2]
: k > end_indices_vec[2];
k += strides_indices_vec[2]) {
*output_data++ =
input_data[(i * input->dim(1) + j) * input->dim(2) + k];
}
}
}
} else if (input->dim_size() == 4) {
for (index_t i = begin_indices_vec[0];
strides_indices_vec[0] > 0 ? i < end_indices_vec[0]
: i > end_indices_vec[0];
i += strides_indices_vec[0]) {
for (index_t j = begin_indices_vec[1];
strides_indices_vec[1] > 0 ? j < end_indices_vec[1]
: j > end_indices_vec[1];
j += strides_indices_vec[1]) {
for (index_t k = begin_indices_vec[2];
strides_indices_vec[2] > 0 ? k < end_indices_vec[2]
: k > end_indices_vec[2];
k += strides_indices_vec[2]) {
for (index_t l = begin_indices_vec[3];
strides_indices_vec[3] > 0 ? l < end_indices_vec[3]
: l > end_indices_vec[3];
l += strides_indices_vec[3]) {
*output_data++ =
input_data[((i * input->dim(1) + j) * input->dim(2) + k)
* input->dim(3) + l];
}
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -207,6 +342,8 @@ class StridedSliceOp : public Operation {
int new_axis_mask_;
int shrink_axis_mask_;
bool is_slice_;
int has_data_format_;
bool checked_;
Tensor tmp_strides_tensor_;
MACE_OP_INPUT_TAGS(INPUT, BEGIN, END, STRIDES);
......
......@@ -64,6 +64,54 @@ void TestStridedSlice(const std::vector<index_t> &input_shape,
*net.GetOutput("Output"));
}
void TestStridedSliceWithDataFormat(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int32_t> &begin_indices,
const std::vector<int32_t> &end_indices,
const std::vector<int32_t> &strides,
const int begin_mask,
const int end_mask,
const int ellipsis_mask,
const int new_axis_mask,
const int shrink_axis_mask,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>(
"BeginIndices", {static_cast<int32_t>(begin_indices.size())},
begin_indices);
net.AddInputFromArray<CPU, int32_t>(
"EndIndices", {static_cast<int32_t>(end_indices.size())}, end_indices);
net.AddInputFromArray<CPU, int32_t>(
"Strides", {static_cast<int32_t>(strides.size())}, strides);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("InputNCHW")
.Input("BeginIndices")
.Input("EndIndices")
.Input("Strides")
.Output("OutputNCHW")
.AddIntArg("begin_mask", begin_mask)
.AddIntArg("end_mask", end_mask)
.AddIntArg("ellipsis_mask", ellipsis_mask)
.AddIntArg("new_axis_mask", new_axis_mask)
.AddIntArg("shrink_axis_mask", shrink_axis_mask)
.AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef());
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output",
NHWC);
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
void TestSlice(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int32_t> &begin_indices,
......@@ -92,6 +140,41 @@ void TestSlice(const std::vector<index_t> &input_shape,
*net.GetOutput("Output"));
}
void TestSliceWithDataFormat(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int32_t> &begin_indices,
const std::vector<int32_t> &indices_size,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>(
"BeginIndices", {static_cast<int32_t>(input_shape.size())},
begin_indices);
net.AddInputFromArray<CPU, int32_t>(
"IndicesSize", {static_cast<int32_t>(indices_size.size())}, indices_size);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("InputNCHW")
.Input("BeginIndices")
.Input("IndicesSize")
.Output("OutputNCHW")
.AddIntArg("slice", 1)
.AddIntArg("has_data_format", 1)
.Finalize(net.NewOperatorDef());
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output",
NHWC);
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(StridedSliceOpTest, TestStridedSliceByFirstAxis) {
......@@ -157,6 +240,66 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank3) {
1, 2}, {1, 1, 3, 3});
}
TEST_F(StridedSliceOpTest, TestStridedSliceRank4) {
TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 0, 0, 0, 0, 0, {1, 2, 1, 2},
{15, 16, 21, 22});
TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 3, 0, 0, 0, 0, {2, 2, 1, 2},
{3, 4, 9, 10, 15, 16, 21, 22});
TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 0, {1, 2, 1, 3},
{15, 16, 17, 21, 22, 23});
TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 8, {1, 2, 1},
{15, 21});
TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 15, {}, {15});
TestStridedSlice({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {-1, 2, 1, 3},
{0, 0, 0, 0}, {-1, -1, -1, -1}, 0, 0, 0, 0, 0, {1, 1, 1, 2},
{23, 22});
}
TEST_F(StridedSliceOpTest, TestStridedSliceWithDataFormat) {
TestStridedSliceWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 0, 0, 0, 0, 0, {1, 2, 1, 2},
{15, 16, 21, 22});
TestStridedSliceWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 3, 0, 0, 0, 0, {2, 2, 1, 2},
{3, 4, 9, 10, 15, 16, 21, 22});
TestStridedSliceWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0, 1, 0},
{2, 2, 2, 2}, {1, 1, 1, 1}, 0, 8, 0, 0, 0, {1, 2, 1, 3},
{15, 16, 17, 21, 22, 23});
TestStridedSliceWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0},
{2, 1}, {1, 1}, 0, 8, 0, 0, 0, {1, 1, 2, 3},
{12, 13, 14, 15, 16, 17});
TestStridedSliceWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {1, 0},
{2, 1}, {1, 1}, 0, 2, 0, 0, 0, {1, 2, 2, 3},
{12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
TestStridedSliceWithDataFormat(
{2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, {-1, 2, 1, 3},
{0, 0, 0, 0}, {-1, -1, -1, -1}, 0, 0, 0, 0, 0, {1, 1, 1, 2},
{23, 22});
}
TEST_F(StridedSliceOpTest, TestSlice) {
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {2, 3},
{1, 2, 3, 4, 5, 6});
......@@ -166,6 +309,17 @@ TEST_F(StridedSliceOpTest, TestSlice) {
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 1}, {2, -1}, {2, 2}, {2, 3, 5, 6});
}
TEST_F(StridedSliceOpTest, TestSliceWithDataFormat) {
TestSliceWithDataFormat({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1, 0, 1, 0}, {1, 2, 1, 2}, {1, 2, 1, 2},
{15, 16, 21, 22});
TestSliceWithDataFormat({2, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1, 0, 1, 0}, {-1, -1, -1, -1}, {1, 2, 1, 3},
{15, 16, 17, 21, 22, 23});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -159,6 +159,7 @@ MaceSupportedOps = [
'Transpose',
'WinogradInverseTransform',
'WinogradTransform',
'Cumsum',
]
MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)
......@@ -237,6 +238,8 @@ class MaceKeyword(object):
mace_find_range_every_time = 'find_range_every_time'
mace_non_zero = 'non_zero'
mace_pad_type_str = 'pad_type'
mace_exclusive_str = 'exclusive'
mace_reverse_str = 'reverse'
class TransformerRule(Enum):
......
......@@ -116,6 +116,7 @@ TFSupportedOps = [
'FloorDiv',
'Sqrt',
'MirrorPad',
'Cumsum',
]
TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)
......@@ -253,6 +254,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.FloorDiv.name: self.convert_elementwise,
TFOpType.Sqrt.name: self.convert_elementwise,
TFOpType.MirrorPad.name: self.convert_pad,
TFOpType.Cumsum.name: self.convert_cumsum,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......@@ -1012,3 +1014,23 @@ class TensorflowConverter(base_converter.ConverterInterface):
self._skip_tensor.add(tf_op.inputs[1].name)
self._skip_tensor.add(tf_op.inputs[2].name)
def convert_cumsum(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Cumsum.name
axis = tf_op.inputs[1].eval().astype(np.int32)
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = axis
del op.input[1]
exclusive = tf_op.get_attr('exclusive')
exclusive_arg = op.arg.add()
exclusive_arg.name = MaceKeyword.mace_exclusive_str
exclusive_arg.i = int(exclusive)
reverse = tf_op.get_attr('reverse')
reverse_arg = op.arg.add()
reverse_arg.name = MaceKeyword.mace_reverse_str
reverse_arg.i = int(reverse)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册