提交 eb1e983c 编写于 作者: H hjchen2

Optimize pooling3x3

上级 26769ad7
...@@ -60,18 +60,12 @@ void Pooling<P>::operator()(const framework::Tensor &input, ...@@ -60,18 +60,12 @@ void Pooling<P>::operator()(const framework::Tensor &input,
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
PoolingVal<P> val; PoolingVal<P> val;
// std::cout << "output[" << ph * output_width + pw << "]:"
// << std::endl;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
val += input_ptr[h * input_width + w]; val += input_ptr[h * input_width + w];
// std::cout << "input[" << h << "][" << w << "] = "
// << input_ptr[h * input_width + w] << std::endl;
} }
} }
output_ptr[ph * output_width + pw] = val.Value(); output_ptr[ph * output_width + pw] = val.Value();
// std::cout << "output[" << ph * output_width + pw << "] = "
// << val.Value() << std::endl;
} }
} }
} }
......
...@@ -34,45 +34,39 @@ template <PoolingType P = Max> ...@@ -34,45 +34,39 @@ template <PoolingType P = Max>
struct PoolingVal { struct PoolingVal {
float val; float val;
int count; int count;
PoolingVal() { PoolingVal() : count(0) { val = -std::numeric_limits<float>::max(); }
val = -std::numeric_limits<float>::max();
count = 0;
}
inline PoolingVal<P> &operator+=(const float &x) { inline PoolingVal<P> &operator+=(const float &x) {
val = std::max(val, x); val = std::max(val, x);
count += 1; ++count;
return *this; return *this;
} }
float Value() const { inline float Value() { return (count > 0) ? val : 0.f; }
if (count > 0) {
return val;
}
return 0.f;
}
}; };
template <> template <>
struct PoolingVal<Avg> { struct PoolingVal<Avg> {
float val; float val;
int count; int count;
PoolingVal() { PoolingVal() : val(0.f), count(0) {}
val = 0.f;
count = 0;
}
inline PoolingVal<Avg> &operator+=(const float &x) { inline PoolingVal<Avg> &operator+=(const float &x) {
val += x; val += x;
count += 1; ++count;
return *this; return *this;
} }
float Value() const { inline float Value() { return (count > 0) ? val * (1.f / count) : 0.f; }
if (count > 0) {
return val / count;
}
return 0.f;
}
}; };
#if defined(__ARM_NEON) || defined(__ARM_NEON__) #if defined(__ARM_NEON) || defined(__ARM_NEON__)
template <PoolingType P = Max>
inline float32x4_t vPoolInitq_f32() {
return vdupq_n_f32(-std::numeric_limits<float>::max());
}
template <>
inline float32x4_t vPoolInitq_f32<Avg>() {
return vdupq_n_f32(0.f);
}
template <PoolingType P = Max> template <PoolingType P = Max>
inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) { inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) {
return vmaxq_f32(x1, x2); return vmaxq_f32(x1, x2);
...@@ -85,14 +79,15 @@ inline float32x4_t vPoolPreq_f32<Avg>(const float32x4_t &x1, ...@@ -85,14 +79,15 @@ inline float32x4_t vPoolPreq_f32<Avg>(const float32x4_t &x1,
} }
template <PoolingType P = Max> template <PoolingType P = Max>
inline float32x4_t vPoolPostq_f32(const float32x4_t &x) { inline float32x4_t vPoolPostq_f32(const float32x4_t &x,
const float32x4_t &post) {
return x; return x;
} }
template <> template <>
inline float32x4_t vPoolPostq_f32<Avg>(const float32x4_t &x) { inline float32x4_t vPoolPostq_f32<Avg>(const float32x4_t &x,
float32x4_t avg = vdupq_n_f32(1.f / 9); const float32x4_t &post) {
return vmulq_f32(avg, x); return vmulq_f32(x, post);
} }
#endif // __ARM_NEON__ #endif // __ARM_NEON__
...@@ -107,13 +102,13 @@ inline float PoolPre<Avg>(const float &x1, const float &x2) { ...@@ -107,13 +102,13 @@ inline float PoolPre<Avg>(const float &x1, const float &x2) {
} }
template <PoolingType P = Max> template <PoolingType P = Max>
inline float PoolPost(const float &x) { inline float PoolPost(const float &x, const float &post) {
return x; return x;
} }
template <> template <>
inline float PoolPost<Avg>(const float &x) { inline float PoolPost<Avg>(const float &x, const float &post) {
return 1.f / 9 * x; return x * post;
} }
template <PoolingType P> template <PoolingType P>
......
...@@ -38,6 +38,126 @@ namespace math { ...@@ -38,6 +38,126 @@ namespace math {
output_ptr[w] = val.Value(); \ output_ptr[w] = val.Value(); \
} }
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
template <PoolingType P, int Stride = 1>
struct Pooling3x3ValidColLoadInput {
inline void operator()(const float *input, const int input_w,
const int valid_cols, float32x4x2_t &x0, // NOLINT
float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT
float32x4x2_t &y0) { // NOLINT
float fake_input[3][8];
if (valid_cols == 1) {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
}
} else if (valid_cols == 2) {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
}
} else {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
fake_input[2][i] = input[2];
}
}
y0.val[0] = vPoolInitq_f32<P>();
y0.val[1] = vPoolInitq_f32<P>();
for (int i = 0; i < valid_cols; ++i) {
x0.val[0] = vld1q_f32(fake_input[i]);
x0.val[1] = vld1q_f32(fake_input[i] + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x1.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x1.val[1], y0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x2.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x2.val[1], y0.val[1]);
}
}
};
template <PoolingType P>
struct Pooling3x3ValidColLoadInput<P, 2> {
inline void operator()(const float *input, const int input_w,
const int valid_cols, float32x4x2_t &x0, // NOLINT
float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT
float32x4x2_t &y0) { // NOLINT
float fake_input[3][13];
if (valid_cols == 1) {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
}
} else if (valid_cols == 2) {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
}
} else {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
fake_input[2][i] = input[2];
}
}
for (int i = 0; i < valid_cols; ++i) {
x0 = vld2q_f32(fake_input[i]);
x1 = vld2q_f32(fake_input[i] + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
}
}
};
template <PoolingType P, int Stride = 1>
struct Pooling3x3NormalRowLoadInput {
inline void operator()(const float *input, float32x4x2_t &x0, // NOLINT
float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT
float32x4x2_t &y0) { // NOLINT
x0.val[0] = vld1q_f32(input);
x0.val[1] = vld1q_f32(input + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x1.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x1.val[1], y0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x2.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x2.val[1], y0.val[1]);
}
};
template <PoolingType P>
struct Pooling3x3NormalRowLoadInput<P, 2> {
inline void operator()(const float *input, float32x4x2_t &x0, // NOLINT
float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT
float32x4x2_t &y0) { // NOLINT
x0 = vld2q_f32(input);
x1 = vld2q_f32(input + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
}
};
#endif // __ARM_NEON__
template <PoolingType P, int Stride> template <PoolingType P, int Stride>
inline void Pooling3x3ValidCol(const float *input, const int h_output, inline void Pooling3x3ValidCol(const float *input, const int h_output,
const int h_output_end, const int w_output, const int h_output_end, const int w_output,
...@@ -48,7 +168,38 @@ inline void Pooling3x3ValidCol(const float *input, const int h_output, ...@@ -48,7 +168,38 @@ inline void Pooling3x3ValidCol(const float *input, const int h_output,
const int w_in_end = w_in_start + 3; const int w_in_end = w_in_start + 3;
const int w_start = w_in_start > 0 ? w_in_start : 0; const int w_start = w_in_start > 0 ? w_in_start : 0;
const int w_end = w_in_end < input_w ? w_in_end : input_w; const int w_end = w_in_end < input_w ? w_in_end : input_w;
for (int h = h_output; h < h_output_end; ++h) { int remain_start = h_output;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int output_tiles = (h_output_end - h_output) / 6;
remain_start = h_output + output_tiles * 6;
int input_h_start = h_output * Stride - padding_h;
size_t input_offset = input_h_start * input_w + w_start;
size_t output_offset = h_output * output_w + w_output;
int valid_cols = w_end - w_start;
Pooling3x3ValidColLoadInput<P, Stride> PoolingCompute;
float32x4x2_t x0, x1, x2, y0;
float32x4_t avg = vdupq_n_f32(1.f / (3 * valid_cols));
for (int h = 0; h < output_tiles * 6; h += 6) {
float *output0 = output + output_offset;
float *output1 = output0 + output_w;
float *output2 = output1 + output_w;
float *output3 = output2 + output_w;
float *output4 = output3 + output_w;
float *output5 = output4 + output_w;
y0.val[0] = vPoolInitq_f32<P>();
y0.val[1] = vPoolInitq_f32<P>();
PoolingCompute(input + input_offset, input_w, valid_cols, x0, x1, x2, y0);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], avg);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], avg);
vst1q_lane_f32(output0, y0.val[0], 0);
vst1q_lane_f32(output1, y0.val[0], 1);
vst1q_lane_f32(output2, y0.val[0], 2);
vst1q_lane_f32(output3, y0.val[0], 3);
vst1q_lane_f32(output4, y0.val[1], 0);
vst1q_lane_f32(output5, y0.val[1], 1);
}
#endif
for (int h = remain_start; h < h_output_end; ++h) {
PoolingVal<P> val; PoolingVal<P> val;
const int h_in_start = -padding_h + h * Stride; const int h_in_start = -padding_h + h * Stride;
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
...@@ -77,7 +228,28 @@ inline void Pooling3x3NormalRow(const float *input, const int h_output, ...@@ -77,7 +228,28 @@ inline void Pooling3x3NormalRow(const float *input, const int h_output,
// border left // border left
POOLING3X3_NORMAL_BORDER(0, valid_w_start) POOLING3X3_NORMAL_BORDER(0, valid_w_start)
// middle // middle
for (int w = valid_w_start; w < valid_w_end; ++w) { int remain_start = valid_w_start;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int output_tiles = (valid_w_end - valid_w_start) / 6;
remain_start = valid_w_start + output_tiles * 6;
Pooling3x3NormalRowLoadInput<P, Stride> PoolingCompute;
float32x4x2_t x0, x1, x2, y0;
float32x4_t post = vdupq_n_f32(1.f / (3 * (h_end - h_start)));
for (int w = 0; w < output_tiles * 6; w += 6) {
int output_offset = valid_w_start + w;
int input_w_offset = output_offset * Stride - padding_w;
y0.val[0] = vPoolInitq_f32<P>();
y0.val[1] = vPoolInitq_f32<P>();
for (int h_in = h_start; h_in < h_end; ++h_in) {
PoolingCompute(input + h_in * input_w + input_w_offset, x0, x1, x2, y0);
}
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr + output_offset, y0.val[0]);
vst1_f32(output_ptr + output_offset + 4, vget_low_f32(y0.val[1]));
}
#endif // __ARM_NEON__
for (int w = remain_start; w < valid_w_end; ++w) {
PoolingVal<P> val; PoolingVal<P> val;
int input_start = -padding_w + w * Stride; int input_start = -padding_w + w * Stride;
for (int h_in = h_start; h_in < h_end; ++h_in) { for (int h_in = h_start; h_in < h_end; ++h_in) {
...@@ -112,6 +284,7 @@ struct Pooling3x3<P, 1> { ...@@ -112,6 +284,7 @@ struct Pooling3x3<P, 1> {
int valid_w_start = padding_w; int valid_w_start = padding_w;
int valid_w_end = output_w - valid_w_start; int valid_w_end = output_w - valid_w_start;
int valid_w = valid_w_end - valid_w_start; int valid_w = valid_w_end - valid_w_start;
float avg = 1.f / 9;
#pragma omp parallel for #pragma omp parallel for
for (int c = 0; c < output->dims()[1]; ++c) { for (int c = 0; c < output->dims()[1]; ++c) {
...@@ -157,6 +330,7 @@ struct Pooling3x3<P, 1> { ...@@ -157,6 +330,7 @@ struct Pooling3x3<P, 1> {
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4x2_t x0, x1, x2; float32x4x2_t x0, x1, x2;
float32x4x2_t y0, y1, y2; float32x4x2_t y0, y1, y2;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) { for (int loop = 0; loop < output_w_tiles; ++loop) {
x0.val[0] = vld1q_f32(input_ptr0); x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4); x0.val[1] = vld1q_f32(input_ptr0 + 4);
...@@ -196,8 +370,8 @@ struct Pooling3x3<P, 1> { ...@@ -196,8 +370,8 @@ struct Pooling3x3<P, 1> {
y1.val[1] = vPoolPreq_f32<P>(y2.val[1], y1.val[1]); y1.val[1] = vPoolPreq_f32<P>(y2.val[1], y1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y2.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(y2.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y2.val[1], y0.val[1]); y0.val[1] = vPoolPreq_f32<P>(y2.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1]); y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]); vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
...@@ -215,8 +389,8 @@ struct Pooling3x3<P, 1> { ...@@ -215,8 +389,8 @@ struct Pooling3x3<P, 1> {
y1.val[1] = vPoolPreq_f32<P>(y0.val[1], y1.val[1]); y1.val[1] = vPoolPreq_f32<P>(y0.val[1], y1.val[1]);
y2.val[0] = vPoolPreq_f32<P>(y0.val[0], y2.val[0]); y2.val[0] = vPoolPreq_f32<P>(y0.val[0], y2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(y0.val[1], y2.val[1]); y2.val[1] = vPoolPreq_f32<P>(y0.val[1], y2.val[1]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0]); y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
y1.val[1] = vPoolPostq_f32<P>(y1.val[1]); y1.val[1] = vPoolPostq_f32<P>(y1.val[1], post);
vst1q_f32(output_ptr1, y1.val[0]); vst1q_f32(output_ptr1, y1.val[0]);
vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1])); vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1]));
...@@ -234,8 +408,8 @@ struct Pooling3x3<P, 1> { ...@@ -234,8 +408,8 @@ struct Pooling3x3<P, 1> {
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]); y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]); y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], y2.val[1]); y2.val[1] = vPoolPreq_f32<P>(x0.val[1], y2.val[1]);
y2.val[0] = vPoolPostq_f32<P>(y2.val[0]); y2.val[0] = vPoolPostq_f32<P>(y2.val[0], post);
y2.val[1] = vPoolPostq_f32<P>(y2.val[1]); y2.val[1] = vPoolPostq_f32<P>(y2.val[1], post);
vst1q_f32(output_ptr2, y2.val[0]); vst1q_f32(output_ptr2, y2.val[0]);
vst1_f32(output_ptr2 + 4, vget_low_f32(y2.val[1])); vst1_f32(output_ptr2 + 4, vget_low_f32(y2.val[1]));
...@@ -251,8 +425,8 @@ struct Pooling3x3<P, 1> { ...@@ -251,8 +425,8 @@ struct Pooling3x3<P, 1> {
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]); y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1]); y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr3, y0.val[0]); vst1q_f32(output_ptr3, y0.val[0]);
vst1_f32(output_ptr3 + 4, vget_low_f32(y0.val[1])); vst1_f32(output_ptr3 + 4, vget_low_f32(y0.val[1]));
...@@ -292,7 +466,7 @@ struct Pooling3x3<P, 1> { ...@@ -292,7 +466,7 @@ struct Pooling3x3<P, 1> {
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y2.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[0] = vPoolPreq_f32<P>(y2.val[0], y1.val[0]); y1.val[0] = vPoolPreq_f32<P>(y2.val[0], y1.val[0]);
y0.val[0] = vPoolPreq_f32<P>(y2.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(y2.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]); vst1q_f32(output_ptr0, y0.val[0]);
x0.val[0] = vld1q_f32(input_ptr3); x0.val[0] = vld1q_f32(input_ptr3);
...@@ -303,7 +477,7 @@ struct Pooling3x3<P, 1> { ...@@ -303,7 +477,7 @@ struct Pooling3x3<P, 1> {
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]); y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y2.val[0] = vPoolPreq_f32<P>(y0.val[0], y2.val[0]); y2.val[0] = vPoolPreq_f32<P>(y0.val[0], y2.val[0]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0]); y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
vst1q_f32(output_ptr1, y1.val[0]); vst1q_f32(output_ptr1, y1.val[0]);
x0.val[0] = vld1q_f32(input_ptr4); x0.val[0] = vld1q_f32(input_ptr4);
...@@ -314,7 +488,7 @@ struct Pooling3x3<P, 1> { ...@@ -314,7 +488,7 @@ struct Pooling3x3<P, 1> {
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]); y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]);
y2.val[0] = vPoolPostq_f32<P>(y2.val[0]); y2.val[0] = vPoolPostq_f32<P>(y2.val[0], post);
vst1q_f32(output_ptr2, y2.val[0]); vst1q_f32(output_ptr2, y2.val[0]);
x0.val[0] = vld1q_f32(input_ptr5); x0.val[0] = vld1q_f32(input_ptr5);
...@@ -324,7 +498,7 @@ struct Pooling3x3<P, 1> { ...@@ -324,7 +498,7 @@ struct Pooling3x3<P, 1> {
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr3, y0.val[0]); vst1q_f32(output_ptr3, y0.val[0]);
input_ptr0 += 4; input_ptr0 += 4;
...@@ -358,10 +532,10 @@ struct Pooling3x3<P, 1> { ...@@ -358,10 +532,10 @@ struct Pooling3x3<P, 1> {
m1 = PoolPre<P>(PoolPre<P>(m1, m2), m3); m1 = PoolPre<P>(PoolPre<P>(m1, m2), m3);
m2 = PoolPre<P>(PoolPre<P>(m2, m3), m4); m2 = PoolPre<P>(PoolPre<P>(m2, m3), m4);
m3 = PoolPre<P>(PoolPre<P>(m3, m4), m5); m3 = PoolPre<P>(PoolPre<P>(m3, m4), m5);
output_ptr0[r] = PoolPost<P>(m0); output_ptr0[r] = PoolPost<P>(m0, avg);
output_ptr1[r] = PoolPost<P>(m1); output_ptr1[r] = PoolPost<P>(m1, avg);
output_ptr2[r] = PoolPost<P>(m2); output_ptr2[r] = PoolPost<P>(m2, avg);
output_ptr3[r] = PoolPost<P>(m3); output_ptr3[r] = PoolPost<P>(m3, avg);
} }
} }
// remain h // remain h
...@@ -374,6 +548,7 @@ struct Pooling3x3<P, 1> { ...@@ -374,6 +548,7 @@ struct Pooling3x3<P, 1> {
int remain = output_w_remain; int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4x2_t x0, x1, x2, y0; float32x4x2_t x0, x1, x2, y0;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) { for (int loop = 0; loop < output_w_tiles; ++loop) {
x0.val[0] = vld1q_f32(input_ptr0); x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4); x0.val[1] = vld1q_f32(input_ptr0 + 4);
...@@ -411,8 +586,8 @@ struct Pooling3x3<P, 1> { ...@@ -411,8 +586,8 @@ struct Pooling3x3<P, 1> {
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]); y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1]); y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]); vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
...@@ -445,7 +620,7 @@ struct Pooling3x3<P, 1> { ...@@ -445,7 +620,7 @@ struct Pooling3x3<P, 1> {
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]); vst1q_f32(output_ptr0, y0.val[0]);
input_ptr0 += 4; input_ptr0 += 4;
...@@ -464,7 +639,7 @@ struct Pooling3x3<P, 1> { ...@@ -464,7 +639,7 @@ struct Pooling3x3<P, 1> {
m2 = PoolPre<P>(m2, input_ptr2[r + 2]); m2 = PoolPre<P>(m2, input_ptr2[r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2); m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
output_ptr0[r] = PoolPost<P>(m0); output_ptr0[r] = PoolPost<P>(m0, avg);
} }
} }
} }
...@@ -492,6 +667,7 @@ struct Pooling3x3<P, 2> { ...@@ -492,6 +667,7 @@ struct Pooling3x3<P, 2> {
int valid_w_start = (padding_w + 1) / 2; int valid_w_start = (padding_w + 1) / 2;
int valid_w_end = output_w - valid_w_start; int valid_w_end = output_w - valid_w_start;
int valid_w = valid_w_end - valid_w_start; int valid_w = valid_w_end - valid_w_start;
float avg = 1.f / 9;
#pragma omp parallel for #pragma omp parallel for
for (int c = 0; c < output->dims()[1]; ++c) { for (int c = 0; c < output->dims()[1]; ++c) {
...@@ -539,6 +715,7 @@ struct Pooling3x3<P, 2> { ...@@ -539,6 +715,7 @@ struct Pooling3x3<P, 2> {
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4x2_t x0, x1, x2; float32x4x2_t x0, x1, x2;
float32x4x2_t y0, y1, y2; float32x4x2_t y0, y1, y2;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) { for (int loop = 0; loop < output_w_tiles; ++loop) {
x0 = vld2q_f32(input_ptr0); x0 = vld2q_f32(input_ptr0);
x1 = vld2q_f32(input_ptr0 + 8); x1 = vld2q_f32(input_ptr0 + 8);
...@@ -570,8 +747,8 @@ struct Pooling3x3<P, 2> { ...@@ -570,8 +747,8 @@ struct Pooling3x3<P, 2> {
y1.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]); y1.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y1.val[1], y0.val[1]); y0.val[1] = vPoolPreq_f32<P>(y1.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1]); y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]); vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
...@@ -596,8 +773,8 @@ struct Pooling3x3<P, 2> { ...@@ -596,8 +773,8 @@ struct Pooling3x3<P, 2> {
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]); y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]); y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y0.val[1], y1.val[1]); y1.val[1] = vPoolPreq_f32<P>(y0.val[1], y1.val[1]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0]); y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
y1.val[1] = vPoolPostq_f32<P>(y1.val[1]); y1.val[1] = vPoolPostq_f32<P>(y1.val[1], post);
vst1q_f32(output_ptr1, y1.val[0]); vst1q_f32(output_ptr1, y1.val[0]);
vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1])); vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1]));
...@@ -622,8 +799,8 @@ struct Pooling3x3<P, 2> { ...@@ -622,8 +799,8 @@ struct Pooling3x3<P, 2> {
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]); y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1]); y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr2, y0.val[0]); vst1q_f32(output_ptr2, y0.val[0]);
vst1_f32(output_ptr2 + 4, vget_low_f32(y0.val[1])); vst1_f32(output_ptr2 + 4, vget_low_f32(y0.val[1]));
...@@ -659,7 +836,7 @@ struct Pooling3x3<P, 2> { ...@@ -659,7 +836,7 @@ struct Pooling3x3<P, 2> {
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]); vst1q_f32(output_ptr0, y0.val[0]);
x0 = vld2q_f32(input_ptr3); x0 = vld2q_f32(input_ptr3);
...@@ -675,7 +852,7 @@ struct Pooling3x3<P, 2> { ...@@ -675,7 +852,7 @@ struct Pooling3x3<P, 2> {
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]); y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0]); y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
vst1q_f32(output_ptr1, y1.val[0]); vst1q_f32(output_ptr1, y1.val[0]);
x0 = vld2q_f32(input_ptr5); x0 = vld2q_f32(input_ptr5);
...@@ -691,7 +868,7 @@ struct Pooling3x3<P, 2> { ...@@ -691,7 +868,7 @@ struct Pooling3x3<P, 2> {
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr2, y0.val[0]); vst1q_f32(output_ptr2, y0.val[0]);
input_ptr0 += 8; input_ptr0 += 8;
...@@ -726,9 +903,9 @@ struct Pooling3x3<P, 2> { ...@@ -726,9 +903,9 @@ struct Pooling3x3<P, 2> {
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2); m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
m1 = PoolPre<P>(PoolPre<P>(m2, m3), m4); m1 = PoolPre<P>(PoolPre<P>(m2, m3), m4);
m2 = PoolPre<P>(PoolPre<P>(m4, m5), m6); m2 = PoolPre<P>(PoolPre<P>(m4, m5), m6);
output_ptr0[r] = PoolPost<P>(m0); output_ptr0[r] = PoolPost<P>(m0, avg);
output_ptr1[r] = PoolPost<P>(m1); output_ptr1[r] = PoolPost<P>(m1, avg);
output_ptr2[r] = PoolPost<P>(m2); output_ptr2[r] = PoolPost<P>(m2, avg);
} }
} }
// remain h // remain h
...@@ -742,6 +919,7 @@ struct Pooling3x3<P, 2> { ...@@ -742,6 +919,7 @@ struct Pooling3x3<P, 2> {
int remain = output_w_remain; int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4x2_t x0, x1, x2, y0; float32x4x2_t x0, x1, x2, y0;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) { for (int loop = 0; loop < output_w_tiles; ++loop) {
x0 = vld2q_f32(input_ptr0); x0 = vld2q_f32(input_ptr0);
x1 = vld2q_f32(input_ptr0 + 8); x1 = vld2q_f32(input_ptr0 + 8);
...@@ -773,8 +951,8 @@ struct Pooling3x3<P, 2> { ...@@ -773,8 +951,8 @@ struct Pooling3x3<P, 2> {
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]); y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1]); y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]); vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
...@@ -804,7 +982,7 @@ struct Pooling3x3<P, 2> { ...@@ -804,7 +982,7 @@ struct Pooling3x3<P, 2> {
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]); vst1q_f32(output_ptr0, y0.val[0]);
input_ptr0 += 8; input_ptr0 += 8;
...@@ -823,7 +1001,7 @@ struct Pooling3x3<P, 2> { ...@@ -823,7 +1001,7 @@ struct Pooling3x3<P, 2> {
m2 = PoolPre<P>(m2, input_ptr2[2 * r + 2]); m2 = PoolPre<P>(m2, input_ptr2[2 * r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2); m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
output_ptr0[r] = PoolPost<P>(m0); output_ptr0[r] = PoolPost<P>(m0, avg);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册