未验证 提交 e1a1c354 编写于 作者: W Wen Sun 提交者: GitHub

Refactor collective communication all_gather, all_reduce, broadcast & barrier C++ API (#47481)

上级 2337e609
......@@ -28,6 +28,7 @@ if(WITH_NCCL OR WITH_RCCL)
enforce
collective_helper
device_context
${DEVICE_EVENT_LIBS}
dense_tensor)
if(WITH_DISTRIBUTE AND WITH_PSCORE)
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
......
......@@ -40,6 +40,8 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
return placeList;
}
std::string GetKeyFromPlace(const Place& place) { return place.DebugString(); }
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) {
return std::all_of(
tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) {
......
......@@ -25,6 +25,8 @@ using Place = paddle::platform::Place;
std::vector<Place> GetPlaceList(const std::vector<phi::DenseTensor>& tensors);
// Get the deviceList String from the list of devices
std::string GetKeyFromPlaces(const std::vector<Place>& places);
// Get the device string from one device
std::string GetKeyFromPlace(const Place& place);
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors);
......
......@@ -59,204 +59,6 @@ namespace distributed {
} \
} while (0)
// NOTE(shenliang03): EventManager are movable not copyable CudaEvent wrapper.
// EventManage is different from paddle::platform::CudaEvent.
// It uses lazy initialization and is only created when the
// Record() method is called for the first time; it also monitors
// device information to ensure that recorded stream and event
// are on the same device.
class EventManager {
public:
EventManager() {}
explicit EventManager(unsigned int flags) : flags_{flags} {}
~EventManager() {
if (is_created_) {
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_HIP
hipEventDestroy(event_);
#else
cudaEventDestroy(event_);
#endif
}
}
EventManager(const EventManager&) = delete;
EventManager& operator=(const EventManager&) = delete;
EventManager(EventManager&& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
}
EventManager& operator=(EventManager&& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
return *this;
}
bool IsCreated() const { return is_created_; }
bool DeviceId() const { return device_index_; }
gpuEvent_t GetRawCudaEvent() const { return event_; }
void Record(const phi::GPUContext& ctx) {
auto device_index = ctx.GetPlace().device;
if (!is_created_) {
CreateEvent(device_index);
}
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"phi::GPUContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event_, ctx.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(event_, ctx.stream()));
#endif
}
bool Query() const {
#ifdef PADDLE_WITH_HIP
gpuError_t err = hipEventQuery(event_);
if (err == hipSuccess) {
return true;
}
if (err == hipErrorNotReady) {
return false;
}
#else
gpuError_t err = cudaEventQuery(event_);
if (err == cudaSuccess) {
return true;
}
if (err == cudaErrorNotReady) {
return false;
}
#endif
PADDLE_ENFORCE_GPU_SUCCESS(err);
return false;
}
void Synchronize() const {
if (is_created_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipEventSynchronize(event_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(event_));
#endif
}
}
void Block(const phi::GPUContext& ctx) const {
if (is_created_) {
auto device_index = ctx.GetPlace().device;
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"phi::GPUContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamWaitEvent(ctx.stream(), event_, 0));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(ctx.stream(), event_, 0));
#endif
}
}
private:
#ifdef PADDLE_WITH_HIP
unsigned int flags_ = hipEventDefault;
#else
unsigned int flags_ = cudaEventDefault;
#endif
bool is_created_{false};
gpuEvent_t event_{};
int8_t device_index_{0};
private:
void CreateEvent(int device_index) {
device_index_ = device_index;
platform::CUDADeviceGuard guard(device_index);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipEventCreateWithFlags(&event_, flags_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreateWithFlags(&event_, flags_));
#endif
is_created_ = true;
}
};
// NOTE(shenliang03): NCCLCommManager is more lightweight than
// platform::NCCLComm
class NCCLCommManager {
public:
explicit NCCLCommManager(ncclComm_t ncclComm) : nccl_comm_(ncclComm) {}
NCCLCommManager() : NCCLCommManager(nullptr) {}
~NCCLCommManager() noexcept {
std::unique_lock<std::mutex> lock(mutex_);
if (nccl_comm_) {
platform::dynload::ncclCommDestroy(nccl_comm_);
}
}
static std::shared_ptr<NCCLCommManager> Create(int num_ranks,
int rank,
ncclUniqueId comm_id) {
auto nccl_manager = std::make_shared<NCCLCommManager>();
NCCLCHECK(platform::dynload::ncclCommInitRank(
&(nccl_manager->nccl_comm_), num_ranks, comm_id, rank));
nccl_manager->nccl_id_ = comm_id;
nccl_manager->rank_ = rank;
return nccl_manager;
}
ncclUniqueId GetNcclId() const {
std::unique_lock<std::mutex> lock(mutex_);
return nccl_id_;
}
ncclComm_t GetNcclComm() const {
std::unique_lock<std::mutex> lock(mutex_);
return nccl_comm_;
}
NCCLCommManager(const NCCLCommManager&) = delete;
NCCLCommManager& operator=(const NCCLCommManager&) = delete;
NCCLCommManager& operator=(NCCLCommManager&& other) = delete;
NCCLCommManager(NCCLCommManager&& other) {
std::unique_lock<std::mutex> lock(other.mutex_);
std::swap(nccl_comm_, other.nccl_comm_);
}
protected:
ncclComm_t nccl_comm_;
ncclUniqueId nccl_id_;
int rank_;
mutable std::mutex mutex_;
};
ncclRedOp_t ToNCCLRedType(ReduceOp reduction);
std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID);
......
......@@ -17,15 +17,7 @@
namespace paddle {
namespace distributed {
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op)
ProcessGroup::Task::Task(int rank, CommType comm_type, bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
ProcessGroup::Task::~Task() = default;
......@@ -62,5 +54,17 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid)
}
}
// TODO(sunyilun): methods below will be removed later
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
} // namespace distributed
} // namespace paddle
......@@ -54,13 +54,7 @@ class ProcessGroup {
public:
class Task {
public:
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type);
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op);
Task(int rank, CommType comm_type, bool sync_op);
virtual ~Task();
virtual bool IsCompleted();
......@@ -69,6 +63,15 @@ class ProcessGroup {
virtual void UpdateWaitChain(const phi::DeviceContext& ctx);
bool IsSync() const { return sync_op_; }
// TODO(sunyilun): methods below will be removed later
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type);
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op);
protected:
const int rank_;
CommType comm_type_{CommType::UNKNOWN};
......@@ -79,6 +82,7 @@ class ProcessGroup {
bool sync_op_{true};
};
public:
explicit ProcessGroup(int rank,
int size,
const platform::Place& place,
......@@ -93,12 +97,48 @@ class ProcessGroup {
int GetSize() const { return size_; }
virtual std::string GetBackendName() const = 0;
virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const {
PADDLE_THROW(platform::errors::InvalidArgument(
"Does not support to get device_context from ProcessGroup%s.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support all_gather with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support all_reduce with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support barrier", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support broadcast with sync_op flag",
GetBackendName()));
}
// TODO(liyurui): This API will be moved later
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
......@@ -118,6 +158,7 @@ class ProcessGroup {
GetBackendName()));
}
// TODO(sunyilun): methods below will be removed later
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
......@@ -136,12 +177,6 @@ class ProcessGroup {
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support barrier", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>&, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
......
......@@ -229,6 +229,17 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
}
};
// TODO(sunyilun): for compatibility, will be updated later
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper = {in_tensor};
std::vector<phi::DenseTensor> out_wrapper = {*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
......
......@@ -108,6 +108,13 @@ class ProcessGroupGloo : public ProcessGroup {
~ProcessGroupGloo() = default;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
......
......@@ -24,10 +24,10 @@
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_event.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/NCCLTools.h"
......@@ -44,16 +44,28 @@ namespace distributed {
using Place = paddle::platform::Place;
class ProcessGroupNCCL : public ProcessGroupStream {
class ProcessGroupNCCL final : public ProcessGroupStream {
public:
class NCCLTask : public ProcessGroupStream::TaskStream,
public std::enable_shared_from_this<NCCLTask> {
class NCCLTask final : public ProcessGroupStream::TaskStream,
public std::enable_shared_from_this<NCCLTask> {
public:
NCCLTask(const Place& place,
int rank,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
virtual ~NCCLTask();
bool IsCompleted() override;
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override;
void Synchronize() override;
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
// TODO(sunyilun): methods below will be removed later
NCCLTask(const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs);
NCCLTask(const std::vector<Place>& places,
int rank,
CommType comm_type,
......@@ -61,31 +73,15 @@ class ProcessGroupNCCL : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream);
bool IsCompleted();
void SynchronizeStreams();
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
void Synchronize();
void SetOutputs(std::vector<phi::DenseTensor>& outputs); // NOLINT
virtual ~NCCLTask();
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
std::vector<EventManager> control_events_;
std::vector<phi::DenseTensor> barrierTensors_;
protected:
std::vector<Place> places_;
std::vector<std::shared_ptr<NCCLCommManager>> ncclComms_;
std::shared_ptr<std::vector<phi::DenseTensor>> outputs_;
public:
bool barrier_{false};
platform::DeviceEvent comm_event_; // event on comm stream
private:
Place place_;
};
public:
ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank,
int size,
......@@ -99,34 +95,47 @@ class ProcessGroupNCCL : public ProcessGroupStream {
const phi::DeviceContext& GetDeviceContext(
const Place& place, bool use_calc_stream) const override;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const AllreduceOptions& options,
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) override;
static void GroupStart();
static void GroupEnd();
ncclComm_t NCCLComm(const Place& place) const;
// TODO(liyurui): This API will be moved later
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
......@@ -175,12 +184,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......@@ -255,20 +258,37 @@ class ProcessGroupNCCL : public ProcessGroupStream {
phi::DenseTensor&, // NOLINT
const ReduceScatterOptions&) override;
static void GroupStart();
private:
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
int rank,
CommType op_type,
bool sync_op,
bool use_calc_stream);
static void GroupEnd();
void BroadcastUniqueNCCLID(ncclUniqueId* nccl_id);
ncclComm_t NCCLComm(const Place& place) const;
void CreateNCCLEnvCache(const Place& place, const std::string& place_key);
template <typename Fn>
std::shared_ptr<ProcessGroupStream::Task> Collective(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
Fn fn,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
protected:
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
void SyncCalcStream(const Place& place,
const std::shared_ptr<platform::DeviceEvent>& event);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places,
int rank,
CommType op_type,
const std::vector<phi::DenseTensor>& inputs);
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
const std::vector<Place>& places,
int rank,
CommType op_type,
......@@ -276,27 +296,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream);
protected:
std::shared_ptr<Store> store_;
std::shared_ptr<NCCLCommManager> nccl_comm_;
std::mutex mutex_;
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLCommManager>>>
places_to_ncclcomm_;
std::unordered_map<std::string, std::vector<EventManager>> places_to_events_;
std::unordered_map<std::string, std::vector<std::unique_ptr<phi::GPUContext>>>
places_to_ctx_;
std::set<int> used_place_ids_;
private:
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, // NOLINT
int root, // NOLINT
int server_fd);
void BroadcastUniqueNCCLID(std::vector<ncclUniqueId>& nccl_ids); // NOLINT
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<phi::DenseTensor>& inputs, // NOLINT
......@@ -340,6 +339,17 @@ class ProcessGroupNCCL : public ProcessGroupStream {
void CheckSplitSizes(std::vector<int64_t>* split_sizes,
std::vector<int64_t> tensor_shape);
private:
std::shared_ptr<Store> store_;
std::shared_ptr<platform::DeviceEvent> calc_event_; // event on calc stream
std::unordered_map<std::string, phi::GPUContext*> place_to_calc_ctx_;
std::unordered_map<std::string, std::unique_ptr<phi::GPUContext>>
place_to_comm_ctx_;
// TODO(sunyilun): attrs below will be removed later
std::mutex mutex_;
std::unordered_map<std::string, std::vector<phi::GPUContext*>> places_to_ctx_;
};
} // namespace distributed
......
......@@ -30,18 +30,18 @@ const phi::DeviceContext& ProcessGroupStream::GetDeviceContext(
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) {
return AllGather(input_tensors,
output_tensors,
return AllGather(out_tensor,
in_tensor,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -49,27 +49,50 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
const AllreduceOptions& options,
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) {
return AllReduce(input_tensors,
output_tensors,
options,
return AllReduce(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
const AllreduceOptions& options,
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do all_reduce", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
return Broadcast(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do broadcast", GetBackendName()));
}
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......@@ -114,28 +137,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle(
"ProcessGroup%s does not support do alltoall_single", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op) {
return Broadcast(in_tensors,
out_tensors,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do broadcast", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......
......@@ -27,6 +27,11 @@ class ProcessGroupStream : public ProcessGroup {
public:
class TaskStream : public ProcessGroup::Task {
public:
TaskStream(int rank, CommType comm_type, bool sync_op, bool use_calc_stream)
: Task(rank, comm_type, sync_op), use_calc_stream_(use_calc_stream) {}
virtual ~TaskStream() = default;
// TODO(liyurui): This constructor is temporary here for compatible reason,
// will be deleted soon.
TaskStream(int rank,
......@@ -42,8 +47,6 @@ class ProcessGroupStream : public ProcessGroup {
: Task(rank, inputs, comm_type, sync_op),
use_calc_stream_(use_calc_stream) {}
virtual ~TaskStream() = default;
protected:
bool UseCalcStream() const { return use_calc_stream_; }
......@@ -51,6 +54,7 @@ class ProcessGroupStream : public ProcessGroup {
bool use_calc_stream_{false};
};
public:
ProcessGroupStream(int rank, int size, const platform::Place& place, int gid);
virtual ~ProcessGroupStream() = default;
......@@ -58,29 +62,43 @@ class ProcessGroupStream : public ProcessGroup {
const Place& place, bool use_calc_stream) const;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
const AllreduceOptions& options,
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
const AllreduceOptions& options,
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
......@@ -107,19 +125,6 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const BroadcastOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
......
......@@ -51,14 +51,9 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
auto pg_nccl = static_cast<distributed::ProcessGroupNCCL *>(pg);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(tensor);
out_tensor.push_back(tensor);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true);
auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
task->Wait();
} else {
auto dtype = platform::ToNCCLDataType(
......
......@@ -44,14 +44,9 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT
if (map->has(ring_id)) {
paddle::distributed::ProcessGroup* pg = map->get(ring_id);
auto pg_nccl = static_cast<distributed::ProcessGroupNCCL*>(pg);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(tensor);
out_tensor.push_back(tensor);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true);
auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
task->Wait();
} else {
auto dtype = platform::ToNCCLDataType(
......
......@@ -147,12 +147,12 @@ void BindDistributed(py::module *m) {
distributed::ReduceOp op,
bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts;
opts.reduce_op = op;
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.AllReduce(tensors, tensors, opts, sync_op);
auto *out_dense = p_dense.get();
auto in_dense = *p_dense;
distributed::AllreduceOptions opts{op};
return self.AllReduce(out_dense, in_dense, opts, sync_op);
},
py::arg("tensor"),
py::arg("op"),
......@@ -183,11 +183,12 @@ void BindDistributed(py::module *m) {
int src,
bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts{src};
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Broadcast(tensors, tensors, opts, sync_op);
auto *out_dense = p_dense.get();
auto in_dense = *p_dense;
distributed::BroadcastOptions opts{src};
return self.Broadcast(out_dense, in_dense, opts, sync_op);
},
py::arg("tensor"),
py::arg("src"),
......@@ -380,52 +381,52 @@ void BindDistributed(py::module *m) {
.def(
"allgather",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor_list,
py::handle py_in_tensor,
bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
const auto &dev_ctx = self.GetDeviceContext(in_tensor.place());
auto task = self.AllGather(in_wrapper, out_wrapper, sync_op);
auto task = self.AllGather(out_dense, in_dense, sync_op);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx);
return task;
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"allgather_into_tensor",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
py::handle py_in_tensor,
bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllGather(in_wrapper, out_wrapper, sync_op);
return self.AllGather(out_dense, in_dense, sync_op);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
......@@ -784,55 +785,55 @@ void BindDistributed(py::module *m) {
.def(
"allgather_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor_list) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
py::handle py_out_tensor_list,
py::handle py_in_tensor) {
auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
const auto &dev_ctx =
self.GetDeviceContext(in_tensor.place(), true);
auto task = self.AllGather(in_wrapper,
out_wrapper,
auto task = self.AllGather(out_dense,
in_dense,
/*sync_op*/ true,
/*use_calc_stream*/ true);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
return task;
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::call_guard<py::gil_scoped_release>())
.def(
"allgather_into_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
py::handle py_out_tensor,
py::handle py_in_tensor) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
return self.AllGather(in_wrapper,
out_wrapper,
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllGather(out_dense,
in_dense,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::call_guard<py::gil_scoped_release>())
.def(
......@@ -872,13 +873,13 @@ void BindDistributed(py::module *m) {
py::handle py_tensor,
distributed::ReduceOp op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts;
opts.reduce_op = op;
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.AllReduce(tensors,
tensors,
auto in_dense = *p_dense;
auto *out_dense = p_dense.get();
distributed::AllreduceOptions opts{op};
return self.AllReduce(out_dense,
in_dense,
opts,
/*sync_op*/ true,
/*use_calc_stream*/ true);
......@@ -980,12 +981,13 @@ void BindDistributed(py::module *m) {
py::handle py_tensor,
int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts{src};
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Broadcast(tensors,
tensors,
auto *out_dense = p_dense.get();
auto in_dense = *p_dense;
distributed::BroadcastOptions opts{src};
return self.Broadcast(out_dense,
in_dense,
opts,
/*sync_op*/ true,
/*use_calc_stream*/ true);
......
......@@ -21,18 +21,18 @@ def _check_tensor_shape(tensor, shape, nranks=1):
expect_shape = list(shape)
expect_shape[0] *= nranks
if list(tensor.shape) != expect_shape:
raise RuntimeError('The tensor for all_gather is not correctly-sized.')
raise RuntimeError("The tensor for all_gather is not correctly-sized.")
def _check_tensor_list_shape(tensor_list, shape, nranks=1):
if len(tensor_list) != nranks:
raise RuntimeError(
'The tensor_list for all_gather is not correctly-sized.'
"The tensor_list for all_gather is not correctly-sized."
)
for tensor in tensor_list:
if tensor.shape != shape:
raise RuntimeError(
'The tensor_list for all_gather is not correctly-sized.'
"The tensor_list for all_gather is not correctly-sized."
)
......@@ -45,11 +45,12 @@ def _all_gather_into_tensor_in_dygraph(
if use_calc_stream:
return group.process_group.allgather_into_tensor_on_calc_stream(
in_tensor, out_tensor
out_tensor,
in_tensor,
)
task = group.process_group.allgather_into_tensor(
in_tensor, out_tensor, sync_op
out_tensor, in_tensor, sync_op
)
if sync_op:
task.wait()
......@@ -68,9 +69,9 @@ def _all_gather_in_dygraph(
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
if use_calc_stream:
return group.process_group.allgather_on_calc_stream(tensor, tensor_list)
return group.process_group.allgather_on_calc_stream(tensor_list, tensor)
task = group.process_group.allgather(tensor, tensor_list, sync_op)
task = group.process_group.allgather(tensor_list, tensor, sync_op)
if sync_op:
task.wait()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册