提交 55541ad1 编写于 作者: H hjchen2

Refine int8 3x3 depthwise conv

上级 9513ad79
...@@ -14,185 +14,13 @@ limitations under the License. */ ...@@ -14,185 +14,13 @@ limitations under the License. */
#if defined(__ARM_NEON__) && !defined(__aarch64__) #if defined(__ARM_NEON__) && !defined(__aarch64__)
#include "operators/math/depthwise_conv3x3.h"
#ifdef __ARM_NEON__
#include <arm_neon.h> #include <arm_neon.h>
#endif #include "operators/math/depthwise_conv3x3.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
template <int Stride>
inline void Depth3x3ValidColLoadInput(const int8_t *input, const int input_w,
const int valid_cols, int16x8_t *y0,
int16x8_t *y1, int16x8_t *y2) {
PADDLE_MOBILE_THROW_EXCEPTION("Stride %d is not supported.", Stride);
}
template <>
inline void Depth3x3ValidColLoadInput<1>(const int8_t *input, const int input_w,
const int valid_cols, int16x8_t *y0,
int16x8_t *y1, int16x8_t *y2) {
int8_t fake_input[3][8];
if (valid_cols == 1) {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
}
} else if (valid_cols == 2) {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
}
} else {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
fake_input[2][i] = input[2];
}
}
int8x8_t input0 = vld1_s8(fake_input[0]);
int8x8_t input1 = vld1_s8(fake_input[1]);
int8x8_t input2 = vld1_s8(fake_input[2]);
y0[0] = vmovl_s8(input0);
y1[0] = vmovl_s8(input1);
y2[0] = vmovl_s8(input2);
y0[1] = vextq_s16(y0[0], y0[0], 1);
y0[2] = vextq_s16(y0[0], y0[0], 2);
y1[1] = vextq_s16(y1[0], y1[0], 1);
y1[2] = vextq_s16(y1[0], y1[0], 2);
y2[1] = vextq_s16(y2[0], y2[0], 1);
y2[2] = vextq_s16(y2[0], y2[0], 2);
}
template <>
inline void Depth3x3ValidColLoadInput<2>(const int8_t *input, const int input_w,
const int valid_cols, int16x8_t *y0,
int16x8_t *y1, int16x8_t *y2) {
int8_t fake_input[3][13];
if (valid_cols == 1) {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
}
} else if (valid_cols == 2) {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
}
} else {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
fake_input[2][i] = input[2];
}
}
int8x8x2_t input0 = vld2_s8(fake_input[0]);
int8x8x2_t input1 = vld2_s8(fake_input[1]);
int8x8x2_t input2 = vld2_s8(fake_input[2]);
y0[0] = vmovl_s8(input0.val[0]);
y0[1] = vmovl_s8(input0.val[1]);
y0[2] = vextq_s16(y0[0], y0[0], 1);
y1[0] = vmovl_s8(input1.val[0]);
y1[1] = vmovl_s8(input1.val[1]);
y1[2] = vextq_s16(y1[0], y1[0], 1);
y2[0] = vmovl_s8(input2.val[0]);
y2[1] = vmovl_s8(input2.val[1]);
y2[2] = vextq_s16(y2[0], y2[0], 1);
}
template <int Stride_h, int Stride_w>
inline void DepthwiseConv3x3ValidCol(const int8_t *input, const int8_t *filter,
const int h_output, const int h_output_end,
const int w_output, const int input_h,
const int input_w, const int padding_h,
const int padding_w, const int output_w,
int32_t *output) {
const int w_in_start = -padding_w + w_output * Stride_w;
const int w_in_end = w_in_start + 3;
const int w_start = w_in_start > 0 ? w_in_start : 0;
const int w_end = w_in_end < input_w ? w_in_end : input_w;
int remain_start = h_output;
#ifdef __ARM_NEON__
int output_tiles = (h_output_end - h_output) / 6;
remain_start = h_output + output_tiles * 6;
int input_h_start = h_output * Stride_h - padding_h;
size_t input_offset = input_h_start * input_w + w_start;
size_t output_offset = h_output * output_w + w_output;
int16x8_t _input[3][3];
int16x4_t _kernel[3];
int32x4_t _sum0, _sum1;
const int8_t *filter_ptr = filter;
asm volatile(
"mov r0, #3 \n"
"vld1.s8 d10, [%[filter]], r0 \n"
"vld1.s8 d11, [%[filter]], r0 \n"
"vld1.s8 d12, [%[filter]] \n"
"vtrn.8 d10, d11 \n"
"vtrn.8 d12, d13 \n"
"vtrn.16 d10, d12 \n"
"vtrn.16 d11, d13 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmovl.s8 q9, d12 \n"
"vmov.32 %[_kernel0], d14 \n"
"vmov.32 %[_kernel1], d16 \n"
"vmov.32 %[_kernel2], d18 \n"
: [_kernel0] "+w"(_kernel[0]), [_kernel1] "+w"(_kernel[1]),
[_kernel2] "+w"(_kernel[2])
: [filter] "r"(filter_ptr)
: "memory", "q5", "q6", "q7", "q8", "q9", "r0");
int valid_cols = w_end - w_start;
for (int h = 0; h < output_tiles * 6; h += 6) {
int32_t *output0 = output + output_offset;
int32_t *output1 = output0 + output_w;
int32_t *output2 = output1 + output_w;
int32_t *output3 = output2 + output_w;
int32_t *output4 = output3 + output_w;
int32_t *output5 = output4 + output_w;
Depth3x3ValidColLoadInput<Stride_w>(input + input_offset, input_w,
valid_cols, _input[0], _input[1],
_input[2]);
_sum0 = veorq_s32(_sum0, _sum0);
_sum1 = veorq_s32(_sum1, _sum1);
for (int w_in = 0; w_in < valid_cols; ++w_in) {
int index = w_in + w_start - w_in_start;
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][0]),
_kernel[index], 0);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][1]),
_kernel[index], 1);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][2]),
_kernel[index], 2);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][0]),
_kernel[index], 0);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][1]),
_kernel[index], 1);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][2]),
_kernel[index], 2);
}
vst1q_lane_s32(output0, _sum0, 0);
vst1q_lane_s32(output1, _sum0, 1);
vst1q_lane_s32(output2, _sum0, 2);
vst1q_lane_s32(output3, _sum0, 3);
vst1q_lane_s32(output4, _sum1, 0);
vst1q_lane_s32(output5, _sum1, 1);
input_offset += 6 * Stride_h * input_w;
output_offset += 6 * output_w;
}
#endif
for (int h = remain_start; h < h_output_end; ++h) {
int32_t value = 0;
const int h_in_start = -padding_h + h * Stride_h;
for (int i = 0; i < 3; ++i) {
for (int w_in = w_start; w_in < w_end; ++w_in) {
value += filter[i * 3 + (w_in - w_in_start)] *
input[(h_in_start + i) * input_w + w_in];
}
}
output[h * output_w + w_output] = value;
}
}
#define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \ #define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \
for (int w = start; w < end; ++w) { \ for (int w = start; w < end; ++w) { \
const int w_in_start = -padding_w + w * Stride_w; \ const int w_in_start = -padding_w + w * Stride_w; \
...@@ -209,34 +37,19 @@ inline void DepthwiseConv3x3ValidCol(const int8_t *input, const int8_t *filter, ...@@ -209,34 +37,19 @@ inline void DepthwiseConv3x3ValidCol(const int8_t *input, const int8_t *filter,
output_ptr[w] = value; \ output_ptr[w] = value; \
} }
template <int Stride> template <int Stride = 1>
inline void Depth3x3NormalRowLoadInput(const int8_t *input, inline void Depth3x3NormalRowLoadInput(const int8_t *input, int16x8_t *y) {
int16x8_t &y0, // NOLINT y[0] = vmovl_s8(vld1_s8(input));
int16x8_t &y1, // NOLINT y[1] = vextq_s16(y[0], y[0], 1);
int16x8_t &y2) { // NOLINT y[2] = vextq_s16(y[1], y[1], 1);
PADDLE_MOBILE_THROW_EXCEPTION("Stride %d is not supported.", Stride);
}
template <>
inline void Depth3x3NormalRowLoadInput<1>(const int8_t *input,
int16x8_t &y0, // NOLINT
int16x8_t &y1, // NOLINT
int16x8_t &y2) { // NOLINT
int8x8_t x0 = vld1_s8(input);
y0 = vmovl_s8(x0);
y1 = vextq_s16(y0, y0, 1);
y2 = vextq_s16(y1, y1, 1);
} }
template <> template <>
inline void Depth3x3NormalRowLoadInput<2>(const int8_t *input, inline void Depth3x3NormalRowLoadInput<2>(const int8_t *input, int16x8_t *y) {
int16x8_t &y0, // NOLINT
int16x8_t &y1, // NOLINT
int16x8_t &y2) { // NOLINT
int8x8x2_t x0 = vld2_s8(input); int8x8x2_t x0 = vld2_s8(input);
y0 = vmovl_s8(x0.val[0]); y[0] = vmovl_s8(x0.val[0]);
y1 = vmovl_s8(x0.val[1]); y[1] = vmovl_s8(x0.val[1]);
y2 = vextq_s16(y0, y0, 1); y[2] = vextq_s16(y[0], y[0], 1);
} }
template <int Stride_h, int Stride_w> template <int Stride_h, int Stride_w>
...@@ -244,15 +57,14 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter, ...@@ -244,15 +57,14 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter,
const int h_output, const int input_h, const int h_output, const int input_h,
const int input_w, const int padding_h, const int input_w, const int padding_h,
const int padding_w, const int output_w, const int padding_w, const int output_w,
int32_t *output) { int32_t *output, int16x4_t *ker) {
const int h_in_start = -padding_h + h_output * Stride_h; const int h_in_start = -padding_h + h_output * Stride_h;
const int h_in_end = h_in_start + 3; const int h_in_end = h_in_start + 3;
const int h_start = h_in_start > 0 ? h_in_start : 0; const int h_start = h_in_start > 0 ? h_in_start : 0;
const int h_end = h_in_end < input_h ? h_in_end : input_h; const int h_end = h_in_end < input_h ? h_in_end : input_h;
int valid_w_start = (padding_w + Stride_w - 1) / Stride_w; const int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
int valid_w_end = output_w - valid_w_start; const int valid_w_end = (input_w + padding_w - 3) / Stride_w + 1;
int32_t *output_ptr = output + h_output * output_w; int32_t *output_ptr = output + h_output * output_w;
// border left // border left
DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start) DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start)
...@@ -262,14 +74,7 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter, ...@@ -262,14 +74,7 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter,
int output_tiles = (valid_w_end - valid_w_start) / 6; int output_tiles = (valid_w_end - valid_w_start) / 6;
remain_start = valid_w_start + output_tiles * 6; remain_start = valid_w_start + output_tiles * 6;
int32x4_t _sum0, _sum1; int32x4_t _sum0, _sum1;
int16x8_t y0, y1, y2; int16x8_t _y[3];
int16x4_t _kernel[3];
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
int8x8_t w0 = vld1_s8(filter + index * 3);
int16x8_t w1 = vmovl_s8(w0);
_kernel[index] = vget_low_s16(w1);
}
for (int w = 0; w < output_tiles * 6; w += 6) { for (int w = 0; w < output_tiles * 6; w += 6) {
_sum0 = veorq_s32(_sum0, _sum0); _sum0 = veorq_s32(_sum0, _sum0);
_sum1 = veorq_s32(_sum1, _sum1); _sum1 = veorq_s32(_sum1, _sum1);
...@@ -278,19 +83,18 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter, ...@@ -278,19 +83,18 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter,
for (int h_in = h_start; h_in < h_end; ++h_in) { for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start; int index = h_in - h_in_start;
Depth3x3NormalRowLoadInput<Stride_w>( Depth3x3NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, y0, y1, y2); input + h_in * input_w + input_w_offset, _y);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y0), _kernel[index], 0); _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_y[0]), ker[index], 0);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y1), _kernel[index], 1); _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_y[1]), ker[index], 1);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y2), _kernel[index], 2); _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_y[2]), ker[index], 2);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y0), _kernel[index], 0); _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_y[0]), ker[index], 0);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y1), _kernel[index], 1); _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_y[1]), ker[index], 1);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y2), _kernel[index], 2); _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_y[2]), ker[index], 2);
} }
vst1q_s32(output_ptr + output_offset, _sum0); vst1q_s32(output_ptr + output_offset, _sum0);
vst1q_lane_s32(output_ptr + output_offset + 4, _sum1, 0); vst1_s32(output_ptr + output_offset + 4, vget_low_s32(_sum1));
vst1q_lane_s32(output_ptr + output_offset + 5, _sum1, 1);
} }
#endif #endif // __ARM_NEON__
for (int w = remain_start; w < valid_w_end; ++w) { for (int w = remain_start; w < valid_w_end; ++w) {
int32_t value = 0; int32_t value = 0;
int input_start = -padding_w + w * Stride_w; int input_start = -padding_w + w * Stride_w;
...@@ -306,14 +110,6 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter, ...@@ -306,14 +110,6 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter,
DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w) DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w)
} }
// template<>
// void DepthwiseConv3x3<int8_t, int32_t>(
// const framework::Tensor *input, const framework::Tensor *filter,
// const std::vector<int> &strides, framework::Tensor *output) {
// PADDLE_MOBILE_THROW_EXCEPTION(
// "Depthwise conv with generic strides has not been implemented.");
// }
template <> template <>
void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
const framework::Tensor &filter, const framework::Tensor &filter,
...@@ -342,29 +138,22 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -342,29 +138,22 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
const int8_t *input_ptr = input_data + g * image_size; const int8_t *input_ptr = input_data + g * image_size;
const int8_t *filter_ptr = filter_data + g * 9; const int8_t *filter_ptr = filter_data + g * 9;
int32_t *output_ptr = out_data + g * out_image_size; int32_t *output_ptr = out_data + g * out_image_size;
const int8_t *filter_ptr0 = filter_ptr;
const int8_t *filter_ptr1 = filter_ptr0 + 3;
const int8_t *filter_ptr2 = filter_ptr1 + 3;
int16x4_t _k0 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr0)));
int16x4_t _k1 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr1)));
int16x4_t _k2 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr2)));
int16x8_t _ker0 = vcombine_s16(_k0, _k1);
int16x8_t _ker1 = vcombine_s16(_k2, _k2);
int16x4_t zero = vdup_n_s16(0);
int16x4_t _ker[3] = {_k0, _k1, _k2};
// top // top
for (int h = 0; h < valid_h_start; ++h) { for (int h = 0; h < valid_h_start; ++h) {
DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w, input_w, padding_h, padding_w, output_w,
output_ptr); output_ptr, _ker);
}
// left
for (int w = 0; w < valid_w_start; ++w) {
DepthwiseConv3x3ValidCol<1, 1>(
input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h,
input_w, padding_h, padding_w, output_w, output_ptr);
}
// right
for (int w = valid_w_end; w < output_w; ++w) {
DepthwiseConv3x3ValidCol<1, 1>(
input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h,
input_w, padding_h, padding_w, output_w, output_ptr);
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr);
} }
// valid // valid
int output_w_tiles = valid_w / 6; int output_w_tiles = valid_w / 6;
...@@ -376,32 +165,63 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -376,32 +165,63 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
const int8_t *input_ptr3 = input_ptr2 + input_w; const int8_t *input_ptr3 = input_ptr2 + input_w;
const int8_t *input_ptr4 = input_ptr3 + input_w; const int8_t *input_ptr4 = input_ptr3 + input_w;
const int8_t *input_ptr5 = input_ptr4 + input_w; const int8_t *input_ptr5 = input_ptr4 + input_w;
int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start; int32_t *output_ptr0 = output_ptr + h * output_w;
int32_t *output_ptr1 = output_ptr0 + output_w; int32_t *output_ptr1 = output_ptr0 + output_w;
int32_t *output_ptr2 = output_ptr1 + output_w; int32_t *output_ptr2 = output_ptr1 + output_w;
int32_t *output_ptr3 = output_ptr2 + output_w; int32_t *output_ptr3 = output_ptr2 + output_w;
// pad left
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3)));
int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4)));
int16x4_t row5 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr5)));
int32x4_t acc;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0;
output_ptr1[w] = 0;
output_ptr2[w] = 0;
output_ptr3[w] = 0;
} else {
row0 = vext_s16(zero, row0, 3);
row1 = vext_s16(zero, row1, 3);
row2 = vext_s16(zero, row2, 3);
row3 = vext_s16(zero, row3, 3);
row4 = vext_s16(zero, row4, 3);
row5 = vext_s16(zero, row5, 3);
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
output_ptr0[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
acc = vmull_s16(row1, _ker[0]);
acc = vmlal_s16(acc, row2, _ker[1]);
acc = vmlal_s16(acc, row3, _ker[2]);
output_ptr1[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
acc = vmull_s16(row2, _ker[0]);
acc = vmlal_s16(acc, row3, _ker[1]);
acc = vmlal_s16(acc, row4, _ker[2]);
output_ptr2[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
acc = vmull_s16(row3, _ker[0]);
acc = vmlal_s16(acc, row4, _ker[1]);
acc = vmlal_s16(acc, row5, _ker[2]);
output_ptr3[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
}
}
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
output_ptr2 += valid_w_start;
output_ptr3 += valid_w_start;
}
// valid
int loop = output_w_tiles; int loop = output_w_tiles;
asm volatile( asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #6 \n"
"cmp %[loop], #0 \n" "cmp %[loop], #0 \n"
"ble start_remain_%= \n" "ble start_remain_%= \n"
// loop 6 widths "mov r0, #6 \n"
// loop 6 width
"loop_4h6w_%=: \n" "loop_4h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n" "vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n" "vld1.32 {d10}, [%[input_ptr1]], r0 \n"
...@@ -411,59 +231,59 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -411,59 +231,59 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n" "vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, d1 \n" "vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, d2 \n" "vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, d0 \n" "vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, d1 \n" "vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, d2 \n" "vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d12, d10, d10, #1 \n" "vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n" "vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n" "vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n" "vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, d4 \n" "vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, d5 \n" "vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, d3 \n" "vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, d4 \n" "vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, d5 \n" "vmlal.s16 q11, d19, %f[ker0][2] \n"
"vmull.s16 q12, d14, d0 \n" "vmull.s16 q12, d14, %e[ker0][0] \n"
"vmlal.s16 q12, d16, d1 \n" "vmlal.s16 q12, d16, %e[ker0][1] \n"
"vmlal.s16 q12, d18, d2 \n" "vmlal.s16 q12, d18, %e[ker0][2] \n"
"vmull.s16 q13, d15, d0 \n" "vmull.s16 q13, d15, %e[ker0][0] \n"
"vmlal.s16 q13, d17, d1 \n" "vmlal.s16 q13, d17, %e[ker0][1] \n"
"vmlal.s16 q13, d19, d2 \n" "vmlal.s16 q13, d19, %e[ker0][2] \n"
"vext.s8 d12, d11, d11, #1 \n" "vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n" "vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n" "vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n" "vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, d7 \n" "vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, d8 \n" "vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, d6 \n" "vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, d7 \n" "vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, d8 \n" "vmlal.s16 q11, d19, %e[ker1][2] \n"
// store row 0, reuse q10/q11 // store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n" "vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"vmlal.s16 q12, d14, d3 \n" "vmlal.s16 q12, d14, %f[ker0][0] \n"
"vmlal.s16 q12, d16, d4 \n" "vmlal.s16 q12, d16, %f[ker0][1] \n"
"vmlal.s16 q12, d18, d5 \n" "vmlal.s16 q12, d18, %f[ker0][2] \n"
"vmlal.s16 q13, d15, d3 \n" "vmlal.s16 q13, d15, %f[ker0][0] \n"
"vmlal.s16 q13, d17, d4 \n" "vmlal.s16 q13, d17, %f[ker0][1] \n"
"vmlal.s16 q13, d19, d5 \n" "vmlal.s16 q13, d19, %f[ker0][2] \n"
"vmull.s16 q14, d14, d0 \n" "vmull.s16 q14, d14, %e[ker0][0] \n"
"vmlal.s16 q14, d16, d1 \n" "vmlal.s16 q14, d16, %e[ker0][1] \n"
"vmlal.s16 q14, d18, d2 \n" "vmlal.s16 q14, d18, %e[ker0][2] \n"
"vmull.s16 q15, d15, d0 \n" "vmull.s16 q15, d15, %e[ker0][0] \n"
"vmlal.s16 q15, d17, d1 \n" "vmlal.s16 q15, d17, %e[ker0][1] \n"
"vmlal.s16 q15, d19, d2 \n" "vmlal.s16 q15, d19, %e[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr3]], r0 \n" "vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vld1.32 {d10}, [%[input_ptr4]], r0 \n" "vld1.32 {d10}, [%[input_ptr4]], r0 \n"
...@@ -473,61 +293,61 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -473,61 +293,61 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n" "vmlal.s16 q12, d14, %e[ker1][0] \n"
"vmlal.s16 q12, d16, d7 \n" "vmlal.s16 q12, d16, %e[ker1][1] \n"
"vmlal.s16 q12, d18, d8 \n" "vmlal.s16 q12, d18, %e[ker1][2] \n"
"vmlal.s16 q13, d15, d6 \n" "vmlal.s16 q13, d15, %e[ker1][0] \n"
"vmlal.s16 q13, d17, d7 \n" "vmlal.s16 q13, d17, %e[ker1][1] \n"
"vmlal.s16 q13, d19, d8 \n" "vmlal.s16 q13, d19, %e[ker1][2] \n"
// store row 1 // store row 1
"vst1.32 {d24-d26}, [%[output_ptr1]]! \n" "vst1.32 {d24-d26}, [%[output_ptr1]]! \n"
"vmlal.s16 q14, d14, d3 \n" "vmlal.s16 q14, d14, %f[ker0][0] \n"
"vmlal.s16 q14, d16, d4 \n" "vmlal.s16 q14, d16, %f[ker0][1] \n"
"vmlal.s16 q14, d18, d5 \n" "vmlal.s16 q14, d18, %f[ker0][2] \n"
"vmlal.s16 q15, d15, d3 \n" "vmlal.s16 q15, d15, %f[ker0][0] \n"
"vmlal.s16 q15, d17, d4 \n" "vmlal.s16 q15, d17, %f[ker0][1] \n"
"vmlal.s16 q15, d19, d5 \n" "vmlal.s16 q15, d19, %f[ker0][2] \n"
"vmull.s16 q10, d14, d0 \n" "vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, d1 \n" "vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, d2 \n" "vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, d0 \n" "vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, d1 \n" "vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, d2 \n" "vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d12, d10, d10, #1 \n" "vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n" "vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n" "vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q14, d14, d6 \n" "vmlal.s16 q14, d14, %e[ker1][0] \n"
"vmlal.s16 q14, d16, d7 \n" "vmlal.s16 q14, d16, %e[ker1][1] \n"
"vmlal.s16 q14, d18, d8 \n" "vmlal.s16 q14, d18, %e[ker1][2] \n"
"vmlal.s16 q15, d15, d6 \n" "vmlal.s16 q15, d15, %e[ker1][0] \n"
"vmlal.s16 q15, d17, d7 \n" "vmlal.s16 q15, d17, %e[ker1][1] \n"
"vmlal.s16 q15, d19, d8 \n" "vmlal.s16 q15, d19, %e[ker1][2] \n"
// store row 2 // store row 2
"vst1.32 {d28-d30}, [%[output_ptr2]]! \n" "vst1.32 {d28-d30}, [%[output_ptr2]]! \n"
"vmlal.s16 q10, d14, d3 \n" "vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, d4 \n" "vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, d5 \n" "vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, d3 \n" "vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, d4 \n" "vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, d5 \n" "vmlal.s16 q11, d19, %f[ker0][2] \n"
"vext.s8 d12, d11, d11, #1 \n" "vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n" "vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n" "vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n" "vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, d7 \n" "vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, d8 \n" "vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, d6 \n" "vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, d7 \n" "vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, d8 \n" "vmlal.s16 q11, d19, %e[ker1][2] \n"
// store row 3 // store row 3
"vst1.32 {d20-d22}, [%[output_ptr3]]! \n" "vst1.32 {d20-d22}, [%[output_ptr3]]! \n"
...@@ -538,125 +358,126 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -538,125 +358,126 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
"cmp %[remain], #0 \n" "cmp %[remain], #0 \n"
"ble end_%= \n" "ble end_%= \n"
"vld1.32 {d9}, [%[input_ptr0]] \n" "mov r0, %[remain] \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n" "vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n" "vmovl.s8 q9, d9 \n"
"vmull.s16 q10, d14, d0 \n" "vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, d1 \n" "vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, d2 \n" "vmlal.s16 q10, d18, %e[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr1]] \n" "vld1.32 {d9}, [%[input_ptr1]], r0 \n"
"vmull.s16 q11, d15, d0 \n" "vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, d1 \n" "vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, d2 \n" "vmlal.s16 q11, d19, %e[ker0][2] \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n" "vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n" "vmovl.s8 q9, d9 \n"
"vmlal.s16 q10, d14, d3 \n" "vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, d4 \n" "vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, d5 \n" "vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, d3 \n" "vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, d4 \n" "vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, d5 \n" "vmlal.s16 q11, d19, %f[ker0][2] \n"
"vmull.s16 q12, d14, d0 \n" "vmull.s16 q12, d14, %e[ker0][0] \n"
"vmlal.s16 q12, d16, d1 \n" "vmlal.s16 q12, d16, %e[ker0][1] \n"
"vmlal.s16 q12, d18, d2 \n" "vmlal.s16 q12, d18, %e[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr2]] \n" "vld1.32 {d9}, [%[input_ptr2]], r0 \n"
"vmull.s16 q13, d15, d0 \n" "vmull.s16 q13, d15, %e[ker0][0] \n"
"vmlal.s16 q13, d17, d1 \n" "vmlal.s16 q13, d17, %e[ker0][1] \n"
"vmlal.s16 q13, d19, d2 \n" "vmlal.s16 q13, d19, %e[ker0][2] \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n" "vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n" "vmovl.s8 q9, d9 \n"
"vmlal.s16 q10, d14, d6 \n" "vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, d7 \n" "vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, d8 \n" "vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, d6 \n" "vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, d7 \n" "vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, d8 \n" "vmlal.s16 q11, d19, %e[ker1][2] \n"
"vmlal.s16 q12, d14, d3 \n" "vmlal.s16 q12, d14, %f[ker0][0] \n"
"vmlal.s16 q12, d16, d4 \n" "vmlal.s16 q12, d16, %f[ker0][1] \n"
"vmlal.s16 q12, d18, d5 \n" "vmlal.s16 q12, d18, %f[ker0][2] \n"
"vmlal.s16 q13, d15, d3 \n" "vmlal.s16 q13, d15, %f[ker0][0] \n"
"vmlal.s16 q13, d17, d4 \n" "vmlal.s16 q13, d17, %f[ker0][1] \n"
"vmlal.s16 q13, d19, d5 \n" "vmlal.s16 q13, d19, %f[ker0][2] \n"
"vmull.s16 q14, d14, d0 \n" "vmull.s16 q14, d14, %e[ker0][0] \n"
"vmlal.s16 q14, d16, d1 \n" "vmlal.s16 q14, d16, %e[ker0][1] \n"
"vmlal.s16 q14, d18, d2 \n" "vmlal.s16 q14, d18, %e[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr3]] \n" "vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vmull.s16 q15, d15, d0 \n" "vmull.s16 q15, d15, %e[ker0][0] \n"
"vmlal.s16 q15, d17, d1 \n" "vmlal.s16 q15, d17, %e[ker0][1] \n"
"vmlal.s16 q15, d19, d2 \n" "vmlal.s16 q15, d19, %e[ker0][2] \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n" "vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n" "vmovl.s8 q9, d9 \n"
"vmlal.s16 q12, d14, d6 \n" "vmlal.s16 q12, d14, %e[ker1][0] \n"
"vmlal.s16 q12, d16, d7 \n" "vmlal.s16 q12, d16, %e[ker1][1] \n"
"vmlal.s16 q12, d18, d8 \n" "vmlal.s16 q12, d18, %e[ker1][2] \n"
"vmlal.s16 q13, d15, d6 \n" "vmlal.s16 q13, d15, %e[ker1][0] \n"
"vmlal.s16 q13, d17, d7 \n" "vmlal.s16 q13, d17, %e[ker1][1] \n"
"vmlal.s16 q13, d19, d8 \n" "vmlal.s16 q13, d19, %e[ker1][2] \n"
"vmlal.s16 q14, d14, d3 \n" "vmlal.s16 q14, d14, %f[ker0][0] \n"
"vmlal.s16 q14, d16, d4 \n" "vmlal.s16 q14, d16, %f[ker0][1] \n"
"vmlal.s16 q14, d18, d5 \n" "vmlal.s16 q14, d18, %f[ker0][2] \n"
"vmlal.s16 q15, d15, d3 \n" "vmlal.s16 q15, d15, %f[ker0][0] \n"
"vmlal.s16 q15, d17, d4 \n" "vmlal.s16 q15, d17, %f[ker0][1] \n"
"vmlal.s16 q15, d19, d5 \n" "vmlal.s16 q15, d19, %f[ker0][2] \n"
"vmull.s16 q5, d14, d0 \n" "vmull.s16 q5, d14, %e[ker0][0] \n"
"vmlal.s16 q5, d16, d1 \n" "vmlal.s16 q5, d16, %e[ker0][1] \n"
"vmlal.s16 q5, d18, d2 \n" "vmlal.s16 q5, d18, %e[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr4]] \n" "vld1.32 {d9}, [%[input_ptr4]], r0 \n"
"vmull.s16 q6, d15, d0 \n" "vmull.s16 q6, d15, %e[ker0][0] \n"
"vmlal.s16 q6, d17, d1 \n" "vmlal.s16 q6, d17, %e[ker0][1] \n"
"vmlal.s16 q6, d19, d2 \n" "vmlal.s16 q6, d19, %e[ker0][2] \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n" "vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n" "vmovl.s8 q9, d9 \n"
"vmlal.s16 q14, d14, d6 \n" "vmlal.s16 q14, d14, %e[ker1][0] \n"
"vmlal.s16 q14, d16, d7 \n" "vmlal.s16 q14, d16, %e[ker1][1] \n"
"vmlal.s16 q14, d18, d8 \n" "vmlal.s16 q14, d18, %e[ker1][2] \n"
"vmlal.s16 q15, d15, d6 \n" "vmlal.s16 q15, d15, %e[ker1][0] \n"
"vmlal.s16 q15, d17, d7 \n" "vmlal.s16 q15, d17, %e[ker1][1] \n"
"vmlal.s16 q15, d19, d8 \n" "vmlal.s16 q15, d19, %e[ker1][2] \n"
"vmlal.s16 q5, d14, d3 \n" "vmlal.s16 q5, d14, %f[ker0][0] \n"
"vmlal.s16 q5, d16, d4 \n" "vmlal.s16 q5, d16, %f[ker0][1] \n"
"vmlal.s16 q5, d18, d5 \n" "vmlal.s16 q5, d18, %f[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr5]] \n" "vld1.32 {d9}, [%[input_ptr5]], r0 \n"
"vmlal.s16 q6, d15, d3 \n" "vmlal.s16 q6, d15, %f[ker0][0] \n"
"vmlal.s16 q6, d17, d4 \n" "vmlal.s16 q6, d17, %f[ker0][1] \n"
"vmlal.s16 q6, d19, d5 \n" "vmlal.s16 q6, d19, %f[ker0][2] \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n" "vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n" "vmovl.s8 q9, d9 \n"
"vmlal.s16 q5, d14, d6 \n" "vmlal.s16 q5, d14, %e[ker1][0] \n"
"vmlal.s16 q5, d16, d7 \n" "vmlal.s16 q5, d16, %e[ker1][1] \n"
"vmlal.s16 q5, d18, d8 \n" "vmlal.s16 q5, d18, %e[ker1][2] \n"
"vmlal.s16 q6, d15, d6 \n" "vmlal.s16 q6, d15, %e[ker1][0] \n"
"vmlal.s16 q6, d17, d7 \n" "vmlal.s16 q6, d17, %e[ker1][1] \n"
"vmlal.s16 q6, d19, d8 \n" "vmlal.s16 q6, d19, %e[ker1][2] \n"
"cmp %[remain], #4 \n" "cmp %[remain], #4 \n"
"blt store_4h2w_%= \n" "blt store_4h2w_%= \n"
...@@ -701,9 +522,62 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -701,9 +522,62 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5), [input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5),
[loop] "+r"(loop) [loop] "+r"(loop)
: [remain] "r"(output_w_remain) : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); "q12", "q13", "q14", "q15", "r0");
// pad right
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - 2)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - 2)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - 2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3 - 2)));
int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4 - 2)));
int16x4_t row5 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr5 - 2)));
row0 = vext_s16(row0, zero, 2);
row1 = vext_s16(row1, zero, 2);
row2 = vext_s16(row2, zero, 2);
row3 = vext_s16(row3, zero, 2);
row4 = vext_s16(row4, zero, 2);
row5 = vext_s16(row5, zero, 2);
int32x4_t acc;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
*output_ptr1 = 0;
*output_ptr2 = 0;
*output_ptr3 = 0;
} else {
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
*output_ptr0 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
acc = vmull_s16(row1, _ker[0]);
acc = vmlal_s16(acc, row2, _ker[1]);
acc = vmlal_s16(acc, row3, _ker[2]);
*output_ptr1 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
acc = vmull_s16(row2, _ker[0]);
acc = vmlal_s16(acc, row3, _ker[1]);
acc = vmlal_s16(acc, row4, _ker[2]);
*output_ptr2 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
acc = vmull_s16(row3, _ker[0]);
acc = vmlal_s16(acc, row4, _ker[1]);
acc = vmlal_s16(acc, row5, _ker[2]);
*output_ptr3 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
row0 = vext_s16(row0, zero, 1);
row1 = vext_s16(row1, zero, 1);
row2 = vext_s16(row2, zero, 1);
row3 = vext_s16(row3, zero, 1);
row4 = vext_s16(row4, zero, 1);
row5 = vext_s16(row5, zero, 1);
}
output_ptr0++;
output_ptr1++;
output_ptr2++;
output_ptr3++;
}
}
} }
// remain height // remain height
int start_h = valid_h_start + (valid_h & 0xFFFC); int start_h = valid_h_start + (valid_h & 0xFFFC);
...@@ -712,25 +586,40 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -712,25 +586,40 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
const int8_t *input_ptr1 = input_ptr0 + input_w; const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w; const int8_t *input_ptr2 = input_ptr1 + input_w;
const int8_t *input_ptr3 = input_ptr2 + input_w; const int8_t *input_ptr3 = input_ptr2 + input_w;
int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start; int32_t *output_ptr0 = output_ptr + h * output_w;
int32_t *output_ptr1 = output_ptr0 + output_w; int32_t *output_ptr1 = output_ptr0 + output_w;
// pad left
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3)));
int32x4_t acc;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0;
output_ptr1[w] = 0;
} else {
row0 = vext_s16(zero, row0, 3);
row1 = vext_s16(zero, row1, 3);
row2 = vext_s16(zero, row2, 3);
row3 = vext_s16(zero, row3, 3);
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
output_ptr0[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
acc = vmull_s16(row1, _ker[0]);
acc = vmlal_s16(acc, row2, _ker[1]);
acc = vmlal_s16(acc, row3, _ker[2]);
output_ptr1[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
}
}
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
}
// valid
int loop = output_w_tiles; int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile( asm volatile(
"cmp %[loop], #0 \n" "cmp %[loop], #0 \n"
"ble start_remain_%= \n" "ble start_remain_%= \n"
...@@ -745,52 +634,52 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -745,52 +634,52 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n" "vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, d1 \n" "vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, d2 \n" "vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, d0 \n" "vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, d1 \n" "vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, d2 \n" "vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d12, d10, d10, #1 \n" "vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n" "vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n" "vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n" "vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, d4 \n" "vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, d5 \n" "vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, d3 \n" "vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, d4 \n" "vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, d5 \n" "vmlal.s16 q11, d19, %f[ker0][2] \n"
"vmull.s16 q12, d14, d0 \n" "vmull.s16 q12, d14, %e[ker0][0] \n"
"vmlal.s16 q12, d16, d1 \n" "vmlal.s16 q12, d16, %e[ker0][1] \n"
"vmlal.s16 q12, d18, d2 \n" "vmlal.s16 q12, d18, %e[ker0][2] \n"
"vmull.s16 q13, d15, d0 \n" "vmull.s16 q13, d15, %e[ker0][0] \n"
"vmlal.s16 q13, d17, d1 \n" "vmlal.s16 q13, d17, %e[ker0][1] \n"
"vmlal.s16 q13, d19, d2 \n" "vmlal.s16 q13, d19, %e[ker0][2] \n"
"vext.s8 d12, d11, d11, #1 \n" "vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n" "vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n" "vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n" "vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, d7 \n" "vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, d8 \n" "vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, d6 \n" "vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, d7 \n" "vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, d8 \n" "vmlal.s16 q11, d19, %e[ker1][2] \n"
// store row 0, reuse q10/q11 // store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n" "vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"vmlal.s16 q12, d14, d3 \n" "vmlal.s16 q12, d14, %f[ker0][0] \n"
"vmlal.s16 q12, d16, d4 \n" "vmlal.s16 q12, d16, %f[ker0][1] \n"
"vmlal.s16 q12, d18, d5 \n" "vmlal.s16 q12, d18, %f[ker0][2] \n"
"vmlal.s16 q13, d15, d3 \n" "vmlal.s16 q13, d15, %f[ker0][0] \n"
"vmlal.s16 q13, d17, d4 \n" "vmlal.s16 q13, d17, %f[ker0][1] \n"
"vmlal.s16 q13, d19, d5 \n" "vmlal.s16 q13, d19, %f[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr3]], r0 \n" "vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n" "vext.s8 d12, d9, d9, #1 \n"
...@@ -798,12 +687,12 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -798,12 +687,12 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n" "vmlal.s16 q12, d14, %e[ker1][0] \n"
"vmlal.s16 q12, d16, d7 \n" "vmlal.s16 q12, d16, %e[ker1][1] \n"
"vmlal.s16 q12, d18, d8 \n" "vmlal.s16 q12, d18, %e[ker1][2] \n"
"vmlal.s16 q13, d15, d6 \n" "vmlal.s16 q13, d15, %e[ker1][0] \n"
"vmlal.s16 q13, d17, d7 \n" "vmlal.s16 q13, d17, %e[ker1][1] \n"
"vmlal.s16 q13, d19, d8 \n" "vmlal.s16 q13, d19, %e[ker1][2] \n"
// store row 1 // store row 1
"vst1.32 {d24-d26}, [%[output_ptr1]]! \n" "vst1.32 {d24-d26}, [%[output_ptr1]]! \n"
...@@ -814,71 +703,72 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -814,71 +703,72 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
"cmp %[remain], #0 \n" "cmp %[remain], #0 \n"
"ble end_%= \n" "ble end_%= \n"
"vld1.32 {d9}, [%[input_ptr0]] \n" "mov r0, %[remain] \n"
"vld1.32 {d10}, [%[input_ptr1]] \n" "vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]] \n" "vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n" "vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n" "vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n" "vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, d1 \n" "vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, d2 \n" "vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, d0 \n" "vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, d1 \n" "vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, d2 \n" "vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d12, d10, d10, #1 \n" "vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n" "vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n" "vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n" "vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, d4 \n" "vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, d5 \n" "vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, d3 \n" "vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, d4 \n" "vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, d5 \n" "vmlal.s16 q11, d19, %f[ker0][2] \n"
"vmull.s16 q12, d14, d0 \n" "vmull.s16 q12, d14, %e[ker0][0] \n"
"vmlal.s16 q12, d16, d1 \n" "vmlal.s16 q12, d16, %e[ker0][1] \n"
"vmlal.s16 q12, d18, d2 \n" "vmlal.s16 q12, d18, %e[ker0][2] \n"
"vmull.s16 q13, d15, d0 \n" "vmull.s16 q13, d15, %e[ker0][0] \n"
"vmlal.s16 q13, d17, d1 \n" "vmlal.s16 q13, d17, %e[ker0][1] \n"
"vmlal.s16 q13, d19, d2 \n" "vmlal.s16 q13, d19, %e[ker0][2] \n"
"vext.s8 d12, d11, d11, #1 \n" "vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n" "vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n" "vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n" "vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, d7 \n" "vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, d8 \n" "vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, d6 \n" "vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, d7 \n" "vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, d8 \n" "vmlal.s16 q11, d19, %e[ker1][2] \n"
"vmlal.s16 q12, d14, d3 \n" "vmlal.s16 q12, d14, %f[ker0][0] \n"
"vmlal.s16 q12, d16, d4 \n" "vmlal.s16 q12, d16, %f[ker0][1] \n"
"vmlal.s16 q12, d18, d5 \n" "vmlal.s16 q12, d18, %f[ker0][2] \n"
"vmlal.s16 q13, d15, d3 \n" "vmlal.s16 q13, d15, %f[ker0][0] \n"
"vmlal.s16 q13, d17, d4 \n" "vmlal.s16 q13, d17, %f[ker0][1] \n"
"vmlal.s16 q13, d19, d5 \n" "vmlal.s16 q13, d19, %f[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr3]] \n" "vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n" "vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n" "vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n" "vmlal.s16 q12, d14, %e[ker1][0] \n"
"vmlal.s16 q12, d16, d7 \n" "vmlal.s16 q12, d16, %e[ker1][1] \n"
"vmlal.s16 q12, d18, d8 \n" "vmlal.s16 q12, d18, %e[ker1][2] \n"
"vmlal.s16 q13, d15, d6 \n" "vmlal.s16 q13, d15, %e[ker1][0] \n"
"vmlal.s16 q13, d17, d7 \n" "vmlal.s16 q13, d17, %e[ker1][1] \n"
"vmlal.s16 q13, d19, d8 \n" "vmlal.s16 q13, d19, %e[ker1][2] \n"
"cmp %[remain], #4 \n" "cmp %[remain], #4 \n"
"blt store_2h2w_%= \n" "blt store_2h2w_%= \n"
...@@ -911,9 +801,44 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -911,9 +801,44 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
[input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[loop] "+r"(loop) [loop] "+r"(loop)
: [remain] "r"(output_w_remain) : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
"q8", "q9", "q10", "q11", "q12", "q13", "r0"); "q12", "q13", "q14", "q15", "r0");
// pad right
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - 2)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - 2)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - 2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3 - 2)));
row0 = vext_s16(row0, zero, 2);
row1 = vext_s16(row1, zero, 2);
row2 = vext_s16(row2, zero, 2);
row3 = vext_s16(row3, zero, 2);
int32x4_t acc;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
*output_ptr1 = 0;
} else {
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
*output_ptr0 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
acc = vmull_s16(row1, _ker[0]);
acc = vmlal_s16(acc, row2, _ker[1]);
acc = vmlal_s16(acc, row3, _ker[2]);
*output_ptr1 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
row0 = vext_s16(row0, zero, 1);
row1 = vext_s16(row1, zero, 1);
row2 = vext_s16(row2, zero, 1);
row3 = vext_s16(row3, zero, 1);
}
output_ptr0++;
output_ptr1++;
}
}
} }
start_h = valid_h_start + (valid_h & 0xFFFE); start_h = valid_h_start + (valid_h & 0xFFFE);
...@@ -921,24 +846,31 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -921,24 +846,31 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
const int8_t *input_ptr0 = input_ptr + (start_h - padding_h) * input_w; const int8_t *input_ptr0 = input_ptr + (start_h - padding_h) * input_w;
const int8_t *input_ptr1 = input_ptr0 + input_w; const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w; const int8_t *input_ptr2 = input_ptr1 + input_w;
int32_t *output_ptr0 = output_ptr + start_h * output_w + valid_w_start; int32_t *output_ptr0 = output_ptr + start_h * output_w;
// pad left
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int32x4_t acc;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0;
} else {
row0 = vext_s16(zero, row0, 3);
row1 = vext_s16(zero, row1, 3);
row2 = vext_s16(zero, row2, 3);
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
output_ptr0[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
}
}
output_ptr0 += valid_w_start;
}
// valid
int loop = output_w_tiles; int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile( asm volatile(
"cmp %[loop], #0 \n" "cmp %[loop], #0 \n"
"ble start_remain_%= \n" "ble start_remain_%= \n"
...@@ -953,36 +885,36 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -953,36 +885,36 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n" "vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, d1 \n" "vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, d2 \n" "vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, d0 \n" "vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, d1 \n" "vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, d2 \n" "vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d12, d10, d10, #1 \n" "vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n" "vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n" "vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n" "vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, d4 \n" "vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, d5 \n" "vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, d3 \n" "vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, d4 \n" "vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, d5 \n" "vmlal.s16 q11, d19, %f[ker0][2] \n"
"vext.s8 d12, d11, d11, #1 \n" "vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n" "vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n" "vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n" "vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, d7 \n" "vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, d8 \n" "vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, d6 \n" "vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, d7 \n" "vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, d8 \n" "vmlal.s16 q11, d19, %e[ker1][2] \n"
// store row 0, reuse q10/q11 // store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n" "vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
...@@ -992,45 +924,46 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -992,45 +924,46 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
"start_remain_%=: \n" "start_remain_%=: \n"
"cmp %[remain], #0 \n" "cmp %[remain], #0 \n"
"ble end_%= \n" "ble end_%= \n"
"mov r0, %[remain] \n"
"vld1.32 {d9}, [%[input_ptr0]] \n" "vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]] \n" "vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]] \n" "vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n" "vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n" "vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n" "vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, d1 \n" "vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, d2 \n" "vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, d0 \n" "vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, d1 \n" "vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, d2 \n" "vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d12, d10, d10, #1 \n" "vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n" "vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n" "vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n" "vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, d4 \n" "vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, d5 \n" "vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, d3 \n" "vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, d4 \n" "vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, d5 \n" "vmlal.s16 q11, d19, %f[ker0][2] \n"
"vext.s8 d12, d11, d11, #1 \n" "vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n" "vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n" "vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n" "vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, d7 \n" "vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, d8 \n" "vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, d6 \n" "vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, d7 \n" "vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, d8 \n" "vmlal.s16 q11, d19, %e[ker1][2] \n"
"cmp %[remain], #4 \n" "cmp %[remain], #4 \n"
"blt store_1h2w_%= \n" "blt store_1h2w_%= \n"
...@@ -1057,9 +990,41 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input, ...@@ -1057,9 +990,41 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
: [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0), : [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0),
[input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2),
[loop] "+r"(loop) [loop] "+r"(loop)
: [remain] "r"(output_w_remain) : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
"q8", "q9", "q10", "q11", "r0"); "q12", "q13", "q14", "q15", "r0");
// pad right
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - 2)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - 2)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - 2)));
row0 = vext_s16(row0, zero, 2);
row1 = vext_s16(row1, zero, 2);
row2 = vext_s16(row2, zero, 2);
int32x4_t acc;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
} else {
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
*output_ptr0 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
row0 = vext_s16(row0, zero, 1);
row1 = vext_s16(row1, zero, 1);
row2 = vext_s16(row2, zero, 1);
}
output_ptr0++;
}
}
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker);
} }
} }
} }
...@@ -1081,11 +1046,13 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input, ...@@ -1081,11 +1046,13 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input,
int image_size = input_h * input_w; int image_size = input_h * input_w;
int out_image_size = output_h * output_w; int out_image_size = output_h * output_w;
int valid_h_start = (padding_h + 1) / 2; int valid_h_start = (padding_h + 1) / 2;
int valid_h_end = output_h - valid_h_start; int valid_h_end = (input_h + padding_h - 1) / 2;
int valid_h = valid_h_end - valid_h_start; int valid_h = valid_h_end - valid_h_start;
int valid_w_start = (padding_w + 1) / 2; int valid_w_start = (padding_w + 1) / 2;
int valid_w_end = output_w - valid_w_start; int valid_w_end = (input_w + padding_w - 1) / 2;
int valid_w = valid_w_end - valid_w_start; int valid_w = valid_w_end - valid_w_start;
// for pad left
int valid_input_w_start = (valid_w_start << 1) - padding_w;
// DLOG << "valid_h_start: " << valid_h_start; // DLOG << "valid_h_start: " << valid_h_start;
// DLOG << "valid_h_end: " << valid_h_end; // DLOG << "valid_h_end: " << valid_h_end;
...@@ -1097,170 +1064,203 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input, ...@@ -1097,170 +1064,203 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input,
const int8_t *input_ptr = input_data + g * image_size; const int8_t *input_ptr = input_data + g * image_size;
const int8_t *filter_ptr = filter_data + g * 9; const int8_t *filter_ptr = filter_data + g * 9;
int32_t *output_ptr = out_data + g * out_image_size; int32_t *output_ptr = out_data + g * out_image_size;
const int8_t *filter_ptr0 = filter_ptr;
const int8_t *filter_ptr1 = filter_ptr0 + 3;
const int8_t *filter_ptr2 = filter_ptr1 + 3;
int16x4_t _k0 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr0)));
int16x4_t _k1 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr1)));
int16x4_t _k2 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr2)));
int16x8_t _ker0 = vcombine_s16(_k0, _k1);
int16x8_t _ker1 = vcombine_s16(_k2, _k2);
int16x4_t _ker[3] = {_k0, _k1, _k2};
// top // top
for (int h = 0; h < valid_h_start; ++h) { for (int h = 0; h < valid_h_start; ++h) {
DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h, DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w, input_w, padding_h, padding_w, output_w,
output_ptr); output_ptr, _ker);
}
// left
for (int w = 0; w < valid_w_start; ++w) {
DepthwiseConv3x3ValidCol<2, 2>(
input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h,
input_w, padding_h, padding_w, output_w, output_ptr);
}
// right
for (int w = valid_w_end; w < output_w; ++w) {
DepthwiseConv3x3ValidCol<2, 2>(
input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h,
input_w, padding_h, padding_w, output_w, output_ptr);
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr);
} }
// valid // valid
int input_w_start = 2 * valid_w_start - padding_w; int input_w_start = 2 * valid_w_start - padding_w;
int output_w_tiles = valid_w / 6; int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6; int output_w_remain = valid_w - output_w_tiles * 6;
for (int h = valid_h_start; h < valid_h_end - 2; h += 3) { for (int h = valid_h_start; h < valid_h_end - 2; h += 3) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start; const int8_t *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
const int8_t *input_ptr0 = input_ptr + offset;
const int8_t *input_ptr1 = input_ptr0 + input_w; const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w; const int8_t *input_ptr2 = input_ptr1 + input_w;
const int8_t *input_ptr3 = input_ptr2 + input_w; const int8_t *input_ptr3 = input_ptr2 + input_w;
const int8_t *input_ptr4 = input_ptr3 + input_w; const int8_t *input_ptr4 = input_ptr3 + input_w;
const int8_t *input_ptr5 = input_ptr4 + input_w; const int8_t *input_ptr5 = input_ptr4 + input_w;
const int8_t *input_ptr6 = input_ptr5 + input_w; const int8_t *input_ptr6 = input_ptr5 + input_w;
int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start; int32_t *output_ptr0 = output_ptr + h * output_w;
int32_t *output_ptr1 = output_ptr0 + output_w; int32_t *output_ptr1 = output_ptr0 + output_w;
int32_t *output_ptr2 = output_ptr1 + output_w; int32_t *output_ptr2 = output_ptr1 + output_w;
// pad left
if (padding_w) {
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - (w << 1);
if (padding >= 3) {
output_ptr0[w] = 0;
output_ptr1[w] = 0;
output_ptr2[w] = 0;
} else {
int16x4_t row0 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - padding)));
int16x4_t row1 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - padding)));
int16x4_t row2 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - padding)));
int16x4_t row3 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr3 - padding)));
int16x4_t row4 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr4 - padding)));
int16x4_t row5 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr5 - padding)));
int16x4_t row6 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr6 - padding)));
int32x4_t acc0 = vmull_s16(row0, _ker[0]);
acc0 = vmlal_s16(acc0, row1, _ker[1]);
acc0 = vmlal_s16(acc0, row2, _ker[2]);
int32x4_t acc1 = vmull_s16(row2, _ker[0]);
acc1 = vmlal_s16(acc1, row3, _ker[1]);
acc1 = vmlal_s16(acc1, row4, _ker[2]);
int32x4_t acc2 = vmull_s16(row4, _ker[0]);
acc2 = vmlal_s16(acc2, row5, _ker[1]);
acc2 = vmlal_s16(acc2, row6, _ker[2]);
int32_t sum0 = vgetq_lane_s32(acc0, 2);
int32_t sum1 = vgetq_lane_s32(acc1, 2);
int32_t sum2 = vgetq_lane_s32(acc2, 2);
if (padding == 1) {
sum0 += vgetq_lane_s32(acc0, 1);
sum1 += vgetq_lane_s32(acc1, 1);
sum2 += vgetq_lane_s32(acc2, 1);
}
output_ptr0[w] = sum0;
output_ptr1[w] = sum1;
output_ptr2[w] = sum2;
}
}
input_ptr0 += valid_input_w_start;
input_ptr1 += valid_input_w_start;
input_ptr2 += valid_input_w_start;
input_ptr3 += valid_input_w_start;
input_ptr4 += valid_input_w_start;
input_ptr5 += valid_input_w_start;
input_ptr6 += valid_input_w_start;
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
output_ptr2 += valid_w_start;
}
// valid
int loop = output_w_tiles; int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile( asm volatile(
"cmp %[loop], #0 \n" "cmp %[loop], #0 \n"
"ble start_remain_%= \n" "ble start_remain_%= \n"
"mov r0, #12 \n" "mov r0, #12 \n"
// loop 6 widths // loop 6 widths
"loop_3h6w_%=: \n" "loop_3h6w_%=: \n"
"vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n" "vld2.8 {d10-d11}, [%[input_ptr0]], r0 \n"
"vld2.8 {d12, d13}, [%[input_ptr1]], r0 \n" "vld2.8 {d12-d13}, [%[input_ptr1]], r0 \n"
"vld2.8 {d14, d15}, [%[input_ptr2]], r0 \n" "vld2.8 {d14-d15}, [%[input_ptr2]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n" "vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n" "vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n" "vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n" "vmovl.s8 q9, d11 \n"
"vmull.s16 q11, d16, d0 \n" "vmull.s16 q11, d16, %e[ker0][0] \n"
"vmlal.s16 q11, d18, d1 \n" "vmlal.s16 q11, d18, %e[ker0][1] \n"
"vmlal.s16 q11, d20, d2 \n" "vmlal.s16 q11, d20, %e[ker0][2] \n"
"vmull.s16 q12, d17, d0 \n" "vmull.s16 q12, d17, %e[ker0][0] \n"
"vmlal.s16 q12, d19, d1 \n" "vmlal.s16 q12, d19, %e[ker0][1] \n"
"vmlal.s16 q12, d21, d2 \n" "vmlal.s16 q12, d21, %e[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n" "vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n" "vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q11, d16, d3 \n" "vmlal.s16 q11, d16, %f[ker0][0] \n"
"vmlal.s16 q11, d18, d4 \n" "vmlal.s16 q11, d18, %f[ker0][1] \n"
"vmlal.s16 q11, d20, d5 \n" "vmlal.s16 q11, d20, %f[ker0][2] \n"
"vmlal.s16 q12, d17, d3 \n" "vmlal.s16 q12, d17, %f[ker0][0] \n"
"vmlal.s16 q12, d19, d4 \n" "vmlal.s16 q12, d19, %f[ker0][1] \n"
"vmlal.s16 q12, d21, d5 \n" "vmlal.s16 q12, d21, %f[ker0][2] \n"
"vext.s8 d9, d14, d14, #1 \n" "vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n" "vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n" "vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n" "vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, d6 \n" "vmlal.s16 q11, d16, %e[ker1][0] \n"
"vmlal.s16 q11, d18, d7 \n" "vmlal.s16 q11, d18, %e[ker1][1] \n"
"vmlal.s16 q11, d20, d8 \n" "vmlal.s16 q11, d20, %e[ker1][2] \n"
"vmlal.s16 q12, d17, d6 \n" "vmlal.s16 q12, d17, %e[ker1][0] \n"
"vmlal.s16 q12, d19, d7 \n" "vmlal.s16 q12, d19, %e[ker1][1] \n"
"vmlal.s16 q12, d21, d8 \n" "vmlal.s16 q12, d21, %e[ker1][2] \n"
// store row 0, reuse q11/q12 // store row 0, reuse q11/q12
"vst1.32 {d22-d24}, [%[output_ptr0]]! \n" "vst1.32 {d22-d24}, [%[output_ptr0]]! \n"
"vmull.s16 q13, d16, d0 \n" "vmull.s16 q13, d16, %e[ker0][0] \n"
"vmlal.s16 q13, d18, d1 \n" "vmlal.s16 q13, d18, %e[ker0][1] \n"
"vmlal.s16 q13, d20, d2 \n" "vmlal.s16 q13, d20, %e[ker0][2] \n"
"vmull.s16 q14, d17, d0 \n" "vmull.s16 q14, d17, %e[ker0][0] \n"
"vmlal.s16 q14, d19, d1 \n" "vmlal.s16 q14, d19, %e[ker0][1] \n"
"vmlal.s16 q14, d21, d2 \n" "vmlal.s16 q14, d21, %e[ker0][2] \n"
"vld2.8 {d10, d11}, [%[input_ptr3]], r0 \n" "vld2.8 {d10-d11}, [%[input_ptr3]], r0 \n"
"vld2.8 {d12, d13}, [%[input_ptr4]], r0 \n" "vld2.8 {d12-d13}, [%[input_ptr4]], r0 \n"
"vld2.8 {d14, d15}, [%[input_ptr5]], r0 \n" "vld2.8 {d14-d15}, [%[input_ptr5]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n" "vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n" "vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n" "vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n" "vmovl.s8 q9, d11 \n"
"vmlal.s16 q13, d16, d3 \n" "vmlal.s16 q13, d16, %f[ker0][0] \n"
"vmlal.s16 q13, d18, d4 \n" "vmlal.s16 q13, d18, %f[ker0][1] \n"
"vmlal.s16 q13, d20, d5 \n" "vmlal.s16 q13, d20, %f[ker0][2] \n"
"vmlal.s16 q14, d17, d3 \n" "vmlal.s16 q14, d17, %f[ker0][0] \n"
"vmlal.s16 q14, d19, d4 \n" "vmlal.s16 q14, d19, %f[ker0][1] \n"
"vmlal.s16 q14, d21, d5 \n" "vmlal.s16 q14, d21, %f[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n" "vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n" "vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q13, d16, d6 \n" "vmlal.s16 q13, d16, %e[ker1][0] \n"
"vmlal.s16 q13, d18, d7 \n" "vmlal.s16 q13, d18, %e[ker1][1] \n"
"vmlal.s16 q13, d20, d8 \n" "vmlal.s16 q13, d20, %e[ker1][2] \n"
"vmlal.s16 q14, d17, d6 \n" "vmlal.s16 q14, d17, %e[ker1][0] \n"
"vmlal.s16 q14, d19, d7 \n" "vmlal.s16 q14, d19, %e[ker1][1] \n"
"vmlal.s16 q14, d21, d8 \n" "vmlal.s16 q14, d21, %e[ker1][2] \n"
// store row 1 // store row 1
"vst1.32 {d26-d28}, [%[output_ptr1]]! \n" "vst1.32 {d26-d28}, [%[output_ptr1]]! \n"
"vmull.s16 q11, d16, d0 \n" "vmull.s16 q11, d16, %e[ker0][0] \n"
"vmlal.s16 q11, d18, d1 \n" "vmlal.s16 q11, d18, %e[ker0][1] \n"
"vmlal.s16 q11, d20, d2 \n" "vmlal.s16 q11, d20, %e[ker0][2] \n"
"vmull.s16 q12, d17, d0 \n" "vmull.s16 q12, d17, %e[ker0][0] \n"
"vmlal.s16 q12, d19, d1 \n" "vmlal.s16 q12, d19, %e[ker0][1] \n"
"vmlal.s16 q12, d21, d2 \n" "vmlal.s16 q12, d21, %e[ker0][2] \n"
"vext.s8 d9, d14, d14, #1 \n" "vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n" "vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n" "vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n" "vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, d3 \n" "vmlal.s16 q11, d16, %f[ker0][0] \n"
"vmlal.s16 q11, d18, d4 \n" "vmlal.s16 q11, d18, %f[ker0][1] \n"
"vmlal.s16 q11, d20, d5 \n" "vmlal.s16 q11, d20, %f[ker0][2] \n"
"vmlal.s16 q12, d17, d3 \n" "vmlal.s16 q12, d17, %f[ker0][0] \n"
"vmlal.s16 q12, d19, d4 \n" "vmlal.s16 q12, d19, %f[ker0][1] \n"
"vmlal.s16 q12, d21, d5 \n" "vmlal.s16 q12, d21, %f[ker0][2] \n"
"vld2.8 {d10, d11}, [%[input_ptr6]], r0 \n" "vld2.8 {d10-d11}, [%[input_ptr6]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n" "vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n" "vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n" "vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n" "vmovl.s8 q9, d11 \n"
"vmlal.s16 q11, d16, d6 \n" "vmlal.s16 q11, d16, %e[ker1][0] \n"
"vmlal.s16 q11, d18, d7 \n" "vmlal.s16 q11, d18, %e[ker1][1] \n"
"vmlal.s16 q11, d20, d8 \n" "vmlal.s16 q11, d20, %e[ker1][2] \n"
"vmlal.s16 q12, d17, d6 \n" "vmlal.s16 q12, d17, %e[ker1][0] \n"
"vmlal.s16 q12, d19, d7 \n" "vmlal.s16 q12, d19, %e[ker1][1] \n"
"vmlal.s16 q12, d21, d8 \n" "vmlal.s16 q12, d21, %e[ker1][2] \n"
// store row 2 // store row 2
"vst1.32 {d22-d24}, [%[output_ptr2]]! \n" "vst1.32 {d22-d24}, [%[output_ptr2]]! \n"
...@@ -1270,104 +1270,105 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input, ...@@ -1270,104 +1270,105 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input,
"start_remain_%=: \n" "start_remain_%=: \n"
"cmp %[remain], #0 \n" "cmp %[remain], #0 \n"
"ble end_%= \n" "ble end_%= \n"
"mov r0, %[remain], lsl #1 \n"
"vld2.8 {d10, d11}, [%[input_ptr0]] \n" "vld2.8 {d10-d11}, [%[input_ptr0]], r0 \n"
"vld2.8 {d12, d13}, [%[input_ptr1]] \n" "vld2.8 {d12-d13}, [%[input_ptr1]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n" "vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n" "vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n" "vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n" "vmovl.s8 q8, d11 \n"
"vmull.s16 q10, d14, d0 \n" "vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, d1 \n" "vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, d2 \n" "vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, d0 \n" "vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, d1 \n" "vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, d2 \n" "vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n" "vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q9, d9 \n" "vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d12 \n" "vmovl.s8 q7, d12 \n"
"vmovl.s8 q8, d13 \n" "vmovl.s8 q8, d13 \n"
"vmlal.s16 q10, d14, d3 \n" "vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, d4 \n" "vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, d5 \n" "vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, d3 \n" "vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, d4 \n" "vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, d5 \n" "vmlal.s16 q11, d19, %f[ker0][2] \n"
"vld2.8 {d10, d11}, [%[input_ptr2]] \n" "vld2.8 {d10-d11}, [%[input_ptr2]], r0 \n"
"vld2.8 {d12, d13}, [%[input_ptr3]] \n" "vld2.8 {d12-d13}, [%[input_ptr3]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n" "vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n" "vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n" "vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n" "vmovl.s8 q8, d11 \n"
"vmlal.s16 q10, d14, d6 \n" "vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, d7 \n" "vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, d8 \n" "vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, d6 \n" "vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, d7 \n" "vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, d8 \n" "vmlal.s16 q11, d19, %e[ker1][2] \n"
"vmull.s16 q12, d14, d0 \n" "vmull.s16 q12, d14, %e[ker0][0] \n"
"vmlal.s16 q12, d16, d1 \n" "vmlal.s16 q12, d16, %e[ker0][1] \n"
"vmlal.s16 q12, d18, d2 \n" "vmlal.s16 q12, d18, %e[ker0][2] \n"
"vmull.s16 q13, d15, d0 \n" "vmull.s16 q13, d15, %e[ker0][0] \n"
"vmlal.s16 q13, d17, d1 \n" "vmlal.s16 q13, d17, %e[ker0][1] \n"
"vmlal.s16 q13, d19, d2 \n" "vmlal.s16 q13, d19, %e[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n" "vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q9, d9 \n" "vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d12 \n" "vmovl.s8 q7, d12 \n"
"vmovl.s8 q8, d13 \n" "vmovl.s8 q8, d13 \n"
"vmlal.s16 q12, d14, d3 \n" "vmlal.s16 q12, d14, %f[ker0][0] \n"
"vmlal.s16 q12, d16, d4 \n" "vmlal.s16 q12, d16, %f[ker0][1] \n"
"vmlal.s16 q12, d18, d5 \n" "vmlal.s16 q12, d18, %f[ker0][2] \n"
"vmlal.s16 q13, d15, d3 \n" "vmlal.s16 q13, d15, %f[ker0][0] \n"
"vmlal.s16 q13, d17, d4 \n" "vmlal.s16 q13, d17, %f[ker0][1] \n"
"vmlal.s16 q13, d19, d5 \n" "vmlal.s16 q13, d19, %f[ker0][2] \n"
"vld2.8 {d10, d11}, [%[input_ptr4]] \n" "vld2.8 {d10-d11}, [%[input_ptr4]], r0 \n"
"vld2.8 {d12, d13}, [%[input_ptr5]] \n" "vld2.8 {d12-d13}, [%[input_ptr5]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n" "vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n" "vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n" "vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n" "vmovl.s8 q8, d11 \n"
"vmlal.s16 q12, d14, d6 \n" "vmlal.s16 q12, d14, %e[ker1][0] \n"
"vmlal.s16 q12, d16, d7 \n" "vmlal.s16 q12, d16, %e[ker1][1] \n"
"vmlal.s16 q12, d18, d8 \n" "vmlal.s16 q12, d18, %e[ker1][2] \n"
"vmlal.s16 q13, d15, d6 \n" "vmlal.s16 q13, d15, %e[ker1][0] \n"
"vmlal.s16 q13, d17, d7 \n" "vmlal.s16 q13, d17, %e[ker1][1] \n"
"vmlal.s16 q13, d19, d8 \n" "vmlal.s16 q13, d19, %e[ker1][2] \n"
"vmull.s16 q14, d14, d0 \n" "vmull.s16 q14, d14, %e[ker0][0] \n"
"vmlal.s16 q14, d16, d1 \n" "vmlal.s16 q14, d16, %e[ker0][1] \n"
"vmlal.s16 q14, d18, d2 \n" "vmlal.s16 q14, d18, %e[ker0][2] \n"
"vmull.s16 q15, d15, d0 \n" "vmull.s16 q15, d15, %e[ker0][0] \n"
"vmlal.s16 q15, d17, d1 \n" "vmlal.s16 q15, d17, %e[ker0][1] \n"
"vmlal.s16 q15, d19, d2 \n" "vmlal.s16 q15, d19, %e[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n" "vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q9, d9 \n" "vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d12 \n" "vmovl.s8 q7, d12 \n"
"vmovl.s8 q8, d13 \n" "vmovl.s8 q8, d13 \n"
"vmlal.s16 q14, d14, d3 \n" "vmlal.s16 q14, d14, %f[ker0][0] \n"
"vmlal.s16 q14, d16, d4 \n" "vmlal.s16 q14, d16, %f[ker0][1] \n"
"vmlal.s16 q14, d18, d5 \n" "vmlal.s16 q14, d18, %f[ker0][2] \n"
"vmlal.s16 q15, d15, d3 \n" "vmlal.s16 q15, d15, %f[ker0][0] \n"
"vmlal.s16 q15, d17, d4 \n" "vmlal.s16 q15, d17, %f[ker0][1] \n"
"vmlal.s16 q15, d19, d5 \n" "vmlal.s16 q15, d19, %f[ker0][2] \n"
"vld2.8 {d10, d11}, [%[input_ptr6]] \n" "vld2.8 {d10-d11}, [%[input_ptr6]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n" "vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n" "vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n" "vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n" "vmovl.s8 q8, d11 \n"
"vmlal.s16 q14, d14, d6 \n" "vmlal.s16 q14, d14, %e[ker1][0] \n"
"vmlal.s16 q14, d16, d7 \n" "vmlal.s16 q14, d16, %e[ker1][1] \n"
"vmlal.s16 q14, d18, d8 \n" "vmlal.s16 q14, d18, %e[ker1][2] \n"
"vmlal.s16 q15, d15, d6 \n" "vmlal.s16 q15, d15, %e[ker1][0] \n"
"vmlal.s16 q15, d17, d7 \n" "vmlal.s16 q15, d17, %e[ker1][1] \n"
"vmlal.s16 q15, d19, d8 \n" "vmlal.s16 q15, d19, %e[ker1][2] \n"
"cmp %[remain], #4 \n" "cmp %[remain], #4 \n"
"blt store_3h2w_%= \n" "blt store_3h2w_%= \n"
...@@ -1407,35 +1408,90 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input, ...@@ -1407,35 +1408,90 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input,
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5), [input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5),
[loop] "+r"(loop) [loop] "+r"(loop)
: [remain] "r"(output_w_remain) : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); "q12", "q13", "q14", "q15", "r0");
// pad right
if (padding_w > 0) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3)));
int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4)));
int16x4_t row5 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr5)));
int16x4_t row6 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr6)));
int32x4_t acc0, acc1, acc2;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
*output_ptr1 = 0;
*output_ptr2 = 0;
} else {
acc0 = vmull_s16(row0, _ker[0]);
acc0 = vmlal_s16(acc0, row1, _ker[1]);
acc0 = vmlal_s16(acc0, row2, _ker[2]);
acc1 = vmull_s16(row2, _ker[0]);
acc1 = vmlal_s16(acc1, row3, _ker[1]);
acc1 = vmlal_s16(acc1, row4, _ker[2]);
acc2 = vmull_s16(row4, _ker[0]);
acc2 = vmlal_s16(acc2, row5, _ker[1]);
acc2 = vmlal_s16(acc2, row6, _ker[2]);
int32_t sum0 = vgetq_lane_s32(acc0, 0);
int32_t sum1 = vgetq_lane_s32(acc1, 0);
int32_t sum2 = vgetq_lane_s32(acc2, 0);
if (padding == 1) {
sum0 += vgetq_lane_s32(acc0, 1);
sum1 += vgetq_lane_s32(acc1, 1);
sum2 += vgetq_lane_s32(acc2, 1);
} }
*output_ptr0 = sum0;
*output_ptr1 = sum1;
*output_ptr2 = sum2;
}
output_ptr0++;
output_ptr1++;
output_ptr2++;
}
}
}
// remain height
int start_h = valid_h_start + valid_h / 3 * 3; int start_h = valid_h_start + valid_h / 3 * 3;
for (int h = start_h; h < valid_h_end; ++h) { for (int h = start_h; h < valid_h_end; ++h) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start; const int8_t *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
const int8_t *input_ptr0 = input_ptr + offset;
const int8_t *input_ptr1 = input_ptr0 + input_w; const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w; const int8_t *input_ptr2 = input_ptr1 + input_w;
int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start; int32_t *output_ptr0 = output_ptr + h * output_w;
// pad left
if (padding_w) {
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - (w << 1);
if (padding >= 3) {
output_ptr0[w] = 0;
} else {
int16x4_t row0 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - padding)));
int16x4_t row1 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - padding)));
int16x4_t row2 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - padding)));
int32x4_t acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
int32_t sum0 = vgetq_lane_s32(acc, 2);
if (padding == 1) {
sum0 += vgetq_lane_s32(acc, 1);
}
output_ptr0[w] = sum0;
}
}
input_ptr0 += valid_input_w_start;
input_ptr1 += valid_input_w_start;
input_ptr2 += valid_input_w_start;
output_ptr0 += valid_w_start;
}
// valid
int loop = output_w_tiles; int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile( asm volatile(
"cmp %[loop], #0 \n" "cmp %[loop], #0 \n"
"ble start_remain_%= \n" "ble start_remain_%= \n"
...@@ -1449,34 +1505,34 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input, ...@@ -1449,34 +1505,34 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input,
"vmovl.s8 q10, d9 \n" "vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n" "vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n" "vmovl.s8 q9, d11 \n"
"vmull.s16 q11, d16, d0 \n" "vmull.s16 q11, d16, %e[ker0][0] \n"
"vmlal.s16 q11, d18, d1 \n" "vmlal.s16 q11, d18, %e[ker0][1] \n"
"vmlal.s16 q11, d20, d2 \n" "vmlal.s16 q11, d20, %e[ker0][2] \n"
"vmull.s16 q12, d17, d0 \n" "vmull.s16 q12, d17, %e[ker0][0] \n"
"vmlal.s16 q12, d19, d1 \n" "vmlal.s16 q12, d19, %e[ker0][1] \n"
"vmlal.s16 q12, d21, d2 \n" "vmlal.s16 q12, d21, %e[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n" "vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n" "vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q11, d16, d3 \n" "vmlal.s16 q11, d16, %f[ker0][0] \n"
"vmlal.s16 q11, d18, d4 \n" "vmlal.s16 q11, d18, %f[ker0][1] \n"
"vmlal.s16 q11, d20, d5 \n" "vmlal.s16 q11, d20, %f[ker0][2] \n"
"vmlal.s16 q12, d17, d3 \n" "vmlal.s16 q12, d17, %f[ker0][0] \n"
"vmlal.s16 q12, d19, d4 \n" "vmlal.s16 q12, d19, %f[ker0][1] \n"
"vmlal.s16 q12, d21, d5 \n" "vmlal.s16 q12, d21, %f[ker0][2] \n"
"vext.s8 d9, d14, d14, #1 \n" "vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n" "vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n" "vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n" "vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, d6 \n" "vmlal.s16 q11, d16, %e[ker1][0] \n"
"vmlal.s16 q11, d18, d7 \n" "vmlal.s16 q11, d18, %e[ker1][1] \n"
"vmlal.s16 q11, d20, d8 \n" "vmlal.s16 q11, d20, %e[ker1][2] \n"
"vmlal.s16 q12, d17, d6 \n" "vmlal.s16 q12, d17, %e[ker1][0] \n"
"vmlal.s16 q12, d19, d7 \n" "vmlal.s16 q12, d19, %e[ker1][1] \n"
"vmlal.s16 q12, d21, d8 \n" "vmlal.s16 q12, d21, %e[ker1][2] \n"
// store row 0 // store row 0
"vst1.32 {d22-d24}, [%[output_ptr0]]! \n" "vst1.32 {d22-d24}, [%[output_ptr0]]! \n"
...@@ -1486,41 +1542,43 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input, ...@@ -1486,41 +1542,43 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input,
"start_remain_%=: \n" "start_remain_%=: \n"
"cmp %[remain], #0 \n" "cmp %[remain], #0 \n"
"ble end_%= \n" "ble end_%= \n"
"vld2.8 {d10, d11}, [%[input_ptr0]] \n" "mov r0, %[remain], lsl #1 \n"
"vld2.8 {d12, d13}, [%[input_ptr1]] \n"
"vld2.8 {d14, d15}, [%[input_ptr2]] \n" "vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n"
"vld2.8 {d12, d13}, [%[input_ptr1]], r0 \n"
"vld2.8 {d14, d15}, [%[input_ptr2]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n" "vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n" "vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n" "vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n" "vmovl.s8 q9, d11 \n"
"vmull.s16 q11, d16, d0 \n" "vmull.s16 q11, d16, %e[ker0][0] \n"
"vmlal.s16 q11, d18, d1 \n" "vmlal.s16 q11, d18, %e[ker0][1] \n"
"vmlal.s16 q11, d20, d2 \n" "vmlal.s16 q11, d20, %e[ker0][2] \n"
"vmull.s16 q12, d17, d0 \n" "vmull.s16 q12, d17, %e[ker0][0] \n"
"vmlal.s16 q12, d19, d1 \n" "vmlal.s16 q12, d19, %e[ker0][1] \n"
"vmlal.s16 q12, d21, d2 \n" "vmlal.s16 q12, d21, %e[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n" "vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n" "vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q11, d16, d3 \n" "vmlal.s16 q11, d16, %f[ker0][0] \n"
"vmlal.s16 q11, d18, d4 \n" "vmlal.s16 q11, d18, %f[ker0][1] \n"
"vmlal.s16 q11, d20, d5 \n" "vmlal.s16 q11, d20, %f[ker0][2] \n"
"vmlal.s16 q12, d17, d3 \n" "vmlal.s16 q12, d17, %f[ker0][0] \n"
"vmlal.s16 q12, d19, d4 \n" "vmlal.s16 q12, d19, %f[ker0][1] \n"
"vmlal.s16 q12, d21, d5 \n" "vmlal.s16 q12, d21, %f[ker0][2] \n"
"vext.s8 d9, d14, d14, #1 \n" "vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n" "vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n" "vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n" "vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, d6 \n" "vmlal.s16 q11, d16, %e[ker1][0] \n"
"vmlal.s16 q11, d18, d7 \n" "vmlal.s16 q11, d18, %e[ker1][1] \n"
"vmlal.s16 q11, d20, d8 \n" "vmlal.s16 q11, d20, %e[ker1][2] \n"
"vmlal.s16 q12, d17, d6 \n" "vmlal.s16 q12, d17, %e[ker1][0] \n"
"vmlal.s16 q12, d19, d7 \n" "vmlal.s16 q12, d19, %e[ker1][1] \n"
"vmlal.s16 q12, d21, d8 \n" "vmlal.s16 q12, d21, %e[ker1][2] \n"
"cmp %[remain], #4 \n" "cmp %[remain], #4 \n"
"blt store_1h2w_%= \n" "blt store_1h2w_%= \n"
...@@ -1547,9 +1605,38 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input, ...@@ -1547,9 +1605,38 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input,
: [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0), : [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0),
[input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2),
[loop] "+r"(loop) [loop] "+r"(loop)
: [remain] "r"(output_w_remain) : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
"q8", "q9", "q10", "q11", "q12", "r0"); "q12", "q13", "q14", "q15", "r0");
// pad right
if (padding_w > 0) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int32x4_t acc;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
} else {
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
int32_t sum0 = vgetq_lane_s32(acc, 0);
if (padding == 1) {
sum0 += vgetq_lane_s32(acc, 1);
}
*output_ptr0 = sum0;
}
output_ptr0++;
}
}
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册