未验证 提交 89da2f19 编写于 作者: L LiYuRio 提交者: GitHub

fix nccl version (#53942)

上级 73d706ce
...@@ -98,7 +98,7 @@ PD_REGISTER_STRUCT_KERNEL(alltoall, ...@@ -98,7 +98,7 @@ PD_REGISTER_STRUCT_KERNEL(alltoall,
ops::AllToAllOpCUDAKernel, ops::AllToAllOpCUDAKernel,
float, float,
double, double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
int, int,
......
...@@ -95,7 +95,7 @@ PD_REGISTER_STRUCT_KERNEL(c_allgather, ...@@ -95,7 +95,7 @@ PD_REGISTER_STRUCT_KERNEL(c_allgather,
ops::CAllGatherOpCUDAKernel, ops::CAllGatherOpCUDAKernel,
float, float,
double, double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
int, int,
......
...@@ -28,7 +28,7 @@ PD_REGISTER_STRUCT_KERNEL(c_allreduce_max, ...@@ -28,7 +28,7 @@ PD_REGISTER_STRUCT_KERNEL(c_allreduce_max,
ALL_LAYOUT, ALL_LAYOUT,
ops::CAllReduceMaxCUDAKernel, ops::CAllReduceMaxCUDAKernel,
float, float,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
double, double,
......
...@@ -28,7 +28,7 @@ PD_REGISTER_STRUCT_KERNEL(c_allreduce_sum, ...@@ -28,7 +28,7 @@ PD_REGISTER_STRUCT_KERNEL(c_allreduce_sum,
ALL_LAYOUT, ALL_LAYOUT,
ops::CAllReduceSumCUDAKernel, ops::CAllReduceSumCUDAKernel,
float, float,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
double, double,
......
...@@ -100,7 +100,7 @@ PD_REGISTER_STRUCT_KERNEL(c_broadcast, ...@@ -100,7 +100,7 @@ PD_REGISTER_STRUCT_KERNEL(c_broadcast,
int64_t, int64_t,
float, float,
double, double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
plat::float16) { plat::float16) {
......
...@@ -137,7 +137,7 @@ PD_REGISTER_STRUCT_KERNEL(c_concat, ...@@ -137,7 +137,7 @@ PD_REGISTER_STRUCT_KERNEL(c_concat,
double, double,
int, int,
int64_t, int64_t,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
plat::float16) { plat::float16) {
......
...@@ -239,7 +239,7 @@ PD_REGISTER_STRUCT_KERNEL(c_embedding, ...@@ -239,7 +239,7 @@ PD_REGISTER_STRUCT_KERNEL(c_embedding,
ops::CEmbeddingCUDAKernel, ops::CEmbeddingCUDAKernel,
float, float,
double, double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
plat::float16) { plat::float16) {
...@@ -251,7 +251,7 @@ PD_REGISTER_STRUCT_KERNEL(c_embedding_grad, ...@@ -251,7 +251,7 @@ PD_REGISTER_STRUCT_KERNEL(c_embedding_grad,
ops::CEmbeddingGradCUDAKernel, ops::CEmbeddingGradCUDAKernel,
float, float,
double, double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
plat::float16) { plat::float16) {
......
...@@ -25,7 +25,7 @@ PD_REGISTER_STRUCT_KERNEL(c_identity, ...@@ -25,7 +25,7 @@ PD_REGISTER_STRUCT_KERNEL(c_identity,
double, double,
int, int,
int64_t, int64_t,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
plat::float16) { plat::float16) {
......
...@@ -87,7 +87,7 @@ PD_REGISTER_STRUCT_KERNEL(c_reducescatter, ...@@ -87,7 +87,7 @@ PD_REGISTER_STRUCT_KERNEL(c_reducescatter,
ops::CReduceScatterOpCUDAKernel, ops::CReduceScatterOpCUDAKernel,
float, float,
double, double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
int, int,
......
...@@ -123,7 +123,7 @@ PD_REGISTER_STRUCT_KERNEL(c_split, ...@@ -123,7 +123,7 @@ PD_REGISTER_STRUCT_KERNEL(c_split,
double, double,
int, int,
int64_t, int64_t,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
plat::float16) { plat::float16) {
......
...@@ -31,7 +31,7 @@ PD_REGISTER_STRUCT_KERNEL(mp_allreduce_sum, ...@@ -31,7 +31,7 @@ PD_REGISTER_STRUCT_KERNEL(mp_allreduce_sum,
double, double,
int, int,
int64_t, int64_t,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
plat::float16) { plat::float16) {
......
...@@ -108,7 +108,7 @@ PD_REGISTER_STRUCT_KERNEL(partial_allgather, ...@@ -108,7 +108,7 @@ PD_REGISTER_STRUCT_KERNEL(partial_allgather,
ops::PartialAllGatherOpCUDAKernel, ops::PartialAllGatherOpCUDAKernel,
float, float,
double, double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
int, int,
......
...@@ -124,7 +124,7 @@ PD_REGISTER_STRUCT_KERNEL(partial_recv, ...@@ -124,7 +124,7 @@ PD_REGISTER_STRUCT_KERNEL(partial_recv,
ops::PartialRecvOpCUDAKernel, ops::PartialRecvOpCUDAKernel,
float, float,
double, double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
int, int,
......
...@@ -123,7 +123,7 @@ PD_REGISTER_STRUCT_KERNEL(partial_send, ...@@ -123,7 +123,7 @@ PD_REGISTER_STRUCT_KERNEL(partial_send,
ops::PartialSendCUDAKernel, ops::PartialSendCUDAKernel,
float, float,
double, double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
int, int,
......
...@@ -238,7 +238,7 @@ PD_REGISTER_STRUCT_KERNEL(recv_v2, ...@@ -238,7 +238,7 @@ PD_REGISTER_STRUCT_KERNEL(recv_v2,
ops::RecvOpV2CUDAKernel, ops::RecvOpV2CUDAKernel,
float, float,
double, double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
int, int,
......
...@@ -223,7 +223,7 @@ PD_REGISTER_STRUCT_KERNEL(send_v2, ...@@ -223,7 +223,7 @@ PD_REGISTER_STRUCT_KERNEL(send_v2,
ops::SendOpV2CUDAKernel, ops::SendOpV2CUDAKernel,
float, float,
double, double,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16, plat::bfloat16,
#endif #endif
int, int,
......
...@@ -59,7 +59,7 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) { ...@@ -59,7 +59,7 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return ncclUint8; return ncclUint8;
} else if (type == framework::proto::VarType::BOOL) { } else if (type == framework::proto::VarType::BOOL) {
return ncclUint8; return ncclUint8;
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
} else if (type == framework::proto::VarType::BF16) { } else if (type == framework::proto::VarType::BF16) {
return ncclBfloat16; return ncclBfloat16;
#endif #endif
...@@ -86,7 +86,7 @@ inline ncclDataType_t ToNCCLDataType(phi::DataType type) { ...@@ -86,7 +86,7 @@ inline ncclDataType_t ToNCCLDataType(phi::DataType type) {
return ncclInt8; return ncclInt8;
} else if (type == phi::DataType::BOOL) { } else if (type == phi::DataType::BOOL) {
return ncclUint8; return ncclUint8;
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
} else if (type == phi::DataType::BFLOAT16) { } else if (type == phi::DataType::BFLOAT16) {
return ncclBfloat16; return ncclBfloat16;
#endif #endif
......
...@@ -229,7 +229,7 @@ inline ncclDataType_t ToNCCLDataType(DataType type) { ...@@ -229,7 +229,7 @@ inline ncclDataType_t ToNCCLDataType(DataType type) {
return ncclInt8; return ncclInt8;
} else if (type == DataType::BOOL) { } else if (type == DataType::BOOL) {
return ncclUint8; return ncclUint8;
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
} else if (type == DataType::BFLOAT16) { } else if (type == DataType::BFLOAT16) {
return ncclBfloat16; return ncclBfloat16;
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册