未验证 提交 62397cd2 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add comm context support (#56301)

上级 ede8fd55
......@@ -18,7 +18,7 @@
namespace paddle {
namespace distributed {
phi::ccl::CCLReduceOp ToCustomCCLRedType(ReduceOp reduction) {
phi::ccl::CCLReduceOp ToXCCLRedType(ReduceOp reduction) {
static const std::map<ReduceOp, phi::ccl::CCLReduceOp> red_type = {
{ReduceOp::MIN, phi::ccl::CCLReduceOp::MIN},
{ReduceOp::MAX, phi::ccl::CCLReduceOp::MAX},
......@@ -34,14 +34,5 @@ phi::ccl::CCLReduceOp ToCustomCCLRedType(ReduceOp reduction) {
return it->second;
}
std::string SerializeCustomCCLUniqueId(const phi::ccl::CCLRootId& ccl_id) {
const uint8_t* bytes = ccl_id.data();
std::ostringstream oss;
for (size_t i = 0; i < ccl_id.size(); ++i) {
oss << std::hex << static_cast<int>(bytes[i]);
}
return oss.str();
}
} // namespace distributed
} // namespace paddle
......@@ -34,170 +34,7 @@
namespace paddle {
namespace distributed {
class CustomEventManager {
public:
CustomEventManager() = default;
~CustomEventManager() {
if (is_created_) {
event_->Destroy();
}
}
CustomEventManager(const CustomEventManager&) = delete;
CustomEventManager& operator=(const CustomEventManager&) = delete;
CustomEventManager(CustomEventManager&& other) {
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(device_type_, other.device_type_);
std::swap(event_, other.event_);
}
CustomEventManager& operator=(CustomEventManager&& other) {
std::swap(is_created_, other.is_created_);
std::swap(device_index_, other.device_index_);
std::swap(device_type_, other.device_type_);
std::swap(event_, other.event_);
return *this;
}
bool IsCreated() const { return is_created_; }
int8_t DeviceId() const { return device_index_; }
std::string DeviceType() const { return device_type_; }
phi::event::event_t GetRawCustomEvent() const { return event_->raw_event(); }
phi::event::Event* GetCustomEvent() const { return event_.get(); }
void Record(const paddle::platform::CustomDeviceContext& ctx) {
auto place = ctx.GetPlace();
auto device_type = place.GetDeviceType();
auto device_index = place.GetDeviceId();
if (!is_created_) {
CreateEvent(place);
}
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"CustomDeviceContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
PADDLE_ENFORCE_EQ(device_type,
device_type_,
platform::errors::PreconditionNotMet(
"CustomDeviceContext's device %d does not match"
"Event's device type %d",
device_type,
device_type_));
phi::DeviceGuard guard(place);
phi::stream::Stream stream(place, ctx.stream());
event_->Record(&stream);
}
bool Query() const { return event_->Query(); }
void Block(const paddle::platform::CustomDeviceContext& ctx) const {
if (is_created_) {
auto place = ctx.GetPlace();
auto device_type = place.GetDeviceType();
auto device_index = place.GetDeviceId();
PADDLE_ENFORCE_EQ(device_index,
device_index_,
platform::errors::PreconditionNotMet(
"CustomDeviceContext's device %d does not match"
"Event's device %d",
device_index,
device_index_));
PADDLE_ENFORCE_EQ(device_type,
device_type_,
platform::errors::PreconditionNotMet(
"CustomDeviceContext's device %d does not match"
"Event's device type %d",
device_type,
device_type_));
phi::DeviceGuard guard(place);
phi::stream::Stream stream(place, ctx.stream());
stream.WaitEvent(event_.get());
}
}
private:
bool is_created_{false};
std::shared_ptr<phi::event::Event> event_{nullptr};
int8_t device_index_{0};
std::string device_type_;
private:
void CreateEvent(const platform::Place& place) {
device_index_ = place.GetDeviceId();
device_type_ = place.GetDeviceType();
event_.reset(new phi::event::Event);
event_->Init(place);
is_created_ = true;
}
};
class CustomCCLCommManager {
public:
CustomCCLCommManager(const std::string& device_type,
phi::ccl::CCLComm ccl_comm)
: device_type_(device_type), ccl_comm_(ccl_comm) {}
CustomCCLCommManager() : CustomCCLCommManager("", nullptr) {}
~CustomCCLCommManager() noexcept {
std::unique_lock<std::mutex> lock(mutex_);
if (phi::DeviceManager::HasDeviceType(device_type_) && ccl_comm_) {
phi::DeviceManager::CCLDestroyComm(device_type_, ccl_comm_);
}
}
static std::shared_ptr<CustomCCLCommManager> Create(
const std::string& device_type,
int num_ranks,
int rank,
phi::ccl::CCLRootId* comm_id,
phi::ccl::CCLComm* ccl_comm) {
auto custom_ccl_manager = std::make_shared<CustomCCLCommManager>();
phi::DeviceManager::CCLCommInitRank(
device_type, num_ranks, comm_id, rank, ccl_comm);
custom_ccl_manager->device_type_ = device_type;
custom_ccl_manager->ccl_id_ = comm_id;
custom_ccl_manager->rank_ = rank;
custom_ccl_manager->ccl_comm_ = *ccl_comm;
return custom_ccl_manager;
}
phi::ccl::CCLRootId* GetCustomCCLId() const {
std::unique_lock<std::mutex> lock(mutex_);
return ccl_id_;
}
phi::ccl::CCLComm GetCustomCCLComm() const {
std::unique_lock<std::mutex> lock(mutex_);
return ccl_comm_;
}
CustomCCLCommManager(const CustomCCLCommManager&) = delete;
CustomCCLCommManager& operator=(const CustomCCLCommManager&) = delete;
CustomCCLCommManager& operator=(CustomCCLCommManager&& other) = delete;
CustomCCLCommManager(CustomCCLCommManager&& other) {
std::unique_lock<std::mutex> lock(other.mutex_);
std::swap(ccl_comm_, other.ccl_comm_);
}
protected:
std::string device_type_;
phi::ccl::CCLComm ccl_comm_;
phi::ccl::CCLRootId* ccl_id_;
int rank_;
mutable std::mutex mutex_;
};
phi::ccl::CCLReduceOp ToCustomCCLRedType(ReduceOp reduction);
std::string SerializeCustomCCLUniqueId(const phi::ccl::CCLRootId& ccl_id);
phi::ccl::CCLReduceOp ToXCCLRedType(ReduceOp reduction);
} // namespace distributed
} // namespace paddle
......@@ -15,62 +15,57 @@
#pragma once
#include <chrono>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/collective/custom_ccl_tools.h"
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_with_stream.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/distributed/xccl_comm_context.h"
namespace paddle {
namespace distributed {
using Place = paddle::platform::Place;
using CustomDeviceContext = paddle::platform::CustomDeviceContext;
class ProcessGroupCustom : public ProcessGroupWithStream {
using Place = phi::Place;
class ProcessGroupCustom final : public ProcessGroupWithStream {
public:
class CustomTask : public ProcessGroup::Task,
public std::enable_shared_from_this<CustomTask> {
class XCCLTask final : public ProcessGroupWithStream::TaskStream,
public std::enable_shared_from_this<XCCLTask> {
public:
CustomTask(const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs);
XCCLTask(const Place& place,
int rank,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
virtual ~XCCLTask();
bool IsCompleted() override;
void SynchronizeStreams();
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override;
void Synchronize() override;
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
void SetOutputs(std::vector<phi::DenseTensor>& outputs); // NOLINT
virtual ~CustomTask();
std::vector<CustomEventManager> control_events_;
std::vector<phi::DenseTensor> barrierTensors_;
bool IsBlockCPUInWait() const { return block_cpu_in_wait_; }
void SetBlockCPUInWait() { block_cpu_in_wait_ = true; }
protected:
std::vector<Place> places_;
std::vector<std::shared_ptr<CustomCCLCommManager>> cclComms_;
std::shared_ptr<std::vector<phi::DenseTensor>> outputs_;
// TODO(sunyilun): methods below will be removed later
XCCLTask(const std::vector<Place>& places,
int rank,
CommType CommType,
const std::vector<phi::DenseTensor>& inputs);
private:
const std::string device_type_;
bool block_cpu_in_wait_{false};
phi::event::Event comm_event_; // event on comm stream
Place task_place_;
};
ProcessGroupCustom(const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type,
int rank,
int size,
int gid);
public:
static std::shared_ptr<ProcessGroupCustom> CreateProcessGroupCustom(
const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type,
......@@ -78,19 +73,18 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
int size,
int gid);
std::string GetBackendName() const override { return "XCCL_" + device_type_; }
ProcessGroupCustom(const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type,
int rank,
int size,
int gid);
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::string GetBackendName() const override { return "XCCL"; }
phi::DeviceContext* GetDeviceContext(const Place& place) const override;
phi::ccl::CCLComm CustomCCLComm(const Place& place) const;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const override;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
......@@ -100,11 +94,6 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
......@@ -112,10 +101,16 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
......@@ -124,49 +119,12 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
......@@ -180,11 +138,6 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
......@@ -198,52 +151,124 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
protected:
virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask(
std::vector<Place> places,
int rank,
CommType opType,
const std::vector<phi::DenseTensor>& inputs);
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<phi::distributed::Store> store_;
std::shared_ptr<CustomCCLCommManager> custom_comm_;
std::mutex mutex_;
std::unordered_map<std::string,
std::vector<std::shared_ptr<CustomCCLCommManager>>>
places_to_customcomm_;
std::unordered_map<std::string, std::vector<CustomEventManager>>
places_to_events_;
std::unordered_map<std::string,
std::vector<std::unique_ptr<CustomDeviceContext>>>
places_to_ctx_;
std::set<int> used_place_ids_;
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) override;
static void GroupStart(const std::string& dev_type);
static void GroupEnd(const std::string& dev_type);
phi::ccl::CCLComm XCCLComm(const Place& place) const;
// TODO(liyurui): This API will be moved later
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) override;
private:
void BcastCustomId(std::vector<phi::ccl::CCLRootId>& ccl_ids, // NOLINT
int root,
int server_fd);
std::shared_ptr<ProcessGroupCustom::XCCLTask> CreateTask(
const Place& place,
int rank,
CommType op_type,
bool sync_op,
bool use_calc_stream);
void BroadcastUniqueCustomID(
std::vector<phi::ccl::CCLRootId>& custom_ccl_ids); // NOLINT
void BroadcastUniqueXCCLID(phi::ccl::CCLRootId* nccl_id);
void CreateXCCLEnvCache(const Place& place, const std::string& place_key);
void SyncCalcStream(const Place& place);
std::shared_ptr<ProcessGroup::Task> RunFnInXCCLEnv(
std::function<void(const phi::stream::Stream&)> fn,
const phi::DenseTensor& tensor,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroupCustom::XCCLTask> CreateTask(
std::vector<Place> places,
int rank,
CommType op_type,
const std::vector<phi::DenseTensor>& inputs);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective(
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
Fn fn,
CommType op_type,
bool sync_op,
bool use_calc_stream);
CommType op_type);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective(Fn fn,
CommType op_type,
bool sync_op,
bool use_calc_stream);
void CreateCustomManagerCache(const std::string& places_key,
const std::vector<Place>& places);
const std::string device_type_;
std::shared_ptr<ProcessGroup::Task> PointToPoint(
std::vector<phi::DenseTensor>& tensors, // NOLINT
Fn fn,
int dst_rank,
CommType op_type);
void CreateXCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places);
phi::distributed::XCCLCommContext* GetCommContext();
private:
std::shared_ptr<phi::distributed::Store> store_;
std::string device_type_;
std::unordered_map<std::string, std::unique_ptr<phi::event::Event>>
place_to_calc_event_; // event on calc stream
std::unordered_map<std::string, phi::CustomContext*> place_to_calc_ctx_;
std::unordered_map<std::string, std::unique_ptr<phi::CustomContext>>
place_to_comm_ctx_;
// TODO(sunyilun): attrs below will be removed later
std::mutex mutex_;
std::unordered_map<std::string, std::vector<phi::CustomContext*>>
places_to_ctx_;
};
} // namespace distributed
} // namespace paddle
......@@ -53,8 +53,8 @@ ccl::CCLComm GetCCLComm(const Place& place, int global_gid) {
#endif
} else if (place.GetType() == phi::AllocationType::CUSTOM) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
return static_cast<paddle::distributed::ProcessGroupCustom*>(pg)
->CustomCCLComm(place);
return static_cast<paddle::distributed::ProcessGroupCustom*>(pg)->XCCLComm(
place);
#else
return nullptr;
#endif
......
......@@ -1574,6 +1574,7 @@ void AllocatorFacade::SetDefaultStream(const platform::CustomPlace& place,
m_->SetDefaultStream(place, stream);
}
}
#endif
UNUSED static std::shared_ptr<NaiveBestFitAllocator> unused_obj =
......
......@@ -101,6 +101,7 @@ class AllocatorFacade {
phi::stream::stream_t stream);
void RecordStream(std::shared_ptr<phi::Allocation> allocation,
phi::stream::stream_t stream);
void SetDefaultStream(const platform::CustomPlace& place,
phi::stream::stream_t stream);
#endif
......
......@@ -71,6 +71,14 @@ gpuStream_t GetStream(const std::shared_ptr<Allocation>& allocation) {
return allocation::AllocatorFacade::Instance().GetStream(allocation);
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void RecordStream(std::shared_ptr<Allocation> allocation,
phi::stream::stream_t stream) {
return allocation::AllocatorFacade::Instance().RecordStream(allocation,
stream);
}
#endif
} // namespace memory
} // namespace paddle
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/stream.h"
......@@ -55,5 +56,9 @@ void RecordStream(std::shared_ptr<Allocation> allocation, gpuStream_t stream);
gpuStream_t GetStream(const std::shared_ptr<Allocation>& allocation);
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void RecordStream(std::shared_ptr<Allocation> allocation,
phi::stream::stream_t stream);
#endif
} // namespace memory
} // namespace paddle
......@@ -56,6 +56,12 @@ void BindCommContextManager(py::module *m) {
"create_gloo_comm_context",
&phi::distributed::CommContextManager::CreateGlooCommContext,
py::call_guard<py::gil_scoped_release>())
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
.def_static(
"create_xccl_comm_context",
&phi::distributed::CommContextManager::CreateXCCLCommContext,
py::call_guard<py::gil_scoped_release>())
#endif
.def("set_store", &phi::distributed::CommContextManager::SetStore);
}
......
......@@ -21,6 +21,8 @@
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/common/reduce_type.h"
namespace phi {
namespace ccl {
typedef void* CCLComm;
......@@ -38,6 +40,32 @@ enum CCLDataType {
CCL_DATA_TYPE_UINT8
};
inline CCLReduceOp ToXCCLReduceOp(int reduce_type) {
phi::ccl::CCLReduceOp red_type = phi::ccl::CCLReduceOp::SUM;
switch (static_cast<phi::ReduceType>(reduce_type)) {
case phi::ReduceType::kRedSum:
red_type = phi::ccl::CCLReduceOp::SUM;
break;
case phi::ReduceType::kRedMax:
red_type = phi::ccl::CCLReduceOp::MAX;
break;
case phi::ReduceType::kRedMin:
red_type = phi::ccl::CCLReduceOp::MIN;
break;
case phi::ReduceType::kRedProd:
red_type = phi::ccl::CCLReduceOp::PRODUCT;
break;
case phi::ReduceType::kRedAvg:
red_type = phi::ccl::CCLReduceOp::AVG;
break;
default:
PADDLE_THROW(
errors::Unavailable("Unsuppored reduce type. Reduce type must be one "
"of SUM, MAX, MIN, PRODUCT and AVG."));
}
return red_type;
}
inline CCLDataType ToCCLDataType(phi::DataType type) {
if (type == phi::DataType::FLOAT64) {
return CCL_DATA_TYPE_FP64;
......@@ -79,5 +107,14 @@ inline phi::DataType ToPhiDataType(CCLDataType type) {
}
}
inline std::string SerializeXCCLUniqueId(const phi::ccl::CCLRootId& ccl_id) {
const uint8_t* bytes = ccl_id.data();
std::ostringstream oss;
for (size_t i = 0; i < ccl_id.size(); ++i) {
oss << std::hex << static_cast<int>(bytes[i]);
}
return oss.str();
}
} // namespace ccl
} // namespace phi
......@@ -754,13 +754,15 @@ class CustomDevice : public DeviceInterface {
}
void CCLGroupStart() override {
CHECK_PTR(pimpl_->xccl_group_start);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_start());
if (pimpl_->xccl_group_start) {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_start());
}
}
void CCLGroupEnd() override {
CHECK_PTR(pimpl_->xccl_group_end);
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_end());
if (pimpl_->xccl_group_end) {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_end());
}
}
void CCLSend(void* send_buf,
......
......@@ -12,4 +12,8 @@ if(WITH_GLOO)
list(APPEND DISTRIBUTED_COMMON_SRCS gloo_utils.cc gloo_comm_context.cc)
endif()
if(WITH_CUSTOM_DEVICE)
list(APPEND DISTRIBUTED_COMMON_SRCS xccl_comm_context.cc)
endif()
collect_srcs(core_srcs SRCS ${DISTRIBUTED_COMMON_SRCS})
......@@ -53,6 +53,18 @@ DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
}));
return out;
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (phi::CustomContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const CustomContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The all_gather in reshard only supported on CPU and GPU for now."));
......
......@@ -170,6 +170,15 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
#else
PADDLE_THROW(phi::errors::Unimplemented(
"Cannot use gloo on CPU, please turn PADDLE_WITH_GLOO flag on."));
#endif
} else if (phi::CustomContext::classof(&dev_ctx)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
CommContextManager::CreateXCCLCommContext(
store,
unique_comm_key,
dev_ctx.GetPlace().GetDeviceType(),
rank,
world_size);
#endif
} else {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
......
......@@ -24,8 +24,8 @@ class CommContext {
CommContext(int rank, int size) : rank_(rank), size_(size) {}
virtual ~CommContext() = default;
int GetRank() { return rank_; }
int GetSize() { return size_; }
int GetRank() const { return rank_; }
int GetSize() const { return size_; }
protected:
int rank_;
......
......@@ -32,6 +32,9 @@
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi {
namespace distributed {
......@@ -91,6 +94,35 @@ void CommContextManager::CreateGlooCommContext(
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void CommContextManager::CreateXCCLCommContext(
const std::shared_ptr<Store>& store,
const std::string& unique_comm_key,
const std::string& device_type,
int rank,
int size) {
phi::ccl::CCLRootId xccl_root_id;
if (rank == 0) {
phi::DeviceManager::CCLGetUniqueId(device_type, &xccl_root_id);
}
std::string unique_key = "XCCLCommContext/" + unique_comm_key;
if (rank == 0) {
store->set(unique_key, xccl_root_id);
} else {
xccl_root_id = store->get(unique_key);
}
VLOG(3) << "init xccl rank: " << rank << ", nranks: " << size
<< ", unique_comm_key: " << unique_comm_key << ", xccl uniqueid: "
<< phi::ccl::SerializeXCCLUniqueId(xccl_root_id);
auto xccl_comm_context =
std::make_unique<XCCLCommContext>(device_type, rank, size, xccl_root_id);
auto& comm_context_manager = CommContextManager::GetInstance();
comm_context_manager.SetStore(store);
comm_context_manager.Emplace(unique_comm_key, std::move(xccl_comm_context));
}
#endif
CommContext* CommContextManager::Emplace(
const std::string& unique_comm_key,
std::unique_ptr<CommContext> comm_context) {
......
......@@ -62,6 +62,14 @@ class CommContextManager {
int size);
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
static void CreateXCCLCommContext(const std::shared_ptr<Store>& store,
const std::string& unique_comm_key,
const std::string& device_type,
int rank,
int size);
#endif
private:
DISABLE_COPY_AND_ASSIGN(CommContextManager);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#include "glog/logging.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/check/static_check.h"
#include "paddle/phi/core/distributed/utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
namespace distributed {
XCCLCommContext::XCCLCommContext(const std::string& device_type,
int rank,
int size,
const ccl::CCLRootId& xccl_id)
: CommContext(rank, size) {
device_type_ = device_type;
phi::DeviceManager::CCLCommInitRank(device_type,
size_,
const_cast<ccl::CCLRootId*>(&xccl_id),
rank,
&xccl_comm_);
}
void XCCLCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
const phi::stream::Stream& stream) const {
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
if (rank_ == root) {
phi::DeviceManager::CCLBroadcast(device_type_,
const_cast<void*>(in_tensor.data()),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
root,
xccl_comm_,
stream);
} else {
phi::DeviceManager::CCLBroadcast(device_type_,
out_tensor->data(),
out_tensor->numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
root,
xccl_comm_,
stream);
}
}
void XCCLCommContext::AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::GatherLikeShape(
*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLAllGather(device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
xccl_comm_,
stream);
}
void XCCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::ScatterLikeShape(
*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLReduceScatter(
device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
out_tensor->numel(),
phi::ccl::ToCCLDataType(in_tensor.type()),
reduce_type,
xccl_comm_,
stream);
}
void XCCLCommContext::Send(const phi::DenseTensor& in_tensor,
const int64_t& count,
const int& peer,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::CheckShape(
in_tensor, rank_, size_, phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLSend(device_type_,
const_cast<void*>(in_tensor.data()),
count,
phi::ccl::ToCCLDataType(in_tensor.type()),
peer,
xccl_comm_,
stream);
VLOG(3) << "rank " << GetRank() << " send " << phi::product(in_tensor.dims())
<< " to " << peer;
}
void XCCLCommContext::Recv(phi::DenseTensor* out_tensor,
const int64_t& count,
const int& peer,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::CheckShape(
*out_tensor, rank_, size_, phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLRecv(device_type_,
out_tensor->data(),
count,
phi::ccl::ToCCLDataType(out_tensor->type()),
peer,
xccl_comm_,
stream);
VLOG(3) << "rank " << GetRank() << " recv "
<< phi::product(out_tensor->dims()) << " from " << peer;
}
void XCCLCommContext::AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLAllReduce(device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.type()),
reduce_type,
xccl_comm_,
stream);
}
void XCCLCommContext::Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
int root,
const phi::stream::Stream& stream) const {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ root,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
phi::DeviceManager::CCLReduce(device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.type()),
reduce_type,
root,
xccl_comm_,
stream);
}
void XCCLCommContext::GroupStart() const {
phi::DeviceManager::CCLGroupStart(device_type_);
}
void XCCLCommContext::GroupEnd() const {
phi::DeviceManager::CCLGroupEnd(device_type_);
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/distributed/comm_context.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/backends/device_manager.h"
namespace phi {
class DenseTensor;
namespace distributed {
class XCCLCommContext final : public CommContext {
public:
XCCLCommContext(const std::string& device_type,
int rank,
int size,
const ccl::CCLRootId& xccl_id);
ccl::CCLComm GetXcclComm() const { return xccl_comm_; }
const std::string& GetDeviceType() const { return device_type_; }
void Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
const phi::stream::Stream& stream) const;
void Send(const phi::DenseTensor& in_tensor,
const int64_t& count,
const int& peer,
const phi::stream::Stream& stream) const;
void Recv(phi::DenseTensor* out_tensor,
const int64_t& count,
const int& peer,
const phi::stream::Stream& stream) const;
void ReduceScatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
const phi::stream::Stream& stream) const;
void AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const phi::stream::Stream& stream) const;
void AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
const phi::stream::Stream& stream) const;
void Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
phi::ccl::CCLReduceOp reduce_type,
int root,
const phi::stream::Stream& stream) const;
void GroupStart() const;
void GroupEnd() const;
private:
DISABLE_COPY_AND_ASSIGN(XCCLCommContext);
std::string device_type_;
ccl::CCLComm xccl_comm_;
};
} // namespace distributed
} // namespace phi
......@@ -20,6 +20,9 @@
#if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi {
......@@ -49,6 +52,28 @@ void AllGatherKernel(const Context& dev_ctx,
#endif
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
void AllGatherKernel(const phi::CustomContext& dev_ctx,
const DenseTensor& x,
int nranks,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
auto out_dims = x.dims();
out_dims[0] *= nranks;
out->Resize(out_dims);
auto comm_ctx =
static_cast<distributed::XCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_EQ(
nranks,
comm_ctx->GetSize(),
errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm_ctx->GetSize()));
comm_ctx->AllGather(out, x, *dev_ctx.GetStream());
}
#endif
} // namespace phi
PD_REGISTER_KERNEL(all_gather,
......@@ -64,3 +89,19 @@ PD_REGISTER_KERNEL(all_gather,
int16_t,
int64_t,
phi::dtype::float16) {}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(all_gather,
Custom,
ALL_LAYOUT,
phi::AllGatherKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::float16) {}
#endif
......@@ -20,6 +20,9 @@
#if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi {
......@@ -47,6 +50,27 @@ void AllReduceKernel(const Context& dev_ctx,
#endif
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
void AllReduceKernel(const phi::CustomContext& dev_ctx,
const DenseTensor& x,
int reduce_type,
DenseTensor* out) {
out->Resize(x.dims());
dev_ctx.template Alloc<T>(out);
auto comm_ctx =
static_cast<distributed::XCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_ctx,
nullptr,
errors::Unavailable("XCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
comm_ctx->AllReduce(
out, x, phi::ccl::ToXCCLReduceOp(reduce_type), *dev_ctx.GetStream());
}
#endif
} // namespace phi
PD_REGISTER_KERNEL(all_reduce,
......@@ -61,3 +85,18 @@ PD_REGISTER_KERNEL(all_reduce,
uint8_t,
int64_t,
phi::dtype::float16) {}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(all_reduce,
Custom,
ALL_LAYOUT,
phi::AllReduceKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#endif
......@@ -16,6 +16,9 @@
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi {
......@@ -26,6 +29,42 @@ void AllToAllKernel(const Context& dev_ctx UNUSED,
PADDLE_THROW(
errors::Unimplemented("Unimplemented cpu kernel for all_to_all."));
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
void AllToAllKernel(const phi::CustomContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
out->Resize(x.dims());
dev_ctx.template Alloc<T>(out);
auto comm_ctx =
static_cast<distributed::XCCLCommContext*>(dev_ctx.GetCommContext());
int nranks = comm_ctx->GetSize();
int rank = comm_ctx->GetRank();
int send_numel = x.numel() / nranks;
std::vector<void*> sendbuf, recvbuf;
std::vector<size_t> sendsize(send_numel, nranks);
std::vector<phi::ccl::CCLDataType> sendtype(
phi::ccl::ToCCLDataType(x.dtype()), nranks);
for (auto i = 0; i < nranks; ++i) {
sendbuf.push_back(x.data<T>() + i * send_numel);
recvbuf.push_back(out->data<T>() + i * send_numel);
}
phi::DeviceManager::CCLAllToAll(dev_ctx.GetPlace().GetDeviceType(),
const_cast<const void**>(sendbuf.data()),
sendsize.data(),
sendtype.data(),
recvbuf.data(),
sendsize.data(),
sendtype.data(),
rank,
nranks,
comm_ctx->GetXcclComm(),
*dev_ctx.GetStream());
}
#endif
} // namespace phi
......@@ -41,3 +80,17 @@ PD_REGISTER_KERNEL(all_to_all,
uint8_t,
int64_t,
phi::dtype::float16) {}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(all_to_all,
Custom,
ALL_LAYOUT,
phi::AllToAllKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#endif
......@@ -20,6 +20,9 @@
#if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#endif
namespace phi {
......@@ -51,6 +54,35 @@ void ReduceKernel(const Context& dev_ctx,
#endif
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
void ReduceKernel(const phi::CustomContext& dev_ctx,
const DenseTensor& x,
int root,
int reduce_type,
DenseTensor* out) {
PADDLE_ENFORCE_GT(
x.numel(),
0,
phi::errors::InvalidArgument("Tensor need be reduced must not empyt."));
out->Resize(x.dims());
dev_ctx.template Alloc<T>(out);
auto comm_ctx =
static_cast<distributed::XCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_ctx,
nullptr,
errors::Unavailable("XCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
comm_ctx->Reduce(out,
x,
phi::ccl::ToXCCLReduceOp(reduce_type),
root,
*dev_ctx.GetStream());
}
#endif
} // namespace phi
PD_REGISTER_KERNEL(reduce,
......@@ -65,3 +97,18 @@ PD_REGISTER_KERNEL(reduce,
uint8_t,
int64_t,
phi::dtype::float16) {}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_KERNEL(reduce,
Custom,
ALL_LAYOUT,
phi::ReduceKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#endif
......@@ -334,3 +334,9 @@ def _init_parallel_env(backend):
core.CommContextManager.create_nccl_comm_context(
store, "0", rank, world_size
)
elif backend == "xccl":
dev_type = global_env.device_type
paddle.device.set_device(f"{dev_type}:{dev_id}")
core.CommContextManager.create_xccl_comm_context(
store, "0", rank, world_size, dev_type
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册