diff --git a/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu b/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu index 0369c9c410ab96a90c8a2c4d6df4e7aec20ea9c2..106b3d66427a8b8b0655d9fe6306a5ed1b7c0f8b 100644 --- a/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu @@ -18,6 +18,26 @@ #include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h" namespace phi { +namespace detail { + +ccl::CCLComm GetCCLComm(const Place &place, int global_gid) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + ncclComm_t comm = nullptr; + + if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has( + global_gid)) { + auto *nccl_pg = static_cast( + paddle::distributed::ProcessGroupMapFromGid::getInstance()->get( + global_gid)); + comm = nccl_pg->NCCLComm(place); + } + return comm; +#else + return nullptr; +#endif +} + +} // namespace detail template void SyncBatchNormKernel(const Context &ctx, @@ -102,16 +122,8 @@ void SyncBatchNormKernel(const Context &ctx, } #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - int global_gid = 0; - ncclComm_t comm = nullptr; - - if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has( - global_gid)) { - auto *nccl_pg = static_cast( - paddle::distributed::ProcessGroupMapFromGid::getInstance()->get( - global_gid)); - comm = nccl_pg->NCCLComm(x.place()); - } else { + ncclComm_t comm = static_cast(detail::GetCCLComm(x.place(), 0)); + if (comm == nullptr) { comm = ctx.nccl_comm(); } diff --git a/paddle/phi/kernels/sync_batch_norm_kernel.h b/paddle/phi/kernels/sync_batch_norm_kernel.h index 5071eaabf8653404951566c44b0294ef3b4441c7..a4909deb648cfe20cb9656dd41901a6d2ce0b2c6 100644 --- a/paddle/phi/kernels/sync_batch_norm_kernel.h +++ b/paddle/phi/kernels/sync_batch_norm_kernel.h @@ -16,9 +16,23 @@ #include +#include "paddle/phi/backends/c_comm_lib.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { +namespace detail { + +// FIXME(paddle-dev): Since the singleton of ProcessGroup in fluid is used in +// SyncBN, the fluid symbol will be dependent on external hardware access. +// Here, the part that depends on the fluid symbol is individually encapsulated +// as a temporary function to isolate external symbol dependencies. +// In the future, the dependence on the singleton in fluid in SyncBN needs +// to be removed. +// In principle, the PHI Kernel cannot use the global singleton internally, +// and the required members need to be passed in from the eucalyptus tree. +ccl::CCLComm GetCCLComm(const Place& place, int global_gid = 0); + +} // namespace detail template void SyncBatchNormKernel(const Context& dev_ctx,