未验证 提交 30d1ff3b 编写于 作者: Z Zhang Ting 提交者: GitHub

call cublasGemmStridedBatchedEx when using fp16, test=develop (#25553)

上级 9df18b08
......@@ -428,7 +428,8 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
const int64_t strideC = M * N;
#if CUDA_VERSION >= 9010
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same<T, float>::value)) ||
std::is_same<T, paddle::platform::float16>::value) {
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
......@@ -437,11 +438,11 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
auto fp = std::is_same<T, float>::value ? CUDA_R_32F : CUDA_R_16F;
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
strideC, batchCount, CUDA_R_32F, algo));
handle, cuTransB, cuTransA, N, M, K, &alpha, B, fp, ldb, strideB, A,
fp, lda, strideA, &beta, C, fp, ldc, strideC, batchCount, fp, algo));
});
} else {
#endif // CUDA_VERSION >= 9010
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册