diff --git a/paddle/fluid/operators/collective/alltoall_op.cu.cc b/paddle/fluid/operators/collective/alltoall_op.cu.cc index 0e0ea72208488c520b24680611a272bd378c9338..bb498047a50b04fc24e71eddf6535dadc6798ae0 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cu.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cu.cc @@ -91,6 +91,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(alltoall, ops::AllToAllOpCUDAKernel, ops::AllToAllOpCUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::AllToAllOpCUDAKernel, +#endif ops::AllToAllOpCUDAKernel, ops::AllToAllOpCUDAKernel, ops::AllToAllOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc index 0d97ffa96dc5c079642937a8fc06109f7c031c77..62ed916d6e08c8efe522faed1e014dd832eba979 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -90,6 +90,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(c_allgather, ops::CAllGatherOpCUDAKernel, ops::CAllGatherOpCUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::CAllGatherOpCUDAKernel, +#endif ops::CAllGatherOpCUDAKernel, ops::CAllGatherOpCUDAKernel, ops::CAllGatherOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_allreduce_sum_op.cu.cc b/paddle/fluid/operators/collective/c_allreduce_sum_op.cu.cc index 8fe7fce21e465af8af4d045c29dbc12ab9bc3c84..565633c2e7b2d104cb36acce99e640855eba6137 100644 --- a/paddle/fluid/operators/collective/c_allreduce_sum_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allreduce_sum_op.cu.cc @@ -19,6 +19,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( c_allreduce_sum, ops::CAllReduceOpCUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::CAllReduceOpCUDAKernel, +#endif ops::CAllReduceOpCUDAKernel, ops::CAllReduceOpCUDAKernel, ops::CAllReduceOpCUDAKernel, diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc index eeae16a0d71f3881cf9e8447bf96c1508c219d7f..478dc85914964c205124d5b1f97894af20c87728 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc @@ -98,6 +98,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(c_broadcast, ops::CBroadcastOpCUDAKernel, ops::CBroadcastOpCUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::CBroadcastOpCUDAKernel, +#endif ops::CBroadcastOpCUDAKernel, ops::CBroadcastOpCUDAKernel, ops::CBroadcastOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc index 9b05e940d4f6024a720d432f8754f641aedc81ef..fda192c45e77924687fa6b2f7608b38588a9d62d 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc @@ -76,6 +76,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(c_reducescatter, ops::CReduceScatterOpCUDAKernel, ops::CReduceScatterOpCUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::CReduceScatterOpCUDAKernel, +#endif ops::CReduceScatterOpCUDAKernel, ops::CReduceScatterOpCUDAKernel, ops::CReduceScatterOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc index f7a2e198db938496365923dd99035b26f8e472d4..67c30438869b1836abb635ce7990c89795e0ce97 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cu.cc @@ -224,6 +224,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(recv_v2, ops::RecvOpV2CUDAKernel, ops::RecvOpV2CUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::RecvOpV2CUDAKernel, +#endif ops::RecvOpV2CUDAKernel, ops::RecvOpV2CUDAKernel, ops::RecvOpV2CUDAKernel, diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc index 8878b7c3449b9fdebdbc232eef4edddf29738fea..cfb3a11513a219bfce1719cd0518f46159e5a364 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -197,6 +197,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(send_v2, ops::SendOpV2CUDAKernel, ops::SendOpV2CUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::SendOpV2CUDAKernel, +#endif ops::SendOpV2CUDAKernel, ops::SendOpV2CUDAKernel, ops::SendOpV2CUDAKernel, diff --git a/paddle/fluid/platform/device/gpu/nccl_helper.h b/paddle/fluid/platform/device/gpu/nccl_helper.h index 61ea0fd3cd2939c5256a59b738964152483b6499..d0cb9c953a5bfa31129be8746f24de37a7b3bc3c 100644 --- a/paddle/fluid/platform/device/gpu/nccl_helper.h +++ b/paddle/fluid/platform/device/gpu/nccl_helper.h @@ -31,6 +31,8 @@ #ifdef PADDLE_WITH_RCCL #include "paddle/fluid/platform/dynload/rccl.h" #endif +#include "paddle/fluid/platform/bfloat16.h" +#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" @@ -52,6 +54,10 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) { return ncclFloat16; } else if (type == framework::proto::VarType::INT8) { return ncclInt8; +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + } else if (type == framework::proto::VarType::BF16) { + return ncclBfloat16; +#endif } else { PADDLE_THROW(platform::errors::Unimplemented( "This datatype in nccl is not supported.")); @@ -69,6 +75,10 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) { return ncclInt64; } else if (type == experimental::DataType::FLOAT16) { return ncclFloat16; +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + } else if (type == experimental::DataType::BFLOAT16) { + return ncclBfloat16; +#endif } else { PADDLE_THROW(platform::errors::Unimplemented( "This datatype in nccl is not supported."));