提交 b5816ca8 编写于 作者: 李寅

Merge branch 'conv2d-neon' into 'master'

Finish depthwise conv2d 3x3 stride 1/2 NEON kernel.

See merge request !48
......@@ -71,6 +71,51 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
output_shape[3] = output_width;
}
void CalPaddingSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
const int *dilations,
const int *strides,
Padding padding,
int *padding_size) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(padding_size);
index_t output_height, output_width;
index_t k_extent_height = (filter_shape[2] - 1) * dilations[0] + 1;
index_t k_extent_width = (filter_shape[3] - 1) * dilations[1] + 1;
switch (padding) {
case VALID:
output_height = (input_shape[2] - k_extent_height) / strides[0] + 1;
output_width = (input_shape[3] - k_extent_width) / strides[1] + 1;
break;
case SAME:
output_height = (input_shape[2] - 1) / strides[0] + 1;
output_width = (input_shape[3] - 1) / strides[1] + 1;
break;
case FULL:
output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1;
output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1;
break;
default:
MACE_CHECK(false, "Unsupported padding type: ", padding);
}
// Note: TensorFlow may padded one more on the right/bottom side
// TODO may be it's better to also truncate the left/top to
// utilize the more centered features. We need to benchmark
// based on the model accuracy.
padding_size[0] =
(output_height - 1) * strides[0] + k_extent_height - input_shape[2];
padding_size[1] =
(output_width - 1) * strides[1] + k_extent_width - input_shape[3];
}
void ConstructInputWithPadding(const float *input,
const index_t *input_shape,
const int *paddings,
......
......@@ -25,6 +25,13 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
index_t *output_shape,
int *padding_size);
void CalPaddingSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
const int *dilations,
const int *strides,
Padding padding,
int *padding_size);
void ConstructInputWithPadding(const float *input,
const index_t *input_shape,
const int *paddings,
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_DEPTHWISE_CONV_H_
#define MACE_KERNELS_DEPTHWISE_CONV_H_
#include "mace/proto/mace.pb.h"
#include "mace/core/common.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace {
namespace kernels {
template<DeviceType D, typename T>
class DepthwiseConv2dFunctor {
public:
DepthwiseConv2dFunctor(const index_t *input_shape,
const index_t *filter_shape,
const int *strides,
const Padding padding,
const int *dilations) :
strides_(strides),
paddings_(2, 0),
dilations_(dilations) {
CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding, paddings_.data());
}
DepthwiseConv2dFunctor(const int *strides,
const std::vector<int> &paddings,
const int *dilations) :
strides_(strides),
paddings_(paddings),
dilations_(dilations) {}
void operator()(const T *input, // NCHW
const index_t *input_shape,
const T *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const T *bias, // c_out
T *output, // NCHW
const index_t *output_shape) {
MACE_CHECK_NOTNULL(output);
index_t batch = output_shape[0];
index_t channels = output_shape[1];
index_t height = output_shape[2];
index_t width = output_shape[3];
index_t input_batch = input_shape[0];
index_t input_channels = input_shape[1];
index_t input_height = input_shape[2];
index_t input_width = input_shape[3];
index_t kernel_h = filter_shape[2];
index_t kernel_w = filter_shape[3];
int stride_h = strides_[0];
int stride_w = strides_[1];
int dilation_h = dilations_[0];
int dilation_w = dilations_[1];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
// The left-upper most offset of the padded input
int padded_h_start = 0 - paddings_[0] / 2;
int padded_w_start = 0 - paddings_[1] / 2;
index_t padded_h_stop = input_height + paddings_[0] - paddings_[0] / 2;
index_t padded_w_stop = input_width + paddings_[1] - paddings_[1] / 2;
index_t kernel_size = filter_shape[1] * kernel_h * kernel_w;
index_t multiplier = channels / input_channels;
#pragma omp parallel for collapse(2)
for (int n = 0; n < batch; ++n) {
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
index_t offset = n * channels * height * width +
c * height * width + h * width + w;
T sum = 0;
const T *filter_ptr = filter + c * kernel_size;
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
if (inh < 0 || inh >= input_height || inw < 0 ||
inw >= input_width) {
MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop &&
inw >= padded_w_start && inw < padded_w_stop,
"Out of range read from input: ", inh, ", ",
inw);
// else padding with 0:
// sum += 0;
} else {
index_t input_offset =
n * input_channels * input_height * input_width +
(c / multiplier) * input_height * input_width + inh * input_width +
inw;
sum += input[input_offset] * *filter_ptr;
}
++filter_ptr;
}
}
output[offset] = sum + bias[c];
}
}
}
}
}
private:
const int *strides_; // [stride_h, stride_w]
std::vector<int> paddings_; // [padding_h, padding_w]
const int *dilations_; // [dilation_h, dilation_w]
};
template<>
void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_DEPTHWISE_CONV_H_
......@@ -11,6 +11,7 @@ namespace kernels {
extern void Conv2dNeonK1x1S1(const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
......@@ -18,6 +19,7 @@ extern void Conv2dNeonK1x1S1(const float *input,
extern void Conv2dNeonK3x3S1(const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
......@@ -25,6 +27,7 @@ extern void Conv2dNeonK3x3S1(const float *input,
extern void Conv2dNeonK3x3S2(const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
......@@ -32,6 +35,7 @@ extern void Conv2dNeonK3x3S2(const float *input,
extern void Conv2dNeonK5x5S1(const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
......@@ -48,6 +52,7 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input,
const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
......@@ -81,7 +86,7 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input,
input_shape = padded_input.shape().data();
}
auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_neon_func(input, input_shape, filter, bias, output, output_shape);
conv2d_neon_func(input, input_shape, filter, nullptr, bias, output, output_shape);
}
} // namespace kernels
......
......@@ -8,12 +8,13 @@
namespace mace {
namespace kernels {
void Conv2dNeonK1x1S1(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape) {
void Conv2dNeonK1x1S1(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
const index_t batch = output_shape[0];
const index_t channels = output_shape[1];
const index_t height = output_shape[2];
......@@ -25,7 +26,7 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
const index_t input_width = input_shape[3];
MACE_CHECK(input_batch == batch && input_height == height &&
input_width == width);
input_width == width);
const index_t total_pixels = height * width;
// Process 4 * 2 = 8 pixels for each innermost loop
......@@ -35,17 +36,17 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
// benchmark omp collapsed(2)
for (index_t n = 0; n < batch; ++n) {
const float* filter_ptr = filter;
const float *filter_ptr = filter;
#pragma omp parallel for
for (index_t c = 0; c < channels; ++c) {
// TODO Will GCC opt these out?
float* channel_output_start =
float *channel_output_start =
output + n * channels * height * width + c * height * width;
const float* input_ptr =
const float *input_ptr =
input + n * input_channels * input_height * input_width;
// Fill with bias
float* output_ptr = channel_output_start;
float *output_ptr = channel_output_start;
for (index_t ptr = 0; ptr < total_pixels; ++ptr) {
output_ptr[ptr] = bias[c]; // TODO can we avoid this?
}
......@@ -53,15 +54,15 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
index_t inc = 0;
// Process 4 input channels in batch
for (; inc + 3 < input_channels; inc += 4) {
float* output_ptr = channel_output_start;
float *output_ptr = channel_output_start;
// The begining of each input feature map channel
MACE_ASSERT(input_ptr ==
input + n * input_channels * input_height * input_width +
inc * input_height * input_width);
input + n * input_channels * input_height * input_width +
inc * input_height * input_width);
const float* input_ptr1 = input_ptr + total_pixels;
const float* input_ptr2 = input_ptr1 + total_pixels;
const float* input_ptr3 = input_ptr2 + total_pixels;
const float *input_ptr1 = input_ptr + total_pixels;
const float *input_ptr2 = input_ptr1 + total_pixels;
const float *input_ptr3 = input_ptr2 + total_pixels;
// filter is in c_out, c_in, 1, 1 order
MACE_ASSERT(filter_ptr == filter + c * input_channels + inc);
......@@ -139,10 +140,10 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW
}
// Process the remaining channels
for (; inc < input_channels; ++inc) {
float* output_ptr = channel_output_start;
float *output_ptr = channel_output_start;
MACE_ASSERT(input_ptr ==
input + n * input_channels * input_height * input_width +
inc * input_height * input_width);
input + n * input_channels * input_height * input_width +
inc * input_height * input_width);
MACE_ASSERT(filter_ptr == filter + c * input_channels + inc);
const float k0 = filter_ptr[0];
......
......@@ -17,30 +17,35 @@ namespace kernels {
int input_channels = input_shape[1]; \
int input_height = input_shape[2]; \
int input_width = input_shape[3]; \
int kernel_h = 3; \
int kernel_w = 3; \
int multiplier = filter_shape == nullptr ? 0 : (filter_shape[0] / input_channels); \
int filter_in_channels = filter_shape == nullptr ? input_channels : filter_shape[1]; \
for (int b = 0; b < output_batch; ++b) { \
float* output_ptr_base = output + b * output_channels * output_height * output_width; \
float *output_ptr_base = output + b * output_channels * output_height * output_width; \
for (int oc = 0; oc < output_channels; ++oc) { \
const float* filter_ptr = filter + oc * input_channels * kernel_h * kernel_w; \
const float* input_ptr = input + b * input_channels * input_height * input_width; \
float* output_ptr = output_ptr_base + oc * output_height * output_width; \
const float *filter_ptr = filter + oc * filter_in_channels * kFilterSize; \
const float *input_ptr = input + b * input_channels * input_height * input_width; \
if (filter_shape != nullptr) { \
input_ptr += (oc / multiplier) * input_height * input_width; \
} \
float *output_ptr = output_ptr_base + oc * output_height * output_width; \
std::fill(output_ptr, output_ptr + output_height * output_width, bias[oc]); \
for (int ic = 0; ic < input_channels; ++ic) { \
for (int ic = 0; ic < filter_in_channels; ++ic) { \
float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)};
#define KERNEL_TAIL_CODE \
filter_ptr += 9; \
filter_ptr += kFilterSize; \
input_ptr += input_height * input_width; \
} \
} \
}
static const int kRegisterSize = 4;
static const int kFilterSize = 9;
void Conv2dNeonK3x3S1(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
......@@ -213,6 +218,7 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
void Conv2dNeonK3x3S2(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
......@@ -287,7 +293,6 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW
KERNEL_TAIL_CODE
}
#undef KERNEL_HEAD_CODE
#undef KERNEL_TAIL_CODE
......
......@@ -10,12 +10,13 @@
namespace mace {
namespace kernels {
void Conv2dNeonK5x5S1(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape) {
void Conv2dNeonK5x5S1(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
const index_t batch = output_shape[0];
const index_t channels = output_shape[1];
const index_t height = output_shape[2];
......@@ -39,9 +40,9 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (index_t c = 0; c < channels; ++c) {
float* output_ptr = output + n * output_total_pixels_per_batch +
c * output_total_pixels_per_channel;
const float* input_ptr = input + n * input_total_pixels_per_batch;
float *output_ptr = output + n * output_total_pixels_per_batch +
c * output_total_pixels_per_channel;
const float *input_ptr = input + n * input_total_pixels_per_batch;
// Fill with bias
for (index_t i = 0; i < output_total_pixels_per_channel; ++i) {
......@@ -49,24 +50,24 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW
}
for (index_t inc = 0; inc < input_channels; ++inc) {
float* outptr = output_ptr;
float* outptr2 = outptr + width;
const float* inptr = input_ptr + inc * input_total_pixels_per_channel;
const float* filter_ptr = filter + c * patch_size + inc * 25;
const float* r0 = inptr;
const float* r1 = inptr + input_width;
const float* r2 = inptr + input_width * 2;
const float* r3 = inptr + input_width * 3;
const float* r4 = inptr + input_width * 4;
const float* r5 = inptr + input_width * 5;
const float* k0 = filter_ptr;
const float* k1 = filter_ptr + 5;
const float* k2 = filter_ptr + 10;
const float* k3 = filter_ptr + 15;
const float* k4 = filter_ptr + 20;
float *outptr = output_ptr;
float *outptr2 = outptr + width;
const float *inptr = input_ptr + inc * input_total_pixels_per_channel;
const float *filter_ptr = filter + c * patch_size + inc * 25;
const float *r0 = inptr;
const float *r1 = inptr + input_width;
const float *r2 = inptr + input_width * 2;
const float *r3 = inptr + input_width * 3;
const float *r4 = inptr + input_width * 4;
const float *r5 = inptr + input_width * 5;
const float *k0 = filter_ptr;
const float *k1 = filter_ptr + 5;
const float *k2 = filter_ptr + 10;
const float *k3 = filter_ptr + 15;
const float *k4 = filter_ptr + 20;
float32x4_t _k0123 = vld1q_f32(filter_ptr);
float32x4_t _k4567 = vld1q_f32(filter_ptr + 4);
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/depthwise_conv2d.h"
#include "mace/kernels/conv_2d.h"
namespace mace {
namespace kernels {
extern void Conv2dNeonK3x3S1(const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
extern void Conv2dNeonK3x3S2(const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
template<>
void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
typedef void (*Conv2dNeonFunction)(
const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
// Selection matrix: kernel_size x stride_size
static const Conv2dNeonFunction selector[5][2] = {
{nullptr, nullptr},
{nullptr, nullptr},
{Conv2dNeonK3x3S1, Conv2dNeonK3x3S2},
{nullptr, nullptr},
{nullptr, nullptr}};
// not implement yet
index_t kernel_h = filter_shape[2];
index_t kernel_w = filter_shape[3];
if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
LOG(WARNING) << "Depthwise-Conv2d NEON kernel with "
<< "filter" << kernel_h << "x" << kernel_w << ","
<< " stride " << strides_[0] << "x" << strides_[1]
<< " is not implemented yet, using slow version";
DepthwiseConv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)(
input, input_shape, filter, filter_shape, bias, output, output_shape);
return;
}
// Keep this alive during kernel execution
Tensor padded_input;
if (paddings_[0] > 0 || paddings_[1] > 0) {
ConstructInputWithPadding(input, input_shape, paddings_.data(), &padded_input);
input = padded_input.data<float>();
input_shape = padded_input.shape().data();
}
auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_neon_func(input, input_shape, filter, filter_shape, bias, output, output_shape);
}
} // namespace kernels
} // namespace mace
\ No newline at end of file
......@@ -3,7 +3,6 @@
//
#include "mace/ops/conv_2d.h"
#include "mace/proto/mace.pb.h"
namespace mace {
......
......@@ -13,17 +13,17 @@
namespace mace {
template <DeviceType D, typename T>
template<DeviceType D, typename T>
class Conv2dOp : public ConvPool2dOpBase<D, T> {
public:
Conv2dOp(const OperatorDef& op_def, Workspace* ws)
: ConvPool2dOpBase<D, T>(op_def, ws){};
Conv2dOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws) {};
bool Run() override {
const Tensor* input = this->Input(INPUT);
const Tensor* filter = this->Input(FILTER);
const Tensor* bias = this->Input(BIAS);
Tensor* output = this->Output(OUTPUT);
const Tensor *input = this->Input(INPUT);
const Tensor *filter = this->Input(FILTER);
const Tensor *bias = this->Input(BIAS);
Tensor *output = this->Output(OUTPUT);
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
......
......@@ -56,7 +56,7 @@ static void Conv2d(int iters,
#define BM_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE) \
static void \
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE( \
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
......@@ -64,7 +64,7 @@ static void Conv2d(int iters,
Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \
} \
BENCHMARK( \
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE)
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE)
#define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \
......@@ -74,9 +74,11 @@ BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128, float); // Test bad alignments
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 3, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 128, float);
BM_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 3, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 128, float);
......
......@@ -173,10 +173,10 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
// generate random input
index_t batch = 1 + rand() % 10;
index_t input_channels = 1 + rand() % 50;
index_t height = 11 + rand() % 100;
index_t width = 11 + rand() % 100;
index_t output_channels = 1 + rand() % 50;
index_t input_channels = 1 + rand() % 10;
index_t height = 107;
index_t width = 113;
index_t output_channels = 1 + rand() % 10;
// Construct graph
auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest")
......
......@@ -18,8 +18,49 @@ class ConvPool2dOpBase : public Operator<D, T> {
strides_(OperatorBase::GetRepeatedArgument<int>("strides")),
padding_(static_cast<Padding>(OperatorBase::GetSingleArgument<int>(
"padding", static_cast<int>(SAME)))),
dilations_(OperatorBase::GetRepeatedArgument<int>("dilations")) {}
dilations_(OperatorBase::GetRepeatedArgument<int>("dilations", {1, 1})) {}
void CalOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
index_t *output_shape) {
MACE_CHECK(dilations_[0] > 0 && dilations_[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations_[0] == 1 || strides_[0] == 1) &&
(dilations_[1] == 1 || strides_[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
/*
* Convlution/pooling arithmetic:
* o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
index_t output_height, output_width;
switch (padding_) {
case VALID:
output_height = (input_shape[2] - (filter_shape[2] - 1) * dilations_[0] - 1) / strides_[0] + 1;
output_width = (input_shape[3] - (filter_shape[3] - 1) * dilations_[1] - 1) / strides_[1] + 1;
break;
case SAME:
output_height = (input_shape[2] - 1) / strides_[0] + 1;
output_width = (input_shape[3] - 1) / strides_[1] + 1;
break;
case FULL:
output_height = (input_shape[2] + (filter_shape[2] - 1) * dilations_[0] - 1) / strides_[0] + 1;
output_width = (input_shape[3] + (filter_shape[3] - 1) * dilations_[1] - 1) / strides_[1] + 1;
break;
default:
MACE_CHECK(false, "Unsupported padding type: ", padding_);
}
output_shape[0] = input_shape[0];
output_shape[1] = filter_shape[0];
output_shape[2] = output_height;
output_shape[3] = output_width;
}
protected:
std::vector<int> strides_;
Padding padding_;
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/depthwise_conv2d.h"
namespace mace {
REGISTER_CPU_OPERATOR(DepthwiseConv2d, DepthwiseConv2dOp<DeviceType::CPU, float>);
#if __ARM_NEON
REGISTER_NEON_OPERATOR(DepthwiseConv2d, DepthwiseConv2dOp<DeviceType::NEON, float>);
#endif // __ARM_NEON
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_DEPTHWISE_CONV_H_
#define MACE_OPS_DEPTHWISE_CONV_H_
#include <memory>
#include "mace/core/operator.h"
#include "mace/kernels/conv_2d.h"
#include "mace/ops/conv_pool_2d_base.h"
#include "mace/kernels/depthwise_conv2d.h"
namespace mace {
template<DeviceType D, typename T>
class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
public:
DepthwiseConv2dOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws),
functor_(this->Input(INPUT)->shape().data(),
this->Input(FILTER)->shape().data(),
this->strides_.data(), this->padding_, this->dilations_.data()) {};
bool Run() override {
const Tensor *input = this->Input(INPUT);
const Tensor *filter = this->Input(FILTER);
const Tensor *bias = this->Input(BIAS);
Tensor *output = this->Output(OUTPUT);
// resize filter shape.
std::vector<index_t> filter_shape(filter->shape().begin(), filter->shape().end());
filter_shape[0] *= filter_shape[1];
filter_shape[1] = 1;
std::vector<index_t> output_shape(4);
this->CalOutputSize(input->shape().data(), filter_shape.data(), output_shape.data());
output->Resize(output_shape);
functor_(input->data<T>(), input->shape().data(), filter->data<T>(),
filter_shape.data(), bias->data<T>(), output->mutable_data<T>(),
output->shape().data());
return true;
}
private:
kernels::DepthwiseConv2dFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, FILTER, BIAS);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace mace
#endif // MACE_OPS_DEPTHWISE_CONV_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/conv_2d.h"
#include "mace/ops/ops_test_util.h"
using namespace mace;
class DepthwiseConv2dOpTest : public OpsTestBase {};
TEST_F(DepthwiseConv2dOpTest, Simple_VALID) {
// Construct graph
auto& net = test_net();
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.Finalize(net.operator_def());
// Add args
net.AddIntsArg("strides", {1, 1});
net.AddIntArg("padding", Padding::VALID);
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddInputFromArray<float>(
"Input", {1, 2, 2, 3},
{1, 3, 5, 7, 9, 11, 2, 4, 6, 8, 10, 12});
net.AddInputFromArray<float>(
"Filter", {2, 2, 2, 2},
{1.0f, 5.0f, 9.0f, 13.0f,
2.0f, 6.0f, 10.0f, 14.0f,
3.0f, 7.0f, 11.0f, 15.0f,
4.0f, 8.0f, 12.0f, 16.0f});
net.AddInputFromArray<float>("Bias", {4}, {.1f, .2f, .3f, .4f});
// Run
net.RunOp();
// Check
auto expected = CreateTensor<float>({1, 4, 1, 2},
{196.1f, 252.1f, 216.2f, 280.2f,
272.3f, 344.3f, 296.4f, 376.4f});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) {
srand(time(NULL));
// generate random input
index_t batch = 2 + rand() % 10;
index_t input_channels = 3 + rand() % 10;
index_t height = 107;
index_t width = 113;
index_t multiplier = 3 + rand() % 10;
// Construct graph
auto& net = test_net();
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.Finalize(net.operator_def());
// Add args
net.AddIntsArg("strides", {stride_h, stride_w});
net.AddIntArg("padding", type);
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddRandomInput<float>("Input", {batch, input_channels, height, width});
net.AddRandomInput<float>(
"Filter", {multiplier, input_channels, kernel_h, kernel_w});
net.AddRandomInput<float>("Bias", {multiplier * input_channels});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run NEON
net.RunOp(DeviceType::NEON);
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-3);
};
for (int kernel_size : {3}) {
for (int stride : {1, 2}) {
func(kernel_size, kernel_size, stride, stride, VALID);
func(kernel_size, kernel_size, stride, stride, SAME);
}
}
}
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <algorithm>
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/conv_2d.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
template <DeviceType D, typename T>
static void DepthwiseConv2d(int iters,
int batch,
int channels,
int height,
int width,
int kernel_h,
int kernel_w,
int stride,
Padding padding,
int output_channels) {
mace::testing::StopTiming();
OpsTestNet net;
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.Finalize(net.operator_def());
// Add args
net.AddIntsArg("strides", {stride, stride});
net.AddIntArg("padding", padding);
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddRandomInput<float>("Input", {batch, channels, height, width});
net.AddRandomInput<float>("Filter",
{output_channels, channels, kernel_h, kernel_w});
net.AddRandomInput<float>("Bias", {output_channels});
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
}
#define BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE) \
static void \
BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \
DepthwiseConv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \
} \
BENCHMARK( \
BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE)
#define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 3, float);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, SAME, 3, float);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 128, float);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 128, float);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 128, float);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 128, float);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 3, float);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 3, float);
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册