未验证 提交 d92d1c2f 编写于 作者: S smilejames 提交者: GitHub

Merge pull request #737 from yangfei963158659/develop

implement multithreading 3x3 s2 depth_conv
...@@ -28,13 +28,13 @@ Paddle-Moible是PaddlePaddle组织下的项目,是一个致力于嵌入式平 ...@@ -28,13 +28,13 @@ Paddle-Moible是PaddlePaddle组织下的项目,是一个致力于嵌入式平
|mobilenet arm v7|1线程|2线程|4线程| |mobilenet arm v7|1线程|2线程|4线程|
|------------|----|-----|-----| |------------|----|-----|-----|
|麒麟960(ms)|110.586|72.474|49.833| |麒麟960(ms)|110.586|70.897|47.474|
||||| |||||
|mobilenetssd arm v7|1线程|2线程|4线程| |mobilenetssd arm v7|1线程|2线程|4线程|
|麒麟960(ms)|224.464|142.544|96.068| |麒麟960(ms)|222.124|138.952|90.856|
||||| |||||
|googlenet(v1) arm v7|1线程|2线程|4线程| |googlenet(v1) arm v7|1线程|2线程|4线程|
|麒麟960(ms)|348.018|242.689|169.998| |麒麟960(ms)|348.018|240.304|169.998|
arm cpu是paddle-mobile的主要支持方向,cpu的通用性一直是其优势。嵌入式深度学习,需要大量的cpu汇编实现。我们正在紧锣密鼓的编码,为的是能充分硬件的每一点加速能力。 arm cpu是paddle-mobile的主要支持方向,cpu的通用性一直是其优势。嵌入式深度学习,需要大量的cpu汇编实现。我们正在紧锣密鼓的编码,为的是能充分硬件的每一点加速能力。
arm cpu的优化工作还在进行中,现在使用了常规的cpu优化。在arm a73上paddle-mobile arm-v7现在单核运行一次mobilenet1.0是110+ms,显然这不是我们的最终目标,我们正在用大量的汇编改写,后续性能仍会有巨大提升空间, 目前只支持armv7, 未来我们也会支持armv8。 arm cpu的优化工作还在进行中,现在使用了常规的cpu优化。在arm a73上paddle-mobile arm-v7现在单核运行一次mobilenet1.0是110+ms,显然这不是我们的最终目标,我们正在用大量的汇编改写,后续性能仍会有巨大提升空间, 目前只支持armv7, 未来我们也会支持armv8。
......
...@@ -613,7 +613,6 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -613,7 +613,6 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
} }
int m; int m;
for (m = 1; m < output_width - 4; m += 4) { for (m = 1; m < output_width - 4; m += 4) {
float *output_ptr = output_data + m; float *output_ptr = output_data + m;
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
...@@ -637,7 +636,8 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -637,7 +636,8 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
} }
vst1q_f32(output_ptr, out0); vst1q_f32(output_ptr, out0);
} }
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
for (m = 1; (m + 3) < output_width - 1; m += 4) {
} }
for (int j = m; j < output_width - 1; j++) { for (int j = m; j < output_width - 1; j++) {
output_data[j] = input_data[j - 1] * w10 + input_data[j] * w11 + output_data[j] = input_data[j - 1] * w10 + input_data[j] * w11 +
...@@ -652,7 +652,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -652,7 +652,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
} }
} }
for (m = 1; (m + 3) < output_width - 1; m = m + 4) { for (m = 1; m < output_width - 4; m += 4) {
float *output_ptr = float *output_ptr =
output_data + (output_height - 1) * output_width + m; output_data + (output_height - 1) * output_width + m;
...@@ -769,305 +769,295 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -769,305 +769,295 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
} }
/* /*
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
float *output_data = output->data<float>(); float *output_data = output->data<float>();
const float *newscale_data = new_scale->data<float>(); const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>(); const float *newbias_data = new_bias->data<float>();
const int h = static_cast<int>(input->dims()[2]); const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]); const int w = static_cast<int>(input->dims()[3]);
const int l = h; const int l = h;
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]); const int c = static_cast<int>(input->dims()[1]);
const int hxw = h * w; const int hxw = h * w;
float32x4_t vnewbias = vdupq_n_f32(0.0); float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0); float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t vzero = vdupq_n_f32(0); float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
const float *filter_data_tmp = filter_data; const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) { for (int j = 0; j < c; ++j) {
vnewbias = vdupq_n_f32(newbias_data[j]); vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]); vnewscale = vdupq_n_f32(newscale_data[j]);
int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0]; float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1]; float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2]; float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3]; float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4]; float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5]; float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6]; float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7]; float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8]; float w22 = filter_data_tmp[8];
output_data[0] = w11 * input_data[0] + w12 * input_data[1] + output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[l] + w22 * input_data[l + 1]; w21 * input_data[l] + w22 * input_data[l + 1];
output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l -
w20 * input_data[2 * l - 2] + 1] + w20 * input_data[2 * l - 2] + w21 * input_data[2 * l - 1];
w21 * input_data[2 * l - 1];
output_data[(l - 1) * l] =
output_data[(l - 1) * l] = w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l +
w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + 1] + w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1];
w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1]; output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] +
output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] + w01 * input_data[(l - 2) * (l + 1) + 1] +
w01 * input_data[(l - 2) * (l + 1) + 1] + w10 * input_data[l * l - 2] +
w10 * input_data[l * l - 2] + w11 * input_data[l * l - 1];
w11 * input_data[l * l - 1]; output_data[0] = output_data[0] * newscale_data[j] +
output_data[0] = output_data[0] * newscale_data[j] + newbias_data[j]; newbias_data[j]; output_data[l - 1] = output_data[l - 1] *
output_data[l - 1] = newscale_data[j] + newbias_data[j]; output_data[(l - 1) * l] =
output_data[l - 1] * newscale_data[j] + newbias_data[j]; output_data[(l - 1) * l] * newscale_data[j] + newbias_data[j];
output_data[(l - 1) * l] = output_data[l * l - 1] =
output_data[(l - 1) * l] * newscale_data[j] + newbias_data[j]; output_data[l * l - 1] * newscale_data[j] + newbias_data[j];
output_data[l * l - 1] =
output_data[l * l - 1] * newscale_data[j] + newbias_data[j];
if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l - 1];
output_data[(l - 1) * l] =
output_data[(l - 1) * l] < 0 ? 0 : output_data[(l - 1) * l];
output_data[l * l - 1] =
output_data[l * l - 1] < 0 ? 0 : output_data[l * l - 1];
}
for (int i = 1; i < l - 1; ++i) {
output_data[i * l] =
w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] +
w11 * input_data[i * l] + w12 * input_data[i * l + 1] +
w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1];
output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] +
w01 * input_data[i * l + l - 1 - l] +
w10 * input_data[i * l + l - 1 - 1] +
w11 * input_data[i * l + l - 1] +
w20 * input_data[i * l + l - 1 + l - 1] +
w21 * input_data[i * l + l - 1 + l];
output_data[i * l] =
output_data[i * l] * newscale_data[j] + newbias_data[j];
output_data[i * l + l - 1] =
output_data[i * l + l - 1] * newscale_data[j] + newbias_data[j];
if (if_relu) {
output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i *
l]; output_data[i * l + l - 1] =
output_data[i * l + l - 1] < 0 ? 0 : output_data[i * l + l - 1];
}
}
// top 1 row and bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, out0;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + l);
const float *input_tmp_end = input_tmp + (l - 2) * l;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + l);
int c_mid = l_mid;
auto output_ptr = output_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + l + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + l + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr + (l - 1) * l, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right pad if (if_relu) {
float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l -
1]; output_data[(l - 1) * l] = output_data[(l - 1) * l] < 0 ? 0 :
output_data[(l - 1) * l]; output_data[l * l - 1] = output_data[l * l - 1]
< 0 ? 0 : output_data[l * l - 1];
}
for (int i = 1; i < l - 1; ++i) {
output_data[i * l] =
w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1]
+ w11 * input_data[i * l] + w12 * input_data[i * l + 1] + w21 *
input_data[i * l + l] + w22 * input_data[i * l + l + 1]; output_data[i *
l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + w01 * input_data[i
* l + l - 1 - l] + w10 * input_data[i * l + l - 1 - 1] + w11 *
input_data[i * l + l - 1] + w20 * input_data[i * l + l - 1 + l - 1] + w21
* input_data[i * l + l - 1 + l]; output_data[i * l] = output_data[i * l]
* newscale_data[j] + newbias_data[j]; output_data[i * l + l - 1] =
output_data[i * l + l - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) {
output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i
* l]; output_data[i * l + l - 1] = output_data[i * l + l - 1] < 0 ? 0 :
output_data[i * l + l - 1];
}
}
tmp0 = vextq_f32(in0, pad0, 1); // top 1 row and bottom 1 row
tmp1 = vextq_f32(in0, pad0, 2); const float *input_tmp = input_data;
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2); float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1,
tmp2, tmp3, tmp4, tmp5, out0; in0 = vld1q_f32(input_tmp); in2 =
vld1q_f32(input_tmp + l); const float *input_tmp_end = input_tmp + (l -
2) * l; in4 = vld1q_f32(input_tmp_end); in6 = vld1q_f32(input_tmp_end +
l); int c_mid = l_mid; auto output_ptr = output_data + 1; for (; c_mid >
3; c_mid -= 4) { in1 = vld1q_f32(input_tmp + 4); in3 =
vld1q_f32(input_tmp + l + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + l + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr + (l - 1) * l, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
out0 = vmulq_n_f32(in0, w10); // top right pad
out0 = vmlaq_n_f32(out0, tmp0, w11); float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]);
out0 = vmlaq_n_f32(out0, tmp1, w12); float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21); tmp0 = vextq_f32(in0, pad0, 1);
out0 = vmlaq_n_f32(out0, tmp3, w22); tmp1 = vextq_f32(in0, pad0, 2);
out0 = vmlaq_f32(vnewbias, vnewscale, out0); tmp2 = vextq_f32(in2, pad1, 1);
if (if_relu) { tmp3 = vextq_f32(in2, pad1, 2);
out0 = vmaxq_f32(out0, vzero);
} out0 = vmulq_n_f32(in0, w10);
for (int i = 0; i < c_mid; ++i) { out0 = vmlaq_n_f32(out0, tmp0, w11);
if (i == 0) { out0 = vmlaq_n_f32(out0, tmp1, w12);
vst1q_lane_f32(output_ptr + i, out0, 0); out0 = vmlaq_n_f32(out0, in2, w20);
} out0 = vmlaq_n_f32(out0, tmp2, w21);
if (i == 1) { out0 = vmlaq_n_f32(out0, tmp3, w22);
vst1q_lane_f32(output_ptr + i, out0, 1); out0 = vmlaq_f32(vnewbias, vnewscale, out0);
} if (if_relu) {
if (i == 2) { out0 = vmaxq_f32(out0, vzero);
vst1q_lane_f32(output_ptr + i, out0, 2); }
} for (int i = 0; i < c_mid; ++i) {
} if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
// bottom right pad // bottom right pad
float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]);
float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]);
tmp0 = vextq_f32(in4, pad2, 1); tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2); tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 1); tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2); tmp3 = vextq_f32(in6, pad3, 2);
out0 = vmulq_n_f32(in4, w00); out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2);
}
}
// mid
for (int i = 0; i < l - 2; ++i) {
auto output_ptr = output_data + (i + 1) * l + 1;
input_tmp = input_data + i * l;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + l);
auto in4_tmp = vld1q_f32(input_tmp + l + l);
c_mid = l_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + l + 4);
auto in5_tmp = vld1q_f32(input_tmp + l + l + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01); out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02); out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10); out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11); out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12); out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0); out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) { if (if_relu) {
out0 = vmaxq_f32(out0, vzero); out0 = vmaxq_f32(out0, vzero);
} }
vst1q_f32(output_ptr, out0); for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
output_ptr += 4; vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0);
input_tmp += 4; }
in0_tmp = in1_tmp; if (i == 1) {
in2_tmp = in3_tmp; vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1);
in4_tmp = in5_tmp; }
} if (i == 2) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2);
float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); }
float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); }
float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); // mid
for (int i = 0; i < l - 2; ++i) {
auto output_ptr = output_data + (i + 1) * l + 1;
input_tmp = input_data + i * l;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + l);
auto in4_tmp = vld1q_f32(input_tmp + l + l);
c_mid = l_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + l + 4);
auto in5_tmp = vld1q_f32(input_tmp + l + l + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
tmp0 = vextq_f32(in0_tmp, pad0, 1); output_ptr += 4;
tmp1 = vextq_f32(in0_tmp, pad0, 2); input_tmp += 4;
tmp2 = vextq_f32(in2_tmp, pad1, 1); in0_tmp = in1_tmp;
tmp3 = vextq_f32(in2_tmp, pad1, 2); in2_tmp = in3_tmp;
tmp4 = vextq_f32(in4_tmp, pad2, 1); in4_tmp = in5_tmp;
tmp5 = vextq_f32(in4_tmp, pad2, 2); }
out0 = vmulq_n_f32(in0_tmp, w00); float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]);
out0 = vmlaq_n_f32(out0, tmp0, w01); float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]);
out0 = vmlaq_n_f32(out0, tmp1, w02); float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11); tmp0 = vextq_f32(in0_tmp, pad0, 1);
out0 = vmlaq_n_f32(out0, tmp3, w12); tmp1 = vextq_f32(in0_tmp, pad0, 2);
out0 = vmlaq_n_f32(out0, in4_tmp, w20); tmp2 = vextq_f32(in2_tmp, pad1, 1);
out0 = vmlaq_n_f32(out0, tmp4, w21); tmp3 = vextq_f32(in2_tmp, pad1, 2);
out0 = vmlaq_n_f32(out0, tmp5, w22); tmp4 = vextq_f32(in4_tmp, pad2, 1);
out0 = vmlaq_f32(vnewbias, vnewscale, out0); tmp5 = vextq_f32(in4_tmp, pad2, 2);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero); out0 = vmulq_n_f32(in0_tmp, w00);
} out0 = vmlaq_n_f32(out0, tmp0, w01);
for (int i = 0; i < c_mid; ++i) { out0 = vmlaq_n_f32(out0, tmp1, w02);
if (i == 0) { out0 = vmlaq_n_f32(out0, in2_tmp, w10);
vst1q_lane_f32(output_ptr + i, out0, 0); out0 = vmlaq_n_f32(out0, tmp2, w11);
} out0 = vmlaq_n_f32(out0, tmp3, w12);
if (i == 1) { out0 = vmlaq_n_f32(out0, in4_tmp, w20);
vst1q_lane_f32(output_ptr + i, out0, 1); out0 = vmlaq_n_f32(out0, tmp4, w21);
} out0 = vmlaq_n_f32(out0, tmp5, w22);
if (i == 2) { out0 = vmlaq_f32(vnewbias, vnewscale, out0);
vst1q_lane_f32(output_ptr + i, out0, 2); if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
} }
output_data += hxw;
input_data += hxw;
filter_data_tmp += 9;
} }
} }
output_data += hxw; */
input_data += hxw;
filter_data_tmp += 9;
}
}
*/
#endif #endif
} }
...@@ -1482,230 +1472,421 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1482,230 +1472,421 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
const float *newscale_data = new_scale->data<float>(); const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>(); const float *newbias_data = new_bias->data<float>();
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]);
const int out_h = static_cast<int>(output->dims()[2]);
const int out_w = static_cast<int>(output->dims()[3]);
const int out_l = out_h;
const int in_l = in_h;
const int inhxw = in_h * in_w;
const int outhxw = out_h * out_w;
const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0;
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]); const int input_channel = static_cast<int>(input->dims()[1]);
const float *input_row_ptr;
float *output_row_ptr;
const int w_times = (out_w - 2) / 3; const int input_height = static_cast<int>(input->dims()[2]);
const int input_width = static_cast<int>(input->dims()[3]);
const int output_height = static_cast<int>(output->dims()[2]);
const int output_width = static_cast<int>(output->dims()[3]);
const int inhxw = input_height * input_width;
const int outhxw = output_height * output_width;
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1]; float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; float32x4_t vnewscale = vdupq_n_f32(1.0);
int out2in_mid;
float32x4_t zero = vdupq_n_f32(0.0); float32x4_t zero = vdupq_n_f32(0.0);
for (int b = batch_size; b > 0; --b) { for (int b = 0; b < batch_size; b++) {
const float *filter_data_tmp = filter_data; filter_data = filter->data<float>();
for (int j = 0; j < c; ++j) { for (int c = 0; c < input_channel; c++) {
auto output_data_tmp = output_data + j * out_h * out_w; vnewbias = vdupq_n_f32(newbias_data[c]);
auto input_data_tmp = input_data + j * in_h * in_w; vnewscale = vdupq_n_f32(newscale_data[c]);
auto input_const = input_data_tmp;
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
int h_mid = 0;
for (; h_mid < out_h - 1; h_mid++) {
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) {
if (h_mid == 0) {
elewise_res1 = zero;
elewise_res0 = zero;
elewise_res2 = zero;
} else {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
}
input_buff_mid = vld2q_f32(input_row_ptr);
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); float w00 = filter_data[0];
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); float w01 = filter_data[1];
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); float w02 = filter_data[2];
float w10 = filter_data[3];
float w11 = filter_data[4];
float w12 = filter_data[5];
float w20 = filter_data[6];
float w21 = filter_data[7];
float w22 = filter_data[8];
elewise_res1 = int m;
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); for (m = 1; m < output_width - 2; m = m + 3) {
elewise_res0 = float *output_ptr = output_data + m;
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); float32x4x2_t input_buff_mid{}, input_buff_bottom{};
elewise_res2 = float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); input_buff_mid = vld2q_f32(input_data + (2 * m - 1));
input_buff_bottom = vld2q_f32(input_data + input_width + (2 * m - 1));
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), in0 = input_buff_mid.val[0];
vaddq_f32(elewise_res0, elewise_res1)); tmp0 = input_buff_mid.val[1];
res3 = vmlaq_f32(vnewbias, vnewscale, res3); tmp1 = vextq_f32(in0, zero, 1);
if (if_relu) { in2 = input_buff_bottom.val[0];
res3 = vmaxq_f32(res3, zero); tmp2 = input_buff_bottom.val[1];
} tmp3 = vextq_f32(in2, zero, 1);
vst1q_f32(output_row_ptr, res3);
input_row_ptr += 6; out0 = vmulq_n_f32(in0, w10);
output_row_ptr += 3; out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
vst1q_f32(output_ptr, out0);
}
for (m = 1; m < output_width - 2; m += 3) {
}
for (int j = m; j < output_width; j++) {
output_data[j] = input_data[2 * j - 1] * w10 + input_data[2 * j] * w11 +
input_data[2 * j + 1] * w12 +
input_data[2 * j - 1 + input_width] * w20 +
input_data[2 * j + input_width] * w21 +
input_data[2 * j + 1 + input_width] * w22;
output_data[j] = newscale_data[c] * output_data[j] + newbias_data[c];
if (if_relu) {
output_data[j] = output_data[j] < 0 ? 0 : output_data[j];
} }
} }
clock();
input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; #pragma omp parallel for
output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
for (int w4 = 0; w4 < w_times + 1; w4++) { for (int i = 1; i < output_height; i += 1) {
elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); for (int m = 1; m < output_width - 2; m += 3) {
elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); float *output_ptr = output_data + i * output_width + m;
elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); float32x4x2_t input_buff_top{}, input_buff_mid{}, input_buff_bottom{};
float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3,
tmp4, tmp5, out0;
input_buff_top =
vld2q_f32(input_data + (2 * i - 1) * input_width + (2 * m - 1));
input_buff_mid =
vld2q_f32(input_data + (2 * i) * input_width + (2 * m - 1));
input_buff_bottom =
vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m - 1));
input_buff_mid = vld2q_f32(input_row_ptr); in0 = input_buff_top.val[0];
input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); tmp0 = input_buff_top.val[1];
tmp1 = vextq_f32(in0, zero, 1);
elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); in2 = input_buff_mid.val[0];
elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); tmp2 = input_buff_mid.val[1];
elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); tmp3 = vextq_f32(in2, zero, 1);
if (!if_pad) { in4 = input_buff_bottom.val[0];
elewise_res1 = tmp4 = input_buff_bottom.val[1];
vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); tmp5 = vextq_f32(in4, zero, 1);
elewise_res0 =
vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20);
elewise_res2 =
vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22);
}
res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
vaddq_f32(elewise_res0, elewise_res1));
res3 = vmlaq_f32(vnewbias, vnewscale, res3);
if (if_relu) { out0 = vmulq_n_f32(in0, w00);
res3 = vmaxq_f32(res3, zero); out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, zero);
}
vst1q_f32(output_ptr, out0);
} }
if ((w4 != w_times)) { int m;
vst1q_f32(output_row_ptr, res3); for (m = 1; m < output_width - 2; m += 3) {
} else { }
if (out_l - 2 - w_times * 3 == 1) { for (int j = m; j < output_width; j++) {
vst1q_lane_f32(output_row_ptr, res3, 0); output_data[i * output_width + j] =
} else if (out_l - 2 - w_times * 3 == 2) { input_data[(2 * i - 1) * input_width + 2 * j - 1] * w00 +
vst1q_lane_f32(output_row_ptr, res3, 0); input_data[(2 * i - 1) * input_width + 2 * j] * w01 +
vst1q_lane_f32(output_row_ptr + 1, res3, 1); input_data[(2 * i - 1) * input_width + 2 * j + 1] * w02 +
input_data[(2 * i) * input_width + 2 * j - 1] * w10 +
input_data[(2 * i) * input_width + 2 * j] * w11 +
input_data[(2 * i) * input_width + 2 * j + 1] * w12 +
input_data[(2 * i + 1) * input_width + 2 * j - 1] * w20 +
input_data[(2 * i + 1) * input_width + 2 * j] * w21 +
input_data[(2 * i + 1) * input_width + 2 * j + 1] * w22;
output_data[i * output_width + j] =
newscale_data[c] * output_data[i * output_width + j] +
newbias_data[c];
if (if_relu) {
output_data[i * output_width + j] =
output_data[i * output_width + j] < 0
? 0
: output_data[i * output_width + j];
} }
} }
input_row_ptr += 6;
output_row_ptr += 3;
} }
output_data[0] = input_data[0] * w11 + input_data[1] * w12 +
input_data[input_height] * w21 +
input_data[input_height + 1] * w22;
output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 + output_data[0] = newscale_data[c] * output_data[0] + newbias_data[c];
input_const[in_l] * w21 +
input_const[in_l + 1] * w22;
out2in_mid = (out_l - 1) * 2;
output_data_tmp[out_l - 1] =
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad) * (w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_l - 1) * 2 * in_w;
output_data_tmp[out_l * (out_l - 1)] =
w01 * input_const[out2in_mid - in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] +
(1 - if_pad) * (w21 * input_const[out2in_mid + in_w] +
w22 * input_const[out2in_mid + in_w + 1]);
out2in_mid = (out_l - 1) * 2 * in_w + (out_l - 1) * 2;
output_data_tmp[out_l * out_l - 1] =
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
(1 - if_pad) * (w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
output_data_tmp[0] =
output_data_tmp[0] * newscale_data[j] + newbias_data[j];
output_data_tmp[out_l - 1] =
output_data_tmp[out_l - 1] * newscale_data[j] + newbias_data[j];
output_data_tmp[out_l * (out_l - 1)] =
output_data_tmp[out_l * (out_l - 1)] * newscale_data[j] +
newbias_data[j];
output_data_tmp[out_l * out_l - 1] =
output_data_tmp[out_l * out_l - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) { if (if_relu) {
output_data_tmp[0] = output_data_tmp[0] < 0 ? 0 : output_data_tmp[0]; output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data_tmp[out_l - 1] =
output_data_tmp[out_l - 1] < 0 ? 0 : output_data_tmp[out_l - 1];
output_data_tmp[out_l * (out_l - 1)] =
output_data_tmp[out_l * (out_l - 1)] < 0
? 0
: output_data_tmp[out_l * (out_l - 1)];
output_data_tmp[out_l * out_l - 1] =
output_data_tmp[out_l * out_l - 1] < 0
? 0
: output_data_tmp[out_l * out_l - 1];
} }
for (int i = 1; i < out_h - 1; i++) { for (int i = 1; i < output_height; i++) {
out2in_mid = i * 2 * in_w; output_data[i * output_width] =
output_data_tmp[i * out_l] = w01 * input_const[out2in_mid - in_w] + input_data[(2 * i - 1) * input_width] * w01 +
w02 * input_const[out2in_mid - in_w + 1] + input_data[(2 * i - 1) * input_width + 1] * w02 +
w11 * input_const[out2in_mid] + input_data[(2 * i) * input_width] * w11 +
w12 * input_const[out2in_mid + 1] + input_data[(2 * i) * input_width + 1] * w12 +
w21 * input_const[out2in_mid + in_w] + input_data[(2 * i + 1) * input_width] * w21 +
w22 * input_const[out2in_mid + in_w + 1]; input_data[(2 * i + 1) * input_width + 1] * w22;
out2in_mid = i * 2 * in_w + (out_l - 1) * 2; output_data[i * output_width] =
output_data_tmp[i * out_l + out_l - 1] = newscale_data[c] * output_data[i * output_width] + newbias_data[c];
w00 * input_const[out2in_mid - in_w - 1] +
w01 * input_const[out2in_mid - in_w] +
w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] +
w20 * input_const[out2in_mid + in_w - 1] +
w21 * input_const[out2in_mid + in_w] +
(1 - if_pad) * (w02 * input_const[out2in_mid - in_w + 1] +
w12 * input_const[out2in_mid + 1] +
w22 * input_const[out2in_mid + in_w + 1]);
output_data_tmp[i * out_l] =
output_data_tmp[i * out_l] * newscale_data[j] + newbias_data[j];
output_data_tmp[i * out_l + out_l - 1] =
output_data_tmp[i * out_l + out_l - 1] * newscale_data[j] +
newbias_data[j];
if (if_relu) { if (if_relu) {
output_data_tmp[i * out_l] = output_data[i * output_width] = output_data[i * output_width] < 0
output_data_tmp[i * out_l] < 0 ? 0 : output_data_tmp[i * out_l]; ? 0
output_data_tmp[i * out_l + out_l - 1] = : output_data[i * output_width];
output_data_tmp[i * out_l + out_l - 1] < 0
? 0
: output_data_tmp[i * out_l + out_l - 1];
} }
} }
filter_data_tmp += 9;
input_data = input_data + inhxw;
output_data = output_data + outhxw;
filter_data = filter_data + 9;
} }
input_data += inhxw * c;
output_data += outhxw * c;
} }
// const float *input_data = input->data<float>();
// const float *filter_data = filter->data<float>();
// float *output_data = output->data<float>();
// const float *newscale_data = new_scale->data<float>();
// const float *newbias_data = new_bias->data<float>();
//
// float32x4_t vnewbias = vdupq_n_f32(0.0);
// float32x4_t vnewscale = vdupq_n_f32(1.0);
//
// const int in_h = static_cast<int>(input->dims()[2]);
// const int in_w = static_cast<int>(input->dims()[3]);
// const int out_h = static_cast<int>(output->dims()[2]);
// const int out_w = static_cast<int>(output->dims()[3]);
// const int out_l = out_h;
// const int in_l = in_h;
// const int inhxw = in_h * in_w;
// const int outhxw = out_h * out_w;
// const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0;
// const int batch_size = static_cast<int>(input->dims()[0]);
// const int c = static_cast<int>(input->dims()[1]);
// const float *input_row_ptr;
// float *output_row_ptr;
//
// const int w_times = (out_w - 2) / 3;
//
// float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1];
// float32x4_t elewise_res0, elewise_res1, elewise_res2, res3;
// int out2in_mid;
// float32x4_t zero = vdupq_n_f32(0.0);
// for (int b = batch_size; b > 0; --b) {
// const float *filter_data_tmp = filter_data;
// for (int j = 0; j < c; ++j) {
// auto output_data_tmp = output_data + j * out_h * out_w;
// auto input_data_tmp = input_data + j * in_h * in_w;
// auto input_const = input_data_tmp;
//
// vnewbias = vdupq_n_f32(newbias_data[j]);
// vnewscale = vdupq_n_f32(newscale_data[j]);
//
// float w00 = filter_data_tmp[0];
// float w01 = filter_data_tmp[1];
// float w02 = filter_data_tmp[2];
// float w10 = filter_data_tmp[3];
// float w11 = filter_data_tmp[4];
// float w12 = filter_data_tmp[5];
// float w20 = filter_data_tmp[6];
// float w21 = filter_data_tmp[7];
// float w22 = filter_data_tmp[8];
//
// int h_mid = 0;
//
// for (; h_mid < out_h - 1; h_mid++) {
// input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
// output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
//
// for (int w4 = 0; w4 < w_times + 1; w4++) {
// if (h_mid == 0) {
// elewise_res1 = zero;
// elewise_res0 = zero;
// elewise_res2 = zero;
// } else {
// elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
// elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
// elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
// }
// input_buff_mid = vld2q_f32(input_row_ptr);
// input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
//
// elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1],
// w11); elewise_res0 = vmlaq_n_f32(elewise_res0,
// input_buff_mid.val[0], w10); elewise_res2 =
// vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
//
// elewise_res1 =
// vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1],
// w21);
// elewise_res0 =
// vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0],
// w20);
// elewise_res2 =
// vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0],
// w22);
//
// res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
// vaddq_f32(elewise_res0, elewise_res1));
// res3 = vmlaq_f32(vnewbias, vnewscale, res3);
//
// if (if_relu) {
// res3 = vmaxq_f32(res3, zero);
// }
// vst1q_f32(output_row_ptr, res3);
//
// input_row_ptr += 6;
// output_row_ptr += 3;
// }
// }
// clock();
//
// input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w;
// output_row_ptr = output_data_tmp + 1 + h_mid * out_w;
//
// for (int w4 = 0; w4 < w_times + 1; w4++) {
// elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01);
// elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00);
// elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02);
//
// input_buff_mid = vld2q_f32(input_row_ptr);
// input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w);
//
// elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1],
// w11); elewise_res0 = vmlaq_n_f32(elewise_res0,
// input_buff_mid.val[0], w10); elewise_res2 =
// vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12);
//
// if (!if_pad) {
// elewise_res1 =
// vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1],
// w21);
// elewise_res0 =
// vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0],
// w20);
// elewise_res2 =
// vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0],
// w22);
// }
// res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1),
// vaddq_f32(elewise_res0, elewise_res1));
// res3 = vmlaq_f32(vnewbias, vnewscale, res3);
//
// if (if_relu) {
// res3 = vmaxq_f32(res3, zero);
// }
// if ((w4 != w_times)) {
// vst1q_f32(output_row_ptr, res3);
// } else {
// if (out_l - 2 - w_times * 3 == 1) {
// vst1q_lane_f32(output_row_ptr, res3, 0);
// } else if (out_l - 2 - w_times * 3 == 2) {
// vst1q_lane_f32(output_row_ptr, res3, 0);
// vst1q_lane_f32(output_row_ptr + 1, res3, 1);
// }
// }
// input_row_ptr += 6;
// output_row_ptr += 3;
// }
//
// output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 +
// input_const[in_l] * w21 +
// input_const[in_l + 1] * w22;
//
// out2in_mid = (out_l - 1) * 2;
// output_data_tmp[out_l - 1] =
// w10 * input_const[out2in_mid - 1] + w11 *
// input_const[out2in_mid] + w20 * input_const[out2in_mid + in_w -
// 1] + w21 * input_const[out2in_mid + in_w] + (1 - if_pad) * (w12
// * input_const[out2in_mid + 1] +
// w22 * input_const[out2in_mid + in_w + 1]);
//
// out2in_mid = (out_l - 1) * 2 * in_w;
//
// output_data_tmp[out_l * (out_l - 1)] =
// w01 * input_const[out2in_mid - in_w] +
// w02 * input_const[out2in_mid - in_w + 1] +
// w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid +
// 1] + (1 - if_pad) * (w21 * input_const[out2in_mid + in_w] +
// w22 * input_const[out2in_mid + in_w + 1]);
// out2in_mid = (out_l - 1) * 2 * in_w + (out_l - 1) * 2;
//
// output_data_tmp[out_l * out_l - 1] =
// w00 * input_const[out2in_mid - in_w - 1] +
// w01 * input_const[out2in_mid - in_w] +
// w10 * input_const[out2in_mid - 1] + w11 *
// input_const[out2in_mid] + (1 - if_pad) * (w20 *
// input_const[out2in_mid + in_w - 1] +
// w21 * input_const[out2in_mid + in_w] +
// w02 * input_const[out2in_mid - in_w + 1] +
// w12 * input_const[out2in_mid + 1] +
// w22 * input_const[out2in_mid + in_w + 1]);
// output_data_tmp[0] =
// output_data_tmp[0] * newscale_data[j] + newbias_data[j];
// output_data_tmp[out_l - 1] =
// output_data_tmp[out_l - 1] * newscale_data[j] + newbias_data[j];
// output_data_tmp[out_l * (out_l - 1)] =
// output_data_tmp[out_l * (out_l - 1)] * newscale_data[j] +
// newbias_data[j];
// output_data_tmp[out_l * out_l - 1] =
// output_data_tmp[out_l * out_l - 1] * newscale_data[j] +
// newbias_data[j];
// if (if_relu) {
// output_data_tmp[0] = output_data_tmp[0] < 0 ? 0 :
// output_data_tmp[0]; output_data_tmp[out_l - 1] =
// output_data_tmp[out_l - 1] < 0 ? 0 : output_data_tmp[out_l -
// 1];
// output_data_tmp[out_l * (out_l - 1)] =
// output_data_tmp[out_l * (out_l - 1)] < 0
// ? 0
// : output_data_tmp[out_l * (out_l - 1)];
// output_data_tmp[out_l * out_l - 1] =
// output_data_tmp[out_l * out_l - 1] < 0
// ? 0
// : output_data_tmp[out_l * out_l - 1];
// }
// for (int i = 1; i < out_h - 1; i++) {
// out2in_mid = i * 2 * in_w;
// output_data_tmp[i * out_l] = w01 * input_const[out2in_mid - in_w]
// +
// w02 * input_const[out2in_mid - in_w +
// 1] + w11 * input_const[out2in_mid] +
// w12 * input_const[out2in_mid + 1] +
// w21 * input_const[out2in_mid + in_w]
// + w22 * input_const[out2in_mid + in_w
// + 1];
//
// out2in_mid = i * 2 * in_w + (out_l - 1) * 2;
// output_data_tmp[i * out_l + out_l - 1] =
// w00 * input_const[out2in_mid - in_w - 1] +
// w01 * input_const[out2in_mid - in_w] +
// w10 * input_const[out2in_mid - 1] + w11 *
// input_const[out2in_mid] + w20 * input_const[out2in_mid + in_w
// - 1] + w21 * input_const[out2in_mid + in_w] + (1 - if_pad) *
// (w02 * input_const[out2in_mid - in_w + 1] +
// w12 * input_const[out2in_mid + 1] +
// w22 * input_const[out2in_mid + in_w + 1]);
// output_data_tmp[i * out_l] =
// output_data_tmp[i * out_l] * newscale_data[j] +
// newbias_data[j];
// output_data_tmp[i * out_l + out_l - 1] =
// output_data_tmp[i * out_l + out_l - 1] * newscale_data[j] +
// newbias_data[j];
// if (if_relu) {
// output_data_tmp[i * out_l] =
// output_data_tmp[i * out_l] < 0 ? 0 : output_data_tmp[i *
// out_l];
// output_data_tmp[i * out_l + out_l - 1] =
// output_data_tmp[i * out_l + out_l - 1] < 0
// ? 0
// : output_data_tmp[i * out_l + out_l - 1];
// }
// }
// filter_data_tmp += 9;
// }
// input_data += inhxw * c;
// output_data += outhxw * c;
// }
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册