未验证 提交 ae14bad1 编写于 作者: W Wen Sun 提交者: GitHub

refactor: ProcessGroupNCCL (#47740)

上级 87d97246
...@@ -350,14 +350,6 @@ class ProcessGroup { ...@@ -350,14 +350,6 @@ class ProcessGroup {
GetBackendName())); GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> _ReduceScatterBase(
phi::DenseTensor&, // NOLINT
phi::DenseTensor&, // NOLINT
const ReduceScatterOptions&) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support ReduceScatter", GetBackendName()));
}
protected: protected:
const int rank_; const int rank_;
const int size_; const int size_;
......
...@@ -33,7 +33,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place, ...@@ -33,7 +33,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place,
bool use_calc_stream) bool use_calc_stream)
: TaskStream(rank, comm_type, sync_op, use_calc_stream), : TaskStream(rank, comm_type, sync_op, use_calc_stream),
comm_event_(place), comm_event_(place),
place_(place) {} task_place_(place) {}
ProcessGroupNCCL::NCCLTask::~NCCLTask() {} ProcessGroupNCCL::NCCLTask::~NCCLTask() {}
...@@ -53,8 +53,9 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { ...@@ -53,8 +53,9 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
return true; return true;
} }
const auto* calc_ctx = platform::DeviceContextPool::Instance().Get(place_); const auto* calc_ctx =
comm_event_.Wait(platform::Place2DeviceType(place_), calc_ctx); platform::DeviceContextPool::Instance().Get(task_place_);
comm_event_.Wait(platform::Place2DeviceType(task_place_), calc_ctx);
if (FLAGS_nccl_blocking_wait) { if (FLAGS_nccl_blocking_wait) {
// NOTE(shenliang03): It will block host for sync // NOTE(shenliang03): It will block host for sync
...@@ -63,7 +64,7 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { ...@@ -63,7 +64,7 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
} }
} }
if (barrier_) { if (IsBlockCPUInWait()) {
// If we use the work to do barrier, we should block cpu // If we use the work to do barrier, we should block cpu
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
...@@ -192,7 +193,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier( ...@@ -192,7 +193,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ false); /*use_calc_stream*/ false);
auto nccl_task = dynamic_cast<NCCLTask*>(task.get()); auto nccl_task = dynamic_cast<NCCLTask*>(task.get());
nccl_task->barrier_ = true; nccl_task->SetBlockCPUInWait();
return task; return task;
} }
...@@ -250,6 +251,10 @@ void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) { ...@@ -250,6 +251,10 @@ void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) {
void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
const std::string& place_key) { const std::string& place_key) {
if (place_to_comm_ctx_.size() > 0) {
VLOG(3) << "Warning: Tensors from multiple devices are not supported yet.";
}
ncclUniqueId nccl_id; ncclUniqueId nccl_id;
if (rank_ == 0) { if (rank_ == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id)); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id));
...@@ -260,7 +265,6 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, ...@@ -260,7 +265,6 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
<< ", place: " << place_key << ", place: " << place_key
<< ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id); << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);
calc_event_ = std::make_shared<platform::DeviceEvent>(place);
auto* calc_ctx = static_cast<phi::GPUContext*>( auto* calc_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
auto comm_ctx = std::make_unique<phi::GPUContext>(place); auto comm_ctx = std::make_unique<phi::GPUContext>(place);
...@@ -269,20 +273,23 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, ...@@ -269,20 +273,23 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
&nccl_comm, GetSize(), nccl_id, GetRank())); &nccl_comm, GetSize(), nccl_id, GetRank()));
comm_ctx->set_nccl_comm(nccl_comm); comm_ctx->set_nccl_comm(nccl_comm);
place_to_calc_ctx_[place_key] = calc_ctx; place_to_calc_event_.emplace(place_key, place);
place_to_comm_ctx_[place_key] = std::move(comm_ctx); place_to_calc_ctx_.emplace(place_key, calc_ctx);
place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));
// TODO(sunyilun): for compatibility, will be removed later // TODO(sunyilun): for compatibility, will be removed later
places_to_ctx_[place_key] = {place_to_comm_ctx_[place_key].get()}; std::vector<phi::GPUContext*> comm_ctx_wrapper{
place_to_comm_ctx_[place_key].get()};
places_to_ctx_.emplace(place_key, comm_ctx_wrapper);
} }
void ProcessGroupNCCL::SyncCalcStream( void ProcessGroupNCCL::SyncCalcStream(const Place& place) {
const Place& place, const std::shared_ptr<platform::DeviceEvent>& event) {
const std::string& key = GetKeyFromPlace(place); const std::string& key = GetKeyFromPlace(place);
const auto* calc_ctx = place_to_calc_ctx_[key]; auto& calc_event = place_to_calc_event_.at(key);
const auto* comm_ctx = place_to_comm_ctx_[key].get(); const auto* calc_ctx = place_to_calc_ctx_.at(key);
event->Record(calc_ctx); const auto* comm_ctx = place_to_comm_ctx_.at(key).get();
event->Wait(platform::Place2DeviceType(place), comm_ctx); calc_event.Record(calc_ctx);
calc_event.Wait(platform::Place2DeviceType(place), comm_ctx);
} }
template <typename Fn> template <typename Fn>
...@@ -296,26 +303,29 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -296,26 +303,29 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
const auto& place = in_tensor.place(); const auto& place = in_tensor.place();
const auto& key = GetKeyFromPlace(place); const auto& key = GetKeyFromPlace(place);
if (!calc_event_) { platform::CUDADeviceGuard cuda_guard(place);
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLEnvCache(place, key); CreateNCCLEnvCache(place, key);
} }
if (!use_calc_stream) { if (!use_calc_stream) {
SyncCalcStream(place, calc_event_); SyncCalcStream(place);
} }
auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream); auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);
const auto* calc_ctx = place_to_calc_ctx_[key]; const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto& comm_ctx = place_to_comm_ctx_[key]; const auto& comm_ctx = place_to_comm_ctx_.at(key);
auto nccl_comm = comm_ctx->nccl_comm();
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
fn(out_tensor, in_tensor, comm_ctx->nccl_comm(), nccl_stream); fn(out_tensor, in_tensor, nccl_comm, nccl_stream);
if (!use_calc_stream) { if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) { if (FLAGS_use_stream_safe_cuda_allocator) {
memory::RecordStream(in_tensor.Holder(), nccl_stream); memory::RecordStream(in_tensor.Holder(), nccl_stream);
} }
task->comm_event_.Record(comm_ctx.get()); task->UpdateWaitChain(*comm_ctx);
} }
return task; return task;
...@@ -352,13 +362,13 @@ void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes, ...@@ -352,13 +362,13 @@ void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
// TODO(sunyilun): methods below will be removed later // TODO(sunyilun): methods below will be removed later
void SyncDefaultStream(const std::vector<Place>& places, void SyncDefaultStream(const std::vector<Place>& places,
const std::shared_ptr<platform::DeviceEvent>& nccl_event, platform::DeviceEvent& nccl_event, // NOLINT
std::vector<phi::GPUContext*>& dev_ctx) { // NOLINT std::vector<phi::GPUContext*>& dev_ctx) { // NOLINT
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
auto* default_ctx = static_cast<phi::GPUContext*>( auto* default_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(places[i])); platform::DeviceContextPool::Instance().Get(places[i]));
nccl_event->Record(default_ctx); nccl_event.Record(default_ctx);
nccl_event->Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]); nccl_event.Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]);
} }
} }
...@@ -389,7 +399,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask( ...@@ -389,7 +399,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<phi::DenseTensor>& inputs) const std::vector<phi::DenseTensor>& inputs)
: TaskStream(rank, inputs, CommType), : TaskStream(rank, inputs, CommType),
comm_event_(places[0]), comm_event_(places[0]),
place_(places[0]) {} task_place_(places[0]) {}
ProcessGroupNCCL::NCCLTask::NCCLTask( ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<Place>& places, const std::vector<Place>& places,
...@@ -400,7 +410,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask( ...@@ -400,7 +410,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(
bool use_calc_stream) bool use_calc_stream)
: TaskStream(rank, inputs, comm_type, sync_op, use_calc_stream), : TaskStream(rank, inputs, comm_type, sync_op, use_calc_stream),
comm_event_(places[0]), comm_event_(places[0]),
place_(places[0]) {} task_place_(places[0]) {}
// create NCCLManager cache for places_key // create NCCLManager cache for places_key
void ProcessGroupNCCL::CreateNCCLManagerCache( void ProcessGroupNCCL::CreateNCCLManagerCache(
...@@ -437,17 +447,18 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( ...@@ -437,17 +447,18 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
NCCLCHECK(platform::dynload::ncclCommInitRank( NCCLCHECK(platform::dynload::ncclCommInitRank(
&nccl_comm, GetSize(), nccl_id, GetRank())); &nccl_comm, GetSize(), nccl_id, GetRank()));
dev_ctx[i]->set_nccl_comm(nccl_comm); dev_ctx[i]->set_nccl_comm(nccl_comm);
dev_ctx_raw[i] = dev_ctx[i].get(); dev_ctx_raw[i] = dev_ctx[i].get();
} }
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
calc_event_ = std::make_shared<platform::DeviceEvent>(places[0]);
// TODO(sunyilun): for compatibility, will be removed later // TODO(sunyilun): for compatibility, will be removed later
place_to_calc_ctx_[places_key] = static_cast<phi::GPUContext*>( place_to_calc_event_.emplace(places_key, places[0]);
platform::DeviceContextPool::Instance().Get(places[0])); place_to_calc_ctx_.emplace(
place_to_comm_ctx_[places_key] = std::move(dev_ctx[0]); places_key,
static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(places[0])));
place_to_comm_ctx_.emplace(places_key, std::move(dev_ctx[0]));
// These caches will be useful to process sync/wait/communicate // These caches will be useful to process sync/wait/communicate
places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw)); places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw));
...@@ -466,13 +477,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -466,13 +477,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (!calc_event_) { if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLManagerCache(key, places); CreateNCCLManagerCache(key, places);
} }
} }
if (!use_calc_stream) { if (!use_calc_stream) {
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); SyncDefaultStream(
places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
} }
auto task = auto task =
...@@ -492,12 +504,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -492,12 +504,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
platform::DeviceContextPool::Instance().Get(places[i])) platform::DeviceContextPool::Instance().Get(places[i]))
->stream(); ->stream();
} else { } else {
nccl_stream = places_to_ctx_[key][i]->stream(); nccl_stream = places_to_ctx_.at(key)[i]->stream();
} }
fn(inputs[i], fn(inputs[i],
outputs[i], outputs[i],
places_to_ctx_[key][i]->nccl_comm(), places_to_ctx_.at(key)[i]->nccl_comm(),
nccl_stream); nccl_stream);
} }
} }
...@@ -513,7 +525,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -513,7 +525,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
platform::DeviceContextPool::Instance().Get(places[i])) platform::DeviceContextPool::Instance().Get(places[i]))
->stream(); ->stream();
} else { } else {
nccl_stream = places_to_ctx_[key][i]->stream(); nccl_stream = places_to_ctx_.at(key)[i]->stream();
} }
memory::RecordStream(inputs[i].Holder(), nccl_stream); memory::RecordStream(inputs[i].Holder(), nccl_stream);
...@@ -524,7 +536,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -524,7 +536,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
if (!use_calc_stream) { if (!use_calc_stream) {
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
task->comm_event_.Record(places_to_ctx_[key][i]); task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
} }
} }
...@@ -542,12 +554,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -542,12 +554,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (!calc_event_) { if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLManagerCache(key, places); CreateNCCLManagerCache(key, places);
} }
} }
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); SyncDefaultStream(
places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
auto task = CreateTask(places, rank_, op_type, inputs); auto task = CreateTask(places, rank_, op_type, inputs);
...@@ -558,10 +571,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -558,10 +571,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
platform::NCCLGroupGuard nccl_guard; platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream(); const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
fn(inputs[i], fn(inputs[i],
outputs[i], outputs[i],
places_to_ctx_[key][i]->nccl_comm(), places_to_ctx_.at(key)[i]->nccl_comm(),
nccl_stream); nccl_stream);
} }
} }
...@@ -570,13 +583,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -570,13 +583,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
memory::RecordStream(inputs[i].Holder(), memory::RecordStream(inputs[i].Holder(),
places_to_ctx_[key][i]->stream()); places_to_ctx_.at(key)[i]->stream());
} }
} }
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
task->comm_event_.Record(places_to_ctx_[key][i]); task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
} }
return task; return task;
} }
...@@ -592,26 +605,27 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in, ...@@ -592,26 +605,27 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (!calc_event_) { if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLManagerCache(key, places); CreateNCCLManagerCache(key, places);
} }
} }
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); SyncDefaultStream(
places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
// construct uninitialize guard for device // construct uninitialize guard for device
platform::CUDADeviceGuard cuda_guard; platform::CUDADeviceGuard cuda_guard;
if (FLAGS_use_stream_safe_cuda_allocator) { if (FLAGS_use_stream_safe_cuda_allocator) {
cuda_guard.SetDevice(places[0]); cuda_guard.SetDevice(places[0]);
memory::RecordStream(in->Holder(), places_to_ctx_[key][0]->stream()); memory::RecordStream(in->Holder(), places_to_ctx_.at(key)[0]->stream());
} }
{ {
platform::NCCLGroupGuard nccl_guard; platform::NCCLGroupGuard nccl_guard;
cuda_guard.SetDevice(places[0]); cuda_guard.SetDevice(places[0]);
const auto& nccl_stream = places_to_ctx_[key][0]->stream(); const auto& nccl_stream = places_to_ctx_.at(key)[0]->stream();
fn(in, out, places_to_ctx_[key][0]->nccl_comm(), nccl_stream); fn(in, out, places_to_ctx_.at(key)[0]->nccl_comm(), nccl_stream);
} }
cuda_guard.SetDevice(places[0]); cuda_guard.SetDevice(places[0]);
...@@ -630,13 +644,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -630,13 +644,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (!calc_event_) { if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLManagerCache(key, places); CreateNCCLManagerCache(key, places);
} }
} }
if (!use_calc_stream) { if (!use_calc_stream) {
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); SyncDefaultStream(
places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
} }
auto task = auto task =
...@@ -655,10 +670,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -655,10 +670,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
platform::DeviceContextPool::Instance().Get(places[i])) platform::DeviceContextPool::Instance().Get(places[i]))
->stream(); ->stream();
} else { } else {
nccl_stream = places_to_ctx_[key][i]->stream(); nccl_stream = places_to_ctx_.at(key)[i]->stream();
} }
fn(tensors[i], fn(tensors[i],
places_to_ctx_[key][i]->nccl_comm(), places_to_ctx_.at(key)[i]->nccl_comm(),
nccl_stream, nccl_stream,
dst_rank); dst_rank);
} }
...@@ -674,7 +689,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -674,7 +689,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
platform::DeviceContextPool::Instance().Get(places[i])) platform::DeviceContextPool::Instance().Get(places[i]))
->stream(); ->stream();
} else { } else {
nccl_stream = places_to_ctx_[key][i]->stream(); nccl_stream = places_to_ctx_.at(key)[i]->stream();
} }
memory::RecordStream(tensors[i].Holder(), nccl_stream); memory::RecordStream(tensors[i].Holder(), nccl_stream);
} }
...@@ -683,7 +698,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -683,7 +698,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
if (!use_calc_stream) { if (!use_calc_stream) {
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
task->comm_event_.Record(places_to_ctx_[key][i]); task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
} }
} }
...@@ -701,12 +716,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -701,12 +716,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (!calc_event_) { if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLManagerCache(key, places); CreateNCCLManagerCache(key, places);
} }
} }
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); SyncDefaultStream(
places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
auto task = CreateTask(places, rank_, op_type, tensors); auto task = CreateTask(places, rank_, op_type, tensors);
...@@ -717,9 +733,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -717,9 +733,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
platform::NCCLGroupGuard nccl_guard; platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream(); const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
fn(tensors[i], fn(tensors[i],
places_to_ctx_[key][i]->nccl_comm(), places_to_ctx_.at(key)[i]->nccl_comm(),
nccl_stream, nccl_stream,
dst_rank); dst_rank);
} }
...@@ -729,13 +745,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -729,13 +745,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
memory::RecordStream(tensors[i].Holder(), memory::RecordStream(tensors[i].Holder(),
places_to_ctx_[key][i]->stream()); places_to_ctx_.at(key)[i]->stream());
} }
} }
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
task->comm_event_.Record(places_to_ctx_[key][i]); task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
} }
return task; return task;
} }
...@@ -1608,49 +1624,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -1608,49 +1624,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
use_calc_stream); use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::_ReduceScatterBase(
phi::DenseTensor& out_tensor,
phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts) {
// auto tensor = out_tensors.back();
PADDLE_ENFORCE_EQ(
out_tensor.dtype(),
in_tensor.dtype(),
platform::errors::InvalidArgument(
"Input tensor and output tensor should be same dtype."));
PADDLE_ENFORCE_EQ(
out_tensor.numel() * size_,
in_tensor.numel(),
platform::errors::InvalidArgument("input tensor must be the same size as "
"output tensor size times world_size"));
auto inputs = std::vector<phi::DenseTensor>{in_tensor};
auto outputs = std::vector<phi::DenseTensor>{out_tensor};
return Collective(
inputs,
outputs,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
platform::CUDADeviceGuard cuda_guard;
cuda_guard.SetDevice(output.place());
memory::RecordStream(output.Holder(), stream);
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
input.data(),
output.data(),
output.numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
stream));
},
CommType::REDUCE_SCATTER);
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#pragma once #pragma once
#include <chrono> #include <chrono>
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -61,6 +60,9 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -61,6 +60,9 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
void Synchronize() override; void Synchronize() override;
void UpdateWaitChain(const phi::DeviceContext& ctx) override; void UpdateWaitChain(const phi::DeviceContext& ctx) override;
bool IsBlockCPUInWait() const { return block_cpu_in_wait_; }
void SetBlockCPUInWait() { block_cpu_in_wait_ = true; }
// TODO(sunyilun): methods below will be removed later // TODO(sunyilun): methods below will be removed later
NCCLTask(const std::vector<Place>& places, NCCLTask(const std::vector<Place>& places,
int rank, int rank,
...@@ -73,12 +75,10 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -73,12 +75,10 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
public:
bool barrier_{false};
platform::DeviceEvent comm_event_; // event on comm stream
private: private:
Place place_; bool block_cpu_in_wait_{false};
platform::DeviceEvent comm_event_; // event on comm stream
Place task_place_;
}; };
public: public:
...@@ -253,11 +253,6 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -253,11 +253,6 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> _ReduceScatterBase(
phi::DenseTensor&, // NOLINT
phi::DenseTensor&, // NOLINT
const ReduceScatterOptions&) override;
private: private:
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place, std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
int rank, int rank,
...@@ -278,8 +273,7 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -278,8 +273,7 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
void SyncCalcStream(const Place& place, void SyncCalcStream(const Place& place);
const std::shared_ptr<platform::DeviceEvent>& event);
// TODO(sunyilun): methods below will be removed later // TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask( std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
...@@ -342,7 +336,8 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -342,7 +336,8 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
private: private:
std::shared_ptr<Store> store_; std::shared_ptr<Store> store_;
std::shared_ptr<platform::DeviceEvent> calc_event_; // event on calc stream std::unordered_map<std::string, platform::DeviceEvent>
place_to_calc_event_; // event on calc stream
std::unordered_map<std::string, phi::GPUContext*> place_to_calc_ctx_; std::unordered_map<std::string, phi::GPUContext*> place_to_calc_ctx_;
std::unordered_map<std::string, std::unique_ptr<phi::GPUContext>> std::unordered_map<std::string, std::unique_ptr<phi::GPUContext>>
place_to_comm_ctx_; place_to_comm_ctx_;
......
...@@ -761,27 +761,6 @@ void BindDistributed(py::module *m) { ...@@ -761,27 +761,6 @@ void BindDistributed(py::module *m) {
py::arg("in"), py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("src"), py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def(
"_reduce_scatter_base",
[](distributed::ProcessGroup &self,
py::handle py_out_tensor,
py::handle py_in_tensor,
distributed::ReduceOp op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
distributed::ReduceScatterOptions opts;
opts.reduce_op = op;
auto dense_out = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
auto dense_in = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
return self._ReduceScatterBase(*dense_out, *dense_in, opts);
},
py::arg("out_tensor"),
py::arg("in_tensor"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
auto ProcessGroupStream = auto ProcessGroupStream =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册