diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h index b74ea66fe28fbae0ffd6e6d3d4e503f5d739251b..d3b5bc69760797c4efcc3fb77831d54676d7d5b1 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h @@ -104,7 +104,7 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam ¶m) { math::matmulWithBn( filter_slice, false, col_matrix, false, static_cast(1), - &out_slice, static_cast(0), true, &new_scale, &new_bias); + &out_slice, static_cast(0), true, &new_scale, &new_bias, g); } } } diff --git a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h index 737165884bfb89feeebfe7cf38c58edb44bc3e83..84b2142de6e652a0f85151ec178fce4fb0feb7d8 100644 --- a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h @@ -101,23 +101,22 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam ¶m) { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - std::cout << "***************" << std::endl; math::matmulWithBn( filter_slice, false, col_matrix, false, static_cast(1), - &out_slice, static_cast(0), false, &new_scale, &new_bias); + &out_slice, static_cast(0), true, &new_scale, &new_bias, g); } } } template void DWConvBNReluCompute(const FusionDWConvBNReluParam ¶m) { - if (param.Groups() == param.Input()->dims()[1] && + if (0&¶m.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); - } else if (param.Groups() == param.Input()->dims()[1] && + } else if (0&¶m.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index ca5367788ed87da070dd19900e8d546e51caf337..519bc1a55611d1cf26599ea2efb3707e081c7f9c 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -50,7 +50,7 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, float alpha, framework::Tensor *matrix_out, float beta, bool relu, framework::Tensor *new_scale, - framework::Tensor *new_bias) { + framework::Tensor *new_bias, int group) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); @@ -71,7 +71,7 @@ void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, SgemmWithBn(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, - new_scale->data(), new_bias->data()); + new_scale->data()+group, new_bias->data()+group); } } // namespace math diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index 0ca7815fc2bcff2be0345b581d3dfb26cf55794c..04e43a49547cf770555d126cbe97ab0e91e2b637 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -31,7 +31,7 @@ template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, framework::Tensor *matrix_out, T beta, bool relu, - framework::Tensor *new_scale, framework::Tensor *new_bias); + framework::Tensor *new_scale, framework::Tensor *new_bias, int group); } // namespace math } // namespace operators } // namespace paddle_mobile