diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index 86e49469912256d305e6a5d19f1736aef51a754b..89935829ab35a52dd85bcaf906b53e41d576cf3f 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -126,14 +126,9 @@ inline void Blas::GEMM( CUDA_R_32F, algo)); #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - const half h_alpha = static_cast(alpha); - const half h_beta = static_cast(beta); - const half *h_A = reinterpret_cast(A); - const half *h_B = reinterpret_cast(B); - half *h_C = reinterpret_cast(C); - - CUBlas(context_.cublas_handle(), cuTransB, cuTransA, N, M, - K, &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, N); + CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, + N, M, K, &h_alpha, h_B, ldb, h_A, lda, + &h_beta, h_C, N); #endif // CUDA_VERSION >= 8000 }