未验证 提交 62397cd2 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add comm context support (#56301)

上级 ede8fd55
......@@ -18,7 +18,7 @@
namespace paddle {
namespace distributed {
phi::ccl::CCLReduceOp ToCustomCCLRedType(ReduceOp reduction) {
phi::ccl::CCLReduceOp ToXCCLRedType(ReduceOp reduction) {
static const std::map<ReduceOp, phi::ccl::CCLReduceOp> red_type = {
{ReduceOp::MIN, phi::ccl::CCLReduceOp::MIN},
{ReduceOp::MAX, phi::ccl::CCLReduceOp::MAX},
......@@ -34,14 +34,5 @@ phi::ccl::CCLReduceOp ToCustomCCLRedType(ReduceOp reduction) {
return it->second;
}
std::string SerializeCustomCCLUniqueId(const phi::ccl::CCLRootId& ccl_id) {
const uint8_t* bytes = ccl_id.data();
std::ostringstream oss;
for (size_t i = 0; i < ccl_id.size(); ++i) {
oss << std::hex << static_cast<int>(bytes[i]);
}
return oss.str();
}
} // namespace distributed
} // namespace paddle
......@@ -34,170 +34,7 @@
namespace paddle {
namespace distributed {
class CustomEventManager {
public:
CustomEventManager() = default;
~CustomEventManager() {
if (is_created_) {
event_->Destroy();
}
}
CustomEventManager(const CustomEventManager&) = delete;
CustomEventManager& operator=(const CustomEventManager&) = delete;
CustomEventManager(CustomEventManager&& other) {
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(device_type_, other.device_type_);
std::swap(event_, other.event_);
}
CustomEventManager& operator=(CustomEventManager&& other) {
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(device_type_, other.device_type_);
std::swap(event_, other.event_);
return *this;
}
bool IsCreated() const { return is_created_; }
int8_t DeviceId() const { return device_index_; }
std::string DeviceType() const { return device_type_; }
phi::event::event_t GetRawCustomEvent() const { return event_->raw_event(); }
phi::event::Event* GetCustomEvent() const { return event_.get(); }
void Record(const paddle::platform::CustomDeviceContext& ctx) {
auto place = ctx.GetPlace();
auto device_type = place.GetDeviceType();
auto device_index = place.GetDeviceId();
if (!is_created_) {
CreateEvent(place);
}
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"CustomDeviceContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
PADDLE_ENFORCE_EQ(device_type,
device_type_,
platform::errors::PreconditionNotMet(
"CustomDeviceContext's device %d does not match"
"Event's device type %d",
device_type,
device_type_));
phi::DeviceGuard guard(place);
phi::stream::Stream stream(place, ctx.stream());
event_->Record(&stream);
}
bool Query() const { return event_->Query(); }
void Block(const paddle::platform::CustomDeviceContext& ctx) const {
if (is_created_) {
auto place = ctx.GetPlace();
auto device_type = place.GetDeviceType();
auto device_index = place.GetDeviceId();
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"CustomDeviceContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
PADDLE_ENFORCE_EQ(device_type,
device_type_,
platform::errors::PreconditionNotMet(
"CustomDeviceContext's device %d does not match"
"Event's device type %d",
device_type,
device_type_));
phi::DeviceGuard guard(place);
phi::stream::Stream stream(place, ctx.stream());
stream.WaitEvent(event_.get());
}
}
private:
bool is_created_{false};
std::shared_ptr<phi::event::Event> event_{nullptr};
int8_t device_index_{0};
std::string device_type_;
private:
void CreateEvent(const platform::Place& place) {
device_index_ = place.GetDeviceId();
device_type_ = place.GetDeviceType();
event_.reset(new phi::event::Event);
event_->Init(place);
is_created_ = true;
}
};
class CustomCCLCommManager {
public:
CustomCCLCommManager(const std::string& device_type,
phi::ccl::CCLComm ccl_comm)
: device_type_(device_type), ccl_comm_(ccl_comm) {}
CustomCCLCommManager() : CustomCCLCommManager("", nullptr) {}
~CustomCCLCommManager() noexcept {
std::unique_lock<std::mutex> lock(mutex_);
if (phi::DeviceManager::HasDeviceType(device_type_) && ccl_comm_) {
phi::DeviceManager::CCLDestroyComm(device_type_, ccl_comm_);
}
}
static std::shared_ptr<CustomCCLCommManager> Create(
const std::string& device_type,
int num_ranks,
int rank,
phi::ccl::CCLRootId* comm_id,
phi::ccl::CCLComm* ccl_comm) {
auto custom_ccl_manager = std::make_shared<CustomCCLCommManager>();
phi::DeviceManager::CCLCommInitRank(
device_type, num_ranks, comm_id, rank, ccl_comm);
custom_ccl_manager->device_type_ = device_type;
custom_ccl_manager->ccl_id_ = comm_id;
custom_ccl_manager->rank_ = rank;
custom_ccl_manager->ccl_comm_ = *ccl_comm;
return custom_ccl_manager;
}
phi::ccl::CCLRootId* GetCustomCCLId() const {
std::unique_lock<std::mutex> lock(mutex_);
return ccl_id_;
}
phi::ccl::CCLComm GetCustomCCLComm() const {
std::unique_lock<std::mutex> lock(mutex_);
return ccl_comm_;
}
CustomCCLCommManager(const CustomCCLCommManager&) = delete;
CustomCCLCommManager& operator=(const CustomCCLCommManager&) = delete;
CustomCCLCommManager& operator=(CustomCCLCommManager&& other) = delete;
CustomCCLCommManager(CustomCCLCommManager&& other) {
std::unique_lock<std::mutex> lock(other.mutex_);
std::swap(ccl_comm_, other.ccl_comm_);
}
protected:
std::string device_type_;
phi::ccl::CCLComm ccl_comm_;
phi::ccl::CCLRootId* ccl_id_;
int rank_;
mutable std::mutex mutex_;
};
phi::ccl::CCLReduceOp ToCustomCCLRedType(ReduceOp reduction);
std::string SerializeCustomCCLUniqueId(const phi::ccl::CCLRootId& ccl_id);
phi::ccl::CCLReduceOp ToXCCLRedType(ReduceOp reduction);
} // namespace distributed
} // namespace paddle
......@@ -17,99 +17,62 @@
#include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/distributed/collective/custom_ccl_tools.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/distributed/check/static_check.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/utils/data_type.h"
DECLARE_bool(xccl_blocking_wait);
#include "paddle/phi/core/distributed/comm_context_manager.h"
constexpr int64_t kWaitBlockTImeout = 10;
DECLARE_bool(use_stream_safe_cuda_allocator);
namespace paddle {
namespace distributed {
void SyncDefaultStream(
const std::vector<Place>& places,
std::vector<CustomEventManager>& cclEvents, // NOLINT
std::vector<std::unique_ptr<CustomDeviceContext>>& dev_ctx) { // NOLINT
for (size_t i = 0; i < places.size(); ++i) {
auto* default_ctx = static_cast<platform::CustomDeviceContext*>(
platform::DeviceContextPool::Instance().Get(places[i]));
cclEvents[i].Record(*default_ctx);
cclEvents[i].Block(*dev_ctx[i]);
}
}
std::shared_ptr<ProcessGroupCustom::CustomTask> ProcessGroupCustom::CreateTask(
std::vector<Place> places,
ProcessGroupCustom::XCCLTask::XCCLTask(const Place& place,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupCustom::CustomTask>(
places, rank, comm_type, inputs);
}
ProcessGroupCustom::CustomTask::CustomTask(
const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs)
: Task(rank, inputs, CommType), places_(places) {
control_events_.resize(places.size());
cclComms_.resize(places.size());
bool sync_op,
bool use_calc_stream)
: TaskStream(rank, comm_type, sync_op, use_calc_stream),
task_place_(place) {
comm_event_.Init(place);
}
ProcessGroupCustom::CustomTask::~CustomTask() {}
ProcessGroupCustom::XCCLTask::~XCCLTask() = default;
void ProcessGroupCustom::CustomTask::SetOutputs(
std::vector<phi::DenseTensor>& outputs) { // NOLINT
outputs_ = std::make_shared<std::vector<phi::DenseTensor>>(outputs);
}
bool ProcessGroupCustom::XCCLTask::IsCompleted() { return comm_event_.Query(); }
void ProcessGroupCustom::CustomTask::SynchronizeStreams() {
for (size_t i = 0; i < places_.size(); ++i) {
auto* default_ctx = static_cast<platform::CustomDeviceContext*>(
platform::DeviceContextPool::Instance().Get(places_[i]));
phi::DeviceGuard guard(default_ctx->GetPlace());
control_events_[i].Block(*default_ctx);
}
void ProcessGroupCustom::XCCLTask::UpdateWaitChain(
const phi::DeviceContext& ctx) {
comm_event_.Record(
reinterpret_cast<const phi::CustomContext&>(ctx).GetStream().get());
}
bool ProcessGroupCustom::CustomTask::IsCompleted() {
for (size_t i = 0; i < places_.size(); ++i) {
if (!control_events_[i].Query()) {
return false;
}
bool ProcessGroupCustom::XCCLTask::Wait(std::chrono::milliseconds timeout) {
// Warning here when use calc stream but also invoke waiting explicitly.
if (UseCalcStream()) {
VLOG(3) << "Warning: The communication is on calc stream, wait here is "
"useless.";
return true;
}
return true;
}
const auto* calc_ctx = reinterpret_cast<phi::CustomContext*>(
platform::DeviceContextPool::Instance().Get(task_place_));
calc_ctx->GetStream()->WaitEvent(&comm_event_);
bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) {
SynchronizeStreams();
while (!IsCompleted()) {
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
if (IsBlockCPUInWait()) {
// If we use the work to do barrier, we should block cpu
phi::DeviceManager::SynchronizeDevice(task_place_);
}
return true;
}
// Same as Wait
void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); }
void ProcessGroupCustom::CustomTask::UpdateWaitChain(
const phi::DeviceContext& ctx) {
PADDLE_ENFORCE_NE(
std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()),
places_.cend(),
phi::errors::NotFound("Cannot find the device context in this task."));
auto index = std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()) -
places_.cbegin();
control_events_[index].Record(
reinterpret_cast<const phi::CustomContext&>(ctx));
}
void ProcessGroupCustom::XCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupCustom::ProcessGroupCustom(
const std::shared_ptr<phi::distributed::Store>& store,
......@@ -121,147 +84,45 @@ ProcessGroupCustom::ProcessGroupCustom(
store_(store),
device_type_(device_type) {}
void ProcessGroupCustom::BroadcastUniqueCustomID(
std::vector<phi::ccl::CCLRootId>& ccl_ids) { // NOLINT
if (rank_ == 0) {
for (size_t i = 0; i < ccl_ids.size(); i++) {
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
store_->set(key, ccl_ids[i]);
}
} else {
for (size_t i = 0; i < ccl_ids.size(); i++) {
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
ccl_ids[i] = store_->get(key);
}
}
void ProcessGroupCustom::GroupStart(const std::string& dev_type) {
phi::DeviceManager::CCLGroupStart(dev_type);
}
// create CustomCCLManager cache for places_key
void ProcessGroupCustom::CreateCustomManagerCache(
const std::string& places_key, const std::vector<Place>& places) {
PADDLE_ENFORCE_EQ(
places_key.empty(),
false,
platform::errors::PreconditionNotMet(
"Not able to create/get the CustomCCL Communicator since "
"the NPU place are not known"));
const std::string device_type = places.back().GetDeviceType();
std::vector<std::shared_ptr<CustomCCLCommManager>> ccl_comms;
ccl_comms.resize(places.size());
// using vector just for broadcast
std::vector<phi::ccl::CCLRootId> ccl_ids;
ccl_ids.resize(1);
auto& ccl_id = ccl_ids.front();
if (rank_ == 0) {
phi::DeviceManager::CCLGetUniqueId(device_type, &ccl_id);
}
BroadcastUniqueCustomID(ccl_ids);
VLOG(3) << "init custom ccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << places_key
<< ", custom ccl uniqueid: " << SerializeCustomCCLUniqueId(ccl_id);
std::vector<std::unique_ptr<CustomDeviceContext>> dev_ctx;
dev_ctx.resize(places.size());
for (size_t i = 0; i < places.size(); ++i) {
phi::DeviceGuard guard(places[i]);
ccl_comms[i] = CustomCCLCommManager::Create(
device_type, GetSize(), GetRank(), &ccl_id, new phi::ccl::CCLComm);
dev_ctx[i] = std::make_unique<CustomDeviceContext>(places[i]);
dev_ctx[i]->SetAllocator(
&(phi::DeviceContextPool::Instance().Get(places[i])->GetAllocator()));
dev_ctx[i]->SetHostAllocator(&(
phi::DeviceContextPool::Instance().Get(places[i])->GetHostAllocator()));
dev_ctx[i]->SetZeroAllocator(&(
phi::DeviceContextPool::Instance().Get(places[i])->GetZeroAllocator()));
dev_ctx[i]->SetHostZeroAllocator(&(phi::DeviceContextPool::Instance()
.Get(places[i])
->GetHostZeroAllocator()));
}
std::vector<CustomEventManager> events;
events.resize(places.size());
// These caches will be useful to process sync/wait/communicate
places_to_events_.emplace(places_key, std::move(events));
places_to_customcomm_.emplace(places_key, std::move(ccl_comms));
places_to_ctx_.emplace(places_key, std::move(dev_ctx));
void ProcessGroupCustom::GroupEnd(const std::string& dev_type) {
phi::DeviceManager::CCLGroupEnd(dev_type);
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
Fn fn,
CommType op_type,
bool sync_op UNUSED,
bool use_calc_stream) {
const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_customcomm_.find(key) == places_to_customcomm_.end()) {
CreateCustomManagerCache(key, places);
}
}
auto& ccl_comms = places_to_customcomm_[key];
if (!use_calc_stream) {
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
}
auto task = CreateTask(places, rank_, op_type, inputs);
task->SetOutputs(outputs);
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
const auto& ccl_stream =
use_calc_stream ? reinterpret_cast<phi::CustomContext*>(
phi::DeviceContextPool::Instance().Get(places[i]))
->stream()
: places_to_ctx_[key][i]->stream();
phi::stream::Stream stream(places[i], ccl_stream);
fn(inputs[i], outputs[i], ccl_comms[i]->GetCustomCCLComm(), stream);
}
if (!use_calc_stream) {
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
}
}
return task;
phi::DeviceContext* ProcessGroupCustom::GetDeviceContext(
const Place& place) const {
return GetDeviceContext(place, /*use_calc_stream*/ false);
}
void* XcclGetPointerByOffset(void* raw_pointer,
size_t offset,
phi::DataType type) {
if (type == phi::DataType::FLOAT32) {
return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
offset);
} else if (type == phi::DataType::FLOAT64) {
return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
offset);
} else if (type == phi::DataType::INT32) {
return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::INT64) {
return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::FLOAT16) {
return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
offset);
phi::DeviceContext* ProcessGroupCustom::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
const std::string& key = GetKeyFromPlace(place);
if (use_calc_stream) {
const auto& iter = place_to_calc_ctx_.find(key);
return iter->second;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in xccl is not supported."));
const auto& iter = place_to_comm_ctx_.find(key);
PADDLE_ENFORCE_NE(
iter,
place_to_comm_ctx_.end(),
phi::errors::NotFound(
"Cannot find the device context in this process group."));
return iter->second.get();
}
return nullptr;
}
phi::ccl::CCLComm ProcessGroupCustom::XCCLComm(const Place& place) const {
const std::string& key = GetKeyFromPlace(place);
const auto& iter = place_to_comm_ctx_.find(key);
PADDLE_ENFORCE_NE(
iter,
place_to_comm_ctx_.end(),
phi::errors::NotFound(
"Cannot find the XCCL communicator in this process group."));
return iter->second->xccl_comm();
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
......@@ -269,212 +130,110 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op, // for compatibility, no use now
bool sync_op,
bool use_calc_stream) {
// numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0
? paddle::distributed::GetPartialTensor(in_tensor, offset, numel)
: in_tensor;
phi::distributed::CommStaticCheck::GatherLikeShape(
*out_tensor,
in_tensor_maybe_partial,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor_maybe_partial};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllGather(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
comm,
stream);
numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
return RunFnInXCCLEnv(
[&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
comm_context->AllGather(out_tensor, in_tensor_maybe_partial, stream);
},
in_tensor_maybe_partial,
CommType::ALLGATHER,
sync_op,
use_calc_stream);
}
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllGather(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op, // for compatibility, no use now
bool sync_op,
bool use_calc_stream) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllReduce(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
ToCustomCCLRedType(opts.reduce_op),
comm,
return RunFnInXCCLEnv(
[&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
comm_context->AllReduce(
out_tensor,
in_tensor,
paddle::distributed::ToXCCLRedType(opts.reduce_op),
stream);
},
in_tensor,
CommType::ALLREDUCE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const AllreduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllReduce(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
ToCustomCCLRedType(opts.reduce_op),
comm,
stream);
},
CommType::ALLREDUCE,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op, // for compatibility, no use now
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
int root = opts.source_rank * in_wrapper.size() + opts.source_root;
if (rank_ == root) {
return phi::DeviceManager::CCLBroadcast(
device_type_,
input.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
root,
comm,
stream);
} else {
return phi::DeviceManager::CCLBroadcast(
const phi::DDim& out_dim = out_tensor->dims();
const phi::DDim& in_dim = in_tensor.dims();
CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);
// NOTE: Since `all_to_all` needs other processes' participation, it cannot
// simply be covered by static checks. Factors are set to 0 here to skip the
// shape check. Its shape check will be done by dynamic checks with
// FLAGS_enable_xccl_dynamic_check.
return RunFnInXCCLEnv(
[&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
int64_t in_row_size = in_tensor.numel() / in_dim[0],
out_row_size = out_tensor->numel() / out_dim[0];
int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
phi::DenseTensor input_partial, output_partial;
std::vector<void*> send_buf, recv_buf;
std::vector<size_t> send_count, recv_count;
std::vector<phi::ccl::CCLDataType> send_dtype, recv_dtype;
for (auto i = 0; i < size_; i++) {
in_numel = in_size_each_rank[i] * in_row_size;
input_partial = GetPartialTensor(in_tensor, in_offset, in_numel);
out_numel = out_size_each_rank[i] * out_row_size;
output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel);
in_offset += in_numel;
out_offset += out_numel;
send_buf.push_back(input_partial.data());
recv_buf.push_back(output_partial.data());
send_count.push_back(in_numel);
recv_count.push_back(out_numel);
send_dtype.push_back(phi::ccl::ToCCLDataType(input_partial.dtype()));
recv_dtype.push_back(phi::ccl::ToCCLDataType(output_partial.dtype()));
}
phi::DeviceManager::CCLAllToAll(
device_type_,
output.data(),
output.numel(),
phi::ccl::ToCCLDataType(output.dtype()),
root,
comm,
const_cast<const void**>(send_buf.data()),
send_count.data(),
send_dtype.data(),
recv_buf.data(),
recv_count.data(),
recv_dtype.data(),
rank_,
size_,
comm_context->GetXcclComm(),
stream);
}
},
CommType::BROADCAST,
in_tensor,
CommType::ALLTOALL,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
const BarrierOptions& opts) {
// Only support single card single process
PADDLE_ENFORCE_GE(opts.device_id,
0,
platform::errors::PreconditionNotMet(
phi::errors::PreconditionNotMet(
"The barrier device id must greater or equal than 0."));
platform::CustomPlace place(device_type_, opts.device_id);
auto allocator = std::unique_ptr<phi::Allocator>(
......@@ -482,111 +241,176 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
phi::DenseTensor barrier_tensor{allocator.get(), meta};
auto task = ProcessGroupCustom::AllReduce(&barrier_tensor,
auto task = AllReduce(&barrier_tensor,
barrier_tensor,
{},
/*sync_op*/ true,
false);
auto xccl_task = dynamic_cast<ProcessGroupCustom::CustomTask*>(task.get());
xccl_task->barrierTensors_ = {barrier_tensor};
/*use_calc_stream*/ false);
auto xccl_task = dynamic_cast<XCCLTask*>(task.get());
xccl_task->SetBlockCPUInWait();
return task;
}
phi::DeviceContext* ProcessGroupCustom::GetDeviceContext(
const Place& place) const {
const std::string key = GetKeyFromPlace(place);
const auto& iter = places_to_ctx_.find(key);
PADDLE_ENFORCE_NE(
iter,
places_to_ctx_.end(),
platform::errors::NotFound(
"Cannot find the device context in this process group."));
return iter->second[0].get();
}
phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_customcomm_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_customcomm_.end(),
platform::errors::InvalidArgument(
"Cannot find nccl comm in process group."));
return iter->second[0]->GetCustomCCLComm();
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
return RunFnInXCCLEnv(
[&](const phi::stream::Stream& stream) {
int root = opts.source_rank + opts.source_root;
auto comm_context = this->GetCommContext();
comm_context->Broadcast(out_tensor, in_tensor, root, stream);
},
in_tensor,
CommType::BROADCAST,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const BroadcastOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
int root = opts.source_rank * in_tensors.size() + opts.source_root;
if (rank_ == root) {
return phi::DeviceManager::CCLBroadcast(
device_type_,
input.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
root,
comm,
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
return RunFnInXCCLEnv(
[&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
comm_context->Reduce(out_tensor,
in_tensor,
paddle::distributed::ToXCCLRedType(opts.reduce_op),
opts.root_rank,
stream);
} else {
return phi::DeviceManager::CCLBroadcast(
device_type_,
output.data(),
output.numel(),
phi::ccl::ToCCLDataType(output.dtype()),
root,
comm,
},
in_tensor,
CommType::REDUCE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
return RunFnInXCCLEnv(
[&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
comm_context->ReduceScatter(
out_tensor,
in_tensor,
paddle::distributed::ToXCCLRedType(opts.reduce_op),
stream);
}
},
CommType::BROADCAST,
false,
false);
in_tensor,
CommType::REDUCE_SCATTER,
sync_op,
use_calc_stream);
}
void CheckTensorsInDifferentCustomDevices(
const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
PADDLE_ENFORCE_EQ(
tensors.size() == 0,
false,
phi::errors::InvalidArgument("Tensor list must be nonempty."));
PADDLE_ENFORCE_LE(
tensors.size(),
num_devices,
phi::errors::InvalidArgument("Tensor list mustn't be larger than the "
"number of available CustomDevice."));
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
phi::distributed::CommStaticCheck::ScatterLikeShape(
*out_tensor,
in_tensor,
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
return RunFnInXCCLEnv(
[&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
std::set<Place> used_devices;
int64_t numel = in_tensor.numel() / size_;
if (rank_ == opts.root_rank) {
int64_t offset = 0;
phi::DenseTensor partial_tensor;
for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(in_tensor, offset, numel);
if (i != rank_) {
comm_context->Send(partial_tensor, numel, i, stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(out_tensor->data(),
partial_tensor.data(),
numel * phi::SizeOf(partial_tensor.dtype()),
&stream);
}
offset += numel;
}
} else {
comm_context->Recv(out_tensor, numel, opts.root_rank, stream);
}
},
in_tensor,
CommType::SCATTER,
sync_op,
use_calc_stream);
}
for (const auto& t : tensors) {
PADDLE_ENFORCE_EQ(platform::is_custom_place(t.place()),
true,
phi::errors::InvalidArgument(
"Tensors must be CustomDevice and dense tensor."));
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
std::vector<phi::DenseTensor> partial_tensors;
if (rank_ == opts.root_rank) {
partial_tensors.reserve(size_);
size_t offset = 0;
size_t numel = out_tensor->numel() / size_;
for (auto i = 0; i < size_; i++) {
partial_tensors.push_back(GetPartialTensor(*out_tensor, offset, numel));
offset += numel;
}
}
return Gather(&partial_tensors, in_tensor, opts, sync_op, use_calc_stream);
}
const auto inserted = used_devices.insert(t.place()).second;
PADDLE_ENFORCE_EQ(inserted,
true,
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
std::vector<phi::DenseTensor>* gather_tensors_ptr,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
auto& gather_tensors = *gather_tensors_ptr;
PADDLE_ENFORCE_GT(size_,
opts.root_rank,
phi::errors::InvalidArgument(
"Tensors must be on distinct custom devices."));
"root world size [%d] is less than root rank [%d]",
size_,
opts.root_rank));
auto gather_func = [&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
// root receive from all devices
if (rank_ == opts.root_rank) {
for (auto i = 0; i < size_; i++) {
auto& gather_tensor = gather_tensors[i];
if (i != rank_) {
comm_context->Recv(&gather_tensor, gather_tensor.numel(), i, stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(
gather_tensor.data(),
in_tensor.data(),
in_tensor.numel() * phi::SizeOf(in_tensor.dtype()),
&stream);
}
}
} else {
// send to root
comm_context->Send(in_tensor, in_tensor.numel(), opts.root_rank, stream);
}
};
return RunFnInXCCLEnv(
gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
......@@ -602,53 +426,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
partial_tensor = GetPartialTensor(*tensor, offset, numel);
tensor = &partial_tensor;
}
phi::distributed::CommStaticCheck::CheckShape(
*tensor, rank_, size_, phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{*tensor};
std::vector<phi::DenseTensor> out_wrapper{*tensor};
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLRecv(device_type_,
output.data(),
output.numel(),
phi::ccl::ToCCLDataType(output.dtype()),
src_rank,
comm,
stream);
return RunFnInXCCLEnv(
[&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
comm_context->Recv(tensor, tensor->numel(), src_rank, stream);
},
*tensor,
CommType::RECV,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) {
CheckTensorsInDifferentCustomDevices(tensors, static_cast<size_t>(GetSize()));
return Collective(
tensors,
tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLRecv(device_type_,
output.data(),
output.numel(),
phi::ccl::ToCCLDataType(output.dtype()),
src_rank,
comm,
stream);
},
CommType::RECV,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
const phi::DenseTensor& tensor,
int dst_rank,
......@@ -659,192 +448,459 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
// numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& tensor_maybe_partial =
numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
phi::distributed::CommStaticCheck::CheckShape(
tensor_maybe_partial, rank_, size_, phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{tensor_maybe_partial};
std::vector<phi::DenseTensor> out_wrapper{tensor_maybe_partial};
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLSend(device_type_,
input.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
return RunFnInXCCLEnv(
[&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
comm_context->Send(tensor_maybe_partial,
tensor_maybe_partial.numel(),
dst_rank,
comm,
stream);
},
tensor_maybe_partial,
CommType::SEND,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) {
CheckTensorsInDifferentCustomDevices(tensors, static_cast<size_t>(GetSize()));
std::shared_ptr<ProcessGroupCustom::XCCLTask> ProcessGroupCustom::CreateTask(
const Place& place,
int rank,
CommType comm_type,
bool is_sync,
bool use_calc_stream) {
return std::make_shared<ProcessGroupCustom::XCCLTask>(
place, rank, comm_type, is_sync, use_calc_stream);
}
void ProcessGroupCustom::BroadcastUniqueXCCLID(
phi::ccl::CCLRootId* xccl_root_id) {
const std::string key =
"ProcessGroupCustom/xccl_ids/" + std::to_string(gid_) + "/0";
if (rank_ == 0) {
store_->set(key, *xccl_root_id);
} else {
*xccl_root_id = store_->get(key);
}
}
void ProcessGroupCustom::CreateXCCLEnvCache(const Place& place,
const std::string& place_key) {
if (!place_to_comm_ctx_.empty()) {
VLOG(3) << "Warning: Tensors from multiple devices are not supported yet.";
}
VLOG(3) << "init xccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << place_key;
phi::distributed::CommContextManager::CreateXCCLCommContext(
store_, std::to_string(gid_), place.GetDeviceType(), rank_, size_);
auto* calc_ctx = static_cast<phi::CustomContext*>(
platform::DeviceContextPool::Instance().Get(place));
auto comm_ctx = std::make_unique<phi::CustomContext>(place);
comm_ctx->SetAllocator(
&(phi::DeviceContextPool::Instance().Get(place)->GetAllocator()));
comm_ctx->SetHostAllocator(
&(phi::DeviceContextPool::Instance().Get(place)->GetHostAllocator()));
comm_ctx->SetZeroAllocator(
&(phi::DeviceContextPool::Instance().Get(place)->GetZeroAllocator()));
comm_ctx->SetHostZeroAllocator(
&(phi::DeviceContextPool::Instance().Get(place)->GetHostZeroAllocator()));
auto xccl_comm_ctx = this->GetCommContext();
comm_ctx->set_xccl_comm(xccl_comm_ctx->GetXcclComm());
auto xccl_event = std::make_unique<phi::event::Event>();
xccl_event->Init(place);
place_to_calc_event_.emplace(place_key, std::move(xccl_event));
place_to_calc_ctx_.emplace(place_key, calc_ctx);
place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));
// TODO(sunyilun): for compatibility, will be removed later
std::vector<phi::CustomContext*> comm_ctx_wrapper{
place_to_comm_ctx_[place_key].get()};
places_to_ctx_.emplace(place_key, comm_ctx_wrapper);
}
void ProcessGroupCustom::SyncCalcStream(const Place& place) {
const std::string& key = GetKeyFromPlace(place);
auto& calc_event = place_to_calc_event_.at(key);
const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto* comm_ctx = place_to_comm_ctx_.at(key).get();
calc_event->Record(calc_ctx->GetStream().get());
comm_ctx->GetStream()->WaitEvent(calc_event.get());
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::RunFnInXCCLEnv(
std::function<void(const phi::stream::Stream&)> fn,
const phi::DenseTensor& tensor,
CommType comm_type,
bool sync_op,
bool use_calc_stream) {
const auto& place = tensor.place();
const auto& key = GetKeyFromPlace(place);
phi::DeviceGuard guard(place);
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateXCCLEnvCache(place, key);
}
if (!use_calc_stream) {
SyncCalcStream(place);
}
auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);
const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto& comm_ctx = place_to_comm_ctx_.at(key);
auto& xccl_stream =
use_calc_stream ? *calc_ctx->GetStream() : *comm_ctx->GetStream();
fn(xccl_stream);
if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
memory::RecordStream(tensor.Holder(), xccl_stream.raw_stream());
}
task->UpdateWaitChain(*comm_ctx);
}
return task;
}
// TODO(sunyilun): methods below will be removed later
void SyncDefaultStream(const std::vector<Place>& places,
phi::event::Event& xccl_event, // NOLINT
std::vector<phi::CustomContext*>& dev_ctx) { // NOLINT
for (size_t i = 0; i < places.size(); ++i) {
auto* default_ctx = static_cast<phi::CustomContext*>(
platform::DeviceContextPool::Instance().Get(places[i]));
xccl_event.Record(default_ctx->GetStream().get());
dev_ctx[i]->GetStream()->WaitEvent(&xccl_event);
}
}
std::shared_ptr<ProcessGroupCustom::XCCLTask> ProcessGroupCustom::CreateTask(
std::vector<Place> places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupCustom::XCCLTask>(
places, rank, comm_type, inputs);
}
ProcessGroupCustom::XCCLTask::XCCLTask(
const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs)
: TaskStream(rank, inputs, CommType), task_place_(places[0]) {
comm_event_.Init(places[0]);
}
// create XCCLManager cache for places_key
void ProcessGroupCustom::CreateXCCLManagerCache(
const std::string& places_key, const std::vector<Place>& places) {
PADDLE_ENFORCE_EQ(places_key.empty(),
false,
phi::errors::PreconditionNotMet(
"Not able to create/get the XCCL Communicator since "
"the CustomPlace are not known"));
phi::ccl::CCLRootId xccl_root_id;
if (rank_ == 0) {
phi::DeviceManager::CCLGetUniqueId(device_type_, &xccl_root_id);
}
BroadcastUniqueXCCLID(&xccl_root_id);
VLOG(3) << "init xccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << places_key << ", xccl uniqueid: "
<< phi::ccl::SerializeXCCLUniqueId(xccl_root_id);
std::vector<std::unique_ptr<phi::CustomContext>> dev_ctx;
dev_ctx.resize(places.size());
std::vector<phi::CustomContext*> dev_ctx_raw;
dev_ctx_raw.resize(places.size());
GroupStart(device_type_);
for (size_t i = 0; i < places.size(); ++i) {
phi::DeviceGuard guard(places[i]);
dev_ctx[i] = std::make_unique<phi::CustomContext>(places[i]);
dev_ctx[i]->SetAllocator(
&(phi::DeviceContextPool::Instance().Get(places[i])->GetAllocator()));
dev_ctx[i]->SetHostAllocator(&(
phi::DeviceContextPool::Instance().Get(places[i])->GetHostAllocator()));
dev_ctx[i]->SetZeroAllocator(&(
phi::DeviceContextPool::Instance().Get(places[i])->GetZeroAllocator()));
dev_ctx[i]->SetHostZeroAllocator(&(phi::DeviceContextPool::Instance()
.Get(places[i])
->GetHostZeroAllocator()));
phi::ccl::CCLComm xccl_comm;
phi::DeviceManager::CCLCommInitRank(
device_type_, GetSize(), &xccl_root_id, GetRank(), &xccl_comm);
dev_ctx[i]->set_xccl_comm(xccl_comm);
dev_ctx_raw[i] = dev_ctx[i].get();
}
GroupEnd(device_type_);
// TODO(sunyilun): for compatibility, will be removed later
auto xccl_event = std::make_unique<phi::event::Event>();
xccl_event->Init(places[0]);
place_to_calc_event_.emplace(places_key, std::move(xccl_event));
place_to_calc_ctx_.emplace(
places_key,
static_cast<phi::CustomContext*>(
platform::DeviceContextPool::Instance().Get(places[0])));
place_to_comm_ctx_.emplace(places_key, std::move(dev_ctx[0]));
// These caches will be useful to process sync/wait/communicate
places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw));
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
Fn fn,
CommType op_type) {
const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateXCCLManagerCache(key, places);
}
}
SyncDefaultStream(
places, *place_to_calc_event_.at(key), places_to_ctx_.at(key));
auto task = CreateTask(places, rank_, op_type, inputs);
// construct uninitialize guard for device
{
GroupStart(device_type_);
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
const auto& xccl_stream = *places_to_ctx_.at(key)[i]->GetStream();
fn(inputs[i],
outputs[i],
places_to_ctx_.at(key)[i]->xccl_comm(),
xccl_stream);
}
GroupEnd(device_type_);
}
if (FLAGS_use_stream_safe_cuda_allocator) {
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
memory::RecordStream(inputs[i].Holder(),
places_to_ctx_.at(key)[i]->stream());
}
}
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
}
return task;
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::PointToPoint(
std::vector<phi::DenseTensor>& tensors,
Fn fn,
int dst_rank,
CommType op_type) {
const auto places = GetPlaceList(tensors);
const auto key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateXCCLManagerCache(key, places);
}
}
SyncDefaultStream(
places, *place_to_calc_event_.at(key), places_to_ctx_.at(key));
auto task = CreateTask(places, rank_, op_type, tensors);
// construct uninitialize guard for device
{
GroupStart(device_type_);
for (size_t i = 0; i < tensors.size(); ++i) {
phi::DeviceGuard guard(places[i]);
const auto& xccl_stream = *places_to_ctx_.at(key)[i]->GetStream();
fn(tensors[i],
places_to_ctx_.at(key)[i]->xccl_comm(),
xccl_stream,
dst_rank);
}
GroupEnd(device_type_);
}
if (FLAGS_use_stream_safe_cuda_allocator) {
for (size_t i = 0; i < tensors.size(); ++i) {
phi::DeviceGuard guard(places[i]);
memory::RecordStream(tensors[i].Holder(),
places_to_ctx_.at(key)[i]->stream());
}
}
for (size_t i = 0; i < tensors.size(); ++i) {
phi::DeviceGuard guard(places[i]);
task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
}
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
return Collective(
tensors,
tensors,
[&](phi::DenseTensor& input,
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLSend(device_type_,
input.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
dst_rank,
comm,
auto comm_context = this->GetCommContext();
comm_context->AllReduce(
&output,
input,
paddle::distributed::ToXCCLRedType(opts.reduce_op),
stream);
},
CommType::SEND,
false,
false);
CommType::ALLREDUCE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
return Collective(
in_wrapper,
out_wrapper,
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream) {
const auto root =
opts.source_rank * in_tensors.size() + opts.source_root;
auto comm_context = this->GetCommContext();
comm_context->Broadcast(&output, input, root, stream);
},
CommType::BROADCAST);
}
void CheckTensorsInDifferentDevices(
const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
PADDLE_ENFORCE_EQ(
tensors.empty(),
false,
phi::errors::InvalidArgument("Tensor list must be nonempty."));
PADDLE_ENFORCE_LE(
tensors.size(),
num_devices,
phi::errors::InvalidArgument("Tensor list mustn't be larger than the "
"number of available CustomDevices."));
std::set<Place> used_devices;
for (const auto& t : tensors) {
PADDLE_ENFORCE_EQ(platform::is_custom_place(t.place()),
true,
phi::errors::InvalidArgument(
"Tensors must be CustomDevice and dense tensor."));
const auto inserted = used_devices.insert(t.place()).second;
PADDLE_ENFORCE_EQ(inserted,
true,
phi::errors::InvalidArgument(
"Tensors must be on distinct CustomDevice devices."));
}
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
auto task = PointToPoint(
tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLReduce(device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
ToCustomCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream);
const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream,
int dst_rank) {
auto comm_context = this->GetCommContext();
comm_context->Send(input, input.numel(), dst_rank, stream);
},
CommType::REDUCE,
sync_op,
use_calc_stream);
dst_rank,
CommType::SEND);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
auto task = PointToPoint(
tensors,
[&](phi::DenseTensor& output,
const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream,
int src_rank) {
auto comm_context = this->GetCommContext();
comm_context->Recv(&output, output.numel(), src_rank, stream);
},
src_rank,
CommType::RECV);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) {
std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
phi::errors::InvalidArgument("All outputs should be in CustomPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLReduce(device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
ToCustomCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream);
},
CommType::REDUCE,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) {
const phi::DDim& out_dim = out_tensor->dims();
const phi::DDim& in_dim = in_tensor.dims();
CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);
// NOTE: Since `all_to_all` needs other processes' participation, it cannot
// simply be covered by static checks. Factors are set to 0 here to skip the
// shape check. Its shape check will be done by dynamic checks with
// FLAGS_enable_nccl_dynamic_check.
phi::distributed::CommStaticCheck::CheckShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
/*out_size_factor*/ 0,
/*in_size_factor*/ 0,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream) {
int64_t in_row_size = in_tensor.numel() / in_dim[0],
out_row_size = out_tensor->numel() / out_dim[0];
int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
phi::DenseTensor input_partial, output_partial;
std::vector<void*> send_buf, recv_buf;
std::vector<size_t> send_count, recv_count;
std::vector<phi::ccl::CCLDataType> send_dtype, recv_dtype;
for (auto i = 0; i < size_; i++) {
in_numel = in_size_each_rank[i] * in_row_size;
input_partial = GetPartialTensor(in_tensor, in_offset, in_numel);
out_numel = out_size_each_rank[i] * out_row_size;
output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel);
in_offset += in_numel;
out_offset += out_numel;
send_buf.push_back(input_partial.data());
recv_buf.push_back(output_partial.data());
send_count.push_back(in_numel);
recv_count.push_back(out_numel);
send_dtype.push_back(phi::ccl::ToCCLDataType(input_partial.dtype()));
recv_dtype.push_back(phi::ccl::ToCCLDataType(output_partial.dtype()));
}
phi::DeviceManager::CCLAllToAll(
device_type_,
const_cast<const void**>(send_buf.data()),
send_count.data(),
send_dtype.data(),
recv_buf.data(),
recv_count.data(),
recv_dtype.data(),
rank_,
size_,
comm,
stream);
auto comm_context = this->GetCommContext();
comm_context->AllGather(&output, input, stream);
},
CommType::ALLTOALL,
sync_op,
use_calc_stream);
CommType::ALLGATHER);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
......@@ -863,8 +919,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
size_t offset = 0;
std::vector<void*> send_buf, recv_buf;
std::vector<size_t> send_count(size_, input.numel() / size_),
......@@ -889,111 +947,35 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
recv_dtype.data(),
rank_,
size_,
comm,
stream);
},
CommType::ALLTOALL,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
phi::distributed::CommStaticCheck::ScatterLikeShape(
*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLReduceScatter(
device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
out_tensor->numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
paddle::distributed::ToCustomCCLRedType(opts.reduce_op),
comm,
comm_context->GetXcclComm(),
stream);
},
CommType::REDUCE_SCATTER,
false,
false);
CommType::ALLTOALL);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
phi::distributed::CommStaticCheck::ScatterLikeShape(
*out_tensor,
in_tensor,
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream) {
int64_t numel = in_tensor.numel() / size_;
if (rank_ == opts.root_rank) {
int64_t offset = 0;
phi::DenseTensor partial_tensor;
for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(in_tensor, offset, numel);
if (i != rank_) {
phi::DeviceManager::CCLSend(
device_type_,
partial_tensor.data(),
numel,
phi::ccl::ToCCLDataType(partial_tensor.dtype()),
i,
comm,
stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(out_tensor->data(),
partial_tensor.data(),
numel * phi::SizeOf(partial_tensor.dtype()),
&stream);
}
offset += numel;
}
} else {
phi::DeviceManager::CCLRecv(
device_type_,
out_tensor->data(),
numel,
phi::ccl::ToCCLDataType(out_tensor->dtype()),
auto comm_context = this->GetCommContext();
comm_context->Reduce(&output,
input,
paddle::distributed::ToXCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream);
}
},
CommType::SCATTER,
sync_op,
use_calc_stream);
CommType::REDUCE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
......@@ -1003,134 +985,36 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream) {
int64_t numel = input.numel() / size_;
if (rank_ == opts.root_rank) {
int64_t offset = 0;
phi::DenseTensor partial_tensor;
for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(input, offset, numel);
if (i != rank_) {
phi::DeviceManager::CCLSend(
device_type_,
partial_tensor.data(),
numel,
phi::ccl::ToCCLDataType(partial_tensor.dtype()),
i,
comm,
stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(output.data(),
partial_tensor.data(),
numel * phi::SizeOf(partial_tensor.dtype()),
&stream);
}
offset += numel;
}
} else {
phi::DeviceManager::CCLRecv(device_type_,
output.data(),
numel,
phi::ccl::ToCCLDataType(output.dtype()),
opts.root_rank,
comm,
stream);
}
},
CommType::SCATTER,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
std::vector<phi::DenseTensor> partial_tensors;
if (rank_ == opts.root_rank) {
partial_tensors.reserve(size_);
auto comm_context = this->GetCommContext();
size_t offset = 0;
size_t numel = out_tensor->numel() / size_;
for (auto i = 0; i < size_; i++) {
partial_tensors.push_back(GetPartialTensor(*out_tensor, offset, numel));
offset += numel;
}
}
return Gather(&partial_tensors, in_tensor, opts, sync_op, use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
std::vector<phi::DenseTensor>* gather_tensors_ptr,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
auto& gather_tensors = *gather_tensors_ptr;
PADDLE_ENFORCE_GT(size_,
opts.root_rank,
phi::errors::InvalidArgument(
"root world size [%d] is less than root rank [%d]",
size_,
opts.root_rank));
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
return Collective(
in_wrapper,
in_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
// root receive from all devices
size_t count = input.numel() / size_;
if (rank_ == opts.root_rank) {
comm_context->GroupStart();
for (auto i = 0; i < size_; i++) {
auto& gather_tensor = gather_tensors[i];
if (i != rank_) {
phi::DeviceManager::CCLRecv(
device_type_,
gather_tensor.data(),
gather_tensor.numel(),
phi::ccl::ToCCLDataType(gather_tensor.dtype()),
i,
comm,
stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(
gather_tensor.data(),
in_tensor.data(),
in_tensor.numel() * phi::SizeOf(in_tensor.dtype()),
&stream);
}
auto input_data = reinterpret_cast<phi::DenseTensor*>(
GetPointerByOffset(input.data(), offset, input.dtype()));
comm_context->Send(*input_data, count, i, stream);
offset += count;
}
comm_context->Recv(&output, count, opts.root_rank, stream);
comm_context->GroupEnd();
} else {
// send to root
phi::DeviceManager::CCLSend(
device_type_,
const_cast<void*>(in_tensor.data()),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
opts.root_rank,
comm,
stream);
comm_context->Recv(&output, count, opts.root_rank, stream);
}
},
CommType::GATHER,
sync_op,
use_calc_stream);
CommType::SCATTER);
}
std::shared_ptr<ProcessGroupCustom>
......@@ -1146,5 +1030,16 @@ ProcessGroupCustom::CreateProcessGroupCustom(
return process_group;
}
phi::distributed::XCCLCommContext* ProcessGroupCustom::GetCommContext() {
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
auto comm_context = static_cast<phi::distributed::XCCLCommContext*>(
comm_context_manager.Get(std::to_string(this->gid_)));
PADDLE_ENFORCE_NE(comm_context,
nullptr,
phi::errors::Unavailable("XCCLCommContext is nullptr"));
return comm_context;
}
} // namespace distributed
} // namespace paddle
......@@ -15,82 +15,76 @@
#pragma once
#include <chrono>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/collective/custom_ccl_tools.h"
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_with_stream.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/distributed/xccl_comm_context.h"
namespace paddle {
namespace distributed {
using Place = paddle::platform::Place;
using CustomDeviceContext = paddle::platform::CustomDeviceContext;
class ProcessGroupCustom : public ProcessGroupWithStream {
using Place = phi::Place;
class ProcessGroupCustom final : public ProcessGroupWithStream {
public:
class CustomTask : public ProcessGroup::Task,
public std::enable_shared_from_this<CustomTask> {
class XCCLTask final : public ProcessGroupWithStream::TaskStream,
public std::enable_shared_from_this<XCCLTask> {
public:
CustomTask(const std::vector<Place>& places,
XCCLTask(const Place& place,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs);
CommType comm_type,
bool sync_op,
bool use_calc_stream);
virtual ~XCCLTask();
bool IsCompleted() override;
void SynchronizeStreams();
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override;
void Synchronize() override;
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
void SetOutputs(std::vector<phi::DenseTensor>& outputs); // NOLINT
virtual ~CustomTask();
std::vector<CustomEventManager> control_events_;
std::vector<phi::DenseTensor> barrierTensors_;
bool IsBlockCPUInWait() const { return block_cpu_in_wait_; }
void SetBlockCPUInWait() { block_cpu_in_wait_ = true; }
protected:
std::vector<Place> places_;
std::vector<std::shared_ptr<CustomCCLCommManager>> cclComms_;
std::shared_ptr<std::vector<phi::DenseTensor>> outputs_;
// TODO(sunyilun): methods below will be removed later
XCCLTask(const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs);
private:
const std::string device_type_;
bool block_cpu_in_wait_{false};
phi::event::Event comm_event_; // event on comm stream
Place task_place_;
};
ProcessGroupCustom(const std::shared_ptr<phi::distributed::Store>& store,
public:
static std::shared_ptr<ProcessGroupCustom> CreateProcessGroupCustom(
const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type,
int rank,
int size,
int gid);
static std::shared_ptr<ProcessGroupCustom> CreateProcessGroupCustom(
const std::shared_ptr<phi::distributed::Store>& store,
ProcessGroupCustom(const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type,
int rank,
int size,
int gid);
std::string GetBackendName() const override { return "XCCL_" + device_type_; }
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::string GetBackendName() const override { return "XCCL"; }
phi::DeviceContext* GetDeviceContext(const Place& place) const override;
phi::ccl::CCLComm CustomCCLComm(const Place& place) const;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const override;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
......@@ -100,11 +94,6 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
......@@ -112,61 +101,30 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
......@@ -180,11 +138,6 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
......@@ -198,52 +151,124 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
protected:
virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask(
std::vector<Place> places,
int rank,
CommType opType,
const std::vector<phi::DenseTensor>& inputs);
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<phi::distributed::Store> store_;
std::shared_ptr<CustomCCLCommManager> custom_comm_;
std::mutex mutex_;
std::unordered_map<std::string,
std::vector<std::shared_ptr<CustomCCLCommManager>>>
places_to_customcomm_;
std::unordered_map<std::string, std::vector<CustomEventManager>>
places_to_events_;
std::unordered_map<std::string,
std::vector<std::unique_ptr<CustomDeviceContext>>>
places_to_ctx_;
std::set<int> used_place_ids_;
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) override;
static void GroupStart(const std::string& dev_type);
static void GroupEnd(const std::string& dev_type);
phi::ccl::CCLComm XCCLComm(const Place& place) const;
// TODO(liyurui): This API will be moved later
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) override;
private:
void BcastCustomId(std::vector<phi::ccl::CCLRootId>& ccl_ids, // NOLINT
int root,
int server_fd);
std::shared_ptr<ProcessGroupCustom::XCCLTask> CreateTask(
const Place& place,
int rank,
CommType op_type,
bool sync_op,
bool use_calc_stream);
void BroadcastUniqueXCCLID(phi::ccl::CCLRootId* nccl_id);
void CreateXCCLEnvCache(const Place& place, const std::string& place_key);
void BroadcastUniqueCustomID(
std::vector<phi::ccl::CCLRootId>& custom_ccl_ids); // NOLINT
void SyncCalcStream(const Place& place);
std::shared_ptr<ProcessGroup::Task> RunFnInXCCLEnv(
std::function<void(const phi::stream::Stream&)> fn,
const phi::DenseTensor& tensor,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroupCustom::XCCLTask> CreateTask(
std::vector<Place> places,
int rank,
CommType op_type,
const std::vector<phi::DenseTensor>& inputs);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
Fn fn,
CommType op_type,
bool sync_op,
bool use_calc_stream);
CommType op_type);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective(Fn fn,
CommType op_type,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> PointToPoint(
std::vector<phi::DenseTensor>& tensors, // NOLINT
Fn fn,
int dst_rank,
CommType op_type);
void CreateCustomManagerCache(const std::string& places_key,
void CreateXCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places);
const std::string device_type_;
phi::distributed::XCCLCommContext* GetCommContext();
private:
std::shared_ptr<phi::distributed::Store> store_;
std::string device_type_;
std::unordered_map<std::string, std::unique_ptr<phi::event::Event>>
place_to_calc_event_; // event on calc stream
std::unordered_map<std::string, phi::CustomContext*> place_to_calc_ctx_;
std::unordered_map<std::string, std::unique_ptr<phi::CustomContext>>
place_to_comm_ctx_;
// TODO(sunyilun): attrs below will be removed later
std::mutex mutex_;
std::unordered_map<std::string, std::vector<phi::CustomContext*>>
places_to_ctx_;
};
} // namespace distributed
} // namespace paddle
......@@ -53,8 +53,8 @@ ccl::CCLComm GetCCLComm(const Place& place, int global_gid) {
#endif
} else if (place.GetType() == phi::AllocationType::CUSTOM) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
return static_cast<paddle::distributed::ProcessGroupCustom*>(pg)
->CustomCCLComm(place);
return static_cast<paddle::distributed::ProcessGroupCustom*>(pg)->XCCLComm(
place);
#else
return nullptr;
#endif
......
......@@ -1574,6 +1574,7 @@ void AllocatorFacade::SetDefaultStream(const platform::CustomPlace& place,
m_->SetDefaultStream(place, stream);
}
}
#endif
UNUSED static std::shared_ptr<NaiveBestFitAllocator> unused_obj =
......
......@@ -101,6 +101,7 @@ class AllocatorFacade {
phi::stream::stream_t stream);
void RecordStream(std::shared_ptr<phi::Allocation> allocation,
phi::stream::stream_t stream);
void SetDefaultStream(const platform::CustomPlace& place,
phi::stream::stream_t stream);
#endif
......
......@@ -71,6 +71,14 @@ gpuStream_t GetStream(const std::shared_ptr<Allocation>& allocation) {
return allocation::AllocatorFacade::Instance().GetStream(allocation);
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void RecordStream(std::shared_ptr<Allocation> allocation,
phi::stream::stream_t stream) {
return allocation::AllocatorFacade::Instance().RecordStream(allocation,
stream);
}
#endif
} // namespace memory
} // namespace paddle
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/stream.h"
......@@ -55,5 +56,9 @@ void RecordStream(std::shared_ptr<Allocation> allocation, gpuStream_t stream);
gpuStream_t GetStream(const std::shared_ptr<Allocation>& allocation);
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void RecordStream(std::shared_ptr<Allocation> allocation,
phi::stream::stream_t stream);
#endif
} // namespace memory
} // namespace paddle
......@@ -56,6 +56,12 @@ void BindCommContextManager(py::module *m) {
"create_gloo_comm_context",
&phi::distributed::CommContextManager::CreateGlooCommContext,
py::call_guard<py::gil_scoped_release>())
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
.def_static(
"create_xccl_comm_context",
&phi::distributed::CommContextManager::CreateXCCLCommContext,
py::call_guard<py::gil_scoped_release>())
#endif
.def("set_store", &phi::distributed::CommContextManager::SetStore);
}
......
......@@ -21,6 +21,8 @@
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/common/reduce_type.h"
namespace phi {
namespace ccl {
typedef void* CCLComm;
......@@ -38,6 +40,32 @@ enum CCLDataType {
CCL_DATA_TYPE_UINT8
};
inline CCLReduceOp ToXCCLReduceOp(int reduce_type) {
phi::ccl::CCLReduceOp red_type = phi::ccl::CCLReduceOp::SUM;
switch (static_cast<phi::ReduceType>(reduce_type)) {
case phi::ReduceType::kRedSum:
red_type = phi::ccl::CCLReduceOp::SUM;
break;
case phi::ReduceType::kRedMax:
red_type = phi::ccl::CCLReduceOp::MAX;
break;
case phi::ReduceType::kRedMin:
red_type = phi::ccl::CCLReduceOp::MIN;
break;
case phi::ReduceType::kRedProd:
red_type = phi::ccl::CCLReduceOp::PRODUCT;
break;
case phi::ReduceType::kRedAvg:
red_type = phi::ccl::CCLReduceOp::AVG;
break;
default:
PADDLE_THROW(
errors::Unavailable("Unsuppored reduce type. Reduce type must be one "
"of SUM, MAX, MIN, PRODUCT and AVG."));
}
return red_type;
}
inline CCLDataType ToCCLDataType(phi::DataType type) {
if (type == phi::DataType::FLOAT64) {
return CCL_DATA_TYPE_FP64;
......@@ -79,5 +107,14 @@ inline phi::DataType ToPhiDataType(CCLDataType type) {
}
}
inline std::string SerializeXCCLUniqueId(const phi::ccl::CCLRootId& ccl_id) {
const uint8_t* bytes = ccl_id.data();
std::ostringstream oss;
for (size_t i = 0; i < ccl_id.size(); ++i) {
oss << std::hex << static_cast<int>(bytes[i]);
}
return oss.str();
}
} // namespace ccl
} // namespace phi
......@@ -754,14 +754,16 @@ class CustomDevice : public DeviceInterface {
}
void CCLGroupStart() override {
CHECK_PTR(pimpl_->xccl_group_start);
if (pimpl_->xccl_group_start) {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_start());
}
}
void CCLGroupEnd() override {
CHECK_PTR(pimpl_->xccl_group_end);
if (pimpl_->xccl_group_end) {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_end());
}
}
void CCLSend(void* send_buf,
size_t count,
......
......@@ -12,4 +12,8 @@ if(WITH_GLOO)
list(APPEND DISTRIBUTED_COMMON_SRCS gloo_utils.cc gloo_comm_context.cc)
endif()
if(WITH_CUSTOM_DEVICE)
list(APPEND DISTRIBUTED_COMMON_SRCS xccl_comm_context.cc)
endif()
collect_srcs(core_srcs SRCS ${DISTRIBUTED_COMMON_SRCS})
......@@ -53,6 +53,18 @@ DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
}));
return out;
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (phi::CustomContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const CustomContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The all_gather in reshard only supported on CPU and GPU for now."));
......
......@@ -170,6 +170,15 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
#else
PADDLE_THROW(phi::errors::Unimplemented(
"Cannot use gloo on CPU, please turn PADDLE_WITH_GLOO flag on."));
#endif
} else if (phi::CustomContext::classof(&dev_ctx)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
CommContextManager::CreateXCCLCommContext(
store,
unique_comm_key,
dev_ctx.GetPlace().GetDeviceType(),
rank,
world_size);
#endif
} else {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
......
......@@ -24,8 +24,8 @@ class CommContext {
CommContext(int rank, int size) : rank_(rank), size_(size) {}
virtual ~CommContext() = default;
int GetRank() { return rank_; }
int GetSize() { return size_; }
int GetRank() const { return rank_; }
int GetSize() const { return size_; }
protected:
int rank_;
......
......@@ -32,6 +32,9 @@
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi {
namespace distributed {
......@@ -91,6 +94,35 @@ void CommContextManager::CreateGlooCommContext(
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void CommContextManager::CreateXCCLCommContext(
const std::shared_ptr<Store>& store,
const std::string& unique_comm_key,
const std::string& device_type,
int rank,
int size) {
phi::ccl::CCLRootId xccl_root_id;
if (rank == 0) {
phi::DeviceManager::CCLGetUniqueId(device_type, &xccl_root_id);
}
std::string unique_key = "XCCLCommContext/" + unique_comm_key;
if (rank == 0) {
store->set(unique_key, xccl_root_id);
} else {
xccl_root_id = store->get(unique_key);
}
VLOG(3) << "init xccl rank: " << rank << ", nranks: " << size
<< ", unique_comm_key: " << unique_comm_key << ", xccl uniqueid: "
<< phi::ccl::SerializeXCCLUniqueId(xccl_root_id);
auto xccl_comm_context =
std::make_unique<XCCLCommContext>(device_type, rank, size, xccl_root_id);
auto& comm_context_manager = CommContextManager::GetInstance();
comm_context_manager.SetStore(store);
comm_context_manager.Emplace(unique_comm_key, std::move(xccl_comm_context));
}
#endif
CommContext* CommContextManager::Emplace(
const std::string& unique_comm_key,
std::unique_ptr<CommContext> comm_context) {
......
......@@ -62,6 +62,14 @@ class CommContextManager {
int size);
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
static void CreateXCCLCommContext(const std::shared_ptr<Store>& store,
const std::string& unique_comm_key,
const std::string& device_type,
int rank,
int size);
#endif
private:
DISABLE_COPY_AND_ASSIGN(CommContextManager);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#include "glog/logging.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/check/static_check.h"
#include "paddle/phi/core/distributed/utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
namespace distributed {
XCCLCommContext::XCCLCommContext(const std::string& device_type,
int rank,
int size,
const ccl::CCLRootId& xccl_id)
: CommContext(rank, size) {
device_type_ = device_type;
phi::DeviceManager::CCLCommInitRank(device_type,
size_,
const_cast<ccl::CCLRootId*>(&xccl_id),
rank,
&xccl_comm_);
}
void XCCLCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
const phi::stream::Stream& stream) const {
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
if (rank_ == root) {
phi::DeviceManager::CCLBroadcast(device_type_,
const_cast<void*>(in_tensor.data()),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
root,
xccl_comm_,
stream);
} else {
phi::DeviceManager::CCLBroadcast(device_type_,
out_tensor->data(),
out_tensor->numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
root,
xccl_comm_,
stream);
}
}
void XCCLCommContext::AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::GatherLikeShape(
*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLAllGather(device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
xccl_comm_,
stream);
}
void XCCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::ScatterLikeShape(
*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLReduceScatter(
device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
out_tensor->numel(),
phi::ccl::ToCCLDataType(in_tensor.type()),
reduce_type,
xccl_comm_,
stream);
}
void XCCLCommContext::Send(const phi::DenseTensor& in_tensor,
const int64_t& count,
const int& peer,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::CheckShape(
in_tensor, rank_, size_, phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLSend(device_type_,
const_cast<void*>(in_tensor.data()),
count,
phi::ccl::ToCCLDataType(in_tensor.type()),
peer,
xccl_comm_,
stream);
VLOG(3) << "rank " << GetRank() << " send " << phi::product(in_tensor.dims())
<< " to " << peer;
}
void XCCLCommContext::Recv(phi::DenseTensor* out_tensor,
const int64_t& count,
const int& peer,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::CheckShape(
*out_tensor, rank_, size_, phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLRecv(device_type_,
out_tensor->data(),
count,
phi::ccl::ToCCLDataType(out_tensor->type()),
peer,
xccl_comm_,
stream);
VLOG(3) << "rank " << GetRank() << " recv "
<< phi::product(out_tensor->dims()) << " from " << peer;
}
void XCCLCommContext::AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLAllReduce(device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.type()),
reduce_type,
xccl_comm_,
stream);
}
void XCCLCommContext::Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
int root,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ root,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLReduce(device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.type()),
reduce_type,
root,
xccl_comm_,
stream);
}
void XCCLCommContext::GroupStart() const {
phi::DeviceManager::CCLGroupStart(device_type_);
}
void XCCLCommContext::GroupEnd() const {
phi::DeviceManager::CCLGroupEnd(device_type_);
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/distributed/comm_context.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/backends/device_manager.h"
namespace phi {
class DenseTensor;
namespace distributed {
class XCCLCommContext final : public CommContext {
public:
XCCLCommContext(const std::string& device_type,
int rank,
int size,
const ccl::CCLRootId& xccl_id);
ccl::CCLComm GetXcclComm() const { return xccl_comm_; }
const std::string& GetDeviceType() const { return device_type_; }
void Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
const phi::stream::Stream& stream) const;
void Send(const phi::DenseTensor& in_tensor,
const int64_t& count,
const int& peer,
const phi::stream::Stream& stream) const;
void Recv(phi::DenseTensor* out_tensor,
const int64_t& count,
const int& peer,
const phi::stream::Stream& stream) const;
void ReduceScatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
const phi::stream::Stream& stream) const;
void AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const phi::stream::Stream& stream) const;
void AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
const phi::stream::Stream& stream) const;
void Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
int root,
const phi::stream::Stream& stream) const;
void GroupStart() const;
void GroupEnd() const;
private:
DISABLE_COPY_AND_ASSIGN(XCCLCommContext);
std::string device_type_;
ccl::CCLComm xccl_comm_;
};
} // namespace distributed
} // namespace phi
......@@ -20,6 +20,9 @@
#if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi {
......@@ -49,6 +52,28 @@ void AllGatherKernel(const Context& dev_ctx,
#endif
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
void AllGatherKernel(const phi::CustomContext& dev_ctx,
const DenseTensor& x,
int nranks,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
auto out_dims = x.dims();
out_dims[0] *= nranks;
out->Resize(out_dims);
auto comm_ctx =
static_cast<distributed::XCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_EQ(
nranks,
comm_ctx->GetSize(),
errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm_ctx->GetSize()));
comm_ctx->AllGather(out, x, *dev_ctx.GetStream());
}
#endif
} // namespace phi
PD_REGISTER_KERNEL(all_gather,
......@@ -64,3 +89,19 @@ PD_REGISTER_KERNEL(all_gather,
int16_t,
int64_t,
phi::dtype::float16) {}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(all_gather,
Custom,
ALL_LAYOUT,
phi::AllGatherKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::float16) {}
#endif
......@@ -20,6 +20,9 @@
#if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi {
......@@ -47,6 +50,27 @@ void AllReduceKernel(const Context& dev_ctx,
#endif
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
void AllReduceKernel(const phi::CustomContext& dev_ctx,
const DenseTensor& x,
int reduce_type,
DenseTensor* out) {
out->Resize(x.dims());
dev_ctx.template Alloc<T>(out);
auto comm_ctx =
static_cast<distributed::XCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_ctx,
nullptr,
errors::Unavailable("XCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
comm_ctx->AllReduce(
out, x, phi::ccl::ToXCCLReduceOp(reduce_type), *dev_ctx.GetStream());
}
#endif
} // namespace phi
PD_REGISTER_KERNEL(all_reduce,
......@@ -61,3 +85,18 @@ PD_REGISTER_KERNEL(all_reduce,
uint8_t,
int64_t,
phi::dtype::float16) {}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(all_reduce,
Custom,
ALL_LAYOUT,
phi::AllReduceKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#endif
......@@ -16,6 +16,9 @@
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi {
......@@ -26,6 +29,42 @@ void AllToAllKernel(const Context& dev_ctx UNUSED,
PADDLE_THROW(
errors::Unimplemented("Unimplemented cpu kernel for all_to_all."));
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
void AllToAllKernel(const phi::CustomContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
out->Resize(x.dims());
dev_ctx.template Alloc<T>(out);
auto comm_ctx =
static_cast<distributed::XCCLCommContext*>(dev_ctx.GetCommContext());
int nranks = comm_ctx->GetSize();
int rank = comm_ctx->GetRank();
int send_numel = x.numel() / nranks;
std::vector<void*> sendbuf, recvbuf;
std::vector<size_t> sendsize(send_numel, nranks);
std::vector<phi::ccl::CCLDataType> sendtype(
phi::ccl::ToCCLDataType(x.dtype()), nranks);
for (auto i = 0; i < nranks; ++i) {
sendbuf.push_back(x.data<T>() + i * send_numel);
recvbuf.push_back(out->data<T>() + i * send_numel);
}
phi::DeviceManager::CCLAllToAll(dev_ctx.GetPlace().GetDeviceType(),
const_cast<const void**>(sendbuf.data()),
sendsize.data(),
sendtype.data(),
recvbuf.data(),
sendsize.data(),
sendtype.data(),
rank,
nranks,
comm_ctx->GetXcclComm(),
*dev_ctx.GetStream());
}
#endif
} // namespace phi
......@@ -41,3 +80,17 @@ PD_REGISTER_KERNEL(all_to_all,
uint8_t,
int64_t,
phi::dtype::float16) {}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(all_to_all,
Custom,
ALL_LAYOUT,
phi::AllToAllKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#endif
......@@ -20,6 +20,9 @@
#if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi {
......@@ -51,6 +54,35 @@ void ReduceKernel(const Context& dev_ctx,
#endif
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
void ReduceKernel(const phi::CustomContext& dev_ctx,
const DenseTensor& x,
int root,
int reduce_type,
DenseTensor* out) {
PADDLE_ENFORCE_GT(
x.numel(),
0,
phi::errors::InvalidArgument("Tensor need be reduced must not empyt."));
out->Resize(x.dims());
dev_ctx.template Alloc<T>(out);
auto comm_ctx =
static_cast<distributed::XCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_ctx,
nullptr,
errors::Unavailable("XCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
comm_ctx->Reduce(out,
x,
phi::ccl::ToXCCLReduceOp(reduce_type),
root,
*dev_ctx.GetStream());
}
#endif
} // namespace phi
PD_REGISTER_KERNEL(reduce,
......@@ -65,3 +97,18 @@ PD_REGISTER_KERNEL(reduce,
uint8_t,
int64_t,
phi::dtype::float16) {}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(reduce,
Custom,
ALL_LAYOUT,
phi::ReduceKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#endif
......@@ -334,3 +334,9 @@ def _init_parallel_env(backend):
core.CommContextManager.create_nccl_comm_context(
store, "0", rank, world_size
)
elif backend == "xccl":
dev_type = global_env.device_type
paddle.device.set_device(f"{dev_type}:{dev_id}")
core.CommContextManager.create_xccl_comm_context(
store, "0", rank, world_size, dev_type
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册