diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index 988f0b0f03b84c25a2e17e9d14054f99dcce4916..d65f89ede012ea083e4a73e4647079c248e33fe0 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -25,6 +25,7 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { + void ConvAddBasic(const FusionConvAddParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); @@ -106,9 +107,9 @@ void ConvAddBasic(const FusionConvAddParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(1), false, biase_data); + math::matmul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(1), false, biase_data); } } } diff --git a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h index 62497d793ee94a8ccf6fb65c6c38c88b084e10ff..860a3746e2918b05aa0f09f1536589b7dc62899c 100644 --- a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h @@ -25,15 +25,15 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -template +template void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); Tensor bias = *param.Bias(); int32_t axis = param.Axis(); - S *bias_data = bias.data(); + Otype *bias_data = bias.data(); Tensor *output = param.Output(); - output->mutable_data

(); + output->mutable_data(); float alpha = 1.0f; float beta = 1.0f; @@ -64,7 +64,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { Tensor col; Tensor col_matrix; if (is_expand) { - col.mutable_data

(col_shape); + col.mutable_data(col_shape); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } @@ -83,8 +83,8 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { int32_t in_step = static_cast(input->dims()[1]) / groups; int32_t out_step = static_cast(output->dims()[1]) / groups; - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; for (int32_t i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); @@ -112,8 +112,8 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, alpha, &out_slice, - beta, true, bias_data); + math::matmul(filter_slice, false, col_matrix, false, alpha, + &out_slice, beta, true, bias_data); } } } diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 11667dfcc9cf2e25712a8f5c57d665cd41e9a9c6..32c66c2e13aebf3dae00cd324299d4e169f636dc 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -106,9 +106,10 @@ inline void GemmConv(const ConvParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, static_cast(0), - false, static_cast(nullptr)); + math::matmul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(0), false, + static_cast(nullptr)); } } } diff --git a/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h b/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h index 1bb3aac3e9619da9e6cb9e4dac5061a7d9115014..300cb8e84b0703951b5305d684eb2f7bb652d669 100644 --- a/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h @@ -93,8 +93,8 @@ void ConvTransposeCompute(const ConvTransposeParam ¶m) { Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step); Tensor out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, true, in_slice, false, static_cast

(1.0), - &col_matrix, static_cast

