diff --git a/src/operators/kernel/arm/pool_kernel.cpp b/src/operators/kernel/arm/pool_kernel.cpp index 5c92d5be014faf4007c0853bde08e450ebc4f79a..38e6c5f3f071d8eb0385d742fb819564309eeef6 100644 --- a/src/operators/kernel/arm/pool_kernel.cpp +++ b/src/operators/kernel/arm/pool_kernel.cpp @@ -63,8 +63,20 @@ void PoolKernel::Compute(const PoolParam ¶m) const { } } else if (ksize[0] == 3 && ksize[0] == ksize[1]) { if (pooling_type == "max") { + if (strides[0] == strides[1] && strides[0] == 1 && + paddings[0] == paddings[1] && paddings[1] == 1) { + math::Pool3x3Maxs1p1(in_x, out); + } else { + math::Pool3x3Max(strides, paddings, in_x, out); + } math::Pool3x3Max(strides, paddings, in_x, out); } else if (pooling_type == "avg") { + if (strides[0] == strides[1] && strides[0] == 1 && + paddings[0] == paddings[1] && paddings[1] == 1) { + math::Pool3x3Avgs1p1(in_x, out); + } else { + math::Pool3x3Avg(strides, paddings, in_x, out); + } math::Pool3x3Avg(strides, paddings, in_x, out); } diff --git a/src/operators/math/pool_3x3.cpp b/src/operators/math/pool_3x3.cpp index 0259565377386a1415d27b0794580a6a223a88d4..fb91528b473418849d9005a2c0a5a52d9d033e58 100644 --- a/src/operators/math/pool_3x3.cpp +++ b/src/operators/math/pool_3x3.cpp @@ -13,13 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef POOL_OP -#define __ARM_NEON true -#include "pool_3x3.h" -#include "framework/tensor.h" -#if __ARM_NEON -#include -#endif // __ARM_NEON +#include "operators/math/pool_3x3.h" #include +#include "framework/tensor.h" namespace paddle_mobile { namespace operators { namespace math { @@ -27,6 +23,481 @@ using framework::Tensor; using std::max; using std::min; using std::vector; +void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) { +#if __ARM_NEON + const int batch_size = input->dims()[0]; + + const int h_in = input->dims()[2]; + + const int w_in = input->dims()[3]; + + const int output_channels = output->dims()[1]; + + const int h_out = output->dims()[2]; + const int w_out = output->dims()[3]; + const int outputdata_channel_stride = h_out * w_out; + const int inputdata_channel_stride = h_in * w_in; + float *out_data = output->data(); + const float *input_data = input->data(); + const float coef = 1.0 / 9.0; + for (int k = 0; k < batch_size; ++k) { + for (int c = 0; c < output_channels; ++c) { + // four corner point + out_data[0] = (input_data[0] + input_data[1] + input_data[w_in] + + input_data[w_in + 1]) * + coef; + out_data[w_out - 1] = + (input_data[w_in - 2] + input_data[w_in - 1] + + input_data[w_in * 2 - 2] + input_data[2 * w_in - 1]) * + coef; + out_data[(h_out - 1) * w_out] = + (input_data[(h_in - 2) * w_in] + input_data[(h_in - 2) * w_in + 1] + + input_data[(h_in - 1) * w_in] + input_data[(h_in - 1) * w_in + 1]) * + coef; + out_data[h_out * w_out - 1] = + (input_data[h_in * w_in - 1] + input_data[h_in * w_in - 2] + + input_data[(h_in - 1) * w_in - 1] + + input_data[(h_in - 1) * w_in - 2]) * + coef; + // left side & right side + for (int i = 1; i < h_in - 1; ++i) { + out_data[i * w_out] = + (input_data[i * w_in - w_in] + input_data[i * w_in - w_in + 1] + + input_data[i * w_in] + input_data[i * w_in + 1] + + input_data[i * w_in + w_in] + input_data[i * w_in + w_in + 1]) * + coef; + out_data[i * w_out + w_out - 1] = + (input_data[i * w_in - w_in + w_in - 2] + + input_data[i * w_in - w_in + 1 + w_in - 2] + + input_data[i * w_in + w_in - 2] + + input_data[i * w_in + 1 + w_in - 2] + + input_data[i * w_in + w_in + w_in - 2] + + input_data[i * w_in + w_in + 1 + w_in - 2]) * + coef; + } + // top 1 row & bottom 1 row + const float *input_tmp = input_data; + + float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, + tmp3, tmp4, tmp5, sum, out0; + float32x4_t v_coef = vdupq_n_f32(coef); + in0 = vld1q_f32(input_tmp); + in2 = vld1q_f32(input_tmp + w_in); + const float *input_tmp_end = input_tmp + (h_in - 2) * w_in; + in4 = vld1q_f32(input_tmp_end); + in6 = vld1q_f32(input_tmp_end + w_in); + int c_mid = w_out - 2; + auto output_ptr = out_data + 1; + for (; c_mid > 3; c_mid -= 4) { + in1 = vld1q_f32(input_tmp + 4); + in3 = vld1q_f32(input_tmp + w_in + 4); + + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); + + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + + sum = vaddq_f32(in0, tmp0); + sum = vaddq_f32(sum, tmp1); + sum = vaddq_f32(sum, in2); + sum = vaddq_f32(sum, tmp2); + sum = vaddq_f32(sum, tmp3); + + vst1q_f32(output_ptr, vmulq_f32(sum, v_coef)); + + in5 = vld1q_f32(input_tmp_end + 4); + in7 = vld1q_f32(input_tmp_end + w_in + 4); + + tmp0 = vextq_f32(in4, in5, 1); + tmp1 = vextq_f32(in4, in5, 2); + tmp2 = vextq_f32(in6, in7, 1); + tmp3 = vextq_f32(in6, in7, 2); + + sum = vaddq_f32(in0, tmp0); + sum = vaddq_f32(sum, tmp1); + sum = vaddq_f32(sum, in2); + sum = vaddq_f32(sum, tmp2); + sum = vaddq_f32(sum, tmp3); + + vst1q_f32(output_ptr + (h_out - 1) * w_out, vmulq_f32(sum, v_coef)); + + // can optimize to each 8 stride. + input_tmp += 4; + input_tmp_end += 4; + output_ptr += 4; + in0 = in1; + in2 = in3; + in4 = in5; + in6 = in7; + } + // top right remain + float32x4_t pad0 = vdupq_n_f32(input_data[w_in - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[2 * w_in - 1]); + + tmp0 = vextq_f32(in0, pad0, 1); + tmp1 = vextq_f32(in0, pad0, 2); + tmp2 = vextq_f32(in2, pad1, 2); + tmp3 = vextq_f32(in2, pad1, 2); + + sum = vaddq_f32(in0, tmp0); + sum = vaddq_f32(sum, tmp1); + sum = vaddq_f32(sum, in2); + sum = vaddq_f32(sum, tmp2); + sum = vaddq_f32(sum, tmp3); + out0 = vmulq_f32(sum, v_coef); + + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + i, out0, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + i, out0, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + i, out0, 2); + } + } + + // bottom_right remain + float32x4_t pad2 = vdupq_n_f32(input_data[(h_in - 1) * w_in - 1]); + float32x4_t pad3 = vdupq_n_f32(input_data[h_in * w_in - 1]); + + tmp0 = vextq_f32(in4, pad2, 1); + tmp1 = vextq_f32(in4, pad2, 2); + tmp2 = vextq_f32(in6, pad3, 2); + tmp3 = vextq_f32(in6, pad3, 2); + + sum = vaddq_f32(in4, tmp0); + sum = vaddq_f32(sum, tmp1); + sum = vaddq_f32(sum, in6); + sum = vaddq_f32(sum, tmp2); + sum = vaddq_f32(sum, tmp3); + out0 = vmulq_f32(sum, v_coef); + + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 2); + } + } + // mid + for (int j = 0; j < h_out - 2; ++j) { + output_ptr = out_data + w_out * (j + 1) + 1; + input_tmp = input_data + j * w_in; + + in0 = vld1q_f32(input_tmp); + in2 = vld1q_f32(input_tmp + w_in); + in4 = vld1q_f32(input_tmp + 2 * w_in); + c_mid = w_out - 2; + for (; c_mid > 3; c_mid -= 4) { + in1 = vld1q_f32(input_tmp + 4); + in3 = vld1q_f32(input_tmp + w_in + 4); + in5 = vld1q_f32(input_tmp + 2 * w_in + 4); + + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + tmp4 = vextq_f32(in4, in5, 1); + tmp5 = vextq_f32(in4, in5, 2); + + sum = vaddq_f32(in0, tmp0); + sum = vaddq_f32(sum, tmp1); + sum = vaddq_f32(sum, in2); + sum = vaddq_f32(sum, tmp2); + sum = vaddq_f32(sum, tmp3); + sum = vaddq_f32(sum, in4); + sum = vaddq_f32(sum, tmp4); + sum = vaddq_f32(sum, tmp5); + + out0 = vmulq_f32(sum, v_coef); + vst1q_f32(output_ptr, out0); + output_ptr += 4; + input_tmp += 4; + in0 = in1; + in2 = in3; + in4 = in5; + } + // mid remain + float32x4_t pad0 = vdupq_n_f32(input_data[(j + 1) * w_in - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]); + float32x4_t pad2 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]); + + tmp0 = vextq_f32(in0, pad0, 1); + tmp1 = vextq_f32(in0, pad0, 2); + tmp2 = vextq_f32(in2, pad1, 1); + tmp3 = vextq_f32(in2, pad1, 2); + tmp4 = vextq_f32(in4, pad2, 1); + tmp5 = vextq_f32(in4, pad2, 2); + + sum = vaddq_f32(in0, tmp0); + sum = vaddq_f32(sum, tmp1); + sum = vaddq_f32(sum, in2); + sum = vaddq_f32(sum, tmp2); + sum = vaddq_f32(sum, tmp3); + sum = vaddq_f32(sum, in4); + sum = vaddq_f32(sum, tmp4); + sum = vaddq_f32(sum, tmp5); + out0 = vmulq_f32(sum, v_coef); + + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + i, out0, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + i, out0, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + i, out0, 2); + } + } + } + input_data += inputdata_channel_stride; + out_data += outputdata_channel_stride; + } + } +#endif +} + +void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) { +#if __ARM_NEON + const int batch_size = input->dims()[0]; + + const int h_in = input->dims()[2]; + + const int w_in = input->dims()[3]; + + const int output_channels = output->dims()[1]; + + const int h_out = output->dims()[2]; + const int w_out = output->dims()[3]; + const int outputdata_channel_stride = h_out * w_out; + const int inputdata_channel_stride = h_in * w_in; + float *out_data = output->data(); + const float *input_data = input->data(); + for (int k = 0; k < batch_size; ++k) { + for (int c = 0; c < output_channels; ++c) { + // four corner point + out_data[0] = std::max(std::max(input_data[0], input_data[1]), + std::max(input_data[w_in], input_data[w_in + 1])); + out_data[w_out - 1] = std::max( + std::max(input_data[w_in - 2], input_data[w_in - 1]), + std::max(input_data[w_in * 2 - 2], input_data[2 * w_in - 1])); + out_data[(h_out - 1) * w_out] = + std::max(std::max(input_data[(h_in - 2) * w_in], + input_data[(h_in - 2) * w_in + 1]), + std::max(input_data[(h_in - 1) * w_in], + input_data[(h_in - 1) * w_in + 1])); + out_data[h_out * w_out - 1] = std::max( + std::max(input_data[(h_in - 1) * w_in - 1], + input_data[(h_in - 1) * w_in - 2]), + std::max(input_data[h_in * w_in - 1], input_data[h_in * w_in - 2])); + // left side & right side + for (int i = 1; i < h_in - 1; ++i) { + float max1 = std::max(input_data[i * w_in - w_in], + input_data[i * w_in - w_in + 1]); + float max2 = std::max(input_data[i * w_in], input_data[i * w_in + 1]); + float max3 = std::max(input_data[i * w_in + w_in], + input_data[i * w_in + w_in + 1]); + out_data[i * w_out] = std::max(std::max(max1, max2), max3); + + max1 = std::max(input_data[i * w_in - w_in + w_in - 2], + input_data[i * w_in - w_in + 1 + w_in - 2]); + max2 = std::max(input_data[i * w_in + w_in - 2], + input_data[i * w_in + 1 + w_in - 2]); + max3 = std::max(input_data[i * w_in + w_in + w_in - 2], + input_data[i * w_in + w_in + 1 + w_in - 2]); + out_data[i * w_out + w_out - 1] = std::max(std::max(max1, max2), max3); + } + // top 1 row & bottom 1 row + const float *input_tmp = input_data; + + float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, + tmp3, tmp4, tmp5, max; + in0 = vld1q_f32(input_tmp); + in2 = vld1q_f32(input_tmp + w_in); + const float *input_tmp_end = input_tmp + (h_in - 2) * w_in; + in4 = vld1q_f32(input_tmp_end); + in6 = vld1q_f32(input_tmp_end + w_in); + int c_mid = w_out - 2; + auto output_ptr = out_data + 1; + for (; c_mid > 3; c_mid -= 4) { + in1 = vld1q_f32(input_tmp + 4); + in3 = vld1q_f32(input_tmp + w_in + 4); + + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); + + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + + max = vmaxq_f32(in0, tmp0); + max = vmaxq_f32(max, tmp1); + max = vmaxq_f32(max, in2); + max = vmaxq_f32(max, tmp2); + max = vmaxq_f32(max, tmp3); + + vst1q_f32(output_ptr, max); + + in5 = vld1q_f32(input_tmp_end + 4); + in7 = vld1q_f32(input_tmp_end + w_in + 4); + + tmp0 = vextq_f32(in4, in5, 1); + tmp1 = vextq_f32(in4, in5, 2); + tmp2 = vextq_f32(in6, in7, 1); + tmp3 = vextq_f32(in6, in7, 2); + + max = vmaxq_f32(in4, tmp0); + max = vmaxq_f32(max, tmp1); + max = vmaxq_f32(max, in6); + max = vmaxq_f32(max, tmp2); + max = vmaxq_f32(max, tmp3); + + vst1q_f32(output_ptr + (h_out - 1) * w_out, max); + + input_tmp += 4; + input_tmp_end += 4; + output_ptr += 4; + in0 = in1; + in2 = in3; + in4 = in5; + in6 = in7; + } + // top right remain + float32x4_t pad0 = vdupq_n_f32(input_data[w_in - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[2 * w_in - 1]); + + tmp0 = vextq_f32(in0, pad0, 1); + tmp1 = vextq_f32(in0, pad0, 2); + tmp2 = vextq_f32(in2, pad1, 1); + tmp3 = vextq_f32(in2, pad1, 2); + + max = vmaxq_f32(in0, tmp0); + max = vmaxq_f32(max, tmp1); + max = vmaxq_f32(max, in2); + max = vmaxq_f32(max, tmp2); + max = vmaxq_f32(max, tmp3); + + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + i, max, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + i, max, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + i, max, 2); + } + } + + // bottom_right remain + float32x4_t pad2 = vdupq_n_f32(input_data[(h_in - 1) * w_in - 1]); + float32x4_t pad3 = vdupq_n_f32(input_data[h_in * w_in - 1]); + + tmp0 = vextq_f32(in4, pad2, 1); + tmp1 = vextq_f32(in4, pad2, 2); + tmp2 = vextq_f32(in6, pad3, 1); + tmp3 = vextq_f32(in6, pad3, 2); + + max = vmaxq_f32(in4, tmp0); + max = vmaxq_f32(max, tmp1); + max = vmaxq_f32(max, in6); + max = vmaxq_f32(max, tmp2); + max = vmaxq_f32(max, tmp3); + + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, max, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, max, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, max, 2); + } + } + // mid + for (int j = 0; j < h_out - 2; ++j) { + output_ptr = out_data + (j + 1) * w_out + 1; + input_tmp = input_data + j * w_in; + + in0 = vld1q_f32(input_tmp); + in2 = vld1q_f32(input_tmp + w_in); + in4 = vld1q_f32(input_tmp + 2 * w_in); + c_mid = w_out - 2; + for (; c_mid > 3; c_mid -= 4) { + in1 = vld1q_f32(input_tmp + 4); + in3 = vld1q_f32(input_tmp + w_in + 4); + in5 = vld1q_f32(input_tmp + 2 * w_in + 4); + + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + tmp4 = vextq_f32(in4, in5, 1); + tmp5 = vextq_f32(in4, in5, 2); + + max = vmaxq_f32(in0, tmp0); + max = vmaxq_f32(max, tmp1); + max = vmaxq_f32(max, in2); + max = vmaxq_f32(max, tmp2); + max = vmaxq_f32(max, tmp3); + max = vmaxq_f32(max, in4); + max = vmaxq_f32(max, tmp4); + max = vmaxq_f32(max, tmp5); + + vst1q_f32(output_ptr, max); + output_ptr += 4; + input_tmp += 4; + in0 = in1; + in2 = in3; + in4 = in5; + } + // mid remain + float32x4_t pad0 = vdupq_n_f32(input_data[(j + 1) * w_in - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]); + float32x4_t pad2 = vdupq_n_f32(input_data[(j + 3) * w_in - 1]); + + tmp0 = vextq_f32(in0, pad0, 1); + tmp1 = vextq_f32(in0, pad0, 2); + tmp2 = vextq_f32(in2, pad1, 1); + tmp3 = vextq_f32(in2, pad1, 2); + tmp4 = vextq_f32(in4, pad2, 1); + tmp5 = vextq_f32(in4, pad2, 2); + + max = vmaxq_f32(in0, tmp0); + max = vmaxq_f32(max, tmp1); + max = vmaxq_f32(max, in2); + max = vmaxq_f32(max, tmp2); + max = vmaxq_f32(max, tmp3); + max = vmaxq_f32(max, in4); + max = vmaxq_f32(max, tmp4); + max = vmaxq_f32(max, tmp5); + + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + i, max, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + i, max, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + i, max, 2); + } + } + } + input_data += inputdata_channel_stride; + out_data += outputdata_channel_stride; + } + } +#endif +} void Pool3x3Max(vector strides, vector paddings, const Tensor *input, Tensor *output) { diff --git a/src/operators/math/pool_3x3.h b/src/operators/math/pool_3x3.h index 22a398084390701aefc8815c9aa93b82b4c4ec7b..53d39b81cc158f02601a352f0ec2996f1d444304 100644 --- a/src/operators/math/pool_3x3.h +++ b/src/operators/math/pool_3x3.h @@ -15,7 +15,8 @@ limitations under the License. */ #ifdef POOL_OP #pragma once - +#include +#include #include "framework/tensor.h" #if __ARM_NEON #include @@ -26,7 +27,8 @@ namespace operators { namespace math { using framework::Tensor; using std::vector; - +void Pool3x3Avgs1p1(const Tensor *input, Tensor *output); +void Pool3x3Maxs1p1(const Tensor *input, Tensor *output); void Pool3x3Max(vector strides, vector paddings, const Tensor *input, Tensor *output); diff --git a/test/net/test_googlenet.cpp b/test/net/test_googlenet.cpp index 1695995a8d60d20e0d6c5f8911c39a948426a82a..d0a5979516d3256d2ddfc7b042a18ab119d6a202 100644 --- a/test/net/test_googlenet.cpp +++ b/test/net/test_googlenet.cpp @@ -30,12 +30,12 @@ int main() { std::vector dims{1, 3, 224, 224}; GetInput(g_test_image_1x3x224x224, &input, dims); auto time3 = time(); - - for (int i = 0; i < 10; ++i) { + int count = 1; + for (int i = 0; i < count; i++) { executor.Predict(input, dims); } auto time4 = time(); - DLOG << "predict cost :" << time_diff(time3, time4) << "ms\n"; + DLOG << "avg predict cost :" << time_diff(time3, time4) / count << "ms\n"; return 0; }