From a58aea9e7c85dc6562ab56cc2f84f6ef78453ec1 Mon Sep 17 00:00:00 2001 From: eclipsess Date: Mon, 16 Jul 2018 15:46:31 +0800 Subject: [PATCH] fix basicConv+bnrelu in dwgroup --- .../kernel/central-arm-func/conv_add_bn_relu_arm_func.h | 2 +- .../kernel/central-arm-func/dwconv_bn_relu_arm_func.h | 7 +++---- src/operators/math/math_function.cpp | 4 ++-- src/operators/math/math_function.h | 2 +- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h index b74ea66fe2..d3b5bc6976 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h @@ -104,7 +104,7 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam ¶m) { math::matmulWithBn( filter_slice, false, col_matrix, false, static_cast(1), - &out_slice, static_cast(0), true, &new_scale, &new_bias); + &out_slice, static_cast(0), true, &new_scale, &new_bias, g); } } } diff --git a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h index 737165884b..84b2142de6 100644 --- a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h @@ -101,23 +101,22 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - std::cout << "***************" << std::endl; math::matmulWithBn( filter_slice, false, col_matrix, false, static_cast(1), - &out_slice, static_cast(0), false, &new_scale, &new_bias); + &out_slice, static_cast(0), true, &new_scale, &new_bias, g); } } } template void DWConvBNReluCompute(const FusionDWConvBNReluParam ¶m) { - if (param.Groups() == param.Input()->dims()[1] && + if (0&¶m.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); - } else if (param.Groups() == param.Input()->dims()[1] && + } else if (0&¶m.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index ca5367788e..519bc1a556 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -50,7 +50,7 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha, framework::Tensor *matrix_out, float beta, bool relu, framework::Tensor *new_scale, - framework::Tensor *new_bias) { + framework::Tensor *new_bias, int group) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); @@ -71,7 +71,7 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, SgemmWithBn(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, - new_scale->data(), new_bias->data()); + new_scale->data()+group, new_bias->data()+group); } } // namespace math diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index 0ca7815fc2..04e43a4954 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -31,7 +31,7 @@ template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, framework::Tensor *matrix_out, T beta, bool relu, - framework::Tensor *new_scale, framework::Tensor *new_bias); + framework::Tensor *new_scale, framework::Tensor *new_bias, int group); } // namespace math } // namespace operators } // namespace paddle_mobile -- GitLab