From 587fd892d83ab0b731ce8126db23730907c3f64c Mon Sep 17 00:00:00 2001 From: eclipsess Date: Mon, 16 Jul 2018 15:49:25 +0800 Subject: [PATCH] fix basicConv+bnrelu in dwgroup --- .../kernel/central-arm-func/dwconv_bn_relu_arm_func.h | 4 ++-- src/operators/math/math_function.cpp | 3 ++- src/operators/math/math_function.h | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h index 84b2142de6..7693da2a84 100644 --- a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h @@ -109,14 +109,14 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam ¶m) { } template void DWConvBNReluCompute(const FusionDWConvBNReluParam ¶m) { - if (0&¶m.Groups() == param.Input()->dims()[1] && + if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); - } else if (0&¶m.Groups() == param.Input()->dims()[1] && + } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index 519bc1a556..d881014ccb 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -71,7 +71,8 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, SgemmWithBn(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, - new_scale->data()+group, new_bias->data()+group); + new_scale->data() + group, + new_bias->data() + group); } } // namespace math diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index 04e43a4954..b5179458a2 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -31,7 +31,8 @@ template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, framework::Tensor *matrix_out, T beta, bool relu, - framework::Tensor *new_scale, framework::Tensor *new_bias, int group); + framework::Tensor *new_scale, framework::Tensor *new_bias, + int group); } // namespace math } // namespace operators } // namespace paddle_mobile -- GitLab