提交 8bb5716a 编写于 作者: L Liangliang He

Merge branch 'conv2d-neon' into 'master'

Neon conv2d 3x3 stride 2 kernel.

See merge request !44
...@@ -22,6 +22,13 @@ extern void Conv2dNeonK3x3S1(const float *input, ...@@ -22,6 +22,13 @@ extern void Conv2dNeonK3x3S1(const float *input,
float *output, float *output,
const index_t *output_shape); const index_t *output_shape);
extern void Conv2dNeonK3x3S2(const float *input,
const index_t *input_shape,
const float *filter,
const float *bias,
float *output,
const index_t *output_shape);
extern void Conv2dNeonK5x5S1(const float *input, extern void Conv2dNeonK5x5S1(const float *input,
const index_t *input_shape, const index_t *input_shape,
const float *filter, const float *filter,
...@@ -30,27 +37,25 @@ extern void Conv2dNeonK5x5S1(const float *input, ...@@ -30,27 +37,25 @@ extern void Conv2dNeonK5x5S1(const float *input,
const index_t *output_shape); const index_t *output_shape);
template <> template <>
void Conv2dFunctor<DeviceType::NEON, void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input,
float>:: const index_t *input_shape,
operator()(const float *input, // NCHW const float *filter,
const index_t *input_shape, const index_t *filter_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w const float *bias,
const index_t *filter_shape, float *output,
const float *bias, // c_out const index_t *output_shape) {
float *output, // NCHW
const index_t *output_shape) {
typedef void (*Conv2dNeonFunction)( typedef void (*Conv2dNeonFunction)(
const float *input, // NCHW const float *input,
const index_t *input_shape, const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w const float *filter,
const float *bias, // c_out const float *bias,
float *output, // NCHW float *output,
const index_t *output_shape); const index_t *output_shape);
// Selection matrix: kernel_size x stride_size // Selection matrix: kernel_size x stride_size
static const Conv2dNeonFunction selector[5][2] = { static const Conv2dNeonFunction selector[5][2] = {
{Conv2dNeonK1x1S1, nullptr}, {Conv2dNeonK1x1S1, nullptr},
{nullptr, nullptr}, {nullptr, nullptr},
{Conv2dNeonK3x3S1, nullptr}, {Conv2dNeonK3x3S1, Conv2dNeonK3x3S2},
{nullptr, nullptr}, {nullptr, nullptr},
{Conv2dNeonK5x5S1, nullptr}}; {Conv2dNeonK5x5S1, nullptr}};
// not implement yet // not implement yet
...@@ -59,7 +64,10 @@ operator()(const float *input, // NCHW ...@@ -59,7 +64,10 @@ operator()(const float *input, // NCHW
if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] || if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 || strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) { selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
LOG(WARNING) << "NEON conv2d kernel not implementated, using slow vesion"; LOG(WARNING) << "NEON conv2d kernel with "
<< "filter" << kernel_h << "x" << kernel_w << ","
<< " stride " << strides_[0] << "x" << strides_[1]
<< " is not implemented yet, using slow version";
Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)( Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)(
input, input_shape, filter, filter_shape, bias, output, output_shape); input, input_shape, filter, filter_shape, bias, output, output_shape);
return; return;
......
...@@ -8,221 +8,288 @@ ...@@ -8,221 +8,288 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
#define KERNEL_HEAD_CODE \
int output_batch = output_shape[0]; \
int output_channels = output_shape[1]; \
int output_height = output_shape[2]; \
int output_width = output_shape[3]; \
int input_batch = input_shape[0]; \
int input_channels = input_shape[1]; \
int input_height = input_shape[2]; \
int input_width = input_shape[3]; \
int kernel_h = 3; \
int kernel_w = 3; \
for (int b = 0; b < output_batch; ++b) { \
float* output_ptr_base = output + b * output_channels * output_height * output_width; \
for (int oc = 0; oc < output_channels; ++oc) { \
const float* filter_ptr = filter + oc * input_channels * kernel_h * kernel_w; \
const float* input_ptr = input + b * input_channels * input_height * input_width; \
float* output_ptr = output_ptr_base + oc * output_height * output_width; \
std::fill(output_ptr, output_ptr + output_height * output_width, bias[oc]); \
for (int ic = 0; ic < input_channels; ++ic) { \
float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)};
#define KERNEL_TAIL_CODE \
filter_ptr += 9; \
input_ptr += input_height * input_width; \
} \
} \
}
static const int kRegisterSize = 4; static const int kRegisterSize = 4;
void Conv2dNeonK3x3S1(const float* input, // NCHW void Conv2dNeonK3x3S1(const float *input, // NCHW
const index_t* input_shape, const index_t *input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w const float *filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out const float *bias, // c_out
float* output, // NCHW float *output, // NCHW
const index_t* output_shape) { const index_t *output_shape) {
int batch = output_shape[0];
int channels = output_shape[1]; int height_count = (output_shape[2] >> 1) << 1;
int height = output_shape[2];
int width = output_shape[3]; KERNEL_HEAD_CODE
int input_batch = input_shape[0]; const float *row_ptr_v[kRegisterSize] = {
int input_channels = input_shape[1]; input_ptr, input_ptr + input_width,
int input_height = input_shape[2]; input_ptr + 2 * input_width, input_ptr + 3 * input_width
int input_width = input_shape[3]; };
int kernel_h = 3; float *output_ptr_v[] = {output_ptr, output_ptr + output_width};
int kernel_w = 3;
int height_count = (height >> 1) << 1;
for (int b = 0; b < batch; ++b) {
float* output_ptr_base = output + b * channels * height * width;
for (int oc = 0; oc < channels; ++oc) {
const float* filter_ptr =
filter + oc * input_channels * kernel_h * kernel_w;
const float* input_ptr =
input + b * input_channels * input_height * input_width;
float* output_ptr = output_ptr_base + oc * height * width;
std::fill(output_ptr, output_ptr + height * width, bias[oc]);
for (int ic = 0; ic < input_channels; ++ic) {
float32x4_t filter0 = vld1q_f32(filter_ptr);
float32x4_t filter3 = vld1q_f32(filter_ptr + 3);
float32x4_t filter6 = vld1q_f32(filter_ptr + 6);
const float* row[kRegisterSize] = {input_ptr, input_ptr + input_width,
input_ptr + 2 * input_width,
input_ptr + 3 * input_width};
float* output_ptr1 = output_ptr;
float* output_ptr2 = output_ptr + width;
for (int h = 0; h < height_count; h += 2) { for (int h = 0; h < height_count; h += 2) {
int count = width >> 2; int count = output_width >> 2;
int remain_count = width & 3; int remain_count = output_width & 3;
for (; count > 0; --count) { for (; count > 0; --count) {
float32x4_t sum0 = vdupq_n_f32(.0f); float32x4_t n_sum0 = vdupq_n_f32(.0f);
float32x4_t sum1 = vdupq_n_f32(.0f);
float32x4_t row0_ext_0 = vld1q_f32(row[0]); // 0123 float32x4_t n_row_former = vld1q_f32(row_ptr_v[0]);
float32x4_t row0_latter = float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + kRegisterSize);
vld1q_f32(row[0] + kRegisterSize); // 4567 float32x4_t n_row_ext0 = vextq_f32(n_row_former, n_row_latter, 1);
float32x4_t row0_ext_1 = float32x4_t n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 2);
vextq_f32(row0_ext_0, row0_latter, 1); // 1234 n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_former, n_filter_v[0], 0);
float32x4_t row0_ext_2 = n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext0, n_filter_v[0], 1);
vextq_f32(row0_ext_0, row0_latter, 2); // 2345 n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext1, n_filter_v[0], 2);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter0, 0); float32x4_t n_row1_former = vld1q_f32(row_ptr_v[1]);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter0, 1); float32x4_t n_row1_latter = vld1q_f32(row_ptr_v[1] + kRegisterSize);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2); float32x4_t n_row1_ext0 = vextq_f32(n_row1_former, n_row1_latter, 1);
float32x4_t n_row1_ext1 = vextq_f32(n_row1_former, n_row1_latter, 2);
float32x4_t row1_ext_0 = vld1q_f32(row[1]); // 0123 n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_former, n_filter_v[1], 0);
float32x4_t row1_latter = n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_ext0, n_filter_v[1], 1);
vld1q_f32(row[1] + kRegisterSize); // 4567 n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_ext1, n_filter_v[1], 2);
float32x4_t row1_ext_1 =
vextq_f32(row1_ext_0, row1_latter, 1); // 1234 n_row_former = vld1q_f32(row_ptr_v[2]);
float32x4_t row1_ext_2 = n_row_latter = vld1q_f32(row_ptr_v[2] + kRegisterSize);
vextq_f32(row1_ext_0, row1_latter, 2); // 2345 n_row_ext0 = vextq_f32(n_row_former, n_row_latter, 1);
n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 2);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_0, filter3, 0); n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_former, n_filter_v[2], 0);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_1, filter3, 1); n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext0, n_filter_v[2], 1);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2); n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext1, n_filter_v[2], 2);
row0_ext_0 = vld1q_f32(row[2]); // 0123
row0_latter = vld1q_f32(row[2] + kRegisterSize); // 4567
row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); // 1234
row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); // 2345
sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter6, 0);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter6, 1);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter6, 2);
// second row // second row
sum1 = vfmaq_laneq_f32(sum1, row1_ext_0, filter0, 0); float32x4_t n_sum1 = vdupq_n_f32(.0f);
sum1 = vfmaq_laneq_f32(sum1, row1_ext_1, filter0, 1);
sum1 = vfmaq_laneq_f32(sum1, row1_ext_2, filter0, 2); n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_former, n_filter_v[0], 0);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext0, n_filter_v[0], 1);
sum1 = vfmaq_laneq_f32(sum1, row0_ext_0, filter3, 0); n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext1, n_filter_v[0], 2);
sum1 = vfmaq_laneq_f32(sum1, row0_ext_1, filter3, 1);
sum1 = vfmaq_laneq_f32(sum1, row0_ext_2, filter3, 2); n_sum1 = vfmaq_laneq_f32(n_sum1, n_row_former, n_filter_v[1], 0);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row_ext0, n_filter_v[1], 1);
row1_ext_0 = vld1q_f32(row[3]); // 0123 n_sum1 = vfmaq_laneq_f32(n_sum1, n_row_ext1, n_filter_v[1], 2);
row1_latter = vld1q_f32(row[3] + kRegisterSize); // 4567
row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); // 1234 n_row1_former = vld1q_f32(row_ptr_v[3]);
row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); // 2345 n_row1_latter = vld1q_f32(row_ptr_v[3] + kRegisterSize);
n_row1_ext0 = vextq_f32(n_row1_former, n_row1_latter, 1);
sum1 = vfmaq_laneq_f32(sum1, row1_ext_0, filter6, 0); n_row1_ext1 = vextq_f32(n_row1_former, n_row1_latter, 2);
sum1 = vfmaq_laneq_f32(sum1, row1_ext_1, filter6, 1); n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_former, n_filter_v[2], 0);
sum1 = vfmaq_laneq_f32(sum1, row1_ext_2, filter6, 2); n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext0, n_filter_v[2], 1);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext1, n_filter_v[2], 2);
float32x4_t output_row0 = vld1q_f32(output_ptr1);
float32x4_t output_row1 = vld1q_f32(output_ptr2); float32x4_t n_output_row = vld1q_f32(output_ptr_v[0]);
output_row0 = vaddq_f32(output_row0, sum0); float32x4_t n_output_row1 = vld1q_f32(output_ptr_v[1]);
output_row1 = vaddq_f32(output_row1, sum1); n_output_row = vaddq_f32(n_output_row, n_sum0);
vst1q_f32(output_ptr1, output_row0); n_output_row1 = vaddq_f32(n_output_row1, n_sum1);
vst1q_f32(output_ptr2, output_row1); vst1q_f32(output_ptr_v[0], n_output_row);
vst1q_f32(output_ptr_v[1], n_output_row1);
output_ptr1 += kRegisterSize; output_ptr_v[0] += kRegisterSize;
output_ptr2 += kRegisterSize; output_ptr_v[1] += kRegisterSize;
for (int i = 0; i < kRegisterSize; ++i) { for (int i = 0; i < kRegisterSize; ++i) {
row[i] += kRegisterSize; row_ptr_v[i] += kRegisterSize;
} }
} }
for (; remain_count > 0; --remain_count) { for (; remain_count > 0; --remain_count) {
float32x4_t row0 = vld1q_f32(row[0]); // 0123 float32x4_t n_row_v[] = {
float32x4_t row1 = vld1q_f32(row[1]); // 0123 vld1q_f32(row_ptr_v[0]),
float32x4_t row2 = vld1q_f32(row[2]); // 0123 vld1q_f32(row_ptr_v[1]),
float32x4_t row3 = vld1q_f32(row[3]); // 0123 vld1q_f32(row_ptr_v[2])
};
float32x4_t sum = vmulq_f32(row0, filter0); float32x4_t n_sum0 = vmulq_f32(n_row_v[0], n_filter_v[0]);
sum = vmlaq_f32(sum, row1, filter3); n_sum0 = vmlaq_f32(n_sum0, n_row_v[1], n_filter_v[1]);
sum = vmlaq_f32(sum, row2, filter6); n_sum0 = vmlaq_f32(n_sum0, n_row_v[2], n_filter_v[2]);
sum = vsetq_lane_f32(*output_ptr1, sum, 3); n_sum0 = vsetq_lane_f32(*output_ptr_v[0], n_sum0, 3);
*output_ptr1 = vaddvq_f32(sum); *output_ptr_v[0] = vaddvq_f32(n_sum0);
sum = vmulq_f32(row1, filter0); float32x4_t n_row3 = vld1q_f32(row_ptr_v[3]);
sum = vmlaq_f32(sum, row2, filter3); float32x4_t n_sum1 = vmulq_f32(n_row_v[1], n_filter_v[0]);
sum = vmlaq_f32(sum, row3, filter6); n_sum1 = vmlaq_f32(n_sum1, n_row_v[2], n_filter_v[1]);
sum = vsetq_lane_f32(*output_ptr2, sum, 3); n_sum1 = vmlaq_f32(n_sum1, n_row3, n_filter_v[2]);
*output_ptr2 = vaddvq_f32(sum); n_sum1 = vsetq_lane_f32(*output_ptr_v[1], n_sum1, 3);
*output_ptr_v[1] = vaddvq_f32(n_sum1);
++output_ptr1;
++output_ptr2; ++output_ptr_v[0];
++output_ptr_v[1];
for (int i = 0; i < kRegisterSize; ++i) { for (int i = 0; i < kRegisterSize; ++i) {
row[i] += 1; row_ptr_v[i] += 1;
} }
} }
output_ptr1 += width; output_ptr_v[0] += output_width;
output_ptr2 += width; output_ptr_v[1] += output_width;
for (int i = 0; i < kRegisterSize; ++i) { for (int i = 0; i < kRegisterSize; ++i) {
row[i] += 2 + input_width; row_ptr_v[i] += 2 + input_width;
} }
} }
if (height != height_count) { if (output_height != height_count) {
int count = width >> 2; int count = output_width >> 2;
int remain_count = width & 3; int remain_count = output_width & 3;
for (; count > 0; --count) { for (; count > 0; --count) {
float32x4_t sum0 = vdupq_n_f32(.0f); float32x4_t n_sum = vdupq_n_f32(.0f);
float32x4_t row0_ext_0 = vld1q_f32(row[0]); // 0123 float32x4_t n_row_former = vld1q_f32(row_ptr_v[0]);
float32x4_t row0_latter = float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + kRegisterSize);
vld1q_f32(row[0] + kRegisterSize); // 4567 float32x4_t n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 1);
float32x4_t row0_ext_1 = float32x4_t n_row_ext2 = vextq_f32(n_row_former, n_row_latter, 2);
vextq_f32(row0_ext_0, row0_latter, 1); // 1234 n_sum = vfmaq_laneq_f32(n_sum, n_row_former, n_filter_v[0], 0);
float32x4_t row0_ext_2 = n_sum = vfmaq_laneq_f32(n_sum, n_row_ext1, n_filter_v[0], 1);
vextq_f32(row0_ext_0, row0_latter, 2); // 2345 n_sum = vfmaq_laneq_f32(n_sum, n_row_ext2, n_filter_v[0], 2);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter0, 0); n_row_former = vld1q_f32(row_ptr_v[1]);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter0, 1); n_row_latter = vld1q_f32(row_ptr_v[1] + kRegisterSize);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2); n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 1);
n_row_ext2 = vextq_f32(n_row_former, n_row_latter, 2);
float32x4_t row1_ext_0 = vld1q_f32(row[1]); // 0123 n_sum = vfmaq_laneq_f32(n_sum, n_row_former, n_filter_v[1], 0);
float32x4_t row1_latter = n_sum = vfmaq_laneq_f32(n_sum, n_row_ext1, n_filter_v[1], 1);
vld1q_f32(row[1] + kRegisterSize); // 4567 n_sum = vfmaq_laneq_f32(n_sum, n_row_ext2, n_filter_v[1], 2);
float32x4_t row1_ext_1 =
vextq_f32(row1_ext_0, row1_latter, 1); // 1234 n_row_former = vld1q_f32(row_ptr_v[2]);
float32x4_t row1_ext_2 = n_row_latter = vld1q_f32(row_ptr_v[2] + kRegisterSize);
vextq_f32(row1_ext_0, row1_latter, 2); // 2345 n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 1);
n_row_ext2 = vextq_f32(n_row_former, n_row_latter, 2);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_0, filter3, 0); n_sum = vfmaq_laneq_f32(n_sum, n_row_former, n_filter_v[2], 0);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_1, filter3, 1); n_sum = vfmaq_laneq_f32(n_sum, n_row_ext1, n_filter_v[2], 1);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2); n_sum = vfmaq_laneq_f32(n_sum, n_row_ext2, n_filter_v[2], 2);
row0_ext_0 = vld1q_f32(row[2]); // 0123 float32x4_t n_output_row = vld1q_f32(output_ptr_v[0]);
row0_latter = vld1q_f32(row[2] + kRegisterSize); // 4567 n_output_row = vaddq_f32(n_output_row, n_sum);
row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); // 1234 vst1q_f32(output_ptr_v[0], n_output_row);
row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); // 2345 output_ptr_v[0] += kRegisterSize;
sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter6, 0);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter6, 1);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter6, 2);
float32x4_t output_row0 = vld1q_f32(output_ptr1);
output_row0 = vaddq_f32(output_row0, sum0);
vst1q_f32(output_ptr1, output_row0);
output_ptr1 += kRegisterSize;
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
row[i] += kRegisterSize; row_ptr_v[i] += kRegisterSize;
} }
} }
for (; remain_count > 0; --remain_count) { for (; remain_count > 0; --remain_count) {
float32x4_t row0 = vld1q_f32(row[0]); // 0123 float32x4_t n_row_v[] = {
float32x4_t row1 = vld1q_f32(row[1]); // 0123 vld1q_f32(row_ptr_v[0]),
float32x4_t row2 = vld1q_f32(row[2]); // 0123 vld1q_f32(row_ptr_v[1]),
vld1q_f32(row_ptr_v[2]),
};
float32x4_t n_sum = vmulq_f32(n_row_v[0], n_filter_v[0]);
n_sum = vmlaq_f32(n_sum, n_row_v[1], n_filter_v[1]);
n_sum = vmlaq_f32(n_sum, n_row_v[2], n_filter_v[2]);
n_sum = vsetq_lane_f32(*output_ptr_v[0], n_sum, 3);
*output_ptr_v[0] = vaddvq_f32(n_sum);
++output_ptr_v[0];
for (int i = 0; i < 3; ++i) {
row_ptr_v[i] += 1;
}
}
}
KERNEL_TAIL_CODE
}
void Conv2dNeonK3x3S2(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
int tail_step = 2 * (input_shape[3] - output_shape[3]);
float32x4_t sum = vmulq_f32(row0, filter0); KERNEL_HEAD_CODE
sum = vmlaq_f32(sum, row1, filter3);
sum = vmlaq_f32(sum, row2, filter6);
sum = vsetq_lane_f32(*output_ptr1, sum, 3);
*output_ptr1 = vaddvq_f32(sum);
++output_ptr1; const float *row_ptr_v[3] = {
input_ptr, input_ptr + input_width, input_ptr + 2 * input_width
};
float *output_ptr_inner = output_ptr;
for (int h = 0; h < output_height; ++h) {
int count = output_width >> 2;
int remain_count = output_width & 3;
for (; count > 0; --count) {
float32x4_t n_sum = vdupq_n_f32(.0f);
float32x4x2_t n_row_former = vld2q_f32(row_ptr_v[0]);
float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + 8);
float32x4_t n_row_ext = vextq_f32(n_row_former.val[0], n_row_latter, 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row_former.val[0], n_filter_v[0], 0);
n_sum = vfmaq_laneq_f32(n_sum, n_row_former.val[1], n_filter_v[0], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext, n_filter_v[0], 2);
float32x4x2_t n_row1_former = vld2q_f32(row_ptr_v[1]);
float32x4_t n_row1_latter = vld1q_f32(row_ptr_v[1] + 8);
float32x4_t n_row1_ext = vextq_f32(n_row1_former.val[0], n_row1_latter, 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row1_former.val[0], n_filter_v[1], 0);
n_sum = vfmaq_laneq_f32(n_sum, n_row1_former.val[1], n_filter_v[1], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row1_ext, n_filter_v[1], 2);
float32x4x2_t n_row2_former = vld2q_f32(row_ptr_v[2]);
float32x4_t n_row2_latter = vld1q_f32(row_ptr_v[2] + 8);
float32x4_t n_row2_ext = vextq_f32(n_row2_former.val[0], n_row2_latter, 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row2_former.val[0], n_filter_v[2], 0);
n_sum = vfmaq_laneq_f32(n_sum, n_row2_former.val[1], n_filter_v[2], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row2_ext, n_filter_v[2], 2);
float32x4_t n_output_row = vld1q_f32(output_ptr_inner);
n_output_row = vaddq_f32(n_output_row, n_sum);
vst1q_f32(output_ptr_inner, n_output_row);
output_ptr_inner += kRegisterSize;
for (int i = 0; i < 3; ++i) {
row_ptr_v[i] += 2 * kRegisterSize;
}
}
for (; remain_count > 0; --remain_count) {
float32x4_t n_row_v[] = {
vld1q_f32(row_ptr_v[0]),
vld1q_f32(row_ptr_v[1]),
vld1q_f32(row_ptr_v[2])
};
float32x4_t n_sum = vmulq_f32(n_row_v[0], n_filter_v[0]);
n_sum = vmlaq_f32(n_sum, n_row_v[1], n_filter_v[1]);
n_sum = vmlaq_f32(n_sum, n_row_v[2], n_filter_v[2]);
n_sum = vsetq_lane_f32(*output_ptr_inner, n_sum, 3);
*output_ptr_inner = vaddvq_f32(n_sum);
++output_ptr_inner;
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
row[i] += 1; row_ptr_v[i] += 2;
} }
} }
for (int i = 0; i < 3; ++i) {
row_ptr_v[i] += tail_step;
}
} }
filter_ptr += 9;
input_ptr += input_height * input_width; KERNEL_TAIL_CODE
}
}
}
} }
} // namespace kernels #undef KERNEL_HEAD_CODE
} // namespace mace #undef KERNEL_TAIL_CODE
} // namespace kernels
} // namespace mace
...@@ -61,8 +61,7 @@ static void Conv2d(int iters, ...@@ -61,8 +61,7 @@ static void Conv2d(int iters,
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \ mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \
Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, \ Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \
OC); \
} \ } \
BENCHMARK( \ BENCHMARK( \
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE) BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE)
...@@ -77,6 +76,10 @@ BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float); ...@@ -77,6 +76,10 @@ BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float); BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float); BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float); BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 128, float);
BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, VALID, 128, float); BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128, float); BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, SAME, 128, float); BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, SAME, 128, float);
......
...@@ -174,8 +174,8 @@ TEST_F(Conv2dOpTest, ConvNxNS12) { ...@@ -174,8 +174,8 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
// generate random input // generate random input
index_t batch = 1 + rand() % 10; index_t batch = 1 + rand() % 10;
index_t input_channels = 1 + rand() % 50; index_t input_channels = 1 + rand() % 50;
index_t height = 7 + rand() % 100; index_t height = 11 + rand() % 100;
index_t width = 7 + rand() % 100; index_t width = 11 + rand() % 100;
index_t output_channels = 1 + rand() % 50; index_t output_channels = 1 + rand() % 50;
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
......
...@@ -155,9 +155,9 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) { ...@@ -155,9 +155,9 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
net.RunOp(DeviceType::NEON); net.RunOp(DeviceType::NEON);
// Check // Check
Tensor expected = CreateTensor<float>({1, 1, 2, 3}, {6, 8, 9, 16, 18, 19}); auto expected = CreateTensor<float>({1, 1, 2, 3}, {6, 8, 9, 16, 18, 19});
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
TEST_F(PoolingOpTest, MAX_k3x3s2x2) { TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
...@@ -183,7 +183,7 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) { ...@@ -183,7 +183,7 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
net.RunOp(DeviceType::NEON); net.RunOp(DeviceType::NEON);
// Check // Check
Tensor expected = CreateTensor<float>({1, 1, 2, 3}, {11, 13, 14, 16, 18, 19}); auto expected = CreateTensor<float>({1, 1, 2, 3}, {11, 13, 14, 16, 18, 19});
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册