未验证 提交 88410225 编写于 作者: W Wen Sun 提交者: GitHub

Unify `ProcessGroupNCCL` APIs underlying implementation (#48163)

* refactor: replace Collective & PointToPoint with NCCLEnv

* refactor: rename to RunFnInNCCLEnv

* refactor: pass std::function by value
上级 2a47416c
...@@ -185,11 +185,12 @@ class ProcessGroup { ...@@ -185,11 +185,12 @@ class ProcessGroup {
GetBackendName())); GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor*, virtual std::shared_ptr<ProcessGroup::Task> Send(
int dst_rank, const phi::DenseTensor& tensor,
int64_t offset, int dst_rank,
int64_t numel, int64_t offset,
bool sync_op) { int64_t numel,
bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support send with sync_op flag.", "ProcessGroup%s does not support send with sync_op flag.",
GetBackendName())); GetBackendName()));
......
...@@ -137,21 +137,17 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather( ...@@ -137,21 +137,17 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
// numel > 0 indicates the tensor need to be sliced // numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& in_tensor_maybe_partial = const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor; numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
return Collective( return RunFnInNCCLEnv(
out_tensor, [&](ncclComm_t comm, gpuStream_t stream) {
in_tensor_maybe_partial,
[](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclAllGather( NCCL_CHECK(platform::dynload::ncclAllGather(
input.data(), in_tensor_maybe_partial.data(),
output->data(), out_tensor->data(),
input.numel(), in_tensor_maybe_partial.numel(),
platform::ToNCCLDataType(input.dtype()), platform::ToNCCLDataType(in_tensor_maybe_partial.dtype()),
comm, comm,
stream)); stream));
}, },
in_tensor_maybe_partial,
CommType::ALLGATHER, CommType::ALLGATHER,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
...@@ -163,22 +159,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce( ...@@ -163,22 +159,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
const AllreduceOptions& opts, const AllreduceOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
return Collective( return RunFnInNCCLEnv(
out_tensor, [&](ncclComm_t comm, gpuStream_t stream) {
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclAllReduce( NCCL_CHECK(platform::dynload::ncclAllReduce(
input.data(), in_tensor.data(),
output->data(), out_tensor->data(),
input.numel(), in_tensor.numel(),
platform::ToNCCLDataType(input.type()), platform::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op), ToNCCLRedType(opts.reduce_op),
comm, comm,
stream)); stream));
}, },
in_tensor,
CommType::ALLREDUCE, CommType::ALLREDUCE,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
...@@ -215,37 +207,32 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -215,37 +207,32 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
CheckSizeOnEachRank(out_dim, out_size_each_rank, size_); CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
CheckSizeOnEachRank(in_dim, in_size_each_rank, size_); CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);
return Collective( return RunFnInNCCLEnv(
out_tensor, [&](ncclComm_t comm, gpuStream_t stream) {
in_tensor, int64_t in_row_size = in_tensor.numel() / in_dim[0],
[&](phi::DenseTensor* output, out_row_size = out_tensor->numel() / out_dim[0];
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
int64_t in_row_size = input.numel() / in_dim[0],
out_row_size = output->numel() / out_dim[0];
int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0; int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
phi::DenseTensor input_partial, output_partial; phi::DenseTensor input_partial, output_partial;
GroupStart(); GroupStart();
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
in_numel = in_size_each_rank[i] * in_row_size; in_numel = in_size_each_rank[i] * in_row_size;
input_partial = GetPartialTensor(input, in_offset, in_numel); input_partial = GetPartialTensor(in_tensor, in_offset, in_numel);
NCCL_CHECK(platform::dynload::ncclSend( NCCL_CHECK(platform::dynload::ncclSend(
input_partial.data(), input_partial.data(),
in_numel, in_numel,
platform::ToNCCLDataType(input.dtype()), platform::ToNCCLDataType(input_partial.dtype()),
i, i,
comm, comm,
stream)); stream));
in_offset += in_numel; in_offset += in_numel;
out_numel = out_size_each_rank[i] * out_row_size; out_numel = out_size_each_rank[i] * out_row_size;
output_partial = GetPartialTensor(*output, out_offset, out_numel); output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel);
NCCL_CHECK(platform::dynload::ncclRecv( NCCL_CHECK(platform::dynload::ncclRecv(
output_partial.data(), output_partial.data(),
out_numel, out_numel,
platform::ToNCCLDataType(output->dtype()), platform::ToNCCLDataType(output_partial.dtype()),
i, i,
comm, comm,
stream)); stream));
...@@ -253,6 +240,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -253,6 +240,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
} }
GroupEnd(); GroupEnd();
}, },
in_tensor,
CommType::ALLTOALL, CommType::ALLTOALL,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
...@@ -286,23 +274,19 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( ...@@ -286,23 +274,19 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
const BroadcastOptions& opts, const BroadcastOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
return Collective( return RunFnInNCCLEnv(
out_tensor, [&](ncclComm_t comm, gpuStream_t stream) {
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
int root = opts.source_rank + opts.source_root; int root = opts.source_rank + opts.source_root;
NCCL_CHECK(platform::dynload::ncclBroadcast( NCCL_CHECK(platform::dynload::ncclBroadcast(
input.data(), in_tensor.data(),
output->data(), out_tensor->data(),
input.numel(), in_tensor.numel(),
platform::ToNCCLDataType(input.type()), platform::ToNCCLDataType(in_tensor.dtype()),
root, root,
comm, comm,
stream)); stream));
}, },
in_tensor,
CommType::BROADCAST, CommType::BROADCAST,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
...@@ -314,23 +298,19 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce( ...@@ -314,23 +298,19 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
const ReduceOptions& opts, const ReduceOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
return Collective( return RunFnInNCCLEnv(
out_tensor, [&](ncclComm_t comm, gpuStream_t stream) {
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduce( NCCL_CHECK(platform::dynload::ncclReduce(
input.data(), in_tensor.data(),
output->data(), out_tensor->data(),
input.numel(), in_tensor.numel(),
platform::ToNCCLDataType(input.dtype()), platform::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op), ToNCCLRedType(opts.reduce_op),
opts.root_rank, opts.root_rank,
comm, comm,
stream)); stream));
}, },
in_tensor,
CommType::REDUCE, CommType::REDUCE,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
...@@ -342,22 +322,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter( ...@@ -342,22 +322,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
const ReduceScatterOptions& opts, const ReduceScatterOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
return Collective( return RunFnInNCCLEnv(
out_tensor, [&](ncclComm_t comm, gpuStream_t stream) {
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduceScatter( NCCL_CHECK(platform::dynload::ncclReduceScatter(
input.data(), in_tensor.data(),
output->data(), out_tensor->data(),
output->numel(), out_tensor->numel(),
platform::ToNCCLDataType(input.dtype()), platform::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op), ToNCCLRedType(opts.reduce_op),
comm, comm,
stream)); stream));
}, },
in_tensor,
CommType::REDUCE_SCATTER, CommType::REDUCE_SCATTER,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
...@@ -369,47 +345,43 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -369,47 +345,43 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
const ScatterOptions& opts, const ScatterOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
return Collective( return RunFnInNCCLEnv(
out_tensor, [&](ncclComm_t comm, gpuStream_t stream) {
in_tensor, int64_t numel = in_tensor.numel() / size_;
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
int64_t numel = input.numel() / size_;
if (rank_ == opts.root_rank) { if (rank_ == opts.root_rank) {
int64_t offset = 0; int64_t offset = 0;
phi::DenseTensor partial_tensor; phi::DenseTensor partial_tensor;
GroupStart(); GroupStart();
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(input, offset, numel); partial_tensor = GetPartialTensor(in_tensor, offset, numel);
NCCL_CHECK(platform::dynload::ncclSend( NCCL_CHECK(platform::dynload::ncclSend(
partial_tensor.data(), partial_tensor.data(),
numel, numel,
platform::ToNCCLDataType(input.dtype()), platform::ToNCCLDataType(partial_tensor.dtype()),
i, i,
comm, comm,
stream)); stream));
offset += numel; offset += numel;
} }
NCCL_CHECK(platform::dynload::ncclRecv( NCCL_CHECK(platform::dynload::ncclRecv(
output->data(), out_tensor->data(),
numel, numel,
platform::ToNCCLDataType(output->dtype()), platform::ToNCCLDataType(out_tensor->dtype()),
opts.root_rank, opts.root_rank,
comm, comm,
stream)); stream));
GroupEnd(); GroupEnd();
} else { } else {
NCCL_CHECK(platform::dynload::ncclRecv( NCCL_CHECK(platform::dynload::ncclRecv(
output->data(), out_tensor->data(),
numel, numel,
platform::ToNCCLDataType(output->dtype()), platform::ToNCCLDataType(out_tensor->dtype()),
opts.root_rank, opts.root_rank,
comm, comm,
stream)); stream));
} }
}, },
in_tensor,
CommType::SCATTER, CommType::SCATTER,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
...@@ -428,54 +400,43 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv( ...@@ -428,54 +400,43 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
partial_tensor = GetPartialTensor(*tensor, offset, numel); partial_tensor = GetPartialTensor(*tensor, offset, numel);
tensor = &partial_tensor; tensor = &partial_tensor;
} }
return PointToPoint( return RunFnInNCCLEnv(
tensor, [&](ncclComm_t comm, gpuStream_t stream) {
src_rank,
[](phi::DenseTensor* output,
int src,
ncclComm_t comm,
gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclRecv( NCCL_CHECK(platform::dynload::ncclRecv(
output->data(), tensor->data(),
output->numel(), tensor->numel(),
platform::ToNCCLDataType(output->dtype()), platform::ToNCCLDataType(tensor->dtype()),
src, src_rank,
comm, comm,
stream)); stream));
}, },
*tensor,
CommType::RECV, CommType::RECV,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
phi::DenseTensor* tensor, const phi::DenseTensor& tensor,
int dst_rank, int dst_rank,
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
// numel > 0 indicates the tensor need to be sliced // numel > 0 indicates the tensor need to be sliced
phi::DenseTensor partial_tensor; const phi::DenseTensor& tensor_maybe_partial =
if (numel > 0) { numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
partial_tensor = GetPartialTensor(*tensor, offset, numel); return RunFnInNCCLEnv(
tensor = &partial_tensor; [&](ncclComm_t comm, gpuStream_t stream) {
}
return PointToPoint(
tensor,
dst_rank,
[](phi::DenseTensor* input,
int dst,
ncclComm_t comm,
gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclSend( NCCL_CHECK(platform::dynload::ncclSend(
input->data(), tensor_maybe_partial.data(),
input->numel(), tensor_maybe_partial.numel(),
platform::ToNCCLDataType(input->dtype()), platform::ToNCCLDataType(tensor_maybe_partial.dtype()),
dst, dst_rank,
comm, comm,
stream)); stream));
}, },
tensor_maybe_partial,
CommType::SEND, CommType::SEND,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
...@@ -548,54 +509,13 @@ void ProcessGroupNCCL::SyncCalcStream(const Place& place) { ...@@ -548,54 +509,13 @@ void ProcessGroupNCCL::SyncCalcStream(const Place& place) {
calc_event.Wait(platform::Place2DeviceType(place), comm_ctx); calc_event.Wait(platform::Place2DeviceType(place), comm_ctx);
} }
template <typename Fn> std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv(
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( std::function<void(ncclComm_t, gpuStream_t)> fn,
phi::DenseTensor* out_tensor, const phi::DenseTensor& tensor,
const phi::DenseTensor& in_tensor,
Fn fn,
CommType comm_type,
bool sync_op,
bool use_calc_stream) {
const auto& place = in_tensor.place();
const auto& key = GetKeyFromPlace(place);
platform::CUDADeviceGuard cuda_guard(place);
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLEnvCache(place, key);
}
if (!use_calc_stream) {
SyncCalcStream(place);
}
auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);
const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto& comm_ctx = place_to_comm_ctx_.at(key);
auto nccl_comm = comm_ctx->nccl_comm();
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
fn(out_tensor, in_tensor, nccl_comm, nccl_stream);
if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
memory::RecordStream(in_tensor.Holder(), nccl_stream);
}
task->UpdateWaitChain(*comm_ctx);
}
return task;
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
phi::DenseTensor* tensor,
int rank,
Fn fn,
CommType comm_type, CommType comm_type,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
const auto& place = tensor->place(); const auto& place = tensor.place();
const auto& key = GetKeyFromPlace(place); const auto& key = GetKeyFromPlace(place);
platform::CUDADeviceGuard cuda_guard(place); platform::CUDADeviceGuard cuda_guard(place);
...@@ -614,11 +534,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -614,11 +534,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
const auto& comm_ctx = place_to_comm_ctx_.at(key); const auto& comm_ctx = place_to_comm_ctx_.at(key);
auto nccl_comm = comm_ctx->nccl_comm(); auto nccl_comm = comm_ctx->nccl_comm();
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
fn(tensor, rank, nccl_comm, nccl_stream); fn(nccl_comm, nccl_stream);
if (!use_calc_stream) { if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) { if (FLAGS_use_stream_safe_cuda_allocator) {
memory::RecordStream(tensor->Holder(), nccl_stream); memory::RecordStream(tensor.Holder(), nccl_stream);
} }
task->UpdateWaitChain(*comm_ctx); task->UpdateWaitChain(*comm_ctx);
} }
......
...@@ -150,7 +150,7 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -150,7 +150,7 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor, std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank, int dst_rank,
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
...@@ -210,23 +210,13 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -210,23 +210,13 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
void CreateNCCLEnvCache(const Place& place, const std::string& place_key); void CreateNCCLEnvCache(const Place& place, const std::string& place_key);
template <typename Fn> std::shared_ptr<ProcessGroup::Task> RunFnInNCCLEnv(
std::shared_ptr<ProcessGroupStream::Task> Collective( std::function<void(ncclComm_t, gpuStream_t)> fn,
phi::DenseTensor* out_tensor, const phi::DenseTensor& tensor,
const phi::DenseTensor& in_tensor,
Fn fn,
CommType comm_type, CommType comm_type,
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint(phi::DenseTensor* tensor,
int rank,
Fn fn,
CommType op_type,
bool sync_op,
bool use_calc_stream);
void SyncCalcStream(const Place& place); void SyncCalcStream(const Place& place);
// TODO(sunyilun): methods below will be removed later // TODO(sunyilun): methods below will be removed later
......
...@@ -212,7 +212,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv( ...@@ -212,7 +212,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
phi::DenseTensor* tensor, const phi::DenseTensor& tensor,
int dst_rank, int dst_rank,
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
...@@ -226,7 +226,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send( ...@@ -226,7 +226,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
phi::DenseTensor*, const phi::DenseTensor& tensor,
int dst_rank, int dst_rank,
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
......
...@@ -168,18 +168,19 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -168,18 +168,19 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor, std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank, int dst_rank,
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
bool sync_op) override; bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor, virtual std::shared_ptr<ProcessGroup::Task> Send(
int dst_rank, const phi::DenseTensor& tensor,
int64_t offset, int dst_rank,
int64_t numel, int64_t offset,
bool sync_op, int64_t numel,
bool use_calc_stream); bool sync_op,
bool use_calc_stream);
}; };
} // namespace distributed } // namespace distributed
......
...@@ -226,7 +226,7 @@ struct GlobalGatherProcessGroupFunctor<phi::GPUContext, T> { ...@@ -226,7 +226,7 @@ struct GlobalGatherProcessGroupFunctor<phi::GPUContext, T> {
int idx = i + j * n_expert; int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) { if (cpu_global_count_data[idx]) {
phi::DenseTensor tmp = *x; phi::DenseTensor tmp = *x;
pg->Send(&tmp, pg->Send(tmp,
j, j,
send_ptr * in_feat, send_ptr * in_feat,
cpu_global_count_data[idx] * in_feat, cpu_global_count_data[idx] * in_feat,
......
...@@ -224,7 +224,7 @@ struct GlobalScatterProcessGroupFunctor<phi::GPUContext, T> { ...@@ -224,7 +224,7 @@ struct GlobalScatterProcessGroupFunctor<phi::GPUContext, T> {
int idx = i + j * n_expert; int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) { if (cpu_local_count_data[idx]) {
phi::DenseTensor tmp = *x; phi::DenseTensor tmp = *x;
pg->Send(&tmp, pg->Send(tmp,
j, j,
expert_ptr[idx] * in_feat, expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat, cpu_local_count_data[idx] * in_feat,
......
...@@ -70,7 +70,7 @@ class PartialSendCUDAKernel : public framework::OpKernel<T> { ...@@ -70,7 +70,7 @@ class PartialSendCUDAKernel : public framework::OpKernel<T> {
// Use ProcessGroup // Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid); distributed::ProcessGroup* pg = map->get(rid);
phi::DenseTensor tmp = *x; phi::DenseTensor tmp = *x;
auto task = pg->Send(&tmp, peer, offset, send_numel, /*sync_op*/ true); auto task = pg->Send(tmp, peer, offset, send_numel, /*sync_op*/ true);
task->Wait(); task->Wait();
} else { } else {
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
......
...@@ -168,7 +168,7 @@ void BindDistributed(py::module *m) { ...@@ -168,7 +168,7 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense = auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get(); auto out_dense = *p_dense;
// numel == -1 indicates sending the whole tensor // numel == -1 indicates sending the whole tensor
return self.Send( return self.Send(
out_dense, dst, /*offset*/ 0, /*numel*/ -1, sync_op); out_dense, dst, /*offset*/ 0, /*numel*/ -1, sync_op);
...@@ -189,7 +189,7 @@ void BindDistributed(py::module *m) { ...@@ -189,7 +189,7 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense = auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get(); auto out_dense = *p_dense;
int64_t numel = p_dense->numel(); int64_t numel = p_dense->numel();
int64_t send_numel = numel / nranks; int64_t send_numel = numel / nranks;
...@@ -1126,7 +1126,7 @@ void BindDistributed(py::module *m) { ...@@ -1126,7 +1126,7 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense = auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get(); auto out_dense = *p_dense;
// numel == -1 indicates sending the whole tensor // numel == -1 indicates sending the whole tensor
return self.Send(out_dense, return self.Send(out_dense,
dst, dst,
...@@ -1149,7 +1149,7 @@ void BindDistributed(py::module *m) { ...@@ -1149,7 +1149,7 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense = auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get(); auto out_dense = *p_dense;
int64_t numel = p_dense->numel(); int64_t numel = p_dense->numel();
int64_t send_numel = numel / nranks; int64_t send_numel = numel / nranks;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册