提交 25c10df2 编写于 作者: W wangliu

optimize pool3x3 kernel

上级 90344f8c
...@@ -63,8 +63,20 @@ void PoolKernel<CPU, float>::Compute(const PoolParam &param) const { ...@@ -63,8 +63,20 @@ void PoolKernel<CPU, float>::Compute(const PoolParam &param) const {
} }
} else if (ksize[0] == 3 && ksize[0] == ksize[1]) { } else if (ksize[0] == 3 && ksize[0] == ksize[1]) {
if (pooling_type == "max") { if (pooling_type == "max") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Maxs1p1(in_x, out);
} else {
math::Pool3x3Max(strides, paddings, in_x, out);
}
math::Pool3x3Max(strides, paddings, in_x, out); math::Pool3x3Max(strides, paddings, in_x, out);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Avgs1p1(in_x, out);
} else {
math::Pool3x3Avg(strides, paddings, in_x, out);
}
math::Pool3x3Avg(strides, paddings, in_x, out); math::Pool3x3Avg(strides, paddings, in_x, out);
} }
......
...@@ -13,13 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,13 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef POOL_OP #ifdef POOL_OP
#define __ARM_NEON true #include "operators/math/pool_3x3.h"
#include "pool_3x3.h"
#include "framework/tensor.h"
#if __ARM_NEON
#include <arm_neon.h>
#endif // __ARM_NEON
#include <climits> #include <climits>
#include "framework/tensor.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -27,6 +23,481 @@ using framework::Tensor; ...@@ -27,6 +23,481 @@ using framework::Tensor;
using std::max; using std::max;
using std::min; using std::min;
using std::vector; using std::vector;
void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
#if __ARM_NEON
const int batch_size = input->dims()[0];
const int h_in = input->dims()[2];
const int w_in = input->dims()[3];
const int output_channels = output->dims()[1];
const int h_out = output->dims()[2];
const int w_out = output->dims()[3];
const int outputdata_channel_stride = h_out * w_out;
const int inputdata_channel_stride = h_in * w_in;
float *out_data = output->data<float>();
const float *input_data = input->data<float>();
const float coef = 1.0 / 9.0;
for (int k = 0; k < batch_size; ++k) {
for (int c = 0; c < output_channels; ++c) {
// four corner point
out_data[0] = (input_data[0] + input_data[1] + input_data[w_in] +
input_data[w_in + 1]) *
coef;
out_data[w_out - 1] =
(input_data[w_in - 2] + input_data[w_in - 1] +
input_data[w_in * 2 - 2] + input_data[2 * w_in - 1]) *
coef;
out_data[(h_out - 1) * w_out] =
(input_data[(h_in - 2) * w_in] + input_data[(h_in - 2) * w_in + 1] +
input_data[(h_in - 1) * w_in] + input_data[(h_in - 1) * w_in + 1]) *
coef;
out_data[h_out * w_out - 1] =
(input_data[h_in * w_in - 1] + input_data[h_in * w_in - 2] +
input_data[(h_in - 1) * w_in - 1] +
input_data[(h_in - 1) * w_in - 2]) *
coef;
// left side & right side
for (int i = 1; i < h_in - 1; ++i) {
out_data[i * w_out] =
(input_data[i * w_in - w_in] + input_data[i * w_in - w_in + 1] +
input_data[i * w_in] + input_data[i * w_in + 1] +
input_data[i * w_in + w_in] + input_data[i * w_in + w_in + 1]) *
coef;
out_data[i * w_out + w_out - 1] =
(input_data[i * w_in - w_in + w_in - 2] +
input_data[i * w_in - w_in + 1 + w_in - 2] +
input_data[i * w_in + w_in - 2] +
input_data[i * w_in + 1 + w_in - 2] +
input_data[i * w_in + w_in + w_in - 2] +
input_data[i * w_in + w_in + 1 + w_in - 2]) *
coef;
}
// top 1 row & bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, sum, out0;
float32x4_t v_coef = vdupq_n_f32(coef);
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w_in);
const float *input_tmp_end = input_tmp + (h_in - 2) * w_in;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + w_in);
int c_mid = w_out - 2;
auto output_ptr = out_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w_in + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
sum = vaddq_f32(in0, tmp0);
sum = vaddq_f32(sum, tmp1);
sum = vaddq_f32(sum, in2);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3);
vst1q_f32(output_ptr, vmulq_f32(sum, v_coef));
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + w_in + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
sum = vaddq_f32(in0, tmp0);
sum = vaddq_f32(sum, tmp1);
sum = vaddq_f32(sum, in2);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3);
vst1q_f32(output_ptr + (h_out - 1) * w_out, vmulq_f32(sum, v_coef));
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right remain
float32x4_t pad0 = vdupq_n_f32(input_data[w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 2);
tmp3 = vextq_f32(in2, pad1, 2);
sum = vaddq_f32(in0, tmp0);
sum = vaddq_f32(sum, tmp1);
sum = vaddq_f32(sum, in2);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3);
out0 = vmulq_f32(sum, v_coef);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
// bottom_right remain
float32x4_t pad2 = vdupq_n_f32(input_data[(h_in - 1) * w_in - 1]);
float32x4_t pad3 = vdupq_n_f32(input_data[h_in * w_in - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 2);
tmp3 = vextq_f32(in6, pad3, 2);
sum = vaddq_f32(in4, tmp0);
sum = vaddq_f32(sum, tmp1);
sum = vaddq_f32(sum, in6);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3);
out0 = vmulq_f32(sum, v_coef);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 2);
}
}
// mid
for (int j = 0; j < h_out - 2; ++j) {
output_ptr = out_data + w_out * (j + 1) + 1;
input_tmp = input_data + j * w_in;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w_in);
in4 = vld1q_f32(input_tmp + 2 * w_in);
c_mid = w_out - 2;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w_in + 4);
in5 = vld1q_f32(input_tmp + 2 * w_in + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
tmp4 = vextq_f32(in4, in5, 1);
tmp5 = vextq_f32(in4, in5, 2);
sum = vaddq_f32(in0, tmp0);
sum = vaddq_f32(sum, tmp1);
sum = vaddq_f32(sum, in2);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3);
sum = vaddq_f32(sum, in4);
sum = vaddq_f32(sum, tmp4);
sum = vaddq_f32(sum, tmp5);
out0 = vmulq_f32(sum, v_coef);
vst1q_f32(output_ptr, out0);
output_ptr += 4;
input_tmp += 4;
in0 = in1;
in2 = in3;
in4 = in5;
}
// mid remain
float32x4_t pad0 = vdupq_n_f32(input_data[(j + 1) * w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]);
float32x4_t pad2 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
tmp4 = vextq_f32(in4, pad2, 1);
tmp5 = vextq_f32(in4, pad2, 2);
sum = vaddq_f32(in0, tmp0);
sum = vaddq_f32(sum, tmp1);
sum = vaddq_f32(sum, in2);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3);
sum = vaddq_f32(sum, in4);
sum = vaddq_f32(sum, tmp4);
sum = vaddq_f32(sum, tmp5);
out0 = vmulq_f32(sum, v_coef);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
}
input_data += inputdata_channel_stride;
out_data += outputdata_channel_stride;
}
}
#endif
}
void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
#if __ARM_NEON
const int batch_size = input->dims()[0];
const int h_in = input->dims()[2];
const int w_in = input->dims()[3];
const int output_channels = output->dims()[1];
const int h_out = output->dims()[2];
const int w_out = output->dims()[3];
const int outputdata_channel_stride = h_out * w_out;
const int inputdata_channel_stride = h_in * w_in;
float *out_data = output->data<float>();
const float *input_data = input->data<float>();
for (int k = 0; k < batch_size; ++k) {
for (int c = 0; c < output_channels; ++c) {
// four corner point
out_data[0] = std::max(std::max(input_data[0], input_data[1]),
std::max(input_data[w_in], input_data[w_in + 1]));
out_data[w_out - 1] = std::max(
std::max(input_data[w_in - 2], input_data[w_in - 1]),
std::max(input_data[w_in * 2 - 2], input_data[2 * w_in - 1]));
out_data[(h_out - 1) * w_out] =
std::max(std::max(input_data[(h_in - 2) * w_in],
input_data[(h_in - 2) * w_in + 1]),
std::max(input_data[(h_in - 1) * w_in],
input_data[(h_in - 1) * w_in + 1]));
out_data[h_out * w_out - 1] = std::max(
std::max(input_data[(h_in - 1) * w_in - 1],
input_data[(h_in - 1) * w_in - 2]),
std::max(input_data[h_in * w_in - 1], input_data[h_in * w_in - 2]));
// left side & right side
for (int i = 1; i < h_in - 1; ++i) {
float max1 = std::max(input_data[i * w_in - w_in],
input_data[i * w_in - w_in + 1]);
float max2 = std::max(input_data[i * w_in], input_data[i * w_in + 1]);
float max3 = std::max(input_data[i * w_in + w_in],
input_data[i * w_in + w_in + 1]);
out_data[i * w_out] = std::max(std::max(max1, max2), max3);
max1 = std::max(input_data[i * w_in - w_in + w_in - 2],
input_data[i * w_in - w_in + 1 + w_in - 2]);
max2 = std::max(input_data[i * w_in + w_in - 2],
input_data[i * w_in + 1 + w_in - 2]);
max3 = std::max(input_data[i * w_in + w_in + w_in - 2],
input_data[i * w_in + w_in + 1 + w_in - 2]);
out_data[i * w_out + w_out - 1] = std::max(std::max(max1, max2), max3);
}
// top 1 row & bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, max;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w_in);
const float *input_tmp_end = input_tmp + (h_in - 2) * w_in;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + w_in);
int c_mid = w_out - 2;
auto output_ptr = out_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w_in + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
max = vmaxq_f32(in0, tmp0);
max = vmaxq_f32(max, tmp1);
max = vmaxq_f32(max, in2);
max = vmaxq_f32(max, tmp2);
max = vmaxq_f32(max, tmp3);
vst1q_f32(output_ptr, max);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + w_in + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
max = vmaxq_f32(in4, tmp0);
max = vmaxq_f32(max, tmp1);
max = vmaxq_f32(max, in6);
max = vmaxq_f32(max, tmp2);
max = vmaxq_f32(max, tmp3);
vst1q_f32(output_ptr + (h_out - 1) * w_out, max);
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right remain
float32x4_t pad0 = vdupq_n_f32(input_data[w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
max = vmaxq_f32(in0, tmp0);
max = vmaxq_f32(max, tmp1);
max = vmaxq_f32(max, in2);
max = vmaxq_f32(max, tmp2);
max = vmaxq_f32(max, tmp3);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, max, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, max, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, max, 2);
}
}
// bottom_right remain
float32x4_t pad2 = vdupq_n_f32(input_data[(h_in - 1) * w_in - 1]);
float32x4_t pad3 = vdupq_n_f32(input_data[h_in * w_in - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2);
max = vmaxq_f32(in4, tmp0);
max = vmaxq_f32(max, tmp1);
max = vmaxq_f32(max, in6);
max = vmaxq_f32(max, tmp2);
max = vmaxq_f32(max, tmp3);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, max, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, max, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, max, 2);
}
}
// mid
for (int j = 0; j < h_out - 2; ++j) {
output_ptr = out_data + (j + 1) * w_out + 1;
input_tmp = input_data + j * w_in;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w_in);
in4 = vld1q_f32(input_tmp + 2 * w_in);
c_mid = w_out - 2;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w_in + 4);
in5 = vld1q_f32(input_tmp + 2 * w_in + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
tmp4 = vextq_f32(in4, in5, 1);
tmp5 = vextq_f32(in4, in5, 2);
max = vmaxq_f32(in0, tmp0);
max = vmaxq_f32(max, tmp1);
max = vmaxq_f32(max, in2);
max = vmaxq_f32(max, tmp2);
max = vmaxq_f32(max, tmp3);
max = vmaxq_f32(max, in4);
max = vmaxq_f32(max, tmp4);
max = vmaxq_f32(max, tmp5);
vst1q_f32(output_ptr, max);
output_ptr += 4;
input_tmp += 4;
in0 = in1;
in2 = in3;
in4 = in5;
}
// mid remain
float32x4_t pad0 = vdupq_n_f32(input_data[(j + 1) * w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]);
float32x4_t pad2 = vdupq_n_f32(input_data[(j + 3) * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
tmp4 = vextq_f32(in4, pad2, 1);
tmp5 = vextq_f32(in4, pad2, 2);
max = vmaxq_f32(in0, tmp0);
max = vmaxq_f32(max, tmp1);
max = vmaxq_f32(max, in2);
max = vmaxq_f32(max, tmp2);
max = vmaxq_f32(max, tmp3);
max = vmaxq_f32(max, in4);
max = vmaxq_f32(max, tmp4);
max = vmaxq_f32(max, tmp5);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, max, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, max, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, max, 2);
}
}
}
input_data += inputdata_channel_stride;
out_data += outputdata_channel_stride;
}
}
#endif
}
void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input, void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
Tensor *output) { Tensor *output) {
......
...@@ -15,7 +15,8 @@ limitations under the License. */ ...@@ -15,7 +15,8 @@ limitations under the License. */
#ifdef POOL_OP #ifdef POOL_OP
#pragma once #pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h" #include "framework/tensor.h"
#if __ARM_NEON #if __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
...@@ -26,7 +27,8 @@ namespace operators { ...@@ -26,7 +27,8 @@ namespace operators {
namespace math { namespace math {
using framework::Tensor; using framework::Tensor;
using std::vector; using std::vector;
void Pool3x3Avgs1p1(const Tensor *input, Tensor *output);
void Pool3x3Maxs1p1(const Tensor *input, Tensor *output);
void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input, void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
Tensor *output); Tensor *output);
......
...@@ -30,12 +30,12 @@ int main() { ...@@ -30,12 +30,12 @@ int main() {
std::vector<int64_t> dims{1, 3, 224, 224}; std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims); GetInput<float>(g_test_image_1x3x224x224, &input, dims);
auto time3 = time(); auto time3 = time();
int count = 1;
for (int i = 0; i < 10; ++i) { for (int i = 0; i < count; i++) {
executor.Predict(input, dims); executor.Predict(input, dims);
} }
auto time4 = time(); auto time4 = time();
DLOG << "predict cost :" << time_diff(time3, time4) << "ms\n"; DLOG << "avg predict cost :" << time_diff(time3, time4) / count << "ms\n";
return 0; return 0;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册