未验证 提交 9a0de116 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

[AMP]register bf16 for communication ops (#52555)

* register bf16 for communication ops

* fix bfloat16 type finding compile error in c_allreduce_max_op
上级 8da89b81
......@@ -28,7 +28,11 @@ PD_REGISTER_STRUCT_KERNEL(c_allreduce_max,
ALL_LAYOUT,
ops::CAllReduceMaxCUDAKernel,
float,
#if NCCL_VERSION_CODE >= 21000
plat::bfloat16,
#endif
double,
int,
int64_t,
plat::float16) {}
plat::float16) {
}
......@@ -123,4 +123,8 @@ PD_REGISTER_STRUCT_KERNEL(c_split,
double,
int,
int64_t,
plat::float16) {}
#if NCCL_VERSION_CODE >= 21000
plat::bfloat16,
#endif
plat::float16) {
}
......@@ -104,6 +104,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(partial_allgather,
ops::PartialAllGatherOpCUDAKernel<float>,
#if NCCL_VERSION_CODE >= 21000
ops::PartialAllGatherOpCUDAKernel<plat::bfloat16>,
#endif
ops::PartialAllGatherOpCUDAKernel<double>,
ops::PartialAllGatherOpCUDAKernel<int>,
ops::PartialAllGatherOpCUDAKernel<int64_t>,
......
......@@ -120,6 +120,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(partial_recv,
ops::PartialRecvOpCUDAKernel<float>,
#if NCCL_VERSION_CODE >= 21000
ops::PartialRecvOpCUDAKernel<plat::bfloat16>,
#endif
ops::PartialRecvOpCUDAKernel<double>,
ops::PartialRecvOpCUDAKernel<int>,
ops::PartialRecvOpCUDAKernel<int64_t>,
......
......@@ -120,6 +120,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(partial_send,
ops::PartialSendCUDAKernel<float>,
ops::PartialSendCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::PartialSendCUDAKernel<plat::bfloat16>,
#endif
ops::PartialSendCUDAKernel<int>,
ops::PartialSendCUDAKernel<int64_t>,
ops::PartialSendCUDAKernel<plat::float16>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册