未验证 提交 e1a1c354 编写于 作者: W Wen Sun 提交者: GitHub

Refactor collective communication all_gather, all_reduce, broadcast & barrier C++ API (#47481)

上级 2337e609
...@@ -28,6 +28,7 @@ if(WITH_NCCL OR WITH_RCCL) ...@@ -28,6 +28,7 @@ if(WITH_NCCL OR WITH_RCCL)
enforce enforce
collective_helper collective_helper
device_context device_context
${DEVICE_EVENT_LIBS}
dense_tensor) dense_tensor)
if(WITH_DISTRIBUTE AND WITH_PSCORE) if(WITH_DISTRIBUTE AND WITH_PSCORE)
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
......
...@@ -40,6 +40,8 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) { ...@@ -40,6 +40,8 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
return placeList; return placeList;
} }
std::string GetKeyFromPlace(const Place& place) { return place.DebugString(); }
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) { bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) {
return std::all_of( return std::all_of(
tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) { tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) {
......
...@@ -25,6 +25,8 @@ using Place = paddle::platform::Place; ...@@ -25,6 +25,8 @@ using Place = paddle::platform::Place;
std::vector<Place> GetPlaceList(const std::vector<phi::DenseTensor>& tensors); std::vector<Place> GetPlaceList(const std::vector<phi::DenseTensor>& tensors);
// Get the deviceList String from the list of devices // Get the deviceList String from the list of devices
std::string GetKeyFromPlaces(const std::vector<Place>& places); std::string GetKeyFromPlaces(const std::vector<Place>& places);
// Get the device string from one device
std::string GetKeyFromPlace(const Place& place);
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors); bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors);
......
...@@ -59,204 +59,6 @@ namespace distributed { ...@@ -59,204 +59,6 @@ namespace distributed {
} \ } \
} while (0) } while (0)
// NOTE(shenliang03): EventManager are movable not copyable CudaEvent wrapper.
// EventManage is different from paddle::platform::CudaEvent.
// It uses lazy initialization and is only created when the
// Record() method is called for the first time; it also monitors
// device information to ensure that recorded stream and event
// are on the same device.
class EventManager {
public:
EventManager() {}
explicit EventManager(unsigned int flags) : flags_{flags} {}
~EventManager() {
if (is_created_) {
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_HIP
hipEventDestroy(event_);
#else
cudaEventDestroy(event_);
#endif
}
}
EventManager(const EventManager&) = delete;
EventManager& operator=(const EventManager&) = delete;
EventManager(EventManager&& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
}
EventManager& operator=(EventManager&& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
return *this;
}
bool IsCreated() const { return is_created_; }
bool DeviceId() const { return device_index_; }
gpuEvent_t GetRawCudaEvent() const { return event_; }
void Record(const phi::GPUContext& ctx) {
auto device_index = ctx.GetPlace().device;
if (!is_created_) {
CreateEvent(device_index);
}
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"phi::GPUContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event_, ctx.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(event_, ctx.stream()));
#endif
}
bool Query() const {
#ifdef PADDLE_WITH_HIP
gpuError_t err = hipEventQuery(event_);
if (err == hipSuccess) {
return true;
}
if (err == hipErrorNotReady) {
return false;
}
#else
gpuError_t err = cudaEventQuery(event_);
if (err == cudaSuccess) {
return true;
}
if (err == cudaErrorNotReady) {
return false;
}
#endif
PADDLE_ENFORCE_GPU_SUCCESS(err);
return false;
}
void Synchronize() const {
if (is_created_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipEventSynchronize(event_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(event_));
#endif
}
}
void Block(const phi::GPUContext& ctx) const {
if (is_created_) {
auto device_index = ctx.GetPlace().device;
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"phi::GPUContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamWaitEvent(ctx.stream(), event_, 0));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(ctx.stream(), event_, 0));
#endif
}
}
private:
#ifdef PADDLE_WITH_HIP
unsigned int flags_ = hipEventDefault;
#else
unsigned int flags_ = cudaEventDefault;
#endif
bool is_created_{false};
gpuEvent_t event_{};
int8_t device_index_{0};
private:
void CreateEvent(int device_index) {
device_index_ = device_index;
platform::CUDADeviceGuard guard(device_index);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipEventCreateWithFlags(&event_, flags_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreateWithFlags(&event_, flags_));
#endif
is_created_ = true;
}
};
// NOTE(shenliang03): NCCLCommManager is more lightweight than
// platform::NCCLComm
class NCCLCommManager {
public:
explicit NCCLCommManager(ncclComm_t ncclComm) : nccl_comm_(ncclComm) {}
NCCLCommManager() : NCCLCommManager(nullptr) {}
~NCCLCommManager() noexcept {
std::unique_lock<std::mutex> lock(mutex_);
if (nccl_comm_) {
platform::dynload::ncclCommDestroy(nccl_comm_);
}
}
static std::shared_ptr<NCCLCommManager> Create(int num_ranks,
int rank,
ncclUniqueId comm_id) {
auto nccl_manager = std::make_shared<NCCLCommManager>();
NCCLCHECK(platform::dynload::ncclCommInitRank(
&(nccl_manager->nccl_comm_), num_ranks, comm_id, rank));
nccl_manager->nccl_id_ = comm_id;
nccl_manager->rank_ = rank;
return nccl_manager;
}
ncclUniqueId GetNcclId() const {
std::unique_lock<std::mutex> lock(mutex_);
return nccl_id_;
}
ncclComm_t GetNcclComm() const {
std::unique_lock<std::mutex> lock(mutex_);
return nccl_comm_;
}
NCCLCommManager(const NCCLCommManager&) = delete;
NCCLCommManager& operator=(const NCCLCommManager&) = delete;
NCCLCommManager& operator=(NCCLCommManager&& other) = delete;
NCCLCommManager(NCCLCommManager&& other) {
std::unique_lock<std::mutex> lock(other.mutex_);
std::swap(nccl_comm_, other.nccl_comm_);
}
protected:
ncclComm_t nccl_comm_;
ncclUniqueId nccl_id_;
int rank_;
mutable std::mutex mutex_;
};
ncclRedOp_t ToNCCLRedType(ReduceOp reduction); ncclRedOp_t ToNCCLRedType(ReduceOp reduction);
std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID); std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID);
......
...@@ -17,15 +17,7 @@ ...@@ -17,15 +17,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
ProcessGroup::Task::Task(int rank, ProcessGroup::Task::Task(int rank, CommType comm_type, bool sync_op)
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {} : rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
ProcessGroup::Task::~Task() = default; ProcessGroup::Task::~Task() = default;
...@@ -62,5 +54,17 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid) ...@@ -62,5 +54,17 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid)
} }
} }
// TODO(sunyilun): methods below will be removed later
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -54,13 +54,7 @@ class ProcessGroup { ...@@ -54,13 +54,7 @@ class ProcessGroup {
public: public:
class Task { class Task {
public: public:
Task(int rank, Task(int rank, CommType comm_type, bool sync_op);
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type);
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op);
virtual ~Task(); virtual ~Task();
virtual bool IsCompleted(); virtual bool IsCompleted();
...@@ -69,6 +63,15 @@ class ProcessGroup { ...@@ -69,6 +63,15 @@ class ProcessGroup {
virtual void UpdateWaitChain(const phi::DeviceContext& ctx); virtual void UpdateWaitChain(const phi::DeviceContext& ctx);
bool IsSync() const { return sync_op_; } bool IsSync() const { return sync_op_; }
// TODO(sunyilun): methods below will be removed later
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type);
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op);
protected: protected:
const int rank_; const int rank_;
CommType comm_type_{CommType::UNKNOWN}; CommType comm_type_{CommType::UNKNOWN};
...@@ -79,6 +82,7 @@ class ProcessGroup { ...@@ -79,6 +82,7 @@ class ProcessGroup {
bool sync_op_{true}; bool sync_op_{true};
}; };
public:
explicit ProcessGroup(int rank, explicit ProcessGroup(int rank,
int size, int size,
const platform::Place& place, const platform::Place& place,
...@@ -93,12 +97,48 @@ class ProcessGroup { ...@@ -93,12 +97,48 @@ class ProcessGroup {
int GetSize() const { return size_; } int GetSize() const { return size_; }
virtual std::string GetBackendName() const = 0; virtual std::string GetBackendName() const = 0;
virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const { virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Does not support to get device_context from ProcessGroup%s.", "Does not support to get device_context from ProcessGroup%s.",
GetBackendName())); GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support all_gather with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support all_reduce with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support barrier", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support broadcast with sync_op flag",
GetBackendName()));
}
// TODO(liyurui): This API will be moved later // TODO(liyurui): This API will be moved later
virtual std::shared_ptr<ProcessGroup::Task> AllReduce( virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
...@@ -118,6 +158,7 @@ class ProcessGroup { ...@@ -118,6 +158,7 @@ class ProcessGroup {
GetBackendName())); GetBackendName()));
} }
// TODO(sunyilun): methods below will be removed later
virtual std::shared_ptr<ProcessGroup::Task> Broadcast( virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
...@@ -136,12 +177,6 @@ class ProcessGroup { ...@@ -136,12 +177,6 @@ class ProcessGroup {
GetBackendName())); GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support barrier", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send( virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>&, int) { // NOLINT std::vector<phi::DenseTensor>&, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
......
...@@ -229,6 +229,17 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -229,6 +229,17 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
} }
}; };
// TODO(sunyilun): for compatibility, will be updated later
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper = {in_tensor};
std::vector<phi::DenseTensor> out_wrapper = {*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast( std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
std::vector<phi::DenseTensor>& inputs, std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, std::vector<phi::DenseTensor>& outputs,
......
...@@ -108,6 +108,13 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -108,6 +108,13 @@ class ProcessGroupGloo : public ProcessGroup {
~ProcessGroupGloo() = default; ~ProcessGroupGloo() = default;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast( std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& inputs, std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, std::vector<phi::DenseTensor>& outputs,
......
...@@ -24,10 +24,10 @@ ...@@ -24,10 +24,10 @@
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h" #include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
#include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_event.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/phi/common/place.h"
#include "paddle/fluid/platform/place.h" #include "paddle/phi/core/device_context.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/NCCLTools.h" #include "paddle/fluid/distributed/collective/NCCLTools.h"
...@@ -44,16 +44,28 @@ namespace distributed { ...@@ -44,16 +44,28 @@ namespace distributed {
using Place = paddle::platform::Place; using Place = paddle::platform::Place;
class ProcessGroupNCCL : public ProcessGroupStream { class ProcessGroupNCCL final : public ProcessGroupStream {
public: public:
class NCCLTask : public ProcessGroupStream::TaskStream, class NCCLTask final : public ProcessGroupStream::TaskStream,
public std::enable_shared_from_this<NCCLTask> { public std::enable_shared_from_this<NCCLTask> {
public: public:
NCCLTask(const Place& place,
int rank,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
virtual ~NCCLTask();
bool IsCompleted() override;
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override;
void Synchronize() override;
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
// TODO(sunyilun): methods below will be removed later
NCCLTask(const std::vector<Place>& places, NCCLTask(const std::vector<Place>& places,
int rank, int rank,
CommType CommType, CommType CommType,
const std::vector<phi::DenseTensor>& inputs); const std::vector<phi::DenseTensor>& inputs);
NCCLTask(const std::vector<Place>& places, NCCLTask(const std::vector<Place>& places,
int rank, int rank,
CommType comm_type, CommType comm_type,
...@@ -61,31 +73,15 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -61,31 +73,15 @@ class ProcessGroupNCCL : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
bool IsCompleted(); public:
bool barrier_{false};
void SynchronizeStreams(); platform::DeviceEvent comm_event_; // event on comm stream
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
void Synchronize();
void SetOutputs(std::vector<phi::DenseTensor>& outputs); // NOLINT
virtual ~NCCLTask();
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
std::vector<EventManager> control_events_;
std::vector<phi::DenseTensor> barrierTensors_;
protected:
std::vector<Place> places_;
std::vector<std::shared_ptr<NCCLCommManager>> ncclComms_;
std::shared_ptr<std::vector<phi::DenseTensor>> outputs_;
private: private:
Place place_;
}; };
public:
ProcessGroupNCCL(const std::shared_ptr<Store>& store, ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank, int rank,
int size, int size,
...@@ -99,34 +95,47 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -99,34 +95,47 @@ class ProcessGroupNCCL : public ProcessGroupStream {
const phi::DeviceContext& GetDeviceContext( const phi::DeviceContext& GetDeviceContext(
const Place& place, bool use_calc_stream) const override; const Place& place, bool use_calc_stream) const override;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& out_tensors, // NOLINT const phi::DenseTensor& in_tensor,
const AllreduceOptions& options, const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
static void GroupStart();
static void GroupEnd();
ncclComm_t NCCLComm(const Place& place) const;
// TODO(liyurui): This API will be moved later // TODO(liyurui): This API will be moved later
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override; const AllreduceOptions& = AllreduceOptions()) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast( std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override; const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Send( std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override; std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
...@@ -175,12 +184,6 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -175,12 +184,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override; std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllGather_Partial( std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
...@@ -255,20 +258,37 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -255,20 +258,37 @@ class ProcessGroupNCCL : public ProcessGroupStream {
phi::DenseTensor&, // NOLINT phi::DenseTensor&, // NOLINT
const ReduceScatterOptions&) override; const ReduceScatterOptions&) override;
static void GroupStart(); private:
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
int rank,
CommType op_type,
bool sync_op,
bool use_calc_stream);
static void GroupEnd(); void BroadcastUniqueNCCLID(ncclUniqueId* nccl_id);
ncclComm_t NCCLComm(const Place& place) const; void CreateNCCLEnvCache(const Place& place, const std::string& place_key);
template <typename Fn>
std::shared_ptr<ProcessGroupStream::Task> Collective(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
Fn fn,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
protected: void SyncCalcStream(const Place& place,
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask( const std::shared_ptr<platform::DeviceEvent>& event);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places, std::vector<Place> places,
int rank, int rank,
CommType op_type, CommType op_type,
const std::vector<phi::DenseTensor>& inputs); const std::vector<phi::DenseTensor>& inputs);
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask( std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
const std::vector<Place>& places, const std::vector<Place>& places,
int rank, int rank,
CommType op_type, CommType op_type,
...@@ -276,27 +296,6 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -276,27 +296,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
protected:
std::shared_ptr<Store> store_;
std::shared_ptr<NCCLCommManager> nccl_comm_;
std::mutex mutex_;
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLCommManager>>>
places_to_ncclcomm_;
std::unordered_map<std::string, std::vector<EventManager>> places_to_events_;
std::unordered_map<std::string, std::vector<std::unique_ptr<phi::GPUContext>>>
places_to_ctx_;
std::set<int> used_place_ids_;
private:
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, // NOLINT
int root, // NOLINT
int server_fd);
void BroadcastUniqueNCCLID(std::vector<ncclUniqueId>& nccl_ids); // NOLINT
template <typename Fn> template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective( std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
...@@ -340,6 +339,17 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -340,6 +339,17 @@ class ProcessGroupNCCL : public ProcessGroupStream {
void CheckSplitSizes(std::vector<int64_t>* split_sizes, void CheckSplitSizes(std::vector<int64_t>* split_sizes,
std::vector<int64_t> tensor_shape); std::vector<int64_t> tensor_shape);
private:
std::shared_ptr<Store> store_;
std::shared_ptr<platform::DeviceEvent> calc_event_; // event on calc stream
std::unordered_map<std::string, phi::GPUContext*> place_to_calc_ctx_;
std::unordered_map<std::string, std::unique_ptr<phi::GPUContext>>
place_to_comm_ctx_;
// TODO(sunyilun): attrs below will be removed later
std::mutex mutex_;
std::unordered_map<std::string, std::vector<phi::GPUContext*>> places_to_ctx_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -30,18 +30,18 @@ const phi::DeviceContext& ProcessGroupStream::GetDeviceContext( ...@@ -30,18 +30,18 @@ const phi::DeviceContext& ProcessGroupStream::GetDeviceContext(
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& output_tensors, // NOLINT const phi::DenseTensor& in_tensor,
bool sync_op) { bool sync_op) {
return AllGather(input_tensors, return AllGather(out_tensor,
output_tensors, in_tensor,
sync_op, sync_op,
/*use_calc_stream*/ false); /*use_calc_stream*/ false);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& output_tensors, // NOLINT const phi::DenseTensor& in_tensor,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -49,27 +49,50 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather( ...@@ -49,27 +49,50 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& output_tensors, // NOLINT const phi::DenseTensor& in_tensor,
const AllreduceOptions& options, const AllreduceOptions& opts,
bool sync_op) { bool sync_op) {
return AllReduce(input_tensors, return AllReduce(out_tensor,
output_tensors, in_tensor,
options, opts,
sync_op, sync_op,
/*use_calc_stream*/ false); /*use_calc_stream*/ false);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& output_tensors, // NOLINT const phi::DenseTensor& in_tensor,
const AllreduceOptions& options, const AllreduceOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do all_reduce", GetBackendName())); "ProcessGroup%s does not support do all_reduce", GetBackendName()));
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
return Broadcast(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do broadcast", GetBackendName()));
}
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
...@@ -114,28 +137,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle( ...@@ -114,28 +137,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle(
"ProcessGroup%s does not support do alltoall_single", GetBackendName())); "ProcessGroup%s does not support do alltoall_single", GetBackendName()));
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op) {
return Broadcast(in_tensors,
out_tensors,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do broadcast", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
......
...@@ -27,6 +27,11 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -27,6 +27,11 @@ class ProcessGroupStream : public ProcessGroup {
public: public:
class TaskStream : public ProcessGroup::Task { class TaskStream : public ProcessGroup::Task {
public: public:
TaskStream(int rank, CommType comm_type, bool sync_op, bool use_calc_stream)
: Task(rank, comm_type, sync_op), use_calc_stream_(use_calc_stream) {}
virtual ~TaskStream() = default;
// TODO(liyurui): This constructor is temporary here for compatible reason, // TODO(liyurui): This constructor is temporary here for compatible reason,
// will be deleted soon. // will be deleted soon.
TaskStream(int rank, TaskStream(int rank,
...@@ -42,8 +47,6 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -42,8 +47,6 @@ class ProcessGroupStream : public ProcessGroup {
: Task(rank, inputs, comm_type, sync_op), : Task(rank, inputs, comm_type, sync_op),
use_calc_stream_(use_calc_stream) {} use_calc_stream_(use_calc_stream) {}
virtual ~TaskStream() = default;
protected: protected:
bool UseCalcStream() const { return use_calc_stream_; } bool UseCalcStream() const { return use_calc_stream_; }
...@@ -51,6 +54,7 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -51,6 +54,7 @@ class ProcessGroupStream : public ProcessGroup {
bool use_calc_stream_{false}; bool use_calc_stream_{false};
}; };
public:
ProcessGroupStream(int rank, int size, const platform::Place& place, int gid); ProcessGroupStream(int rank, int size, const platform::Place& place, int gid);
virtual ~ProcessGroupStream() = default; virtual ~ProcessGroupStream() = default;
...@@ -58,29 +62,43 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -58,29 +62,43 @@ class ProcessGroupStream : public ProcessGroup {
const Place& place, bool use_calc_stream) const; const Place& place, bool use_calc_stream) const;
std::shared_ptr<ProcessGroup::Task> AllGather( std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& out_tensors, // NOLINT const phi::DenseTensor& in_tensor,
bool sync_op) override; bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllGather( virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& out_tensors, // NOLINT const phi::DenseTensor& in_tensor,
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& output_tensors, // NOLINT const phi::DenseTensor& in_tensor,
const AllreduceOptions& options, const AllreduceOptions& opts,
bool sync_op) override; bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllReduce( virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& output_tensors, // NOLINT const phi::DenseTensor& in_tensor,
const AllreduceOptions& options, const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> AllToAll( std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT std::vector<phi::DenseTensor>& out_tensors, // NOLINT
...@@ -107,19 +125,6 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -107,19 +125,6 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const BroadcastOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Reduce( std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT std::vector<phi::DenseTensor>& out_tensors, // NOLINT
......
...@@ -51,14 +51,9 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT ...@@ -51,14 +51,9 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
if (map->has(ring_id)) { if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id); paddle::distributed::ProcessGroup *pg = map->get(ring_id);
auto pg_nccl = static_cast<distributed::ProcessGroupNCCL *>(pg); auto pg_nccl = static_cast<distributed::ProcessGroupNCCL *>(pg);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(tensor);
out_tensor.push_back(tensor);
paddle::distributed::AllreduceOptions opts; paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM; opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true); auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
task->Wait(); task->Wait();
} else { } else {
auto dtype = platform::ToNCCLDataType( auto dtype = platform::ToNCCLDataType(
......
...@@ -44,14 +44,9 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT ...@@ -44,14 +44,9 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT
if (map->has(ring_id)) { if (map->has(ring_id)) {
paddle::distributed::ProcessGroup* pg = map->get(ring_id); paddle::distributed::ProcessGroup* pg = map->get(ring_id);
auto pg_nccl = static_cast<distributed::ProcessGroupNCCL*>(pg); auto pg_nccl = static_cast<distributed::ProcessGroupNCCL*>(pg);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(tensor);
out_tensor.push_back(tensor);
paddle::distributed::AllreduceOptions opts; paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM; opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true); auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
task->Wait(); task->Wait();
} else { } else {
auto dtype = platform::ToNCCLDataType( auto dtype = platform::ToNCCLDataType(
......
...@@ -147,12 +147,12 @@ void BindDistributed(py::module *m) { ...@@ -147,12 +147,12 @@ void BindDistributed(py::module *m) {
distributed::ReduceOp op, distributed::ReduceOp op,
bool sync_op) { bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts; auto p_dense =
opts.reduce_op = op;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense}; auto *out_dense = p_dense.get();
return self.AllReduce(tensors, tensors, opts, sync_op); auto in_dense = *p_dense;
distributed::AllreduceOptions opts{op};
return self.AllReduce(out_dense, in_dense, opts, sync_op);
}, },
py::arg("tensor"), py::arg("tensor"),
py::arg("op"), py::arg("op"),
...@@ -183,11 +183,12 @@ void BindDistributed(py::module *m) { ...@@ -183,11 +183,12 @@ void BindDistributed(py::module *m) {
int src, int src,
bool sync_op) { bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts{src}; auto p_dense =
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense}; auto *out_dense = p_dense.get();
return self.Broadcast(tensors, tensors, opts, sync_op); auto in_dense = *p_dense;
distributed::BroadcastOptions opts{src};
return self.Broadcast(out_dense, in_dense, opts, sync_op);
}, },
py::arg("tensor"), py::arg("tensor"),
py::arg("src"), py::arg("src"),
...@@ -380,52 +381,52 @@ void BindDistributed(py::module *m) { ...@@ -380,52 +381,52 @@ void BindDistributed(py::module *m) {
.def( .def(
"allgather", "allgather",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor_list, py::handle py_out_tensor_list,
py::handle py_in_tensor,
bool sync_op) { bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor_list = auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0); CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0); Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl()); concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
const auto &dev_ctx = self.GetDeviceContext(in_tensor.place()); const auto &dev_ctx = self.GetDeviceContext(in_tensor.place());
auto task = self.AllGather(in_wrapper, out_wrapper, sync_op); auto task = self.AllGather(out_dense, in_dense, sync_op);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx); task->UpdateWaitChain(dev_ctx);
return task; return task;
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"allgather_into_tensor", "allgather_into_tensor",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor,
bool sync_op) { bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl()); out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
return self.AllGather(in_wrapper, out_wrapper, sync_op); auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllGather(out_dense, in_dense, sync_op);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
...@@ -784,55 +785,55 @@ void BindDistributed(py::module *m) { ...@@ -784,55 +785,55 @@ void BindDistributed(py::module *m) {
.def( .def(
"allgather_on_calc_stream", "allgather_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor, py::handle py_out_tensor_list,
py::handle py_out_tensor_list) { py::handle py_in_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor_list = auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0); CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0); Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl()); concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
const auto &dev_ctx = const auto &dev_ctx =
self.GetDeviceContext(in_tensor.place(), true); self.GetDeviceContext(in_tensor.place(), true);
auto task = self.AllGather(in_wrapper, auto task = self.AllGather(out_dense,
out_wrapper, in_dense,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
return task; return task;
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"allgather_into_tensor_on_calc_stream", "allgather_into_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor, py::handle py_out_tensor,
py::handle py_out_tensor) { py::handle py_in_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl()); out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
return self.AllGather(in_wrapper, auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
out_wrapper, auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllGather(out_dense,
in_dense,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
...@@ -872,13 +873,13 @@ void BindDistributed(py::module *m) { ...@@ -872,13 +873,13 @@ void BindDistributed(py::module *m) {
py::handle py_tensor, py::handle py_tensor,
distributed::ReduceOp op) { distributed::ReduceOp op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts; auto p_dense =
opts.reduce_op = op;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense}; auto in_dense = *p_dense;
return self.AllReduce(tensors, auto *out_dense = p_dense.get();
tensors, distributed::AllreduceOptions opts{op};
return self.AllReduce(out_dense,
in_dense,
opts, opts,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
...@@ -980,12 +981,13 @@ void BindDistributed(py::module *m) { ...@@ -980,12 +981,13 @@ void BindDistributed(py::module *m) {
py::handle py_tensor, py::handle py_tensor,
int src) { int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts{src}; auto p_dense =
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense}; auto *out_dense = p_dense.get();
return self.Broadcast(tensors, auto in_dense = *p_dense;
tensors, distributed::BroadcastOptions opts{src};
return self.Broadcast(out_dense,
in_dense,
opts, opts,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
......
...@@ -21,18 +21,18 @@ def _check_tensor_shape(tensor, shape, nranks=1): ...@@ -21,18 +21,18 @@ def _check_tensor_shape(tensor, shape, nranks=1):
expect_shape = list(shape) expect_shape = list(shape)
expect_shape[0] *= nranks expect_shape[0] *= nranks
if list(tensor.shape) != expect_shape: if list(tensor.shape) != expect_shape:
raise RuntimeError('The tensor for all_gather is not correctly-sized.') raise RuntimeError("The tensor for all_gather is not correctly-sized.")
def _check_tensor_list_shape(tensor_list, shape, nranks=1): def _check_tensor_list_shape(tensor_list, shape, nranks=1):
if len(tensor_list) != nranks: if len(tensor_list) != nranks:
raise RuntimeError( raise RuntimeError(
'The tensor_list for all_gather is not correctly-sized.' "The tensor_list for all_gather is not correctly-sized."
) )
for tensor in tensor_list: for tensor in tensor_list:
if tensor.shape != shape: if tensor.shape != shape:
raise RuntimeError( raise RuntimeError(
'The tensor_list for all_gather is not correctly-sized.' "The tensor_list for all_gather is not correctly-sized."
) )
...@@ -45,11 +45,12 @@ def _all_gather_into_tensor_in_dygraph( ...@@ -45,11 +45,12 @@ def _all_gather_into_tensor_in_dygraph(
if use_calc_stream: if use_calc_stream:
return group.process_group.allgather_into_tensor_on_calc_stream( return group.process_group.allgather_into_tensor_on_calc_stream(
in_tensor, out_tensor out_tensor,
in_tensor,
) )
task = group.process_group.allgather_into_tensor( task = group.process_group.allgather_into_tensor(
in_tensor, out_tensor, sync_op out_tensor, in_tensor, sync_op
) )
if sync_op: if sync_op:
task.wait() task.wait()
...@@ -68,9 +69,9 @@ def _all_gather_in_dygraph( ...@@ -68,9 +69,9 @@ def _all_gather_in_dygraph(
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks) _check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.allgather_on_calc_stream(tensor, tensor_list) return group.process_group.allgather_on_calc_stream(tensor_list, tensor)
task = group.process_group.allgather(tensor, tensor_list, sync_op) task = group.process_group.allgather(tensor_list, tensor, sync_op)
if sync_op: if sync_op:
task.wait() task.wait()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册