diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index 6cd0514832e40c567448df0e16dbf03ec992a634..4a35dd2a57f3bab262dd28a60bac63cb7b7e8f77 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -71,34 +71,11 @@ void MatMul(const framework::Tensor &matrix_a, bool trans_a, a[index++] = tmp[i * n + j]; } } - if (M == 1) { -#ifdef _OPENMP - gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); -#else - gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); -#endif - } else { - cblas_sgemm(false, false, M, N, K, alpha, a, K, matrix_b.data(), N, - beta, matrix_out->data(), N); - } + cblas_sgemm(false, false, M, N, K, alpha, a, K, matrix_b.data(), N, + beta, matrix_out->data(), N); } else { - if (M == 1) { -#ifdef _OPENMP - gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, matrix_out->data(), - N, relu, bias); -#else - gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, matrix_out->data(), N, - relu, bias); -#endif - } else { - cblas_sgemm(false, false, M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, matrix_out->data(), - N); - } + cblas_sgemm(false, false, M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), N); } }