未验证 提交 18c77325 编写于 作者: H Haohongxiang 提交者: GitHub

support send_partial, recv_partial and allgather_partial in ProcessGroupNCCL (#44444)

上级 ea4b2c5e
...@@ -137,6 +137,15 @@ class ProcessGroup { ...@@ -137,6 +137,15 @@ class ProcessGroup {
"ProcessGroup%s does not support AllGather", GetBackendName())); "ProcessGroup%s does not support AllGather", GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
int offset,
int length) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllGather_Partial", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllToAll( virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT std::vector<phi::DenseTensor>&) { // NOLINT
......
...@@ -85,18 +85,19 @@ bool ProcessGroupNCCL::NCCLTask::IsCompleted() { ...@@ -85,18 +85,19 @@ bool ProcessGroupNCCL::NCCLTask::IsCompleted() {
return true; return true;
} }
void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>& split_sizes, void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
std::vector<int64_t> tensor_shape) { std::vector<int64_t> tensor_shape) {
int64_t len_size = split_sizes.size(); int64_t len_size = (*split_sizes).size();
if (len_size == 0) { if (len_size == 0) {
PADDLE_ENFORCE_EQ(tensor_shape[0] % size_ == 0, PADDLE_ENFORCE_EQ(tensor_shape[0] % size_ == 0,
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Tensor's dim[0] must be divisible by group size " "Tensor's dim[0] must be divisible by group size "
"when split_sizes not given.")); "when split_sizes not given."));
split_sizes.insert(split_sizes.end(), (*split_sizes)
size_, .insert((*split_sizes).end(),
static_cast<int64_t>(tensor_shape[0] / size_)); size_,
static_cast<int64_t>(tensor_shape[0] / size_));
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
len_size == size_, len_size == size_,
...@@ -104,7 +105,7 @@ void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>& split_sizes, ...@@ -104,7 +105,7 @@ void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>& split_sizes,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The length of split_sizes must be equal to group size.")); "The length of split_sizes must be equal to group size."));
auto sum_size = std::accumulate( auto sum_size = std::accumulate(
split_sizes.begin(), split_sizes.end(), static_cast<int64_t>(0)); (*split_sizes).begin(), (*split_sizes).end(), static_cast<int64_t>(0));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
sum_size == tensor_shape[0], sum_size == tensor_shape[0],
true, true,
...@@ -626,6 +627,37 @@ void* GetPointerByOffset(void* raw_pointer, ...@@ -626,6 +627,37 @@ void* GetPointerByOffset(void* raw_pointer,
return nullptr; return nullptr;
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int offset,
int length) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllGather(
GetPointerByOffset(input.data(), offset, input.dtype()),
output.data(),
length,
platform::ToNCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) { std::vector<phi::DenseTensor>& out_tensors) {
...@@ -695,8 +727,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll_Single( ...@@ -695,8 +727,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll_Single(
std::vector<int64_t> in_dims = phi::vectorize(input.dims()); std::vector<int64_t> in_dims = phi::vectorize(input.dims());
std::vector<int64_t> out_dims = phi::vectorize(output.dims()); std::vector<int64_t> out_dims = phi::vectorize(output.dims());
CheckSplitSizes(in_sizes, in_dims); CheckSplitSizes(&in_sizes, in_dims);
CheckSplitSizes(out_sizes, out_dims); CheckSplitSizes(&out_sizes, out_dims);
size_t in_offset = 0, out_offset = 0; size_t in_offset = 0, out_offset = 0;
size_t in_length = 0, out_length = 0; size_t in_length = 0, out_length = 0;
......
...@@ -125,6 +125,12 @@ class ProcessGroupNCCL : public ProcessGroup { ...@@ -125,6 +125,12 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override; std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int offset,
int length) override;
std::shared_ptr<ProcessGroup::Task> AllToAll( std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in, std::vector<phi::DenseTensor>& in,
std::vector<phi::DenseTensor>& out) override; std::vector<phi::DenseTensor>& out) override;
...@@ -206,7 +212,7 @@ class ProcessGroupNCCL : public ProcessGroup { ...@@ -206,7 +212,7 @@ class ProcessGroupNCCL : public ProcessGroup {
void CreateNCCLManagerCache(const std::string& places_key, void CreateNCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places); const std::vector<Place>& places);
void CheckSplitSizes(std::vector<int64_t>& split_sizes, void CheckSplitSizes(std::vector<int64_t>* split_sizes,
std::vector<int64_t> tensor_shape); std::vector<int64_t> tensor_shape);
}; };
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/partial_allgather_op.h" #include "paddle/fluid/operators/collective/partial_allgather_op.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
...@@ -61,24 +62,38 @@ class PartialAllGatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -61,24 +62,38 @@ class PartialAllGatherOpCUDAKernel : public framework::OpKernel<T> {
int64_t send_numel = numel / nranks; int64_t send_numel = numel / nranks;
int offset = send_numel * rank; int offset = send_numel * rank;
const T* send_buff = in->data<T>() + offset;
T* recv_buff = out->data<T>();
gpuStream_t stream = nullptr; auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (ctx.Attr<bool>("use_calc_stream")) { if (map->has(rid)) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // Use ProcessGroup
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream(); distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensors;
std::vector<phi::DenseTensor> out_tensors;
in_tensors.push_back(*in);
out_tensors.push_back(*out);
auto task =
pg->AllGather_Partial(in_tensors, out_tensors, offset, send_numel);
task->Wait();
} else { } else {
stream = comm->stream(); const T* send_buff = in->data<T>() + offset;
T* recv_buff = out->data<T>();
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(send_buff,
recv_buff,
send_numel,
static_cast<ncclDataType_t>(dtype),
comm->comm(),
stream));
} }
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(send_buff,
recv_buff,
send_numel,
static_cast<ncclDataType_t>(dtype),
comm->comm(),
stream));
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU.")); "PaddlePaddle should compile with GPU."));
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/partial_recv_op.h" #include "paddle/fluid/operators/collective/partial_recv_op.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
...@@ -65,37 +66,44 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel<T> { ...@@ -65,37 +66,44 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input numel (%d) must be divisible by num(%d)", numel, num)); "The input numel (%d) must be divisible by num(%d)", numel, num));
gpuStream_t stream = nullptr;
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext *>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE_LT(
peer,
comm->nranks(),
platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm->nranks()));
out->mutable_data<T>(out_dims, place); out->mutable_data<T>(out_dims, place);
ncclDataType_t dtype = platform::ToNCCLDataType(type);
int recv_numel = numel / num; int recv_numel = numel / num;
int offset = recv_numel * id; int offset = recv_numel * id;
PADDLE_ENFORCE_GPU_SUCCESS( auto map = distributed::ProcessGroupMapFromGid::getInstance();
platform::dynload::ncclRecv(out->data<T>() + offset, if (map->has(rid)) {
recv_numel, // Use ProcessGroup
dtype, distributed::ProcessGroup *pg = map->get(rid);
peer, auto task = pg->Recv_Partial(*out, peer, offset, recv_numel);
comm->comm(), task->Wait();
stream)); } else {
VLOG(3) << "rank " << comm->rank() << " recv " << recv_numel gpuStream_t stream = nullptr;
<< " from offset[" << offset << "] from " << peer; auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext *>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE_LT(peer,
comm->nranks(),
platform::errors::InvalidArgument(
"The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm->nranks()));
ncclDataType_t dtype = platform::ToNCCLDataType(type);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclRecv(out->data<T>() + offset,
recv_numel,
dtype,
peer,
comm->comm(),
stream));
VLOG(3) << "rank " << comm->rank() << " recv " << recv_numel
<< " from offset[" << offset << "] from " << peer;
}
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should be compiled with NCCL and " "PaddlePaddle should be compiled with NCCL and "
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/partial_send_op.h" #include "paddle/fluid/operators/collective/partial_send_op.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
...@@ -61,32 +62,47 @@ class PartialSendCUDAKernel : public framework::OpKernel<T> { ...@@ -61,32 +62,47 @@ class PartialSendCUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input numel (%d) must be divisible by num(%d)", numel, num)); "The input numel (%d) must be divisible by num(%d)", numel, num));
gpuStream_t stream = nullptr;
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE_LT(
peer,
comm->nranks(),
platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm->nranks()));
ncclDataType_t dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype()));
int send_numel = numel / num; int send_numel = numel / num;
int offset = send_numel * id; int offset = send_numel * id;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( auto map = distributed::ProcessGroupMapFromGid::getInstance();
x->data<T>() + offset, send_numel, dtype, peer, comm->comm(), stream)); if (map->has(rid)) {
VLOG(3) << "rank " << comm->rank() << " send " << send_numel // Use ProcessGroup
<< " from offset[" << offset << "] to " << peer; distributed::ProcessGroup* pg = map->get(rid);
phi::DenseTensor tmp = *x;
auto task = pg->Send_Partial(tmp, peer, offset, send_numel);
task->Wait();
} else {
gpuStream_t stream = nullptr;
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE_LT(peer,
comm->nranks(),
platform::errors::InvalidArgument(
"The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm->nranks()));
ncclDataType_t dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype()));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclSend(x->data<T>() + offset,
send_numel,
dtype,
peer,
comm->comm(),
stream));
VLOG(3) << "rank " << comm->rank() << " send " << send_numel
<< " from offset[" << offset << "] to " << peer;
}
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should be compiled with NCCL " "PaddlePaddle should be compiled with NCCL "
......
...@@ -172,6 +172,27 @@ void BindDistributed(py::module *m) { ...@@ -172,6 +172,27 @@ void BindDistributed(py::module *m) {
py::arg("dst"), py::arg("dst"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(
"send_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int dst_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int numel = (*dense).numel();
int send_numel = numel / nranks;
int offset = send_numel * rank_id;
return self.Send_Partial(*dense, dst_rank, offset, send_numel);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def( .def(
"recv", "recv",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
...@@ -187,6 +208,27 @@ void BindDistributed(py::module *m) { ...@@ -187,6 +208,27 @@ void BindDistributed(py::module *m) {
py::arg("src"), py::arg("src"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(
"recv_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int src_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int numel = (*dense).numel();
int recv_numel = numel / nranks;
int offset = recv_numel * rank_id;
return self.Recv_Partial(*dense, src_rank, offset, recv_numel);
},
py::arg("tensor"),
py::arg("src"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def( .def(
"all_gather", "all_gather",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
...@@ -206,6 +248,33 @@ void BindDistributed(py::module *m) { ...@@ -206,6 +248,33 @@ void BindDistributed(py::module *m) {
py::arg("out"), py::arg("out"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(
"all_gather_partial",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
int nranks,
int rank_id) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
int numel = (*in_dense).numel();
int send_numel = numel / nranks;
int offset = send_numel * rank_id;
return self.AllGather_Partial(
in_tensors, out_tensors, offset, send_numel);
},
py::arg("in"),
py::arg("out"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def( .def(
"alltoall", "alltoall",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
......
...@@ -158,14 +158,26 @@ _send_recv_meta = SendRecvMeta() ...@@ -158,14 +158,26 @@ _send_recv_meta = SendRecvMeta()
def _is_valid_send_recv_partial(tensor, mp_degree): def _is_valid_send_recv_partial(tensor, mp_degree):
tensor_numel = np.prod(tensor.shape)
assert tensor_numel != 0, "can't send/recv zero element"
return mp_degree > 1 and tensor_numel % mp_degree == 0
def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks,
rank_id):
if _in_legacy_dygraph(): if _in_legacy_dygraph():
tensor_numel = np.prod(tensor.shape) return _C_ops.partial_send(tensor.detach(), 'use_calc_stream',
assert tensor_numel != 0, "can't send/recv zero element" use_calc_stream, 'ring_id', ring_id, 'peer',
return mp_degree > 1 and tensor_numel % mp_degree == 0 dst, 'num', nranks, 'id', rank_id)
elif in_dygraph_mode(): elif in_dygraph_mode():
# TODO(shenliang03) support mp+pp optimizer in future. group = paddle.distributed.collective._get_default_group(
# (partial_send/partial_recv/partial_allgather_) ) if group is None else group
return False task = group.process_group.send_partial(tensor, dst, nranks, rank_id)
if use_calc_stream:
task.wait()
return None
else:
return task
def send_partial(tensor, def send_partial(tensor,
...@@ -180,9 +192,8 @@ def send_partial(tensor, ...@@ -180,9 +192,8 @@ def send_partial(tensor,
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
if _is_valid_send_recv_partial(tensor, nranks): if _is_valid_send_recv_partial(tensor, nranks):
return _C_ops.partial_send(tensor.detach(), 'use_calc_stream', return _partial_send_op(tensor, group, use_calc_stream, ring_id, dst,
use_calc_stream, 'ring_id', ring_id, 'peer', nranks, rank_id)
dst, 'num', nranks, 'id', rank_id)
else: else:
return paddle.distributed.send(tensor.detach(), return paddle.distributed.send(tensor.detach(),
dst=group.ranks[dst], dst=group.ranks[dst],
...@@ -190,6 +201,24 @@ def send_partial(tensor, ...@@ -190,6 +201,24 @@ def send_partial(tensor,
use_calc_stream=use_calc_stream) use_calc_stream=use_calc_stream)
def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks,
rank_id):
if _in_legacy_dygraph():
return _C_ops.partial_recv(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id, 'peer',
src, 'num', nranks, 'id', rank_id, 'dtype',
tensor.dtype, 'out_shape', tensor.shape)
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
task = group.process_group.recv_partial(tensor, src, nranks, rank_id)
if use_calc_stream:
task.wait()
return None
else:
return task
def recv_partial(tensor, def recv_partial(tensor,
src=0, src=0,
nranks=1, nranks=1,
...@@ -202,15 +231,31 @@ def recv_partial(tensor, ...@@ -202,15 +231,31 @@ def recv_partial(tensor,
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
if _is_valid_send_recv_partial(tensor, nranks): if _is_valid_send_recv_partial(tensor, nranks):
_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', use_calc_stream, return _partial_recv_op(tensor, group, use_calc_stream, ring_id, src,
'ring_id', ring_id, 'peer', src, 'num', nranks, nranks, rank_id)
'id', rank_id, 'dtype', tensor.dtype, 'out_shape',
tensor.shape)
else: else:
paddle.distributed.recv(tensor.detach(), return paddle.distributed.recv(tensor.detach(),
src=group.ranks[src], src=group.ranks[src],
group=group, group=group,
use_calc_stream=use_calc_stream) use_calc_stream=use_calc_stream)
def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks,
rank_id):
if _in_legacy_dygraph():
return _C_ops.partial_allgather_(tensor.detach(), 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id,
'nranks', nranks, 'rank', rank_id)
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
task = group.process_group.all_gather_partial(tensor, tensor, nranks,
rank_id)
if use_calc_stream:
task.wait()
return None
else:
return task
def allgather_partial(tensor, def allgather_partial(tensor,
...@@ -224,9 +269,8 @@ def allgather_partial(tensor, ...@@ -224,9 +269,8 @@ def allgather_partial(tensor,
return return
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
return _C_ops.partial_allgather_(tensor.detach(), 'use_calc_stream', return _partial_allgather_op(tensor, group, use_calc_stream, ring_id,
use_calc_stream, 'ring_id', ring_id, nranks, rank_id)
'nranks', nranks, 'rank', rank_id)
def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册