(0.0)); + math::matmul(filter_slice, true, in_slice, false, + static_cast

(1.0), &col_matrix, static_cast

(0.0)); if (data_dim == 2U) { col2im(col, dilations, strides, std::vector{paddings[0], paddings[1], paddings[0], diff --git a/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h index c48c6ee8c62d5ded6fa5228fd91d945f8313a84d..30eb30ca3cc0ef416a70f657c0c2e6bde5e7e9ba 100644 --- a/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h +++ b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h @@ -23,20 +23,16 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -template +template void FusionFcCompute(const FusionFcParam ¶m) { const Tensor *input_x = param.InputX(); const Tensor *input_y = param.InputY(); Tensor *input_z = param.InputZ(); - S *input_z_data = input_z->data(); + Otype *input_z_data = input_z->data(); int axis = param.Axis(); Tensor *out = param.Out(); - // int m = out->dims()[0]; - // int n = out->dims()[1]; - auto *out_data = out->mutable_data

(); + auto *out_data = out->mutable_data(); - float alpha = 1.0f; - float beta = 1.0f; const Tensor x_matrix = input_x->dims().size() > 2 ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) @@ -59,11 +55,11 @@ void FusionFcCompute(const FusionFcParam ¶m) { // bias_data的维度和out的第二个维度一致 int64_t classes = input_z->numel(); for (int i = 0; i < out_dim[0]; i++) { - memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes); + memory::Copy(out_data + i * classes, input_z_data, sizeof(Otype) * classes); } - - math::matmul(x_matrix, false, y_matrix, false, alpha, out, beta, - false); + math::matmul(x_matrix, false, y_matrix, false, + static_cast(1), out, static_cast(1), + false); } } // namespace operators diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h index 8b9dad90a0b02ebf761bcd44fabc18905b056e6e..316f78a43f27c17eec1b31741b2b6bc678c41af2 100644 --- a/src/operators/kernel/central-arm-func/mul_arm_func.h +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -73,14 +73,14 @@ void MulCompute(const MulParam ¶m) { } if (param.InputX()->type() == typeid(int8_t)) { out->mutable_data(); - math::matmul(x_matrix, false, y_matrix, false, - static_cast(1), out, - static_cast(0)); - + math::matmul(x_matrix, false, y_matrix, false, + static_cast(1), out, + static_cast(0)); } else { out->mutable_data(); - math::matmul(x_matrix, false, y_matrix, false, static_cast(1), - out, static_cast(0)); + math::matmul(x_matrix, false, y_matrix, false, + static_cast(1), out, + static_cast(0)); } if (out_dim.size() != 2) { out->Resize(out_dim); diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index 4365bf5716b8b5811f6ac66217b2fe74ae116f52..b9ce977e0c84b148b27a02624baa05e6ab150672 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -41,10 +41,10 @@ void set_constant(framework::Tensor *tensor, float value) { } template <> -void matmul(const framework::Tensor &matrix_a, bool trans_a, - const framework::Tensor &matrix_b, bool trans_b, float alpha, - framework::Tensor *matrix_out, float beta, bool relu, - float *bias) { +void matmul(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, + float alpha, framework::Tensor *matrix_out, + float beta, bool relu, float *bias) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index 16c39221db5b94dd8ed323c9cced430a58e32e47..3b682eab2acf96ba70de563aba415a19ad4a66b6 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -24,24 +24,24 @@ namespace math { void set_constant(framework::Tensor *tensor, float value); -template +template void matmul(const framework::Tensor &matrix_a, bool trans_a, - const framework::Tensor &matrix_b, bool trans_b, T alpha, - framework::Tensor *matrix_out, T beta, bool relu = false, - float *bias = nullptr); + const framework::Tensor &matrix_b, bool trans_b, float alpha, + framework::Tensor *matrix_out, float beta, bool relu = false, + Otype *bias = nullptr); -template +template void matmul(const framework::Tensor &matrix_a, bool trans_a, - const framework::Tensor &matrix_b, bool trans_b, T alpha, - framework::Tensor *matrix_out, T beta, bool relu = false, - S *bias = nullptr, bool addOnRow = false); + const framework::Tensor &matrix_b, bool trans_b, float alpha, + framework::Tensor *matrix_out, float beta, bool relu, Otype *bias, + bool addOnRow); template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, - const framework::Tensor &matrix_b, bool trans_b, T alpha, - framework::Tensor *matrix_out, T beta, bool relu, + const framework::Tensor &matrix_b, bool trans_b, float alpha, + framework::Tensor *matrix_out, float beta, bool relu, framework::Tensor *new_scale, framework::Tensor *new_bias, - int group, float *bias = nullptr); + int group, T *bias = nullptr); void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, diff --git a/src/operators/math/math_function_int8.cpp b/src/operators/math/math_function_int8.cpp index b7f634b36fe8d009c06008aada971c61b70b4a46..6b3dd3f00a33cb015891a801801964eca1c5dcf5 100644 --- a/src/operators/math/math_function_int8.cpp +++ b/src/operators/math/math_function_int8.cpp @@ -22,10 +22,11 @@ namespace operators { namespace math { template <> -void matmul(const framework::Tensor &matrix_a, bool trans_a, - const framework::Tensor &matrix_b, bool trans_b, float alpha, - framework::Tensor *matrix_out, float beta, bool relu, int32_t *bias, - bool addOnRow) { +void matmul(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, + float alpha, framework::Tensor *matrix_out, + float beta, bool relu, int32_t *bias, + bool addOnRow) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); @@ -93,6 +94,16 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, #endif } } + +template <> +void matmul(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, + float alpha, framework::Tensor *matrix_out, + float beta, bool relu, int32_t *bias) { + matmul(matrix_a, trans_a, matrix_b, trans_b, alpha, + matrix_out, beta, relu, bias, false); +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/pooling.h b/src/operators/math/pooling.h index 8b60759003f33295edbc5acf1ef238cbd6dc827b..4a3b2b8389eeac708df73d97672682615d1f7912 100644 --- a/src/operators/math/pooling.h +++ b/src/operators/math/pooling.h @@ -53,7 +53,7 @@ struct PoolingVal { ++count; return *this; } - inline float Value() { return (count > 0) ? val * (1.f / count) : 0.f; } + inline float Value() { return (count > 0) ? val / count : 0.f; } }; #if defined(__ARM_NEON) || defined(__ARM_NEON__) diff --git a/src/operators/math/pooling3x3.cpp b/src/operators/math/pooling3x3.cpp index 879893b3752cf3456c7effd56073fa373c48ed43..e556768ce04cd4326d11c10f00aec00e2bd263f8 100644 --- a/src/operators/math/pooling3x3.cpp +++ b/src/operators/math/pooling3x3.cpp @@ -288,360 +288,363 @@ struct Pooling3x3 { int valid_w_end = valid_w_start + valid_w; float avg = 1.f / 9; - #pragma omp parallel for - for (int c = 0; c < output->dims()[1]; ++c) { - const float *input_ptr = input_data + c * image_size; - float *output_ptr = output_data + c * out_image_size; - // top - for (int h = 0; h < valid_h_start; ++h) { - Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, - padding_w, output_w, output_ptr); - } - // left - for (int w = 0; w < valid_w_start; ++w) { - Pooling3x3ValidCol(input_ptr, valid_h_start, valid_h_end, w, - input_h, input_w, padding_h, padding_w, - output_w, output_ptr); - } - // right - for (int w = valid_w_end; w < output_w; ++w) { - Pooling3x3ValidCol(input_ptr, valid_h_start, valid_h_end, w, - input_h, input_w, padding_h, padding_w, - output_w, output_ptr); - } - // bottom - for (int h = valid_h_end; h < output_h; ++h) { - Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, - padding_w, output_w, output_ptr); - } - // valid - int output_w_tiles = valid_w / 6; - int output_w_remain = valid_w - output_w_tiles * 6; - for (int h = valid_h_start; h < valid_h_end - 3; h += 4) { - const float *input_ptr0 = input_ptr + (h - padding_h) * input_w; - const float *input_ptr1 = input_ptr0 + input_w; - const float *input_ptr2 = input_ptr1 + input_w; - const float *input_ptr3 = input_ptr2 + input_w; - const float *input_ptr4 = input_ptr3 + input_w; - const float *input_ptr5 = input_ptr4 + input_w; - float *output_ptr0 = output_ptr + h * output_w + valid_w_start; - float *output_ptr1 = output_ptr0 + output_w; - float *output_ptr2 = output_ptr1 + output_w; - float *output_ptr3 = output_ptr2 + output_w; - int remain = output_w_remain; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) - float32x4x2_t x0, x1, x2; - float32x4x2_t y0, y1, y2; - float32x4_t post = vdupq_n_f32(1.f / 9); - for (int loop = 0; loop < output_w_tiles; ++loop) { - x0.val[0] = vld1q_f32(input_ptr0); - x0.val[1] = vld1q_f32(input_ptr0 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - - x0.val[0] = vld1q_f32(input_ptr1); - x0.val[1] = vld1q_f32(input_ptr1 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); - y1.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y1.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y0.val[0] = vPoolPreq_f32

(y1.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(y1.val[1], y0.val[1]); - - x0.val[0] = vld1q_f32(input_ptr2); - x0.val[1] = vld1q_f32(input_ptr2 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); - y2.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y2.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y1.val[0] = vPoolPreq_f32

(y2.val[0], y1.val[0]); - y1.val[1] = vPoolPreq_f32

(y2.val[1], y1.val[1]); - y0.val[0] = vPoolPreq_f32

(y2.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(y2.val[1], y0.val[1]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - y0.val[1] = vPoolPostq_f32

(y0.val[1], post); - vst1q_f32(output_ptr0, y0.val[0]); - vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); - - x0.val[0] = vld1q_f32(input_ptr3); - x0.val[1] = vld1q_f32(input_ptr3 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y1.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); - y1.val[1] = vPoolPreq_f32

(y0.val[1], y1.val[1]); - y2.val[0] = vPoolPreq_f32

(y0.val[0], y2.val[0]); - y2.val[1] = vPoolPreq_f32

(y0.val[1], y2.val[1]); - y1.val[0] = vPoolPostq_f32

(y1.val[0], post); - y1.val[1] = vPoolPostq_f32

(y1.val[1], post); - vst1q_f32(output_ptr1, y1.val[0]); - vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1])); - - x0.val[0] = vld1q_f32(input_ptr4); - x0.val[1] = vld1q_f32(input_ptr4 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - y2.val[0] = vPoolPreq_f32

(x0.val[0], y2.val[0]); - y2.val[1] = vPoolPreq_f32

(x0.val[1], y2.val[1]); - y2.val[0] = vPoolPostq_f32

(y2.val[0], post); - y2.val[1] = vPoolPostq_f32

(y2.val[1], post); - vst1q_f32(output_ptr2, y2.val[0]); - vst1_f32(output_ptr2 + 4, vget_low_f32(y2.val[1])); - - x0.val[0] = vld1q_f32(input_ptr5); - x0.val[1] = vld1q_f32(input_ptr5 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - y0.val[1] = vPoolPostq_f32

(y0.val[1], post); - vst1q_f32(output_ptr3, y0.val[0]); - vst1_f32(output_ptr3 + 4, vget_low_f32(y0.val[1])); - - input_ptr0 += 6; - input_ptr1 += 6; - input_ptr2 += 6; - input_ptr3 += 6; - input_ptr4 += 6; - input_ptr5 += 6; - output_ptr0 += 6; - output_ptr1 += 6; - output_ptr2 += 6; - output_ptr3 += 6; + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < output->dims()[0]; ++batch) { + for (int c = 0; c < output->dims()[1]; ++c) { + int channel = batch * output->dims()[1] + c; + const float *input_ptr = input_data + channel * image_size; + float *output_ptr = output_data + channel * out_image_size; + // top + for (int h = 0; h < valid_h_start; ++h) { + Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, + padding_w, output_w, output_ptr); } - // remain w - if (remain >= 4) { - x0.val[0] = vld1q_f32(input_ptr0); - x0.val[1] = vld1q_f32(input_ptr0 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - - x0.val[0] = vld1q_f32(input_ptr1); - x0.val[1] = vld1q_f32(input_ptr1 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - y1.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(y1.val[0], y0.val[0]); - - x0.val[0] = vld1q_f32(input_ptr2); - x0.val[1] = vld1q_f32(input_ptr2 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - y2.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y1.val[0] = vPoolPreq_f32

(y2.val[0], y1.val[0]); - y0.val[0] = vPoolPreq_f32

(y2.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - vst1q_f32(output_ptr0, y0.val[0]); - - x0.val[0] = vld1q_f32(input_ptr3); - x0.val[1] = vld1q_f32(input_ptr3 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y1.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); - y2.val[0] = vPoolPreq_f32

(y0.val[0], y2.val[0]); - y1.val[0] = vPoolPostq_f32

(y1.val[0], post); - vst1q_f32(output_ptr1, y1.val[0]); - - x0.val[0] = vld1q_f32(input_ptr4); - x0.val[1] = vld1q_f32(input_ptr4 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y2.val[0] = vPoolPreq_f32

(x0.val[0], y2.val[0]); - y2.val[0] = vPoolPostq_f32

(y2.val[0], post); - vst1q_f32(output_ptr2, y2.val[0]); - - x0.val[0] = vld1q_f32(input_ptr5); - x0.val[1] = vld1q_f32(input_ptr5 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - vst1q_f32(output_ptr3, y0.val[0]); - - input_ptr0 += 4; - input_ptr1 += 4; - input_ptr2 += 4; - input_ptr3 += 4; - input_ptr4 += 4; - input_ptr5 += 4; - output_ptr0 += 4; - output_ptr1 += 4; - output_ptr2 += 4; - output_ptr3 += 4; - remain -= 4; + // left + for (int w = 0; w < valid_w_start; ++w) { + Pooling3x3ValidCol(input_ptr, valid_h_start, valid_h_end, w, + input_h, input_w, padding_h, padding_w, + output_w, output_ptr); } -#endif // __ARM_NEON__ - for (int r = 0; r < remain; ++r) { - float m0 = PoolPre

(input_ptr0[r], input_ptr0[r + 1]); - m0 = PoolPre

(m0, input_ptr0[r + 2]); - float m1 = PoolPre

(input_ptr1[r], input_ptr1[r + 1]); - m1 = PoolPre

(m1, input_ptr1[r + 2]); - float m2 = PoolPre

(input_ptr2[r], input_ptr2[r + 1]); - m2 = PoolPre

(m2, input_ptr2[r + 2]); - float m3 = PoolPre

(input_ptr3[r], input_ptr3[r + 1]); - m3 = PoolPre

(m3, input_ptr3[r + 2]); - float m4 = PoolPre

(input_ptr4[r], input_ptr4[r + 1]); - m4 = PoolPre

(m4, input_ptr4[r + 2]); - float m5 = PoolPre

(input_ptr5[r], input_ptr5[r + 1]); - m5 = PoolPre

(m5, input_ptr5[r + 2]); - - m0 = PoolPre

(PoolPre

(m0, m1), m2); - m1 = PoolPre

(PoolPre

(m1, m2), m3); - m2 = PoolPre

(PoolPre

(m2, m3), m4); - m3 = PoolPre

(PoolPre

(m3, m4), m5); - output_ptr0[r] = PoolPost

(m0, avg); - output_ptr1[r] = PoolPost

(m1, avg); - output_ptr2[r] = PoolPost

(m2, avg); - output_ptr3[r] = PoolPost

(m3, avg); + // right + for (int w = valid_w_end; w < output_w; ++w) { + Pooling3x3ValidCol(input_ptr, valid_h_start, valid_h_end, w, + input_h, input_w, padding_h, padding_w, + output_w, output_ptr); } - } - // remain h - int start_h = valid_h_start + (valid_h & 0xFFFC); - for (int h = start_h; h < valid_h_end; ++h) { - const float *input_ptr0 = input_ptr + (h - padding_h) * input_w; - const float *input_ptr1 = input_ptr0 + input_w; - const float *input_ptr2 = input_ptr1 + input_w; - float *output_ptr0 = output_ptr + h * output_w + valid_w_start; - int remain = output_w_remain; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) - float32x4x2_t x0, x1, x2, y0; - float32x4_t post = vdupq_n_f32(1.f / 9); - for (int loop = 0; loop < output_w_tiles; ++loop) { - x0.val[0] = vld1q_f32(input_ptr0); - x0.val[1] = vld1q_f32(input_ptr0 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - - x0.val[0] = vld1q_f32(input_ptr1); - x0.val[1] = vld1q_f32(input_ptr1 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - - x0.val[0] = vld1q_f32(input_ptr2); - x0.val[1] = vld1q_f32(input_ptr2 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - y0.val[1] = vPoolPostq_f32

(y0.val[1], post); - vst1q_f32(output_ptr0, y0.val[0]); - vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); - - input_ptr0 += 6; - input_ptr1 += 6; - input_ptr2 += 6; - output_ptr0 += 6; + // bottom + for (int h = valid_h_end; h < output_h; ++h) { + Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, + padding_w, output_w, output_ptr); } - // remain w - if (remain >= 4) { - x0.val[0] = vld1q_f32(input_ptr0); - x0.val[1] = vld1q_f32(input_ptr0 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - - x0.val[0] = vld1q_f32(input_ptr1); - x0.val[1] = vld1q_f32(input_ptr1 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - - x0.val[0] = vld1q_f32(input_ptr2); - x0.val[1] = vld1q_f32(input_ptr2 + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - vst1q_f32(output_ptr0, y0.val[0]); - - input_ptr0 += 4; - input_ptr1 += 4; - input_ptr2 += 4; - output_ptr0 += 4; - remain -= 4; + // valid + int output_w_tiles = valid_w / 6; + int output_w_remain = valid_w - output_w_tiles * 6; + for (int h = valid_h_start; h < valid_h_end - 3; h += 4) { + const float *input_ptr0 = input_ptr + (h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + const float *input_ptr3 = input_ptr2 + input_w; + const float *input_ptr4 = input_ptr3 + input_w; + const float *input_ptr5 = input_ptr4 + input_w; + float *output_ptr0 = output_ptr + h * output_w + valid_w_start; + float *output_ptr1 = output_ptr0 + output_w; + float *output_ptr2 = output_ptr1 + output_w; + float *output_ptr3 = output_ptr2 + output_w; + int remain = output_w_remain; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + float32x4x2_t x0, x1, x2; + float32x4x2_t y0, y1, y2; + float32x4_t post = vdupq_n_f32(1.f / 9); + for (int loop = 0; loop < output_w_tiles; ++loop) { + x0.val[0] = vld1q_f32(input_ptr0); + x0.val[1] = vld1q_f32(input_ptr0 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + + x0.val[0] = vld1q_f32(input_ptr1); + x0.val[1] = vld1q_f32(input_ptr1 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); + y1.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y1.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y0.val[0] = vPoolPreq_f32

(y1.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(y1.val[1], y0.val[1]); + + x0.val[0] = vld1q_f32(input_ptr2); + x0.val[1] = vld1q_f32(input_ptr2 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); + y2.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y2.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y1.val[0] = vPoolPreq_f32

(y2.val[0], y1.val[0]); + y1.val[1] = vPoolPreq_f32

(y2.val[1], y1.val[1]); + y0.val[0] = vPoolPreq_f32

(y2.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(y2.val[1], y0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + vst1q_f32(output_ptr0, y0.val[0]); + vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); + + x0.val[0] = vld1q_f32(input_ptr3); + x0.val[1] = vld1q_f32(input_ptr3 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y1.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); + y1.val[1] = vPoolPreq_f32

(y0.val[1], y1.val[1]); + y2.val[0] = vPoolPreq_f32

(y0.val[0], y2.val[0]); + y2.val[1] = vPoolPreq_f32

(y0.val[1], y2.val[1]); + y1.val[0] = vPoolPostq_f32

(y1.val[0], post); + y1.val[1] = vPoolPostq_f32

(y1.val[1], post); + vst1q_f32(output_ptr1, y1.val[0]); + vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1])); + + x0.val[0] = vld1q_f32(input_ptr4); + x0.val[1] = vld1q_f32(input_ptr4 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); + y2.val[0] = vPoolPreq_f32

(x0.val[0], y2.val[0]); + y2.val[1] = vPoolPreq_f32

(x0.val[1], y2.val[1]); + y2.val[0] = vPoolPostq_f32

(y2.val[0], post); + y2.val[1] = vPoolPostq_f32

(y2.val[1], post); + vst1q_f32(output_ptr2, y2.val[0]); + vst1_f32(output_ptr2 + 4, vget_low_f32(y2.val[1])); + + x0.val[0] = vld1q_f32(input_ptr5); + x0.val[1] = vld1q_f32(input_ptr5 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + vst1q_f32(output_ptr3, y0.val[0]); + vst1_f32(output_ptr3 + 4, vget_low_f32(y0.val[1])); + + input_ptr0 += 6; + input_ptr1 += 6; + input_ptr2 += 6; + input_ptr3 += 6; + input_ptr4 += 6; + input_ptr5 += 6; + output_ptr0 += 6; + output_ptr1 += 6; + output_ptr2 += 6; + output_ptr3 += 6; + } + // remain width + if (remain >= 4) { + x0.val[0] = vld1q_f32(input_ptr0); + x0.val[1] = vld1q_f32(input_ptr0 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + + x0.val[0] = vld1q_f32(input_ptr1); + x0.val[1] = vld1q_f32(input_ptr1 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + y1.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[0] = vPoolPreq_f32

(y1.val[0], y0.val[0]); + + x0.val[0] = vld1q_f32(input_ptr2); + x0.val[1] = vld1q_f32(input_ptr2 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + y2.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y1.val[0] = vPoolPreq_f32

(y2.val[0], y1.val[0]); + y0.val[0] = vPoolPreq_f32

(y2.val[0], y0.val[0]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + vst1q_f32(output_ptr0, y0.val[0]); + + x0.val[0] = vld1q_f32(input_ptr3); + x0.val[1] = vld1q_f32(input_ptr3 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y1.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); + y2.val[0] = vPoolPreq_f32

(y0.val[0], y2.val[0]); + y1.val[0] = vPoolPostq_f32

(y1.val[0], post); + vst1q_f32(output_ptr1, y1.val[0]); + + x0.val[0] = vld1q_f32(input_ptr4); + x0.val[1] = vld1q_f32(input_ptr4 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y2.val[0] = vPoolPreq_f32

(x0.val[0], y2.val[0]); + y2.val[0] = vPoolPostq_f32

(y2.val[0], post); + vst1q_f32(output_ptr2, y2.val[0]); + + x0.val[0] = vld1q_f32(input_ptr5); + x0.val[1] = vld1q_f32(input_ptr5 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + vst1q_f32(output_ptr3, y0.val[0]); + + input_ptr0 += 4; + input_ptr1 += 4; + input_ptr2 += 4; + input_ptr3 += 4; + input_ptr4 += 4; + input_ptr5 += 4; + output_ptr0 += 4; + output_ptr1 += 4; + output_ptr2 += 4; + output_ptr3 += 4; + remain -= 4; + } +#endif // __ARM_NEON__ + for (int r = 0; r < remain; ++r) { + float m0 = PoolPre

(input_ptr0[r], input_ptr0[r + 1]); + m0 = PoolPre

(m0, input_ptr0[r + 2]); + float m1 = PoolPre

(input_ptr1[r], input_ptr1[r + 1]); + m1 = PoolPre

(m1, input_ptr1[r + 2]); + float m2 = PoolPre

(input_ptr2[r], input_ptr2[r + 1]); + m2 = PoolPre

(m2, input_ptr2[r + 2]); + float m3 = PoolPre

(input_ptr3[r], input_ptr3[r + 1]); + m3 = PoolPre

(m3, input_ptr3[r + 2]); + float m4 = PoolPre

(input_ptr4[r], input_ptr4[r + 1]); + m4 = PoolPre

(m4, input_ptr4[r + 2]); + float m5 = PoolPre

(input_ptr5[r], input_ptr5[r + 1]); + m5 = PoolPre

(m5, input_ptr5[r + 2]); + + m0 = PoolPre

(PoolPre

(m0, m1), m2); + m1 = PoolPre

(PoolPre

(m1, m2), m3); + m2 = PoolPre

(PoolPre

(m2, m3), m4); + m3 = PoolPre

(PoolPre

(m3, m4), m5); + output_ptr0[r] = PoolPost

(m0, avg); + output_ptr1[r] = PoolPost

(m1, avg); + output_ptr2[r] = PoolPost

(m2, avg); + output_ptr3[r] = PoolPost

(m3, avg); + } } + // remain height + int start_h = valid_h_start + (valid_h & 0xFFFC); + for (int h = start_h; h < valid_h_end; ++h) { + const float *input_ptr0 = input_ptr + (h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + float *output_ptr0 = output_ptr + h * output_w + valid_w_start; + int remain = output_w_remain; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + float32x4x2_t x0, x1, x2, y0; + float32x4_t post = vdupq_n_f32(1.f / 9); + for (int loop = 0; loop < output_w_tiles; ++loop) { + x0.val[0] = vld1q_f32(input_ptr0); + x0.val[1] = vld1q_f32(input_ptr0 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + + x0.val[0] = vld1q_f32(input_ptr1); + x0.val[1] = vld1q_f32(input_ptr1 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); + + x0.val[0] = vld1q_f32(input_ptr2); + x0.val[1] = vld1q_f32(input_ptr2 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + vst1q_f32(output_ptr0, y0.val[0]); + vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); + + input_ptr0 += 6; + input_ptr1 += 6; + input_ptr2 += 6; + output_ptr0 += 6; + } + // remain width + if (remain >= 4) { + x0.val[0] = vld1q_f32(input_ptr0); + x0.val[1] = vld1q_f32(input_ptr0 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + + x0.val[0] = vld1q_f32(input_ptr1); + x0.val[1] = vld1q_f32(input_ptr1 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + + x0.val[0] = vld1q_f32(input_ptr2); + x0.val[1] = vld1q_f32(input_ptr2 + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + vst1q_f32(output_ptr0, y0.val[0]); + + input_ptr0 += 4; + input_ptr1 += 4; + input_ptr2 += 4; + output_ptr0 += 4; + remain -= 4; + } #endif // __ARM_NEON__ - for (int r = 0; r < remain; ++r) { - float m0 = PoolPre

(input_ptr0[r], input_ptr0[r + 1]); - m0 = PoolPre

(m0, input_ptr0[r + 2]); - float m1 = PoolPre

(input_ptr1[r], input_ptr1[r + 1]); - m1 = PoolPre

(m1, input_ptr1[r + 2]); - float m2 = PoolPre

(input_ptr2[r], input_ptr2[r + 1]); - m2 = PoolPre

(m2, input_ptr2[r + 2]); - - m0 = PoolPre

(PoolPre

(m0, m1), m2); - output_ptr0[r] = PoolPost

(m0, avg); + for (int r = 0; r < remain; ++r) { + float m0 = PoolPre

(input_ptr0[r], input_ptr0[r + 1]); + m0 = PoolPre

(m0, input_ptr0[r + 2]); + float m1 = PoolPre

(input_ptr1[r], input_ptr1[r + 1]); + m1 = PoolPre

(m1, input_ptr1[r + 2]); + float m2 = PoolPre

(input_ptr2[r], input_ptr2[r + 1]); + m2 = PoolPre

(m2, input_ptr2[r + 2]); + + m0 = PoolPre

(PoolPre

(m0, m1), m2); + output_ptr0[r] = PoolPost

(m0, avg); + } } } } @@ -671,339 +674,342 @@ struct Pooling3x3 { int valid_w_end = valid_w_start + valid_w; float avg = 1.f / 9; - #pragma omp parallel for - for (int c = 0; c < output->dims()[1]; ++c) { - const float *input_ptr = input_data + c * image_size; - float *output_ptr = output_data + c * out_image_size; - // top - for (int h = 0; h < valid_h_start; ++h) { - Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, - padding_w, output_w, output_ptr); - } - // left - for (int w = 0; w < valid_w_start; ++w) { - Pooling3x3ValidCol(input_ptr, valid_h_start, valid_h_end, w, - input_h, input_w, padding_h, padding_w, - output_w, output_ptr); - } - // right - for (int w = valid_w_end; w < output_w; ++w) { - Pooling3x3ValidCol(input_ptr, valid_h_start, valid_h_end, w, - input_h, input_w, padding_h, padding_w, - output_w, output_ptr); - } - // bottom - for (int h = valid_h_end; h < output_h; ++h) { - Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, - padding_w, output_w, output_ptr); - } - // valid - int input_w_start = 2 * valid_w_start - padding_w; - int output_w_tiles = valid_w / 6; - int output_w_remain = valid_w - output_w_tiles * 6; - for (int h = valid_h_start; h < valid_h_end - 2; h += 3) { - size_t offset = (2 * h - padding_h) * input_w + input_w_start; - const float *input_ptr0 = input_ptr + offset; - const float *input_ptr1 = input_ptr0 + input_w; - const float *input_ptr2 = input_ptr1 + input_w; - const float *input_ptr3 = input_ptr2 + input_w; - const float *input_ptr4 = input_ptr3 + input_w; - const float *input_ptr5 = input_ptr4 + input_w; - const float *input_ptr6 = input_ptr5 + input_w; - float *output_ptr0 = output_ptr + h * output_w + valid_w_start; - float *output_ptr1 = output_ptr0 + output_w; - float *output_ptr2 = output_ptr1 + output_w; - int remain = output_w_remain; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) - float32x4x2_t x0, x1, x2; - float32x4x2_t y0, y1, y2; - float32x4_t post = vdupq_n_f32(1.f / 9); - for (int loop = 0; loop < output_w_tiles; ++loop) { - x0 = vld2q_f32(input_ptr0); - x1 = vld2q_f32(input_ptr0 + 8); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - - x0 = vld2q_f32(input_ptr1); - x1 = vld2q_f32(input_ptr1 + 8); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - - x0 = vld2q_f32(input_ptr2); - x1 = vld2q_f32(input_ptr2 + 8); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); - y1.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y1.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y0.val[0] = vPoolPreq_f32

(y1.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(y1.val[1], y0.val[1]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - y0.val[1] = vPoolPostq_f32

(y0.val[1], post); - vst1q_f32(output_ptr0, y0.val[0]); - vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); - - x0 = vld2q_f32(input_ptr3); - x1 = vld2q_f32(input_ptr3 + 8); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y1.val[0] = vPoolPreq_f32

(x0.val[0], y1.val[0]); - y1.val[1] = vPoolPreq_f32

(x0.val[1], y1.val[1]); - - x0 = vld2q_f32(input_ptr4); - x1 = vld2q_f32(input_ptr4 + 8); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y1.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); - y1.val[1] = vPoolPreq_f32

(y0.val[1], y1.val[1]); - y1.val[0] = vPoolPostq_f32

(y1.val[0], post); - y1.val[1] = vPoolPostq_f32

(y1.val[1], post); - vst1q_f32(output_ptr1, y1.val[0]); - vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1])); - - x0 = vld2q_f32(input_ptr5); - x1 = vld2q_f32(input_ptr5 + 8); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - - x0 = vld2q_f32(input_ptr6); - x1 = vld2q_f32(input_ptr6 + 8); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - y0.val[1] = vPoolPostq_f32

(y0.val[1], post); - vst1q_f32(output_ptr2, y0.val[0]); - vst1_f32(output_ptr2 + 4, vget_low_f32(y0.val[1])); - - input_ptr0 += 12; - input_ptr1 += 12; - input_ptr2 += 12; - input_ptr3 += 12; - input_ptr4 += 12; - input_ptr5 += 12; - input_ptr6 += 12; - output_ptr0 += 6; - output_ptr1 += 6; - output_ptr2 += 6; + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < output->dims()[0]; ++batch) { + for (int c = 0; c < output->dims()[1]; ++c) { + int channel = batch * output->dims()[1] + c; + const float *input_ptr = input_data + channel * image_size; + float *output_ptr = output_data + channel * out_image_size; + // top + for (int h = 0; h < valid_h_start; ++h) { + Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, + padding_w, output_w, output_ptr); } - // remain w - if (remain >= 4) { - x0 = vld2q_f32(input_ptr0); - x1.val[0] = vdupq_n_f32(input_ptr0[8]); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - - x0 = vld2q_f32(input_ptr1); - x1.val[0] = vdupq_n_f32(input_ptr1[8]); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - - x0 = vld2q_f32(input_ptr2); - x1.val[0] = vdupq_n_f32(input_ptr2[8]); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - y1.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(y1.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - vst1q_f32(output_ptr0, y0.val[0]); - - x0 = vld2q_f32(input_ptr3); - x1.val[0] = vdupq_n_f32(input_ptr3[8]); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y1.val[0] = vPoolPreq_f32

(x0.val[0], y1.val[0]); - - x0 = vld2q_f32(input_ptr4); - x1.val[0] = vdupq_n_f32(input_ptr4[8]); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y1.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); - y1.val[0] = vPoolPostq_f32

(y1.val[0], post); - vst1q_f32(output_ptr1, y1.val[0]); - - x0 = vld2q_f32(input_ptr5); - x1.val[0] = vdupq_n_f32(input_ptr5[8]); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - - x0 = vld2q_f32(input_ptr6); - x1.val[0] = vdupq_n_f32(input_ptr6[8]); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - vst1q_f32(output_ptr2, y0.val[0]); - - input_ptr0 += 8; - input_ptr1 += 8; - input_ptr2 += 8; - input_ptr3 += 8; - input_ptr4 += 8; - input_ptr5 += 8; - input_ptr6 += 8; - output_ptr0 += 4; - output_ptr1 += 4; - output_ptr2 += 4; - remain -= 4; + // left + for (int w = 0; w < valid_w_start; ++w) { + Pooling3x3ValidCol(input_ptr, valid_h_start, valid_h_end, w, + input_h, input_w, padding_h, padding_w, + output_w, output_ptr); } -#endif // __ARM_NEON__ - for (int r = 0; r < remain; ++r) { - float m0 = PoolPre

(input_ptr0[2 * r], input_ptr0[2 * r + 1]); - m0 = PoolPre

(m0, input_ptr0[2 * r + 2]); - float m1 = PoolPre

(input_ptr1[2 * r], input_ptr1[2 * r + 1]); - m1 = PoolPre

(m1, input_ptr1[2 * r + 2]); - float m2 = PoolPre

(input_ptr2[2 * r], input_ptr2[2 * r + 1]); - m2 = PoolPre

(m2, input_ptr2[2 * r + 2]); - float m3 = PoolPre

(input_ptr3[2 * r], input_ptr3[2 * r + 1]); - m3 = PoolPre

(m3, input_ptr3[2 * r + 2]); - float m4 = PoolPre

(input_ptr4[2 * r], input_ptr4[2 * r + 1]); - m4 = PoolPre

(m4, input_ptr4[2 * r + 2]); - float m5 = PoolPre

(input_ptr5[2 * r], input_ptr5[2 * r + 1]); - m5 = PoolPre

(m5, input_ptr5[2 * r + 2]); - float m6 = PoolPre

(input_ptr6[2 * r], input_ptr6[2 * r + 1]); - m6 = PoolPre

(m6, input_ptr6[2 * r + 2]); - - m0 = PoolPre

(PoolPre

(m0, m1), m2); - m1 = PoolPre

(PoolPre

(m2, m3), m4); - m2 = PoolPre

(PoolPre

(m4, m5), m6); - output_ptr0[r] = PoolPost

(m0, avg); - output_ptr1[r] = PoolPost

(m1, avg); - output_ptr2[r] = PoolPost

(m2, avg); + // right + for (int w = valid_w_end; w < output_w; ++w) { + Pooling3x3ValidCol(input_ptr, valid_h_start, valid_h_end, w, + input_h, input_w, padding_h, padding_w, + output_w, output_ptr); } - } - // remain h - int start_h = valid_h_start + valid_h / 3 * 3; - for (int h = start_h; h < valid_h_end; ++h) { - size_t offset = (2 * h - padding_h) * input_w + input_w_start; - const float *input_ptr0 = input_ptr + offset; - const float *input_ptr1 = input_ptr0 + input_w; - const float *input_ptr2 = input_ptr1 + input_w; - float *output_ptr0 = output_ptr + h * output_w + valid_w_start; - int remain = output_w_remain; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) - float32x4x2_t x0, x1, x2, y0; - float32x4_t post = vdupq_n_f32(1.f / 9); - for (int loop = 0; loop < output_w_tiles; ++loop) { - x0 = vld2q_f32(input_ptr0); - x1 = vld2q_f32(input_ptr0 + 8); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - - x0 = vld2q_f32(input_ptr1); - x1 = vld2q_f32(input_ptr1 + 8); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - - x0 = vld2q_f32(input_ptr2); - x1 = vld2q_f32(input_ptr2 + 8); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - y0.val[1] = vPoolPostq_f32

(y0.val[1], post); - vst1q_f32(output_ptr0, y0.val[0]); - vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); - - input_ptr0 += 12; - input_ptr1 += 12; - input_ptr2 += 12; - output_ptr0 += 6; + // bottom + for (int h = valid_h_end; h < output_h; ++h) { + Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, + padding_w, output_w, output_ptr); } - // remain w - if (remain >= 4) { - x0 = vld2q_f32(input_ptr0); - x1.val[0] = vdupq_n_f32(input_ptr0[8]); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - - x0 = vld2q_f32(input_ptr1); - x1.val[0] = vdupq_n_f32(input_ptr1[8]); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - - x0 = vld2q_f32(input_ptr2); - x1.val[0] = vdupq_n_f32(input_ptr2[8]); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - vst1q_f32(output_ptr0, y0.val[0]); - - input_ptr0 += 8; - input_ptr1 += 8; - input_ptr2 += 8; - output_ptr0 += 4; - remain -= 4; + // valid + int input_w_start = 2 * valid_w_start - padding_w; + int output_w_tiles = valid_w / 6; + int output_w_remain = valid_w - output_w_tiles * 6; + for (int h = valid_h_start; h < valid_h_end - 2; h += 3) { + size_t offset = (2 * h - padding_h) * input_w + input_w_start; + const float *input_ptr0 = input_ptr + offset; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + const float *input_ptr3 = input_ptr2 + input_w; + const float *input_ptr4 = input_ptr3 + input_w; + const float *input_ptr5 = input_ptr4 + input_w; + const float *input_ptr6 = input_ptr5 + input_w; + float *output_ptr0 = output_ptr + h * output_w + valid_w_start; + float *output_ptr1 = output_ptr0 + output_w; + float *output_ptr2 = output_ptr1 + output_w; + int remain = output_w_remain; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + float32x4x2_t x0, x1, x2; + float32x4x2_t y0, y1, y2; + float32x4_t post = vdupq_n_f32(1.f / 9); + for (int loop = 0; loop < output_w_tiles; ++loop) { + x0 = vld2q_f32(input_ptr0); + x1 = vld2q_f32(input_ptr0 + 8); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + + x0 = vld2q_f32(input_ptr1); + x1 = vld2q_f32(input_ptr1 + 8); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); + + x0 = vld2q_f32(input_ptr2); + x1 = vld2q_f32(input_ptr2 + 8); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); + y1.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y1.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y0.val[0] = vPoolPreq_f32

(y1.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(y1.val[1], y0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + vst1q_f32(output_ptr0, y0.val[0]); + vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); + + x0 = vld2q_f32(input_ptr3); + x1 = vld2q_f32(input_ptr3 + 8); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y1.val[0] = vPoolPreq_f32

(x0.val[0], y1.val[0]); + y1.val[1] = vPoolPreq_f32

(x0.val[1], y1.val[1]); + + x0 = vld2q_f32(input_ptr4); + x1 = vld2q_f32(input_ptr4 + 8); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y1.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); + y1.val[1] = vPoolPreq_f32

(y0.val[1], y1.val[1]); + y1.val[0] = vPoolPostq_f32

(y1.val[0], post); + y1.val[1] = vPoolPostq_f32

(y1.val[1], post); + vst1q_f32(output_ptr1, y1.val[0]); + vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1])); + + x0 = vld2q_f32(input_ptr5); + x1 = vld2q_f32(input_ptr5 + 8); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); + + x0 = vld2q_f32(input_ptr6); + x1 = vld2q_f32(input_ptr6 + 8); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + vst1q_f32(output_ptr2, y0.val[0]); + vst1_f32(output_ptr2 + 4, vget_low_f32(y0.val[1])); + + input_ptr0 += 12; + input_ptr1 += 12; + input_ptr2 += 12; + input_ptr3 += 12; + input_ptr4 += 12; + input_ptr5 += 12; + input_ptr6 += 12; + output_ptr0 += 6; + output_ptr1 += 6; + output_ptr2 += 6; + } + // remain width + if (remain >= 4) { + x0 = vld2q_f32(input_ptr0); + x1.val[0] = vdupq_n_f32(input_ptr0[8]); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + + x0 = vld2q_f32(input_ptr1); + x1.val[0] = vdupq_n_f32(input_ptr1[8]); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + + x0 = vld2q_f32(input_ptr2); + x1.val[0] = vdupq_n_f32(input_ptr2[8]); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + y1.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[0] = vPoolPreq_f32

