未验证 提交 1e965756 编写于 作者: L LiYuRio 提交者: GitHub

fix nccl comm in sync_bn (#45100)

上级 acb78ea2
cc_library( cc_library(
processgroup processgroup
SRCS ProcessGroup.cc SRCS ProcessGroup.cc
DEPS phi_api eager_api) DEPS dense_tensor)
cc_library( cc_library(
eager_reducer eager_reducer
SRCS reducer.cc SRCS reducer.cc
...@@ -18,7 +18,8 @@ if(WITH_NCCL OR WITH_RCCL) ...@@ -18,7 +18,8 @@ if(WITH_NCCL OR WITH_RCCL)
cc_library( cc_library(
processgroup_nccl processgroup_nccl
SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc
DEPS place enforce collective_helper device_context phi_api eager_api) DEPS processgroup place enforce collective_helper device_context
dense_tensor)
if(WITH_DISTRIBUTE AND WITH_PSCORE) if(WITH_DISTRIBUTE AND WITH_PSCORE)
cc_library( cc_library(
processgroup_heter processgroup_heter
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#pragma once #pragma once
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace paddle { namespace paddle {
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
DECLARE_bool(nccl_blocking_wait); DECLARE_bool(nccl_blocking_wait);
...@@ -427,9 +427,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier( ...@@ -427,9 +427,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
platform::CUDADeviceGuard gpuGuard; platform::CUDADeviceGuard gpuGuard;
for (auto& place : places) { for (auto& place : places) {
gpuGuard.SetDeviceIndex(place.GetDeviceId()); gpuGuard.SetDeviceIndex(place.GetDeviceId());
auto dt = full({1}, 0, phi::DataType::FLOAT32, place); phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim({1}));
barrierTensors.push_back( auto allocator = std::unique_ptr<phi::Allocator>(
*std::dynamic_pointer_cast<phi::DenseTensor>(dt.impl())); new paddle::experimental::DefaultAllocator(place));
barrierTensors.emplace_back(allocator.get(), meta);
} }
auto task = ProcessGroupNCCL::AllReduce(barrierTensors, barrierTensors); auto task = ProcessGroupNCCL::AllReduce(barrierTensors, barrierTensors);
auto nccl_task = dynamic_cast<ProcessGroupNCCL::NCCLTask*>(task.get()); auto nccl_task = dynamic_cast<ProcessGroupNCCL::NCCLTask*>(task.get());
...@@ -894,5 +895,15 @@ void ProcessGroupNCCL::GroupEnd() { ...@@ -894,5 +895,15 @@ void ProcessGroupNCCL::GroupEnd() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
} }
ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_ncclcomm_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_ncclcomm_.end(),
platform::errors::InvalidArgument(
"Cannot find nccl comm in process group."));
return iter->second[0]->GetNcclComm();
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -157,6 +157,8 @@ class ProcessGroupNCCL : public ProcessGroup { ...@@ -157,6 +157,8 @@ class ProcessGroupNCCL : public ProcessGroup {
static void GroupEnd(); static void GroupEnd();
ncclComm_t NCCLComm(const Place& place) const;
protected: protected:
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask( virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places, std::vector<Place> places,
......
...@@ -84,6 +84,11 @@ set(COMMON_KERNEL_DEPS ...@@ -84,6 +84,11 @@ set(COMMON_KERNEL_DEPS
gpc gpc
utf8proc) utf8proc)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup)
if(WITH_NCCL OR WITH_RCCL)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup_nccl)
endif()
copy_if_different(${kernel_declare_file} ${kernel_declare_file_final}) copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
file(GLOB kernel_h "*.h" "selected_rows/*.h" "sparse/*.h" "strings/*.h") file(GLOB kernel_h "*.h" "selected_rows/*.h" "sparse/*.h" "strings/*.h")
......
...@@ -100,7 +100,19 @@ void SyncBatchNormKernel(const Context &ctx, ...@@ -100,7 +100,19 @@ void SyncBatchNormKernel(const Context &ctx,
} }
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *comm = ctx.nccl_comm(); int global_gid = 0;
ncclComm_t comm = nullptr;
if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has(
global_gid)) {
auto *nccl_pg = static_cast<paddle::distributed::ProcessGroupNCCL *>(
paddle::distributed::ProcessGroupMapFromGid::getInstance()->get(
global_gid));
comm = nccl_pg->NCCLComm(x.place());
} else {
comm = ctx.nccl_comm();
}
if (comm) { if (comm) {
int dtype = paddle::platform::ToNCCLDataType( int dtype = paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(mean_out->dtype())); paddle::framework::TransToProtoVarType(mean_out->dtype()));
...@@ -113,6 +125,7 @@ void SyncBatchNormKernel(const Context &ctx, ...@@ -113,6 +125,7 @@ void SyncBatchNormKernel(const Context &ctx,
ncclSum, ncclSum,
comm, comm,
stream)); stream));
VLOG(3) << "Sync result using all reduce";
} }
#endif #endif
......
...@@ -26,6 +26,10 @@ limitations under the License. */ ...@@ -26,6 +26,10 @@ limitations under the License. */
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
namespace cub = hipcub; namespace cub = hipcub;
#endif #endif
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
...@@ -411,7 +415,19 @@ void SyncBatchNormGradFunctor( ...@@ -411,7 +415,19 @@ void SyncBatchNormGradFunctor(
} }
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *comm = ctx.nccl_comm(); int global_gid = 0;
ncclComm_t comm = nullptr;
if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has(
global_gid)) {
auto *nccl_pg = static_cast<paddle::distributed::ProcessGroupNCCL *>(
paddle::distributed::ProcessGroupMapFromGid::getInstance()->get(
global_gid));
comm = nccl_pg->NCCLComm(x->place());
} else {
comm = ctx.nccl_comm();
}
if (comm) { if (comm) {
int dtype = paddle::platform::ToNCCLDataType( int dtype = paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(scale.dtype())); paddle::framework::TransToProtoVarType(scale.dtype()));
...@@ -424,6 +440,7 @@ void SyncBatchNormGradFunctor( ...@@ -424,6 +440,7 @@ void SyncBatchNormGradFunctor(
ncclSum, ncclSum,
comm, comm,
stream)); stream));
VLOG(3) << "Sync result using all reduce";
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册