未验证 提交 ab369976 编写于 作者: C Chen Weihang 提交者: GitHub

remove fluid symbol depend in sync bn (#47122)

上级 d00b7d83
......@@ -18,6 +18,26 @@
#include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h"
namespace phi {
namespace detail {
ccl::CCLComm GetCCLComm(const Place &place, int global_gid) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ncclComm_t comm = nullptr;
if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has(
global_gid)) {
auto *nccl_pg = static_cast<paddle::distributed::ProcessGroupNCCL *>(
paddle::distributed::ProcessGroupMapFromGid::getInstance()->get(
global_gid));
comm = nccl_pg->NCCLComm(place);
}
return comm;
#else
return nullptr;
#endif
}
} // namespace detail
template <typename T, typename Context>
void SyncBatchNormKernel(const Context &ctx,
......@@ -102,16 +122,8 @@ void SyncBatchNormKernel(const Context &ctx,
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int global_gid = 0;
ncclComm_t comm = nullptr;
if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has(
global_gid)) {
auto *nccl_pg = static_cast<paddle::distributed::ProcessGroupNCCL *>(
paddle::distributed::ProcessGroupMapFromGid::getInstance()->get(
global_gid));
comm = nccl_pg->NCCLComm(x.place());
} else {
ncclComm_t comm = static_cast<ncclComm_t>(detail::GetCCLComm(x.place(), 0));
if (comm == nullptr) {
comm = ctx.nccl_comm();
}
......
......@@ -16,9 +16,23 @@
#include <string>
#include "paddle/phi/backends/c_comm_lib.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
namespace detail {
// FIXME(paddle-dev): Since the singleton of ProcessGroup in fluid is used in
// SyncBN, the fluid symbol will be dependent on external hardware access.
// Here, the part that depends on the fluid symbol is individually encapsulated
// as a temporary function to isolate external symbol dependencies.
// In the future, the dependence on the singleton in fluid in SyncBN needs
// to be removed.
// In principle, the PHI Kernel cannot use the global singleton internally,
// and the required members need to be passed in from the eucalyptus tree.
ccl::CCLComm GetCCLComm(const Place& place, int global_gid = 0);
} // namespace detail
template <typename T, typename Context>
void SyncBatchNormKernel(const Context& dev_ctx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册