(y1.val[0], y0.val[0]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + vst1q_f32(output_ptr0, y0.val[0]); + + x0 = vld2q_f32(input_ptr3); + x1.val[0] = vdupq_n_f32(input_ptr3[8]); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y1.val[0] = vPoolPreq_f32

(x0.val[0], y1.val[0]); + + x0 = vld2q_f32(input_ptr4); + x1.val[0] = vdupq_n_f32(input_ptr4[8]); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y1.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); + y1.val[0] = vPoolPostq_f32

(y1.val[0], post); + vst1q_f32(output_ptr1, y1.val[0]); + + x0 = vld2q_f32(input_ptr5); + x1.val[0] = vdupq_n_f32(input_ptr5[8]); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + + x0 = vld2q_f32(input_ptr6); + x1.val[0] = vdupq_n_f32(input_ptr6[8]); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + vst1q_f32(output_ptr2, y0.val[0]); + + input_ptr0 += 8; + input_ptr1 += 8; + input_ptr2 += 8; + input_ptr3 += 8; + input_ptr4 += 8; + input_ptr5 += 8; + input_ptr6 += 8; + output_ptr0 += 4; + output_ptr1 += 4; + output_ptr2 += 4; + remain -= 4; + } +#endif // __ARM_NEON__ + for (int r = 0; r < remain; ++r) { + float m0 = PoolPre

