From 2eeaaa7d70ffe143bb89fa1d13d8acb42d0cabf0 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Mon, 27 Feb 2023 11:23:38 +0800 Subject: [PATCH] Add PADDLE_THROW in ToCudaDataType and polish codes. (#50922) --- paddle/phi/backends/gpu/cuda/cuda_helper.h | 7 +++++++ paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h | 18 +++--------------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/paddle/phi/backends/gpu/cuda/cuda_helper.h b/paddle/phi/backends/gpu/cuda/cuda_helper.h index 8d5ffec14e8..ab8facad4de 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_helper.h +++ b/paddle/phi/backends/gpu/cuda/cuda_helper.h @@ -18,7 +18,9 @@ #include // NOLINT #include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" namespace phi { namespace backends { @@ -87,6 +89,11 @@ cudaDataType_t ToCudaDataType() { } else if (std::is_same::value) { return CUDA_R_16BF; #endif + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "DataType %s is unsupported for CUDA.", + paddle::experimental::DataTypeToString( + paddle::experimental::CppTypeToDataType::Type()))); } } diff --git a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h index d3a2ead2847..80b409a5b7e 100644 --- a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "cuda.h" // NOLINT +#include "paddle/phi/backends/gpu/cuda/cuda_helper.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/autotune/cache.h" #include "paddle/phi/kernels/autotune/gpu_timer.h" @@ -27,19 +28,6 @@ namespace funcs { enum MatmulImplType { kImplWithCublas = 1, kImplWithCublasLt = 2 }; -template -cudaDataType_t ConvertToCudaDataType() { - if (std::is_same::value) { - return CUDA_R_32F; - } else if (std::is_same::value) { - return CUDA_R_64F; - } else if (std::is_same::value) { - return CUDA_R_16F; - } else if (std::is_same::value) { - return CUDA_R_16BF; - } -} - template cublasComputeType_t GetCudaComputeType() { if (std::is_same::value) { @@ -68,8 +56,8 @@ struct MatmulDescriptor { int64_t stride_out = 0) { using MT = typename phi::dtype::MPTypeTrait::Type; - cudaDataType_t mat_type = ConvertToCudaDataType(); - cudaDataType_t scale_type = ConvertToCudaDataType(); + cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType(); + cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType(); cublasComputeType_t compute_type = GetCudaComputeType(); // Create operation desciriptor; see cublasLtMatmulDescAttributes_t for -- GitLab