未验证 提交 73473ac2 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

register bf16 for c ops (#52641)

上级 9431bae1
...@@ -20,6 +20,9 @@ namespace plat = paddle::platform; ...@@ -20,6 +20,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
c_allreduce_max, c_allreduce_max,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, float>, ops::CAllReduceOpCUDAKernel<ops::kRedMax, float>,
#if NCCL_VERSION_CODE >= 21000
ops::CAllReduceOpCUDAKernel<ops::kRedMax, plat::bfloat16>,
#endif
ops::CAllReduceOpCUDAKernel<ops::kRedMax, double>, ops::CAllReduceOpCUDAKernel<ops::kRedMax, double>,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, int>, ops::CAllReduceOpCUDAKernel<ops::kRedMax, int>,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, int64_t>, ops::CAllReduceOpCUDAKernel<ops::kRedMax, int64_t>,
......
...@@ -20,7 +20,7 @@ namespace plat = paddle::platform; ...@@ -20,7 +20,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
c_allreduce_sum, c_allreduce_sum,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, float>, ops::CAllReduceOpCUDAKernel<ops::kRedSum, float>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::bfloat16>, ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::bfloat16>,
#endif #endif
ops::CAllReduceOpCUDAKernel<ops::kRedSum, double>, ops::CAllReduceOpCUDAKernel<ops::kRedSum, double>,
......
...@@ -132,6 +132,9 @@ namespace plat = paddle::platform; ...@@ -132,6 +132,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_concat, REGISTER_OP_CUDA_KERNEL(c_concat,
ops::CConcatOpCUDAKernel<float>, ops::CConcatOpCUDAKernel<float>,
ops::CConcatOpCUDAKernel<double>, ops::CConcatOpCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::CConcatOpCUDAKernel<plat::bfloat16>,
#endif
ops::CConcatOpCUDAKernel<int>, ops::CConcatOpCUDAKernel<int>,
ops::CConcatOpCUDAKernel<int64_t>, ops::CConcatOpCUDAKernel<int64_t>,
ops::CConcatOpCUDAKernel<plat::float16>); ops::CConcatOpCUDAKernel<plat::float16>);
...@@ -253,8 +253,10 @@ namespace plat = paddle::platform; ...@@ -253,8 +253,10 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_embedding, REGISTER_OP_CUDA_KERNEL(c_embedding,
ops::CEmbeddingCUDAKernel<float>, ops::CEmbeddingCUDAKernel<float>,
ops::CEmbeddingCUDAKernel<double>, ops::CEmbeddingCUDAKernel<double>,
ops::CEmbeddingCUDAKernel<plat::bfloat16>,
ops::CEmbeddingCUDAKernel<plat::float16>); ops::CEmbeddingCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(c_embedding_grad, REGISTER_OP_CUDA_KERNEL(c_embedding_grad,
ops::CEmbeddingGradCUDAKernel<float>, ops::CEmbeddingGradCUDAKernel<float>,
ops::CEmbeddingGradCUDAKernel<double>, ops::CEmbeddingGradCUDAKernel<double>,
ops::CEmbeddingGradCUDAKernel<plat::bfloat16>,
ops::CEmbeddingGradCUDAKernel<plat::float16>); ops::CEmbeddingGradCUDAKernel<plat::float16>);
...@@ -19,6 +19,9 @@ namespace plat = paddle::platform; ...@@ -19,6 +19,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_identity, REGISTER_OP_CUDA_KERNEL(c_identity,
ops::CIdentityOpKernel<float>, ops::CIdentityOpKernel<float>,
#if NCCL_VERSION_CODE >= 21000
ops::CIdentityOpKernel<plat::bfloat16>,
#endif
ops::CIdentityOpKernel<double>, ops::CIdentityOpKernel<double>,
ops::CIdentityOpKernel<int>, ops::CIdentityOpKernel<int>,
ops::CIdentityOpKernel<int64_t>, ops::CIdentityOpKernel<int64_t>,
......
...@@ -117,6 +117,9 @@ namespace plat = paddle::platform; ...@@ -117,6 +117,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_split, REGISTER_OP_CUDA_KERNEL(c_split,
ops::CSplitOpCUDAKernel<float>, ops::CSplitOpCUDAKernel<float>,
#if NCCL_VERSION_CODE >= 21000
ops::CSplitOpCUDAKernel<plat::bfloat16>,
#endif
ops::CSplitOpCUDAKernel<double>, ops::CSplitOpCUDAKernel<double>,
ops::CSplitOpCUDAKernel<int>, ops::CSplitOpCUDAKernel<int>,
ops::CSplitOpCUDAKernel<int64_t>, ops::CSplitOpCUDAKernel<int64_t>,
......
...@@ -110,6 +110,9 @@ namespace plat = paddle::platform; ...@@ -110,6 +110,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(partial_allgather, REGISTER_OP_CUDA_KERNEL(partial_allgather,
ops::PartialAllGatherOpCUDAKernel<float>, ops::PartialAllGatherOpCUDAKernel<float>,
ops::PartialAllGatherOpCUDAKernel<double>, ops::PartialAllGatherOpCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::PartialAllGatherOpCUDAKernel<plat::bfloat16>,
#endif
ops::PartialAllGatherOpCUDAKernel<int>, ops::PartialAllGatherOpCUDAKernel<int>,
ops::PartialAllGatherOpCUDAKernel<int64_t>, ops::PartialAllGatherOpCUDAKernel<int64_t>,
ops::PartialAllGatherOpCUDAKernel<plat::float16>); ops::PartialAllGatherOpCUDAKernel<plat::float16>);
...@@ -121,6 +121,9 @@ namespace plat = paddle::platform; ...@@ -121,6 +121,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(partial_recv, REGISTER_OP_CUDA_KERNEL(partial_recv,
ops::PartialRecvOpCUDAKernel<float>, ops::PartialRecvOpCUDAKernel<float>,
ops::PartialRecvOpCUDAKernel<double>, ops::PartialRecvOpCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::PartialRecvOpCUDAKernel<plat::bfloat16>,
#endif
ops::PartialRecvOpCUDAKernel<int>, ops::PartialRecvOpCUDAKernel<int>,
ops::PartialRecvOpCUDAKernel<int64_t>, ops::PartialRecvOpCUDAKernel<int64_t>,
ops::PartialRecvOpCUDAKernel<plat::float16>); ops::PartialRecvOpCUDAKernel<plat::float16>);
...@@ -120,6 +120,9 @@ namespace plat = paddle::platform; ...@@ -120,6 +120,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(partial_send, REGISTER_OP_CUDA_KERNEL(partial_send,
ops::PartialSendCUDAKernel<float>, ops::PartialSendCUDAKernel<float>,
ops::PartialSendCUDAKernel<double>, ops::PartialSendCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::PartialSendCUDAKernel<plat::bfloat16>,
#endif
ops::PartialSendCUDAKernel<int>, ops::PartialSendCUDAKernel<int>,
ops::PartialSendCUDAKernel<int64_t>, ops::PartialSendCUDAKernel<int64_t>,
ops::PartialSendCUDAKernel<plat::float16>); ops::PartialSendCUDAKernel<plat::float16>);
...@@ -236,7 +236,7 @@ namespace plat = paddle::platform; ...@@ -236,7 +236,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(recv_v2, REGISTER_OP_CUDA_KERNEL(recv_v2,
ops::RecvOpV2CUDAKernel<float>, ops::RecvOpV2CUDAKernel<float>,
ops::RecvOpV2CUDAKernel<double>, ops::RecvOpV2CUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::RecvOpV2CUDAKernel<plat::bfloat16>, ops::RecvOpV2CUDAKernel<plat::bfloat16>,
#endif #endif
ops::RecvOpV2CUDAKernel<int>, ops::RecvOpV2CUDAKernel<int>,
......
...@@ -222,7 +222,7 @@ namespace plat = paddle::platform; ...@@ -222,7 +222,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(send_v2, REGISTER_OP_CUDA_KERNEL(send_v2,
ops::SendOpV2CUDAKernel<float>, ops::SendOpV2CUDAKernel<float>,
ops::SendOpV2CUDAKernel<double>, ops::SendOpV2CUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::SendOpV2CUDAKernel<plat::bfloat16>, ops::SendOpV2CUDAKernel<plat::bfloat16>,
#endif #endif
ops::SendOpV2CUDAKernel<int>, ops::SendOpV2CUDAKernel<int>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册