未验证 提交 1e965756 编写于 作者: L LiYuRio 提交者: GitHub

fix nccl comm in sync_bn (#45100)

上级 acb78ea2
cc_library(
processgroup
SRCS ProcessGroup.cc
DEPS phi_api eager_api)
DEPS dense_tensor)
cc_library(
eager_reducer
SRCS reducer.cc
......@@ -18,7 +18,8 @@ if(WITH_NCCL OR WITH_RCCL)
cc_library(
processgroup_nccl
SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc
DEPS place enforce collective_helper device_context phi_api eager_api)
DEPS processgroup place enforce collective_helper device_context
dense_tensor)
if(WITH_DISTRIBUTE AND WITH_PSCORE)
cc_library(
processgroup_heter
......
......@@ -15,7 +15,6 @@
#pragma once
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
......
......@@ -18,7 +18,7 @@
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/common/place.h"
DECLARE_bool(nccl_blocking_wait);
......@@ -427,9 +427,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
platform::CUDADeviceGuard gpuGuard;
for (auto& place : places) {
gpuGuard.SetDeviceIndex(place.GetDeviceId());
auto dt = full({1}, 0, phi::DataType::FLOAT32, place);
barrierTensors.push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(dt.impl()));
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim({1}));
auto allocator = std::unique_ptr<phi::Allocator>(
new paddle::experimental::DefaultAllocator(place));
barrierTensors.emplace_back(allocator.get(), meta);
}
auto task = ProcessGroupNCCL::AllReduce(barrierTensors, barrierTensors);
auto nccl_task = dynamic_cast<ProcessGroupNCCL::NCCLTask*>(task.get());
......@@ -894,5 +895,15 @@ void ProcessGroupNCCL::GroupEnd() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}
ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_ncclcomm_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_ncclcomm_.end(),
platform::errors::InvalidArgument(
"Cannot find nccl comm in process group."));
return iter->second[0]->GetNcclComm();
}
} // namespace distributed
} // namespace paddle
......@@ -157,6 +157,8 @@ class ProcessGroupNCCL : public ProcessGroup {
static void GroupEnd();
ncclComm_t NCCLComm(const Place& place) const;
protected:
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places,
......
......@@ -84,6 +84,11 @@ set(COMMON_KERNEL_DEPS
gpc
utf8proc)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup)
if(WITH_NCCL OR WITH_RCCL)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup_nccl)
endif()
copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
file(GLOB kernel_h "*.h" "selected_rows/*.h" "sparse/*.h" "strings/*.h")
......
......@@ -100,7 +100,19 @@ void SyncBatchNormKernel(const Context &ctx,
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *comm = ctx.nccl_comm();
int global_gid = 0;
ncclComm_t comm = nullptr;
if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has(
global_gid)) {
auto *nccl_pg = static_cast<paddle::distributed::ProcessGroupNCCL *>(
paddle::distributed::ProcessGroupMapFromGid::getInstance()->get(
global_gid));
comm = nccl_pg->NCCLComm(x.place());
} else {
comm = ctx.nccl_comm();
}
if (comm) {
int dtype = paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(mean_out->dtype()));
......@@ -113,6 +125,7 @@ void SyncBatchNormKernel(const Context &ctx,
ncclSum,
comm,
stream));
VLOG(3) << "Sync result using all reduce";
}
#endif
......
......@@ -26,6 +26,10 @@ limitations under the License. */
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
......@@ -411,7 +415,19 @@ void SyncBatchNormGradFunctor(
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *comm = ctx.nccl_comm();
int global_gid = 0;
ncclComm_t comm = nullptr;
if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has(
global_gid)) {
auto *nccl_pg = static_cast<paddle::distributed::ProcessGroupNCCL *>(
paddle::distributed::ProcessGroupMapFromGid::getInstance()->get(
global_gid));
comm = nccl_pg->NCCLComm(x->place());
} else {
comm = ctx.nccl_comm();
}
if (comm) {
int dtype = paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(scale.dtype()));
......@@ -424,6 +440,7 @@ void SyncBatchNormGradFunctor(
ncclSum,
comm,
stream));
VLOG(3) << "Sync result using all reduce";
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册