提交 a85f052d 编写于 作者: Y yejianwu

support tf.cumsum and 4-D strided_slice in cpu

上级 04db6237
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
namespace {
void PlusOne(int* val) {
++(*val);
}
void SubOne(int* val) {
--(*val);
}
bool LessThan(const int& val, const int& boundary) {
return val < boundary;
}
bool NotLessThanZero(const int& val, const int& boundary) {
MACE_UNUSED(boundary);
return val >= 0;
}
} // namespace
template <DeviceType D, typename T>
class CumsumOp;
template <typename T>
class CumsumOp<DeviceType::CPU, T> : public Operation {
public:
explicit CumsumOp(OpConstructContext *context)
: Operation(context),
axis_(Operation::GetOptionalArg<int>("axis", 3)),
exclusive_(Operation::GetOptionalArg<bool>("exclusive", false)),
reverse_(Operation::GetOptionalArg<bool>("reverse", false)) {}
void Validate() {
const int32_t input_dims = this->Input(0)->dim_size();
axis_ =
axis_ < 0 ? axis_ + input_dims : axis_;
MACE_CHECK((0 <= axis_ && axis_ < input_dims),
"Expected concatenating axis in the range [", -input_dims, ", ",
input_dims, "], but got ", axis_);
}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
if (!checked_) {
Validate();
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE));
if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) {
if (axis_ == 3) axis_ = 1;
else if (axis_ == 2) axis_ = 3;
else if (axis_ == 1) axis_ = 2;
}
checked_ = true;
}
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard output_mapper(output);
const float *input_ptr = input->data<float>();
float *output_ptr = output->mutable_data<float>();
std::function<void(int*)> next = reverse_ ? SubOne : PlusOne;
std::function<void(int*)> previous = reverse_ ? PlusOne : SubOne;
std::function<bool(const int&, const int&)> boundary =
reverse_ ? NotLessThanZero : LessThan;
if (input->dim_size() == 4) {
const int batch = input->dim(0);
const int channel = input->dim(1);
const int height = input->dim(2);
const int width = input->dim(3);
const int axis_dim_size = input->dim(axis_);
for (int n = reverse_ ? batch - 1 : 0; boundary(n, batch); next(&n)) {
for (int c = reverse_ ? channel - 1 : 0; boundary(c, channel);
next(&c)) {
for (int h = reverse_ ? height - 1 : 0; boundary(h, height);
next(&h)) {
for (int w = reverse_ ? width - 1 : 0; boundary(w, width);
next(&w)) {
int dims[4] = {n, c, h, w};
if (!reverse_ && dims[axis_] == 0) {
if (exclusive_) {
output_ptr[((n * channel + c) * height + h) * width + w] = 0;
} else {
continue;
}
} else if (reverse_ && dims[axis_] == axis_dim_size - 1) {
if (exclusive_) {
output_ptr[((n * channel + c) * height + h) * width + w] = 0;
} else {
continue;
}
} else {
previous(&dims[axis_]);
if (exclusive_) {
output_ptr[((n * channel + c) * height + h) * width + w] =
input_ptr[((dims[0] * channel + dims[1]) * height +
dims[2]) *
width +
dims[3]] +
output_ptr[((dims[0] * channel + dims[1]) * height +
dims[2]) *
width +
dims[3]];
} else {
output_ptr[((n * channel + c) * height + h) * width + w] =
input_ptr[((n * channel + c) * height + h) * width + w] +
output_ptr[((dims[0] * channel + dims[1]) * height +
dims[2]) *
width +
dims[3]];
}
}
}
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
return MaceStatus::MACE_SUCCESS;
}
private:
int32_t axis_;
bool exclusive_;
bool reverse_;
bool checked_;
};
void RegisterCumsum(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Cumsum", CumsumOp,
DeviceType::CPU, float);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <vector>
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class CumsumOpTest : public OpsTestBase {};
namespace {
void SimpleTest() {
// Construct graph
OpsTestNet net;
net.AddInputFromArray<DeviceType::CPU, float>("Input", {2, 2, 2, 2},
{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.});
OpDefBuilder("Cumsum", "CumsumTest")
.Input("Input")
.Output("Output")
.AddIntArg("axis", 1)
.AddIntArg("exclusive", 1)
.AddIntArg("reverse", 1)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<float>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(DeviceType::CPU);
auto expected = net.CreateTensor<float>({2, 2, 2, 2},
{4., 5., 6., 7., 0., 0., 0., 0., 12., 13., 14., 15., 0., 0., 0., 0.});
ExpectTensorNear<float, float>(*expected, *net.GetOutput("Output"), 1e-5);
}
} // namespace
TEST_F(CumsumOpTest, CPU) {
SimpleTest();
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -29,6 +29,7 @@ extern void RegisterChannelShuffle(OpRegistryBase *op_registry);
extern void RegisterConcat(OpRegistryBase *op_registry);
extern void RegisterConv2D(OpRegistryBase *op_registry);
extern void RegisterCrop(OpRegistryBase *op_registry);
extern void RegisterCumsum(OpRegistryBase *op_registry);
extern void RegisterDeconv2D(OpRegistryBase *op_registry);
extern void RegisterDepthToSpace(OpRegistryBase *op_registry);
extern void RegisterDepthwiseConv2d(OpRegistryBase *op_registry);
......@@ -95,6 +96,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterConcat(this);
ops::RegisterConv2D(this);
ops::RegisterCrop(this);
ops::RegisterCumsum(this);
ops::RegisterDeconv2D(this);
ops::RegisterDepthToSpace(this);
ops::RegisterDepthwiseConv2d(this);
......
......@@ -32,13 +32,52 @@ class StridedSliceOp : public Operation {
new_axis_mask_(Operation::GetOptionalArg<int>("new_axis_mask", 0)),
shrink_axis_mask_(
Operation::GetOptionalArg<int>("shrink_axis_mask", 0)),
is_slice_(Operation::GetOptionalArg<bool>("slice", false)) {
is_slice_(Operation::GetOptionalArg<bool>("slice", false)),
checked_(false) {
MACE_CHECK(ellipsis_mask_ == 0 && new_axis_mask_ == 0,
"ellipsis_mask and new_axis_mask are not supported yet.");
}
void TransposeMaskValueFromNHWCToNCHW(int* mask_value) {
size_t dims[4];
int count;
for (count = 0; count < 4; ++count) {
dims[count] = *mask_value & 1;
*mask_value >>= 1;
}
size_t new_dims[4] = {dims[0], dims[3], dims[1], dims[2]};
for (count = 3; count >= 0; --count) {
*mask_value <<= 1;
*mask_value += new_dims[count];
}
}
void TransposeDimsFromNHWCToNCHW(std::vector<int32_t>* dims) {
int32_t h = (*dims)[1];
int32_t w = (*dims)[2];
int32_t c = (*dims)[3];
(*dims)[1] = c;
(*dims)[2] = h;
(*dims)[3] = w;
}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE));
if (!checked_) {
if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) {
TransposeMaskValueFromNHWCToNCHW(&begin_mask_);
TransposeMaskValueFromNHWCToNCHW(&end_mask_);
TransposeMaskValueFromNHWCToNCHW(&ellipsis_mask_);
TransposeMaskValueFromNHWCToNCHW(&new_axis_mask_);
TransposeMaskValueFromNHWCToNCHW(&shrink_axis_mask_);
}
checked_ = true;
}
const Tensor *input = this->Input(INPUT);
const Tensor *begin_indices = this->Input(BEGIN);
const Tensor *end_indices = this->Input(END);
......@@ -76,11 +115,30 @@ class StridedSliceOp : public Operation {
for (index_t i = begin_indices->size(); i < input->dim_size(); ++i) {
pad_end_indices[i] = input->dim(i);
}
begin_indices_data = pad_begin_indices.data();
end_indices_data = pad_end_indices.data();
strides_data = pad_strides_indices.data();
}
std::vector<int32_t> transpose_begin_indices(input->dim_size(), 0);
std::vector<int32_t> transpose_end_indices(input->dim_size(), 0);
std::vector<int32_t> transpose_strides_indices(input->dim_size(), 1);
if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) {
for (index_t i = 0; i < begin_indices->size(); ++i) {
transpose_begin_indices[i] = begin_indices_data[i];
transpose_end_indices[i] = end_indices_data[i];
transpose_strides_indices[i] = strides_data[i];
}
TransposeDimsFromNHWCToNCHW(&transpose_begin_indices);
TransposeDimsFromNHWCToNCHW(&transpose_end_indices);
TransposeDimsFromNHWCToNCHW(&transpose_strides_indices);
begin_indices_data = transpose_begin_indices.data();
end_indices_data = transpose_end_indices.data();
strides_data = transpose_strides_indices.data();
}
std::vector<int32_t> slice_end_data;
if (is_slice_) {
// if this op is slice, the end_indices_data is size actually
......@@ -193,6 +251,30 @@ class StridedSliceOp : public Operation {
}
}
}
} else if (input->dim_size() == 4) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
: i > real_end_indices[0];
i += strides_data[0]) {
for (index_t j = real_begin_indices[1];
strides_data[1] > 0 ? j < real_end_indices[1]
: j > real_end_indices[1];
j += strides_data[1]) {
for (index_t k = real_begin_indices[2];
strides_data[2] > 0 ? k < real_end_indices[2]
: k > real_end_indices[2];
k += strides_data[2]) {
for (index_t l = real_begin_indices[3];
strides_data[3] > 0 ? l < real_end_indices[3]
: l > real_end_indices[3];
l += strides_data[3]) {
*output_data++ =
input_data[((i * input->dim(1) + j) * input->dim(2) + k)
* input->dim(3) + l];
}
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -207,6 +289,7 @@ class StridedSliceOp : public Operation {
int new_axis_mask_;
int shrink_axis_mask_;
bool is_slice_;
bool checked_;
Tensor tmp_strides_tensor_;
MACE_OP_INPUT_TAGS(INPUT, BEGIN, END, STRIDES);
......
......@@ -161,6 +161,7 @@ MaceSupportedOps = [
'Transpose',
'WinogradInverseTransform',
'WinogradTransform',
'Cumsum',
]
MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)
......@@ -239,6 +240,8 @@ class MaceKeyword(object):
mace_find_range_every_time = 'find_range_every_time'
mace_non_zero = 'non_zero'
mace_pad_type_str = 'pad_type'
mace_exclusive_str = 'exclusive'
mace_reverse_str = 'reverse'
class TransformerRule(Enum):
......
......@@ -117,6 +117,7 @@ TFSupportedOps = [
'FloorDiv',
'Sqrt',
'MirrorPad',
'Cumsum',
]
TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)
......@@ -254,6 +255,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.FloorDiv.name: self.convert_elementwise,
TFOpType.Sqrt.name: self.convert_elementwise,
TFOpType.MirrorPad.name: self.convert_pad,
TFOpType.Cumsum.name: self.convert_cumsum,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......@@ -1007,3 +1009,23 @@ class TensorflowConverter(base_converter.ConverterInterface):
self._skip_tensor.add(tf_op.inputs[1].name)
self._skip_tensor.add(tf_op.inputs[2].name)
def convert_cumsum(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Cumsum.name
axis = tf_op.inputs[1].eval().astype(np.int32)
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = axis
del op.input[1]
exclusive = tf_op.get_attr('exclusive')
exclusive_arg = op.arg.add()
exclusive_arg.name = MaceKeyword.mace_exclusive_str
exclusive_arg.i = int(exclusive)
reverse = tf_op.get_attr('reverse')
reverse_arg = op.arg.add()
reverse_arg.name = MaceKeyword.mace_reverse_str
reverse_arg.i = int(reverse)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册