diff --git a/src/operators/kernel/central-arm-func/conv_add_relu_int8_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_relu_int8_arm_func.h index fb431dc279d071a98a22d1e58a3f6b3fc26664c6..b4c5b86b499f98444a4c4a289c96e7f60e38f7fa 100644 --- a/src/operators/kernel/central-arm-func/conv_add_relu_int8_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_relu_int8_arm_func.h @@ -113,8 +113,8 @@ void ConvAddReluInt8Compute(const FusionConvAddReluInt8Param ¶m) { Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul_int8(filter_slice, false, col_matrix, false, scale_v, - &out_slice, static_cast(0), true, biase_data); + math::matmul(filter_slice, false, col_matrix, false, scale_v, &out_slice, + static_cast(0), true, biase_data); } } } diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index ce111ed78f7b81affffc646b49a00e6d15cbb697..f746eae470ede7f6cc21b8abde462eafd46ab89e 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -108,13 +108,13 @@ inline void GemmConv(const ConvParam ¶m) { Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); if (param.Input()->type() == typeid(int8_t)) { - math::matmul_int8(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(0)); + math::matmul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, static_cast(0), + false, static_cast(nullptr)); } else { - math::matmul(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(0)); + math::matmul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, static_cast(0), + false, static_cast(nullptr)); } } } diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h index 62e8ae03d9119cafc3c5716042569a90f077325c..60f6bca611a2c053b4e491964cd473de70031a92 100644 --- a/src/operators/kernel/central-arm-func/mul_arm_func.h +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -73,13 +73,13 @@ void MulCompute(const MulParam ¶m) { } if (param.InputX()->type() == typeid(int8_t)) { out->mutable_data(); - math::matmul_int8(x_matrix, false, y_matrix, false, static_cast(1), - out, static_cast(0)); + math::matmul(x_matrix, false, y_matrix, false, static_cast(1), out, + static_cast(0), false, static_cast(nullptr)); } else { out->mutable_data(); - math::matmul(x_matrix, false, y_matrix, false, static_cast(1), - out, static_cast(0)); + math::matmul(x_matrix, false, y_matrix, false, static_cast(1), out, + static_cast(0), false, static_cast(nullptr)); } if (out_dim.size() != 2) { out->Resize(out_dim); diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index 9661b2d4c22ed49ef0c078fac0872c7643057430..c58e8035940c65646851961bc2b9d12307f37e7a 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -30,10 +30,11 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, framework::Tensor *matrix_out, T beta, bool relu = false, float *bias = nullptr); -void matmul_int8(const framework::Tensor &matrix_a, bool trans_a, - const framework::Tensor &matrix_b, bool trans_b, float alpha, - framework::Tensor *matrix_out, float beta, bool relu = false, - int32_t *bias = nullptr); +template +void matmul(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, T alpha, + framework::Tensor *matrix_out, T beta, bool relu = false, + S *bias = nullptr); template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, diff --git a/src/operators/math/math_function_int8.cpp b/src/operators/math/math_function_int8.cpp index e1998e8e12062fe02fa9140b2f4a57bd8121724a..fe6b05ae1a19f36a0c3f14acc89cf2401af7f610 100644 --- a/src/operators/math/math_function_int8.cpp +++ b/src/operators/math/math_function_int8.cpp @@ -20,10 +20,12 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { namespace math { -void matmul_int8(const framework::Tensor &matrix_a, bool trans_a, - const framework::Tensor &matrix_b, bool trans_b, float alpha, - framework::Tensor *matrix_out, float beta, bool relu, - int32_t *bias) { + +template <> +void matmul(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, float alpha, + framework::Tensor *matrix_out, float beta, bool relu, + int32_t *bias) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); diff --git a/test/common/test_gemm_perf.cpp b/test/common/test_gemm_perf.cpp index 5c5f4026fd6022002ddb26f811e0dfe63e53a980..c8081e2d47061efdfea8db6c0393dc236d33b3e7 100644 --- a/test/common/test_gemm_perf.cpp +++ b/test/common/test_gemm_perf.cpp @@ -85,16 +85,16 @@ int main() { // int8_t without bias // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul_int8( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, - static_cast(0), false, nullptr); + static_cast(0), false, static_cast(nullptr)); } auto time3 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul_int8( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, - static_cast(0), false, nullptr); + static_cast(0), false, static_cast(nullptr)); } auto time4 = time(); std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n"; @@ -102,15 +102,15 @@ int main() { // int8_t with bias&relu // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul_int8( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, - static_cast(0), true, &bias_data[0]); + static_cast(0), true, bias_data); } auto time5 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul_int8( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(0.618), &cc_int8, - static_cast(0), true, &bias_data[0]); + static_cast(0), true, bias_data); } auto time6 = time(); std::cout << "int8_t gemm_with_bias_relu cost :"