未验证 提交 2eeaaa7d 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add PADDLE_THROW in ToCudaDataType and polish codes. (#50922)

上级 3669868d
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
#include <cuda_runtime.h> // NOLINT #include <cuda_runtime.h> // NOLINT
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
namespace phi { namespace phi {
namespace backends { namespace backends {
...@@ -87,6 +89,11 @@ cudaDataType_t ToCudaDataType() { ...@@ -87,6 +89,11 @@ cudaDataType_t ToCudaDataType() {
} else if (std::is_same<T, phi::dtype::bfloat16>::value) { } else if (std::is_same<T, phi::dtype::bfloat16>::value) {
return CUDA_R_16BF; return CUDA_R_16BF;
#endif #endif
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"DataType %s is unsupported for CUDA.",
paddle::experimental::DataTypeToString(
paddle::experimental::CppTypeToDataType<T>::Type())));
} }
} }
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include "cuda.h" // NOLINT #include "cuda.h" // NOLINT
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/autotune/cache.h" #include "paddle/phi/kernels/autotune/cache.h"
#include "paddle/phi/kernels/autotune/gpu_timer.h" #include "paddle/phi/kernels/autotune/gpu_timer.h"
...@@ -27,19 +28,6 @@ namespace funcs { ...@@ -27,19 +28,6 @@ namespace funcs {
enum MatmulImplType { kImplWithCublas = 1, kImplWithCublasLt = 2 }; enum MatmulImplType { kImplWithCublas = 1, kImplWithCublasLt = 2 };
template <typename T>
cudaDataType_t ConvertToCudaDataType() {
if (std::is_same<T, float>::value) {
return CUDA_R_32F;
} else if (std::is_same<T, double>::value) {
return CUDA_R_64F;
} else if (std::is_same<T, phi::dtype::float16>::value) {
return CUDA_R_16F;
} else if (std::is_same<T, phi::dtype::bfloat16>::value) {
return CUDA_R_16BF;
}
}
template <typename T> template <typename T>
cublasComputeType_t GetCudaComputeType() { cublasComputeType_t GetCudaComputeType() {
if (std::is_same<T, double>::value) { if (std::is_same<T, double>::value) {
...@@ -68,8 +56,8 @@ struct MatmulDescriptor { ...@@ -68,8 +56,8 @@ struct MatmulDescriptor {
int64_t stride_out = 0) { int64_t stride_out = 0) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type; using MT = typename phi::dtype::MPTypeTrait<T>::Type;
cudaDataType_t mat_type = ConvertToCudaDataType<T>(); cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType<T>();
cudaDataType_t scale_type = ConvertToCudaDataType<MT>(); cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType<MT>();
cublasComputeType_t compute_type = GetCudaComputeType<T>(); cublasComputeType_t compute_type = GetCudaComputeType<T>();
// Create operation desciriptor; see cublasLtMatmulDescAttributes_t for // Create operation desciriptor; see cublasLtMatmulDescAttributes_t for
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册