diff --git a/paddle/phi/backends/gpu/cuda/cuda_helper.h b/paddle/phi/backends/gpu/cuda/cuda_helper.h index 8d5ffec14e8a6626ae3adc57656137c415f5bcd0..ab8facad4de0519c89506d8e96bcf4ac13d09c59 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_helper.h +++ b/paddle/phi/backends/gpu/cuda/cuda_helper.h @@ -18,7 +18,9 @@ #include // NOLINT #include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" namespace phi { namespace backends { @@ -87,6 +89,11 @@ cudaDataType_t ToCudaDataType() { } else if (std::is_same::value) { return CUDA_R_16BF; #endif + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "DataType %s is unsupported for CUDA.", + paddle::experimental::DataTypeToString( + paddle::experimental::CppTypeToDataType::Type()))); } } diff --git a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h index d3a2ead2847339091b889ebc661c774c5e74be1e..80b409a5b7ee4ee2e2c2ad4988e3035903672cc4 100644 --- a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "cuda.h" // NOLINT +#include "paddle/phi/backends/gpu/cuda/cuda_helper.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/autotune/cache.h" #include "paddle/phi/kernels/autotune/gpu_timer.h" @@ -27,19 +28,6 @@ namespace funcs { enum MatmulImplType { kImplWithCublas = 1, kImplWithCublasLt = 2 }; -template -cudaDataType_t ConvertToCudaDataType() { - if (std::is_same::value) { - return CUDA_R_32F; - } else if (std::is_same::value) { - return CUDA_R_64F; - } else if (std::is_same::value) { - return CUDA_R_16F; - } else if (std::is_same::value) { - return CUDA_R_16BF; - } -} - template cublasComputeType_t GetCudaComputeType() { if (std::is_same::value) { @@ -68,8 +56,8 @@ struct MatmulDescriptor { int64_t stride_out = 0) { using MT = typename phi::dtype::MPTypeTrait::Type; - cudaDataType_t mat_type = ConvertToCudaDataType(); - cudaDataType_t scale_type = ConvertToCudaDataType(); + cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType(); + cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType(); cublasComputeType_t compute_type = GetCudaComputeType(); // Create operation desciriptor; see cublasLtMatmulDescAttributes_t for