(input_ptr0[2 * r], input_ptr0[2 * r + 1]); + m0 = PoolPre

(m0, input_ptr0[2 * r + 2]); + float m1 = PoolPre

(input_ptr1[2 * r], input_ptr1[2 * r + 1]); + m1 = PoolPre

(m1, input_ptr1[2 * r + 2]); + float m2 = PoolPre

(input_ptr2[2 * r], input_ptr2[2 * r + 1]); + m2 = PoolPre

(m2, input_ptr2[2 * r + 2]); + float m3 = PoolPre

(input_ptr3[2 * r], input_ptr3[2 * r + 1]); + m3 = PoolPre

(m3, input_ptr3[2 * r + 2]); + float m4 = PoolPre

(input_ptr4[2 * r], input_ptr4[2 * r + 1]); + m4 = PoolPre

(m4, input_ptr4[2 * r + 2]); + float m5 = PoolPre

(input_ptr5[2 * r], input_ptr5[2 * r + 1]); + m5 = PoolPre

(m5, input_ptr5[2 * r + 2]); + float m6 = PoolPre

(input_ptr6[2 * r], input_ptr6[2 * r + 1]); + m6 = PoolPre

(m6, input_ptr6[2 * r + 2]); + + m0 = PoolPre

(PoolPre

(m0, m1), m2); + m1 = PoolPre

