提交 255296b8 编写于 作者: L liutuo

deconv support caffe and tf.nn.conv2d_transpose and crop op for caffe

上级 fc1c855e
......@@ -6,47 +6,48 @@ Operator lists
:header: "Operator","Supported","Remark"
"AVERAGE_POOL_2D","Y",""
"ARGMAX","Y","Only CPU and tensorflow is supported"
"BATCH_NORM","Y","Fusion with activation is supported"
"ARGMAX","Y","Only CPU and TensorFlow is supported."
"BATCH_NORM","Y","Fusion with activation is supported."
"BATCH_TO_SPACE_ND","Y",""
"BIAS_ADD","Y",""
"CAST","Y","Only CPU and tensorflow model is supported"
"CAST","Y","Only CPU and TensorFlow model is supported."
"CHANNEL_SHUFFLE","Y",""
"CONCATENATION","Y","Only support channel axis concatenation"
"CONV_2D","Y","Fusion with BN and activation layer is supported"
"DECONV_2D","Y","Only tensorflow model is supported"
"DEPTHWISE_CONV_2D","Y","Only multiplier = 1 is supported; Fusion is supported"
"CONCATENATION","Y","Only support channel axis concatenation."
"CONV_2D","Y","Fusion with BN and activation layer is supported."
"CROP","Y","Only Caffe's crop layer is supported (in GPU, offset on channel-dim should be dividable by 4)."
"DECONV_2D","Y","Supports Caffe's Deconvolution and TensorFlow's tf.layers.conv2d_transpose."
"DEPTHWISE_CONV_2D","Y","Only multiplier = 1 is supported; Fusion is supported."
"DEPTH_TO_SPACE","Y",""
"DEQUANTIZE","Y","Model quantization will be supported later"
"DEQUANTIZE","Y","Model quantization will be supported later."
"ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/EQUAL"
"EMBEDDING_LOOKUP","Y","Only support channel axis concatenation"
"EMBEDDING_LOOKUP","Y","Only support channel axis concatenation."
"FULLY_CONNECTED","Y",""
"GROUP_CONV_2D","","Caffe model with group count = channel count is supported"
"IDENTITY","Y","Only tensorflow model is supported"
"GROUP_CONV_2D","","Caffe model with group count = channel count is supported."
"IDENTITY","Y","Only TensorFlow model is supported."
"LOCAL_RESPONSE_NORMALIZATION","Y",""
"LOGISTIC","Y",""
"LSTM","",""
"MATMUL","Y","Only CPU is supported"
"MATMUL","Y","Only CPU is supported."
"MAX_POOL_2D","Y",""
"PAD","Y",""
"PSROI_ALIGN","Y",""
"PRELU","Y","Only caffe model is supported"
"REDUCE_MEAN","Y","Only tensorflow model is supported. For GPU only H + W axis reduce is supported"
"PRELU","Y","Only Caffe model is supported"
"REDUCE_MEAN","Y","Only TensorFlow model is supported. For GPU only H + W axis reduce is supported."
"RELU","Y",""
"RELU1","Y",""
"RELU6","Y",""
"RELUX","Y",""
"RESHAPE","Y","Limited support: GPU is full supported, for CPU only supports softmax-like usage"
"RESHAPE","Y","Limited support: GPU is full supported, for CPU only supports softmax-like usage."
"RESIZE_BILINEAR","Y",""
"RNN","",""
"RPN_PROPOSAL_LAYER","Y",""
"SHAPE","Y","Only CPU and tensorflow is supported"
"STACK","Y","Only CPU and tensorflow is supported"
"STRIDEDSLICE","Y","Only CPU and tensorflow is supported"
"SLICE","Y","In tensorflow, this op is equivalent to SPLIT; Only support channel axis slice"
"SHAPE","Y","Only CPU and TensorFlow is supported."
"STACK","Y","Only CPU and TensorFlow is supported."
"STRIDEDSLICE","Y","Only CPU and TensorFlow is supported."
"SLICE","Y","In TensorFlow, this op is equivalent to SPLIT; Only support channel axis slice."
"SOFTMAX","Y",""
"SPACE_TO_BATCH_ND", "Y",""
"SPACE_TO_DEPTH","Y",""
"SQEEZE","Y","Only CPU and tensorflow is supported"
"SQEEZE","Y","Only CPU and TensorFlow is supported."
"TANH","Y",""
"TRANSPOSE","Y","Only CPU and tensorflow is supported"
"TRANSPOSE","Y","Only CPU and TensorFlow is supported."
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_CROP_H_
#define MACE_KERNELS_CROP_H_
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/types.h"
#include "mace/public/mace.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/cl2_header.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace kernels {
struct CropFunctorBase {
CropFunctorBase(const int axis,
const std::vector<int> &offset)
: axis_(axis),
offset_(offset) {}
const int axis_;
std::vector<int> offset_;
};
template <DeviceType D, typename T>
struct CropFunctor : CropFunctorBase {
CropFunctor(const int axis, const std::vector<int> &offset)
: CropFunctorBase(axis, offset) {}
void crop_copy(const T* input_data, T* output_data,
const std::vector<index_t> &input_shape,
const std::vector<index_t> &output_shape,
const int32_t* offsets) {
const index_t out_img_size =
output_shape[1] * output_shape[2] * output_shape[3];
const index_t out_hw = output_shape[2] * output_shape[3];
const index_t in_img_size =
input_shape[1] * input_shape[2] * input_shape[3];
const index_t in_hw = input_shape[2] * input_shape[3];
#pragma omp parallel for collapse(3)
for (int b = 0; b < output_shape[0]; ++b) {
for (int c = 0; c < output_shape[1]; ++c) {
for (int h = 0; h < output_shape[2]; ++h) {
T* out_ptr =
output_data + b * out_img_size + c * out_hw + h * output_shape[3];
const T* in_ptr_bch =
input_data + (b + offsets[0]) * in_img_size +
(c + offsets[1]) * in_hw +
(h + offsets[2]) * input_shape[3] + offsets[3];
memcpy(out_ptr, in_ptr_bch,
output_shape[3] * sizeof(T));
}
}
}
}
MaceStatus operator()(const std::vector<const Tensor *> &input_list,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
MACE_CHECK(input_list.size() == 2, "Crop op needs two inputs.");
const Tensor *input0 = input_list[0];
const Tensor *input1 = input_list[1];
const uint32_t in0_dims = static_cast<uint32_t >(input0->dim_size());
const uint32_t in1_dims = static_cast<uint32_t >(input0->dim_size());
MACE_CHECK(in0_dims == 4 && in1_dims == 4,
"crop op only supports 4-dims inputs now.");
std::vector<int32_t> offsets(in0_dims, 0);
std::vector<index_t> output_shape(input0->shape());
for (index_t i = 0; i < in0_dims; ++i) {
int32_t crop_offset = 0;
index_t new_size = input0->dim(i);
if (i >= axis_) {
new_size = input1->dim(i);
if (offset_.size() == 1) {
crop_offset = offset_[0];
} else if (offset_.size() > 1) {
crop_offset = offset_[i - axis_];
}
MACE_CHECK(input0->dim(i) - crop_offset >= input1->dim(i))
<< "the crop for dimension" << i << "is out of bound with size"
<< input1->dim(i) << "and offset" << crop_offset;
}
output_shape[i] = new_size;
offsets[i] = crop_offset;
}
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
T *output_data = output->mutable_data<T>();
const T * input_data = input0->data<T>();
crop_copy(input_data, output_data, input0->shape(),
output_shape, offsets.data());
return MACE_SUCCESS;
}
};
#ifdef MACE_ENABLE_OPENCL
template <typename T>
struct CropFunctor<DeviceType::GPU, T> : CropFunctorBase {
CropFunctor(const int axis, const std::vector<int> &offset)
: CropFunctorBase(axis, offset) {}
MaceStatus operator()(const std::vector<const Tensor *> &input_list,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
uint32_t kwg_size_;
std::unique_ptr<BufferBase> kernel_error_;
std::vector<index_t> input_shape_;
};
#endif // MACE_ENABLE_OPENCL
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_CROP_H_
......@@ -95,13 +95,15 @@ struct Deconv2dFunctorBase {
const std::vector<int> &paddings,
const std::vector<index_t> &output_shape,
const ActivationType activation,
const float relux_max_limit)
const float relux_max_limit,
const bool from_caffe)
: strides_(strides),
padding_type_(padding_type),
paddings_(paddings),
output_shape_(output_shape),
activation_(activation),
relux_max_limit_(relux_max_limit) {}
relux_max_limit_(relux_max_limit),
from_caffe_(from_caffe) {}
static void CalcDeconvOutputSize(
const index_t *input_shape, // NHWC
......@@ -121,16 +123,13 @@ struct Deconv2dFunctorBase {
const index_t in_height = isNCHW ? input_shape[2] : input_shape[1];
const index_t in_width = isNCHW ? input_shape[3] : input_shape[2];
const index_t extended_input_height =
(in_height - 1) * strides[0] + 1 + padding_size[0];
const index_t extended_input_width =
(in_width - 1) * strides[1] + 1 + padding_size[1];
const index_t filter_h = filter_shape[2];
const index_t filter_w = filter_shape[3];
index_t out_height = extended_input_height - filter_h + 1;
index_t out_width = extended_input_width - filter_w + 1;
index_t out_height =
(in_height - 1) * strides[0] + filter_h -padding_size[0];
index_t out_width =
(in_width - 1) * strides[1] + filter_w -padding_size[1];
output_shape[0] = input_shape[0];
if (isNCHW) {
......@@ -209,6 +208,7 @@ struct Deconv2dFunctorBase {
std::vector<index_t> output_shape_;
const ActivationType activation_;
const float relux_max_limit_;
const bool from_caffe_;
};
template <DeviceType D, typename T>
......@@ -218,13 +218,15 @@ struct Deconv2dFunctor : Deconv2dFunctorBase {
const std::vector<int> &paddings,
const std::vector<index_t> &output_shape,
const ActivationType activation,
const float relux_max_limit)
const float relux_max_limit,
const bool from_caffe)
: Deconv2dFunctorBase(strides,
padding_type,
paddings,
output_shape,
activation,
relux_max_limit) {}
relux_max_limit,
from_caffe) {}
MaceStatus operator()(const Tensor *input, // NCHW
const Tensor *filter, // OIHW
......@@ -236,8 +238,8 @@ struct Deconv2dFunctor : Deconv2dFunctorBase {
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output);
std::vector<index_t> output_shape(4);
if (output_shape_.size() == 4) {
if (!from_caffe_) { // tensorflow
std::vector<index_t> output_shape(4);
output_shape[0] = output_shape_[0];
output_shape[1] = output_shape_[3];
output_shape[2] = output_shape_[1];
......@@ -251,7 +253,7 @@ struct Deconv2dFunctor : Deconv2dFunctorBase {
output_shape.data(),
paddings_.data(), true);
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
} else {
} else { // caffe
output_shape_.clear();
output_shape_ = std::vector<index_t>(4, 0);
CalcDeconvOutputSize(input->shape().data(),
......@@ -268,7 +270,7 @@ struct Deconv2dFunctor : Deconv2dFunctorBase {
const index_t kernel_hw[2] = {kernel_h, kernel_w};
MACE_CHECK(filter->dim(0) == out_shape[1], filter->dim(0), " != ",
output_shape[1]);
out_shape[1]);
MACE_CHECK(filter->dim(1) == in_shape[1], filter->dim(1), " != ",
in_shape[1]);
MACE_CHECK(in_shape[0] == out_shape[0], "Input/Output batch size mismatch");
......@@ -311,13 +313,15 @@ struct Deconv2dFunctor<DeviceType::GPU, T> : Deconv2dFunctorBase {
const std::vector<int> &paddings,
const std::vector<index_t> &output_shape,
const ActivationType activation,
const float relux_max_limit)
const float relux_max_limit,
const bool from_caffe)
: Deconv2dFunctorBase(strides,
padding_type,
paddings,
output_shape,
activation,
relux_max_limit) {}
relux_max_limit,
from_caffe) {}
MaceStatus operator()(const Tensor *input,
const Tensor *filter,
......
#include <common.h>
__kernel void crop(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int offset_b,
__private const int offset_h,
__private const int offset_w,
__private const int offset_chan_blk,
__private const int in_height,
__private const int in_width,
__private const int out_height,
__private const int out_width,
__write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
const int width = global_size_dim1;
#else
const int width = get_global_size(1);
#endif
const int b = hb_idx / out_height;
const int h = hb_idx % out_height;
const int in_chan_blk_idx = chan_blk_idx + offset_chan_blk;
const int in_width_idx = width_idx + offset_w;
const int in_h = h + offset_h;
const int in_b = b + offset_b;
const int in_hb_idx = mad24(in_b, in_height, in_h);
const int in_pos = mad24(in_chan_blk_idx, in_width, in_width_idx);
DATA_TYPE4 data = READ_IMAGET(input, SAMPLER,
(int2)(in_pos, in_hb_idx));
const int pos = mad24(chan_blk_idx, width, width_idx);
WRITE_IMAGET(output, (int2)(pos, hb_idx), data);
}
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/crop.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
namespace {
std::vector<uint32_t> LocalWS(const uint32_t *gws, const uint32_t kwg_size) {
std::vector<uint32_t> lws(4, 0);
uint64_t cache_size = OpenCLRuntime::Global()->device_global_mem_cache_size();
uint32_t base = std::max<uint32_t>(cache_size / kBaseGPUMemCacheSize, 1);
lws[1] = std::min<uint32_t>(gws[1], kwg_size);
lws[0] = std::min<uint32_t>(base, kwg_size / lws[1]);
const uint32_t lws_size = lws[0] * lws[1];
lws[2] = std::max<uint32_t>(std::min<uint32_t>(base, kwg_size / lws_size), 1);
return lws;
}
} // namespace
template <typename T>
MaceStatus CropFunctor<DeviceType::GPU, T>::operator()(
const std::vector<const Tensor *> &input_list,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
const int32_t inputs_count = static_cast<int32_t>(input_list.size());
MACE_CHECK(inputs_count >= 2)
<< "Crop opencl kernel only support 2 elements input";
const Tensor *input0 = input_list[0];
const Tensor *input1 = input_list[1];
const uint32_t in0_dims = static_cast<uint32_t >(input0->dim_size());
const uint32_t in1_dims = static_cast<uint32_t >(input0->dim_size());
MACE_CHECK(in0_dims == 4 && in1_dims == 4,
"Crop op only supports 4-dims inputs now.");
std::vector<int32_t> offsets(4, 0);
std::vector<index_t> output_shape(input0->shape());
switch (axis_) {
case 0:
if (offset_.size() == 1) {
offsets[0] = offset_[0];
offsets[1] = offset_[0];
offsets[2] = offset_[0];
offsets[3] = offset_[0];
} else if (offset_.size() == 4) {
offsets[0] = offset_[0];
offsets[1] = offset_[2];
offsets[2] = offset_[3];
offsets[3] = offset_[1];
}
for (int i = 0; i < 4; ++i) {
output_shape[i] = input1->dim(i);
}
break;
case 1:
if (offset_.size() == 1) {
offsets[1] = offset_[0];
offsets[2] = offset_[0];
offsets[3] = offset_[0];
} else if (offset_.size() == 3) {
offsets[1] = offset_[1];
offsets[2] = offset_[2];
offsets[3] = offset_[0];
}
for (int i = 1; i < 4; ++i) {
output_shape[i] = input1->dim(i);
}
break;
case 2:
if (offset_.size() == 1) {
offsets[1] = offset_[0];
offsets[2] = offset_[0];
} else if (offset_.size() == 2) {
offsets[1] = offset_[0];
offsets[2] = offset_[1];
}
output_shape[1] = input1->dim(1);
output_shape[2] = input1->dim(2);
break;
case 3:
if (offset_.size() == 1) {
offsets[2] = offset_[0];
}
output_shape[2] = input1->dim(2);
break;
default:
MACE_CHECK(axis_ >= 0 && axis_ < 4, "axis is out of boundary.");
break;
}
MACE_CHECK(offsets[3] % 4 == 0,
"MACE opencl only supports cropping channel offset divisible by 4.");
for (index_t i = 0; i < 4; ++i) {
MACE_CHECK(input0->dim(i) - offsets[i] >= input1->dim(i))
<< "the crop for dimension" << i << "is out of bound with size"
<< input1->dim(i) << "and offset" << offsets[i];
}
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, image_shape));
const index_t offset_chan_blk = RoundUpDiv4(offsets[3]);
const index_t channel_blk = RoundUpDiv4(output->dim(3));
const uint32_t gws[3] = {
static_cast<uint32_t>(channel_blk), static_cast<uint32_t>(output->dim(2)),
static_cast<uint32_t>(output->dim(0) * output->dim(1))
};
auto runtime = OpenCLRuntime::Global();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("crop");
built_options.emplace("-Dcrop=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
if (runtime->IsOutOfRangeCheckEnabled()) {
built_options.emplace("-DOUT_OF_RANGE_CHECK");
kernel_error_ = std::move(std::unique_ptr<Buffer>(
new Buffer(GetDeviceAllocator(DeviceType::GPU))));
MACE_RETURN_IF_ERROR(kernel_error_->Allocate(1));
kernel_error_->Map(nullptr);
*(kernel_error_->mutable_data<char>()) = 0;
kernel_error_->UnMap();
}
if (runtime->IsNonUniformWorkgroupsSupported()) {
built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
}
kernel_ = runtime->BuildKernel("crop", kernel_name, built_options);
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
if (!IsVecEqual(input_shape_, input0->shape())) {
uint32_t idx = 0;
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel_.setArg(idx++,
*(static_cast<cl::Buffer *>(kernel_error_->buffer())));
}
if (!runtime->IsNonUniformWorkgroupsSupported()) {
kernel_.setArg(idx++, gws[0]);
kernel_.setArg(idx++, gws[1]);
kernel_.setArg(idx++, gws[2]);
}
kernel_.setArg(idx++, *(input0->opencl_image()));
kernel_.setArg(idx++, static_cast<int>(offsets[0]));
kernel_.setArg(idx++, static_cast<int>(offsets[1]));
kernel_.setArg(idx++, static_cast<int>(offsets[2]));
kernel_.setArg(idx++, static_cast<int>(offset_chan_blk));
kernel_.setArg(idx++, static_cast<int>(input0->dim(1)));
kernel_.setArg(idx++, static_cast<int>(input0->dim(2)));
kernel_.setArg(idx++, static_cast<int>(output->dim(1)));
kernel_.setArg(idx++, static_cast<int>(output->dim(2)));
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input0->shape();
}
const std::vector<uint32_t> lws = LocalWS(gws, kwg_size_);
std::string tuning_key =
Concat("crop_opencl_kernel", output->dim(0), output->dim(1),
output->dim(2), output->dim(3));
TuningOrRun3DKernel(kernel_, tuning_key, gws, lws, future);
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel_error_->Map(nullptr);
char *kerror_code = kernel_error_->mutable_data<char>();
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template struct CropFunctor<DeviceType::GPU, float>;
template struct CropFunctor<DeviceType::GPU, half>;
} // namespace kernels
} // namespace mace
......@@ -173,7 +173,7 @@ MaceStatus Deconv2dFunctor<DeviceType::GPU, T>::operator()(
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output);
if (output_shape_.size() == 4) {
if (!from_caffe_) {
paddings_.clear();
paddings_ = std::vector<int>(2, 0);
CalcDeconvPaddingAndInputSize(input->shape().data(), filter->shape().data(),
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/crop.h"
namespace mace {
namespace ops {
void Register_Crop(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Crop")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
CropOp<DeviceType::CPU, float>);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Crop")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
.Build(),
CropOp<DeviceType::GPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Crop")
.Device(DeviceType::GPU)
.TypeConstraint<half>("T")
.Build(),
CropOp<DeviceType::GPU, half>);
#endif // MACE_ENABLE_OPENCL
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_CROP_H_
#define MACE_OPS_CROP_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/crop.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class CropOp : public Operator<D, T> {
public:
CropOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetOptionalArg<int>("axis", 2),
OperatorBase::GetRepeatedArgs<int>("offset")) {}
MaceStatus Run(StatsFuture *future) override {
MACE_CHECK(this->InputSize() >= 2)
<< "There must be two inputs to crop";
const std::vector<const Tensor *> input_list = this->Inputs();
Tensor *output = this->Output(0);
return functor_(input_list, output, future);
}
private:
kernels::CropFunctor<D, T> functor_;
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_CROP_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void CropHelper(int iters, int crop_axis, int dim1, int offset) {
mace::testing::StopTiming();
OpsTestNet net;
OpDefBuilder("Crop", "CropBM")
.Input("Input0")
.Input("Input1")
.AddIntArg("axis", crop_axis)
.AddIntsArg("offset", {offset})
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
const int kDim0 = 100;
net.AddRandomInput<DeviceType::CPU, T>("Input0", {1, kDim0, dim1, dim1, });
net.AddRandomInput<DeviceType::CPU, T>("Input1",
{1, kDim0 / 2, dim1 / 2, dim1 / 2});
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
const int64_t tot = static_cast<int64_t>(iters) * kDim0 * dim1 * dim1;
mace::testing::MaccProcessed(tot);
testing::BytesProcessed(tot * sizeof(T));
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
}
} // namespace
#define MACE_BM_CROP_CPU_MACRO(AXIS, DIM, OFFSET) \
static void MACE_BM_CROP_CPU_##AXIS##_##DIM##_##OFFSET(int iters) { \
CropHelper<DeviceType::CPU, float>(iters, AXIS, DIM, OFFSET); \
} \
MACE_BENCHMARK(MACE_BM_CROP_CPU_##AXIS##_##DIM##_##OFFSET)
MACE_BM_CROP_CPU_MACRO(1, 256, 3);
MACE_BM_CROP_CPU_MACRO(2, 256, 3);
MACE_BM_CROP_CPU_MACRO(3, 512, 3);
MACE_BM_CROP_CPU_MACRO(2, 512, 6);
namespace {
template <typename T>
void OpenclCropHelper(int iters,
const std::vector<index_t> &shape0,
const std::vector<index_t> &shape1,
int crop_axis,
int offset) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<DeviceType::GPU, float>("Input0", shape0);
net.AddRandomInput<DeviceType::GPU, float>("Input1", shape1);
BufferToImage<DeviceType::GPU, T>(&net, "Input0", "InputImage0",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<DeviceType::GPU, T>(&net, "Input1", "InputImage1",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Crop", "CropBM")
.Input("InputImage0")
.Input("InputImage1")
.AddIntArg("axis", crop_axis)
.AddIntsArg("offset", {offset})
.Output("OutputImage")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(DeviceType::GPU);
}
const int64_t tot =
static_cast<int64_t>(iters) *
(net.GetTensor("Input0")->size() + net.GetTensor("Input1")->size());
mace::testing::MaccProcessed(tot);
testing::BytesProcessed(tot * sizeof(T));
mace::testing::StartTiming();
while (iters--) {
net.RunOp(DeviceType::GPU);
}
}
} // namespace
#define MACE_BM_CROP_GPU_MACRO(N, H, W, C, AXIS, OFFSET, TYPE) \
static void MACE_BM_CROP_GPU_##N##_##H##_##W##_##C##_##AXIS##_##OFFSET##\
_##TYPE(int iters) { \
std::vector<index_t> shape0 = {N, H, W, C}; \
std::vector<index_t> shape1 = {N / 2, H / 2, W / 2, C / 2}; \
OpenclCropHelper<TYPE>(iters, shape0, shape1, AXIS, OFFSET); \
} \
MACE_BENCHMARK(MACE_BM_CROP_GPU_##N##_##H##_##W##_##C##_##AXIS##_##OFFSET\
##_##TYPE)
MACE_BM_CROP_GPU_MACRO(4, 32, 32, 32, 2, 4, float);
MACE_BM_CROP_GPU_MACRO(8, 32, 32, 64, 1, 0, float);
MACE_BM_CROP_GPU_MACRO(8, 32, 32, 128, 0, 0, float);
MACE_BM_CROP_GPU_MACRO(8, 32, 32, 256, 2, 4, float);
MACE_BM_CROP_GPU_MACRO(4, 32, 32, 32, 2, 4, half);
MACE_BM_CROP_GPU_MACRO(8, 32, 32, 64, 1, 0, half);
MACE_BM_CROP_GPU_MACRO(8, 32, 32, 128, 0, 0, half);
MACE_BM_CROP_GPU_MACRO(8, 32, 32, 256, 2, 4, half);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class CropTest : public OpsTestBase {};
namespace {
template <DeviceType D>
void RunCrop(const std::vector<index_t> &input_shape,
const std::vector<float> &input_data,
const std::vector<index_t> &input_shape2,
const std::vector<int> &offset,
const int axis,
const std::vector<index_t> &expected_shape,
const std::vector<float> &expected_data) {
OpsTestNet net;
net.AddInputFromArray<D, float>("Input0", input_shape, input_data);
net.AddRandomInput<D, float>("Input1", input_shape2);
if (D == GPU) {
BufferToImage<D, float>(&net, "Input0", "InputImage0",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Input1", "InputImage1",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Crop", "CropTest")
.Input("InputImage0")
.Input("InputImage1")
.Output("OutputImage")
.AddIntsArg("offset", offset)
.AddIntArg("axis", axis)
.Finalize(net.NewOperatorDef());
} else if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input0",
NHWC,
"InputNCHW0",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Input1",
NHWC,
"InputNCHW1",
NCHW);
OpDefBuilder("Crop", "CropTest")
.Input("InputNCHW0")
.Input("InputNCHW1")
.Output("OutputNCHW")
.AddIntsArg("offset", offset)
.AddIntArg("axis", axis)
.Finalize(net.NewOperatorDef());
}
// Run
net.RunOp(D);
if (D == GPU) {
ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
} else if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW,
"Output", NHWC);
}
// Check
auto expected = CreateTensor<float>(expected_shape, expected_data);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"));
}
} // namespace
TEST_F(CropTest, SimpleCPU) {
RunCrop<DeviceType::CPU>({1, 10, 10, 3},
{1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0}, {1, 5, 5, 3}, {2, 2}, 2,
{1, 5, 5, 3},
{1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
2.0, 2.0, 2.0, 3.0, 3.0, 3.0,
1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
2.0, 2.0, 2.0, 3.0, 3.0, 3.0,
1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
2.0, 2.0, 2.0, 3.0, 3.0, 3.0,
1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
2.0, 2.0, 2.0, 3.0, 3.0, 3.0,
1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
2.0, 2.0, 2.0, 3.0, 3.0, 3.0});
}
TEST_F(CropTest, SimpleGPU) {
RunCrop<DeviceType::GPU>({1, 10, 10, 3},
{1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
4.0, 4.0, 4.0}, {1, 5, 5, 3}, {2, 2}, 2,
{1, 5, 5, 3},
{1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
2.0, 2.0, 2.0, 3.0, 3.0, 3.0,
1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
2.0, 2.0, 2.0, 3.0, 3.0, 3.0,
1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
2.0, 2.0, 2.0, 3.0, 3.0, 3.0,
1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
2.0, 2.0, 2.0, 3.0, 3.0, 3.0,
1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
2.0, 2.0, 2.0, 3.0, 3.0, 3.0});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -34,7 +34,8 @@ class Deconv2dOp : public ConvPool2dOpBase<D, T> {
this->paddings_,
OperatorBase::GetRepeatedArgs<index_t>("output_shape"),
kernels::ActivationType::NOOP,
0.0f) {}
0.0f,
OperatorBase::GetOptionalArg<bool>("from_caffe", false)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -41,7 +41,7 @@ void RunTestSimple(const std::vector<index_t> &input_shape,
net.AddInputFromArray<D, float>("Input", input_shape, input_data);
net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data);
net.TransformDataFormat<D, float>("Filter", HWOI, "FilterOIHW", OIHW);
bool from_caffe = output_shape.size() != 4;
if (D == DeviceType::GPU) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
......@@ -55,6 +55,7 @@ void RunTestSimple(const std::vector<index_t> &input_shape,
.AddIntArg("padding", padding)
.AddIntsArg("padding_values", padding_size)
.AddIntsArg("output_shape", output_shape)
.AddIntArg("from_caffe", from_caffe)
.Finalize(net.NewOperatorDef());
net.RunOp(D);
......@@ -73,6 +74,7 @@ void RunTestSimple(const std::vector<index_t> &input_shape,
.AddIntArg("padding", padding)
.AddIntsArg("padding_values", padding_size)
.AddIntsArg("output_shape", output_shape)
.AddIntArg("from_caffe", from_caffe)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
......@@ -206,17 +208,6 @@ void TestNHWCSimple3x3VALID_S1() {
366, 399, 432, 234, 252, 270, 146, 157, 168, 354, 378, 402, 630,
669, 708, 502, 530, 558, 294, 309, 324, 133, 140, 147, 306, 321,
336, 522, 546, 570, 398, 415, 432, 225, 234, 243});
RunTestSimple<D>(
{1, 3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, 1, Padding::VALID, {4, 4}, {0},
{3, 3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27},
{1, 5, 5, 3},
{1, 2, 3, 6, 9, 12, 18, 24, 30, 26, 31, 36, 21,
24, 27, 14, 19, 24, 54, 66, 78, 126, 147, 168, 130, 146,
162, 90, 99, 108, 66, 78, 90, 198, 225, 252, 405, 450, 495,
366, 399, 432, 234, 252, 270, 146, 157, 168, 354, 378, 402, 630,
669, 708, 502, 530, 558, 294, 309, 324, 133, 140, 147, 306, 321,
336, 522, 546, 570, 398, 415, 432, 225, 234, 243});
}
template <DeviceType D>
......@@ -342,6 +333,7 @@ void TestComplexDeconvNxNS12(const int batch,
paddings.push_back(padding);
paddings.push_back(padding);
}
bool from_caffe = output_shape.size() != 4;
// Construct graph
OpDefBuilder("Deconv2D", "Deconv2dTest")
.Input("InputNCHW")
......@@ -352,6 +344,7 @@ void TestComplexDeconvNxNS12(const int batch,
.AddIntArg("padding", type)
.AddIntsArg("padding_values", paddings)
.AddIntsArg("output_shape", output_shape)
.AddIntArg("from_caffe", from_caffe)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
......@@ -382,6 +375,7 @@ void TestComplexDeconvNxNS12(const int batch,
.AddIntArg("padding", type)
.AddIntsArg("padding_values", paddings)
.AddIntsArg("output_shape", output_shape)
.AddIntArg("from_caffe", from_caffe)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run on device
......
......@@ -28,6 +28,7 @@ extern void Register_Cast(OperatorRegistryBase *op_registry);
extern void Register_ChannelShuffle(OperatorRegistryBase *op_registry);
extern void Register_Concat(OperatorRegistryBase *op_registry);
extern void Register_Conv2D(OperatorRegistryBase *op_registry);
extern void Register_Crop(OperatorRegistryBase *op_registry);
extern void Register_Deconv2D(OperatorRegistryBase *op_registry);
extern void Register_DepthToSpace(OperatorRegistryBase *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry);
......@@ -78,6 +79,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_ChannelShuffle(this);
ops::Register_Concat(this);
ops::Register_Conv2D(this);
ops::Register_Crop(this);
ops::Register_Deconv2D(this);
ops::Register_DepthToSpace(this);
ops::Register_DepthwiseConv2d(this);
......
......@@ -80,6 +80,7 @@ MaceSupportedOps = [
'ChannelShuffle',
'Concat',
'Conv2D',
'Crop',
'Deconv2D',
'DepthToSpace',
'DepthwiseConv2d',
......@@ -160,6 +161,8 @@ class MaceKeyword(object):
mace_transpose_a_str = 'transpose_a'
mace_transpose_b_str = 'transpose_b'
mace_op_data_type_str = 'T'
mace_offset_str = 'offset'
mace_from_caffe_str = 'from_caffe'
class TransformerRule(Enum):
......
......@@ -167,6 +167,7 @@ class CaffeConverter(base_converter.ConverterInterface):
self._op_converters = {
'Input': self.convert_nop,
'Convolution': self.convert_conv2d,
'Deconvolution': self.convert_deconv2d,
'Eltwise': self.convert_elementwise,
'Add': self.convert_add,
'ReLU': self.convert_activation,
......@@ -179,6 +180,7 @@ class CaffeConverter(base_converter.ConverterInterface):
'Softmax': self.convert_softmax,
'InnerProduct': self.convert_fully_connected,
'BatchNorm': self.convert_folded_batchnorm,
'Crop': self.convert_crop,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......@@ -397,6 +399,55 @@ class CaffeConverter(base_converter.ConverterInterface):
bias_data)
op.input.extend([bias_tensor_name])
def convert_deconv2d(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.convolution_param
is_depthwise = False
if param.HasField(caffe_group_str):
filter_data = caffe_op.blobs[0]
mace_check(param.group == filter_data.shape[0] and
filter_data.shape[1] == 1,
"Mace does not support group deconvolution yet")
is_depthwise = True
mace_check(is_depthwise is False,
"Mace do not support depthwise deconvolution yet")
op.type = MaceOp.Deconv2D.name
from_caffe_arg = op.arg.add()
from_caffe_arg.name = MaceKeyword.mace_from_caffe_str
from_caffe_arg.i = 1
self.add_stride_pad_kernel_arg(param, op)
# dilation is specific for convolution in caffe
dilations = [1, 1]
if len(param.dilation) > 0:
dilation_arg = op.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str
if len(param.dilation) == 1:
dilations = [param.dilation[0], param.dilation[0]]
elif len(param.dilation) == 2:
dilations = [param.dilation[0], param.dilation[1]]
mace_check(dilations[0] == 1 and dilations[1] == 1,
"Mace only supports dilation == 1 deconvolution.")
dilation_arg.ints.extend(dilations)
filter_tensor_name = op.name + '_filter'
filter_data = caffe_op.blobs[0]
self.add_tensor(filter_tensor_name, filter_data.shape,
mace_pb2.DT_FLOAT, filter_data)
op.input.extend([filter_tensor_name])
if len(caffe_op.blobs) == 2:
bias_tensor_name = op.name + '_bias'
bias_data = caffe_op.blobs[1]
# caffe of old version has 4-dimension bias, so reshape it
# to single dimension
self.add_tensor(bias_tensor_name, bias_data.reshape(-1).shape,
mace_pb2.DT_FLOAT,
bias_data)
op.input.extend([bias_tensor_name])
def convert_elementwise(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.eltwise_param
......@@ -475,6 +526,24 @@ class CaffeConverter(base_converter.ConverterInterface):
def convert_softmax(self, caffe_op):
self.convert_general_op(caffe_op)
def convert_crop(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.crop_param
op.type = MaceOp.Crop.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 2
if param.HasField('axis'):
axis_arg.i = param.axis
axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i
offset_arg = op.arg.add()
offset_arg.name = MaceKeyword.mace_offset_str
if len(param.offset) > 0:
offset_arg.ints.extend(list(param.offset))
else:
offset_arg.i = 0
def convert_concat(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.concat_param
......
......@@ -32,6 +32,7 @@ class ShapeInference(object):
def __init__(self, net, input_nodes):
self._op_shape_inference = {
MaceOp.Conv2D.name: self.infer_shape_conv_pool_shape,
MaceOp.Deconv2D.name: self.infer_shape_deconv,
MaceOp.DepthwiseConv2d.name: self.infer_shape_conv_pool_shape,
MaceOp.Eltwise.name: self.infer_shape_general,
MaceOp.FoldedBatchNorm.name: self.infer_shape_general,
......@@ -42,6 +43,7 @@ class ShapeInference(object):
MaceOp.Slice.name: self.infer_shape_slice,
MaceOp.Softmax.name: self.infer_shape_general,
MaceOp.FullyConnected.name: self.infer_shape_fully_connected,
MaceOp.Crop.name: self.infer_shape_crop,
}
self._net = net
......@@ -139,6 +141,44 @@ class ShapeInference(object):
self.add_output_shape(op, [output_shape])
def infer_shape_deconv(self, op):
input_shape = self._output_shape_cache[op.input[0]]
output_shape = np.zeros_like(input_shape)
filter_shape = self._output_shape_cache[op.input[1]]
paddings = ConverterUtil.get_arg(op,
MaceKeyword.mace_padding_values_str).ints # noqa
strides = ConverterUtil.get_arg(op, MaceKeyword.mace_strides_str).ints
dilations_arg = ConverterUtil.get_arg(op,
MaceKeyword.mace_dilations_str)
if dilations_arg is not None:
dilations = dilations_arg.ints
else:
dilations = [1, 1]
round_func = math.floor
output_shape[0] = input_shape[0]
if ConverterUtil.data_format(op) == DataFormat.NCHW \
and ConverterUtil.filter_format(self._net) == FilterFormat.OIHW: # noqa
# filter format: IOHW
output_shape[1] = filter_shape[1]
output_shape[2] = int(
round_func((input_shape[2] - 1) * strides[0] +
(filter_shape[2] - 1) * (dilations[0] - 1) +
filter_shape[2] - paddings[0]))
output_shape[3] = int(
round_func((input_shape[3] - 1) * strides[1] +
(filter_shape[3] - 1) * (dilations[1] - 1) +
filter_shape[3] - paddings[1]))
else:
mace_check(False,
"Mace can only infer shape for"
" NCHW input and OIHW filter")
print ("deconv layer %s (%s) input:%s filter:%s output:%s" %
(op.name, op.type, input_shape, filter_shape, output_shape))
self.add_output_shape(op, [output_shape])
def infer_shape_concat(self, op):
output_shape = self._output_shape_cache[op.input[0]]
axis = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str).i
......@@ -166,3 +206,8 @@ class ShapeInference(object):
mace_check(False, "format %s is not supported"
% ConverterUtil.data_format(op))
self.add_output_shape(op, [output_shape])
def infer_shape_crop(self, op):
mace_check(len(op.input) == 2, "crop layer needs two inputs")
output_shape = self._output_shape_cache[op.input[1]]
self.add_output_shape(op, [output_shape])
......@@ -363,15 +363,19 @@ class TensorflowConverter(base_converter.ConverterInterface):
dilation_val = [1, 1]
dilation_arg.ints.extend(dilation_val)
else:
del op.input[1:]
output_shape_arg = op.arg.add()
output_shape_arg.name = MaceKeyword.mace_output_shape_str
output_shape_value = tf_op.inputs[0].eval().astype(np.int32).flat
output_shape_arg.ints.extend(output_shape_value)
self._skip_tensor.add(tf_op.inputs[0].name)
del op.input[0]
if len(tf_op.inputs) >= 3:
del op.input[1:]
output_shape_value =\
tf_op.inputs[0].eval().astype(np.int32).flat
output_shape_arg.ints.extend(output_shape_value)
self._skip_tensor.add(tf_op.inputs[0].name)
del op.input[0]
op.input.extend([tf_op.inputs[2].name, tf_op.inputs[1].name])
else:
output_shape_value = tf_op.get_attr(tf_strides_str)
output_shape_arg.ints.extend(output_shape_value)
def convert_elementwise(self, tf_op):
op = self.convert_general_op(tf_op)
......
......@@ -889,10 +889,7 @@ class Transformer(base_converter.ConverterInterface):
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
if op.type == MaceOp.Deconv2D.name:
filter_data = filter_data.transpose(2, 3, 0, 1)
else:
filter_data = filter_data.transpose(3, 2, 0, 1)
filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
if (op.type == MaceOp.MatMul.name and
......@@ -905,6 +902,14 @@ class Transformer(base_converter.ConverterInterface):
filter.dims[:] = filter_data.shape
self.set_filter_format(FilterFormat.OIHW)
for op in net.op:
if op.type == MaceOp.Deconv2D.name:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(1, 0, 2, 3)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册