未验证 提交 67b9b51b 编写于 作者: G Guoxia Wang 提交者: GitHub

support nccl api for bfloat16, required >= cudnn 10.1, nccl >= 2.10.3 (#43147)

上级 048b0013
...@@ -91,6 +91,9 @@ namespace plat = paddle::platform; ...@@ -91,6 +91,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(alltoall, ops::AllToAllOpCUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(alltoall, ops::AllToAllOpCUDAKernel<float>,
ops::AllToAllOpCUDAKernel<double>, ops::AllToAllOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::AllToAllOpCUDAKernel<plat::bfloat16>,
#endif
ops::AllToAllOpCUDAKernel<int>, ops::AllToAllOpCUDAKernel<int>,
ops::AllToAllOpCUDAKernel<int64_t>, ops::AllToAllOpCUDAKernel<int64_t>,
ops::AllToAllOpCUDAKernel<plat::float16>); ops::AllToAllOpCUDAKernel<plat::float16>);
...@@ -90,6 +90,9 @@ namespace plat = paddle::platform; ...@@ -90,6 +90,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_allgather, ops::CAllGatherOpCUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(c_allgather, ops::CAllGatherOpCUDAKernel<float>,
ops::CAllGatherOpCUDAKernel<double>, ops::CAllGatherOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::CAllGatherOpCUDAKernel<plat::bfloat16>,
#endif
ops::CAllGatherOpCUDAKernel<int>, ops::CAllGatherOpCUDAKernel<int>,
ops::CAllGatherOpCUDAKernel<int64_t>, ops::CAllGatherOpCUDAKernel<int64_t>,
ops::CAllGatherOpCUDAKernel<plat::float16>); ops::CAllGatherOpCUDAKernel<plat::float16>);
...@@ -19,6 +19,9 @@ namespace plat = paddle::platform; ...@@ -19,6 +19,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
c_allreduce_sum, ops::CAllReduceOpCUDAKernel<ops::kRedSum, float>, c_allreduce_sum, ops::CAllReduceOpCUDAKernel<ops::kRedSum, float>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::bfloat16>,
#endif
ops::CAllReduceOpCUDAKernel<ops::kRedSum, double>, ops::CAllReduceOpCUDAKernel<ops::kRedSum, double>,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, int>, ops::CAllReduceOpCUDAKernel<ops::kRedSum, int>,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, int64_t>, ops::CAllReduceOpCUDAKernel<ops::kRedSum, int64_t>,
......
...@@ -98,6 +98,9 @@ namespace plat = paddle::platform; ...@@ -98,6 +98,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_broadcast, ops::CBroadcastOpCUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(c_broadcast, ops::CBroadcastOpCUDAKernel<float>,
ops::CBroadcastOpCUDAKernel<double>, ops::CBroadcastOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::CBroadcastOpCUDAKernel<plat::bfloat16>,
#endif
ops::CBroadcastOpCUDAKernel<int>, ops::CBroadcastOpCUDAKernel<int>,
ops::CBroadcastOpCUDAKernel<int64_t>, ops::CBroadcastOpCUDAKernel<int64_t>,
ops::CBroadcastOpCUDAKernel<plat::float16>); ops::CBroadcastOpCUDAKernel<plat::float16>);
...@@ -76,6 +76,9 @@ namespace plat = paddle::platform; ...@@ -76,6 +76,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_reducescatter, ops::CReduceScatterOpCUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(c_reducescatter, ops::CReduceScatterOpCUDAKernel<float>,
ops::CReduceScatterOpCUDAKernel<double>, ops::CReduceScatterOpCUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::CReduceScatterOpCUDAKernel<plat::bfloat16>,
#endif
ops::CReduceScatterOpCUDAKernel<int>, ops::CReduceScatterOpCUDAKernel<int>,
ops::CReduceScatterOpCUDAKernel<int64_t>, ops::CReduceScatterOpCUDAKernel<int64_t>,
ops::CReduceScatterOpCUDAKernel<plat::float16>); ops::CReduceScatterOpCUDAKernel<plat::float16>);
...@@ -224,6 +224,9 @@ namespace plat = paddle::platform; ...@@ -224,6 +224,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(recv_v2, ops::RecvOpV2CUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(recv_v2, ops::RecvOpV2CUDAKernel<float>,
ops::RecvOpV2CUDAKernel<double>, ops::RecvOpV2CUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::RecvOpV2CUDAKernel<plat::bfloat16>,
#endif
ops::RecvOpV2CUDAKernel<int>, ops::RecvOpV2CUDAKernel<int>,
ops::RecvOpV2CUDAKernel<int64_t>, ops::RecvOpV2CUDAKernel<int64_t>,
ops::RecvOpV2CUDAKernel<int8_t>, ops::RecvOpV2CUDAKernel<int8_t>,
......
...@@ -197,6 +197,9 @@ namespace plat = paddle::platform; ...@@ -197,6 +197,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(send_v2, ops::SendOpV2CUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(send_v2, ops::SendOpV2CUDAKernel<float>,
ops::SendOpV2CUDAKernel<double>, ops::SendOpV2CUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
ops::SendOpV2CUDAKernel<plat::bfloat16>,
#endif
ops::SendOpV2CUDAKernel<int>, ops::SendOpV2CUDAKernel<int>,
ops::SendOpV2CUDAKernel<int64_t>, ops::SendOpV2CUDAKernel<int64_t>,
ops::SendOpV2CUDAKernel<int8_t>, ops::SendOpV2CUDAKernel<int8_t>,
......
...@@ -31,6 +31,8 @@ ...@@ -31,6 +31,8 @@
#ifdef PADDLE_WITH_RCCL #ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h" #include "paddle/fluid/platform/dynload/rccl.h"
#endif #endif
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -52,6 +54,10 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) { ...@@ -52,6 +54,10 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return ncclFloat16; return ncclFloat16;
} else if (type == framework::proto::VarType::INT8) { } else if (type == framework::proto::VarType::INT8) {
return ncclInt8; return ncclInt8;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
} else if (type == framework::proto::VarType::BF16) {
return ncclBfloat16;
#endif
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported.")); "This datatype in nccl is not supported."));
...@@ -69,6 +75,10 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) { ...@@ -69,6 +75,10 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) {
return ncclInt64; return ncclInt64;
} else if (type == experimental::DataType::FLOAT16) { } else if (type == experimental::DataType::FLOAT16) {
return ncclFloat16; return ncclFloat16;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
} else if (type == experimental::DataType::BFLOAT16) {
return ncclBfloat16;
#endif
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported.")); "This datatype in nccl is not supported."));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册