(PoolPre

(m2, m3), m4); + m2 = PoolPre

(PoolPre

(m4, m5), m6); + output_ptr0[r] = PoolPost

(m0, avg); + output_ptr1[r] = PoolPost

(m1, avg); + output_ptr2[r] = PoolPost

(m2, avg); + } } + // remain height + int start_h = valid_h_start + valid_h / 3 * 3; + for (int h = start_h; h < valid_h_end; ++h) { + size_t offset = (2 * h - padding_h) * input_w + input_w_start; + const float *input_ptr0 = input_ptr + offset; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + float *output_ptr0 = output_ptr + h * output_w + valid_w_start; + int remain = output_w_remain; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + float32x4x2_t x0, x1, x2, y0; + float32x4_t post = vdupq_n_f32(1.f / 9); + for (int loop = 0; loop < output_w_tiles; ++loop) { + x0 = vld2q_f32(input_ptr0); + x1 = vld2q_f32(input_ptr0 + 8); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + + x0 = vld2q_f32(input_ptr1); + x1 = vld2q_f32(input_ptr1 + 8); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); + + x0 = vld2q_f32(input_ptr2); + x1 = vld2q_f32(input_ptr2 + 8); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + vst1q_f32(output_ptr0, y0.val[0]); + vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); + + input_ptr0 += 12; + input_ptr1 += 12; + input_ptr2 += 12; + output_ptr0 += 6; + } + // remain width + if (remain >= 4) { + x0 = vld2q_f32(input_ptr0); + x1.val[0] = vdupq_n_f32(input_ptr0[8]); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + + x0 = vld2q_f32(input_ptr1); + x1.val[0] = vdupq_n_f32(input_ptr1[8]); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + + x0 = vld2q_f32(input_ptr2); + x1.val[0] = vdupq_n_f32(input_ptr2[8]); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + vst1q_f32(output_ptr0, y0.val[0]); + + input_ptr0 += 8; + input_ptr1 += 8; + input_ptr2 += 8; + output_ptr0 += 4; + remain -= 4; + } #endif // __ARM_NEON__ - for (int r = 0; r < remain; ++r) { - float m0 = PoolPre

