From 898f73767d0596158479b08693bc43a9df139e57 Mon Sep 17 00:00:00 2001 From: eclipsess Date: Tue, 10 Jul 2018 16:30:17 +0800 Subject: [PATCH] add fusion bnrelu --- src/operators/fusion_conv_add.h | 10 ++++---- src/operators/fusion_conv_add_bn_relu_op.h | 2 +- .../central-arm-func/conv_add_arm_func.h | 8 +++---- .../central-arm-func/conv_add_bn_relu_func.h | 23 ++++--------------- src/operators/math/math_function.h | 2 +- test/common/test_gemm.cpp | 5 ++-- 6 files changed, 18 insertions(+), 32 deletions(-) diff --git a/src/operators/fusion_conv_add.h b/src/operators/fusion_conv_add.h index 3aa419d139..ae030ba576 100644 --- a/src/operators/fusion_conv_add.h +++ b/src/operators/fusion_conv_add.h @@ -68,11 +68,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel< #ifdef PADDLE_MOBILE_CPU -//#ifndef CONV_ADD_REGISTER -//static framework::FusionOpRegistrar convadd_registrar( -// new FusionConvAddMatcher()); -//#define CONV_ADD_REGISTER -//#endif +#ifndef CONV_ADD_REGISTER +static framework::FusionOpRegistrar convadd_registrar( + new FusionConvAddMatcher()); +#define CONV_ADD_REGISTER +#endif #endif diff --git a/src/operators/fusion_conv_add_bn_relu_op.h b/src/operators/fusion_conv_add_bn_relu_op.h index af252fabb1..389c76cc83 100644 --- a/src/operators/fusion_conv_add_bn_relu_op.h +++ b/src/operators/fusion_conv_add_bn_relu_op.h @@ -80,7 +80,7 @@ class FusionConvAddBNReluOp #ifdef PADDLE_MOBILE_CPU #ifndef FUSION_CONV_ADD_BN_RELU_REGISTER - static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar( +static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar( new FusionConvAddBNReluMatcher()); #define FUSION_CONV_ADD_BN_RELU_REGISTER #endif diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index a6e19d30e2..15f9b4a178 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -125,10 +125,10 @@ void ConvAddCompute(const FusionConvAddParam ¶m) { param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { -// math::DepthwiseConv3x3(param.Input(), param.Strides(), -// param.Paddings(), -// param.Filter(), param.Bias(), param.Output(), -// false); + // math::DepthwiseConv3x3(param.Input(), param.Strides(), + // param.Paddings(), + // param.Filter(), param.Bias(), + // param.Output(), false); math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(), *param.Bias(), true); diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h index 115574bcc0..fb49a33c67 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h @@ -26,8 +26,6 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam ¶m) { Tensor bias = *param.Bias(); Tensor new_bias = *param.NewBias(); Tensor new_scale = *param.NewScale(); - auto new_bias_ptr = new_bias.data(); - auto new_scale_ptr = new_scale.data(); int axis = param.Axis(); Tensor *output = param.Output(); math::expand_bias(bias, axis, output->dims()); @@ -106,25 +104,12 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); -// math::matmul(filter_slice, false, col_matrix, false, -// static_cast(1), &out_slice, -// static_cast(0)); - math::matmulWithBn(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(0),true, &new_scale,&new_bias); + + math::matmulWithBn( + filter_slice, false, col_matrix, false, static_cast(1), + &out_slice, static_cast(0), true, &new_scale, &new_bias); } } - /// todo : use neon in special case instead of 2for(300ms) -// auto output_ptr = output->data(); -// for (int c = 0; c < output_matrix_shape[0]; c++) { -// int start = c * output_matrix_shape[1]; -// for (int j = 0; j < output_matrix_shape[1]; j++) { -// output_ptr[start + j] = -// output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c]; -// output_ptr[start + j] = -// output_ptr[start + j] < 0 ? 0 : output_ptr[start + j]; -// } -// } } template void ConvAddBNReluCompute(const FusionConvAddBNReluParam ¶m) { diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index cb0b96e256..0ca7815fc2 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -30,7 +30,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, - framework::Tensor *matrix_out, T beta, bool relu = true, + framework::Tensor *matrix_out, T beta, bool relu, framework::Tensor *new_scale, framework::Tensor *new_bias); } // namespace math } // namespace operators diff --git a/test/common/test_gemm.cpp b/test/common/test_gemm.cpp index 1841de0be4..8cb778c458 100644 --- a/test/common/test_gemm.cpp +++ b/test/common/test_gemm.cpp @@ -52,8 +52,9 @@ int main() { } auto time1 = time(); -// paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, -// ldc); + // paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, + // c, + // ldc); auto time2 = time(); DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n"; for (int i = 0; i < m * n; ++i) { -- GitLab