未验证 提交 e1a1c354 编写于 作者: W Wen Sun 提交者: GitHub

Refactor collective communication all_gather, all_reduce, broadcast & barrier C++ API (#47481)

上级 2337e609
...@@ -28,6 +28,7 @@ if(WITH_NCCL OR WITH_RCCL) ...@@ -28,6 +28,7 @@ if(WITH_NCCL OR WITH_RCCL)
enforce enforce
collective_helper collective_helper
device_context device_context
${DEVICE_EVENT_LIBS}
dense_tensor) dense_tensor)
if(WITH_DISTRIBUTE AND WITH_PSCORE) if(WITH_DISTRIBUTE AND WITH_PSCORE)
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
......
...@@ -40,6 +40,8 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) { ...@@ -40,6 +40,8 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
return placeList; return placeList;
} }
std::string GetKeyFromPlace(const Place& place) { return place.DebugString(); }
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) { bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) {
return std::all_of( return std::all_of(
tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) { tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) {
......
...@@ -25,6 +25,8 @@ using Place = paddle::platform::Place; ...@@ -25,6 +25,8 @@ using Place = paddle::platform::Place;
std::vector<Place> GetPlaceList(const std::vector<phi::DenseTensor>& tensors); std::vector<Place> GetPlaceList(const std::vector<phi::DenseTensor>& tensors);
// Get the deviceList String from the list of devices // Get the deviceList String from the list of devices
std::string GetKeyFromPlaces(const std::vector<Place>& places); std::string GetKeyFromPlaces(const std::vector<Place>& places);
// Get the device string from one device
std::string GetKeyFromPlace(const Place& place);
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors); bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors);
......
...@@ -59,204 +59,6 @@ namespace distributed { ...@@ -59,204 +59,6 @@ namespace distributed {
} \ } \
} while (0) } while (0)
// NOTE(shenliang03): EventManager are movable not copyable CudaEvent wrapper.
// EventManage is different from paddle::platform::CudaEvent.
// It uses lazy initialization and is only created when the
// Record() method is called for the first time; it also monitors
// device information to ensure that recorded stream and event
// are on the same device.
class EventManager {
public:
EventManager() {}
explicit EventManager(unsigned int flags) : flags_{flags} {}
~EventManager() {
if (is_created_) {
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_HIP
hipEventDestroy(event_);
#else
cudaEventDestroy(event_);
#endif
}
}
EventManager(const EventManager&) = delete;
EventManager& operator=(const EventManager&) = delete;
EventManager(EventManager&& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
}
EventManager& operator=(EventManager&& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
return *this;
}
bool IsCreated() const { return is_created_; }
bool DeviceId() const { return device_index_; }
gpuEvent_t GetRawCudaEvent() const { return event_; }
void Record(const phi::GPUContext& ctx) {
auto device_index = ctx.GetPlace().device;
if (!is_created_) {
CreateEvent(device_index);
}
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"phi::GPUContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event_, ctx.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(event_, ctx.stream()));
#endif
}
bool Query() const {
#ifdef PADDLE_WITH_HIP
gpuError_t err = hipEventQuery(event_);
if (err == hipSuccess) {
return true;
}
if (err == hipErrorNotReady) {
return false;
}
#else
gpuError_t err = cudaEventQuery(event_);
if (err == cudaSuccess) {
return true;
}
if (err == cudaErrorNotReady) {
return false;
}
#endif
PADDLE_ENFORCE_GPU_SUCCESS(err);
return false;
}
void Synchronize() const {
if (is_created_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipEventSynchronize(event_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(event_));
#endif
}
}
void Block(const phi::GPUContext& ctx) const {
if (is_created_) {
auto device_index = ctx.GetPlace().device;
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"phi::GPUContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
platform::CUDADeviceGuard guard(device_index_);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamWaitEvent(ctx.stream(), event_, 0));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(ctx.stream(), event_, 0));
#endif
}
}
private:
#ifdef PADDLE_WITH_HIP
unsigned int flags_ = hipEventDefault;
#else
unsigned int flags_ = cudaEventDefault;
#endif
bool is_created_{false};
gpuEvent_t event_{};
int8_t device_index_{0};
private:
void CreateEvent(int device_index) {
device_index_ = device_index;
platform::CUDADeviceGuard guard(device_index);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipEventCreateWithFlags(&event_, flags_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreateWithFlags(&event_, flags_));
#endif
is_created_ = true;
}
};
// NOTE(shenliang03): NCCLCommManager is more lightweight than
// platform::NCCLComm
class NCCLCommManager {
public:
explicit NCCLCommManager(ncclComm_t ncclComm) : nccl_comm_(ncclComm) {}
NCCLCommManager() : NCCLCommManager(nullptr) {}
~NCCLCommManager() noexcept {
std::unique_lock<std::mutex> lock(mutex_);
if (nccl_comm_) {
platform::dynload::ncclCommDestroy(nccl_comm_);
}
}
static std::shared_ptr<NCCLCommManager> Create(int num_ranks,
int rank,
ncclUniqueId comm_id) {
auto nccl_manager = std::make_shared<NCCLCommManager>();
NCCLCHECK(platform::dynload::ncclCommInitRank(
&(nccl_manager->nccl_comm_), num_ranks, comm_id, rank));
nccl_manager->nccl_id_ = comm_id;
nccl_manager->rank_ = rank;
return nccl_manager;
}
ncclUniqueId GetNcclId() const {
std::unique_lock<std::mutex> lock(mutex_);
return nccl_id_;
}
ncclComm_t GetNcclComm() const {
std::unique_lock<std::mutex> lock(mutex_);
return nccl_comm_;
}
NCCLCommManager(const NCCLCommManager&) = delete;
NCCLCommManager& operator=(const NCCLCommManager&) = delete;
NCCLCommManager& operator=(NCCLCommManager&& other) = delete;
NCCLCommManager(NCCLCommManager&& other) {
std::unique_lock<std::mutex> lock(other.mutex_);
std::swap(nccl_comm_, other.nccl_comm_);
}
protected:
ncclComm_t nccl_comm_;
ncclUniqueId nccl_id_;
int rank_;
mutable std::mutex mutex_;
};
ncclRedOp_t ToNCCLRedType(ReduceOp reduction); ncclRedOp_t ToNCCLRedType(ReduceOp reduction);
std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID); std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID);
......
...@@ -17,15 +17,7 @@ ...@@ -17,15 +17,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
ProcessGroup::Task::Task(int rank, ProcessGroup::Task::Task(int rank, CommType comm_type, bool sync_op)
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {} : rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
ProcessGroup::Task::~Task() = default; ProcessGroup::Task::~Task() = default;
...@@ -62,5 +54,17 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid) ...@@ -62,5 +54,17 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid)
} }
} }
// TODO(sunyilun): methods below will be removed later
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -54,13 +54,7 @@ class ProcessGroup { ...@@ -54,13 +54,7 @@ class ProcessGroup {
public: public:
class Task { class Task {
public: public:
Task(int rank, Task(int rank, CommType comm_type, bool sync_op);
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type);
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op);
virtual ~Task(); virtual ~Task();
virtual bool IsCompleted(); virtual bool IsCompleted();
...@@ -69,6 +63,15 @@ class ProcessGroup { ...@@ -69,6 +63,15 @@ class ProcessGroup {
virtual void UpdateWaitChain(const phi::DeviceContext& ctx); virtual void UpdateWaitChain(const phi::DeviceContext& ctx);
bool IsSync() const { return sync_op_; } bool IsSync() const { return sync_op_; }
// TODO(sunyilun): methods below will be removed later
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type);
Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op);
protected: protected:
const int rank_; const int rank_;
CommType comm_type_{CommType::UNKNOWN}; CommType comm_type_{CommType::UNKNOWN};
...@@ -79,6 +82,7 @@ class ProcessGroup { ...@@ -79,6 +82,7 @@ class ProcessGroup {
bool sync_op_{true}; bool sync_op_{true};
}; };
public:
explicit ProcessGroup(int rank, explicit ProcessGroup(int rank,
int size, int size,
const platform::Place& place, const platform::Place& place,
...@@ -93,12 +97,48 @@ class ProcessGroup { ...@@ -93,12 +97,48 @@ class ProcessGroup {
int GetSize() const { return size_; } int GetSize() const { return size_; }
virtual std::string GetBackendName() const = 0; virtual std::string GetBackendName() const = 0;
virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const { virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Does not support to get device_context from ProcessGroup%s.", "Does not support to get device_context from ProcessGroup%s.",
GetBackendName())); GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support all_gather with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support all_reduce with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support barrier", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support broadcast with sync_op flag",
GetBackendName()));
}
// TODO(liyurui): This API will be moved later // TODO(liyurui): This API will be moved later
virtual std::shared_ptr<ProcessGroup::Task> AllReduce( virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
...@@ -118,6 +158,7 @@ class ProcessGroup { ...@@ -118,6 +158,7 @@ class ProcessGroup {
GetBackendName())); GetBackendName()));
} }
// TODO(sunyilun): methods below will be removed later
virtual std::shared_ptr<ProcessGroup::Task> Broadcast( virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
...@@ -136,12 +177,6 @@ class ProcessGroup { ...@@ -136,12 +177,6 @@ class ProcessGroup {
GetBackendName())); GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support barrier", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send( virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>&, int) { // NOLINT std::vector<phi::DenseTensor>&, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
......
...@@ -229,6 +229,17 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -229,6 +229,17 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
} }
}; };
// TODO(sunyilun): for compatibility, will be updated later
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper = {in_tensor};
std::vector<phi::DenseTensor> out_wrapper = {*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast( std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
std::vector<phi::DenseTensor>& inputs, std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, std::vector<phi::DenseTensor>& outputs,
......
...@@ -108,6 +108,13 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -108,6 +108,13 @@ class ProcessGroupGloo : public ProcessGroup {
~ProcessGroupGloo() = default; ~ProcessGroupGloo() = default;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast( std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& inputs, std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, std::vector<phi::DenseTensor>& outputs,
......
...@@ -15,12 +15,8 @@ ...@@ -15,12 +15,8 @@
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" #include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/distributed/collective/Common.h" #include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
DECLARE_bool(nccl_blocking_wait); DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator); DECLARE_bool(use_stream_safe_cuda_allocator);
...@@ -30,89 +26,299 @@ constexpr int64_t kWaitBlockTImeout = 10; ...@@ -30,89 +26,299 @@ constexpr int64_t kWaitBlockTImeout = 10;
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
void SyncDefaultStream( ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place,
const std::vector<Place>& places, int rank,
std::vector<EventManager>& ncclEvents, // NOLINT CommType comm_type,
std::vector<std::unique_ptr<phi::GPUContext>>& dev_ctx) { // NOLINT bool sync_op,
for (size_t i = 0; i < places.size(); ++i) { bool use_calc_stream)
auto* default_ctx = static_cast<phi::GPUContext*>( : TaskStream(rank, comm_type, sync_op, use_calc_stream),
platform::DeviceContextPool::Instance().Get(places[i])); comm_event_(place),
ncclEvents[i].Record(*default_ctx); place_(place) {}
ncclEvents[i].Block(*dev_ctx[i]);
ProcessGroupNCCL::NCCLTask::~NCCLTask() {}
bool ProcessGroupNCCL::NCCLTask::IsCompleted() { return comm_event_.Query(); }
void ProcessGroupNCCL::NCCLTask::UpdateWaitChain(
const phi::DeviceContext& ctx) {
comm_event_.Record(&ctx);
}
// TODO(sheniang03): Add timeout for wait, now timeout unused
bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
// Warning here when use calc stream but also invoke waiting explicitly.
if (UseCalcStream()) {
VLOG(3) << "Warning: The communication is on calc stream, wait here is "
"useless.";
return true;
} }
const auto* calc_ctx = platform::DeviceContextPool::Instance().Get(place_);
comm_event_.Wait(platform::Place2DeviceType(place_), calc_ctx);
if (FLAGS_nccl_blocking_wait) {
// NOTE(shenliang03): It will block host for sync
while (!IsCompleted()) {
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
}
}
if (barrier_) {
// If we use the work to do barrier, we should block cpu
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
}
return true;
} }
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask( // Same as Wait
std::vector<Place> places, void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank, int rank,
CommType comm_type, int size,
const std::vector<phi::DenseTensor>& inputs) { const platform::Place& place,
return std::make_shared<ProcessGroupNCCL::NCCLTask>( int gid)
places, rank, comm_type, inputs); : ProcessGroupStream(rank, size, place, gid), store_(store) {
platform::SetDeviceId(place_.device);
}
void ProcessGroupNCCL::GroupStart() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
}
void ProcessGroupNCCL::GroupEnd() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
const Place& place) const {
return GetDeviceContext(place, /*use_calc_stream*/ false);
}
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
const std::string& key = GetKeyFromPlace(place);
if (use_calc_stream) {
const auto& iter = place_to_calc_ctx_.find(key);
return *iter->second;
} else {
const auto& iter = place_to_comm_ctx_.find(key);
PADDLE_ENFORCE_NE(
iter,
place_to_comm_ctx_.end(),
platform::errors::NotFound(
"Cannot find the device context in this process group."));
return *iter->second;
}
}
ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
const std::string& key = GetKeyFromPlace(place);
const auto& iter = place_to_comm_ctx_.find(key);
PADDLE_ENFORCE_NE(
iter,
place_to_comm_ctx_.end(),
platform::errors::NotFound(
"Cannot find the NCCL commmunicator in this process group."));
return iter->second->nccl_comm();
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllGather(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllReduce(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op),
comm,
stream);
},
CommType::ALLREDUCE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
const BarrierOptions& opts) {
auto allocator = std::unique_ptr<phi::Allocator>(
new paddle::experimental::DefaultAllocator(place_));
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
phi::DenseTensor barrier_tensor{allocator.get(), meta};
auto task = AllReduce(&barrier_tensor,
barrier_tensor,
{},
/*sync_op*/ true,
/*use_calc_stream*/ false);
auto nccl_task = dynamic_cast<NCCLTask*>(task.get());
nccl_task->barrier_ = true;
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream) {
int root = opts.source_rank + opts.source_root;
return platform::dynload::ncclBroadcast(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
root,
comm,
stream);
},
CommType::BROADCAST,
sync_op,
use_calc_stream);
} }
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask( std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
const std::vector<Place>& places, const Place& place,
int rank, int rank,
CommType comm_type, CommType comm_type,
const std::vector<phi::DenseTensor>& inputs,
bool is_sync, bool is_sync,
bool use_calc_stream) { bool use_calc_stream) {
return std::make_shared<ProcessGroupNCCL::NCCLTask>( return std::make_shared<ProcessGroupNCCL::NCCLTask>(
places, rank, comm_type, inputs, is_sync, use_calc_stream); place, rank, comm_type, is_sync, use_calc_stream);
} }
ProcessGroupNCCL::NCCLTask::NCCLTask( void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) {
const std::vector<Place>& places, const std::string key =
int rank, "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/0";
CommType CommType, if (rank_ == 0) {
const std::vector<phi::DenseTensor>& inputs) std::vector<uint8_t> nccl_id_wrapper(
: TaskStream(rank, inputs, CommType), places_(places) { reinterpret_cast<uint8_t*>(nccl_id),
control_events_.resize(places.size()); reinterpret_cast<uint8_t*>(nccl_id) + NCCL_UNIQUE_ID_BYTES);
ncclComms_.resize(places.size()); store_->set(key, nccl_id_wrapper);
} else {
const auto& nccl_id_wrapper = store_->get(key);
std::memcpy(nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size());
}
} }
ProcessGroupNCCL::NCCLTask::NCCLTask( void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
const std::vector<Place>& places, const std::string& place_key) {
int rank, ncclUniqueId nccl_id;
CommType comm_type, if (rank_ == 0) {
const std::vector<phi::DenseTensor>& inputs, PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id));
bool sync_op, }
bool use_calc_stream) BroadcastUniqueNCCLID(&nccl_id);
: TaskStream(rank, inputs, comm_type, sync_op, use_calc_stream),
places_(places) {
control_events_.resize(places.size());
ncclComms_.resize(places.size());
}
ProcessGroupNCCL::NCCLTask::~NCCLTask() {} VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << place_key
<< ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);
calc_event_ = std::make_shared<platform::DeviceEvent>(place);
auto* calc_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place));
auto comm_ctx = std::make_unique<phi::GPUContext>(place);
ncclComm_t nccl_comm;
NCCLCHECK(platform::dynload::ncclCommInitRank(
&nccl_comm, GetSize(), nccl_id, GetRank()));
comm_ctx->set_nccl_comm(nccl_comm);
place_to_calc_ctx_[place_key] = calc_ctx;
place_to_comm_ctx_[place_key] = std::move(comm_ctx);
void ProcessGroupNCCL::NCCLTask::SetOutputs( // TODO(sunyilun): for compatibility, will be removed later
std::vector<phi::DenseTensor>& outputs) { // NOLINT places_to_ctx_[place_key] = {place_to_comm_ctx_[place_key].get()};
outputs_ = std::make_shared<std::vector<phi::DenseTensor>>(outputs);
} }
void ProcessGroupNCCL::NCCLTask::SynchronizeStreams() { void ProcessGroupNCCL::SyncCalcStream(
for (size_t i = 0; i < places_.size(); ++i) { const Place& place, const std::shared_ptr<platform::DeviceEvent>& event) {
auto* default_ctx = static_cast<phi::GPUContext*>( const std::string& key = GetKeyFromPlace(place);
platform::DeviceContextPool::Instance().Get(places_[i])); const auto* calc_ctx = place_to_calc_ctx_[key];
default_ctx->WaitEvent(control_events_[i].GetRawCudaEvent()); const auto* comm_ctx = place_to_comm_ctx_[key].get();
} event->Record(calc_ctx);
event->Wait(platform::Place2DeviceType(place), comm_ctx);
} }
bool ProcessGroupNCCL::NCCLTask::IsCompleted() { template <typename Fn>
for (size_t i = 0; i < places_.size(); ++i) { std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
if (!control_events_[i].Query()) { phi::DenseTensor* out_tensor,
return false; const phi::DenseTensor& in_tensor,
Fn fn,
CommType comm_type,
bool sync_op,
bool use_calc_stream) {
const auto& place = in_tensor.place();
const auto& key = GetKeyFromPlace(place);
if (!calc_event_) {
CreateNCCLEnvCache(place, key);
} }
if (!use_calc_stream) {
SyncCalcStream(place, calc_event_);
} }
return true; auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);
}
void ProcessGroupNCCL::NCCLTask::UpdateWaitChain( const auto* calc_ctx = place_to_calc_ctx_[key];
const phi::DeviceContext& ctx) { const auto& comm_ctx = place_to_comm_ctx_[key];
control_events_[0].Record(*static_cast<const phi::GPUContext*>(&ctx)); auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
fn(out_tensor, in_tensor, comm_ctx->nccl_comm(), nccl_stream);
if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
memory::RecordStream(in_tensor.Holder(), nccl_stream);
}
task->comm_event_.Record(comm_ctx.get());
}
return task;
} }
void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes, void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
...@@ -144,70 +350,58 @@ void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes, ...@@ -144,70 +350,58 @@ void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
} }
} }
// TODO(sheniang03): Add timeout for wait, now timeout unused // TODO(sunyilun): methods below will be removed later
bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { void SyncDefaultStream(const std::vector<Place>& places,
// Warning here when use calc stream but also invoke waiting explicitly. const std::shared_ptr<platform::DeviceEvent>& nccl_event,
if (UseCalcStream()) { std::vector<phi::GPUContext*>& dev_ctx) { // NOLINT
VLOG(3) << "Warning: The communication is on calc stream, wait here is " for (size_t i = 0; i < places.size(); ++i) {
"useless."; auto* default_ctx = static_cast<phi::GPUContext*>(
return true; platform::DeviceContextPool::Instance().Get(places[i]));
} nccl_event->Record(default_ctx);
nccl_event->Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]);
SynchronizeStreams();
if (FLAGS_nccl_blocking_wait) {
// NOTE(shenliang03): It will block host for sync
while (!IsCompleted()) {
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
}
}
if (!barrierTensors_.empty()) {
// If we use the work to do barrier, we should block cpu
for (auto& place : places_) {
platform::CUDADeviceGuard gpuGuard(place);
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
}
} }
return true;
} }
// Same as Wait std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); } std::vector<Place> places,
ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank, int rank,
int size, CommType comm_type,
const platform::Place& place, const std::vector<phi::DenseTensor>& inputs) {
int gid) return std::make_shared<ProcessGroupNCCL::NCCLTask>(
: ProcessGroupStream(rank, size, place, gid), store_(store) { places, rank, comm_type, inputs);
platform::SetDeviceId(place_.device);
} }
void ProcessGroupNCCL::BroadcastUniqueNCCLID( std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
std::vector<ncclUniqueId>& nccl_ids) { // NOLINT const std::vector<Place>& places,
if (rank_ == 0) { int rank,
for (size_t i = 0; i < nccl_ids.size(); i++) { CommType comm_type,
auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/" + const std::vector<phi::DenseTensor>& inputs,
std::to_string(i); bool is_sync,
auto nccl_id = std::vector<uint8_t>( bool use_calc_stream) {
reinterpret_cast<uint8_t*>(&nccl_ids[i]), return std::make_shared<ProcessGroupNCCL::NCCLTask>(
reinterpret_cast<uint8_t*>(&nccl_ids[i]) + NCCL_UNIQUE_ID_BYTES); places, rank, comm_type, inputs, is_sync, use_calc_stream);
store_->set(key, nccl_id);
}
} else {
for (size_t i = 0; i < nccl_ids.size(); i++) {
auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
auto ret = store_->get(key);
std::memcpy(&nccl_ids[i], ret.data(), ret.size());
}
}
} }
ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs)
: TaskStream(rank, inputs, CommType),
comm_event_(places[0]),
place_(places[0]) {}
ProcessGroupNCCL::NCCLTask::NCCLTask(
const std::vector<Place>& places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs,
bool sync_op,
bool use_calc_stream)
: TaskStream(rank, inputs, comm_type, sync_op, use_calc_stream),
comm_event_(places[0]),
place_(places[0]) {}
// create NCCLManager cache for places_key // create NCCLManager cache for places_key
void ProcessGroupNCCL::CreateNCCLManagerCache( void ProcessGroupNCCL::CreateNCCLManagerCache(
const std::string& places_key, const std::vector<Place>& places) { const std::string& places_key, const std::vector<Place>& places) {
...@@ -217,22 +411,11 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( ...@@ -217,22 +411,11 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
"Not able to create/get the NCCL Communicator since " "Not able to create/get the NCCL Communicator since "
"the GPU place are not known")); "the GPU place are not known"));
std::vector<std::shared_ptr<NCCLCommManager>> nccl_comms; ncclUniqueId nccl_id;
nccl_comms.resize(places.size());
// using vector just for broadcast
std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1);
auto& nccl_id = nccl_ids.front();
for (auto& place : places) {
used_place_ids_.insert(place.GetDeviceId());
}
if (rank_ == 0) { if (rank_ == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id)); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id));
} }
BroadcastUniqueNCCLID(nccl_ids); BroadcastUniqueNCCLID(&nccl_id);
VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_ VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << places_key << ", place: " << places_key
...@@ -241,23 +424,33 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( ...@@ -241,23 +424,33 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
std::vector<std::unique_ptr<phi::GPUContext>> dev_ctx; std::vector<std::unique_ptr<phi::GPUContext>> dev_ctx;
dev_ctx.resize(places.size()); dev_ctx.resize(places.size());
std::vector<phi::GPUContext*> dev_ctx_raw;
dev_ctx_raw.resize(places.size());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
platform::CUDADeviceGuard guard(places[i]); platform::CUDADeviceGuard guard(places[i]);
nccl_comms[i] = NCCLCommManager::Create(GetSize(), GetRank(), nccl_id);
dev_ctx[i].reset(new phi::GPUContext(places[i])); dev_ctx[i].reset(new phi::GPUContext(places[i]));
ncclComm_t nccl_comm;
NCCLCHECK(platform::dynload::ncclCommInitRank(
&nccl_comm, GetSize(), nccl_id, GetRank()));
dev_ctx[i]->set_nccl_comm(nccl_comm);
dev_ctx_raw[i] = dev_ctx[i].get();
} }
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
calc_event_ = std::make_shared<platform::DeviceEvent>(places[0]);
std::vector<EventManager> events; // TODO(sunyilun): for compatibility, will be removed later
events.resize(places.size()); place_to_calc_ctx_[places_key] = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(places[0]));
place_to_comm_ctx_[places_key] = std::move(dev_ctx[0]);
// These caches will be useful to process sync/wait/communicate // These caches will be useful to process sync/wait/communicate
places_to_events_.emplace(places_key, std::move(events)); places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw));
places_to_ncclcomm_.emplace(places_key, std::move(nccl_comms));
places_to_ctx_.emplace(places_key, std::move(dev_ctx));
} }
template <typename Fn> template <typename Fn>
...@@ -273,15 +466,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -273,15 +466,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) { if (!calc_event_) {
CreateNCCLManagerCache(key, places); CreateNCCLManagerCache(key, places);
} }
} }
auto& nccl_comms = places_to_ncclcomm_[key];
if (!use_calc_stream) { if (!use_calc_stream) {
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
} }
auto task = auto task =
...@@ -304,7 +495,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -304,7 +495,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
nccl_stream = places_to_ctx_[key][i]->stream(); nccl_stream = places_to_ctx_[key][i]->stream();
} }
fn(inputs[i], outputs[i], nccl_comms[i]->GetNcclComm(), nccl_stream); fn(inputs[i],
outputs[i],
places_to_ctx_[key][i]->nccl_comm(),
nccl_stream);
} }
} }
...@@ -330,7 +524,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -330,7 +524,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
if (!use_calc_stream) { if (!use_calc_stream) {
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]); task->comm_event_.Record(places_to_ctx_[key][i]);
} }
} }
...@@ -348,14 +542,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -348,14 +542,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) { if (!calc_event_) {
CreateNCCLManagerCache(key, places); CreateNCCLManagerCache(key, places);
} }
} }
auto& nccl_comms = places_to_ncclcomm_[key]; SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
auto task = CreateTask(places, rank_, op_type, inputs); auto task = CreateTask(places, rank_, op_type, inputs);
...@@ -367,7 +559,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -367,7 +559,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream(); const auto& nccl_stream = places_to_ctx_[key][i]->stream();
fn(inputs[i], outputs[i], nccl_comms[i]->GetNcclComm(), nccl_stream); fn(inputs[i],
outputs[i],
places_to_ctx_[key][i]->nccl_comm(),
nccl_stream);
} }
} }
...@@ -381,7 +576,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -381,7 +576,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]); task->comm_event_.Record(places_to_ctx_[key][i]);
} }
return task; return task;
} }
...@@ -393,18 +588,16 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in, ...@@ -393,18 +588,16 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
CommType op_type) { CommType op_type) {
std::vector<Place> places; std::vector<Place> places;
places.push_back(in->place()); places.push_back(in->place());
const auto key = GetKeyFromPlaces(places); const std::string& key = GetKeyFromPlaces(places);
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) { if (!calc_event_) {
CreateNCCLManagerCache(key, places); CreateNCCLManagerCache(key, places);
} }
} }
auto& nccl_comms = places_to_ncclcomm_[key]; SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
// construct uninitialize guard for device // construct uninitialize guard for device
platform::CUDADeviceGuard cuda_guard; platform::CUDADeviceGuard cuda_guard;
...@@ -418,7 +611,7 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in, ...@@ -418,7 +611,7 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
platform::NCCLGroupGuard nccl_guard; platform::NCCLGroupGuard nccl_guard;
cuda_guard.SetDevice(places[0]); cuda_guard.SetDevice(places[0]);
const auto& nccl_stream = places_to_ctx_[key][0]->stream(); const auto& nccl_stream = places_to_ctx_[key][0]->stream();
fn(in, out, nccl_comms[0]->GetNcclComm(), nccl_stream); fn(in, out, places_to_ctx_[key][0]->nccl_comm(), nccl_stream);
} }
cuda_guard.SetDevice(places[0]); cuda_guard.SetDevice(places[0]);
...@@ -437,15 +630,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -437,15 +630,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) { if (!calc_event_) {
CreateNCCLManagerCache(key, places); CreateNCCLManagerCache(key, places);
} }
} }
auto& nccl_comms = places_to_ncclcomm_[key];
if (!use_calc_stream) { if (!use_calc_stream) {
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
} }
auto task = auto task =
...@@ -466,7 +657,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -466,7 +657,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
} else { } else {
nccl_stream = places_to_ctx_[key][i]->stream(); nccl_stream = places_to_ctx_[key][i]->stream();
} }
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); fn(tensors[i],
places_to_ctx_[key][i]->nccl_comm(),
nccl_stream,
dst_rank);
} }
} }
...@@ -489,7 +683,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -489,7 +683,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
if (!use_calc_stream) { if (!use_calc_stream) {
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]); task->comm_event_.Record(places_to_ctx_[key][i]);
} }
} }
...@@ -507,14 +701,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -507,14 +701,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) { if (!calc_event_) {
CreateNCCLManagerCache(key, places); CreateNCCLManagerCache(key, places);
} }
} }
auto& nccl_comms = places_to_ncclcomm_[key]; SyncDefaultStream(places, calc_event_, places_to_ctx_[key]);
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
auto task = CreateTask(places, rank_, op_type, tensors); auto task = CreateTask(places, rank_, op_type, tensors);
...@@ -526,7 +718,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -526,7 +718,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream(); const auto& nccl_stream = places_to_ctx_[key][i]->stream();
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); fn(tensors[i],
places_to_ctx_[key][i]->nccl_comm(),
nccl_stream,
dst_rank);
} }
} }
...@@ -540,7 +735,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -540,7 +735,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]); task->comm_event_.Record(places_to_ctx_[key][i]);
} }
return task; return task;
} }
...@@ -572,37 +767,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce( ...@@ -572,37 +767,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
CommType::ALLREDUCE); CommType::ALLREDUCE);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllReduce(
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op),
comm,
stream);
},
CommType::ALLREDUCE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
...@@ -633,63 +797,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( ...@@ -633,63 +797,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
CommType::BROADCAST); CommType::BROADCAST);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
const auto root =
opts.source_rank * in_tensors.size() + opts.source_root;
return platform::dynload::ncclBroadcast(
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
root,
comm,
stream);
},
CommType::BROADCAST,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
const BarrierOptions& opts) {
// Only support single card single process
std::vector<phi::GPUPlace> places = {place_};
std::vector<phi::DenseTensor> barrierTensors;
barrierTensors.reserve(places.size());
platform::CUDADeviceGuard gpuGuard;
for (auto& place : places) {
gpuGuard.SetDeviceIndex(place.GetDeviceId());
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim({1}));
auto allocator = std::unique_ptr<phi::Allocator>(
new paddle::experimental::DefaultAllocator(place));
barrierTensors.emplace_back(allocator.get(), meta);
}
auto task = ProcessGroupNCCL::AllReduce(
barrierTensors, barrierTensors, AllreduceOptions());
auto nccl_task = dynamic_cast<ProcessGroupNCCL::NCCLTask*>(task.get());
nccl_task->barrierTensors_ = std::move(barrierTensors);
return task;
}
void CheckTensorsInDifferentDevices( void CheckTensorsInDifferentDevices(
const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) { const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -975,39 +1082,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather( ...@@ -975,39 +1082,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
CommType::ALLGATHER); CommType::ALLGATHER);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllGather(
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER,
sync_op,
use_calc_stream);
}
void* GetPointerByOffset(void* raw_pointer, void* GetPointerByOffset(void* raw_pointer,
size_t offset, size_t offset,
experimental::DataType type) { experimental::DataType type) {
...@@ -1578,43 +1652,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::_ReduceScatterBase( ...@@ -1578,43 +1652,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::_ReduceScatterBase(
CommType::REDUCE_SCATTER); CommType::REDUCE_SCATTER);
} }
void ProcessGroupNCCL::GroupStart() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
}
void ProcessGroupNCCL::GroupEnd() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}
ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_ncclcomm_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_ncclcomm_.end(),
platform::errors::InvalidArgument(
"Cannot find nccl comm in process group."));
return iter->second[0]->GetNcclComm();
}
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
const Place& place) const {
return GetDeviceContext(place, /*use_calc_stream*/ false);
}
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
if (use_calc_stream) {
return *platform::DeviceContextPool::Instance().Get(place);
} else {
std::vector<Place> places = {place};
const auto& iter = places_to_ctx_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_ctx_.end(),
platform::errors::InvalidArgument(
"Cannot find device context in process group."));
return *iter->second[0];
}
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -24,10 +24,10 @@ ...@@ -24,10 +24,10 @@
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h" #include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
#include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_event.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/phi/common/place.h"
#include "paddle/fluid/platform/place.h" #include "paddle/phi/core/device_context.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/NCCLTools.h" #include "paddle/fluid/distributed/collective/NCCLTools.h"
...@@ -44,16 +44,28 @@ namespace distributed { ...@@ -44,16 +44,28 @@ namespace distributed {
using Place = paddle::platform::Place; using Place = paddle::platform::Place;
class ProcessGroupNCCL : public ProcessGroupStream { class ProcessGroupNCCL final : public ProcessGroupStream {
public: public:
class NCCLTask : public ProcessGroupStream::TaskStream, class NCCLTask final : public ProcessGroupStream::TaskStream,
public std::enable_shared_from_this<NCCLTask> { public std::enable_shared_from_this<NCCLTask> {
public: public:
NCCLTask(const Place& place,
int rank,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
virtual ~NCCLTask();
bool IsCompleted() override;
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override;
void Synchronize() override;
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
// TODO(sunyilun): methods below will be removed later
NCCLTask(const std::vector<Place>& places, NCCLTask(const std::vector<Place>& places,
int rank, int rank,
CommType CommType, CommType CommType,
const std::vector<phi::DenseTensor>& inputs); const std::vector<phi::DenseTensor>& inputs);
NCCLTask(const std::vector<Place>& places, NCCLTask(const std::vector<Place>& places,
int rank, int rank,
CommType comm_type, CommType comm_type,
...@@ -61,31 +73,15 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -61,31 +73,15 @@ class ProcessGroupNCCL : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
bool IsCompleted(); public:
bool barrier_{false};
void SynchronizeStreams(); platform::DeviceEvent comm_event_; // event on comm stream
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
void Synchronize();
void SetOutputs(std::vector<phi::DenseTensor>& outputs); // NOLINT
virtual ~NCCLTask();
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
std::vector<EventManager> control_events_;
std::vector<phi::DenseTensor> barrierTensors_;
protected:
std::vector<Place> places_;
std::vector<std::shared_ptr<NCCLCommManager>> ncclComms_;
std::shared_ptr<std::vector<phi::DenseTensor>> outputs_;
private: private:
Place place_;
}; };
public:
ProcessGroupNCCL(const std::shared_ptr<Store>& store, ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank, int rank,
int size, int size,
...@@ -99,34 +95,47 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -99,34 +95,47 @@ class ProcessGroupNCCL : public ProcessGroupStream {
const phi::DeviceContext& GetDeviceContext( const phi::DeviceContext& GetDeviceContext(
const Place& place, bool use_calc_stream) const override; const Place& place, bool use_calc_stream) const override;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& out_tensors, // NOLINT const phi::DenseTensor& in_tensor,
const AllreduceOptions& options, const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
static void GroupStart();
static void GroupEnd();
ncclComm_t NCCLComm(const Place& place) const;
// TODO(liyurui): This API will be moved later // TODO(liyurui): This API will be moved later
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override; const AllreduceOptions& = AllreduceOptions()) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast( std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override; const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Send( std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override; std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
...@@ -175,12 +184,6 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -175,12 +184,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override; std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllGather_Partial( std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
...@@ -255,20 +258,37 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -255,20 +258,37 @@ class ProcessGroupNCCL : public ProcessGroupStream {
phi::DenseTensor&, // NOLINT phi::DenseTensor&, // NOLINT
const ReduceScatterOptions&) override; const ReduceScatterOptions&) override;
static void GroupStart(); private:
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
int rank,
CommType op_type,
bool sync_op,
bool use_calc_stream);
static void GroupEnd(); void BroadcastUniqueNCCLID(ncclUniqueId* nccl_id);
ncclComm_t NCCLComm(const Place& place) const; void CreateNCCLEnvCache(const Place& place, const std::string& place_key);
template <typename Fn>
std::shared_ptr<ProcessGroupStream::Task> Collective(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
Fn fn,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
protected: void SyncCalcStream(const Place& place,
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask( const std::shared_ptr<platform::DeviceEvent>& event);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places, std::vector<Place> places,
int rank, int rank,
CommType op_type, CommType op_type,
const std::vector<phi::DenseTensor>& inputs); const std::vector<phi::DenseTensor>& inputs);
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask( std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
const std::vector<Place>& places, const std::vector<Place>& places,
int rank, int rank,
CommType op_type, CommType op_type,
...@@ -276,27 +296,6 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -276,27 +296,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
protected:
std::shared_ptr<Store> store_;
std::shared_ptr<NCCLCommManager> nccl_comm_;
std::mutex mutex_;
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLCommManager>>>
places_to_ncclcomm_;
std::unordered_map<std::string, std::vector<EventManager>> places_to_events_;
std::unordered_map<std::string, std::vector<std::unique_ptr<phi::GPUContext>>>
places_to_ctx_;
std::set<int> used_place_ids_;
private:
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, // NOLINT
int root, // NOLINT
int server_fd);
void BroadcastUniqueNCCLID(std::vector<ncclUniqueId>& nccl_ids); // NOLINT
template <typename Fn> template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective( std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
...@@ -340,6 +339,17 @@ class ProcessGroupNCCL : public ProcessGroupStream { ...@@ -340,6 +339,17 @@ class ProcessGroupNCCL : public ProcessGroupStream {
void CheckSplitSizes(std::vector<int64_t>* split_sizes, void CheckSplitSizes(std::vector<int64_t>* split_sizes,
std::vector<int64_t> tensor_shape); std::vector<int64_t> tensor_shape);
private:
std::shared_ptr<Store> store_;
std::shared_ptr<platform::DeviceEvent> calc_event_; // event on calc stream
std::unordered_map<std::string, phi::GPUContext*> place_to_calc_ctx_;
std::unordered_map<std::string, std::unique_ptr<phi::GPUContext>>
place_to_comm_ctx_;
// TODO(sunyilun): attrs below will be removed later
std::mutex mutex_;
std::unordered_map<std::string, std::vector<phi::GPUContext*>> places_to_ctx_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -30,18 +30,18 @@ const phi::DeviceContext& ProcessGroupStream::GetDeviceContext( ...@@ -30,18 +30,18 @@ const phi::DeviceContext& ProcessGroupStream::GetDeviceContext(
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& output_tensors, // NOLINT const phi::DenseTensor& in_tensor,
bool sync_op) { bool sync_op) {
return AllGather(input_tensors, return AllGather(out_tensor,
output_tensors, in_tensor,
sync_op, sync_op,
/*use_calc_stream*/ false); /*use_calc_stream*/ false);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& output_tensors, // NOLINT const phi::DenseTensor& in_tensor,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -49,27 +49,50 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather( ...@@ -49,27 +49,50 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& output_tensors, // NOLINT const phi::DenseTensor& in_tensor,
const AllreduceOptions& options, const AllreduceOptions& opts,
bool sync_op) { bool sync_op) {
return AllReduce(input_tensors, return AllReduce(out_tensor,
output_tensors, in_tensor,
options, opts,
sync_op, sync_op,
/*use_calc_stream*/ false); /*use_calc_stream*/ false);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& output_tensors, // NOLINT const phi::DenseTensor& in_tensor,
const AllreduceOptions& options, const AllreduceOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do all_reduce", GetBackendName())); "ProcessGroup%s does not support do all_reduce", GetBackendName()));
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
return Broadcast(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do broadcast", GetBackendName()));
}
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
...@@ -114,28 +137,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle( ...@@ -114,28 +137,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle(
"ProcessGroup%s does not support do alltoall_single", GetBackendName())); "ProcessGroup%s does not support do alltoall_single", GetBackendName()));
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op) {
return Broadcast(in_tensors,
out_tensors,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do broadcast", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
......
...@@ -27,6 +27,11 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -27,6 +27,11 @@ class ProcessGroupStream : public ProcessGroup {
public: public:
class TaskStream : public ProcessGroup::Task { class TaskStream : public ProcessGroup::Task {
public: public:
TaskStream(int rank, CommType comm_type, bool sync_op, bool use_calc_stream)
: Task(rank, comm_type, sync_op), use_calc_stream_(use_calc_stream) {}
virtual ~TaskStream() = default;
// TODO(liyurui): This constructor is temporary here for compatible reason, // TODO(liyurui): This constructor is temporary here for compatible reason,
// will be deleted soon. // will be deleted soon.
TaskStream(int rank, TaskStream(int rank,
...@@ -42,8 +47,6 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -42,8 +47,6 @@ class ProcessGroupStream : public ProcessGroup {
: Task(rank, inputs, comm_type, sync_op), : Task(rank, inputs, comm_type, sync_op),
use_calc_stream_(use_calc_stream) {} use_calc_stream_(use_calc_stream) {}
virtual ~TaskStream() = default;
protected: protected:
bool UseCalcStream() const { return use_calc_stream_; } bool UseCalcStream() const { return use_calc_stream_; }
...@@ -51,6 +54,7 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -51,6 +54,7 @@ class ProcessGroupStream : public ProcessGroup {
bool use_calc_stream_{false}; bool use_calc_stream_{false};
}; };
public:
ProcessGroupStream(int rank, int size, const platform::Place& place, int gid); ProcessGroupStream(int rank, int size, const platform::Place& place, int gid);
virtual ~ProcessGroupStream() = default; virtual ~ProcessGroupStream() = default;
...@@ -58,29 +62,43 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -58,29 +62,43 @@ class ProcessGroupStream : public ProcessGroup {
const Place& place, bool use_calc_stream) const; const Place& place, bool use_calc_stream) const;
std::shared_ptr<ProcessGroup::Task> AllGather( std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& out_tensors, // NOLINT const phi::DenseTensor& in_tensor,
bool sync_op) override; bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllGather( virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& out_tensors, // NOLINT const phi::DenseTensor& in_tensor,
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& output_tensors, // NOLINT const phi::DenseTensor& in_tensor,
const AllreduceOptions& options, const AllreduceOptions& opts,
bool sync_op) override; bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllReduce( virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& output_tensors, // NOLINT const phi::DenseTensor& in_tensor,
const AllreduceOptions& options, const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> AllToAll( std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT std::vector<phi::DenseTensor>& out_tensors, // NOLINT
...@@ -107,19 +125,6 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -107,19 +125,6 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const BroadcastOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Reduce( std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT std::vector<phi::DenseTensor>& out_tensors, // NOLINT
......
...@@ -51,14 +51,9 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT ...@@ -51,14 +51,9 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
if (map->has(ring_id)) { if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id); paddle::distributed::ProcessGroup *pg = map->get(ring_id);
auto pg_nccl = static_cast<distributed::ProcessGroupNCCL *>(pg); auto pg_nccl = static_cast<distributed::ProcessGroupNCCL *>(pg);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(tensor);
out_tensor.push_back(tensor);
paddle::distributed::AllreduceOptions opts; paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM; opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true); auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
task->Wait(); task->Wait();
} else { } else {
auto dtype = platform::ToNCCLDataType( auto dtype = platform::ToNCCLDataType(
......
...@@ -44,14 +44,9 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT ...@@ -44,14 +44,9 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT
if (map->has(ring_id)) { if (map->has(ring_id)) {
paddle::distributed::ProcessGroup* pg = map->get(ring_id); paddle::distributed::ProcessGroup* pg = map->get(ring_id);
auto pg_nccl = static_cast<distributed::ProcessGroupNCCL*>(pg); auto pg_nccl = static_cast<distributed::ProcessGroupNCCL*>(pg);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(tensor);
out_tensor.push_back(tensor);
paddle::distributed::AllreduceOptions opts; paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM; opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true); auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
task->Wait(); task->Wait();
} else { } else {
auto dtype = platform::ToNCCLDataType( auto dtype = platform::ToNCCLDataType(
......
...@@ -147,12 +147,12 @@ void BindDistributed(py::module *m) { ...@@ -147,12 +147,12 @@ void BindDistributed(py::module *m) {
distributed::ReduceOp op, distributed::ReduceOp op,
bool sync_op) { bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts; auto p_dense =
opts.reduce_op = op;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense}; auto *out_dense = p_dense.get();
return self.AllReduce(tensors, tensors, opts, sync_op); auto in_dense = *p_dense;
distributed::AllreduceOptions opts{op};
return self.AllReduce(out_dense, in_dense, opts, sync_op);
}, },
py::arg("tensor"), py::arg("tensor"),
py::arg("op"), py::arg("op"),
...@@ -183,11 +183,12 @@ void BindDistributed(py::module *m) { ...@@ -183,11 +183,12 @@ void BindDistributed(py::module *m) {
int src, int src,
bool sync_op) { bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts{src}; auto p_dense =
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense}; auto *out_dense = p_dense.get();
return self.Broadcast(tensors, tensors, opts, sync_op); auto in_dense = *p_dense;
distributed::BroadcastOptions opts{src};
return self.Broadcast(out_dense, in_dense, opts, sync_op);
}, },
py::arg("tensor"), py::arg("tensor"),
py::arg("src"), py::arg("src"),
...@@ -380,52 +381,52 @@ void BindDistributed(py::module *m) { ...@@ -380,52 +381,52 @@ void BindDistributed(py::module *m) {
.def( .def(
"allgather", "allgather",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor_list, py::handle py_out_tensor_list,
py::handle py_in_tensor,
bool sync_op) { bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor_list = auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0); CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0); Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl()); concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
const auto &dev_ctx = self.GetDeviceContext(in_tensor.place()); const auto &dev_ctx = self.GetDeviceContext(in_tensor.place());
auto task = self.AllGather(in_wrapper, out_wrapper, sync_op); auto task = self.AllGather(out_dense, in_dense, sync_op);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx); task->UpdateWaitChain(dev_ctx);
return task; return task;
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"allgather_into_tensor", "allgather_into_tensor",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor,
bool sync_op) { bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl()); out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
return self.AllGather(in_wrapper, out_wrapper, sync_op); auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllGather(out_dense, in_dense, sync_op);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
...@@ -784,55 +785,55 @@ void BindDistributed(py::module *m) { ...@@ -784,55 +785,55 @@ void BindDistributed(py::module *m) {
.def( .def(
"allgather_on_calc_stream", "allgather_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor, py::handle py_out_tensor_list,
py::handle py_out_tensor_list) { py::handle py_in_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor_list = auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0); CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0); Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl()); concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
const auto &dev_ctx = const auto &dev_ctx =
self.GetDeviceContext(in_tensor.place(), true); self.GetDeviceContext(in_tensor.place(), true);
auto task = self.AllGather(in_wrapper, auto task = self.AllGather(out_dense,
out_wrapper, in_dense,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
return task; return task;
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"allgather_into_tensor_on_calc_stream", "allgather_into_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor, py::handle py_out_tensor,
py::handle py_out_tensor) { py::handle py_in_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl()); out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
return self.AllGather(in_wrapper, auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
out_wrapper, auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllGather(out_dense,
in_dense,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
...@@ -872,13 +873,13 @@ void BindDistributed(py::module *m) { ...@@ -872,13 +873,13 @@ void BindDistributed(py::module *m) {
py::handle py_tensor, py::handle py_tensor,
distributed::ReduceOp op) { distributed::ReduceOp op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts; auto p_dense =
opts.reduce_op = op;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense}; auto in_dense = *p_dense;
return self.AllReduce(tensors, auto *out_dense = p_dense.get();
tensors, distributed::AllreduceOptions opts{op};
return self.AllReduce(out_dense,
in_dense,
opts, opts,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
...@@ -980,12 +981,13 @@ void BindDistributed(py::module *m) { ...@@ -980,12 +981,13 @@ void BindDistributed(py::module *m) {
py::handle py_tensor, py::handle py_tensor,
int src) { int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts{src}; auto p_dense =
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense}; auto *out_dense = p_dense.get();
return self.Broadcast(tensors, auto in_dense = *p_dense;
tensors, distributed::BroadcastOptions opts{src};
return self.Broadcast(out_dense,
in_dense,
opts, opts,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
......
...@@ -21,18 +21,18 @@ def _check_tensor_shape(tensor, shape, nranks=1): ...@@ -21,18 +21,18 @@ def _check_tensor_shape(tensor, shape, nranks=1):
expect_shape = list(shape) expect_shape = list(shape)
expect_shape[0] *= nranks expect_shape[0] *= nranks
if list(tensor.shape) != expect_shape: if list(tensor.shape) != expect_shape:
raise RuntimeError('The tensor for all_gather is not correctly-sized.') raise RuntimeError("The tensor for all_gather is not correctly-sized.")
def _check_tensor_list_shape(tensor_list, shape, nranks=1): def _check_tensor_list_shape(tensor_list, shape, nranks=1):
if len(tensor_list) != nranks: if len(tensor_list) != nranks:
raise RuntimeError( raise RuntimeError(
'The tensor_list for all_gather is not correctly-sized.' "The tensor_list for all_gather is not correctly-sized."
) )
for tensor in tensor_list: for tensor in tensor_list:
if tensor.shape != shape: if tensor.shape != shape:
raise RuntimeError( raise RuntimeError(
'The tensor_list for all_gather is not correctly-sized.' "The tensor_list for all_gather is not correctly-sized."
) )
...@@ -45,11 +45,12 @@ def _all_gather_into_tensor_in_dygraph( ...@@ -45,11 +45,12 @@ def _all_gather_into_tensor_in_dygraph(
if use_calc_stream: if use_calc_stream:
return group.process_group.allgather_into_tensor_on_calc_stream( return group.process_group.allgather_into_tensor_on_calc_stream(
in_tensor, out_tensor out_tensor,
in_tensor,
) )
task = group.process_group.allgather_into_tensor( task = group.process_group.allgather_into_tensor(
in_tensor, out_tensor, sync_op out_tensor, in_tensor, sync_op
) )
if sync_op: if sync_op:
task.wait() task.wait()
...@@ -68,9 +69,9 @@ def _all_gather_in_dygraph( ...@@ -68,9 +69,9 @@ def _all_gather_in_dygraph(
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks) _check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.allgather_on_calc_stream(tensor, tensor_list) return group.process_group.allgather_on_calc_stream(tensor_list, tensor)
task = group.process_group.allgather(tensor, tensor_list, sync_op) task = group.process_group.allgather(tensor_list, tensor, sync_op)
if sync_op: if sync_op:
task.wait() task.wait()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册