未验证 提交 89da2f19 编写于 作者: L LiYuRio 提交者: GitHub

fix nccl version (#53942)

上级 73d706ce
......@@ -98,7 +98,7 @@ PD_REGISTER_STRUCT_KERNEL(alltoall,
ops::AllToAllOpCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
int,
......
......@@ -95,7 +95,7 @@ PD_REGISTER_STRUCT_KERNEL(c_allgather,
ops::CAllGatherOpCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
int,
......
......@@ -28,7 +28,7 @@ PD_REGISTER_STRUCT_KERNEL(c_allreduce_max,
ALL_LAYOUT,
ops::CAllReduceMaxCUDAKernel,
float,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
double,
......
......@@ -28,7 +28,7 @@ PD_REGISTER_STRUCT_KERNEL(c_allreduce_sum,
ALL_LAYOUT,
ops::CAllReduceSumCUDAKernel,
float,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
double,
......
......@@ -100,7 +100,7 @@ PD_REGISTER_STRUCT_KERNEL(c_broadcast,
int64_t,
float,
double,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
plat::float16) {
......
......@@ -137,7 +137,7 @@ PD_REGISTER_STRUCT_KERNEL(c_concat,
double,
int,
int64_t,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
plat::float16) {
......
......@@ -239,7 +239,7 @@ PD_REGISTER_STRUCT_KERNEL(c_embedding,
ops::CEmbeddingCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
plat::float16) {
......@@ -251,7 +251,7 @@ PD_REGISTER_STRUCT_KERNEL(c_embedding_grad,
ops::CEmbeddingGradCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
plat::float16) {
......
......@@ -25,7 +25,7 @@ PD_REGISTER_STRUCT_KERNEL(c_identity,
double,
int,
int64_t,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
plat::float16) {
......
......@@ -87,7 +87,7 @@ PD_REGISTER_STRUCT_KERNEL(c_reducescatter,
ops::CReduceScatterOpCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
int,
......
......@@ -123,7 +123,7 @@ PD_REGISTER_STRUCT_KERNEL(c_split,
double,
int,
int64_t,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
plat::float16) {
......
......@@ -31,7 +31,7 @@ PD_REGISTER_STRUCT_KERNEL(mp_allreduce_sum,
double,
int,
int64_t,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
plat::float16) {
......
......@@ -108,7 +108,7 @@ PD_REGISTER_STRUCT_KERNEL(partial_allgather,
ops::PartialAllGatherOpCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
int,
......
......@@ -124,7 +124,7 @@ PD_REGISTER_STRUCT_KERNEL(partial_recv,
ops::PartialRecvOpCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
int,
......
......@@ -123,7 +123,7 @@ PD_REGISTER_STRUCT_KERNEL(partial_send,
ops::PartialSendCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
int,
......
......@@ -238,7 +238,7 @@ PD_REGISTER_STRUCT_KERNEL(recv_v2,
ops::RecvOpV2CUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
int,
......
......@@ -223,7 +223,7 @@ PD_REGISTER_STRUCT_KERNEL(send_v2,
ops::SendOpV2CUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
int,
......
......@@ -59,7 +59,7 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return ncclUint8;
} else if (type == framework::proto::VarType::BOOL) {
return ncclUint8;
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
} else if (type == framework::proto::VarType::BF16) {
return ncclBfloat16;
#endif
......@@ -86,7 +86,7 @@ inline ncclDataType_t ToNCCLDataType(phi::DataType type) {
return ncclInt8;
} else if (type == phi::DataType::BOOL) {
return ncclUint8;
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
} else if (type == phi::DataType::BFLOAT16) {
return ncclBfloat16;
#endif
......
......@@ -229,7 +229,7 @@ inline ncclDataType_t ToNCCLDataType(DataType type) {
return ncclInt8;
} else if (type == DataType::BOOL) {
return ncclUint8;
#if NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
} else if (type == DataType::BFLOAT16) {
return ncclBfloat16;
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册