提交 caa4027d 编写于 作者: Y Yu Yang

Follow comments

上级 4db43c6c
...@@ -126,14 +126,9 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -126,14 +126,9 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
CUDA_R_32F, algo)); CUDA_R_32F, algo));
#else #else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
const half h_alpha = static_cast<const half>(alpha); CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
const half h_beta = static_cast<const half>(beta); N, M, K, &h_alpha, h_B, ldb, h_A, lda,
const half *h_A = reinterpret_cast<const half *>(A); &h_beta, h_C, N);
const half *h_B = reinterpret_cast<const half *>(B);
half *h_C = reinterpret_cast<half *>(C);
CUBlas<platform::float16>(context_.cublas_handle(), cuTransB, cuTransA, N, M,
K, &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, N);
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册