未验证 提交 9a0de116 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

[AMP]register bf16 for communication ops (#52555)

* register bf16 for communication ops

* fix bfloat16 type finding compile error in c_allreduce_max_op
上级 8da89b81
...@@ -28,7 +28,11 @@ PD_REGISTER_STRUCT_KERNEL(c_allreduce_max, ...@@ -28,7 +28,11 @@ PD_REGISTER_STRUCT_KERNEL(c_allreduce_max,
ALL_LAYOUT, ALL_LAYOUT,
ops::CAllReduceMaxCUDAKernel, ops::CAllReduceMaxCUDAKernel,
float, float,
#if NCCL_VERSION_CODE >= 21000
plat::bfloat16,
#endif
double, double,
int, int,
int64_t, int64_t,
plat::float16) {} plat::float16) {
}
...@@ -123,4 +123,8 @@ PD_REGISTER_STRUCT_KERNEL(c_split, ...@@ -123,4 +123,8 @@ PD_REGISTER_STRUCT_KERNEL(c_split,
double, double,
int, int,
int64_t, int64_t,
plat::float16) {} #if NCCL_VERSION_CODE >= 21000
plat::bfloat16,
#endif
plat::float16) {
}
...@@ -104,6 +104,9 @@ namespace plat = paddle::platform; ...@@ -104,6 +104,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(partial_allgather, REGISTER_OP_CUDA_KERNEL(partial_allgather,
ops::PartialAllGatherOpCUDAKernel<float>, ops::PartialAllGatherOpCUDAKernel<float>,
#if NCCL_VERSION_CODE >= 21000
ops::PartialAllGatherOpCUDAKernel<plat::bfloat16>,
#endif
ops::PartialAllGatherOpCUDAKernel<double>, ops::PartialAllGatherOpCUDAKernel<double>,
ops::PartialAllGatherOpCUDAKernel<int>, ops::PartialAllGatherOpCUDAKernel<int>,
ops::PartialAllGatherOpCUDAKernel<int64_t>, ops::PartialAllGatherOpCUDAKernel<int64_t>,
......
...@@ -120,6 +120,9 @@ namespace plat = paddle::platform; ...@@ -120,6 +120,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(partial_recv, REGISTER_OP_CUDA_KERNEL(partial_recv,
ops::PartialRecvOpCUDAKernel<float>, ops::PartialRecvOpCUDAKernel<float>,
#if NCCL_VERSION_CODE >= 21000
ops::PartialRecvOpCUDAKernel<plat::bfloat16>,
#endif
ops::PartialRecvOpCUDAKernel<double>, ops::PartialRecvOpCUDAKernel<double>,
ops::PartialRecvOpCUDAKernel<int>, ops::PartialRecvOpCUDAKernel<int>,
ops::PartialRecvOpCUDAKernel<int64_t>, ops::PartialRecvOpCUDAKernel<int64_t>,
......
...@@ -120,6 +120,9 @@ namespace plat = paddle::platform; ...@@ -120,6 +120,9 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(partial_send, REGISTER_OP_CUDA_KERNEL(partial_send,
ops::PartialSendCUDAKernel<float>, ops::PartialSendCUDAKernel<float>,
ops::PartialSendCUDAKernel<double>, ops::PartialSendCUDAKernel<double>,
#if NCCL_VERSION_CODE >= 21000
ops::PartialSendCUDAKernel<plat::bfloat16>,
#endif
ops::PartialSendCUDAKernel<int>, ops::PartialSendCUDAKernel<int>,
ops::PartialSendCUDAKernel<int64_t>, ops::PartialSendCUDAKernel<int64_t>,
ops::PartialSendCUDAKernel<plat::float16>); ops::PartialSendCUDAKernel<plat::float16>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册