未验证 提交 73473ac2 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

register bf16 for c ops (#52641)

上级 9431bae1
......@@ -20,6 +20,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_allreduce_max,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, float>,
#if NCCL_VERSION_CODE >= 21000
ops::CAllReduceOpCUDAKernel<ops::kRedMax, plat::bfloat16>,
#endif
ops::CAllReduceOpCUDAKernel<ops::kRedMax, double>,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, int>,
ops::CAllReduceOpCUDAKernel<ops::kRedMax, int64_t>,
......
......@@ -20,7 +20,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_allreduce_sum,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, float>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::bfloat16>,
#endif
ops::CAllReduceOpCUDAKernel<ops::kRedSum, double>,
......
......@@ -132,6 +132,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_concat,
ops::CConcatOpCUDAKernel<float>,
ops::CConcatOpCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::CConcatOpCUDAKernel<plat::bfloat16>,
#endif
ops::CConcatOpCUDAKernel<int>,
ops::CConcatOpCUDAKernel<int64_t>,
ops::CConcatOpCUDAKernel<plat::float16>);
......@@ -253,8 +253,10 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_embedding,
ops::CEmbeddingCUDAKernel<float>,
ops::CEmbeddingCUDAKernel<double>,
ops::CEmbeddingCUDAKernel<plat::bfloat16>,
ops::CEmbeddingCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(c_embedding_grad,
ops::CEmbeddingGradCUDAKernel<float>,
ops::CEmbeddingGradCUDAKernel<double>,
ops::CEmbeddingGradCUDAKernel<plat::bfloat16>,
ops::CEmbeddingGradCUDAKernel<plat::float16>);
......@@ -19,6 +19,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_identity,
ops::CIdentityOpKernel<float>,
#if NCCL_VERSION_CODE >= 21000
ops::CIdentityOpKernel<plat::bfloat16>,
#endif
ops::CIdentityOpKernel<double>,
ops::CIdentityOpKernel<int>,
ops::CIdentityOpKernel<int64_t>,
......
......@@ -117,6 +117,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_split,
ops::CSplitOpCUDAKernel<float>,
#if NCCL_VERSION_CODE >= 21000
ops::CSplitOpCUDAKernel<plat::bfloat16>,
#endif
ops::CSplitOpCUDAKernel<double>,
ops::CSplitOpCUDAKernel<int>,
ops::CSplitOpCUDAKernel<int64_t>,
......
......@@ -110,6 +110,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(partial_allgather,
ops::PartialAllGatherOpCUDAKernel<float>,
ops::PartialAllGatherOpCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::PartialAllGatherOpCUDAKernel<plat::bfloat16>,
#endif
ops::PartialAllGatherOpCUDAKernel<int>,
ops::PartialAllGatherOpCUDAKernel<int64_t>,
ops::PartialAllGatherOpCUDAKernel<plat::float16>);
......@@ -121,6 +121,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(partial_recv,
ops::PartialRecvOpCUDAKernel<float>,
ops::PartialRecvOpCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::PartialRecvOpCUDAKernel<plat::bfloat16>,
#endif
ops::PartialRecvOpCUDAKernel<int>,
ops::PartialRecvOpCUDAKernel<int64_t>,
ops::PartialRecvOpCUDAKernel<plat::float16>);
......@@ -120,6 +120,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(partial_send,
ops::PartialSendCUDAKernel<float>,
ops::PartialSendCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::PartialSendCUDAKernel<plat::bfloat16>,
#endif
ops::PartialSendCUDAKernel<int>,
ops::PartialSendCUDAKernel<int64_t>,
ops::PartialSendCUDAKernel<plat::float16>);
......@@ -236,7 +236,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(recv_v2,
ops::RecvOpV2CUDAKernel<float>,
ops::RecvOpV2CUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
ops::RecvOpV2CUDAKernel<plat::bfloat16>,
#endif
ops::RecvOpV2CUDAKernel<int>,
......
......@@ -222,7 +222,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(send_v2,
ops::SendOpV2CUDAKernel<float>,
ops::SendOpV2CUDAKernel<double>,
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
ops::SendOpV2CUDAKernel<plat::bfloat16>,
#endif
ops::SendOpV2CUDAKernel<int>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册