From 67b9b51b0eb3f8e66624b11ee3600563031c0297 Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Wed, 1 Jun 2022 17:08:14 +0800 Subject: [PATCH] support nccl api for bfloat16, required >= cudnn 10.1, nccl >= 2.10.3 (#43147) --- paddle/fluid/operators/collective/alltoall_op.cu.cc | 3 +++ paddle/fluid/operators/collective/c_allgather_op.cu.cc | 3 +++ .../operators/collective/c_allreduce_sum_op.cu.cc | 3 +++ paddle/fluid/operators/collective/c_broadcast_op.cu.cc | 3 +++ .../operators/collective/c_reducescatter_op.cu.cc | 3 +++ paddle/fluid/operators/collective/recv_v2_op.cu.cc | 3 +++ paddle/fluid/operators/collective/send_v2_op.cu.cc | 3 +++ paddle/fluid/platform/device/gpu/nccl_helper.h | 10 ++++++++++ 8 files changed, 31 insertions(+) diff --git a/paddle/fluid/operators/collective/alltoall_op.cu.cc b/paddle/fluid/operators/collective/alltoall_op.cu.cc index 0e0ea722084..bb498047a50 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cu.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cu.cc @@ -91,6 +91,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(alltoall, ops::AllToAllOpCUDAKernel, ops::AllToAllOpCUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::AllToAllOpCUDAKernel, +#endif ops::AllToAllOpCUDAKernel, ops::AllToAllOpCUDAKernel, ops::AllToAllOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc index 0d97ffa96dc..62ed916d6e0 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -90,6 +90,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(c_allgather, ops::CAllGatherOpCUDAKernel, ops::CAllGatherOpCUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::CAllGatherOpCUDAKernel, +#endif ops::CAllGatherOpCUDAKernel, ops::CAllGatherOpCUDAKernel, ops::CAllGatherOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_allreduce_sum_op.cu.cc b/paddle/fluid/operators/collective/c_allreduce_sum_op.cu.cc index 8fe7fce21e4..565633c2e7b 100644 --- a/paddle/fluid/operators/collective/c_allreduce_sum_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allreduce_sum_op.cu.cc @@ -19,6 +19,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( c_allreduce_sum, ops::CAllReduceOpCUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::CAllReduceOpCUDAKernel, +#endif ops::CAllReduceOpCUDAKernel, ops::CAllReduceOpCUDAKernel, ops::CAllReduceOpCUDAKernel, diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc index eeae16a0d71..478dc859149 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc @@ -98,6 +98,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(c_broadcast, ops::CBroadcastOpCUDAKernel, ops::CBroadcastOpCUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::CBroadcastOpCUDAKernel, +#endif ops::CBroadcastOpCUDAKernel, ops::CBroadcastOpCUDAKernel, ops::CBroadcastOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc index 9b05e940d4f..fda192c45e7 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc @@ -76,6 +76,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(c_reducescatter, ops::CReduceScatterOpCUDAKernel, ops::CReduceScatterOpCUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::CReduceScatterOpCUDAKernel, +#endif ops::CReduceScatterOpCUDAKernel, ops::CReduceScatterOpCUDAKernel, ops::CReduceScatterOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc index f7a2e198db9..67c30438869 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cu.cc @@ -224,6 +224,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(recv_v2, ops::RecvOpV2CUDAKernel, ops::RecvOpV2CUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::RecvOpV2CUDAKernel, +#endif ops::RecvOpV2CUDAKernel, ops::RecvOpV2CUDAKernel, ops::RecvOpV2CUDAKernel, diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc index 8878b7c3449..cfb3a11513a 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -197,6 +197,9 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(send_v2, ops::SendOpV2CUDAKernel, ops::SendOpV2CUDAKernel, +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + ops::SendOpV2CUDAKernel, +#endif ops::SendOpV2CUDAKernel, ops::SendOpV2CUDAKernel, ops::SendOpV2CUDAKernel, diff --git a/paddle/fluid/platform/device/gpu/nccl_helper.h b/paddle/fluid/platform/device/gpu/nccl_helper.h index 61ea0fd3cd2..d0cb9c953a5 100644 --- a/paddle/fluid/platform/device/gpu/nccl_helper.h +++ b/paddle/fluid/platform/device/gpu/nccl_helper.h @@ -31,6 +31,8 @@ #ifdef PADDLE_WITH_RCCL #include "paddle/fluid/platform/dynload/rccl.h" #endif +#include "paddle/fluid/platform/bfloat16.h" +#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" @@ -52,6 +54,10 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) { return ncclFloat16; } else if (type == framework::proto::VarType::INT8) { return ncclInt8; +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + } else if (type == framework::proto::VarType::BF16) { + return ncclBfloat16; +#endif } else { PADDLE_THROW(platform::errors::Unimplemented( "This datatype in nccl is not supported.")); @@ -69,6 +75,10 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) { return ncclInt64; } else if (type == experimental::DataType::FLOAT16) { return ncclFloat16; +#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 + } else if (type == experimental::DataType::BFLOAT16) { + return ncclBfloat16; +#endif } else { PADDLE_THROW(platform::errors::Unimplemented( "This datatype in nccl is not supported.")); -- GitLab