未验证 提交 e1a1c354 编写于 作者: W Wen Sun 提交者: GitHub

Refactor collective communication all_gather, all_reduce, broadcast & barrier C++ API (#47481)

上级 2337e609
......@@ -28,6 +28,7 @@ if(WITH_NCCL OR WITH_RCCL)
enforce
collective_helper
device_context
${DEVICE_EVENT_LIBS}
dense_tensor)
if(WITH_DISTRIBUTE AND WITH_PSCORE)
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
......
......@@ -40,6 +40,8 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
return placeList;
}
std::string GetKeyFromPlace(const Place& place) { return place.DebugString(); }
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) {
return std::all_of(
tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) {
......
......@@ -25,6 +25,8 @@ using Place = paddle::platform::Place;
std::vector<Place> GetPlaceList(const std::vector<phi::DenseTensor>& tensors);
// Get the deviceList String from the list of devices
std::string GetKeyFromPlaces(const std::vector<Place>& places);
// Get the device string from one device
std::string GetKeyFromPlace(const Place& place);
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors);
......
......@@ -59,204 +59,6 @@ namespace distributed {
} \
} while (0)
// NOTE(shenliang03): EventManager are movable not copyable CudaEvent wrapper.
// EventManage is different from paddle::platform::CudaEvent.
// It uses lazy initialization and is only created when the
// Record() method is called for the first time; it also monitors
// device information to ensure that recorded stream and event
// are on the same device.
class EventManager {
public:
EventManager() {}
explicit EventManager(unsigned int flags) : flags_{flags} {}
~EventManager() {
if (is_created_) {
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_HIP
hipEventDestroy(event_);
#else
cudaEventDestroy(event_);
#endif
}
}
EventManager(const EventManager&) = delete;
EventManager& operator=(const EventManager&) = delete;
EventManager(EventManager&& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
}
EventManager& operator=(EventManager&& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
return *this;
}
bool IsCreated() const { return is_created_; }
bool DeviceId() const { return device_index_; }
gpuEvent_t GetRawCudaEvent() const { return event_; }
void Record(const phi::GPUContext& ctx) {
auto device_index = ctx.GetPlace().device;
if (!is_created_) {
CreateEvent(device_index);
}
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"phi::GPUContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event_, ctx.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(event_, ctx.stream()));
#endif
}
bool Query() const {
#ifdef PADDLE_WITH_HIP
gpuError_t err = hipEventQuery(event_);
if (err == hipSuccess) {
return true;
}
if (err == hipErrorNotReady) {
return false;
}
#else
gpuError_t err = cudaEventQuery(event_);
if (err == cudaSuccess) {
return true;
}
if (err == cudaErrorNotReady) {
return false;
}
#endif
PADDLE_ENFORCE_GPU_SUCCESS(err);
return false;
}
void Synchronize() const {
if (is_created_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipEventSynchronize(event_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(event_));
#endif
}
}
void Block(const phi::GPUContext& ctx) const {
if (is_created_) {
auto device_index = ctx.GetPlace().device;
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"phi::GPUContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamWaitEvent(ctx.stream(), event_, 0));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(ctx.stream(), event_, 0));
#endif
}
}
private:
#ifdef PADDLE_WITH_HIP
unsigned int flags_ = hipEventDefault;
#else
unsigned int flags_ = cudaEventDefault;
#endif
bool is_created_{false};
gpuEvent_t event_{};
int8_t device_index_{0};
private:
void CreateEvent(int device_index) {
device_index_ = device_index;
platform::CUDADeviceGuard guard(device_index);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipEventCreateWithFlags(&event_, flags_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreateWithFlags(&event_, flags_));
#endif
is_created_ = true;
}
};
// NOTE(shenliang03): NCCLCommManager is more lightweight than
// platform::NCCLComm
class NCCLCommManager {
public:
explicit NCCLCommManager(ncclComm_t ncclComm) : nccl_comm_(ncclComm) {}
NCCLCommManager() : NCCLCommManager(nullptr) {}
~NCCLCommManager() noexcept {
std::unique_lock<std::mutex> lock(mutex_);
if (nccl_comm_) {
platform::dynload::ncclCommDestroy(nccl_comm_);
}
}
static std::shared_ptr<NCCLCommManager> Create(int num_ranks,
int rank,
ncclUniqueId comm_id) {
auto nccl_manager = std::make_shared<NCCLCommManager>();
NCCLCHECK(platform::dynload::ncclCommInitRank(
&(nccl_manager->nccl_comm_), num_ranks, comm_id, rank));
nccl_manager->nccl_id_ = comm_id;
nccl_manager->rank_ = rank;
return nccl_manager;
}
ncclUniqueId GetNcclId() const {
std::unique_lock<std::mutex> lock(mutex_);
return nccl_id_;
}
ncclComm_t GetNcclComm() const {
std::unique_lock<std::mutex> lock(mutex_);
return nccl_comm_;
}
NCCLCommManager(const NCCLCommManager&) = delete;
NCCLCommManager& operator=(const NCCLCommManager&) = delete;
NCCLCommManager& operator=(NCCLCommManager&& other) = delete;
NCCLCommManager(NCCLCommManager&& other) {
std::unique_lock<std::mutex> lock(other.mutex_);
std::swap(nccl_comm_, other.nccl_comm_);
}
protected:
ncclComm_t nccl_comm_;
ncclUniqueId nccl_id_;
int rank_;
mutable std::mutex mutex_;
};
ncclRedOp_t ToNCCLRedType(ReduceOp reduction);
std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID);
......
......@@ -17,15 +17,7 @@
namespace paddle {
namespace distributed {
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op)
ProcessGroup::Task::Task(int rank, CommType comm_type, bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
ProcessGroup::Task::~Task() = default;
......@@ -62,5 +54,17 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid)
}
}
// TODO(sunyilun): methods below will be removed later
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
} // namespace distributed
} // namespace paddle
......@@ -54,13 +54,7 @@ class ProcessGroup {
public:
class Task {
public:
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type);
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op);
Task(int rank, CommType comm_type, bool sync_op);
virtual ~Task();
virtual bool IsCompleted();
......@@ -69,6 +63,15 @@ class ProcessGroup {
virtual void UpdateWaitChain(const phi::DeviceContext& ctx);
bool IsSync() const { return sync_op_; }
// TODO(sunyilun): methods below will be removed later
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type);
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op);
protected:
const int rank_;
CommType comm_type_{CommType::UNKNOWN};
......@@ -79,6 +82,7 @@ class ProcessGroup {
bool sync_op_{true};
};
public:
explicit ProcessGroup(int rank,
int size,
const platform::Place& place,
......@@ -93,12 +97,48 @@ class ProcessGroup {
int GetSize() const { return size_; }
virtual std::string GetBackendName() const = 0;
virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const {
PADDLE_THROW(platform::errors::InvalidArgument(
"Does not support to get device_context from ProcessGroup%s.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support all_gather with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support all_reduce with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support barrier", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support broadcast with sync_op flag",
GetBackendName()));
}
// TODO(liyurui): This API will be moved later
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
......@@ -118,6 +158,7 @@ class ProcessGroup {
GetBackendName()));
}
// TODO(sunyilun): methods below will be removed later
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
......@@ -136,12 +177,6 @@ class ProcessGroup {
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support barrier", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>&, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
......
......@@ -229,6 +229,17 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
}
};
// TODO(sunyilun): for compatibility, will be updated later
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper = {in_tensor};
std::vector<phi::DenseTensor> out_wrapper = {*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
......
......@@ -108,6 +108,13 @@ class ProcessGroupGloo : public ProcessGroup {
~ProcessGroupGloo() = default;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
......
......@@ -15,12 +15,8 @@
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator);
......@@ -30,89 +26,299 @@ constexpr int64_t kWaitBlockTImeout = 10;
namespace paddle {
namespace distributed {
void SyncDefaultStream(
const std::vector<Place>& places,
std::vector<EventManager>& ncclEvents, // NOLINT
std::vector<std::unique_ptr<phi::GPUContext>>& dev_ctx) { // NOLINT
for (size_t i = 0; i < places.size(); ++i) {
auto* default_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(places[i]));
ncclEvents[i].Record(*default_ctx);
ncclEvents[i].Block(*dev_ctx[i]);
ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place,
int rank,
CommType comm_type,
bool sync_op,
bool use_calc_stream)
: TaskStream(rank, comm_type, sync_op, use_calc_stream),
comm_event_(place),
place_(place) {}
ProcessGroupNCCL::NCCLTask::~NCCLTask() {}
bool ProcessGroupNCCL::NCCLTask::IsCompleted() { return comm_event_.Query(); }
void ProcessGroupNCCL::NCCLTask::UpdateWaitChain(
const phi::DeviceContext& ctx) {
comm_event_.Record(&ctx);
}
// TODO(sheniang03): Add timeout for wait, now timeout unused
bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
// Warning here when use calc stream but also invoke waiting explicitly.
if (UseCalcStream()) {
VLOG(3) << "Warning: The communication is on calc stream, wait here is "
"useless.";
return true;
}
const auto* calc_ctx = platform::DeviceContextPool::Instance().Get(place_);
comm_event_.Wait(platform::Place2DeviceType(place_), calc_ctx);
if (FLAGS_nccl_blocking_wait) {
// NOTE(shenliang03): It will block host for sync
while (!IsCompleted()) {
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
}
}
if (barrier_) {
// If we use the work to do barrier, we should block cpu
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
}
return true;
}
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
std::vector<Place> places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupNCCL::NCCLTask>(
places, rank, comm_type, inputs);
// Same as Wait
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank,
int size,
const platform::Place& place,
int gid)
: ProcessGroupStream(rank, size, place, gid), store_(store) {
platform::SetDeviceId(place_.device);
}
void ProcessGroupNCCL::GroupStart() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
}
void ProcessGroupNCCL::GroupEnd() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
const Place& place) const {
return GetDeviceContext(place, /*use_calc_stream*/ false);
}
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
const std::string& key = GetKeyFromPlace(place);
if (use_calc_stream) {
const auto& iter = place_to_calc_ctx_.find(key);
return *iter->second;
} else {
const auto& iter = place_to_comm_ctx_.find(key);
PADDLE_ENFORCE_NE(
iter,
place_to_comm_ctx_.end(),
platform::errors::NotFound(
"Cannot find the device context in this process group."));
return *iter->second;
}
}
ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
const std::string& key = GetKeyFromPlace(place);
const auto& iter = place_to_comm_ctx_.find(key);
PADDLE_ENFORCE_NE(
iter,
place_to_comm_ctx_.end(),
platform::errors::NotFound(
"Cannot find the NCCL commmunicator in this process group."));
return iter->second->nccl_comm();
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllGather(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllReduce(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op),
comm,
stream);
},
CommType::ALLREDUCE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
const BarrierOptions& opts) {
auto allocator = std::unique_ptr<phi::Allocator>(
new paddle::experimental::DefaultAllocator(place_));
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
phi::DenseTensor barrier_tensor{allocator.get(), meta};
auto task = AllReduce(&barrier_tensor,
barrier_tensor,
{},
/*sync_op*/ true,
/*use_calc_stream*/ false);
auto nccl_task = dynamic_cast<NCCLTask*>(task.get());
nccl_task->barrier_ = true;
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream) {
int root = opts.source_rank + opts.source_root;
return platform::dynload::ncclBroadcast(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
root,
comm,
stream);
},
CommType::BROADCAST,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
const std::vector<Place>& places,
const Place& place,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs,
bool is_sync,
bool use_calc_stream) {
return std::make_shared<ProcessGroupNCCL::NCCLTask>(
places, rank, comm_type, inputs, is_sync, use_calc_stream);
place, rank, comm_type, is_sync, use_calc_stream);
}
ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs)
: TaskStream(rank, inputs, CommType), places_(places) {
control_events_.resize(places.size());
ncclComms_.resize(places.size());
void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) {
const std::string key =
"ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/0";
if (rank_ == 0) {
std::vector<uint8_t> nccl_id_wrapper(
reinterpret_cast<uint8_t*>(nccl_id),
reinterpret_cast<uint8_t*>(nccl_id) + NCCL_UNIQUE_ID_BYTES);
store_->set(key, nccl_id_wrapper);
} else {
const auto& nccl_id_wrapper = store_->get(key);
std::memcpy(nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size());
}
}
ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<Place>& places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs,
bool sync_op,
bool use_calc_stream)
: TaskStream(rank, inputs, comm_type, sync_op, use_calc_stream),
places_(places) {
control_events_.resize(places.size());
ncclComms_.resize(places.size());
}
void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
const std::string& place_key) {
ncclUniqueId nccl_id;
if (rank_ == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id));
}
BroadcastUniqueNCCLID(&nccl_id);
ProcessGroupNCCL::NCCLTask::~NCCLTask() {}
VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << place_key
<< ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);
void ProcessGroupNCCL::NCCLTask::SetOutputs(
std::vector<phi::DenseTensor>& outputs) { // NOLINT
outputs_ = std::make_shared<std::vector<phi::DenseTensor>>(outputs);
calc_event_ = std::make_shared<platform::DeviceEvent>(place);
auto* calc_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place));
auto comm_ctx = std::make_unique<phi::GPUContext>(place);
ncclComm_t nccl_comm;
NCCLCHECK(platform::dynload::ncclCommInitRank(
&nccl_comm, GetSize(), nccl_id, GetRank()));
comm_ctx->set_nccl_comm(nccl_comm);
place_to_calc_ctx_[place_key] = calc_ctx;
place_to_comm_ctx_[place_key] = std::move(comm_ctx);
// TODO(sunyilun): for compatibility, will be removed later
places_to_ctx_[place_key] = {place_to_comm_ctx_[place_key].get()};
}
void ProcessGroupNCCL::NCCLTask::SynchronizeStreams() {
for (size_t i = 0; i < places_.size(); ++i) {
auto* default_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(places_[i]));
default_ctx->WaitEvent(control_events_[i].GetRawCudaEvent());
}
void ProcessGroupNCCL::SyncCalcStream(
const Place& place, const std::shared_ptr<platform::DeviceEvent>& event) {
const std::string& key = GetKeyFromPlace(place);
const auto* calc_ctx = place_to_calc_ctx_[key];
const auto* comm_ctx = place_to_comm_ctx_[key].get();
event->Record(calc_ctx);
event->Wait(platform::Place2DeviceType(place), comm_ctx);
}
bool ProcessGroupNCCL::NCCLTask::IsCompleted() {
for (size_t i = 0; i < places_.size(); ++i) {
if (!control_events_[i].Query()) {
return false;
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
Fn fn,
CommType comm_type,
bool sync_op,
bool use_calc_stream) {
const auto& place = in_tensor.place();
const auto& key = GetKeyFromPlace(place);
if (!calc_event_) {
CreateNCCLEnvCache(place, key);
}
return true;
}
if (!use_calc_stream) {
SyncCalcStream(place, calc_event_);
}
void ProcessGroupNCCL::NCCLTask::UpdateWaitChain(
const phi::DeviceContext& ctx) {
control_events_[0].Record(*static_cast<const phi::GPUContext*>(&ctx));
auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);
const auto* calc_ctx = place_to_calc_ctx_[key];
const auto& comm_ctx = place_to_comm_ctx_[key];
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
fn(out_tensor, in_tensor, comm_ctx->nccl_comm(), nccl_stream);
if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
memory::RecordStream(in_tensor.Holder(), nccl_stream);
}
task->comm_event_.Record(comm_ctx.get());
}
return task;
}
void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
......@@ -144,70 +350,58 @@ void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
}
}
// TODO(sheniang03): Add timeout for wait, now timeout unused
bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
// Warning here when use calc stream but also invoke waiting explicitly.
if (UseCalcStream()) {
VLOG(3) << "Warning: The communication is on calc stream, wait here is "
"useless.";
return true;
}
SynchronizeStreams();
if (FLAGS_nccl_blocking_wait) {
// NOTE(shenliang03): It will block host for sync
while (!IsCompleted()) {
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
}
}
if (!barrierTensors_.empty()) {
// If we use the work to do barrier, we should block cpu
for (auto& place : places_) {
platform::CUDADeviceGuard gpuGuard(place);
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
}
// TODO(sunyilun): methods below will be removed later
void SyncDefaultStream(const std::vector<Place>& places,
const std::shared_ptr<platform::DeviceEvent>& nccl_event,
std::vector<phi::GPUContext*>& dev_ctx) { // NOLINT
for (size_t i = 0; i < places.size(); ++i) {
auto* default_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(places[i]));
nccl_event->Record(default_ctx);
nccl_event->Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]);
}
return true;
}
// Same as Wait
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank,
int size,
const platform::Place& place,
int gid)
: ProcessGroupStream(rank, size, place, gid), store_(store) {
platform::SetDeviceId(place_.device);
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
std::vector<Place> places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupNCCL::NCCLTask>(
places, rank, comm_type, inputs);
}
void ProcessGroupNCCL::BroadcastUniqueNCCLID(
std::vector<ncclUniqueId>& nccl_ids) { // NOLINT
if (rank_ == 0) {
for (size_t i = 0; i < nccl_ids.size(); i++) {
auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
auto nccl_id = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(&nccl_ids[i]),
reinterpret_cast<uint8_t*>(&nccl_ids[i]) + NCCL_UNIQUE_ID_BYTES);
store_->set(key, nccl_id);
}
} else {
for (size_t i = 0; i < nccl_ids.size(); i++) {
auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
auto ret = store_->get(key);
std::memcpy(&nccl_ids[i], ret.data(), ret.size());
}
}
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
const std::vector<Place>& places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs,
bool is_sync,
bool use_calc_stream) {
return std::make_shared<ProcessGroupNCCL::NCCLTask>(
places, rank, comm_type, inputs, is_sync, use_calc_stream);
}
ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs)
: TaskStream(rank, inputs, CommType),
comm_event_(places[0]),
place_(places[0]) {}
ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<Place>& places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs,
bool sync_op,
bool use_calc_stream)
: TaskStream(rank, inputs, comm_type, sync_op, use_calc_stream),
comm_event_(places[0]),
place_(places[0]) {}
// create NCCLManager cache for places_key
void ProcessGroupNCCL::CreateNCCLManagerCache(
const std::string& places_key, const std::vector<Place>& places) {
......@@ -217,22 +411,11 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
"Not able to create/get the NCCL Communicator since "
"the GPU place are not known"));
std::vector<std::shared_ptr<NCCLCommManager>> nccl_comms;
nccl_comms.resize(places.size());
// using vector just for broadcast
std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1);
auto& nccl_id = nccl_ids.front();
for (auto& place : places) {
used_place_ids_.insert(place.GetDeviceId());
}
ncclUniqueId nccl_id;
if (rank_ == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id));
}
BroadcastUniqueNCCLID(nccl_ids);
BroadcastUniqueNCCLID(&nccl_id);
VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << places_key
......@@ -241,23 +424,33 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
std::vector<std::unique_ptr<phi::GPUContext>> dev_ctx;
dev_ctx.resize(places.size());
std::vector<phi::GPUContext*> dev_ctx_raw;
dev_ctx_raw.resize(places.size());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (size_t i = 0; i < places.size(); ++i) {
platform::CUDADeviceGuard guard(places[i]);
nccl_comms[i] = NCCLCommManager::Create(GetSize(), GetRank(), nccl_id);
dev_ctx[i].reset(new phi::GPUContext(places[i]));
ncclComm_t nccl_comm;
NCCLCHECK(platform::dynload::ncclCommInitRank(
&nccl_comm, GetSize(), nccl_id, GetRank()));
dev_ctx[i]->set_nccl_comm(nccl_comm);
dev_ctx_raw[i] = dev_ctx[i].get();
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
calc_event_ = std::make_shared<platform::DeviceEvent>(places[0]);
std::vector<EventManager> events;
events.resize(places.size());
// TODO(sunyilun): for compatibility, will be removed later
place_to_calc_ctx_[places_key] = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(places[0]));
place_to_comm_ctx_[places_key] = std::move(dev_ctx[0]);
// These caches will be useful to process sync/wait/communicate
places_to_events_.emplace(places_key, std::move(events));
places_to_ncclcomm_.emplace(places_key, std::move(nccl_comms));
places_to_ctx_.emplace(places_key, std::move(dev_ctx));
places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw));
}
template <typename Fn>
......@@ -273,15 +466,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) {
if (!calc_event_) {
CreateNCCLManagerCache(key, places);
}
}
auto& nccl_comms = places_to_ncclcomm_[key];
if (!use_calc_stream) {
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
}
auto task =
......@@ -304,7 +495,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
nccl_stream = places_to_ctx_[key][i]->stream();
}
fn(inputs[i], outputs[i], nccl_comms[i]->GetNcclComm(), nccl_stream);
fn(inputs[i],
outputs[i],
places_to_ctx_[key][i]->nccl_comm(),
nccl_stream);
}
}
......@@ -330,7 +524,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
if (!use_calc_stream) {
for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
task->comm_event_.Record(places_to_ctx_[key][i]);
}
}
......@@ -348,14 +542,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) {
if (!calc_event_) {
CreateNCCLManagerCache(key, places);
}
}
auto& nccl_comms = places_to_ncclcomm_[key];
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
auto task = CreateTask(places, rank_, op_type, inputs);
......@@ -367,7 +559,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream();
fn(inputs[i], outputs[i], nccl_comms[i]->GetNcclComm(), nccl_stream);
fn(inputs[i],
outputs[i],
places_to_ctx_[key][i]->nccl_comm(),
nccl_stream);
}
}
......@@ -381,7 +576,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
task->comm_event_.Record(places_to_ctx_[key][i]);
}
return task;
}
......@@ -393,18 +588,16 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
CommType op_type) {
std::vector<Place> places;
places.push_back(in->place());
const auto key = GetKeyFromPlaces(places);
const std::string& key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) {
if (!calc_event_) {
CreateNCCLManagerCache(key, places);
}
}
auto& nccl_comms = places_to_ncclcomm_[key];
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
// construct uninitialize guard for device
platform::CUDADeviceGuard cuda_guard;
......@@ -418,7 +611,7 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
platform::NCCLGroupGuard nccl_guard;
cuda_guard.SetDevice(places[0]);
const auto& nccl_stream = places_to_ctx_[key][0]->stream();
fn(in, out, nccl_comms[0]->GetNcclComm(), nccl_stream);
fn(in, out, places_to_ctx_[key][0]->nccl_comm(), nccl_stream);
}
cuda_guard.SetDevice(places[0]);
......@@ -437,15 +630,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) {
if (!calc_event_) {
CreateNCCLManagerCache(key, places);
}
}
auto& nccl_comms = places_to_ncclcomm_[key];
if (!use_calc_stream) {
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
}
auto task =
......@@ -466,7 +657,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
} else {
nccl_stream = places_to_ctx_[key][i]->stream();
}
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
fn(tensors[i],
places_to_ctx_[key][i]->nccl_comm(),
nccl_stream,
dst_rank);
}
}
......@@ -489,7 +683,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
if (!use_calc_stream) {
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
task->comm_event_.Record(places_to_ctx_[key][i]);
}
}
......@@ -507,14 +701,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) {
if (!calc_event_) {
CreateNCCLManagerCache(key, places);
}
}
auto& nccl_comms = places_to_ncclcomm_[key];
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
auto task = CreateTask(places, rank_, op_type, tensors);
......@@ -526,7 +718,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream();
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
fn(tensors[i],
places_to_ctx_[key][i]->nccl_comm(),
nccl_stream,
dst_rank);
}
}
......@@ -540,7 +735,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
task->comm_event_.Record(places_to_ctx_[key][i]);
}
return task;
}
......@@ -572,37 +767,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
CommType::ALLREDUCE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllReduce(
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op),
comm,
stream);
},
CommType::ALLREDUCE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......@@ -633,63 +797,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
CommType::BROADCAST);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
const auto root =
opts.source_rank * in_tensors.size() + opts.source_root;
return platform::dynload::ncclBroadcast(
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
root,
comm,
stream);
},
CommType::BROADCAST,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
const BarrierOptions& opts) {
// Only support single card single process
std::vector<phi::GPUPlace> places = {place_};
std::vector<phi::DenseTensor> barrierTensors;
barrierTensors.reserve(places.size());
platform::CUDADeviceGuard gpuGuard;
for (auto& place : places) {
gpuGuard.SetDeviceIndex(place.GetDeviceId());
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim({1}));
auto allocator = std::unique_ptr<phi::Allocator>(
new paddle::experimental::DefaultAllocator(place));
barrierTensors.emplace_back(allocator.get(), meta);
}
auto task = ProcessGroupNCCL::AllReduce(
barrierTensors, barrierTensors, AllreduceOptions());
auto nccl_task = dynamic_cast<ProcessGroupNCCL::NCCLTask*>(task.get());
nccl_task->barrierTensors_ = std::move(barrierTensors);
return task;
}
void CheckTensorsInDifferentDevices(
const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
PADDLE_ENFORCE_EQ(
......@@ -975,39 +1082,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
CommType::ALLGATHER);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllGather(
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER,
sync_op,
use_calc_stream);
}
void* GetPointerByOffset(void* raw_pointer,
size_t offset,
experimental::DataType type) {
......@@ -1578,43 +1652,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::_ReduceScatterBase(
CommType::REDUCE_SCATTER);
}
void ProcessGroupNCCL::GroupStart() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
}
void ProcessGroupNCCL::GroupEnd() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}
ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_ncclcomm_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_ncclcomm_.end(),
platform::errors::InvalidArgument(
"Cannot find nccl comm in process group."));
return iter->second[0]->GetNcclComm();
}
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
const Place& place) const {
return GetDeviceContext(place, /*use_calc_stream*/ false);
}
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
if (use_calc_stream) {
return *platform::DeviceContextPool::Instance().Get(place);
} else {
std::vector<Place> places = {place};
const auto& iter = places_to_ctx_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_ctx_.end(),
platform::errors::InvalidArgument(
"Cannot find device context in process group."));
return *iter->second[0];
}
}
} // namespace distributed
} // namespace paddle
......@@ -24,10 +24,10 @@
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_event.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/NCCLTools.h"
......@@ -44,16 +44,28 @@ namespace distributed {
using Place = paddle::platform::Place;
class ProcessGroupNCCL : public ProcessGroupStream {
class ProcessGroupNCCL final : public ProcessGroupStream {
public:
class NCCLTask : public ProcessGroupStream::TaskStream,
public std::enable_shared_from_this<NCCLTask> {
class NCCLTask final : public ProcessGroupStream::TaskStream,
public std::enable_shared_from_this<NCCLTask> {
public:
NCCLTask(const Place& place,
int rank,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
virtual ~NCCLTask();
bool IsCompleted() override;
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override;
void Synchronize() override;
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
// TODO(sunyilun): methods below will be removed later
NCCLTask(const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs);
NCCLTask(const std::vector<Place>& places,
int rank,
CommType comm_type,
......@@ -61,31 +73,15 @@ class ProcessGroupNCCL : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream);
bool IsCompleted();
void SynchronizeStreams();
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
void Synchronize();
void SetOutputs(std::vector<phi::DenseTensor>& outputs); // NOLINT
virtual ~NCCLTask();
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
std::vector<EventManager> control_events_;
std::vector<phi::DenseTensor> barrierTensors_;
protected:
std::vector<Place> places_;
std::vector<std::shared_ptr<NCCLCommManager>> ncclComms_;
std::shared_ptr<std::vector<phi::DenseTensor>> outputs_;
public:
bool barrier_{false};
platform::DeviceEvent comm_event_; // event on comm stream
private:
Place place_;
};
public:
ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank,
int size,
......@@ -99,34 +95,47 @@ class ProcessGroupNCCL : public ProcessGroupStream {
const phi::DeviceContext& GetDeviceContext(
const Place& place, bool use_calc_stream) const override;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const AllreduceOptions& options,
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) override;
static void GroupStart();
static void GroupEnd();
ncclComm_t NCCLComm(const Place& place) const;
// TODO(liyurui): This API will be moved later
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
......@@ -175,12 +184,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......@@ -255,20 +258,37 @@ class ProcessGroupNCCL : public ProcessGroupStream {
phi::DenseTensor&, // NOLINT
const ReduceScatterOptions&) override;
static void GroupStart();
private:
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
int rank,
CommType op_type,
bool sync_op,
bool use_calc_stream);
static void GroupEnd();
void BroadcastUniqueNCCLID(ncclUniqueId* nccl_id);
ncclComm_t NCCLComm(const Place& place) const;
void CreateNCCLEnvCache(const Place& place, const std::string& place_key);
template <typename Fn>
std::shared_ptr<ProcessGroupStream::Task> Collective(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
Fn fn,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
protected:
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
void SyncCalcStream(const Place& place,
const std::shared_ptr<platform::DeviceEvent>& event);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places,
int rank,
CommType op_type,
const std::vector<phi::DenseTensor>& inputs);
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
const std::vector<Place>& places,
int rank,
CommType op_type,
......@@ -276,27 +296,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream);
protected:
std::shared_ptr<Store> store_;
std::shared_ptr<NCCLCommManager> nccl_comm_;
std::mutex mutex_;
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLCommManager>>>
places_to_ncclcomm_;
std::unordered_map<std::string, std::vector<EventManager>> places_to_events_;
std::unordered_map<std::string, std::vector<std::unique_ptr<phi::GPUContext>>>
places_to_ctx_;
std::set<int> used_place_ids_;
private:
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, // NOLINT
int root, // NOLINT
int server_fd);
void BroadcastUniqueNCCLID(std::vector<ncclUniqueId>& nccl_ids); // NOLINT
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<phi::DenseTensor>& inputs, // NOLINT
......@@ -340,6 +339,17 @@ class ProcessGroupNCCL : public ProcessGroupStream {
void CheckSplitSizes(std::vector<int64_t>* split_sizes,
std::vector<int64_t> tensor_shape);
private:
std::shared_ptr<Store> store_;
std::shared_ptr<platform::DeviceEvent> calc_event_; // event on calc stream
std::unordered_map<std::string, phi::GPUContext*> place_to_calc_ctx_;
std::unordered_map<std::string, std::unique_ptr<phi::GPUContext>>
place_to_comm_ctx_;
// TODO(sunyilun): attrs below will be removed later
std::mutex mutex_;
std::unordered_map<std::string, std::vector<phi::GPUContext*>> places_to_ctx_;
};
} // namespace distributed
......
......@@ -30,18 +30,18 @@ const phi::DeviceContext& ProcessGroupStream::GetDeviceContext(
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) {
return AllGather(input_tensors,
output_tensors,
return AllGather(out_tensor,
in_tensor,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -49,27 +49,50 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
const AllreduceOptions& options,
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) {
return AllReduce(input_tensors,
output_tensors,
options,
return AllReduce(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
const AllreduceOptions& options,
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do all_reduce", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
return Broadcast(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do broadcast", GetBackendName()));
}
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......@@ -114,28 +137,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle(
"ProcessGroup%s does not support do alltoall_single", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op) {
return Broadcast(in_tensors,
out_tensors,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do broadcast", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......
......@@ -27,6 +27,11 @@ class ProcessGroupStream : public ProcessGroup {
public:
class TaskStream : public ProcessGroup::Task {
public:
TaskStream(int rank, CommType comm_type, bool sync_op, bool use_calc_stream)
: Task(rank, comm_type, sync_op), use_calc_stream_(use_calc_stream) {}
virtual ~TaskStream() = default;
// TODO(liyurui): This constructor is temporary here for compatible reason,
// will be deleted soon.
TaskStream(int rank,
......@@ -42,8 +47,6 @@ class ProcessGroupStream : public ProcessGroup {
: Task(rank, inputs, comm_type, sync_op),
use_calc_stream_(use_calc_stream) {}
virtual ~TaskStream() = default;
protected:
bool UseCalcStream() const { return use_calc_stream_; }
......@@ -51,6 +54,7 @@ class ProcessGroupStream : public ProcessGroup {
bool use_calc_stream_{false};
};
public:
ProcessGroupStream(int rank, int size, const platform::Place& place, int gid);
virtual ~ProcessGroupStream() = default;
......@@ -58,29 +62,43 @@ class ProcessGroupStream : public ProcessGroup {
const Place& place, bool use_calc_stream) const;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
const AllreduceOptions& options,
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
const AllreduceOptions& options,
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
......@@ -107,19 +125,6 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const BroadcastOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
......
......@@ -51,14 +51,9 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
auto pg_nccl = static_cast<distributed::ProcessGroupNCCL *>(pg);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(tensor);
out_tensor.push_back(tensor);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true);
auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
task->Wait();
} else {
auto dtype = platform::ToNCCLDataType(
......
......@@ -44,14 +44,9 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT
if (map->has(ring_id)) {
paddle::distributed::ProcessGroup* pg = map->get(ring_id);
auto pg_nccl = static_cast<distributed::ProcessGroupNCCL*>(pg);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(tensor);
out_tensor.push_back(tensor);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true);
auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
task->Wait();
} else {
auto dtype = platform::ToNCCLDataType(
......
......@@ -147,12 +147,12 @@ void BindDistributed(py::module *m) {
distributed::ReduceOp op,
bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts;
opts.reduce_op = op;
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.AllReduce(tensors, tensors, opts, sync_op);
auto *out_dense = p_dense.get();
auto in_dense = *p_dense;
distributed::AllreduceOptions opts{op};
return self.AllReduce(out_dense, in_dense, opts, sync_op);
},
py::arg("tensor"),
py::arg("op"),
......@@ -183,11 +183,12 @@ void BindDistributed(py::module *m) {
int src,
bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts{src};
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Broadcast(tensors, tensors, opts, sync_op);
auto *out_dense = p_dense.get();
auto in_dense = *p_dense;
distributed::BroadcastOptions opts{src};
return self.Broadcast(out_dense, in_dense, opts, sync_op);
},
py::arg("tensor"),
py::arg("src"),
......@@ -380,52 +381,52 @@ void BindDistributed(py::module *m) {
.def(
"allgather",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor_list,
py::handle py_in_tensor,
bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
const auto &dev_ctx = self.GetDeviceContext(in_tensor.place());
auto task = self.AllGather(in_wrapper, out_wrapper, sync_op);
auto task = self.AllGather(out_dense, in_dense, sync_op);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx);
return task;
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"allgather_into_tensor",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
py::handle py_in_tensor,
bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllGather(in_wrapper, out_wrapper, sync_op);
return self.AllGather(out_dense, in_dense, sync_op);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
......@@ -784,55 +785,55 @@ void BindDistributed(py::module *m) {
.def(
"allgather_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor_list) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
py::handle py_out_tensor_list,
py::handle py_in_tensor) {
auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
const auto &dev_ctx =
self.GetDeviceContext(in_tensor.place(), true);
auto task = self.AllGather(in_wrapper,
out_wrapper,
auto task = self.AllGather(out_dense,
in_dense,
/*sync_op*/ true,
/*use_calc_stream*/ true);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
return task;
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::call_guard<py::gil_scoped_release>())
.def(
"allgather_into_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
py::handle py_out_tensor,
py::handle py_in_tensor) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
return self.AllGather(in_wrapper,
out_wrapper,
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllGather(out_dense,
in_dense,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::call_guard<py::gil_scoped_release>())
.def(
......@@ -872,13 +873,13 @@ void BindDistributed(py::module *m) {
py::handle py_tensor,
distributed::ReduceOp op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts;
opts.reduce_op = op;
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.AllReduce(tensors,
tensors,
auto in_dense = *p_dense;
auto *out_dense = p_dense.get();
distributed::AllreduceOptions opts{op};
return self.AllReduce(out_dense,
in_dense,
opts,
/*sync_op*/ true,
/*use_calc_stream*/ true);
......@@ -980,12 +981,13 @@ void BindDistributed(py::module *m) {
py::handle py_tensor,
int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts{src};
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Broadcast(tensors,
tensors,
auto *out_dense = p_dense.get();
auto in_dense = *p_dense;
distributed::BroadcastOptions opts{src};
return self.Broadcast(out_dense,
in_dense,
opts,
/*sync_op*/ true,
/*use_calc_stream*/ true);
......
......@@ -21,18 +21,18 @@ def _check_tensor_shape(tensor, shape, nranks=1):
expect_shape = list(shape)
expect_shape[0] *= nranks
if list(tensor.shape) != expect_shape:
raise RuntimeError('The tensor for all_gather is not correctly-sized.')
raise RuntimeError("The tensor for all_gather is not correctly-sized.")
def _check_tensor_list_shape(tensor_list, shape, nranks=1):
if len(tensor_list) != nranks:
raise RuntimeError(
'The tensor_list for all_gather is not correctly-sized.'
"The tensor_list for all_gather is not correctly-sized."
)
for tensor in tensor_list:
if tensor.shape != shape:
raise RuntimeError(
'The tensor_list for all_gather is not correctly-sized.'
"The tensor_list for all_gather is not correctly-sized."
)
......@@ -45,11 +45,12 @@ def _all_gather_into_tensor_in_dygraph(
if use_calc_stream:
return group.process_group.allgather_into_tensor_on_calc_stream(
in_tensor, out_tensor
out_tensor,
in_tensor,
)
task = group.process_group.allgather_into_tensor(
in_tensor, out_tensor, sync_op
out_tensor, in_tensor, sync_op
)
if sync_op:
task.wait()
......@@ -68,9 +69,9 @@ def _all_gather_in_dygraph(
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
if use_calc_stream:
return group.process_group.allgather_on_calc_stream(tensor, tensor_list)
return group.process_group.allgather_on_calc_stream(tensor_list, tensor)
task = group.process_group.allgather(tensor, tensor_list, sync_op)
task = group.process_group.allgather(tensor_list, tensor, sync_op)
if sync_op:
task.wait()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册