未验证 提交 2eeaaa7d 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add PADDLE_THROW in ToCudaDataType and polish codes. (#50922)

上级 3669868d
......@@ -18,7 +18,9 @@
#include <cuda_runtime.h> // NOLINT
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
namespace phi {
namespace backends {
......@@ -87,6 +89,11 @@ cudaDataType_t ToCudaDataType() {
} else if (std::is_same<T, phi::dtype::bfloat16>::value) {
return CUDA_R_16BF;
#endif
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"DataType %s is unsupported for CUDA.",
paddle::experimental::DataTypeToString(
paddle::experimental::CppTypeToDataType<T>::Type())));
}
}
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <cuda_runtime_api.h>
#include "cuda.h" // NOLINT
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/autotune/cache.h"
#include "paddle/phi/kernels/autotune/gpu_timer.h"
......@@ -27,19 +28,6 @@ namespace funcs {
enum MatmulImplType { kImplWithCublas = 1, kImplWithCublasLt = 2 };
template <typename T>
cudaDataType_t ConvertToCudaDataType() {
if (std::is_same<T, float>::value) {
return CUDA_R_32F;
} else if (std::is_same<T, double>::value) {
return CUDA_R_64F;
} else if (std::is_same<T, phi::dtype::float16>::value) {
return CUDA_R_16F;
} else if (std::is_same<T, phi::dtype::bfloat16>::value) {
return CUDA_R_16BF;
}
}
template <typename T>
cublasComputeType_t GetCudaComputeType() {
if (std::is_same<T, double>::value) {
......@@ -68,8 +56,8 @@ struct MatmulDescriptor {
int64_t stride_out = 0) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
cudaDataType_t mat_type = ConvertToCudaDataType<T>();
cudaDataType_t scale_type = ConvertToCudaDataType<MT>();
cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType<T>();
cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType<MT>();
cublasComputeType_t compute_type = GetCudaComputeType<T>();
// Create operation desciriptor; see cublasLtMatmulDescAttributes_t for
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册