未验证 提交 12e9aaa5 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix process_group_custom api (#50718)

* [CustomDevice] fix process_group_custom api

* update

* update

* update

* update
上级 bf50784c
...@@ -61,5 +61,7 @@ if(WITH_CUSTOM_DEVICE) ...@@ -61,5 +61,7 @@ if(WITH_CUSTOM_DEVICE)
place place
enforce enforce
collective_helper collective_helper
device_context) device_context
comm_static_check
dense_tensor)
endif() endif()
...@@ -16,11 +16,13 @@ ...@@ -16,11 +16,13 @@
#include "paddle/fluid/distributed/collective/common.h" #include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/distributed/collective/custom_ccl_tools.h" #include "paddle/fluid/distributed/collective/custom_ccl_tools.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/distributed/check/static_check.h"
DECLARE_bool(xccl_blocking_wait); DECLARE_bool(xccl_blocking_wait);
...@@ -234,10 +236,21 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather( ...@@ -234,10 +236,21 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
bool sync_op // for compatibility, no use now bool sync_op, // for compatibility, no use now
) { bool use_calc_stream) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor}; // numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0
? paddle::distributed::GetPartialTensor(in_tensor, offset, numel)
: in_tensor;
phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor,
in_tensor_maybe_partial,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
std::vector<phi::DenseTensor> in_wrapper{in_tensor_maybe_partial};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor}; std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective( return Collective(
in_wrapper, in_wrapper,
out_wrapper, out_wrapper,
...@@ -247,9 +260,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather( ...@@ -247,9 +260,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
const phi::stream::Stream& stream) { const phi::stream::Stream& stream) {
return phi::DeviceManager::CCLAllGather( return phi::DeviceManager::CCLAllGather(
device_type_, device_type_,
XcclGetPointerByOffset(input.data(), offset, input.dtype()), input.data(),
output.data(), output.data(),
numel, input.numel(),
phi::ccl::ToCCLDataType(input.dtype()), phi::ccl::ToCCLDataType(input.dtype()),
comm, comm,
stream); stream);
...@@ -257,70 +270,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather( ...@@ -257,70 +270,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
CommType::ALLGATHER); CommType::ALLGATHER);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op // for compatibility, no use now
) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllReduce(in_wrapper, out_wrapper, opts);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts, int64_t offset,
bool sync_op // for compatibility, no use now int64_t numel,
) { bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor}; return AllGather(out_tensor, in_tensor, offset, numel, sync_op);
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
const BarrierOptions& opts) {
// Only support single card single process
PADDLE_ENFORCE_GE(opts.device_id,
0,
platform::errors::PreconditionNotMet(
"The barrier device id must greater or equal than 0."));
platform::CustomPlace place(device_type_, opts.device_id);
auto allocator = std::unique_ptr<phi::Allocator>(
new paddle::experimental::DefaultAllocator(place));
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
phi::DenseTensor barrier_tensor{allocator.get(), meta};
auto task = ProcessGroupCustom::AllReduce(&barrier_tensor,
barrier_tensor,
{},
/*sync_op*/ true);
auto xccl_task = dynamic_cast<ProcessGroupCustom::CustomTask*>(task.get());
xccl_task->barrierTensors_ = {barrier_tensor};
return task;
}
phi::DeviceContext* ProcessGroupCustom::GetDeviceContext(
const Place& place) const {
const std::string key = GetKeyFromPlace(place);
const auto& iter = places_to_ctx_.find(key);
PADDLE_ENFORCE_NE(
iter,
places_to_ctx_.end(),
platform::errors::NotFound(
"Cannot find the device context in this process group."));
return iter->second[0].get();
}
phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_customcomm_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_customcomm_.end(),
platform::errors::InvalidArgument(
"Cannot find nccl comm in process group."));
return iter->second[0]->GetCustomCCLComm();
} }
// TODO(sunyilun): methods below will be removed later // TODO(sunyilun): methods below will be removed later
...@@ -356,6 +312,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather( ...@@ -356,6 +312,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
CommType::ALLGATHER); CommType::ALLGATHER);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op, // for compatibility, no use now
bool use_calc_stream) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllReduce(in_wrapper, out_wrapper, opts);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op // for compatibility, no use now
) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllReduce(in_wrapper, out_wrapper, opts);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT std::vector<phi::DenseTensor>& out_tensors, // NOLINT
...@@ -390,6 +368,72 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce( ...@@ -390,6 +368,72 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
CommType::ALLREDUCE); CommType::ALLREDUCE);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op, // for compatibility, no use now
bool use_calc_stream) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
const BarrierOptions& opts) {
// Only support single card single process
PADDLE_ENFORCE_GE(opts.device_id,
0,
platform::errors::PreconditionNotMet(
"The barrier device id must greater or equal than 0."));
platform::CustomPlace place(device_type_, opts.device_id);
auto allocator = std::unique_ptr<phi::Allocator>(
new paddle::experimental::DefaultAllocator(place));
phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
phi::DenseTensor barrier_tensor{allocator.get(), meta};
auto task = ProcessGroupCustom::AllReduce(&barrier_tensor,
barrier_tensor,
{},
/*sync_op*/ true,
false);
auto xccl_task = dynamic_cast<ProcessGroupCustom::CustomTask*>(task.get());
xccl_task->barrierTensors_ = {barrier_tensor};
return task;
}
phi::DeviceContext* ProcessGroupCustom::GetDeviceContext(
const Place& place) const {
const std::string key = GetKeyFromPlace(place);
const auto& iter = places_to_ctx_.find(key);
PADDLE_ENFORCE_NE(
iter,
places_to_ctx_.end(),
platform::errors::NotFound(
"Cannot find the device context in this process group."));
return iter->second[0].get();
}
phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_customcomm_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_customcomm_.end(),
platform::errors::InvalidArgument(
"Cannot find nccl comm in process group."));
return iter->second[0]->GetCustomCCLComm();
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast( std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT std::vector<phi::DenseTensor>& out_tensors, // NOLINT
......
...@@ -80,25 +80,6 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream { ...@@ -80,25 +80,6 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream {
std::string GetBackendName() const override { return "XCCL_" + device_type_; } std::string GetBackendName() const override { return "XCCL_" + device_type_; }
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Barrier( std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override; const BarrierOptions& = BarrierOptions()) override;
...@@ -111,16 +92,57 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream { ...@@ -111,16 +92,57 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream {
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override; std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
const AllreduceOptions& = AllreduceOptions()) override; const AllreduceOptions& = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Broadcast( std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
const BroadcastOptions& = BroadcastOptions()) override; const BroadcastOptions& = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) override;
protected: protected:
virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask( virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask(
std::vector<Place> places, std::vector<Place> places,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册