diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h index b74ea66fe28fbae0ffd6e6d3d4e503f5d739251b..d3b5bc69760797c4efcc3fb77831d54676d7d5b1 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h @@ -104,7 +104,7 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam ¶m) { math::matmulWithBn( filter_slice, false, col_matrix, false, static_cast(1), - &out_slice, static_cast(0), true, &new_scale, &new_bias); + &out_slice, static_cast(0), true, &new_scale, &new_bias, g); } } } diff --git a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h index 737165884bfb89feeebfe7cf38c58edb44bc3e83..7693da2a84c15b8f7b6953eb51e2765b5ea159f8 100644 --- a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h @@ -101,10 +101,9 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - std::cout << "***************" << std::endl; math::matmulWithBn( filter_slice, false, col_matrix, false, static_cast(1), - &out_slice, static_cast(0), false, &new_scale, &new_bias); + &out_slice, static_cast(0), true, &new_scale, &new_bias, g); } } } diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index ca5367788ed87da070dd19900e8d546e51caf337..d881014ccb3f29393ca73fa0e7f4792d4c0d65c7 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -50,7 +50,7 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha, framework::Tensor *matrix_out, float beta, bool relu, framework::Tensor *new_scale, - framework::Tensor *new_bias) { + framework::Tensor *new_bias, int group) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); @@ -71,7 +71,8 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, SgemmWithBn(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, - new_scale->data(), new_bias->data()); + new_scale->data() + group, + new_bias->data() + group); } } // namespace math diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index 0ca7815fc2bcff2be0345b581d3dfb26cf55794c..b5179458a2bf9e6817366c7bd4ea1f536fd21642 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -31,7 +31,8 @@ template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, framework::Tensor *matrix_out, T beta, bool relu, - framework::Tensor *new_scale, framework::Tensor *new_bias); + framework::Tensor *new_scale, framework::Tensor *new_bias, + int group); } // namespace math } // namespace operators } // namespace paddle_mobile