未验证 提交 ae14bad1 编写于 作者: W Wen Sun 提交者: GitHub

refactor: ProcessGroupNCCL (#47740)

上级 87d97246
......@@ -350,14 +350,6 @@ class ProcessGroup {
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> _ReduceScatterBase(
phi::DenseTensor&, // NOLINT
phi::DenseTensor&, // NOLINT
const ReduceScatterOptions&) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support ReduceScatter", GetBackendName()));
}
protected:
const int rank_;
const int size_;
......
......@@ -33,7 +33,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place,
bool use_calc_stream)
: TaskStream(rank, comm_type, sync_op, use_calc_stream),
comm_event_(place),
place_(place) {}
task_place_(place) {}
ProcessGroupNCCL::NCCLTask::~NCCLTask() {}
......@@ -53,8 +53,9 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
return true;
}
const auto* calc_ctx = platform::DeviceContextPool::Instance().Get(place_);
comm_event_.Wait(platform::Place2DeviceType(place_), calc_ctx);
const auto* calc_ctx =
platform::DeviceContextPool::Instance().Get(task_place_);
comm_event_.Wait(platform::Place2DeviceType(task_place_), calc_ctx);
if (FLAGS_nccl_blocking_wait) {
// NOTE(shenliang03): It will block host for sync
......@@ -63,7 +64,7 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
}
}
if (barrier_) {
if (IsBlockCPUInWait()) {
// If we use the work to do barrier, we should block cpu
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
......@@ -192,7 +193,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
/*sync_op*/ true,
/*use_calc_stream*/ false);
auto nccl_task = dynamic_cast<NCCLTask*>(task.get());
nccl_task->barrier_ = true;
nccl_task->SetBlockCPUInWait();
return task;
}
......@@ -250,6 +251,10 @@ void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) {
void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
const std::string& place_key) {
if (place_to_comm_ctx_.size() > 0) {
VLOG(3) << "Warning: Tensors from multiple devices are not supported yet.";
}
ncclUniqueId nccl_id;
if (rank_ == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id));
......@@ -260,7 +265,6 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
<< ", place: " << place_key
<< ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);
calc_event_ = std::make_shared<platform::DeviceEvent>(place);
auto* calc_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place));
auto comm_ctx = std::make_unique<phi::GPUContext>(place);
......@@ -269,20 +273,23 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
&nccl_comm, GetSize(), nccl_id, GetRank()));
comm_ctx->set_nccl_comm(nccl_comm);
place_to_calc_ctx_[place_key] = calc_ctx;
place_to_comm_ctx_[place_key] = std::move(comm_ctx);
place_to_calc_event_.emplace(place_key, place);
place_to_calc_ctx_.emplace(place_key, calc_ctx);
place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));
// TODO(sunyilun): for compatibility, will be removed later
places_to_ctx_[place_key] = {place_to_comm_ctx_[place_key].get()};
std::vector<phi::GPUContext*> comm_ctx_wrapper{
place_to_comm_ctx_[place_key].get()};
places_to_ctx_.emplace(place_key, comm_ctx_wrapper);
}
void ProcessGroupNCCL::SyncCalcStream(
const Place& place, const std::shared_ptr<platform::DeviceEvent>& event) {
void ProcessGroupNCCL::SyncCalcStream(const Place& place) {
const std::string& key = GetKeyFromPlace(place);
const auto* calc_ctx = place_to_calc_ctx_[key];
const auto* comm_ctx = place_to_comm_ctx_[key].get();
event->Record(calc_ctx);
event->Wait(platform::Place2DeviceType(place), comm_ctx);
auto& calc_event = place_to_calc_event_.at(key);
const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto* comm_ctx = place_to_comm_ctx_.at(key).get();
calc_event.Record(calc_ctx);
calc_event.Wait(platform::Place2DeviceType(place), comm_ctx);
}
template <typename Fn>
......@@ -296,26 +303,29 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
const auto& place = in_tensor.place();
const auto& key = GetKeyFromPlace(place);
if (!calc_event_) {
platform::CUDADeviceGuard cuda_guard(place);
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLEnvCache(place, key);
}
if (!use_calc_stream) {
SyncCalcStream(place, calc_event_);
SyncCalcStream(place);
}
auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);
const auto* calc_ctx = place_to_calc_ctx_[key];
const auto& comm_ctx = place_to_comm_ctx_[key];
const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto& comm_ctx = place_to_comm_ctx_.at(key);
auto nccl_comm = comm_ctx->nccl_comm();
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
fn(out_tensor, in_tensor, comm_ctx->nccl_comm(), nccl_stream);
fn(out_tensor, in_tensor, nccl_comm, nccl_stream);
if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
memory::RecordStream(in_tensor.Holder(), nccl_stream);
}
task->comm_event_.Record(comm_ctx.get());
task->UpdateWaitChain(*comm_ctx);
}
return task;
......@@ -352,13 +362,13 @@ void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
// TODO(sunyilun): methods below will be removed later
void SyncDefaultStream(const std::vector<Place>& places,
const std::shared_ptr<platform::DeviceEvent>& nccl_event,
platform::DeviceEvent& nccl_event, // NOLINT
std::vector<phi::GPUContext*>& dev_ctx) { // NOLINT
for (size_t i = 0; i < places.size(); ++i) {
auto* default_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(places[i]));
nccl_event->Record(default_ctx);
nccl_event->Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]);
nccl_event.Record(default_ctx);
nccl_event.Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]);
}
}
......@@ -389,7 +399,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<phi::DenseTensor>& inputs)
: TaskStream(rank, inputs, CommType),
comm_event_(places[0]),
place_(places[0]) {}
task_place_(places[0]) {}
ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<Place>& places,
......@@ -400,7 +410,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(
bool use_calc_stream)
: TaskStream(rank, inputs, comm_type, sync_op, use_calc_stream),
comm_event_(places[0]),
place_(places[0]) {}
task_place_(places[0]) {}
// create NCCLManager cache for places_key
void ProcessGroupNCCL::CreateNCCLManagerCache(
......@@ -437,17 +447,18 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
NCCLCHECK(platform::dynload::ncclCommInitRank(
&nccl_comm, GetSize(), nccl_id, GetRank()));
dev_ctx[i]->set_nccl_comm(nccl_comm);
dev_ctx_raw[i] = dev_ctx[i].get();
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
calc_event_ = std::make_shared<platform::DeviceEvent>(places[0]);
// TODO(sunyilun): for compatibility, will be removed later
place_to_calc_ctx_[places_key] = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(places[0]));
place_to_comm_ctx_[places_key] = std::move(dev_ctx[0]);
place_to_calc_event_.emplace(places_key, places[0]);
place_to_calc_ctx_.emplace(
places_key,
static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(places[0])));
place_to_comm_ctx_.emplace(places_key, std::move(dev_ctx[0]));
// These caches will be useful to process sync/wait/communicate
places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw));
......@@ -466,13 +477,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
{
std::lock_guard<std::mutex> lock(mutex_);
if (!calc_event_) {
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLManagerCache(key, places);
}
}
if (!use_calc_stream) {
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
SyncDefaultStream(
places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
}
auto task =
......@@ -492,12 +504,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
platform::DeviceContextPool::Instance().Get(places[i]))
->stream();
} else {
nccl_stream = places_to_ctx_[key][i]->stream();
nccl_stream = places_to_ctx_.at(key)[i]->stream();
}
fn(inputs[i],
outputs[i],
places_to_ctx_[key][i]->nccl_comm(),
places_to_ctx_.at(key)[i]->nccl_comm(),
nccl_stream);
}
}
......@@ -513,7 +525,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
platform::DeviceContextPool::Instance().Get(places[i]))
->stream();
} else {
nccl_stream = places_to_ctx_[key][i]->stream();
nccl_stream = places_to_ctx_.at(key)[i]->stream();
}
memory::RecordStream(inputs[i].Holder(), nccl_stream);
......@@ -524,7 +536,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
if (!use_calc_stream) {
for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]);
task->comm_event_.Record(places_to_ctx_[key][i]);
task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
}
}
......@@ -542,12 +554,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
{
std::lock_guard<std::mutex> lock(mutex_);
if (!calc_event_) {
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLManagerCache(key, places);
}
}
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
SyncDefaultStream(
places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
auto task = CreateTask(places, rank_, op_type, inputs);
......@@ -558,10 +571,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream();
const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
fn(inputs[i],
outputs[i],
places_to_ctx_[key][i]->nccl_comm(),
places_to_ctx_.at(key)[i]->nccl_comm(),
nccl_stream);
}
}
......@@ -570,13 +583,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]);
memory::RecordStream(inputs[i].Holder(),
places_to_ctx_[key][i]->stream());
places_to_ctx_.at(key)[i]->stream());
}
}
for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]);
task->comm_event_.Record(places_to_ctx_[key][i]);
task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
}
return task;
}
......@@ -592,26 +605,27 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
{
std::lock_guard<std::mutex> lock(mutex_);
if (!calc_event_) {
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLManagerCache(key, places);
}
}
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
SyncDefaultStream(
places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
// construct uninitialize guard for device
platform::CUDADeviceGuard cuda_guard;
if (FLAGS_use_stream_safe_cuda_allocator) {
cuda_guard.SetDevice(places[0]);
memory::RecordStream(in->Holder(), places_to_ctx_[key][0]->stream());
memory::RecordStream(in->Holder(), places_to_ctx_.at(key)[0]->stream());
}
{
platform::NCCLGroupGuard nccl_guard;
cuda_guard.SetDevice(places[0]);
const auto& nccl_stream = places_to_ctx_[key][0]->stream();
fn(in, out, places_to_ctx_[key][0]->nccl_comm(), nccl_stream);
const auto& nccl_stream = places_to_ctx_.at(key)[0]->stream();
fn(in, out, places_to_ctx_.at(key)[0]->nccl_comm(), nccl_stream);
}
cuda_guard.SetDevice(places[0]);
......@@ -630,13 +644,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
{
std::lock_guard<std::mutex> lock(mutex_);
if (!calc_event_) {
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLManagerCache(key, places);
}
}
if (!use_calc_stream) {
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
SyncDefaultStream(
places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
}
auto task =
......@@ -655,10 +670,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
platform::DeviceContextPool::Instance().Get(places[i]))
->stream();
} else {
nccl_stream = places_to_ctx_[key][i]->stream();
nccl_stream = places_to_ctx_.at(key)[i]->stream();
}
fn(tensors[i],
places_to_ctx_[key][i]->nccl_comm(),
places_to_ctx_.at(key)[i]->nccl_comm(),
nccl_stream,
dst_rank);
}
......@@ -674,7 +689,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
platform::DeviceContextPool::Instance().Get(places[i]))
->stream();
} else {
nccl_stream = places_to_ctx_[key][i]->stream();
nccl_stream = places_to_ctx_.at(key)[i]->stream();
}
memory::RecordStream(tensors[i].Holder(), nccl_stream);
}
......@@ -683,7 +698,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
if (!use_calc_stream) {
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
task->comm_event_.Record(places_to_ctx_[key][i]);
task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
}
}
......@@ -701,12 +716,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
{
std::lock_guard<std::mutex> lock(mutex_);
if (!calc_event_) {
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLManagerCache(key, places);
}
}
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
SyncDefaultStream(
places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
auto task = CreateTask(places, rank_, op_type, tensors);
......@@ -717,9 +733,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream();
const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
fn(tensors[i],
places_to_ctx_[key][i]->nccl_comm(),
places_to_ctx_.at(key)[i]->nccl_comm(),
nccl_stream,
dst_rank);
}
......@@ -729,13 +745,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
memory::RecordStream(tensors[i].Holder(),
places_to_ctx_[key][i]->stream());
places_to_ctx_.at(key)[i]->stream());
}
}
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
task->comm_event_.Record(places_to_ctx_[key][i]);
task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
}
return task;
}
......@@ -1608,49 +1624,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::_ReduceScatterBase(
phi::DenseTensor& out_tensor,
phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts) {
// auto tensor = out_tensors.back();
PADDLE_ENFORCE_EQ(
out_tensor.dtype(),
in_tensor.dtype(),
platform::errors::InvalidArgument(
"Input tensor and output tensor should be same dtype."));
PADDLE_ENFORCE_EQ(
out_tensor.numel() * size_,
in_tensor.numel(),
platform::errors::InvalidArgument("input tensor must be the same size as "
"output tensor size times world_size"));
auto inputs = std::vector<phi::DenseTensor>{in_tensor};
auto outputs = std::vector<phi::DenseTensor>{out_tensor};
return Collective(
inputs,
outputs,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
platform::CUDADeviceGuard cuda_guard;
cuda_guard.SetDevice(output.place());
memory::RecordStream(output.Holder(), stream);
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
input.data(),
output.data(),
output.numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
stream));
},
CommType::REDUCE_SCATTER);
}
} // namespace distributed
} // namespace paddle
......@@ -15,7 +15,6 @@
#pragma once
#include <chrono>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
......@@ -61,6 +60,9 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
void Synchronize() override;
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
bool IsBlockCPUInWait() const { return block_cpu_in_wait_; }
void SetBlockCPUInWait() { block_cpu_in_wait_ = true; }
// TODO(sunyilun): methods below will be removed later
NCCLTask(const std::vector<Place>& places,
int rank,
......@@ -73,12 +75,10 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream);
public:
bool barrier_{false};
platform::DeviceEvent comm_event_; // event on comm stream
private:
Place place_;
bool block_cpu_in_wait_{false};
platform::DeviceEvent comm_event_; // event on comm stream
Place task_place_;
};
public:
......@@ -253,11 +253,6 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> _ReduceScatterBase(
phi::DenseTensor&, // NOLINT
phi::DenseTensor&, // NOLINT
const ReduceScatterOptions&) override;
private:
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
int rank,
......@@ -278,8 +273,7 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream);
void SyncCalcStream(const Place& place,
const std::shared_ptr<platform::DeviceEvent>& event);
void SyncCalcStream(const Place& place);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
......@@ -342,7 +336,8 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
private:
std::shared_ptr<Store> store_;
std::shared_ptr<platform::DeviceEvent> calc_event_; // event on calc stream
std::unordered_map<std::string, platform::DeviceEvent>
place_to_calc_event_; // event on calc stream
std::unordered_map<std::string, phi::GPUContext*> place_to_calc_ctx_;
std::unordered_map<std::string, std::unique_ptr<phi::GPUContext>>
place_to_comm_ctx_;
......
......@@ -761,27 +761,6 @@ void BindDistributed(py::module *m) {
py::arg("in"),
py::arg("out"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def(
"_reduce_scatter_base",
[](distributed::ProcessGroup &self,
py::handle py_out_tensor,
py::handle py_in_tensor,
distributed::ReduceOp op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
distributed::ReduceScatterOptions opts;
opts.reduce_op = op;
auto dense_out = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
auto dense_in = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
return self._ReduceScatterBase(*dense_out, *dense_in, opts);
},
py::arg("out_tensor"),
py::arg("in_tensor"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>());
auto ProcessGroupStream =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册