未验证 提交 25e63dca 编写于 作者: W Wen Sun 提交者: GitHub

Refactor collective communication send_partial, recv_partial, all_gather_partial C++ API (#47863)

* refactor: simplify send, recv interfaces

* refactor: rm send_partial, recv_partial, all_gather_partial
上级 dac0f7dd
......@@ -98,17 +98,19 @@ class ProcessGroup {
virtual std::string GetBackendName() const = 0;
virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const {
PADDLE_THROW(platform::errors::InvalidArgument(
"Does not support to get device_context from ProcessGroup%s.",
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support get device_context.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support all_gather with sync_op flag",
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support all_gather with sync_op flag.",
GetBackendName()));
}
......@@ -117,15 +119,15 @@ class ProcessGroup {
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support all_reduce with sync_op flag",
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support all_reduce with sync_op flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support barrier", GetBackendName()));
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support barrier.", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
......@@ -133,46 +135,28 @@ class ProcessGroup {
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support broadcast with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support recv with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support recv_partial with sync_op flag",
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support recv with sync_op flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor*,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send_partial with sync_op flag",
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support send with sync_op flag.",
GetBackendName()));
}
......@@ -240,38 +224,6 @@ class ProcessGroup {
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor&, // NOLINT
int,
int64_t,
int64_t) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send_partial", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor&, int, int64_t, int64_t, bool) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send_partial with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor&, // NOLINT
int,
int64_t,
int64_t) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support recv_partial", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor&, int, int64_t, int64_t, bool) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support recv_partial with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
......@@ -288,25 +240,6 @@ class ProcessGroup {
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
int64_t offset,
int64_t length) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllGather_Partial", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
int64_t offset,
int64_t length,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllGather_Partial", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
......
......@@ -228,6 +228,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset, // for compatibility, no use now
int64_t numel, // for compatibility, no use now
bool sync_op,
bool use_calc_stream) {
return Collective(
......
......@@ -99,6 +99,8 @@ class ProcessGroupBKCL : public ProcessGroupStream {
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset, // for compatibility, no use now
int64_t numel, // for compatibility, no use now
bool sync_op,
bool use_calc_stream) override;
......
......@@ -259,24 +259,18 @@ void* XcclGetPointerByOffset(void* raw_pointer,
return nullptr;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
// NOTE: this is ONLY for compatibility
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t length) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All inputs should be in CustomPlace(%s).", device_type_));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
platform::errors::InvalidArgument(
"All outputs should be in CustomPlace(%s).", device_type_));
int64_t numel,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective(
in_tensors,
out_tensors,
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
......@@ -285,7 +279,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather_Partial(
device_type_,
XcclGetPointerByOffset(input.data(), offset, input.dtype()),
output.data(),
length,
numel,
phi::ccl::ToCCLDataType(input.dtype()),
comm,
stream);
......
......@@ -72,14 +72,15 @@ class ProcessGroupCustom : public ProcessGroup {
std::string GetBackendName() const override { return "XCCL_" + device_type_; }
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int64_t offset,
int64_t length) override;
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
......
......@@ -393,6 +393,8 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset, // for compatibility, no use now
int64_t numel, // for compatibility, no use now
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper = {in_tensor};
std::vector<phi::DenseTensor> out_wrapper = {*out_tensor};
......
......@@ -110,6 +110,8 @@ class ProcessGroupGloo : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset, // for compatibility, no use now
int64_t numel, // for compatibility, no use now
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/utils/allocator.h"
......@@ -129,15 +130,20 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
// numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
in_tensor_maybe_partial,
[](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclAllGather(
input.data(),
output->data(),
......@@ -229,48 +235,25 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
phi::DenseTensor* tensor,
int src_rank,
bool sync_op,
bool use_calc_stream) {
return PointToPoint(
tensor,
src_rank,
[&](phi::DenseTensor* output,
int src,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclRecv(
output->data(),
output->numel(),
platform::ToNCCLDataType(output->dtype()),
src,
comm,
stream);
},
CommType::RECV,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
phi::DenseTensor tensor_flattened;
tensor_flattened.ShareDataWith(*tensor).Resize({tensor->numel()});
phi::DenseTensor tensor_recv =
tensor_flattened.Slice(offset, offset + length);
// numel > 0 indicates the tensor need to be sliced
phi::DenseTensor partial_tensor;
if (numel > 0) {
partial_tensor = GetPartialTensor(*tensor, offset, numel);
tensor = &partial_tensor;
}
return PointToPoint(
&tensor_recv,
tensor,
src_rank,
[&](phi::DenseTensor* output,
int src,
ncclComm_t comm,
gpuStream_t stream) {
[](phi::DenseTensor* output,
int src,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclRecv(
output->data(),
output->numel(),
......@@ -285,48 +268,25 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RecvPartial(
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
phi::DenseTensor* tensor,
int dst_rank,
bool sync_op,
bool use_calc_stream) {
return PointToPoint(
tensor,
dst_rank,
[&](phi::DenseTensor* input,
int dst,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclSend(
input->data(),
input->numel(),
platform::ToNCCLDataType(input->dtype()),
dst,
comm,
stream);
},
CommType::SEND,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
phi::DenseTensor tensor_flattened;
tensor_flattened.ShareDataWith(*tensor).Resize({tensor->numel()});
phi::DenseTensor tensor_send =
tensor_flattened.Slice(offset, offset + length);
// numel > 0 indicates the tensor need to be sliced
phi::DenseTensor partial_tensor;
if (numel > 0) {
partial_tensor = GetPartialTensor(*tensor, offset, numel);
tensor = &partial_tensor;
}
return PointToPoint(
&tensor_send,
tensor,
dst_rank,
[&](phi::DenseTensor* input,
int dst,
ncclComm_t comm,
gpuStream_t stream) {
[](phi::DenseTensor* input,
int dst,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclSend(
input->data(),
input->numel(),
......@@ -1041,132 +1001,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
phi::DenseTensor& tensors, int dst_rank, int64_t offset, int64_t length) {
// CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});
std::vector<phi::DenseTensor> shared_tensors{
flatten_tensor.Slice(offset, offset + length)};
auto task = PointToPoint(
shared_tensors,
[&](phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream,
int dst_rank) {
return platform::dynload::ncclSend(
input.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
},
dst_rank,
CommType::SEND);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
phi::DenseTensor& tensors,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});
std::vector<phi::DenseTensor> shared_tensors{
flatten_tensor.Slice(offset, offset + length)};
auto task = PointToPoint(
shared_tensors,
[&](phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream,
int dst_rank) {
return platform::dynload::ncclSend(
input.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
},
dst_rank,
CommType::SEND,
sync_op,
use_calc_stream);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv_Partial(
phi::DenseTensor& tensors, int src_rank, int64_t offset, int64_t length) {
// phi::DenseTensor shared_input = tensors.Slice(offset, offset+length);
phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});
std::vector<phi::DenseTensor> shared_tensors{
flatten_tensor.Slice(offset, offset + length)};
auto task = PointToPoint(
shared_tensors,
[&](phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream,
int src_rank) {
return platform::dynload::ncclRecv(
output.data(),
output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
},
src_rank,
CommType::RECV);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv_Partial(
phi::DenseTensor& tensors,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});
std::vector<phi::DenseTensor> shared_tensors{
flatten_tensor.Slice(offset, offset + length)};
auto task = PointToPoint(
shared_tensors,
[&](phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream,
int src_rank) {
return platform::dynload::ncclRecv(
output.data(),
output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
},
src_rank,
CommType::RECV,
sync_op,
use_calc_stream);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
......@@ -1228,77 +1062,11 @@ void* GetPointerByOffset(void* raw_pointer,
offset);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported."));
"Datatype %s in NCCL is not supported.", type));
}
return nullptr;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int64_t offset,
int64_t length) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllGather(
GetPointerByOffset(input.data(), offset, input.dtype()),
output.data(),
length,
platform::ToNCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllGather(
GetPointerByOffset(input.data(), offset, input.dtype()),
output.data(),
length,
platform::ToNCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
......
......@@ -97,6 +97,8 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) override;
......@@ -119,30 +121,18 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) override;
static void GroupStart();
static void GroupEnd();
......@@ -167,50 +157,10 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor& tensors,
int dst_rank,
int64_t offset,
int64_t length) override;
std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor& tensors,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv_Partial(phi::DenseTensor& tensors,
int src_rank,
int64_t offset,
int64_t length) override;
std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor& tensors,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int64_t offset,
int64_t length) override;
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
......
......@@ -22,16 +22,20 @@ ProcessGroupStream::ProcessGroupStream(int rank, int size, int gid)
const phi::DeviceContext& ProcessGroupStream::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support get device_context.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op) {
return AllGather(out_tensor,
in_tensor,
offset,
numel,
sync_op,
/*use_calc_stream*/ false);
}
......@@ -39,10 +43,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do all_gather", GetBackendName()));
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support all_gather.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
......@@ -63,8 +69,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do all_reduce", GetBackendName()));
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support all_reduce.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
......@@ -85,14 +91,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do broadcast", GetBackendName()));
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support broadcast.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
phi::DenseTensor* tensor, int src_rank, bool sync_op) {
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op) {
return Recv(tensor,
src_rank,
offset,
numel,
sync_op,
/*use_calc_stream*/ false);
}
......@@ -100,74 +112,37 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do recv", GetBackendName()));
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support recv.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::RecvPartial(
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
phi::DenseTensor* tensor,
int src_rank,
int dst_rank,
int64_t offset,
int64_t length,
int64_t numel,
bool sync_op) {
return RecvPartial(tensor,
src_rank,
offset,
length,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do recv_partial", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
phi::DenseTensor* tensor, int dst_rank, bool sync_op) {
return Send(tensor,
dst_rank,
offset,
numel,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
phi::DenseTensor*, int dst_rank, bool sync_op, bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do send", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::SendPartial(
phi::DenseTensor* tensor,
phi::DenseTensor*,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op) {
return SendPartial(tensor,
dst_rank,
offset,
length,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do send partial", GetBackendName()));
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support send.", GetBackendName()));
}
// TODO(sunyilun): methods below will be removed later
......@@ -281,31 +256,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
"ProcessGroup%s does not support do scatter", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send_Partial(
phi::DenseTensor& tensors,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op) {
return Send_Partial(tensors,
dst_rank,
offset,
length,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send_Partial(
phi::DenseTensor& tensors,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do send_partial", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank, bool sync_op) {
return Recv(tensors,
......@@ -323,55 +273,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
"ProcessGroup%s does not support do recv", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv_Partial(
phi::DenseTensor& tensors,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op) {
return Recv_Partial(tensors,
src_rank,
offset,
length,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv_Partial(
phi::DenseTensor& tensors,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do recv_partial", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int64_t offset,
int64_t length,
bool sync_op) {
return AllGather_Partial(in_tensors,
out_tensors,
offset,
length,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do recv_partial", GetBackendName()));
}
} // namespace distributed
} // namespace paddle
......@@ -64,11 +64,15 @@ class ProcessGroupStream : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream);
......@@ -100,50 +104,30 @@ class ProcessGroupStream : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> RecvPartial(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> SendPartial(phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
......@@ -210,21 +194,6 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor& tensors, // NOLINT
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor& tensors, // NOLINT
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, // NOLINT
int src_rank,
......@@ -235,36 +204,6 @@ class ProcessGroupStream : public ProcessGroup {
int src_rank,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor& tensors, // NOLINT
int src_rank,
int64_t offset,
int64_t length,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor& tensors, // NOLINT
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int64_t offset,
int64_t length,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream);
};
} // namespace distributed
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
namespace distributed {
inline phi::DenseTensor GetPartialTensor(const phi::DenseTensor &tensor,
int64_t offset,
int64_t numel) {
phi::DenseTensor tensor_flattened;
tensor_flattened.ShareDataWith(tensor);
tensor_flattened.Resize({tensor.numel()});
return tensor_flattened.Slice(offset, offset + numel);
}
} // namespace distributed
} // namespace paddle
......@@ -226,15 +226,19 @@ struct GlobalGatherProcessGroupFunctor<phi::GPUContext, T> {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
phi::DenseTensor tmp = *x;
pg->Send_Partial(
tmp, j, send_ptr * in_feat, cpu_global_count_data[idx] * in_feat);
pg->Send(&tmp,
j,
send_ptr * in_feat,
cpu_global_count_data[idx] * in_feat,
/*sync_op*/ true);
send_ptr += cpu_global_count_data[idx];
}
if (cpu_local_count_data[idx]) {
pg->Recv_Partial(*out,
j,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat);
pg->Recv(out,
j,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat,
/*sync_op*/ true);
}
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
......
......@@ -224,16 +224,18 @@ struct GlobalScatterProcessGroupFunctor<phi::GPUContext, T> {
int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) {
phi::DenseTensor tmp = *x;
pg->Send_Partial(tmp,
j,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat);
pg->Send(&tmp,
j,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat,
/*sync_op*/ true);
}
if (cpu_global_count_data[idx]) {
pg->Recv_Partial(*out,
j,
recv_ptr * in_feat,
cpu_global_count_data[idx] * in_feat);
pg->Recv(out,
j,
recv_ptr * in_feat,
cpu_global_count_data[idx] * in_feat,
/*sync_op*/ true);
recv_ptr += cpu_global_count_data[idx];
}
}
......
......@@ -67,12 +67,7 @@ class PartialAllGatherOpCUDAKernel : public framework::OpKernel<T> {
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensors;
std::vector<phi::DenseTensor> out_tensors;
in_tensors.push_back(*in);
out_tensors.push_back(*out);
auto task =
pg->AllGather_Partial(in_tensors, out_tensors, offset, send_numel);
auto task = pg->AllGather(out, *in, offset, send_numel, /*sync_op*/ true);
task->Wait();
} else {
const T* send_buff = in->data<T>() + offset;
......
......@@ -75,7 +75,7 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel<T> {
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup *pg = map->get(rid);
auto task = pg->Recv_Partial(*out, peer, offset, recv_numel);
auto task = pg->Recv(out, peer, offset, recv_numel, /*sync_op*/ true);
task->Wait();
} else {
gpuStream_t stream = nullptr;
......
......@@ -70,7 +70,7 @@ class PartialSendCUDAKernel : public framework::OpKernel<T> {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
phi::DenseTensor tmp = *x;
auto task = pg->Send_Partial(tmp, peer, offset, send_numel);
auto task = pg->Send(&tmp, peer, offset, send_numel, /*sync_op*/ true);
task->Wait();
} else {
gpuStream_t stream = nullptr;
......
......@@ -24,13 +24,13 @@ limitations under the License. */
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/distributed/collective/Utils.h"
#include "paddle/fluid/distributed/collective/reducer.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/pybind/distributed_py.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/process_group_utils.h"
#include "paddle/phi/api/all.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
......@@ -171,7 +171,9 @@ void BindDistributed(py::module *m) {
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
return self.Send(out_dense, dst, sync_op);
// numel == -1 indicates sending the whole tensor
return self.Send(
out_dense, dst, /*offset*/ 0, /*numel*/ -1, sync_op);
},
py::arg("tensor"),
py::arg("dst"),
......@@ -189,18 +191,20 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
int64_t numel = p_dense->numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
auto *out_dense = p_dense.get();
return self.SendPartial(
return self.Send(
out_dense, dst_rank, offset, send_numel, sync_op);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("num"),
py::arg("id"),
py::arg("sync_op"),
py::arg("sync_op") = true,
py::call_guard<py::gil_scoped_release>())
.def(
......@@ -213,7 +217,9 @@ void BindDistributed(py::module *m) {
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *in_dense = p_dense.get();
return self.Recv(in_dense, src, sync_op);
// numel == -1 indicates receiving the whole tensor
return self.Recv(
in_dense, src, /*offset*/ 0, /*numel*/ -1, sync_op);
},
py::arg("tensor"),
py::arg("src"),
......@@ -231,18 +237,20 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
int64_t numel = p_dense->numel();
int64_t recv_numel = numel / nranks;
int64_t offset = recv_numel * rank_id;
auto *out_dense = p_dense.get();
return self.RecvPartial(
return self.Recv(
out_dense, src_rank, offset, recv_numel, sync_op);
},
py::arg("tensor"),
py::arg("src"),
py::arg("num"),
py::arg("id"),
py::arg("sync_op"),
py::arg("sync_op") = true,
py::call_guard<py::gil_scoped_release>())
.def(
......@@ -264,7 +272,11 @@ void BindDistributed(py::module *m) {
auto in_dense = *p_in_tensor;
const auto &dev_ctx = self.GetDeviceContext(in_tensor.place());
auto task = self.AllGather(out_dense, in_dense, sync_op);
auto task = self.AllGather(out_dense,
in_dense,
/*offset*/ 0,
/*numel*/ -1,
sync_op);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx);
return task;
......@@ -290,7 +302,11 @@ void BindDistributed(py::module *m) {
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllGather(out_dense, in_dense, sync_op);
return self.AllGather(out_dense,
in_dense,
/*offset*/ 0,
/*numel*/ -1,
sync_op);
},
py::arg("out"),
py::arg("in"),
......@@ -571,27 +587,6 @@ void BindDistributed(py::module *m) {
py::arg("dst"),
py::call_guard<py::gil_scoped_release>())
.def(
"send_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int dst_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
return self.Send_Partial(*dense, dst_rank, offset, send_numel);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv",
[](distributed::ProcessGroup &self,
......@@ -607,27 +602,6 @@ void BindDistributed(py::module *m) {
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int src_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t recv_numel = numel / nranks;
int64_t offset = recv_numel * rank_id;
return self.Recv_Partial(*dense, src_rank, offset, recv_numel);
},
py::arg("tensor"),
py::arg("src"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"all_gather",
[](distributed::ProcessGroup &self,
......@@ -650,26 +624,28 @@ void BindDistributed(py::module *m) {
.def(
"all_gather_partial",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
py::handle py_in_tensor,
int nranks,
int rank_id) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
int64_t numel = (*in_dense).numel();
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
int64_t numel = in_dense.numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
return self.AllGather_Partial(
in_tensors, out_tensors, offset, send_numel);
return self.AllGather(
out_dense, in_dense, offset, send_numel, /*sync_op*/ true);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
......@@ -785,6 +761,8 @@ void BindDistributed(py::module *m) {
self.GetDeviceContext(in_tensor.place(), true);
auto task = self.AllGather(out_dense,
in_dense,
/*offset*/ 0,
/*numel*/ -1,
/*sync_op*/ true,
/*use_calc_stream*/ true);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
......@@ -811,6 +789,8 @@ void BindDistributed(py::module *m) {
return self.AllGather(out_dense,
in_dense,
/*offset*/ 0,
/*numel*/ -1,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
......@@ -821,30 +801,33 @@ void BindDistributed(py::module *m) {
.def(
"all_gather_partial_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
py::handle py_in_tensor,
int nranks,
int rank_id) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
int64_t numel = (*in_dense).numel();
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
int64_t numel = in_dense.numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
return self.AllGather_Partial(in_tensors,
out_tensors,
offset,
send_numel,
/*sync_op*/ true,
/*use_calc_stream*/ true);
return self.AllGather(out_dense,
in_dense,
offset,
send_numel,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
......@@ -1125,8 +1108,11 @@ void BindDistributed(py::module *m) {
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
// numel == -1 indicates sending the whole tensor
return self.Send(out_dense,
dst,
/*offset*/ 0,
/*numel*/ -1,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
......@@ -1144,16 +1130,18 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
int64_t numel = p_dense->numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
auto *out_dense = p_dense.get();
return self.SendPartial(out_dense,
dst_rank,
offset,
send_numel,
/*sync_op*/ true,
/*use_calc_stream*/ true);
return self.Send(out_dense,
dst_rank,
offset,
send_numel,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("tensor"),
py::arg("dst"),
......@@ -1170,8 +1158,11 @@ void BindDistributed(py::module *m) {
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *in_dense = p_dense.get();
// numel == -1 indicates receiving the whole tensor
return self.Recv(in_dense,
src,
/*offset*/ 0,
/*numel*/ -1,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
......@@ -1189,16 +1180,18 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
int64_t numel = p_dense->numel();
int64_t recv_numel = numel / nranks;
int64_t offset = recv_numel * rank_id;
auto *out_dense = p_dense.get();
return self.RecvPartial(out_dense,
src_rank,
offset,
recv_numel,
/*sync_op*/ true,
/*use_calc_stream*/ true);
return self.Recv(out_dense,
src_rank,
offset,
recv_numel,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("tensor"),
py::arg("src"),
......
......@@ -14,10 +14,10 @@
#pragma once
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
namespace paddle {
......@@ -110,7 +110,7 @@ void ConcatDenseTensorWithType(const DeviceContext &dev_ctx,
ConcatDenseTensor<DeviceContext, int64_t>()(dev_ctx, t_list, p_out);
break;
case phi::DataType::FLOAT16:
ConcatDenseTensor<DeviceContext, platform::float16>()(
ConcatDenseTensor<DeviceContext, phi::dtype::float16>()(
dev_ctx, t_list, p_out);
break;
case phi::DataType::FLOAT32:
......@@ -147,7 +147,7 @@ void SplitDenseTensorWithType(const DeviceContext &dev_ctx,
SplitDenseTensor<DeviceContext, int64_t>()(dev_ctx, t_in, p_list);
break;
case phi::DataType::FLOAT16:
SplitDenseTensor<DeviceContext, platform::float16>()(
SplitDenseTensor<DeviceContext, phi::dtype::float16>()(
dev_ctx, t_in, p_list);
break;
case phi::DataType::FLOAT32:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册