未验证 提交 d926c270 编写于 作者: W Wen Sun 提交者: GitHub

Refactor collective communication P2P C++ API (#47801)

* refactor: send, recv, send_partial, recv_partial

* refactor: rm useless const ref
上级 1b250710
...@@ -139,6 +139,44 @@ class ProcessGroup { ...@@ -139,6 +139,44 @@ class ProcessGroup {
GetBackendName())); GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support recv with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support recv_partial with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor*,
int dst_rank,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send_partial with sync_op flag",
GetBackendName()));
}
// TODO(liyurui): This API will be moved later // TODO(liyurui): This API will be moved later
virtual std::shared_ptr<ProcessGroup::Task> AllReduce( virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
......
...@@ -139,7 +139,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather( ...@@ -139,7 +139,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
[&](phi::DenseTensor* output, [&](phi::DenseTensor* output,
const phi::DenseTensor& input, const phi::DenseTensor& input,
ncclComm_t comm, ncclComm_t comm,
const gpuStream_t& stream) { gpuStream_t stream) {
return platform::dynload::ncclAllGather( return platform::dynload::ncclAllGather(
input.data(), input.data(),
output->data(), output->data(),
...@@ -165,7 +165,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce( ...@@ -165,7 +165,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
[&](phi::DenseTensor* output, [&](phi::DenseTensor* output,
const phi::DenseTensor& input, const phi::DenseTensor& input,
ncclComm_t comm, ncclComm_t comm,
const gpuStream_t& stream) { gpuStream_t stream) {
return platform::dynload::ncclAllReduce( return platform::dynload::ncclAllReduce(
input.data(), input.data(),
output->data(), output->data(),
...@@ -209,7 +209,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( ...@@ -209,7 +209,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
[&](phi::DenseTensor* output, [&](phi::DenseTensor* output,
const phi::DenseTensor& input, const phi::DenseTensor& input,
ncclComm_t comm, ncclComm_t comm,
const gpuStream_t& stream) { gpuStream_t stream) {
int root = opts.source_rank + opts.source_root; int root = opts.source_rank + opts.source_root;
return platform::dynload::ncclBroadcast( return platform::dynload::ncclBroadcast(
input.data(), input.data(),
...@@ -225,6 +225,118 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( ...@@ -225,6 +225,118 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
use_calc_stream); use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
phi::DenseTensor* tensor,
int src_rank,
bool sync_op,
bool use_calc_stream) {
return PointToPoint(
tensor,
src_rank,
[&](phi::DenseTensor* output,
int src,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclRecv(
output->data(),
output->numel(),
platform::ToNCCLDataType(output->dtype()),
src,
comm,
stream);
},
CommType::RECV,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
phi::DenseTensor tensor_flattened;
tensor_flattened.ShareDataWith(*tensor).Resize({tensor->numel()});
phi::DenseTensor tensor_recv =
tensor_flattened.Slice(offset, offset + length);
return PointToPoint(
&tensor_recv,
src_rank,
[&](phi::DenseTensor* output,
int src,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclRecv(
output->data(),
output->numel(),
platform::ToNCCLDataType(output->dtype()),
src,
comm,
stream);
},
CommType::RECV,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
phi::DenseTensor* tensor,
int dst_rank,
bool sync_op,
bool use_calc_stream) {
return PointToPoint(
tensor,
dst_rank,
[&](phi::DenseTensor* input,
int dst,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclSend(
input->data(),
input->numel(),
platform::ToNCCLDataType(input->dtype()),
dst,
comm,
stream);
},
CommType::SEND,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
phi::DenseTensor tensor_flattened;
tensor_flattened.ShareDataWith(*tensor).Resize({tensor->numel()});
phi::DenseTensor tensor_send =
tensor_flattened.Slice(offset, offset + length);
return PointToPoint(
&tensor_send,
dst_rank,
[&](phi::DenseTensor* input,
int dst,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclSend(
input->data(),
input->numel(),
platform::ToNCCLDataType(input->dtype()),
dst,
comm,
stream);
},
CommType::SEND,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask( std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
const Place& place, const Place& place,
int rank, int rank,
...@@ -331,6 +443,45 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -331,6 +443,45 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
return task; return task;
} }
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
phi::DenseTensor* tensor,
int rank,
Fn fn,
CommType comm_type,
bool sync_op,
bool use_calc_stream) {
const auto& place = tensor->place();
const auto& key = GetKeyFromPlace(place);
platform::CUDADeviceGuard cuda_guard(place);
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLEnvCache(place, key);
}
if (!use_calc_stream) {
SyncCalcStream(place);
}
auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);
const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto& comm_ctx = place_to_comm_ctx_.at(key);
auto nccl_comm = comm_ctx->nccl_comm();
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
fn(tensor, rank, nccl_comm, nccl_stream);
if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
memory::RecordStream(tensor->Holder(), nccl_stream);
}
task->UpdateWaitChain(*comm_ctx);
}
return task;
}
void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes, void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
std::vector<int64_t> tensor_shape) { std::vector<int64_t> tensor_shape) {
int64_t len_size = (*split_sizes).size(); int64_t len_size = (*split_sizes).size();
...@@ -864,34 +1015,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send( ...@@ -864,34 +1015,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
return task; return task;
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
std::vector<phi::DenseTensor>& tensors,
int dst_rank,
bool sync_op,
bool use_calc_stream) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
auto task = PointToPoint(
tensors,
[&](phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream,
int dst_rank) {
return platform::dynload::ncclSend(
input.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
},
dst_rank,
CommType::SEND,
sync_op,
use_calc_stream);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) { std::vector<phi::DenseTensor>& tensors, int src_rank) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize())); CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
...@@ -915,34 +1038,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv( ...@@ -915,34 +1038,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
return task; return task;
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
std::vector<phi::DenseTensor>& tensors,
int src_rank,
bool sync_op,
bool use_calc_stream) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
auto task = PointToPoint(
tensors,
[&](phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream,
int src_rank) {
return platform::dynload::ncclRecv(
output.data(),
output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
},
src_rank,
CommType::RECV,
sync_op,
use_calc_stream);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
phi::DenseTensor& tensors, int dst_rank, int64_t offset, int64_t length) { phi::DenseTensor& tensors, int dst_rank, int64_t offset, int64_t length) {
// CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize())); // CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
......
...@@ -118,6 +118,32 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -118,6 +118,32 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor,
int dst_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) override;
static void GroupStart(); static void GroupStart();
static void GroupEnd(); static void GroupEnd();
...@@ -139,21 +165,9 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -139,21 +165,9 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
std::shared_ptr<ProcessGroup::Task> Send( std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override; std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors,
int dst_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv( std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override; std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors,
int src_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor& tensors, std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor& tensors,
int dst_rank, int dst_rank,
int64_t offset, int64_t offset,
...@@ -273,6 +287,14 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -273,6 +287,14 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint(phi::DenseTensor* tensor,
int rank,
Fn fn,
CommType op_type,
bool sync_op,
bool use_calc_stream);
void SyncCalcStream(const Place& place); void SyncCalcStream(const Place& place);
// TODO(sunyilun): methods below will be removed later // TODO(sunyilun): methods below will be removed later
......
...@@ -92,6 +92,87 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast( ...@@ -92,6 +92,87 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
"ProcessGroup%s does not support do broadcast", GetBackendName())); "ProcessGroup%s does not support do broadcast", GetBackendName()));
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
phi::DenseTensor* tensor, int src_rank, bool sync_op) {
return Recv(tensor,
src_rank,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
phi::DenseTensor* tensor,
int src_rank,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do recv", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op) {
return RecvPartial(tensor,
src_rank,
offset,
length,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do recv_partial", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
phi::DenseTensor* tensor, int dst_rank, bool sync_op) {
return Send(tensor,
dst_rank,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
phi::DenseTensor*, int dst_rank, bool sync_op, bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do send", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op) {
return SendPartial(tensor,
dst_rank,
offset,
length,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do send partial", GetBackendName()));
}
// TODO(sunyilun): methods below will be removed later // TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
...@@ -203,23 +284,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter( ...@@ -203,23 +284,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
"ProcessGroup%s does not support do scatter", GetBackendName())); "ProcessGroup%s does not support do scatter", GetBackendName()));
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank, bool sync_op) {
return Send(tensors,
dst_rank,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
std::vector<phi::DenseTensor>& tensors,
int dst_rank,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do send", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send_Partial( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send_Partial(
phi::DenseTensor& tensors, phi::DenseTensor& tensors,
int dst_rank, int dst_rank,
......
...@@ -98,6 +98,52 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -98,6 +98,52 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> RecvPartial(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor,
int dst_rank,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor,
int dst_rank,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> SendPartial(phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream);
// TODO(sunyilun): methods below will be removed later // TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> AllToAll( std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT std::vector<phi::DenseTensor>& in_tensors, // NOLINT
...@@ -164,17 +210,6 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -164,17 +210,6 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, // NOLINT
int dst_rank,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, // NOLINT
int dst_rank,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Send_Partial( std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor& tensors, // NOLINT phi::DenseTensor& tensors, // NOLINT
int dst_rank, int dst_rank,
......
...@@ -173,10 +173,10 @@ void BindDistributed(py::module *m) { ...@@ -173,10 +173,10 @@ void BindDistributed(py::module *m) {
int dst, int dst,
bool sync_op) { bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense = auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense}; auto *out_dense = p_dense.get();
return self.Send(tensors, dst, sync_op); return self.Send(out_dense, dst, sync_op);
}, },
py::arg("tensor"), py::arg("tensor"),
py::arg("dst"), py::arg("dst"),
...@@ -192,13 +192,14 @@ void BindDistributed(py::module *m) { ...@@ -192,13 +192,14 @@ void BindDistributed(py::module *m) {
int rank_id, int rank_id,
bool sync_op) { bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense = auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel(); int64_t numel = p_dense->numel();
int64_t send_numel = numel / nranks; int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id; int64_t offset = send_numel * rank_id;
return self.Send_Partial( auto *out_dense = p_dense.get();
*dense, dst_rank, offset, send_numel, sync_op); return self.SendPartial(
out_dense, dst_rank, offset, send_numel, sync_op);
}, },
py::arg("tensor"), py::arg("tensor"),
py::arg("dst"), py::arg("dst"),
...@@ -214,10 +215,10 @@ void BindDistributed(py::module *m) { ...@@ -214,10 +215,10 @@ void BindDistributed(py::module *m) {
int src, int src,
bool sync_op) { bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense = auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense}; auto *in_dense = p_dense.get();
return self.Recv(tensors, src, sync_op); return self.Recv(in_dense, src, sync_op);
}, },
py::arg("tensor"), py::arg("tensor"),
py::arg("src"), py::arg("src"),
...@@ -233,13 +234,14 @@ void BindDistributed(py::module *m) { ...@@ -233,13 +234,14 @@ void BindDistributed(py::module *m) {
int rank_id, int rank_id,
bool sync_op) { bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense = auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel(); int64_t numel = p_dense->numel();
int64_t recv_numel = numel / nranks; int64_t recv_numel = numel / nranks;
int64_t offset = recv_numel * rank_id; int64_t offset = recv_numel * rank_id;
return self.Recv_Partial( auto *out_dense = p_dense.get();
*dense, src_rank, offset, recv_numel, sync_op); return self.RecvPartial(
out_dense, src_rank, offset, recv_numel, sync_op);
}, },
py::arg("tensor"), py::arg("tensor"),
py::arg("src"), py::arg("src"),
...@@ -1125,10 +1127,10 @@ void BindDistributed(py::module *m) { ...@@ -1125,10 +1127,10 @@ void BindDistributed(py::module *m) {
py::handle py_tensor, py::handle py_tensor,
int dst) { int dst) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense = auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense}; auto *out_dense = p_dense.get();
return self.Send(tensors, return self.Send(out_dense,
dst, dst,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
...@@ -1145,12 +1147,13 @@ void BindDistributed(py::module *m) { ...@@ -1145,12 +1147,13 @@ void BindDistributed(py::module *m) {
int nranks, int nranks,
int rank_id) { int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense = auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel(); int64_t numel = p_dense->numel();
int64_t send_numel = numel / nranks; int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id; int64_t offset = send_numel * rank_id;
return self.Send_Partial(*dense, auto *out_dense = p_dense.get();
return self.SendPartial(out_dense,
dst_rank, dst_rank,
offset, offset,
send_numel, send_numel,
...@@ -1169,10 +1172,10 @@ void BindDistributed(py::module *m) { ...@@ -1169,10 +1172,10 @@ void BindDistributed(py::module *m) {
py::handle py_tensor, py::handle py_tensor,
int src) { int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense = auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense}; auto *in_dense = p_dense.get();
return self.Recv(tensors, return self.Recv(in_dense,
src, src,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
...@@ -1189,12 +1192,13 @@ void BindDistributed(py::module *m) { ...@@ -1189,12 +1192,13 @@ void BindDistributed(py::module *m) {
int nranks, int nranks,
int rank_id) { int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense = auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel(); int64_t numel = p_dense->numel();
int64_t recv_numel = numel / nranks; int64_t recv_numel = numel / nranks;
int64_t offset = recv_numel * rank_id; int64_t offset = recv_numel * rank_id;
return self.Recv_Partial(*dense, auto *out_dense = p_dense.get();
return self.RecvPartial(out_dense,
src_rank, src_rank,
offset, offset,
recv_numel, recv_numel,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册