(input_ptr0[2 * r], input_ptr0[2 * r + 1]); - m0 = PoolPre

(m0, input_ptr0[2 * r + 2]); - float m1 = PoolPre

(input_ptr1[2 * r], input_ptr1[2 * r + 1]); - m1 = PoolPre

(m1, input_ptr1[2 * r + 2]); - float m2 = PoolPre

(input_ptr2[2 * r], input_ptr2[2 * r + 1]); - m2 = PoolPre

(m2, input_ptr2[2 * r + 2]); - - m0 = PoolPre

(PoolPre

(m0, m1), m2); - output_ptr0[r] = PoolPost

(m0, avg); + for (int r = 0; r < remain; ++r) { + float m0 = PoolPre

(input_ptr0[2 * r], input_ptr0[2 * r + 1]); + m0 = PoolPre

(m0, input_ptr0[2 * r + 2]); + float m1 = PoolPre

(input_ptr1[2 * r], input_ptr1[2 * r + 1]); + m1 = PoolPre

(m1, input_ptr1[2 * r + 2]); + float m2 = PoolPre

(input_ptr2[2 * r], input_ptr2[2 * r + 1]); + m2 = PoolPre

(m2, input_ptr2[2 * r + 2]); + + m0 = PoolPre

(PoolPre

(m0, m1), m2); + output_ptr0[r] = PoolPost

(m0, avg); + } } } } diff --git a/test/common/test_gemm_perf.cpp b/test/common/test_gemm_perf.cpp index 92b78a57e9a0236ce2e1c6627b150d4c246c5413..17f9b77b7039cbe5eb26645ffa8dc97a164bc808 100644 --- a/test/common/test_gemm_perf.cpp +++ b/test/common/test_gemm_perf.cpp @@ -73,14 +73,14 @@ int main() { // float // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::matmul( aa, false, bb, false, static_cast(1), &cc, static_cast(0), false, nullptr); } auto time_start0 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::matmul( aa, false, bb, false, static_cast(1), &cc, static_cast(0), false, nullptr); } @@ -91,14 +91,14 @@ int main() { // int8_t without bias // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, static_cast(0)); } auto time_start1 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, static_cast(0)); } @@ -109,13 +109,13 @@ int main() { // int8_t with bias, column element wise add // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), false, bias_data_col, false); } auto time_start2 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), false, bias_data_col, false); } @@ -126,13 +126,13 @@ int main() { // int8_t with bias, row element wise add // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), false, bias_data_row, true); } auto time_start3 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), false, bias_data_row, true); } @@ -143,13 +143,13 @@ int main() { // int8_t with bias&relu // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), true, bias_data_col, false); } auto time_start4 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, static_cast(0), true, bias_data_col, false); } diff --git a/test/operators/test_pool_op.cpp b/test/operators/test_pool_op.cpp index b38123a7d623d614fb04077daf351cad7d1aa8f2..3668b8cb2846ba4e3cb7f3b0728c1356a343cf4c 100644 --- a/test/operators/test_pool_op.cpp +++ b/test/operators/test_pool_op.cpp @@ -59,7 +59,8 @@ int TestPoolOp(int in_channels, int in_height, int in_width) { attrs["ksize"].Set>(std::vector({kernel_h, kernel_w})); attrs["strides"].Set>(std::vector({stride_h, stride_w})); attrs["paddings"].Set>(std::vector({pad_h, pad_w})); - attrs["ceil_mode"].Set(false); + attrs["ceil_mode"].Set(true); + // attrs["ceil_mode"].Set(false); attrs["global_pooling"].Set(false); auto *op = new operators::PoolOp("pool2d", inputs, outputs, attrs, @@ -116,57 +117,57 @@ int main(int argc, char *argv[]) { int in_channels = atoi(argv[1]); int in_height = atoi(argv[2]); int in_width = atoi(argv[3]); - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=max, kernel=3, pad=0, stride=1"; - paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width); - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=max, kernel=3, pad=1, stride=1"; - paddle_mobile::TestPoolOp<0, 3, 1, 1>(in_channels, in_height, in_width); - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=max, kernel=3, pad=2, stride=1"; - paddle_mobile::TestPoolOp<0, 3, 2, 1>(in_channels, in_height, in_width); - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=max, kernel=3, pad=5, stride=1"; - paddle_mobile::TestPoolOp<0, 3, 5, 1>(in_channels, in_height, in_width); - - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=avg, kernel=3, pad=0, stride=1"; - paddle_mobile::TestPoolOp<1, 3, 0, 1>(in_channels, in_height, in_width); - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=avg, kernel=3, pad=1, stride=1"; - paddle_mobile::TestPoolOp<1, 3, 1, 1>(in_channels, in_height, in_width); - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=avg, kernel=3, pad=2, stride=1"; - paddle_mobile::TestPoolOp<1, 3, 2, 1>(in_channels, in_height, in_width); - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=avg, kernel=3, pad=5, stride=1"; - paddle_mobile::TestPoolOp<1, 3, 5, 1>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=3, pad=0, stride=1"; + // paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=3, pad=1, stride=1"; + // paddle_mobile::TestPoolOp<0, 3, 1, 1>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=3, pad=2, stride=1"; + // paddle_mobile::TestPoolOp<0, 3, 2, 1>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=3, pad=5, stride=1"; + // paddle_mobile::TestPoolOp<0, 3, 5, 1>(in_channels, in_height, in_width); + // + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=3, pad=0, stride=1"; + // paddle_mobile::TestPoolOp<1, 3, 0, 1>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=3, pad=1, stride=1"; + // paddle_mobile::TestPoolOp<1, 3, 1, 1>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=3, pad=2, stride=1"; + // paddle_mobile::TestPoolOp<1, 3, 2, 1>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=3, pad=5, stride=1"; + // paddle_mobile::TestPoolOp<1, 3, 5, 1>(in_channels, in_height, in_width); LOG(paddle_mobile::kLOG_INFO) << "float, pooling_type=max, kernel=3, pad=0, stride=2"; paddle_mobile::TestPoolOp<0, 3, 0, 2>(in_channels, in_height, in_width); - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=max, kernel=3, pad=1, stride=2"; - paddle_mobile::TestPoolOp<0, 3, 1, 2>(in_channels, in_height, in_width); - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=max, kernel=3, pad=2, stride=2"; - paddle_mobile::TestPoolOp<0, 3, 2, 2>(in_channels, in_height, in_width); - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=max, kernel=3, pad=5, stride=2"; - paddle_mobile::TestPoolOp<0, 3, 5, 2>(in_channels, in_height, in_width); - - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=avg, kernel=3, pad=0, stride=2"; - paddle_mobile::TestPoolOp<1, 3, 0, 2>(in_channels, in_height, in_width); - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=avg, kernel=3, pad=1, stride=2"; - paddle_mobile::TestPoolOp<1, 3, 1, 2>(in_channels, in_height, in_width); - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=avg, kernel=3, pad=2, stride=2"; - paddle_mobile::TestPoolOp<1, 3, 2, 2>(in_channels, in_height, in_width); - LOG(paddle_mobile::kLOG_INFO) - << "float, pooling_type=avg, kernel=3, pad=5, stride=2"; - paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=3, pad=1, stride=2"; + // paddle_mobile::TestPoolOp<0, 3, 1, 2>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=3, pad=2, stride=2"; + // paddle_mobile::TestPoolOp<0, 3, 2, 2>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=3, pad=5, stride=2"; + // paddle_mobile::TestPoolOp<0, 3, 5, 2>(in_channels, in_height, in_width); + // + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=3, pad=0, stride=2"; + // paddle_mobile::TestPoolOp<1, 3, 0, 2>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=3, pad=1, stride=2"; + // paddle_mobile::TestPoolOp<1, 3, 1, 2>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=3, pad=2, stride=2"; + // paddle_mobile::TestPoolOp<1, 3, 2, 2>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=3, pad=5, stride=2"; + // paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width); // // kernel = 5, pad = 0, stride = 1 // LOG(paddle_mobile::kLOG_INFO)