未验证 提交 67b9b51b 编写于 作者: G Guoxia Wang 提交者: GitHub

support nccl api for bfloat16, required >= cudnn 10.1, nccl >= 2.10.3 (#43147)

上级 048b0013
......@@ -91,6 +91,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(alltoall, ops::AllToAllOpCUDAKernel<float>,
ops::AllToAllOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::AllToAllOpCUDAKernel<plat::bfloat16>,
#endif
ops::AllToAllOpCUDAKernel<int>,
ops::AllToAllOpCUDAKernel<int64_t>,
ops::AllToAllOpCUDAKernel<plat::float16>);
......@@ -90,6 +90,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_allgather, ops::CAllGatherOpCUDAKernel<float>,
ops::CAllGatherOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::CAllGatherOpCUDAKernel<plat::bfloat16>,
#endif
ops::CAllGatherOpCUDAKernel<int>,
ops::CAllGatherOpCUDAKernel<int64_t>,
ops::CAllGatherOpCUDAKernel<plat::float16>);
......@@ -19,6 +19,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_allreduce_sum, ops::CAllReduceOpCUDAKernel<ops::kRedSum, float>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::bfloat16>,
#endif
ops::CAllReduceOpCUDAKernel<ops::kRedSum, double>,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, int>,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, int64_t>,
......
......@@ -98,6 +98,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_broadcast, ops::CBroadcastOpCUDAKernel<float>,
ops::CBroadcastOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::CBroadcastOpCUDAKernel<plat::bfloat16>,
#endif
ops::CBroadcastOpCUDAKernel<int>,
ops::CBroadcastOpCUDAKernel<int64_t>,
ops::CBroadcastOpCUDAKernel<plat::float16>);
......@@ -76,6 +76,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_reducescatter, ops::CReduceScatterOpCUDAKernel<float>,
ops::CReduceScatterOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::CReduceScatterOpCUDAKernel<plat::bfloat16>,
#endif
ops::CReduceScatterOpCUDAKernel<int>,
ops::CReduceScatterOpCUDAKernel<int64_t>,
ops::CReduceScatterOpCUDAKernel<plat::float16>);
......@@ -224,6 +224,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(recv_v2, ops::RecvOpV2CUDAKernel<float>,
ops::RecvOpV2CUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::RecvOpV2CUDAKernel<plat::bfloat16>,
#endif
ops::RecvOpV2CUDAKernel<int>,
ops::RecvOpV2CUDAKernel<int64_t>,
ops::RecvOpV2CUDAKernel<int8_t>,
......
......@@ -197,6 +197,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(send_v2, ops::SendOpV2CUDAKernel<float>,
ops::SendOpV2CUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::SendOpV2CUDAKernel<plat::bfloat16>,
#endif
ops::SendOpV2CUDAKernel<int>,
ops::SendOpV2CUDAKernel<int64_t>,
ops::SendOpV2CUDAKernel<int8_t>,
......
......@@ -31,6 +31,8 @@
#ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
......@@ -52,6 +54,10 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return ncclFloat16;
} else if (type == framework::proto::VarType::INT8) {
return ncclInt8;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
} else if (type == framework::proto::VarType::BF16) {
return ncclBfloat16;
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported."));
......@@ -69,6 +75,10 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) {
return ncclInt64;
} else if (type == experimental::DataType::FLOAT16) {
return ncclFloat16;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
} else if (type == experimental::DataType::BFLOAT16) {
return ncclBfloat16;
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported."));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册