提交 310b1dbd 编写于 作者: H hjchen2

Optimize pooling which efficiency has increased by 30% for googlenet, Fix pooling3x3 for stride 2

上级 2e0735e6
...@@ -110,7 +110,7 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) { ...@@ -110,7 +110,7 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
break; break;
case 3: case 3:
vst1_f32(output, vget_low_f32(r0)); vst1_f32(output, vget_low_f32(r0));
vst1_lane_f32(output, vget_high_f32(r0), 0); vst1q_lane_f32(output, r0, 2);
break; break;
} }
} }
......
...@@ -53,7 +53,7 @@ struct PoolingVal<AVG> { ...@@ -53,7 +53,7 @@ struct PoolingVal<AVG> {
++count; ++count;
return *this; return *this;
} }
inline float Value() { return (count > 0) ? val / count : 0.f; } inline float Value() { return (count > 0) ? val * (1.f / count) : 0.f; }
}; };
#if defined(__ARM_NEON) || defined(__ARM_NEON__) #if defined(__ARM_NEON) || defined(__ARM_NEON__)
...@@ -67,6 +67,16 @@ inline float32x4_t vPoolInitq_f32<AVG>() { ...@@ -67,6 +67,16 @@ inline float32x4_t vPoolInitq_f32<AVG>() {
return vdupq_n_f32(0.f); return vdupq_n_f32(0.f);
} }
template <PoolingType P = MAX>
inline float32x2_t vPoolInit_f32() {
return vdup_n_f32(-std::numeric_limits<float>::max());
}
template <>
inline float32x2_t vPoolInit_f32<AVG>() {
return vdup_n_f32(0.f);
}
template <PoolingType P = MAX> template <PoolingType P = MAX>
inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) { inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) {
return vmaxq_f32(x1, x2); return vmaxq_f32(x1, x2);
...@@ -78,6 +88,28 @@ inline float32x4_t vPoolPreq_f32<AVG>(const float32x4_t &x1, ...@@ -78,6 +88,28 @@ inline float32x4_t vPoolPreq_f32<AVG>(const float32x4_t &x1,
return vaddq_f32(x1, x2); return vaddq_f32(x1, x2);
} }
template <PoolingType P = MAX>
inline float32x2_t vPoolPre_f32(const float32x2_t &x1, const float32x2_t &x2) {
return vmax_f32(x1, x2);
}
template <>
inline float32x2_t vPoolPre_f32<AVG>(const float32x2_t &x1,
const float32x2_t &x2) {
return vadd_f32(x1, x2);
}
template <PoolingType P = MAX>
inline float32x2_t vpPoolPre_f32(const float32x2_t &x1, const float32x2_t &x2) {
return vpmax_f32(x1, x2);
}
template <>
inline float32x2_t vpPoolPre_f32<AVG>(const float32x2_t &x1,
const float32x2_t &x2) {
return vpadd_f32(x1, x2);
}
template <PoolingType P = MAX> template <PoolingType P = MAX>
inline float32x4_t vPoolPostq_f32(const float32x4_t &x, inline float32x4_t vPoolPostq_f32(const float32x4_t &x,
const float32x4_t &post) { const float32x4_t &post) {
...@@ -89,6 +121,18 @@ inline float32x4_t vPoolPostq_f32<AVG>(const float32x4_t &x, ...@@ -89,6 +121,18 @@ inline float32x4_t vPoolPostq_f32<AVG>(const float32x4_t &x,
const float32x4_t &post) { const float32x4_t &post) {
return vmulq_f32(x, post); return vmulq_f32(x, post);
} }
template <PoolingType P = MAX>
inline float32x2_t vPoolPost_f32(const float32x2_t &x,
const float32x2_t &post) {
return x;
}
template <>
inline float32x2_t vPoolPost_f32<AVG>(const float32x2_t &x,
const float32x2_t &post) {
return vmul_f32(x, post);
}
#endif // __ARM_NEON__ #endif // __ARM_NEON__
template <PoolingType P = MAX> template <PoolingType P = MAX>
......
...@@ -40,7 +40,7 @@ namespace math { ...@@ -40,7 +40,7 @@ namespace math {
template <PoolingType P, int Stride = 1> template <PoolingType P, int Stride = 1>
struct Pooling2x2NormalRowLoadInput { struct Pooling2x2NormalRowLoadInput {
inline void operator()(const float *input, float32x4_t *x0, float32x4_t *x1) { void operator()(const float *input, float32x4_t *x0, float32x4_t *x1) {
x0[0] = vld1q_f32(input); x0[0] = vld1q_f32(input);
x0[1] = vld1q_f32(input + 4); x0[1] = vld1q_f32(input + 4);
x1[0] = vextq_f32(x0[0], x0[1], 1); x1[0] = vextq_f32(x0[0], x0[1], 1);
...@@ -50,7 +50,7 @@ struct Pooling2x2NormalRowLoadInput { ...@@ -50,7 +50,7 @@ struct Pooling2x2NormalRowLoadInput {
template <PoolingType P> template <PoolingType P>
struct Pooling2x2NormalRowLoadInput<P, 2> { struct Pooling2x2NormalRowLoadInput<P, 2> {
inline void operator()(const float *input, float32x4_t *x0, float32x4_t *x1) { void operator()(const float *input, float32x4_t *x0, float32x4_t *x1) {
float32x4x2_t t0 = vld2q_f32(input); float32x4x2_t t0 = vld2q_f32(input);
float32x4x2_t t1 = vld2q_f32(input + 8); float32x4x2_t t1 = vld2q_f32(input + 8);
x0[0] = t0.val[0]; x0[0] = t0.val[0];
......
...@@ -14,10 +14,10 @@ limitations under the License. */ ...@@ -14,10 +14,10 @@ limitations under the License. */
#ifdef POOL_OP #ifdef POOL_OP
#include "operators/math/pooling.h"
#if defined(__ARM_NEON) || defined(__ARM_NEON__) #if defined(__ARM_NEON) || defined(__ARM_NEON__)
#include <arm_neon.h> #include <arm_neon.h>
#endif // __ARM_NEON #include "operators/math/pooling.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -38,87 +38,6 @@ namespace math { ...@@ -38,87 +38,6 @@ namespace math {
output_ptr[w] = val.Value(); \ output_ptr[w] = val.Value(); \
} }
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
template <PoolingType P, int Stride = 1>
struct Pooling3x3ValidColLoadInput {
inline void operator()(const float *input, const int input_w,
const int valid_cols, float32x4x2_t &x0, // NOLINT
float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT
float32x4x2_t &y0) { // NOLINT
float fake_input[3][8];
if (valid_cols == 1) {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
}
} else if (valid_cols == 2) {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
}
} else {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
fake_input[2][i] = input[2];
}
}
y0.val[0] = vPoolInitq_f32<P>();
y0.val[1] = vPoolInitq_f32<P>();
for (int i = 0; i < valid_cols; ++i) {
x0.val[0] = vld1q_f32(fake_input[i]);
x0.val[1] = vld1q_f32(fake_input[i] + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x1.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x1.val[1], y0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x2.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x2.val[1], y0.val[1]);
}
}
};
template <PoolingType P>
struct Pooling3x3ValidColLoadInput<P, 2> {
inline void operator()(const float *input, const int input_w,
const int valid_cols, float32x4x2_t &x0, // NOLINT
float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT
float32x4x2_t &y0) { // NOLINT
float fake_input[3][13];
if (valid_cols == 1) {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
}
} else if (valid_cols == 2) {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
}
} else {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
fake_input[2][i] = input[2];
}
}
for (int i = 0; i < valid_cols; ++i) {
x0 = vld2q_f32(fake_input[i]);
x1 = vld2q_f32(fake_input[i] + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
}
}
};
template <PoolingType P, int Stride = 1> template <PoolingType P, int Stride = 1>
struct Pooling3x3NormalRowLoadInput { struct Pooling3x3NormalRowLoadInput {
inline void operator()(const float *input, float32x4x2_t &x0, // NOLINT inline void operator()(const float *input, float32x4x2_t &x0, // NOLINT
...@@ -156,62 +75,6 @@ struct Pooling3x3NormalRowLoadInput<P, 2> { ...@@ -156,62 +75,6 @@ struct Pooling3x3NormalRowLoadInput<P, 2> {
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]); y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
} }
}; };
#endif // __ARM_NEON__
template <PoolingType P, int Stride>
inline void Pooling3x3ValidCol(const float *input, const int h_output,
const int h_output_end, const int w_output,
const int input_h, const int input_w,
const int padding_h, const int padding_w,
const int output_w, float *output) {
const int w_in_start = -padding_w + w_output * Stride;
const int w_in_end = w_in_start + 3;
const int w_start = w_in_start > 0 ? w_in_start : 0;
const int w_end = w_in_end < input_w ? w_in_end : input_w;
int remain_start = h_output;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int output_tiles = (h_output_end - h_output) / 6;
remain_start = h_output + output_tiles * 6;
int input_h_start = h_output * Stride - padding_h;
size_t input_offset = input_h_start * input_w + w_start;
size_t output_offset = h_output * output_w + w_output;
int valid_cols = w_end - w_start;
Pooling3x3ValidColLoadInput<P, Stride> PoolingCompute;
float32x4x2_t x0, x1, x2, y0;
float32x4_t avg = vdupq_n_f32(1.f / (3 * valid_cols));
for (int h = 0; h < output_tiles * 6; h += 6) {
float *output0 = output + output_offset;
float *output1 = output0 + output_w;
float *output2 = output1 + output_w;
float *output3 = output2 + output_w;
float *output4 = output3 + output_w;
float *output5 = output4 + output_w;
y0.val[0] = vPoolInitq_f32<P>();
y0.val[1] = vPoolInitq_f32<P>();
PoolingCompute(input + input_offset, input_w, valid_cols, x0, x1, x2, y0);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], avg);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], avg);
vst1q_lane_f32(output0, y0.val[0], 0);
vst1q_lane_f32(output1, y0.val[0], 1);
vst1q_lane_f32(output2, y0.val[0], 2);
vst1q_lane_f32(output3, y0.val[0], 3);
vst1q_lane_f32(output4, y0.val[1], 0);
vst1q_lane_f32(output5, y0.val[1], 1);
input_offset += 6 * Stride * input_w;
output_offset += 6 * output_w;
}
#endif
for (int h = remain_start; h < h_output_end; ++h) {
PoolingVal<P> val;
const int h_in_start = -padding_h + h * Stride;
for (int i = 0; i < 3; ++i) {
for (int w_in = w_start; w_in < w_end; ++w_in) {
val += input[(h_in_start + i) * input_w + w_in];
}
}
output[h * output_w + w_output] = val.Value();
}
}
template <PoolingType P, int Stride> template <PoolingType P, int Stride>
inline void Pooling3x3NormalRow(const float *input, const int h_output, inline void Pooling3x3NormalRow(const float *input, const int h_output,
...@@ -223,21 +86,25 @@ inline void Pooling3x3NormalRow(const float *input, const int h_output, ...@@ -223,21 +86,25 @@ inline void Pooling3x3NormalRow(const float *input, const int h_output,
const int h_start = h_in_start > 0 ? h_in_start : 0; const int h_start = h_in_start > 0 ? h_in_start : 0;
const int h_end = h_in_end < input_h ? h_in_end : input_h; const int h_end = h_in_end < input_h ? h_in_end : input_h;
int valid_w_start = (padding_w + Stride - 1) / Stride;
int valid_w_end = (input_w - 3) / Stride + 1 + valid_w_start;
float *output_ptr = output + h_output * output_w; float *output_ptr = output + h_output * output_w;
if (h_end - h_start <= 0) {
memset(output_ptr, 0, output_w * sizeof(float));
return;
}
const int valid_w_start = (padding_w + Stride - 1) / Stride;
const int valid_w_end = (input_w + padding_w - 3) / Stride + 1;
const int valid_w = valid_w_end - valid_w_start;
// border left // border left
POOLING3X3_NORMAL_BORDER(0, valid_w_start) POOLING3X3_NORMAL_BORDER(0, valid_w_start)
// middle // middle
int remain_start = valid_w_start;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int output_tiles = (valid_w_end - valid_w_start) / 6; int output_tiles = (valid_w_end - valid_w_start) / 6;
remain_start = valid_w_start + output_tiles * 6; int output_tiles_w = output_tiles * 6;
Pooling3x3NormalRowLoadInput<P, Stride> PoolingCompute; Pooling3x3NormalRowLoadInput<P, Stride> PoolingCompute;
float32x4x2_t x0, x1, x2, y0; float32x4x2_t x0, x1, x2, y0;
float32x4_t post = vdupq_n_f32(1.f / (3 * (h_end - h_start))); float32x4_t post = vdupq_n_f32(1.f / (3 * (h_end - h_start)));
for (int w = 0; w < output_tiles * 6; w += 6) { for (int w = 0; w < output_tiles_w; w += 6) {
int output_offset = valid_w_start + w; int output_offset = valid_w_start + w;
int input_w_offset = output_offset * Stride - padding_w; int input_w_offset = output_offset * Stride - padding_w;
y0.val[0] = vPoolInitq_f32<P>(); y0.val[0] = vPoolInitq_f32<P>();
...@@ -250,16 +117,37 @@ inline void Pooling3x3NormalRow(const float *input, const int h_output, ...@@ -250,16 +117,37 @@ inline void Pooling3x3NormalRow(const float *input, const int h_output,
vst1q_f32(output_ptr + output_offset, y0.val[0]); vst1q_f32(output_ptr + output_offset, y0.val[0]);
vst1_f32(output_ptr + output_offset + 4, vget_low_f32(y0.val[1])); vst1_f32(output_ptr + output_offset + 4, vget_low_f32(y0.val[1]));
} }
#endif // __ARM_NEON__ int remain = valid_w - output_tiles_w;
for (int w = remain_start; w < valid_w_end; ++w) { if (remain > 0) {
PoolingVal<P> val; int remain_start = valid_w_start + output_tiles_w;
int input_start = -padding_w + w * Stride; int input_w_offset = remain_start * Stride - padding_w;
float *output_ptr0 = output_ptr + remain_start;
y0.val[0] = vPoolInitq_f32<P>();
y0.val[1] = vPoolInitq_f32<P>();
for (int h_in = h_start; h_in < h_end; ++h_in) { for (int h_in = h_start; h_in < h_end; ++h_in) {
for (int j = 0; j < 3; ++j) { PoolingCompute(input + h_in * input_w + input_w_offset, x0, x1, x2, y0);
val += input[h_in * input_w + j + input_start];
} }
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
switch (remain) {
case 1:
vst1q_lane_f32(output_ptr0, y0.val[0], 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2);
break;
case 4:
vst1q_f32(output_ptr0, y0.val[0]);
break;
case 5:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0);
break;
} }
output_ptr[w] = val.Value();
} }
// border right // border right
POOLING3X3_NORMAL_BORDER(valid_w_end, output_w) POOLING3X3_NORMAL_BORDER(valid_w_end, output_w)
...@@ -286,7 +174,6 @@ struct Pooling3x3<P, 1> { ...@@ -286,7 +174,6 @@ struct Pooling3x3<P, 1> {
int valid_w_start = padding_w; int valid_w_start = padding_w;
int valid_w = input_w - 2; int valid_w = input_w - 2;
int valid_w_end = valid_w_start + valid_w; int valid_w_end = valid_w_start + valid_w;
float avg = 1.f / 9;
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int batch = 0; batch < output->dims()[0]; ++batch) { for (int batch = 0; batch < output->dims()[0]; ++batch) {
...@@ -299,23 +186,6 @@ struct Pooling3x3<P, 1> { ...@@ -299,23 +186,6 @@ struct Pooling3x3<P, 1> {
Pooling3x3NormalRow<P, 1>(input_ptr, h, input_h, input_w, padding_h, Pooling3x3NormalRow<P, 1>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr); padding_w, output_w, output_ptr);
} }
// left
for (int w = 0; w < valid_w_start; ++w) {
Pooling3x3ValidCol<P, 1>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
// right
for (int w = valid_w_end; w < output_w; ++w) {
Pooling3x3ValidCol<P, 1>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling3x3NormalRow<P, 1>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// valid // valid
int output_w_tiles = valid_w / 6; int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6; int output_w_remain = valid_w - output_w_tiles * 6;
...@@ -326,12 +196,61 @@ struct Pooling3x3<P, 1> { ...@@ -326,12 +196,61 @@ struct Pooling3x3<P, 1> {
const float *input_ptr3 = input_ptr2 + input_w; const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w; const float *input_ptr4 = input_ptr3 + input_w;
const float *input_ptr5 = input_ptr4 + input_w; const float *input_ptr5 = input_ptr4 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start; float *output_ptr0 = output_ptr + h * output_w;
float *output_ptr1 = output_ptr0 + output_w; float *output_ptr1 = output_ptr0 + output_w;
float *output_ptr2 = output_ptr1 + output_w; float *output_ptr2 = output_ptr1 + output_w;
float *output_ptr3 = output_ptr2 + output_w; float *output_ptr3 = output_ptr2 + output_w;
int remain = output_w_remain; // pad left
#if defined(__ARM_NEON__) || defined(__ARM_NEON) if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t row3 = vld1_f32(input_ptr3);
float32x2_t row4 = vld1_f32(input_ptr4);
float32x2_t row5 = vld1_f32(input_ptr5);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, acc1, acc2, acc3, acc12, acc34, post;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0.f;
output_ptr1[w] = 0.f;
output_ptr2[w] = 0.f;
output_ptr3[w] = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc12 = vPoolPre_f32<P>(row1, row2);
acc34 = vPoolPre_f32<P>(row3, row4);
acc0 = vPoolPre_f32<P>(row0, acc12);
acc1 = vPoolPre_f32<P>(row3, acc12);
acc2 = vPoolPre_f32<P>(row2, acc34);
acc3 = vPoolPre_f32<P>(row5, acc34);
acc0 = vpPoolPre_f32<P>(acc0, acc0);
acc1 = vpPoolPre_f32<P>(acc1, acc1);
acc2 = vpPoolPre_f32<P>(acc2, acc2);
acc3 = vpPoolPre_f32<P>(acc3, acc3);
acc0 = vPoolPost_f32<P>(acc0, post);
acc1 = vPoolPost_f32<P>(acc1, post);
acc2 = vPoolPost_f32<P>(acc2, post);
acc3 = vPoolPost_f32<P>(acc3, post);
vst1_lane_f32(output_ptr0 + w, acc0, 0);
vst1_lane_f32(output_ptr1 + w, acc1, 0);
vst1_lane_f32(output_ptr2 + w, acc2, 0);
vst1_lane_f32(output_ptr3 + w, acc3, 0);
row0 = vext_f32(pad0, row0, 1);
row1 = vext_f32(pad0, row1, 1);
row2 = vext_f32(pad0, row2, 1);
row3 = vext_f32(pad0, row3, 1);
row4 = vext_f32(pad0, row4, 1);
row5 = vext_f32(pad0, row5, 1);
}
}
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
output_ptr2 += valid_w_start;
output_ptr3 += valid_w_start;
}
// valid
float32x4x2_t x0, x1, x2; float32x4x2_t x0, x1, x2;
float32x4x2_t y0, y1, y2; float32x4x2_t y0, y1, y2;
float32x4_t post = vdupq_n_f32(1.f / 9); float32x4_t post = vdupq_n_f32(1.f / 9);
...@@ -446,100 +365,198 @@ struct Pooling3x3<P, 1> { ...@@ -446,100 +365,198 @@ struct Pooling3x3<P, 1> {
output_ptr3 += 6; output_ptr3 += 6;
} }
// remain width // remain width
if (remain >= 4) { if (output_w_remain > 0) {
float32x4x2_t y3;
x0.val[0] = vld1q_f32(input_ptr0); x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4); x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0.val[0] = vld1q_f32(input_ptr1); x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4); x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y1.val[1], y0.val[1]);
x0.val[0] = vld1q_f32(input_ptr2); x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4); x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y2.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y2.val[0], y1.val[0]); y1.val[0] = vPoolPreq_f32<P>(y2.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y2.val[1], y1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y2.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(y2.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y2.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]); y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
x0.val[0] = vld1q_f32(input_ptr3); x0.val[0] = vld1q_f32(input_ptr3);
x0.val[1] = vld1q_f32(input_ptr3 + 4); x0.val[1] = vld1q_f32(input_ptr3 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]); y3.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y2.val[0] = vPoolPreq_f32<P>(y0.val[0], y2.val[0]); y3.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y3.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y3.val[1], y1.val[1]);
y2.val[0] = vPoolPreq_f32<P>(y3.val[0], y2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(y3.val[1], y2.val[1]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post); y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
vst1q_f32(output_ptr1, y1.val[0]); y1.val[1] = vPoolPostq_f32<P>(y1.val[1], post);
x0.val[0] = vld1q_f32(input_ptr4); x0.val[0] = vld1q_f32(input_ptr4);
x0.val[1] = vld1q_f32(input_ptr4 + 4); x0.val[1] = vld1q_f32(input_ptr4 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y3.val[0] = vPoolPreq_f32<P>(x0.val[0], y3.val[0]);
y3.val[1] = vPoolPreq_f32<P>(x0.val[1], y3.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]); y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], y2.val[1]);
y2.val[0] = vPoolPostq_f32<P>(y2.val[0], post); y2.val[0] = vPoolPostq_f32<P>(y2.val[0], post);
vst1q_f32(output_ptr2, y2.val[0]); y2.val[1] = vPoolPostq_f32<P>(y2.val[1], post);
x0.val[0] = vld1q_f32(input_ptr5); x0.val[0] = vld1q_f32(input_ptr5);
x0.val[1] = vld1q_f32(input_ptr5 + 4); x0.val[1] = vld1q_f32(input_ptr5 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post); y3.val[0] = vPoolPreq_f32<P>(x0.val[0], y3.val[0]);
vst1q_f32(output_ptr3, y0.val[0]); y3.val[1] = vPoolPreq_f32<P>(x0.val[1], y3.val[1]);
y3.val[0] = vPoolPostq_f32<P>(y3.val[0], post);
input_ptr0 += 4; y3.val[1] = vPoolPostq_f32<P>(y3.val[1], post);
input_ptr1 += 4;
input_ptr2 += 4; switch (output_w_remain) {
input_ptr3 += 4; case 1:
input_ptr4 += 4; vst1q_lane_f32(output_ptr0, y0.val[0], 0);
input_ptr5 += 4; vst1q_lane_f32(output_ptr1, y1.val[0], 0);
output_ptr0 += 4; vst1q_lane_f32(output_ptr2, y2.val[0], 0);
output_ptr1 += 4; vst1q_lane_f32(output_ptr3, y3.val[0], 0);
output_ptr2 += 4; break;
output_ptr3 += 4; case 2:
remain -= 4; vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1_f32(output_ptr1, vget_low_f32(y1.val[0]));
vst1_f32(output_ptr2, vget_low_f32(y2.val[0]));
vst1_f32(output_ptr3, vget_low_f32(y3.val[0]));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1_f32(output_ptr1, vget_low_f32(y1.val[0]));
vst1_f32(output_ptr2, vget_low_f32(y2.val[0]));
vst1_f32(output_ptr3, vget_low_f32(y3.val[0]));
vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2);
vst1q_lane_f32(output_ptr1 + 2, y1.val[0], 2);
vst1q_lane_f32(output_ptr2 + 2, y2.val[0], 2);
vst1q_lane_f32(output_ptr3 + 2, y3.val[0], 2);
break;
case 4:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_f32(output_ptr1, y1.val[0]);
vst1q_f32(output_ptr2, y2.val[0]);
vst1q_f32(output_ptr3, y3.val[0]);
break;
case 5:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_f32(output_ptr1, y1.val[0]);
vst1q_f32(output_ptr2, y2.val[0]);
vst1q_f32(output_ptr3, y3.val[0]);
vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0);
vst1q_lane_f32(output_ptr1 + 4, y1.val[1], 0);
vst1q_lane_f32(output_ptr2 + 4, y2.val[1], 0);
vst1q_lane_f32(output_ptr3 + 4, y3.val[1], 0);
break;
}
input_ptr0 += output_w_remain;
input_ptr1 += output_w_remain;
input_ptr2 += output_w_remain;
input_ptr3 += output_w_remain;
input_ptr4 += output_w_remain;
input_ptr5 += output_w_remain;
output_ptr0 += output_w_remain;
output_ptr1 += output_w_remain;
output_ptr2 += output_w_remain;
output_ptr3 += output_w_remain;
}
// pad right
if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t row3 = vld1_f32(input_ptr3);
float32x2_t row4 = vld1_f32(input_ptr4);
float32x2_t row5 = vld1_f32(input_ptr5);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, acc1, acc2, acc3, acc12, acc34, post;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0.f;
*output_ptr1 = 0.f;
*output_ptr2 = 0.f;
*output_ptr3 = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc12 = vPoolPre_f32<P>(row1, row2);
acc34 = vPoolPre_f32<P>(row3, row4);
acc0 = vPoolPre_f32<P>(row0, acc12);
acc1 = vPoolPre_f32<P>(row3, acc12);
acc2 = vPoolPre_f32<P>(row2, acc34);
acc3 = vPoolPre_f32<P>(row5, acc34);
acc0 = vpPoolPre_f32<P>(acc0, acc0);
acc1 = vpPoolPre_f32<P>(acc1, acc1);
acc2 = vpPoolPre_f32<P>(acc2, acc2);
acc3 = vpPoolPre_f32<P>(acc3, acc3);
acc0 = vPoolPost_f32<P>(acc0, post);
acc1 = vPoolPost_f32<P>(acc1, post);
acc2 = vPoolPost_f32<P>(acc2, post);
acc3 = vPoolPost_f32<P>(acc3, post);
vst1_lane_f32(output_ptr0, acc0, 0);
vst1_lane_f32(output_ptr1, acc1, 0);
vst1_lane_f32(output_ptr2, acc2, 0);
vst1_lane_f32(output_ptr3, acc3, 0);
row0 = vext_f32(row0, pad0, 1);
row1 = vext_f32(row1, pad0, 1);
row2 = vext_f32(row2, pad0, 1);
row3 = vext_f32(row3, pad0, 1);
row4 = vext_f32(row4, pad0, 1);
row5 = vext_f32(row5, pad0, 1);
}
output_ptr0++;
output_ptr1++;
output_ptr2++;
output_ptr3++;
} }
#endif // __ARM_NEON__
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[r], input_ptr0[r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[r + 2]);
float m1 = PoolPre<P>(input_ptr1[r], input_ptr1[r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[r + 2]);
float m2 = PoolPre<P>(input_ptr2[r], input_ptr2[r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[r + 2]);
float m3 = PoolPre<P>(input_ptr3[r], input_ptr3[r + 1]);
m3 = PoolPre<P>(m3, input_ptr3[r + 2]);
float m4 = PoolPre<P>(input_ptr4[r], input_ptr4[r + 1]);
m4 = PoolPre<P>(m4, input_ptr4[r + 2]);
float m5 = PoolPre<P>(input_ptr5[r], input_ptr5[r + 1]);
m5 = PoolPre<P>(m5, input_ptr5[r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
m1 = PoolPre<P>(PoolPre<P>(m1, m2), m3);
m2 = PoolPre<P>(PoolPre<P>(m2, m3), m4);
m3 = PoolPre<P>(PoolPre<P>(m3, m4), m5);
output_ptr0[r] = PoolPost<P>(m0, avg);
output_ptr1[r] = PoolPost<P>(m1, avg);
output_ptr2[r] = PoolPost<P>(m2, avg);
output_ptr3[r] = PoolPost<P>(m3, avg);
} }
} }
// remain height // remain height
...@@ -548,9 +565,33 @@ struct Pooling3x3<P, 1> { ...@@ -548,9 +565,33 @@ struct Pooling3x3<P, 1> {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w; const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w; const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w; const float *input_ptr2 = input_ptr1 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start; float *output_ptr0 = output_ptr + h * output_w;
int remain = output_w_remain; // pad left
#if defined(__ARM_NEON__) || defined(__ARM_NEON) if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, post;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc0 = vPoolPre_f32<P>(row0, row1);
acc0 = vPoolPre_f32<P>(acc0, row2);
acc0 = vpPoolPre_f32<P>(acc0, acc0);
acc0 = vPoolPost_f32<P>(acc0, post);
vst1_lane_f32(output_ptr0 + w, acc0, 0);
row0 = vext_f32(pad0, row0, 1);
row1 = vext_f32(pad0, row1, 1);
row2 = vext_f32(pad0, row2, 1);
}
}
output_ptr0 += valid_w_start;
}
// valid
float32x4x2_t x0, x1, x2, y0; float32x4x2_t x0, x1, x2, y0;
float32x4_t post = vdupq_n_f32(1.f / 9); float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) { for (int loop = 0; loop < output_w_tiles; ++loop) {
...@@ -601,50 +642,100 @@ struct Pooling3x3<P, 1> { ...@@ -601,50 +642,100 @@ struct Pooling3x3<P, 1> {
output_ptr0 += 6; output_ptr0 += 6;
} }
// remain width // remain width
if (remain >= 4) { if (output_w_remain > 0) {
x0.val[0] = vld1q_f32(input_ptr0); x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4); x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0.val[0] = vld1q_f32(input_ptr1); x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4); x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0.val[0] = vld1q_f32(input_ptr2); x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4); x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
// restore
switch (output_w_remain) {
case 1:
vst1q_lane_f32(output_ptr0, y0.val[0], 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2);
break;
case 4:
vst1q_f32(output_ptr0, y0.val[0]); vst1q_f32(output_ptr0, y0.val[0]);
break;
input_ptr0 += 4; case 5:
input_ptr1 += 4; vst1q_f32(output_ptr0, y0.val[0]);
input_ptr2 += 4; vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0);
output_ptr0 += 4; break;
remain -= 4;
} }
#endif // __ARM_NEON__ input_ptr0 += output_w_remain;
for (int r = 0; r < remain; ++r) { input_ptr1 += output_w_remain;
float m0 = PoolPre<P>(input_ptr0[r], input_ptr0[r + 1]); input_ptr2 += output_w_remain;
m0 = PoolPre<P>(m0, input_ptr0[r + 2]); output_ptr0 += output_w_remain;
float m1 = PoolPre<P>(input_ptr1[r], input_ptr1[r + 1]); }
m1 = PoolPre<P>(m1, input_ptr1[r + 2]); // pad right
float m2 = PoolPre<P>(input_ptr2[r], input_ptr2[r + 1]); if (padding_w) {
m2 = PoolPre<P>(m2, input_ptr2[r + 2]); float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2); float32x2_t row2 = vld1_f32(input_ptr2);
output_ptr0[r] = PoolPost<P>(m0, avg); float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, post;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc0 = vPoolPre_f32<P>(row0, row1);
acc0 = vPoolPre_f32<P>(acc0, row2);
acc0 = vpPoolPre_f32<P>(acc0, acc0);
acc0 = vPoolPost_f32<P>(acc0, post);
vst1_lane_f32(output_ptr0, acc0, 0);
row0 = vext_f32(row0, pad0, 1);
row1 = vext_f32(row1, pad0, 1);
row2 = vext_f32(row2, pad0, 1);
} }
output_ptr0++;
}
}
}
// pad bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling3x3NormalRow<P, 1>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
} }
} }
} }
...@@ -667,12 +758,22 @@ struct Pooling3x3<P, 2> { ...@@ -667,12 +758,22 @@ struct Pooling3x3<P, 2> {
int image_size = input_h * input_w; int image_size = input_h * input_w;
int out_image_size = output_h * output_w; int out_image_size = output_h * output_w;
int valid_h_start = (padding_h + 1) / 2; int valid_h_start = (padding_h + 1) / 2;
int valid_h = (input_h - 3) / 2 + 1; int valid_h_end = (input_h + padding_h - 1) / 2;
int valid_h_end = valid_h_start + valid_h; int valid_h = valid_h_end - valid_h_start;
int valid_w_start = (padding_w + 1) / 2; int valid_w_start = (padding_w + 1) / 2;
int valid_w = (input_w - 3) / 2 + 1; int valid_w_end = (input_w + padding_w - 1) / 2;
int valid_w_end = valid_w_start + valid_w; int valid_w = valid_w_end - valid_w_start;
float avg = 1.f / 9;
int padding_height = input_h + 2 * padding_h;
int padding_width = input_w + 2 * padding_w;
bool ceil_mode = (((padding_height - 1) / 2) < output_h) ||
(((padding_width - 1) / 2) < output_w);
int padding_b =
padding_h + (ceil_mode ? 2 * output_h - (padding_height - 1) : 0);
int padding_r =
padding_w + (ceil_mode ? 2 * output_w - (padding_width - 1) : 0);
// for pad left
int valid_input_w_start = (valid_w_start << 1) - padding_w;
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int batch = 0; batch < output->dims()[0]; ++batch) { for (int batch = 0; batch < output->dims()[0]; ++batch) {
...@@ -685,41 +786,70 @@ struct Pooling3x3<P, 2> { ...@@ -685,41 +786,70 @@ struct Pooling3x3<P, 2> {
Pooling3x3NormalRow<P, 2>(input_ptr, h, input_h, input_w, padding_h, Pooling3x3NormalRow<P, 2>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr); padding_w, output_w, output_ptr);
} }
// left
for (int w = 0; w < valid_w_start; ++w) {
Pooling3x3ValidCol<P, 2>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
// right
for (int w = valid_w_end; w < output_w; ++w) {
Pooling3x3ValidCol<P, 2>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling3x3NormalRow<P, 2>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// valid // valid
int input_w_start = 2 * valid_w_start - padding_w;
int output_w_tiles = valid_w / 6; int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6; int output_w_remain = valid_w - output_w_tiles * 6;
for (int h = valid_h_start; h < valid_h_end - 2; h += 3) { for (int h = valid_h_start; h < valid_h_end - 2; h += 3) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start; const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
const float *input_ptr0 = input_ptr + offset;
const float *input_ptr1 = input_ptr0 + input_w; const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w; const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w; const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w; const float *input_ptr4 = input_ptr3 + input_w;
const float *input_ptr5 = input_ptr4 + input_w; const float *input_ptr5 = input_ptr4 + input_w;
const float *input_ptr6 = input_ptr5 + input_w; const float *input_ptr6 = input_ptr5 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start; float *output_ptr0 = output_ptr + h * output_w;
float *output_ptr1 = output_ptr0 + output_w; float *output_ptr1 = output_ptr0 + output_w;
float *output_ptr2 = output_ptr1 + output_w; float *output_ptr2 = output_ptr1 + output_w;
int remain = output_w_remain; // pad left
#if defined(__ARM_NEON__) || defined(__ARM_NEON) if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t row3 = vld1_f32(input_ptr3);
float32x2_t row4 = vld1_f32(input_ptr4);
float32x2_t row5 = vld1_f32(input_ptr5);
float32x2_t row6 = vld1_f32(input_ptr6);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, acc1, acc2, post;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - (w << 1);
if (padding >= 3) {
output_ptr0[w] = 0.f;
output_ptr1[w] = 0.f;
output_ptr2[w] = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc0 = vPoolPre_f32<P>(row0, row1);
acc1 = vPoolPre_f32<P>(row2, row3);
acc2 = vPoolPre_f32<P>(row4, row5);
acc0 = vPoolPre_f32<P>(acc0, row2);
acc1 = vPoolPre_f32<P>(acc1, row4);
acc2 = vPoolPre_f32<P>(acc2, row6);
if (padding == 1) {
acc0 = vpPoolPre_f32<P>(acc0, acc0);
acc1 = vpPoolPre_f32<P>(acc1, acc1);
acc2 = vpPoolPre_f32<P>(acc2, acc2);
}
acc0 = vPoolPost_f32<P>(acc0, post);
acc1 = vPoolPost_f32<P>(acc1, post);
acc2 = vPoolPost_f32<P>(acc2, post);
vst1_lane_f32(output_ptr0 + w, acc0, 0);
vst1_lane_f32(output_ptr1 + w, acc1, 0);
vst1_lane_f32(output_ptr2 + w, acc2, 0);
}
}
input_ptr0 += valid_input_w_start;
input_ptr1 += valid_input_w_start;
input_ptr2 += valid_input_w_start;
input_ptr3 += valid_input_w_start;
input_ptr4 += valid_input_w_start;
input_ptr5 += valid_input_w_start;
input_ptr6 += valid_input_w_start;
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
output_ptr2 += valid_w_start;
}
// valid
float32x4x2_t x0, x1, x2; float32x4x2_t x0, x1, x2;
float32x4x2_t y0, y1, y2; float32x4x2_t y0, y1, y2;
float32x4_t post = vdupq_n_f32(1.f / 9); float32x4_t post = vdupq_n_f32(1.f / 9);
...@@ -823,108 +953,210 @@ struct Pooling3x3<P, 2> { ...@@ -823,108 +953,210 @@ struct Pooling3x3<P, 2> {
output_ptr2 += 6; output_ptr2 += 6;
} }
// remain width // remain width
if (remain >= 4) { if (output_w_remain > 0) {
x0 = vld2q_f32(input_ptr0); x0 = vld2q_f32(input_ptr0);
x1.val[0] = vdupq_n_f32(input_ptr0[8]); x1 = vld2q_f32(input_ptr0 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0 = vld2q_f32(input_ptr1); x0 = vld2q_f32(input_ptr1);
x1.val[0] = vdupq_n_f32(input_ptr1[8]); x1 = vld2q_f32(input_ptr1 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0 = vld2q_f32(input_ptr2); x0 = vld2q_f32(input_ptr2);
x1.val[0] = vdupq_n_f32(input_ptr2[8]); x1 = vld2q_f32(input_ptr2 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y1.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]); y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
x0 = vld2q_f32(input_ptr3); x0 = vld2q_f32(input_ptr3);
x1.val[0] = vdupq_n_f32(input_ptr3[8]); x1 = vld2q_f32(input_ptr3 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], y1.val[0]); y1.val[0] = vPoolPreq_f32<P>(x0.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x0.val[1], y1.val[1]);
x0 = vld2q_f32(input_ptr4); x0 = vld2q_f32(input_ptr4);
x1.val[0] = vdupq_n_f32(input_ptr4[8]); x1 = vld2q_f32(input_ptr4 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]); y2.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y2.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y2.val[1], y1.val[1]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post); y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
vst1q_f32(output_ptr1, y1.val[0]); y1.val[1] = vPoolPostq_f32<P>(y1.val[1], post);
x0 = vld2q_f32(input_ptr5); x0 = vld2q_f32(input_ptr5);
x1.val[0] = vdupq_n_f32(input_ptr5[8]); x1 = vld2q_f32(input_ptr5 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], y2.val[1]);
x0 = vld2q_f32(input_ptr6); x0 = vld2q_f32(input_ptr6);
x1.val[0] = vdupq_n_f32(input_ptr6[8]); x1 = vld2q_f32(input_ptr6 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post); y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]);
vst1q_f32(output_ptr2, y0.val[0]); y2.val[1] = vPoolPreq_f32<P>(x0.val[1], y2.val[1]);
y2.val[0] = vPoolPostq_f32<P>(y2.val[0], post);
y2.val[1] = vPoolPostq_f32<P>(y2.val[1], post);
input_ptr0 += 8; switch (output_w_remain) {
input_ptr1 += 8; case 1:
input_ptr2 += 8; vst1q_lane_f32(output_ptr0, y0.val[0], 0);
input_ptr3 += 8; vst1q_lane_f32(output_ptr1, y1.val[0], 0);
input_ptr4 += 8; vst1q_lane_f32(output_ptr2, y2.val[0], 0);
input_ptr5 += 8; break;
input_ptr6 += 8; case 2:
output_ptr0 += 4; vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
output_ptr1 += 4; vst1_f32(output_ptr1, vget_low_f32(y1.val[0]));
output_ptr2 += 4; vst1_f32(output_ptr2, vget_low_f32(y2.val[0]));
remain -= 4; break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1_f32(output_ptr1, vget_low_f32(y1.val[0]));
vst1_f32(output_ptr2, vget_low_f32(y2.val[0]));
vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2);
vst1q_lane_f32(output_ptr1 + 2, y1.val[0], 2);
vst1q_lane_f32(output_ptr2 + 2, y2.val[0], 2);
break;
case 4:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_f32(output_ptr1, y1.val[0]);
vst1q_f32(output_ptr2, y2.val[0]);
break;
case 5:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_f32(output_ptr1, y1.val[0]);
vst1q_f32(output_ptr2, y2.val[0]);
vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0);
vst1q_lane_f32(output_ptr1 + 4, y1.val[1], 0);
vst1q_lane_f32(output_ptr2 + 4, y2.val[1], 0);
break;
}
input_ptr0 += (output_w_remain << 1);
input_ptr1 += (output_w_remain << 1);
input_ptr2 += (output_w_remain << 1);
input_ptr3 += (output_w_remain << 1);
input_ptr4 += (output_w_remain << 1);
input_ptr5 += (output_w_remain << 1);
input_ptr6 += (output_w_remain << 1);
output_ptr0 += output_w_remain;
output_ptr1 += output_w_remain;
output_ptr2 += output_w_remain;
}
// pad right
if (padding_r > 0) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t row3 = vld1_f32(input_ptr3);
float32x2_t row4 = vld1_f32(input_ptr4);
float32x2_t row5 = vld1_f32(input_ptr5);
float32x2_t row6 = vld1_f32(input_ptr6);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, acc1, acc2, post;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0.f;
*output_ptr1 = 0.f;
*output_ptr2 = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc0 = vPoolPre_f32<P>(row0, row1);
acc1 = vPoolPre_f32<P>(row2, row3);
acc2 = vPoolPre_f32<P>(row4, row5);
acc0 = vPoolPre_f32<P>(acc0, row2);
acc1 = vPoolPre_f32<P>(acc1, row4);
acc2 = vPoolPre_f32<P>(acc2, row6);
if (padding == 1) {
acc0 = vpPoolPre_f32<P>(acc0, acc0);
acc1 = vpPoolPre_f32<P>(acc1, acc1);
acc2 = vpPoolPre_f32<P>(acc2, acc2);
}
acc0 = vPoolPost_f32<P>(acc0, post);
acc1 = vPoolPost_f32<P>(acc1, post);
acc2 = vPoolPost_f32<P>(acc2, post);
vst1_lane_f32(output_ptr0, acc0, 0);
vst1_lane_f32(output_ptr1, acc1, 0);
vst1_lane_f32(output_ptr2, acc2, 0);
}
output_ptr0++;
output_ptr1++;
output_ptr2++;
} }
#endif // __ARM_NEON__
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[2 * r], input_ptr0[2 * r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[2 * r + 2]);
float m1 = PoolPre<P>(input_ptr1[2 * r], input_ptr1[2 * r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[2 * r + 2]);
float m2 = PoolPre<P>(input_ptr2[2 * r], input_ptr2[2 * r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[2 * r + 2]);
float m3 = PoolPre<P>(input_ptr3[2 * r], input_ptr3[2 * r + 1]);
m3 = PoolPre<P>(m3, input_ptr3[2 * r + 2]);
float m4 = PoolPre<P>(input_ptr4[2 * r], input_ptr4[2 * r + 1]);
m4 = PoolPre<P>(m4, input_ptr4[2 * r + 2]);
float m5 = PoolPre<P>(input_ptr5[2 * r], input_ptr5[2 * r + 1]);
m5 = PoolPre<P>(m5, input_ptr5[2 * r + 2]);
float m6 = PoolPre<P>(input_ptr6[2 * r], input_ptr6[2 * r + 1]);
m6 = PoolPre<P>(m6, input_ptr6[2 * r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
m1 = PoolPre<P>(PoolPre<P>(m2, m3), m4);
m2 = PoolPre<P>(PoolPre<P>(m4, m5), m6);
output_ptr0[r] = PoolPost<P>(m0, avg);
output_ptr1[r] = PoolPost<P>(m1, avg);
output_ptr2[r] = PoolPost<P>(m2, avg);
} }
} }
// remain height // remain height
int start_h = valid_h_start + valid_h / 3 * 3; int start_h = valid_h_start + valid_h / 3 * 3;
for (int h = start_h; h < valid_h_end; ++h) { for (int h = start_h; h < valid_h_end; ++h) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start; const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
const float *input_ptr0 = input_ptr + offset;
const float *input_ptr1 = input_ptr0 + input_w; const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w; const float *input_ptr2 = input_ptr1 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start; float *output_ptr0 = output_ptr + h * output_w;
int remain = output_w_remain; // pad left
#if defined(__ARM_NEON__) || defined(__ARM_NEON) if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, post;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - (w << 1);
if (padding >= 3) {
output_ptr0[w] = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc0 = vPoolPre_f32<P>(row0, row1);
acc0 = vPoolPre_f32<P>(acc0, row2);
if (padding == 1) {
acc0 = vpPoolPre_f32<P>(acc0, acc0);
}
acc0 = vPoolPost_f32<P>(acc0, post);
vst1_lane_f32(output_ptr0 + w, acc0, 0);
}
}
input_ptr0 += valid_input_w_start;
input_ptr1 += valid_input_w_start;
input_ptr2 += valid_input_w_start;
output_ptr0 += valid_w_start;
}
// valid
float32x4x2_t x0, x1, x2, y0; float32x4x2_t x0, x1, x2, y0;
float32x4_t post = vdupq_n_f32(1.f / 9); float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) { for (int loop = 0; loop < output_w_tiles; ++loop) {
...@@ -969,47 +1201,93 @@ struct Pooling3x3<P, 2> { ...@@ -969,47 +1201,93 @@ struct Pooling3x3<P, 2> {
output_ptr0 += 6; output_ptr0 += 6;
} }
// remain width // remain width
if (remain >= 4) { if (output_w_remain > 0) {
x0 = vld2q_f32(input_ptr0); x0 = vld2q_f32(input_ptr0);
x1.val[0] = vdupq_n_f32(input_ptr0[8]); x1 = vld2q_f32(input_ptr0 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0 = vld2q_f32(input_ptr1); x0 = vld2q_f32(input_ptr1);
x1.val[0] = vdupq_n_f32(input_ptr1[8]); x1 = vld2q_f32(input_ptr1 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0 = vld2q_f32(input_ptr2); x0 = vld2q_f32(input_ptr2);
x1.val[0] = vdupq_n_f32(input_ptr2[8]); x1 = vld2q_f32(input_ptr2 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
// restore
switch (output_w_remain) {
case 1:
vst1q_lane_f32(output_ptr0, y0.val[0], 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2);
break;
case 4:
vst1q_f32(output_ptr0, y0.val[0]); vst1q_f32(output_ptr0, y0.val[0]);
break;
input_ptr0 += 8; case 5:
input_ptr1 += 8; vst1q_f32(output_ptr0, y0.val[0]);
input_ptr2 += 8; vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0);
output_ptr0 += 4; break;
remain -= 4;
} }
#endif // __ARM_NEON__ input_ptr0 += (output_w_remain << 1);
for (int r = 0; r < remain; ++r) { input_ptr1 += (output_w_remain << 1);
float m0 = PoolPre<P>(input_ptr0[2 * r], input_ptr0[2 * r + 1]); input_ptr2 += (output_w_remain << 1);
m0 = PoolPre<P>(m0, input_ptr0[2 * r + 2]); output_ptr0 += output_w_remain;
float m1 = PoolPre<P>(input_ptr1[2 * r], input_ptr1[2 * r + 1]); }
m1 = PoolPre<P>(m1, input_ptr1[2 * r + 2]); // pad right
float m2 = PoolPre<P>(input_ptr2[2 * r], input_ptr2[2 * r + 1]); if (padding_r > 0) {
m2 = PoolPre<P>(m2, input_ptr2[2 * r + 2]); float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2); float32x2_t row2 = vld1_f32(input_ptr2);
output_ptr0[r] = PoolPost<P>(m0, avg); float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, post;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc0 = vPoolPre_f32<P>(row0, row1);
acc0 = vPoolPre_f32<P>(acc0, row2);
if (padding == 1) {
acc0 = vpPoolPre_f32<P>(acc0, acc0);
} }
acc0 = vPoolPost_f32<P>(acc0, post);
vst1_lane_f32(output_ptr0, acc0, 0);
}
output_ptr0++;
}
}
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling3x3NormalRow<P, 2>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
} }
} }
} }
...@@ -1025,4 +1303,5 @@ template struct Pooling3x3<AVG, 2>; ...@@ -1025,4 +1303,5 @@ template struct Pooling3x3<AVG, 2>;
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif // __ARM_NEON
#endif // POOL_OP #endif // POOL_OP
...@@ -169,55 +169,55 @@ int main(int argc, char *argv[]) { ...@@ -169,55 +169,55 @@ int main(int argc, char *argv[]) {
<< "float, pooling_type=avg, kernel=3, pad=5, stride=2"; << "float, pooling_type=avg, kernel=3, pad=5, stride=2";
paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=0, stride=1"; // << "float, pooling_type=max, kernel=2, pad=0, stride=1";
paddle_mobile::TestPoolOp<0, 2, 0, 1>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<0, 2, 0, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=1, stride=1"; // << "float, pooling_type=max, kernel=2, pad=1, stride=1";
paddle_mobile::TestPoolOp<0, 2, 1, 1>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<0, 2, 1, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=2, stride=1"; // << "float, pooling_type=max, kernel=2, pad=2, stride=1";
paddle_mobile::TestPoolOp<0, 2, 2, 1>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<0, 2, 2, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=5, stride=1"; // << "float, pooling_type=max, kernel=2, pad=5, stride=1";
paddle_mobile::TestPoolOp<0, 2, 5, 1>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<0, 2, 5, 1>(in_channels, in_height, in_width);
//
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=0, stride=1"; // << "float, pooling_type=avg, kernel=2, pad=0, stride=1";
paddle_mobile::TestPoolOp<1, 2, 0, 1>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<1, 2, 0, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=1, stride=1"; // << "float, pooling_type=avg, kernel=2, pad=1, stride=1";
paddle_mobile::TestPoolOp<1, 2, 1, 1>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<1, 2, 1, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=2, stride=1"; // << "float, pooling_type=avg, kernel=2, pad=2, stride=1";
paddle_mobile::TestPoolOp<1, 2, 2, 1>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<1, 2, 2, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=5, stride=1"; // << "float, pooling_type=avg, kernel=2, pad=5, stride=1";
paddle_mobile::TestPoolOp<1, 2, 5, 1>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<1, 2, 5, 1>(in_channels, in_height, in_width);
//
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=0, stride=2"; // << "float, pooling_type=max, kernel=2, pad=0, stride=2";
paddle_mobile::TestPoolOp<0, 2, 0, 2>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<0, 2, 0, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=1, stride=2"; // << "float, pooling_type=max, kernel=2, pad=1, stride=2";
paddle_mobile::TestPoolOp<0, 2, 1, 2>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<0, 2, 1, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=2, stride=2"; // << "float, pooling_type=max, kernel=2, pad=2, stride=2";
paddle_mobile::TestPoolOp<0, 2, 2, 2>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<0, 2, 2, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=5, stride=2"; // << "float, pooling_type=max, kernel=2, pad=5, stride=2";
paddle_mobile::TestPoolOp<0, 2, 5, 2>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<0, 2, 5, 2>(in_channels, in_height, in_width);
//
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=0, stride=2"; // << "float, pooling_type=avg, kernel=2, pad=0, stride=2";
paddle_mobile::TestPoolOp<1, 2, 0, 2>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<1, 2, 0, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=1, stride=2"; // << "float, pooling_type=avg, kernel=2, pad=1, stride=2";
paddle_mobile::TestPoolOp<1, 2, 1, 2>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<1, 2, 1, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=2, stride=2"; // << "float, pooling_type=avg, kernel=2, pad=2, stride=2";
paddle_mobile::TestPoolOp<1, 2, 2, 2>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<1, 2, 2, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) // LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=5, stride=2"; // << "float, pooling_type=avg, kernel=2, pad=5, stride=2";
paddle_mobile::TestPoolOp<1, 2, 5, 2>(in_channels, in_height, in_width); // paddle_mobile::TestPoolOp<1, 2, 5, 2>(in_channels, in_height, in_width);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册