diff --git a/src/operators/fusion_conv_add.h b/src/operators/fusion_conv_add.h index 3aa419d13953e26f9ed971aaab17a19943bae141..ae030ba5767e4039cfa3effe0a7ded4886f261cf 100644 --- a/src/operators/fusion_conv_add.h +++ b/src/operators/fusion_conv_add.h @@ -68,11 +68,11 @@ class FusionConvAddOp : public framework::OperatorWithKernel< #ifdef PADDLE_MOBILE_CPU -//#ifndef CONV_ADD_REGISTER -//static framework::FusionOpRegistrar convadd_registrar( -// new FusionConvAddMatcher()); -//#define CONV_ADD_REGISTER -//#endif +#ifndef CONV_ADD_REGISTER +static framework::FusionOpRegistrar convadd_registrar( + new FusionConvAddMatcher()); +#define CONV_ADD_REGISTER +#endif #endif diff --git a/src/operators/fusion_conv_add_bn_relu_op.h b/src/operators/fusion_conv_add_bn_relu_op.h index af252fabb118f0885ed878986d28d0cfc04bd385..389c76cc83a532fe706d911903a8412bb8bfb4ca 100644 --- a/src/operators/fusion_conv_add_bn_relu_op.h +++ b/src/operators/fusion_conv_add_bn_relu_op.h @@ -80,7 +80,7 @@ class FusionConvAddBNReluOp #ifdef PADDLE_MOBILE_CPU #ifndef FUSION_CONV_ADD_BN_RELU_REGISTER - static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar( +static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar( new FusionConvAddBNReluMatcher()); #define FUSION_CONV_ADD_BN_RELU_REGISTER #endif diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index a6e19d30e2306dbc293e7a582f6343de7537eb12..15f9b4a17889b77da1884253f9e982d8f14ad131 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -125,10 +125,10 @@ void ConvAddCompute(const FusionConvAddParam ¶m) { param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { -// math::DepthwiseConv3x3(param.Input(), param.Strides(), -// param.Paddings(), -// param.Filter(), param.Bias(), param.Output(), -// false); + // math::DepthwiseConv3x3(param.Input(), param.Strides(), + // param.Paddings(), + // param.Filter(), param.Bias(), + // param.Output(), false); math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(), *param.Bias(), true); diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h index 115574bcc0cca35af871c4f993768c5344a3b8ee..fb49a33c67face81a2615516bffd6aa151868fe3 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h @@ -26,8 +26,6 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam ¶m) { Tensor bias = *param.Bias(); Tensor new_bias = *param.NewBias(); Tensor new_scale = *param.NewScale(); - auto new_bias_ptr = new_bias.data(); - auto new_scale_ptr = new_scale.data(); int axis = param.Axis(); Tensor *output = param.Output(); math::expand_bias(bias, axis, output->dims()); @@ -106,25 +104,12 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); -// math::matmul(filter_slice, false, col_matrix, false, -// static_cast(1), &out_slice, -// static_cast(0)); - math::matmulWithBn(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(0),true, &new_scale,&new_bias); + + math::matmulWithBn( + filter_slice, false, col_matrix, false, static_cast(1), + &out_slice, static_cast(0), true, &new_scale, &new_bias); } } - /// todo : use neon in special case instead of 2for(300ms) -// auto output_ptr = output->data(); -// for (int c = 0; c < output_matrix_shape[0]; c++) { -// int start = c * output_matrix_shape[1]; -// for (int j = 0; j < output_matrix_shape[1]; j++) { -// output_ptr[start + j] = -// output_ptr[start + j] * new_scale_ptr[c] + new_bias_ptr[c]; -// output_ptr[start + j] = -// output_ptr[start + j] < 0 ? 0 : output_ptr[start + j]; -// } -// } } template void ConvAddBNReluCompute(const FusionConvAddBNReluParam ¶m) { diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index cb0b96e2566fcec2d89791fb9c9acd012a39c4d2..0ca7815fc2bcff2be0345b581d3dfb26cf55794c 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -30,7 +30,7 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, - framework::Tensor *matrix_out, T beta, bool relu = true, + framework::Tensor *matrix_out, T beta, bool relu, framework::Tensor *new_scale, framework::Tensor *new_bias); } // namespace math } // namespace operators diff --git a/test/common/test_gemm.cpp b/test/common/test_gemm.cpp index 1841de0be41920506ed39a735af36359241de0c1..8cb778c458034aecf6cea89fcf0d3e2a3d8118ba 100644 --- a/test/common/test_gemm.cpp +++ b/test/common/test_gemm.cpp @@ -52,8 +52,9 @@ int main() { } auto time1 = time(); -// paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, -// ldc); + // paddle_mobile::operators::math::Sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, + // c, + // ldc); auto time2 = time(); DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n"; for (int i = 0; i < m * n; ++i) {