未验证 提交 62397cd2 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add comm context support (#56301)

上级 ede8fd55
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
phi::ccl::CCLReduceOp ToCustomCCLRedType(ReduceOp reduction) { phi::ccl::CCLReduceOp ToXCCLRedType(ReduceOp reduction) {
static const std::map<ReduceOp, phi::ccl::CCLReduceOp> red_type = { static const std::map<ReduceOp, phi::ccl::CCLReduceOp> red_type = {
{ReduceOp::MIN, phi::ccl::CCLReduceOp::MIN}, {ReduceOp::MIN, phi::ccl::CCLReduceOp::MIN},
{ReduceOp::MAX, phi::ccl::CCLReduceOp::MAX}, {ReduceOp::MAX, phi::ccl::CCLReduceOp::MAX},
...@@ -34,14 +34,5 @@ phi::ccl::CCLReduceOp ToCustomCCLRedType(ReduceOp reduction) { ...@@ -34,14 +34,5 @@ phi::ccl::CCLReduceOp ToCustomCCLRedType(ReduceOp reduction) {
return it->second; return it->second;
} }
std::string SerializeCustomCCLUniqueId(const phi::ccl::CCLRootId& ccl_id) {
const uint8_t* bytes = ccl_id.data();
std::ostringstream oss;
for (size_t i = 0; i < ccl_id.size(); ++i) {
oss << std::hex << static_cast<int>(bytes[i]);
}
return oss.str();
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -34,170 +34,7 @@ ...@@ -34,170 +34,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class CustomEventManager { phi::ccl::CCLReduceOp ToXCCLRedType(ReduceOp reduction);
public:
CustomEventManager() = default;
~CustomEventManager() {
if (is_created_) {
event_->Destroy();
}
}
CustomEventManager(const CustomEventManager&) = delete;
CustomEventManager& operator=(const CustomEventManager&) = delete;
CustomEventManager(CustomEventManager&& other) {
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(device_type_, other.device_type_);
std::swap(event_, other.event_);
}
CustomEventManager& operator=(CustomEventManager&& other) {
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(device_type_, other.device_type_);
std::swap(event_, other.event_);
return *this;
}
bool IsCreated() const { return is_created_; }
int8_t DeviceId() const { return device_index_; }
std::string DeviceType() const { return device_type_; }
phi::event::event_t GetRawCustomEvent() const { return event_->raw_event(); }
phi::event::Event* GetCustomEvent() const { return event_.get(); }
void Record(const paddle::platform::CustomDeviceContext& ctx) {
auto place = ctx.GetPlace();
auto device_type = place.GetDeviceType();
auto device_index = place.GetDeviceId();
if (!is_created_) {
CreateEvent(place);
}
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"CustomDeviceContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
PADDLE_ENFORCE_EQ(device_type,
device_type_,
platform::errors::PreconditionNotMet(
"CustomDeviceContext's device %d does not match"
"Event's device type %d",
device_type,
device_type_));
phi::DeviceGuard guard(place);
phi::stream::Stream stream(place, ctx.stream());
event_->Record(&stream);
}
bool Query() const { return event_->Query(); }
void Block(const paddle::platform::CustomDeviceContext& ctx) const {
if (is_created_) {
auto place = ctx.GetPlace();
auto device_type = place.GetDeviceType();
auto device_index = place.GetDeviceId();
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"CustomDeviceContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
PADDLE_ENFORCE_EQ(device_type,
device_type_,
platform::errors::PreconditionNotMet(
"CustomDeviceContext's device %d does not match"
"Event's device type %d",
device_type,
device_type_));
phi::DeviceGuard guard(place);
phi::stream::Stream stream(place, ctx.stream());
stream.WaitEvent(event_.get());
}
}
private:
bool is_created_{false};
std::shared_ptr<phi::event::Event> event_{nullptr};
int8_t device_index_{0};
std::string device_type_;
private:
void CreateEvent(const platform::Place& place) {
device_index_ = place.GetDeviceId();
device_type_ = place.GetDeviceType();
event_.reset(new phi::event::Event);
event_->Init(place);
is_created_ = true;
}
};
class CustomCCLCommManager {
public:
CustomCCLCommManager(const std::string& device_type,
phi::ccl::CCLComm ccl_comm)
: device_type_(device_type), ccl_comm_(ccl_comm) {}
CustomCCLCommManager() : CustomCCLCommManager("", nullptr) {}
~CustomCCLCommManager() noexcept {
std::unique_lock<std::mutex> lock(mutex_);
if (phi::DeviceManager::HasDeviceType(device_type_) && ccl_comm_) {
phi::DeviceManager::CCLDestroyComm(device_type_, ccl_comm_);
}
}
static std::shared_ptr<CustomCCLCommManager> Create(
const std::string& device_type,
int num_ranks,
int rank,
phi::ccl::CCLRootId* comm_id,
phi::ccl::CCLComm* ccl_comm) {
auto custom_ccl_manager = std::make_shared<CustomCCLCommManager>();
phi::DeviceManager::CCLCommInitRank(
device_type, num_ranks, comm_id, rank, ccl_comm);
custom_ccl_manager->device_type_ = device_type;
custom_ccl_manager->ccl_id_ = comm_id;
custom_ccl_manager->rank_ = rank;
custom_ccl_manager->ccl_comm_ = *ccl_comm;
return custom_ccl_manager;
}
phi::ccl::CCLRootId* GetCustomCCLId() const {
std::unique_lock<std::mutex> lock(mutex_);
return ccl_id_;
}
phi::ccl::CCLComm GetCustomCCLComm() const {
std::unique_lock<std::mutex> lock(mutex_);
return ccl_comm_;
}
CustomCCLCommManager(const CustomCCLCommManager&) = delete;
CustomCCLCommManager& operator=(const CustomCCLCommManager&) = delete;
CustomCCLCommManager& operator=(CustomCCLCommManager&& other) = delete;
CustomCCLCommManager(CustomCCLCommManager&& other) {
std::unique_lock<std::mutex> lock(other.mutex_);
std::swap(ccl_comm_, other.ccl_comm_);
}
protected:
std::string device_type_;
phi::ccl::CCLComm ccl_comm_;
phi::ccl::CCLRootId* ccl_id_;
int rank_;
mutable std::mutex mutex_;
};
phi::ccl::CCLReduceOp ToCustomCCLRedType(ReduceOp reduction);
std::string SerializeCustomCCLUniqueId(const phi::ccl::CCLRootId& ccl_id);
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -17,99 +17,62 @@ ...@@ -17,99 +17,62 @@
#include "paddle/fluid/distributed/collective/common.h" #include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/distributed/collective/custom_ccl_tools.h" #include "paddle/fluid/distributed/collective/custom_ccl_tools.h"
#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/distributed/check/static_check.h" #include "paddle/phi/core/distributed/check/static_check.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/utils/data_type.h"
DECLARE_bool(xccl_blocking_wait); #include "paddle/phi/core/distributed/comm_context_manager.h"
constexpr int64_t kWaitBlockTImeout = 10; constexpr int64_t kWaitBlockTImeout = 10;
DECLARE_bool(use_stream_safe_cuda_allocator);
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
void SyncDefaultStream( ProcessGroupCustom::XCCLTask::XCCLTask(const Place& place,
const std::vector<Place>& places, int rank,
std::vector<CustomEventManager>& cclEvents, // NOLINT CommType comm_type,
std::vector<std::unique_ptr<CustomDeviceContext>>& dev_ctx) { // NOLINT bool sync_op,
for (size_t i = 0; i < places.size(); ++i) { bool use_calc_stream)
auto* default_ctx = static_cast<platform::CustomDeviceContext*>( : TaskStream(rank, comm_type, sync_op, use_calc_stream),
platform::DeviceContextPool::Instance().Get(places[i])); task_place_(place) {
cclEvents[i].Record(*default_ctx); comm_event_.Init(place);
cclEvents[i].Block(*dev_ctx[i]);
}
}
std::shared_ptr<ProcessGroupCustom::CustomTask> ProcessGroupCustom::CreateTask(
std::vector<Place> places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupCustom::CustomTask>(
places, rank, comm_type, inputs);
}
ProcessGroupCustom::CustomTask::CustomTask(
const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs)
: Task(rank, inputs, CommType), places_(places) {
control_events_.resize(places.size());
cclComms_.resize(places.size());
} }
ProcessGroupCustom::CustomTask::~CustomTask() {} ProcessGroupCustom::XCCLTask::~XCCLTask() = default;
void ProcessGroupCustom::CustomTask::SetOutputs( bool ProcessGroupCustom::XCCLTask::IsCompleted() { return comm_event_.Query(); }
std::vector<phi::DenseTensor>& outputs) { // NOLINT
outputs_ = std::make_shared<std::vector<phi::DenseTensor>>(outputs);
}
void ProcessGroupCustom::CustomTask::SynchronizeStreams() { void ProcessGroupCustom::XCCLTask::UpdateWaitChain(
for (size_t i = 0; i < places_.size(); ++i) { const phi::DeviceContext& ctx) {
auto* default_ctx = static_cast<platform::CustomDeviceContext*>( comm_event_.Record(
platform::DeviceContextPool::Instance().Get(places_[i])); reinterpret_cast<const phi::CustomContext&>(ctx).GetStream().get());
phi::DeviceGuard guard(default_ctx->GetPlace());
control_events_[i].Block(*default_ctx);
}
} }
bool ProcessGroupCustom::CustomTask::IsCompleted() { bool ProcessGroupCustom::XCCLTask::Wait(std::chrono::milliseconds timeout) {
for (size_t i = 0; i < places_.size(); ++i) { // Warning here when use calc stream but also invoke waiting explicitly.
if (!control_events_[i].Query()) { if (UseCalcStream()) {
return false; VLOG(3) << "Warning: The communication is on calc stream, wait here is "
} "useless.";
return true;
} }
return true; const auto* calc_ctx = reinterpret_cast<phi::CustomContext*>(
} platform::DeviceContextPool::Instance().Get(task_place_));
calc_ctx->GetStream()->WaitEvent(&comm_event_);
bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) { if (IsBlockCPUInWait()) {
SynchronizeStreams(); // If we use the work to do barrier, we should block cpu
while (!IsCompleted()) { phi::DeviceManager::SynchronizeDevice(task_place_);
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
} }
return true; return true;
} }
// Same as Wait // Same as Wait
void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); } void ProcessGroupCustom::XCCLTask::Synchronize() { Wait(kWaitTimeout); }
void ProcessGroupCustom::CustomTask::UpdateWaitChain(
const phi::DeviceContext& ctx) {
PADDLE_ENFORCE_NE(
std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()),
places_.cend(),
phi::errors::NotFound("Cannot find the device context in this task."));
auto index = std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()) -
places_.cbegin();
control_events_[index].Record(
reinterpret_cast<const phi::CustomContext&>(ctx));
}
ProcessGroupCustom::ProcessGroupCustom( ProcessGroupCustom::ProcessGroupCustom(
const std::shared_ptr<phi::distributed::Store>& store, const std::shared_ptr<phi::distributed::Store>& store,
...@@ -121,147 +84,45 @@ ProcessGroupCustom::ProcessGroupCustom( ...@@ -121,147 +84,45 @@ ProcessGroupCustom::ProcessGroupCustom(
store_(store), store_(store),
device_type_(device_type) {} device_type_(device_type) {}
void ProcessGroupCustom::BroadcastUniqueCustomID( void ProcessGroupCustom::GroupStart(const std::string& dev_type) {
std::vector<phi::ccl::CCLRootId>& ccl_ids) { // NOLINT phi::DeviceManager::CCLGroupStart(dev_type);
if (rank_ == 0) {
for (size_t i = 0; i < ccl_ids.size(); i++) {
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
store_->set(key, ccl_ids[i]);
}
} else {
for (size_t i = 0; i < ccl_ids.size(); i++) {
auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" +
std::to_string(i);
ccl_ids[i] = store_->get(key);
}
}
} }
// create CustomCCLManager cache for places_key void ProcessGroupCustom::GroupEnd(const std::string& dev_type) {
void ProcessGroupCustom::CreateCustomManagerCache( phi::DeviceManager::CCLGroupEnd(dev_type);
const std::string& places_key, const std::vector<Place>& places) {
PADDLE_ENFORCE_EQ(
places_key.empty(),
false,
platform::errors::PreconditionNotMet(
"Not able to create/get the CustomCCL Communicator since "
"the NPU place are not known"));
const std::string device_type = places.back().GetDeviceType();
std::vector<std::shared_ptr<CustomCCLCommManager>> ccl_comms;
ccl_comms.resize(places.size());
// using vector just for broadcast
std::vector<phi::ccl::CCLRootId> ccl_ids;
ccl_ids.resize(1);
auto& ccl_id = ccl_ids.front();
if (rank_ == 0) {
phi::DeviceManager::CCLGetUniqueId(device_type, &ccl_id);
}
BroadcastUniqueCustomID(ccl_ids);
VLOG(3) << "init custom ccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << places_key
<< ", custom ccl uniqueid: " << SerializeCustomCCLUniqueId(ccl_id);
std::vector<std::unique_ptr<CustomDeviceContext>> dev_ctx;
dev_ctx.resize(places.size());
for (size_t i = 0; i < places.size(); ++i) {
phi::DeviceGuard guard(places[i]);
ccl_comms[i] = CustomCCLCommManager::Create(
device_type, GetSize(), GetRank(), &ccl_id, new phi::ccl::CCLComm);
dev_ctx[i] = std::make_unique<CustomDeviceContext>(places[i]);
dev_ctx[i]->SetAllocator(
&(phi::DeviceContextPool::Instance().Get(places[i])->GetAllocator()));
dev_ctx[i]->SetHostAllocator(&(
phi::DeviceContextPool::Instance().Get(places[i])->GetHostAllocator()));
dev_ctx[i]->SetZeroAllocator(&(
phi::DeviceContextPool::Instance().Get(places[i])->GetZeroAllocator()));
dev_ctx[i]->SetHostZeroAllocator(&(phi::DeviceContextPool::Instance()
.Get(places[i])
->GetHostZeroAllocator()));
}
std::vector<CustomEventManager> events;
events.resize(places.size());
// These caches will be useful to process sync/wait/communicate
places_to_events_.emplace(places_key, std::move(events));
places_to_customcomm_.emplace(places_key, std::move(ccl_comms));
places_to_ctx_.emplace(places_key, std::move(dev_ctx));
} }
template <typename Fn> phi::DeviceContext* ProcessGroupCustom::GetDeviceContext(
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective( const Place& place) const {
std::vector<phi::DenseTensor>& inputs, return GetDeviceContext(place, /*use_calc_stream*/ false);
std::vector<phi::DenseTensor>& outputs,
Fn fn,
CommType op_type,
bool sync_op UNUSED,
bool use_calc_stream) {
const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_customcomm_.find(key) == places_to_customcomm_.end()) {
CreateCustomManagerCache(key, places);
}
}
auto& ccl_comms = places_to_customcomm_[key];
if (!use_calc_stream) {
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
}
auto task = CreateTask(places, rank_, op_type, inputs);
task->SetOutputs(outputs);
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
const auto& ccl_stream =
use_calc_stream ? reinterpret_cast<phi::CustomContext*>(
phi::DeviceContextPool::Instance().Get(places[i]))
->stream()
: places_to_ctx_[key][i]->stream();
phi::stream::Stream stream(places[i], ccl_stream);
fn(inputs[i], outputs[i], ccl_comms[i]->GetCustomCCLComm(), stream);
}
if (!use_calc_stream) {
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
}
}
return task;
} }
void* XcclGetPointerByOffset(void* raw_pointer, phi::DeviceContext* ProcessGroupCustom::GetDeviceContext(
size_t offset, const Place& place, bool use_calc_stream) const {
phi::DataType type) { const std::string& key = GetKeyFromPlace(place);
if (type == phi::DataType::FLOAT32) { if (use_calc_stream) {
return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) + const auto& iter = place_to_calc_ctx_.find(key);
offset); return iter->second;
} else if (type == phi::DataType::FLOAT64) {
return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
offset);
} else if (type == phi::DataType::INT32) {
return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::INT64) {
return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::FLOAT16) {
return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
offset);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( const auto& iter = place_to_comm_ctx_.find(key);
"This datatype in xccl is not supported.")); PADDLE_ENFORCE_NE(
iter,
place_to_comm_ctx_.end(),
phi::errors::NotFound(
"Cannot find the device context in this process group."));
return iter->second.get();
} }
return nullptr; }
phi::ccl::CCLComm ProcessGroupCustom::XCCLComm(const Place& place) const {
const std::string& key = GetKeyFromPlace(place);
const auto& iter = place_to_comm_ctx_.find(key);
PADDLE_ENFORCE_NE(
iter,
place_to_comm_ctx_.end(),
phi::errors::NotFound(
"Cannot find the XCCL communicator in this process group."));
return iter->second->xccl_comm();
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
...@@ -269,212 +130,110 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather( ...@@ -269,212 +130,110 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
bool sync_op, // for compatibility, no use now bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
// numel > 0 indicates the tensor need to be sliced // numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& in_tensor_maybe_partial = const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0 numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
? paddle::distributed::GetPartialTensor(in_tensor, offset, numel) return RunFnInXCCLEnv(
: in_tensor; [&](const phi::stream::Stream& stream) {
phi::distributed::CommStaticCheck::GatherLikeShape( auto comm_context = this->GetCommContext();
*out_tensor, comm_context->AllGather(out_tensor, in_tensor_maybe_partial, stream);
in_tensor_maybe_partial,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor_maybe_partial};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllGather(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
comm,
stream);
}, },
in_tensor_maybe_partial,
CommType::ALLGATHER, CommType::ALLGATHER,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
} }
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllGather(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts, const AllreduceOptions& opts,
bool sync_op, // for compatibility, no use now bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor}; return RunFnInXCCLEnv(
std::vector<phi::DenseTensor> out_wrapper{*out_tensor}; [&](const phi::stream::Stream& stream) {
PADDLE_ENFORCE_EQ( auto comm_context = this->GetCommContext();
CheckTensorsInCustomPlace(in_wrapper, device_type_), comm_context->AllReduce(
true, out_tensor,
platform::errors::InvalidArgument( in_tensor,
"All inputs should be in CustomPlace(%s).", device_type_)); paddle::distributed::ToXCCLRedType(opts.reduce_op),
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_wrapper, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllReduce(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
ToCustomCCLRedType(opts.reduce_op),
comm,
stream); stream);
}, },
in_tensor,
CommType::ALLREDUCE, CommType::ALLREDUCE,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const AllreduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllReduce(
device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
ToCustomCCLRedType(opts.reduce_op),
comm,
stream);
},
CommType::ALLREDUCE,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts, const std::vector<int64_t>& out_size_each_rank,
bool sync_op, // for compatibility, no use now const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor}; const phi::DDim& out_dim = out_tensor->dims();
std::vector<phi::DenseTensor> out_wrapper{*out_tensor}; const phi::DDim& in_dim = in_tensor.dims();
PADDLE_ENFORCE_EQ( CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
CheckTensorsInCustomPlace(in_wrapper, device_type_), CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);
true,
platform::errors::InvalidArgument( // NOTE: Since `all_to_all` needs other processes' participation, it cannot
"All inputs should be in CustomPlace(%s).", device_type_)); // simply be covered by static checks. Factors are set to 0 here to skip the
PADDLE_ENFORCE_EQ( // shape check. Its shape check will be done by dynamic checks with
CheckTensorsInCustomPlace(out_wrapper, device_type_), // FLAGS_enable_xccl_dynamic_check.
true, return RunFnInXCCLEnv(
platform::errors::InvalidArgument( [&](const phi::stream::Stream& stream) {
"All outputs should be in CustomPlace(%s).", device_type_)); auto comm_context = this->GetCommContext();
return Collective(
in_wrapper, int64_t in_row_size = in_tensor.numel() / in_dim[0],
out_wrapper, out_row_size = out_tensor->numel() / out_dim[0];
[&](phi::DenseTensor& input, int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
phi::DenseTensor& output, phi::DenseTensor input_partial, output_partial;
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) { std::vector<void*> send_buf, recv_buf;
int root = opts.source_rank * in_wrapper.size() + opts.source_root; std::vector<size_t> send_count, recv_count;
if (rank_ == root) { std::vector<phi::ccl::CCLDataType> send_dtype, recv_dtype;
return phi::DeviceManager::CCLBroadcast( for (auto i = 0; i < size_; i++) {
device_type_, in_numel = in_size_each_rank[i] * in_row_size;
input.data(), input_partial = GetPartialTensor(in_tensor, in_offset, in_numel);
input.numel(), out_numel = out_size_each_rank[i] * out_row_size;
phi::ccl::ToCCLDataType(input.dtype()), output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel);
root, in_offset += in_numel;
comm, out_offset += out_numel;
stream); send_buf.push_back(input_partial.data());
} else { recv_buf.push_back(output_partial.data());
return phi::DeviceManager::CCLBroadcast( send_count.push_back(in_numel);
device_type_, recv_count.push_back(out_numel);
output.data(), send_dtype.push_back(phi::ccl::ToCCLDataType(input_partial.dtype()));
output.numel(), recv_dtype.push_back(phi::ccl::ToCCLDataType(output_partial.dtype()));
phi::ccl::ToCCLDataType(output.dtype()),
root,
comm,
stream);
} }
phi::DeviceManager::CCLAllToAll(
device_type_,
const_cast<const void**>(send_buf.data()),
send_count.data(),
send_dtype.data(),
recv_buf.data(),
recv_count.data(),
recv_dtype.data(),
rank_,
size_,
comm_context->GetXcclComm(),
stream);
}, },
CommType::BROADCAST, in_tensor,
CommType::ALLTOALL,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
const BarrierOptions& opts) { const BarrierOptions& opts) {
// Only support single card single process
PADDLE_ENFORCE_GE(opts.device_id, PADDLE_ENFORCE_GE(opts.device_id,
0, 0,
platform::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"The barrier device id must greater or equal than 0.")); "The barrier device id must greater or equal than 0."));
platform::CustomPlace place(device_type_, opts.device_id); platform::CustomPlace place(device_type_, opts.device_id);
auto allocator = std::unique_ptr<phi::Allocator>( auto allocator = std::unique_ptr<phi::Allocator>(
...@@ -482,111 +241,176 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier( ...@@ -482,111 +241,176 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1}); phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
phi::DenseTensor barrier_tensor{allocator.get(), meta}; phi::DenseTensor barrier_tensor{allocator.get(), meta};
auto task = ProcessGroupCustom::AllReduce(&barrier_tensor, auto task = AllReduce(&barrier_tensor,
barrier_tensor, barrier_tensor,
{}, {},
/*sync_op*/ true, /*sync_op*/ true,
false); /*use_calc_stream*/ false);
auto xccl_task = dynamic_cast<ProcessGroupCustom::CustomTask*>(task.get()); auto xccl_task = dynamic_cast<XCCLTask*>(task.get());
xccl_task->barrierTensors_ = {barrier_tensor}; xccl_task->SetBlockCPUInWait();
return task; return task;
} }
phi::DeviceContext* ProcessGroupCustom::GetDeviceContext(
const Place& place) const {
const std::string key = GetKeyFromPlace(place);
const auto& iter = places_to_ctx_.find(key);
PADDLE_ENFORCE_NE(
iter,
places_to_ctx_.end(),
platform::errors::NotFound(
"Cannot find the device context in this process group."));
return iter->second[0].get();
}
phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_customcomm_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_customcomm_.end(),
platform::errors::InvalidArgument(
"Cannot find nccl comm in process group."));
return iter->second[0]->GetCustomCCLComm();
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& out_tensors, // NOLINT const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts) { const BroadcastOptions& opts,
PADDLE_ENFORCE_EQ( bool sync_op,
CheckTensorsInCustomPlace(in_tensors, device_type_), bool use_calc_stream) {
true, return RunFnInXCCLEnv(
platform::errors::InvalidArgument( [&](const phi::stream::Stream& stream) {
"All inputs should be in CustomPlace(%s).", device_type_)); int root = opts.source_rank + opts.source_root;
PADDLE_ENFORCE_EQ( auto comm_context = this->GetCommContext();
CheckTensorsInCustomPlace(out_tensors, device_type_), comm_context->Broadcast(out_tensor, in_tensor, root, stream);
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
int root = opts.source_rank * in_tensors.size() + opts.source_root;
if (rank_ == root) {
return phi::DeviceManager::CCLBroadcast(
device_type_,
input.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
root,
comm,
stream);
} else {
return phi::DeviceManager::CCLBroadcast(
device_type_,
output.data(),
output.numel(),
phi::ccl::ToCCLDataType(output.dtype()),
root,
comm,
stream);
}
}, },
in_tensor,
CommType::BROADCAST, CommType::BROADCAST,
false, sync_op,
false); use_calc_stream);
} }
void CheckTensorsInDifferentCustomDevices( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) { phi::DenseTensor* out_tensor,
PADDLE_ENFORCE_EQ( const phi::DenseTensor& in_tensor,
tensors.size() == 0, const ReduceOptions& opts,
false, bool sync_op,
phi::errors::InvalidArgument("Tensor list must be nonempty.")); bool use_calc_stream) {
PADDLE_ENFORCE_LE( return RunFnInXCCLEnv(
tensors.size(), [&](const phi::stream::Stream& stream) {
num_devices, auto comm_context = this->GetCommContext();
phi::errors::InvalidArgument("Tensor list mustn't be larger than the " comm_context->Reduce(out_tensor,
"number of available CustomDevice.")); in_tensor,
paddle::distributed::ToXCCLRedType(opts.reduce_op),
opts.root_rank,
stream);
},
in_tensor,
CommType::REDUCE,
sync_op,
use_calc_stream);
}
std::set<Place> used_devices; std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
return RunFnInXCCLEnv(
[&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
comm_context->ReduceScatter(
out_tensor,
in_tensor,
paddle::distributed::ToXCCLRedType(opts.reduce_op),
stream);
},
in_tensor,
CommType::REDUCE_SCATTER,
sync_op,
use_calc_stream);
}
for (const auto& t : tensors) { std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
PADDLE_ENFORCE_EQ(platform::is_custom_place(t.place()), phi::DenseTensor* out_tensor,
true, const phi::DenseTensor& in_tensor,
phi::errors::InvalidArgument( const ScatterOptions& opts,
"Tensors must be CustomDevice and dense tensor.")); bool sync_op,
bool use_calc_stream) {
phi::distributed::CommStaticCheck::ScatterLikeShape(
*out_tensor,
in_tensor,
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
return RunFnInXCCLEnv(
[&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
const auto inserted = used_devices.insert(t.place()).second; int64_t numel = in_tensor.numel() / size_;
PADDLE_ENFORCE_EQ(inserted, if (rank_ == opts.root_rank) {
true, int64_t offset = 0;
phi::errors::InvalidArgument( phi::DenseTensor partial_tensor;
"Tensors must be on distinct custom devices.")); for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(in_tensor, offset, numel);
if (i != rank_) {
comm_context->Send(partial_tensor, numel, i, stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(out_tensor->data(),
partial_tensor.data(),
numel * phi::SizeOf(partial_tensor.dtype()),
&stream);
}
offset += numel;
}
} else {
comm_context->Recv(out_tensor, numel, opts.root_rank, stream);
}
},
in_tensor,
CommType::SCATTER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
std::vector<phi::DenseTensor> partial_tensors;
if (rank_ == opts.root_rank) {
partial_tensors.reserve(size_);
size_t offset = 0;
size_t numel = out_tensor->numel() / size_;
for (auto i = 0; i < size_; i++) {
partial_tensors.push_back(GetPartialTensor(*out_tensor, offset, numel));
offset += numel;
}
} }
return Gather(&partial_tensors, in_tensor, opts, sync_op, use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
std::vector<phi::DenseTensor>* gather_tensors_ptr,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
auto& gather_tensors = *gather_tensors_ptr;
PADDLE_ENFORCE_GT(size_,
opts.root_rank,
phi::errors::InvalidArgument(
"root world size [%d] is less than root rank [%d]",
size_,
opts.root_rank));
auto gather_func = [&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
// root receive from all devices
if (rank_ == opts.root_rank) {
for (auto i = 0; i < size_; i++) {
auto& gather_tensor = gather_tensors[i];
if (i != rank_) {
comm_context->Recv(&gather_tensor, gather_tensor.numel(), i, stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(
gather_tensor.data(),
in_tensor.data(),
in_tensor.numel() * phi::SizeOf(in_tensor.dtype()),
&stream);
}
}
} else {
// send to root
comm_context->Send(in_tensor, in_tensor.numel(), opts.root_rank, stream);
}
};
return RunFnInXCCLEnv(
gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
...@@ -602,53 +426,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv( ...@@ -602,53 +426,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
partial_tensor = GetPartialTensor(*tensor, offset, numel); partial_tensor = GetPartialTensor(*tensor, offset, numel);
tensor = &partial_tensor; tensor = &partial_tensor;
} }
phi::distributed::CommStaticCheck::CheckShape(
*tensor, rank_, size_, phi::AllocationType::CUSTOM); return RunFnInXCCLEnv(
std::vector<phi::DenseTensor> in_wrapper{*tensor}; [&](const phi::stream::Stream& stream) {
std::vector<phi::DenseTensor> out_wrapper{*tensor}; auto comm_context = this->GetCommContext();
return Collective( comm_context->Recv(tensor, tensor->numel(), src_rank, stream);
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLRecv(device_type_,
output.data(),
output.numel(),
phi::ccl::ToCCLDataType(output.dtype()),
src_rank,
comm,
stream);
}, },
*tensor,
CommType::RECV, CommType::RECV,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) {
CheckTensorsInDifferentCustomDevices(tensors, static_cast<size_t>(GetSize()));
return Collective(
tensors,
tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLRecv(device_type_,
output.data(),
output.numel(),
phi::ccl::ToCCLDataType(output.dtype()),
src_rank,
comm,
stream);
},
CommType::RECV,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
int dst_rank, int dst_rank,
...@@ -659,192 +448,459 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send( ...@@ -659,192 +448,459 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
// numel > 0 indicates the tensor need to be sliced // numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& tensor_maybe_partial = const phi::DenseTensor& tensor_maybe_partial =
numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor; numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
phi::distributed::CommStaticCheck::CheckShape(
tensor_maybe_partial, rank_, size_, phi::AllocationType::CUSTOM); return RunFnInXCCLEnv(
std::vector<phi::DenseTensor> in_wrapper{tensor_maybe_partial}; [&](const phi::stream::Stream& stream) {
std::vector<phi::DenseTensor> out_wrapper{tensor_maybe_partial}; auto comm_context = this->GetCommContext();
return Collective( comm_context->Send(tensor_maybe_partial,
in_wrapper, tensor_maybe_partial.numel(),
out_wrapper, dst_rank,
[&](phi::DenseTensor& input, stream);
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLSend(device_type_,
input.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
}, },
tensor_maybe_partial,
CommType::SEND, CommType::SEND,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send( std::shared_ptr<ProcessGroupCustom::XCCLTask> ProcessGroupCustom::CreateTask(
std::vector<phi::DenseTensor>& tensors, int dst_rank) { const Place& place,
CheckTensorsInDifferentCustomDevices(tensors, static_cast<size_t>(GetSize())); int rank,
return Collective( CommType comm_type,
tensors, bool is_sync,
tensors, bool use_calc_stream) {
[&](phi::DenseTensor& input, return std::make_shared<ProcessGroupCustom::XCCLTask>(
phi::DenseTensor& output, place, rank, comm_type, is_sync, use_calc_stream);
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLSend(device_type_,
input.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
},
CommType::SEND,
false,
false);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce( void ProcessGroupCustom::BroadcastUniqueXCCLID(
phi::DenseTensor* out_tensor, phi::ccl::CCLRootId* xccl_root_id) {
const phi::DenseTensor& in_tensor, const std::string key =
const ReduceOptions& opts, "ProcessGroupCustom/xccl_ids/" + std::to_string(gid_) + "/0";
if (rank_ == 0) {
store_->set(key, *xccl_root_id);
} else {
*xccl_root_id = store_->get(key);
}
}
void ProcessGroupCustom::CreateXCCLEnvCache(const Place& place,
const std::string& place_key) {
if (!place_to_comm_ctx_.empty()) {
VLOG(3) << "Warning: Tensors from multiple devices are not supported yet.";
}
VLOG(3) << "init xccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << place_key;
phi::distributed::CommContextManager::CreateXCCLCommContext(
store_, std::to_string(gid_), place.GetDeviceType(), rank_, size_);
auto* calc_ctx = static_cast<phi::CustomContext*>(
platform::DeviceContextPool::Instance().Get(place));
auto comm_ctx = std::make_unique<phi::CustomContext>(place);
comm_ctx->SetAllocator(
&(phi::DeviceContextPool::Instance().Get(place)->GetAllocator()));
comm_ctx->SetHostAllocator(
&(phi::DeviceContextPool::Instance().Get(place)->GetHostAllocator()));
comm_ctx->SetZeroAllocator(
&(phi::DeviceContextPool::Instance().Get(place)->GetZeroAllocator()));
comm_ctx->SetHostZeroAllocator(
&(phi::DeviceContextPool::Instance().Get(place)->GetHostZeroAllocator()));
auto xccl_comm_ctx = this->GetCommContext();
comm_ctx->set_xccl_comm(xccl_comm_ctx->GetXcclComm());
auto xccl_event = std::make_unique<phi::event::Event>();
xccl_event->Init(place);
place_to_calc_event_.emplace(place_key, std::move(xccl_event));
place_to_calc_ctx_.emplace(place_key, calc_ctx);
place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));
// TODO(sunyilun): for compatibility, will be removed later
std::vector<phi::CustomContext*> comm_ctx_wrapper{
place_to_comm_ctx_[place_key].get()};
places_to_ctx_.emplace(place_key, comm_ctx_wrapper);
}
void ProcessGroupCustom::SyncCalcStream(const Place& place) {
const std::string& key = GetKeyFromPlace(place);
auto& calc_event = place_to_calc_event_.at(key);
const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto* comm_ctx = place_to_comm_ctx_.at(key).get();
calc_event->Record(calc_ctx->GetStream().get());
comm_ctx->GetStream()->WaitEvent(calc_event.get());
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::RunFnInXCCLEnv(
std::function<void(const phi::stream::Stream&)> fn,
const phi::DenseTensor& tensor,
CommType comm_type,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
phi::distributed::CommStaticCheck::SameShape(*out_tensor, const auto& place = tensor.place();
in_tensor, const auto& key = GetKeyFromPlace(place);
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_, phi::DeviceGuard guard(place);
size_,
phi::AllocationType::CUSTOM); if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor}; CreateXCCLEnvCache(place, key);
std::vector<phi::DenseTensor> out_wrapper{*out_tensor}; }
if (!use_calc_stream) {
SyncCalcStream(place);
}
auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);
const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto& comm_ctx = place_to_comm_ctx_.at(key);
auto& xccl_stream =
use_calc_stream ? *calc_ctx->GetStream() : *comm_ctx->GetStream();
fn(xccl_stream);
if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
memory::RecordStream(tensor.Holder(), xccl_stream.raw_stream());
}
task->UpdateWaitChain(*comm_ctx);
}
return task;
}
// TODO(sunyilun): methods below will be removed later
void SyncDefaultStream(const std::vector<Place>& places,
phi::event::Event& xccl_event, // NOLINT
std::vector<phi::CustomContext*>& dev_ctx) { // NOLINT
for (size_t i = 0; i < places.size(); ++i) {
auto* default_ctx = static_cast<phi::CustomContext*>(
platform::DeviceContextPool::Instance().Get(places[i]));
xccl_event.Record(default_ctx->GetStream().get());
dev_ctx[i]->GetStream()->WaitEvent(&xccl_event);
}
}
std::shared_ptr<ProcessGroupCustom::XCCLTask> ProcessGroupCustom::CreateTask(
std::vector<Place> places,
int rank,
CommType comm_type,
const std::vector<phi::DenseTensor>& inputs) {
return std::make_shared<ProcessGroupCustom::XCCLTask>(
places, rank, comm_type, inputs);
}
ProcessGroupCustom::XCCLTask::XCCLTask(
const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs)
: TaskStream(rank, inputs, CommType), task_place_(places[0]) {
comm_event_.Init(places[0]);
}
// create XCCLManager cache for places_key
void ProcessGroupCustom::CreateXCCLManagerCache(
const std::string& places_key, const std::vector<Place>& places) {
PADDLE_ENFORCE_EQ(places_key.empty(),
false,
phi::errors::PreconditionNotMet(
"Not able to create/get the XCCL Communicator since "
"the CustomPlace are not known"));
phi::ccl::CCLRootId xccl_root_id;
if (rank_ == 0) {
phi::DeviceManager::CCLGetUniqueId(device_type_, &xccl_root_id);
}
BroadcastUniqueXCCLID(&xccl_root_id);
VLOG(3) << "init xccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << places_key << ", xccl uniqueid: "
<< phi::ccl::SerializeXCCLUniqueId(xccl_root_id);
std::vector<std::unique_ptr<phi::CustomContext>> dev_ctx;
dev_ctx.resize(places.size());
std::vector<phi::CustomContext*> dev_ctx_raw;
dev_ctx_raw.resize(places.size());
GroupStart(device_type_);
for (size_t i = 0; i < places.size(); ++i) {
phi::DeviceGuard guard(places[i]);
dev_ctx[i] = std::make_unique<phi::CustomContext>(places[i]);
dev_ctx[i]->SetAllocator(
&(phi::DeviceContextPool::Instance().Get(places[i])->GetAllocator()));
dev_ctx[i]->SetHostAllocator(&(
phi::DeviceContextPool::Instance().Get(places[i])->GetHostAllocator()));
dev_ctx[i]->SetZeroAllocator(&(
phi::DeviceContextPool::Instance().Get(places[i])->GetZeroAllocator()));
dev_ctx[i]->SetHostZeroAllocator(&(phi::DeviceContextPool::Instance()
.Get(places[i])
->GetHostZeroAllocator()));
phi::ccl::CCLComm xccl_comm;
phi::DeviceManager::CCLCommInitRank(
device_type_, GetSize(), &xccl_root_id, GetRank(), &xccl_comm);
dev_ctx[i]->set_xccl_comm(xccl_comm);
dev_ctx_raw[i] = dev_ctx[i].get();
}
GroupEnd(device_type_);
// TODO(sunyilun): for compatibility, will be removed later
auto xccl_event = std::make_unique<phi::event::Event>();
xccl_event->Init(places[0]);
place_to_calc_event_.emplace(places_key, std::move(xccl_event));
place_to_calc_ctx_.emplace(
places_key,
static_cast<phi::CustomContext*>(
platform::DeviceContextPool::Instance().Get(places[0])));
place_to_comm_ctx_.emplace(places_key, std::move(dev_ctx[0]));
// These caches will be useful to process sync/wait/communicate
places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw));
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Collective(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
Fn fn,
CommType op_type) {
const auto places = GetPlaceList(inputs);
const auto key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateXCCLManagerCache(key, places);
}
}
SyncDefaultStream(
places, *place_to_calc_event_.at(key), places_to_ctx_.at(key));
auto task = CreateTask(places, rank_, op_type, inputs);
// construct uninitialize guard for device
{
GroupStart(device_type_);
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
const auto& xccl_stream = *places_to_ctx_.at(key)[i]->GetStream();
fn(inputs[i],
outputs[i],
places_to_ctx_.at(key)[i]->xccl_comm(),
xccl_stream);
}
GroupEnd(device_type_);
}
if (FLAGS_use_stream_safe_cuda_allocator) {
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
memory::RecordStream(inputs[i].Holder(),
places_to_ctx_.at(key)[i]->stream());
}
}
for (size_t i = 0; i < inputs.size(); ++i) {
phi::DeviceGuard guard(places[i]);
task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
}
return task;
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::PointToPoint(
std::vector<phi::DenseTensor>& tensors,
Fn fn,
int dst_rank,
CommType op_type) {
const auto places = GetPlaceList(tensors);
const auto key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateXCCLManagerCache(key, places);
}
}
SyncDefaultStream(
places, *place_to_calc_event_.at(key), places_to_ctx_.at(key));
auto task = CreateTask(places, rank_, op_type, tensors);
// construct uninitialize guard for device
{
GroupStart(device_type_);
for (size_t i = 0; i < tensors.size(); ++i) {
phi::DeviceGuard guard(places[i]);
const auto& xccl_stream = *places_to_ctx_.at(key)[i]->GetStream();
fn(tensors[i],
places_to_ctx_.at(key)[i]->xccl_comm(),
xccl_stream,
dst_rank);
}
GroupEnd(device_type_);
}
if (FLAGS_use_stream_safe_cuda_allocator) {
for (size_t i = 0; i < tensors.size(); ++i) {
phi::DeviceGuard guard(places[i]);
memory::RecordStream(tensors[i].Holder(),
places_to_ctx_.at(key)[i]->stream());
}
}
for (size_t i = 0; i < tensors.size(); ++i) {
phi::DeviceGuard guard(places[i]);
task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
}
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
return Collective( return Collective(
in_wrapper, in_tensors,
out_wrapper, out_tensors,
[&](phi::DenseTensor& input, [&](const phi::DenseTensor& input,
phi::DenseTensor& output, phi::DenseTensor& output,
phi::ccl::CCLComm comm, const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream) { const phi::stream::Stream& stream) {
phi::DeviceManager::CCLReduce(device_type_, auto comm_context = this->GetCommContext();
input.data(), comm_context->AllReduce(
output.data(), &output,
input.numel(), input,
phi::ccl::ToCCLDataType(input.dtype()), paddle::distributed::ToXCCLRedType(opts.reduce_op),
ToCustomCCLRedType(opts.reduce_op), stream);
opts.root_rank,
comm,
stream);
}, },
CommType::REDUCE, CommType::ALLREDUCE);
sync_op,
use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) { const BroadcastOptions& opts) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_), CheckTensorsInCustomPlace(in_tensors, device_type_),
true, true,
phi::errors::InvalidArgument("All inputs should be in CustomPlace.")); phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
return Collective( return Collective(
in_tensors, in_tensors,
out_tensors, out_tensors,
[&](phi::DenseTensor& input, [&](phi::DenseTensor& input,
phi::DenseTensor& output, phi::DenseTensor& output,
phi::ccl::CCLComm comm, const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream) { const phi::stream::Stream& stream) {
phi::DeviceManager::CCLReduce(device_type_, const auto root =
input.data(), opts.source_rank * in_tensors.size() + opts.source_root;
output.data(), auto comm_context = this->GetCommContext();
input.numel(), comm_context->Broadcast(&output, input, root, stream);
phi::ccl::ToCCLDataType(input.dtype()),
ToCustomCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream);
}, },
CommType::REDUCE, CommType::BROADCAST);
}
void CheckTensorsInDifferentDevices(
const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
PADDLE_ENFORCE_EQ(
tensors.empty(),
false, false,
false); phi::errors::InvalidArgument("Tensor list must be nonempty."));
PADDLE_ENFORCE_LE(
tensors.size(),
num_devices,
phi::errors::InvalidArgument("Tensor list mustn't be larger than the "
"number of available CustomDevices."));
std::set<Place> used_devices;
for (const auto& t : tensors) {
PADDLE_ENFORCE_EQ(platform::is_custom_place(t.place()),
true,
phi::errors::InvalidArgument(
"Tensors must be CustomDevice and dense tensor."));
const auto inserted = used_devices.insert(t.place()).second;
PADDLE_ENFORCE_EQ(inserted,
true,
phi::errors::InvalidArgument(
"Tensors must be on distinct CustomDevice devices."));
}
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
phi::DenseTensor* out_tensor, std::vector<phi::DenseTensor>& tensors, int dst_rank) {
const phi::DenseTensor& in_tensor, CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) {
const phi::DDim& out_dim = out_tensor->dims();
const phi::DDim& in_dim = in_tensor.dims();
CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);
// NOTE: Since `all_to_all` needs other processes' participation, it cannot auto task = PointToPoint(
// simply be covered by static checks. Factors are set to 0 here to skip the tensors,
// shape check. Its shape check will be done by dynamic checks with
// FLAGS_enable_nccl_dynamic_check.
phi::distributed::CommStaticCheck::CheckShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
/*out_size_factor*/ 0,
/*in_size_factor*/ 0,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input, [&](phi::DenseTensor& input,
phi::DenseTensor& output, const phi::ccl::CCLComm& comm,
phi::ccl::CCLComm comm, const phi::stream::Stream& stream,
const phi::stream::Stream& stream) { int dst_rank) {
int64_t in_row_size = in_tensor.numel() / in_dim[0], auto comm_context = this->GetCommContext();
out_row_size = out_tensor->numel() / out_dim[0]; comm_context->Send(input, input.numel(), dst_rank, stream);
int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0; },
phi::DenseTensor input_partial, output_partial; dst_rank,
std::vector<void*> send_buf, recv_buf; CommType::SEND);
std::vector<size_t> send_count, recv_count; return task;
std::vector<phi::ccl::CCLDataType> send_dtype, recv_dtype; }
for (auto i = 0; i < size_; i++) { std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
in_numel = in_size_each_rank[i] * in_row_size; std::vector<phi::DenseTensor>& tensors, int src_rank) {
input_partial = GetPartialTensor(in_tensor, in_offset, in_numel); CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
out_numel = out_size_each_rank[i] * out_row_size;
output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel);
in_offset += in_numel;
out_offset += out_numel;
send_buf.push_back(input_partial.data());
recv_buf.push_back(output_partial.data());
send_count.push_back(in_numel);
recv_count.push_back(out_numel);
send_dtype.push_back(phi::ccl::ToCCLDataType(input_partial.dtype()));
recv_dtype.push_back(phi::ccl::ToCCLDataType(output_partial.dtype()));
}
phi::DeviceManager::CCLAllToAll( auto task = PointToPoint(
device_type_, tensors,
const_cast<const void**>(send_buf.data()), [&](phi::DenseTensor& output,
send_count.data(), const phi::ccl::CCLComm& comm,
send_dtype.data(), const phi::stream::Stream& stream,
recv_buf.data(), int src_rank) {
recv_count.data(), auto comm_context = this->GetCommContext();
recv_dtype.data(), comm_context->Recv(&output, output.numel(), src_rank, stream);
rank_,
size_,
comm,
stream);
}, },
CommType::ALLTOALL, src_rank,
sync_op, CommType::RECV);
use_calc_stream); return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
phi::errors::InvalidArgument("All outputs should be in CustomPlace."));
return Collective(
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
comm_context->AllGather(&output, input, stream);
},
CommType::ALLGATHER);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
...@@ -863,8 +919,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll( ...@@ -863,8 +919,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
out_tensors, out_tensors,
[&](phi::DenseTensor& input, [&](phi::DenseTensor& input,
phi::DenseTensor& output, phi::DenseTensor& output,
phi::ccl::CCLComm comm, const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream) { const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
size_t offset = 0; size_t offset = 0;
std::vector<void*> send_buf, recv_buf; std::vector<void*> send_buf, recv_buf;
std::vector<size_t> send_count(size_, input.numel() / size_), std::vector<size_t> send_count(size_, input.numel() / size_),
...@@ -889,111 +947,35 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll( ...@@ -889,111 +947,35 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
recv_dtype.data(), recv_dtype.data(),
rank_, rank_,
size_, size_,
comm, comm_context->GetXcclComm(),
stream);
},
CommType::ALLTOALL,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
phi::distributed::CommStaticCheck::ScatterLikeShape(
*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLReduceScatter(
device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
out_tensor->numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
paddle::distributed::ToCustomCCLRedType(opts.reduce_op),
comm,
stream); stream);
}, },
CommType::REDUCE_SCATTER, CommType::ALLTOALL);
false,
false);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
phi::DenseTensor* out_tensor, std::vector<phi::DenseTensor>& in_tensors,
const phi::DenseTensor& in_tensor, std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts, const ReduceOptions& opts) {
bool sync_op, PADDLE_ENFORCE_EQ(
bool use_calc_stream) { CheckTensorsInCustomPlace(in_tensors, device_type_),
phi::distributed::CommStaticCheck::ScatterLikeShape( true,
*out_tensor, phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
in_tensor,
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective( return Collective(
in_wrapper, in_tensors,
out_wrapper, out_tensors,
[&](phi::DenseTensor& input, [&](const phi::DenseTensor& input,
phi::DenseTensor& output, phi::DenseTensor& output,
phi::ccl::CCLComm comm, const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream) { const phi::stream::Stream& stream) {
int64_t numel = in_tensor.numel() / size_; auto comm_context = this->GetCommContext();
if (rank_ == opts.root_rank) { comm_context->Reduce(&output,
int64_t offset = 0; input,
phi::DenseTensor partial_tensor; paddle::distributed::ToXCCLRedType(opts.reduce_op),
for (auto i = 0; i < size_; i++) { opts.root_rank,
partial_tensor = GetPartialTensor(in_tensor, offset, numel); stream);
if (i != rank_) {
phi::DeviceManager::CCLSend(
device_type_,
partial_tensor.data(),
numel,
phi::ccl::ToCCLDataType(partial_tensor.dtype()),
i,
comm,
stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(out_tensor->data(),
partial_tensor.data(),
numel * phi::SizeOf(partial_tensor.dtype()),
&stream);
}
offset += numel;
}
} else {
phi::DeviceManager::CCLRecv(
device_type_,
out_tensor->data(),
numel,
phi::ccl::ToCCLDataType(out_tensor->dtype()),
opts.root_rank,
comm,
stream);
}
}, },
CommType::SCATTER, CommType::REDUCE);
sync_op,
use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
...@@ -1003,134 +985,36 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter( ...@@ -1003,134 +985,36 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_), CheckTensorsInCustomPlace(in_tensors, device_type_),
true, true,
phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_), CheckTensorsInCustomPlace(out_tensors, device_type_),
true, true,
phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
return Collective( return Collective(
in_tensors, in_tensors,
out_tensors, out_tensors,
[&](phi::DenseTensor& input, [&](phi::DenseTensor& input,
phi::DenseTensor& output, phi::DenseTensor& output,
phi::ccl::CCLComm comm, const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream) {
int64_t numel = input.numel() / size_;
if (rank_ == opts.root_rank) {
int64_t offset = 0;
phi::DenseTensor partial_tensor;
for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(input, offset, numel);
if (i != rank_) {
phi::DeviceManager::CCLSend(
device_type_,
partial_tensor.data(),
numel,
phi::ccl::ToCCLDataType(partial_tensor.dtype()),
i,
comm,
stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(output.data(),
partial_tensor.data(),
numel * phi::SizeOf(partial_tensor.dtype()),
&stream);
}
offset += numel;
}
} else {
phi::DeviceManager::CCLRecv(device_type_,
output.data(),
numel,
phi::ccl::ToCCLDataType(output.dtype()),
opts.root_rank,
comm,
stream);
}
},
CommType::SCATTER,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
std::vector<phi::DenseTensor> partial_tensors;
if (rank_ == opts.root_rank) {
partial_tensors.reserve(size_);
size_t offset = 0;
size_t numel = out_tensor->numel() / size_;
for (auto i = 0; i < size_; i++) {
partial_tensors.push_back(GetPartialTensor(*out_tensor, offset, numel));
offset += numel;
}
}
return Gather(&partial_tensors, in_tensor, opts, sync_op, use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
std::vector<phi::DenseTensor>* gather_tensors_ptr,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
auto& gather_tensors = *gather_tensors_ptr;
PADDLE_ENFORCE_GT(size_,
opts.root_rank,
phi::errors::InvalidArgument(
"root world size [%d] is less than root rank [%d]",
size_,
opts.root_rank));
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
return Collective(
in_wrapper,
in_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) { const phi::stream::Stream& stream) {
// root receive from all devices auto comm_context = this->GetCommContext();
size_t offset = 0;
size_t count = input.numel() / size_;
if (rank_ == opts.root_rank) { if (rank_ == opts.root_rank) {
comm_context->GroupStart();
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
auto& gather_tensor = gather_tensors[i]; auto input_data = reinterpret_cast<phi::DenseTensor*>(
if (i != rank_) { GetPointerByOffset(input.data(), offset, input.dtype()));
phi::DeviceManager::CCLRecv( comm_context->Send(*input_data, count, i, stream);
device_type_, offset += count;
gather_tensor.data(),
gather_tensor.numel(),
phi::ccl::ToCCLDataType(gather_tensor.dtype()),
i,
comm,
stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(
gather_tensor.data(),
in_tensor.data(),
in_tensor.numel() * phi::SizeOf(in_tensor.dtype()),
&stream);
}
} }
comm_context->Recv(&output, count, opts.root_rank, stream);
comm_context->GroupEnd();
} else { } else {
// send to root comm_context->Recv(&output, count, opts.root_rank, stream);
phi::DeviceManager::CCLSend(
device_type_,
const_cast<void*>(in_tensor.data()),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
opts.root_rank,
comm,
stream);
} }
}, },
CommType::GATHER, CommType::SCATTER);
sync_op,
use_calc_stream);
} }
std::shared_ptr<ProcessGroupCustom> std::shared_ptr<ProcessGroupCustom>
...@@ -1146,5 +1030,16 @@ ProcessGroupCustom::CreateProcessGroupCustom( ...@@ -1146,5 +1030,16 @@ ProcessGroupCustom::CreateProcessGroupCustom(
return process_group; return process_group;
} }
phi::distributed::XCCLCommContext* ProcessGroupCustom::GetCommContext() {
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
auto comm_context = static_cast<phi::distributed::XCCLCommContext*>(
comm_context_manager.Get(std::to_string(this->gid_)));
PADDLE_ENFORCE_NE(comm_context,
nullptr,
phi::errors::Unavailable("XCCLCommContext is nullptr"));
return comm_context;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -15,62 +15,57 @@ ...@@ -15,62 +15,57 @@
#pragma once #pragma once
#include <chrono> #include <chrono>
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/collective/custom_ccl_tools.h"
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_with_stream.h" #include "paddle/fluid/distributed/collective/process_group_with_stream.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/phi/backends/device_manager.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/common/place.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/phi/core/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/distributed/store/store.h" #include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/distributed/xccl_comm_context.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using Place = paddle::platform::Place;
using CustomDeviceContext = paddle::platform::CustomDeviceContext;
class ProcessGroupCustom : public ProcessGroupWithStream { using Place = phi::Place;
class ProcessGroupCustom final : public ProcessGroupWithStream {
public: public:
class CustomTask : public ProcessGroup::Task, class XCCLTask final : public ProcessGroupWithStream::TaskStream,
public std::enable_shared_from_this<CustomTask> { public std::enable_shared_from_this<XCCLTask> {
public: public:
CustomTask(const std::vector<Place>& places, XCCLTask(const Place& place,
int rank, int rank,
CommType CommType, CommType comm_type,
const std::vector<phi::DenseTensor>& inputs); bool sync_op,
bool use_calc_stream);
virtual ~XCCLTask();
bool IsCompleted() override; bool IsCompleted() override;
void SynchronizeStreams();
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override; bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override;
void Synchronize() override; void Synchronize() override;
void UpdateWaitChain(const phi::DeviceContext& ctx) override; void UpdateWaitChain(const phi::DeviceContext& ctx) override;
void SetOutputs(std::vector<phi::DenseTensor>& outputs); // NOLINT
virtual ~CustomTask();
std::vector<CustomEventManager> control_events_; bool IsBlockCPUInWait() const { return block_cpu_in_wait_; }
std::vector<phi::DenseTensor> barrierTensors_; void SetBlockCPUInWait() { block_cpu_in_wait_ = true; }
protected: // TODO(sunyilun): methods below will be removed later
std::vector<Place> places_; XCCLTask(const std::vector<Place>& places,
std::vector<std::shared_ptr<CustomCCLCommManager>> cclComms_; int rank,
std::shared_ptr<std::vector<phi::DenseTensor>> outputs_; CommType CommType,
const std::vector<phi::DenseTensor>& inputs);
private: private:
const std::string device_type_; bool block_cpu_in_wait_{false};
phi::event::Event comm_event_; // event on comm stream
Place task_place_;
}; };
ProcessGroupCustom(const std::shared_ptr<phi::distributed::Store>& store, public:
const std::string& device_type,
int rank,
int size,
int gid);
static std::shared_ptr<ProcessGroupCustom> CreateProcessGroupCustom( static std::shared_ptr<ProcessGroupCustom> CreateProcessGroupCustom(
const std::shared_ptr<phi::distributed::Store>& store, const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type, const std::string& device_type,
...@@ -78,19 +73,18 @@ class ProcessGroupCustom : public ProcessGroupWithStream { ...@@ -78,19 +73,18 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
int size, int size,
int gid); int gid);
std::string GetBackendName() const override { return "XCCL_" + device_type_; } ProcessGroupCustom(const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type,
int rank,
int size,
int gid);
std::shared_ptr<ProcessGroup::Task> Barrier( std::string GetBackendName() const override { return "XCCL"; }
const BarrierOptions& = BarrierOptions()) override;
phi::DeviceContext* GetDeviceContext(const Place& place) const override; phi::DeviceContext* GetDeviceContext(const Place& place) const override;
phi::ccl::CCLComm CustomCCLComm(const Place& place) const; phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather( std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
...@@ -100,11 +94,6 @@ class ProcessGroupCustom : public ProcessGroupWithStream { ...@@ -100,11 +94,6 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
...@@ -112,10 +101,16 @@ class ProcessGroupCustom : public ProcessGroupWithStream { ...@@ -112,10 +101,16 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Broadcast( std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors, phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& out_tensors, const phi::DenseTensor& in_tensor,
const BroadcastOptions& = BroadcastOptions()) override; const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast( std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
...@@ -124,49 +119,12 @@ class ProcessGroupCustom : public ProcessGroupWithStream { ...@@ -124,49 +119,12 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor, std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const ReduceOptions& opts, const ReduceOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> ReduceScatter( std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
...@@ -180,11 +138,6 @@ class ProcessGroupCustom : public ProcessGroupWithStream { ...@@ -180,11 +138,6 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Gather(phi::DenseTensor* out_tensor, std::shared_ptr<ProcessGroup::Task> Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const GatherOptions& opts, const GatherOptions& opts,
...@@ -198,52 +151,124 @@ class ProcessGroupCustom : public ProcessGroupWithStream { ...@@ -198,52 +151,124 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
protected: std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask( int src_rank,
std::vector<Place> places, int64_t offset,
int rank, int64_t numel,
CommType opType, bool sync_op,
const std::vector<phi::DenseTensor>& inputs); bool use_calc_stream) override;
std::shared_ptr<phi::distributed::Store> store_; std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
std::shared_ptr<CustomCCLCommManager> custom_comm_; int dst_rank,
std::mutex mutex_; int64_t offset,
std::unordered_map<std::string, int64_t numel,
std::vector<std::shared_ptr<CustomCCLCommManager>>> bool sync_op,
places_to_customcomm_; bool use_calc_stream) override;
std::unordered_map<std::string, std::vector<CustomEventManager>>
places_to_events_; static void GroupStart(const std::string& dev_type);
std::unordered_map<std::string,
std::vector<std::unique_ptr<CustomDeviceContext>>> static void GroupEnd(const std::string& dev_type);
places_to_ctx_;
std::set<int> used_place_ids_; phi::ccl::CCLComm XCCLComm(const Place& place) const;
// TODO(liyurui): This API will be moved later
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) override;
private: private:
void BcastCustomId(std::vector<phi::ccl::CCLRootId>& ccl_ids, // NOLINT std::shared_ptr<ProcessGroupCustom::XCCLTask> CreateTask(
int root, const Place& place,
int server_fd); int rank,
CommType op_type,
bool sync_op,
bool use_calc_stream);
void BroadcastUniqueCustomID( void BroadcastUniqueXCCLID(phi::ccl::CCLRootId* nccl_id);
std::vector<phi::ccl::CCLRootId>& custom_ccl_ids); // NOLINT
void CreateXCCLEnvCache(const Place& place, const std::string& place_key);
void SyncCalcStream(const Place& place);
std::shared_ptr<ProcessGroup::Task> RunFnInXCCLEnv(
std::function<void(const phi::stream::Stream&)> fn,
const phi::DenseTensor& tensor,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroupCustom::XCCLTask> CreateTask(
std::vector<Place> places,
int rank,
CommType op_type,
const std::vector<phi::DenseTensor>& inputs);
template <typename Fn> template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective( std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<phi::DenseTensor>& inputs, // NOLINT std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT std::vector<phi::DenseTensor>& outputs, // NOLINT
Fn fn, Fn fn,
CommType op_type, CommType op_type);
bool sync_op,
bool use_calc_stream);
template <typename Fn> template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective(Fn fn, std::shared_ptr<ProcessGroup::Task> PointToPoint(
CommType op_type, std::vector<phi::DenseTensor>& tensors, // NOLINT
bool sync_op, Fn fn,
bool use_calc_stream); int dst_rank,
CommType op_type);
void CreateCustomManagerCache(const std::string& places_key,
const std::vector<Place>& places); void CreateXCCLManagerCache(const std::string& places_key,
const std::string device_type_; const std::vector<Place>& places);
phi::distributed::XCCLCommContext* GetCommContext();
private:
std::shared_ptr<phi::distributed::Store> store_;
std::string device_type_;
std::unordered_map<std::string, std::unique_ptr<phi::event::Event>>
place_to_calc_event_; // event on calc stream
std::unordered_map<std::string, phi::CustomContext*> place_to_calc_ctx_;
std::unordered_map<std::string, std::unique_ptr<phi::CustomContext>>
place_to_comm_ctx_;
// TODO(sunyilun): attrs below will be removed later
std::mutex mutex_;
std::unordered_map<std::string, std::vector<phi::CustomContext*>>
places_to_ctx_;
}; };
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -53,8 +53,8 @@ ccl::CCLComm GetCCLComm(const Place& place, int global_gid) { ...@@ -53,8 +53,8 @@ ccl::CCLComm GetCCLComm(const Place& place, int global_gid) {
#endif #endif
} else if (place.GetType() == phi::AllocationType::CUSTOM) { } else if (place.GetType() == phi::AllocationType::CUSTOM) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE) #if defined(PADDLE_WITH_CUSTOM_DEVICE)
return static_cast<paddle::distributed::ProcessGroupCustom*>(pg) return static_cast<paddle::distributed::ProcessGroupCustom*>(pg)->XCCLComm(
->CustomCCLComm(place); place);
#else #else
return nullptr; return nullptr;
#endif #endif
......
...@@ -1574,6 +1574,7 @@ void AllocatorFacade::SetDefaultStream(const platform::CustomPlace& place, ...@@ -1574,6 +1574,7 @@ void AllocatorFacade::SetDefaultStream(const platform::CustomPlace& place,
m_->SetDefaultStream(place, stream); m_->SetDefaultStream(place, stream);
} }
} }
#endif #endif
UNUSED static std::shared_ptr<NaiveBestFitAllocator> unused_obj = UNUSED static std::shared_ptr<NaiveBestFitAllocator> unused_obj =
......
...@@ -101,6 +101,7 @@ class AllocatorFacade { ...@@ -101,6 +101,7 @@ class AllocatorFacade {
phi::stream::stream_t stream); phi::stream::stream_t stream);
void RecordStream(std::shared_ptr<phi::Allocation> allocation, void RecordStream(std::shared_ptr<phi::Allocation> allocation,
phi::stream::stream_t stream); phi::stream::stream_t stream);
void SetDefaultStream(const platform::CustomPlace& place, void SetDefaultStream(const platform::CustomPlace& place,
phi::stream::stream_t stream); phi::stream::stream_t stream);
#endif #endif
......
...@@ -71,6 +71,14 @@ gpuStream_t GetStream(const std::shared_ptr<Allocation>& allocation) { ...@@ -71,6 +71,14 @@ gpuStream_t GetStream(const std::shared_ptr<Allocation>& allocation) {
return allocation::AllocatorFacade::Instance().GetStream(allocation); return allocation::AllocatorFacade::Instance().GetStream(allocation);
} }
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void RecordStream(std::shared_ptr<Allocation> allocation,
phi::stream::stream_t stream) {
return allocation::AllocatorFacade::Instance().RecordStream(allocation,
stream);
}
#endif #endif
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/stream.h" #include "paddle/phi/core/stream.h"
...@@ -55,5 +56,9 @@ void RecordStream(std::shared_ptr<Allocation> allocation, gpuStream_t stream); ...@@ -55,5 +56,9 @@ void RecordStream(std::shared_ptr<Allocation> allocation, gpuStream_t stream);
gpuStream_t GetStream(const std::shared_ptr<Allocation>& allocation); gpuStream_t GetStream(const std::shared_ptr<Allocation>& allocation);
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void RecordStream(std::shared_ptr<Allocation> allocation,
phi::stream::stream_t stream);
#endif
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -56,6 +56,12 @@ void BindCommContextManager(py::module *m) { ...@@ -56,6 +56,12 @@ void BindCommContextManager(py::module *m) {
"create_gloo_comm_context", "create_gloo_comm_context",
&phi::distributed::CommContextManager::CreateGlooCommContext, &phi::distributed::CommContextManager::CreateGlooCommContext,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
.def_static(
"create_xccl_comm_context",
&phi::distributed::CommContextManager::CreateXCCLCommContext,
py::call_guard<py::gil_scoped_release>())
#endif #endif
.def("set_store", &phi::distributed::CommContextManager::SetStore); .def("set_store", &phi::distributed::CommContextManager::SetStore);
} }
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "paddle/phi/core/errors.h" #include "paddle/phi/core/errors.h"
#include "paddle/phi/core/macros.h" #include "paddle/phi/core/macros.h"
#include "paddle/phi/common/reduce_type.h"
namespace phi { namespace phi {
namespace ccl { namespace ccl {
typedef void* CCLComm; typedef void* CCLComm;
...@@ -38,6 +40,32 @@ enum CCLDataType { ...@@ -38,6 +40,32 @@ enum CCLDataType {
CCL_DATA_TYPE_UINT8 CCL_DATA_TYPE_UINT8
}; };
inline CCLReduceOp ToXCCLReduceOp(int reduce_type) {
phi::ccl::CCLReduceOp red_type = phi::ccl::CCLReduceOp::SUM;
switch (static_cast<phi::ReduceType>(reduce_type)) {
case phi::ReduceType::kRedSum:
red_type = phi::ccl::CCLReduceOp::SUM;
break;
case phi::ReduceType::kRedMax:
red_type = phi::ccl::CCLReduceOp::MAX;
break;
case phi::ReduceType::kRedMin:
red_type = phi::ccl::CCLReduceOp::MIN;
break;
case phi::ReduceType::kRedProd:
red_type = phi::ccl::CCLReduceOp::PRODUCT;
break;
case phi::ReduceType::kRedAvg:
red_type = phi::ccl::CCLReduceOp::AVG;
break;
default:
PADDLE_THROW(
errors::Unavailable("Unsuppored reduce type. Reduce type must be one "
"of SUM, MAX, MIN, PRODUCT and AVG."));
}
return red_type;
}
inline CCLDataType ToCCLDataType(phi::DataType type) { inline CCLDataType ToCCLDataType(phi::DataType type) {
if (type == phi::DataType::FLOAT64) { if (type == phi::DataType::FLOAT64) {
return CCL_DATA_TYPE_FP64; return CCL_DATA_TYPE_FP64;
...@@ -79,5 +107,14 @@ inline phi::DataType ToPhiDataType(CCLDataType type) { ...@@ -79,5 +107,14 @@ inline phi::DataType ToPhiDataType(CCLDataType type) {
} }
} }
inline std::string SerializeXCCLUniqueId(const phi::ccl::CCLRootId& ccl_id) {
const uint8_t* bytes = ccl_id.data();
std::ostringstream oss;
for (size_t i = 0; i < ccl_id.size(); ++i) {
oss << std::hex << static_cast<int>(bytes[i]);
}
return oss.str();
}
} // namespace ccl } // namespace ccl
} // namespace phi } // namespace phi
...@@ -754,13 +754,15 @@ class CustomDevice : public DeviceInterface { ...@@ -754,13 +754,15 @@ class CustomDevice : public DeviceInterface {
} }
void CCLGroupStart() override { void CCLGroupStart() override {
CHECK_PTR(pimpl_->xccl_group_start); if (pimpl_->xccl_group_start) {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_start()); PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_start());
}
} }
void CCLGroupEnd() override { void CCLGroupEnd() override {
CHECK_PTR(pimpl_->xccl_group_end); if (pimpl_->xccl_group_end) {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_end()); PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_end());
}
} }
void CCLSend(void* send_buf, void CCLSend(void* send_buf,
......
...@@ -12,4 +12,8 @@ if(WITH_GLOO) ...@@ -12,4 +12,8 @@ if(WITH_GLOO)
list(APPEND DISTRIBUTED_COMMON_SRCS gloo_utils.cc gloo_comm_context.cc) list(APPEND DISTRIBUTED_COMMON_SRCS gloo_utils.cc gloo_comm_context.cc)
endif() endif()
if(WITH_CUSTOM_DEVICE)
list(APPEND DISTRIBUTED_COMMON_SRCS xccl_comm_context.cc)
endif()
collect_srcs(core_srcs SRCS ${DISTRIBUTED_COMMON_SRCS}) collect_srcs(core_srcs SRCS ${DISTRIBUTED_COMMON_SRCS})
...@@ -53,6 +53,18 @@ DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx, ...@@ -53,6 +53,18 @@ DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
})); }));
return out; return out;
} }
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (phi::CustomContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const CustomContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#endif #endif
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"The all_gather in reshard only supported on CPU and GPU for now.")); "The all_gather in reshard only supported on CPU and GPU for now."));
......
...@@ -170,6 +170,15 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx, ...@@ -170,6 +170,15 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
#else #else
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Cannot use gloo on CPU, please turn PADDLE_WITH_GLOO flag on.")); "Cannot use gloo on CPU, please turn PADDLE_WITH_GLOO flag on."));
#endif
} else if (phi::CustomContext::classof(&dev_ctx)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
CommContextManager::CreateXCCLCommContext(
store,
unique_comm_key,
dev_ctx.GetPlace().GetDeviceType(),
rank,
world_size);
#endif #endif
} else { } else {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
......
...@@ -24,8 +24,8 @@ class CommContext { ...@@ -24,8 +24,8 @@ class CommContext {
CommContext(int rank, int size) : rank_(rank), size_(size) {} CommContext(int rank, int size) : rank_(rank), size_(size) {}
virtual ~CommContext() = default; virtual ~CommContext() = default;
int GetRank() { return rank_; } int GetRank() const { return rank_; }
int GetSize() { return size_; } int GetSize() const { return size_; }
protected: protected:
int rank_; int rank_;
......
...@@ -32,6 +32,9 @@ ...@@ -32,6 +32,9 @@
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h" #include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi { namespace phi {
namespace distributed { namespace distributed {
...@@ -91,6 +94,35 @@ void CommContextManager::CreateGlooCommContext( ...@@ -91,6 +94,35 @@ void CommContextManager::CreateGlooCommContext(
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void CommContextManager::CreateXCCLCommContext(
const std::shared_ptr<Store>& store,
const std::string& unique_comm_key,
const std::string& device_type,
int rank,
int size) {
phi::ccl::CCLRootId xccl_root_id;
if (rank == 0) {
phi::DeviceManager::CCLGetUniqueId(device_type, &xccl_root_id);
}
std::string unique_key = "XCCLCommContext/" + unique_comm_key;
if (rank == 0) {
store->set(unique_key, xccl_root_id);
} else {
xccl_root_id = store->get(unique_key);
}
VLOG(3) << "init xccl rank: " << rank << ", nranks: " << size
<< ", unique_comm_key: " << unique_comm_key << ", xccl uniqueid: "
<< phi::ccl::SerializeXCCLUniqueId(xccl_root_id);
auto xccl_comm_context =
std::make_unique<XCCLCommContext>(device_type, rank, size, xccl_root_id);
auto& comm_context_manager = CommContextManager::GetInstance();
comm_context_manager.SetStore(store);
comm_context_manager.Emplace(unique_comm_key, std::move(xccl_comm_context));
}
#endif
CommContext* CommContextManager::Emplace( CommContext* CommContextManager::Emplace(
const std::string& unique_comm_key, const std::string& unique_comm_key,
std::unique_ptr<CommContext> comm_context) { std::unique_ptr<CommContext> comm_context) {
......
...@@ -62,6 +62,14 @@ class CommContextManager { ...@@ -62,6 +62,14 @@ class CommContextManager {
int size); int size);
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
static void CreateXCCLCommContext(const std::shared_ptr<Store>& store,
const std::string& unique_comm_key,
const std::string& device_type,
int rank,
int size);
#endif
private: private:
DISABLE_COPY_AND_ASSIGN(CommContextManager); DISABLE_COPY_AND_ASSIGN(CommContextManager);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#include "glog/logging.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/check/static_check.h"
#include "paddle/phi/core/distributed/utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
namespace distributed {
XCCLCommContext::XCCLCommContext(const std::string& device_type,
int rank,
int size,
const ccl::CCLRootId& xccl_id)
: CommContext(rank, size) {
device_type_ = device_type;
phi::DeviceManager::CCLCommInitRank(device_type,
size_,
const_cast<ccl::CCLRootId*>(&xccl_id),
rank,
&xccl_comm_);
}
void XCCLCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
const phi::stream::Stream& stream) const {
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
if (rank_ == root) {
phi::DeviceManager::CCLBroadcast(device_type_,
const_cast<void*>(in_tensor.data()),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
root,
xccl_comm_,
stream);
} else {
phi::DeviceManager::CCLBroadcast(device_type_,
out_tensor->data(),
out_tensor->numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
root,
xccl_comm_,
stream);
}
}
void XCCLCommContext::AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::GatherLikeShape(
*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLAllGather(device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
xccl_comm_,
stream);
}
void XCCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::ScatterLikeShape(
*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLReduceScatter(
device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
out_tensor->numel(),
phi::ccl::ToCCLDataType(in_tensor.type()),
reduce_type,
xccl_comm_,
stream);
}
void XCCLCommContext::Send(const phi::DenseTensor& in_tensor,
const int64_t& count,
const int& peer,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::CheckShape(
in_tensor, rank_, size_, phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLSend(device_type_,
const_cast<void*>(in_tensor.data()),
count,
phi::ccl::ToCCLDataType(in_tensor.type()),
peer,
xccl_comm_,
stream);
VLOG(3) << "rank " << GetRank() << " send " << phi::product(in_tensor.dims())
<< " to " << peer;
}
void XCCLCommContext::Recv(phi::DenseTensor* out_tensor,
const int64_t& count,
const int& peer,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::CheckShape(
*out_tensor, rank_, size_, phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLRecv(device_type_,
out_tensor->data(),
count,
phi::ccl::ToCCLDataType(out_tensor->type()),
peer,
xccl_comm_,
stream);
VLOG(3) << "rank " << GetRank() << " recv "
<< phi::product(out_tensor->dims()) << " from " << peer;
}
void XCCLCommContext::AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLAllReduce(device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.type()),
reduce_type,
xccl_comm_,
stream);
}
void XCCLCommContext::Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
int root,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ root,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLReduce(device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.type()),
reduce_type,
root,
xccl_comm_,
stream);
}
void XCCLCommContext::GroupStart() const {
phi::DeviceManager::CCLGroupStart(device_type_);
}
void XCCLCommContext::GroupEnd() const {
phi::DeviceManager::CCLGroupEnd(device_type_);
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/distributed/comm_context.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/backends/device_manager.h"
namespace phi {
class DenseTensor;
namespace distributed {
class XCCLCommContext final : public CommContext {
public:
XCCLCommContext(const std::string& device_type,
int rank,
int size,
const ccl::CCLRootId& xccl_id);
ccl::CCLComm GetXcclComm() const { return xccl_comm_; }
const std::string& GetDeviceType() const { return device_type_; }
void Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
const phi::stream::Stream& stream) const;
void Send(const phi::DenseTensor& in_tensor,
const int64_t& count,
const int& peer,
const phi::stream::Stream& stream) const;
void Recv(phi::DenseTensor* out_tensor,
const int64_t& count,
const int& peer,
const phi::stream::Stream& stream) const;
void ReduceScatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
const phi::stream::Stream& stream) const;
void AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const phi::stream::Stream& stream) const;
void AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
const phi::stream::Stream& stream) const;
void Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
int root,
const phi::stream::Stream& stream) const;
void GroupStart() const;
void GroupEnd() const;
private:
DISABLE_COPY_AND_ASSIGN(XCCLCommContext);
std::string device_type_;
ccl::CCLComm xccl_comm_;
};
} // namespace distributed
} // namespace phi
...@@ -20,6 +20,9 @@ ...@@ -20,6 +20,9 @@
#if defined(PADDLE_WITH_GLOO) #if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h" #include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi { namespace phi {
...@@ -49,6 +52,28 @@ void AllGatherKernel(const Context& dev_ctx, ...@@ -49,6 +52,28 @@ void AllGatherKernel(const Context& dev_ctx,
#endif #endif
} }
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
void AllGatherKernel(const phi::CustomContext& dev_ctx,
const DenseTensor& x,
int nranks,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
auto out_dims = x.dims();
out_dims[0] *= nranks;
out->Resize(out_dims);
auto comm_ctx =
static_cast<distributed::XCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_EQ(
nranks,
comm_ctx->GetSize(),
errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm_ctx->GetSize()));
comm_ctx->AllGather(out, x, *dev_ctx.GetStream());
}
#endif
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(all_gather, PD_REGISTER_KERNEL(all_gather,
...@@ -64,3 +89,19 @@ PD_REGISTER_KERNEL(all_gather, ...@@ -64,3 +89,19 @@ PD_REGISTER_KERNEL(all_gather,
int16_t, int16_t,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16) {}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(all_gather,
Custom,
ALL_LAYOUT,
phi::AllGatherKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::float16) {}
#endif
...@@ -20,6 +20,9 @@ ...@@ -20,6 +20,9 @@
#if defined(PADDLE_WITH_GLOO) #if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h" #include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi { namespace phi {
...@@ -47,6 +50,27 @@ void AllReduceKernel(const Context& dev_ctx, ...@@ -47,6 +50,27 @@ void AllReduceKernel(const Context& dev_ctx,
#endif #endif
} }
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
void AllReduceKernel(const phi::CustomContext& dev_ctx,
const DenseTensor& x,
int reduce_type,
DenseTensor* out) {
out->Resize(x.dims());
dev_ctx.template Alloc<T>(out);
auto comm_ctx =
static_cast<distributed::XCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_ctx,
nullptr,
errors::Unavailable("XCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
comm_ctx->AllReduce(
out, x, phi::ccl::ToXCCLReduceOp(reduce_type), *dev_ctx.GetStream());
}
#endif
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(all_reduce, PD_REGISTER_KERNEL(all_reduce,
...@@ -61,3 +85,18 @@ PD_REGISTER_KERNEL(all_reduce, ...@@ -61,3 +85,18 @@ PD_REGISTER_KERNEL(all_reduce,
uint8_t, uint8_t,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16) {}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(all_reduce,
Custom,
ALL_LAYOUT,
phi::AllReduceKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#endif
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
#include "paddle/phi/backends/all_context.h" #include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi { namespace phi {
...@@ -26,6 +29,42 @@ void AllToAllKernel(const Context& dev_ctx UNUSED, ...@@ -26,6 +29,42 @@ void AllToAllKernel(const Context& dev_ctx UNUSED,
PADDLE_THROW( PADDLE_THROW(
errors::Unimplemented("Unimplemented cpu kernel for all_to_all.")); errors::Unimplemented("Unimplemented cpu kernel for all_to_all."));
} }
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
void AllToAllKernel(const phi::CustomContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
out->Resize(x.dims());
dev_ctx.template Alloc<T>(out);
auto comm_ctx =
static_cast<distributed::XCCLCommContext*>(dev_ctx.GetCommContext());
int nranks = comm_ctx->GetSize();
int rank = comm_ctx->GetRank();
int send_numel = x.numel() / nranks;
std::vector<void*> sendbuf, recvbuf;
std::vector<size_t> sendsize(send_numel, nranks);
std::vector<phi::ccl::CCLDataType> sendtype(
phi::ccl::ToCCLDataType(x.dtype()), nranks);
for (auto i = 0; i < nranks; ++i) {
sendbuf.push_back(x.data<T>() + i * send_numel);
recvbuf.push_back(out->data<T>() + i * send_numel);
}
phi::DeviceManager::CCLAllToAll(dev_ctx.GetPlace().GetDeviceType(),
const_cast<const void**>(sendbuf.data()),
sendsize.data(),
sendtype.data(),
recvbuf.data(),
sendsize.data(),
sendtype.data(),
rank,
nranks,
comm_ctx->GetXcclComm(),
*dev_ctx.GetStream());
}
#endif
} // namespace phi } // namespace phi
...@@ -41,3 +80,17 @@ PD_REGISTER_KERNEL(all_to_all, ...@@ -41,3 +80,17 @@ PD_REGISTER_KERNEL(all_to_all,
uint8_t, uint8_t,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16) {}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(all_to_all,
Custom,
ALL_LAYOUT,
phi::AllToAllKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#endif
...@@ -20,6 +20,9 @@ ...@@ -20,6 +20,9 @@
#if defined(PADDLE_WITH_GLOO) #if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h" #include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi { namespace phi {
...@@ -51,6 +54,35 @@ void ReduceKernel(const Context& dev_ctx, ...@@ -51,6 +54,35 @@ void ReduceKernel(const Context& dev_ctx,
#endif #endif
} }
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
void ReduceKernel(const phi::CustomContext& dev_ctx,
const DenseTensor& x,
int root,
int reduce_type,
DenseTensor* out) {
PADDLE_ENFORCE_GT(
x.numel(),
0,
phi::errors::InvalidArgument("Tensor need be reduced must not empyt."));
out->Resize(x.dims());
dev_ctx.template Alloc<T>(out);
auto comm_ctx =
static_cast<distributed::XCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_ctx,
nullptr,
errors::Unavailable("XCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
comm_ctx->Reduce(out,
x,
phi::ccl::ToXCCLReduceOp(reduce_type),
root,
*dev_ctx.GetStream());
}
#endif
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(reduce, PD_REGISTER_KERNEL(reduce,
...@@ -65,3 +97,18 @@ PD_REGISTER_KERNEL(reduce, ...@@ -65,3 +97,18 @@ PD_REGISTER_KERNEL(reduce,
uint8_t, uint8_t,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16) {}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(reduce,
Custom,
ALL_LAYOUT,
phi::ReduceKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#endif
...@@ -334,3 +334,9 @@ def _init_parallel_env(backend): ...@@ -334,3 +334,9 @@ def _init_parallel_env(backend):
core.CommContextManager.create_nccl_comm_context( core.CommContextManager.create_nccl_comm_context(
store, "0", rank, world_size store, "0", rank, world_size
) )
elif backend == "xccl":
dev_type = global_env.device_type
paddle.device.set_device(f"{dev_type}:{dev_id}")
core.CommContextManager.create_xccl_comm_context(
store, "0", rank, world_size, dev_type
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册