From e5ef6e7e71a20b9678d426b426395658041f0d39 Mon Sep 17 00:00:00 2001 From: eclipsess Date: Tue, 10 Jul 2018 14:07:41 +0800 Subject: [PATCH] gemm bn relu v1 --- src/operators/fusion_conv_add.h | 10 +++---- src/operators/fusion_conv_add_bn_relu_op.h | 10 +++---- .../central-arm-func/conv_add_arm_func.h | 8 +++--- .../central-arm-func/conv_add_bn_relu_func.h | 27 ++++++++++--------- src/operators/math/depthwise_conv_3x3.cpp | 3 +-- src/operators/math/math_function.h | 2 +- test/common/test_gemm.cpp | 4 +-- 7 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/operators/fusion_conv_add.h b/src/operators/fusion_conv_add.h index ae030ba576..3aa419d139 100644 --- a/src/operators/fusion_conv_add.h +++ b/src/operators/fusion_conv_add.h @@ -68,11 +68,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel< #ifdef PADDLE_MOBILE_CPU -#ifndef CONV_ADD_REGISTER -static framework::FusionOpRegistrar convadd_registrar( - new FusionConvAddMatcher()); -#define CONV_ADD_REGISTER -#endif +//#ifndef CONV_ADD_REGISTER +//static framework::FusionOpRegistrar convadd_registrar( +// new FusionConvAddMatcher()); +//#define CONV_ADD_REGISTER +//#endif #endif diff --git a/src/operators/fusion_conv_add_bn_relu_op.h b/src/operators/fusion_conv_add_bn_relu_op.h index 753ce39598..af252fabb1 100644 --- a/src/operators/fusion_conv_add_bn_relu_op.h +++ b/src/operators/fusion_conv_add_bn_relu_op.h @@ -79,11 +79,11 @@ class FusionConvAddBNReluOp #ifdef PADDLE_MOBILE_CPU -//#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER -// static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar( -// new FusionConvAddBNReluMatcher()); -//#define FUSION_CONV_ADD_BN_RELU_REGISTER -//#endif +#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER + static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar( + new FusionConvAddBNReluMatcher()); +#define FUSION_CONV_ADD_BN_RELU_REGISTER +#endif #endif diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index d364dfdd61..a6e19d30e2 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -125,10 +125,10 @@ void ConvAddCompute(const FusionConvAddParam ¶m) { param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { - // math::DepthwiseConv3x3(param.Input(), param.Strides(), - // param.Paddings(), - // param.Filter(), param.Bias(), param.Output(), - // false); +// math::DepthwiseConv3x3(param.Input(), param.Strides(), +// param.Paddings(), +// param.Filter(), param.Bias(), param.Output(), +// false); math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(), *param.Bias(), true); diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h index e8aed3fd7d..115574bcc0 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h @@ -106,22 +106,25 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, +// math::matmul(filter_slice, false, col_matrix, false, +// static_cast(1), &out_slice, +// static_cast(0)); + math::matmulWithBn(filter_slice, false, col_matrix, false, static_cast(1), &out_slice, - static_cast(0)); + static_cast(0),true, &new_scale,&new_bias); } } /// todo : use neon in special case instead of 2for(300ms) - auto output_ptr = output->data(); - for (int c = 0; c < output_matrix_shape[0]; c++) { - int start = c * output_matrix_shape[1]; - for (int j = 0; j < output_matrix_shape[1]; j++) { - output_ptr[start + j] = - output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c]; - output_ptr[start + j] = - output_ptr[start + j] < 0 ? 0 : output_ptr[start + j]; - } - } +// auto output_ptr = output->data(); +// for (int c = 0; c < output_matrix_shape[0]; c++) { +// int start = c * output_matrix_shape[1]; +// for (int j = 0; j < output_matrix_shape[1]; j++) { +// output_ptr[start + j] = +// output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c]; +// output_ptr[start + j] = +// output_ptr[start + j] < 0 ? 0 : output_ptr[start + j]; +// } +// } } template void ConvAddBNReluCompute(const FusionConvAddBNReluParam ¶m) { diff --git a/src/operators/math/depthwise_conv_3x3.cpp b/src/operators/math/depthwise_conv_3x3.cpp index c8e332c7dd..c8a6473567 100644 --- a/src/operators/math/depthwise_conv_3x3.cpp +++ b/src/operators/math/depthwise_conv_3x3.cpp @@ -1036,8 +1036,7 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, float32x4_t vbias = vdupq_n_f32(0.0); - float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1], - input_buff_top[w_times + 1]; + float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1]; float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; int out2in_mid; float32x4_t zero = vdupq_n_f32(0.0); diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index 0ca7815fc2..cb0b96e256 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -30,7 +30,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, - framework::Tensor *matrix_out, T beta, bool relu, + framework::Tensor *matrix_out, T beta, bool relu = true, framework::Tensor *new_scale, framework::Tensor *new_bias); } // namespace math } // namespace operators diff --git a/test/common/test_gemm.cpp b/test/common/test_gemm.cpp index aaf3c183f3..1841de0be4 100644 --- a/test/common/test_gemm.cpp +++ b/test/common/test_gemm.cpp @@ -52,8 +52,8 @@ int main() { } auto time1 = time(); - paddle_mobile::operators::math::sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, - ldc); +// paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, +// ldc); auto time2 = time(); DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n"; for (int i = 0; i < m * n; ++i) { -- GitLab