未验证 提交 954be40d 编写于 作者: S sneaxiy 提交者: GitHub

fix gemm compute_type (#47613)

上级 b8ae3858
......@@ -1452,7 +1452,11 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
<< FLAGS_gemm_use_half_precision_compute_type;
auto fp = std::is_same<T, float>::value ? CUDA_R_32F : CUDA_R_16F;
cudaDataType_t compute_type = CUDA_R_32F;
#if CUDA_VERSION >= 11000
auto compute_type = CUBLAS_COMPUTE_32F;
#else
auto compute_type = CUDA_R_32F;
#endif
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
......@@ -1463,7 +1467,11 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
std::is_same<T, phi::dtype::float16>::value) {
a = static_cast<void *>(&alpha);
b = static_cast<void *>(&beta);
#if CUDA_VERSION >= 11000
compute_type = CUBLAS_COMPUTE_16F;
#else
compute_type = CUDA_R_16F;
#endif
}
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册