未验证 提交 edda13cd 编写于 作者: W Wen Sun 提交者: GitHub

Refactor collective communication reduce, scatter, reduce_scatter C++ API (#48115)

上级 d7f7963f
......@@ -47,7 +47,7 @@
namespace paddle {
namespace distributed {
#define NCCLCHECK(cmd) \
#define NCCL_CHECK(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
......@@ -60,6 +60,7 @@ namespace distributed {
} while (0)
ncclRedOp_t ToNCCLRedType(ReduceOp reduction);
std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID);
} // namespace distributed
......
......@@ -150,6 +150,36 @@ class ProcessGroup {
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support reduce with sync_op flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support reduce_scatter with sync_op flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support scatter with sync_op flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
......@@ -273,16 +303,6 @@ class ProcessGroup {
"ProcessGroup%s does not support reduce", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const ReduceOptions&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support reduce with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
......@@ -291,26 +311,6 @@ class ProcessGroup {
"ProcessGroup%s does not support scatter", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ScatterOptions&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support scatter with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ReduceScatterOptions&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support reduce_scatter with sync_op flag",
GetBackendName()));
}
protected:
const int rank_;
const int size_;
......
......@@ -234,8 +234,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper = {in_tensor};
std::vector<phi::DenseTensor> out_wrapper = {*out_tensor};
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts, true);
}
......@@ -396,8 +396,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
int64_t offset, // for compatibility, no use now
int64_t numel, // for compatibility, no use now
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper = {in_tensor};
std::vector<phi::DenseTensor> out_wrapper = {*out_tensor};
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllGather(in_wrapper, out_wrapper, true);
}
......@@ -475,26 +475,34 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const ReduceOptions& opts) {
return Reduce(inputs, outputs, opts, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op) {
bool sync_op // for compatibility, no use now
) {
std::shared_ptr<ReduceGlooTask> task;
auto tag = next_tag();
auto context = get_context();
task = std::make_shared<ReduceGlooTask>(
rank_, context, inputs, outputs, opts.reduce_op, opts.root_rank, tag);
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
task = std::make_shared<ReduceGlooTask>(rank_,
context,
in_wrapper,
out_wrapper,
opts.reduce_op,
opts.root_rank,
tag);
task->Run();
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const ReduceOptions& opts) {
return Reduce(&outputs[0], inputs[0], opts, true);
}
class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
public:
ScatterGlooTask(int rank,
......@@ -538,26 +546,28 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) {
return Scatter(in_tensors, out_tensors, opts, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op) {
std::shared_ptr<ScatterGlooTask> task;
auto tag = next_tag();
auto context = get_context();
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
task = std::make_shared<ScatterGlooTask>(
rank_, context, in_tensors, out_tensors, opts.root_rank, size_, tag);
rank_, context, in_wrapper, out_wrapper, opts.root_rank, size_, tag);
task->Run();
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) {
return Scatter(&out_tensors[0], in_tensors[0], opts, true);
}
std::shared_ptr<::gloo::transport::Device>
ProcessGroupGloo::createDeviceForInterface(const std::string& ifname) {
::gloo::transport::tcp::attr attr;
......
......@@ -120,6 +120,16 @@ class ProcessGroupGloo : public ProcessGroup {
const BroadcastOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& inputs,
......@@ -155,23 +165,11 @@ class ProcessGroupGloo : public ProcessGroup {
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions&,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......
......@@ -87,11 +87,11 @@ ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
: ProcessGroupStream(rank, size, gid), store_(store) {}
void ProcessGroupNCCL::GroupStart() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
NCCL_CHECK(platform::dynload::ncclGroupStart());
}
void ProcessGroupNCCL::GroupEnd() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
NCCL_CHECK(platform::dynload::ncclGroupEnd());
}
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
......@@ -144,13 +144,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclAllGather(
NCCL_CHECK(platform::dynload::ncclAllGather(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
comm,
stream);
stream));
},
CommType::ALLGATHER,
sync_op,
......@@ -170,14 +170,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclAllReduce(
NCCL_CHECK(platform::dynload::ncclAllReduce(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op),
comm,
stream);
stream));
},
CommType::ALLREDUCE,
sync_op,
......@@ -231,7 +231,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
for (auto i = 0; i < size_; i++) {
in_numel = in_size_each_rank[i] * in_row_size;
input_partial = GetPartialTensor(input, in_offset, in_numel);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
NCCL_CHECK(platform::dynload::ncclSend(
input_partial.data(),
in_numel,
platform::ToNCCLDataType(input.dtype()),
......@@ -242,7 +242,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
out_numel = out_size_each_rank[i] * out_row_size;
output_partial = GetPartialTensor(*output, out_offset, out_numel);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
NCCL_CHECK(platform::dynload::ncclRecv(
output_partial.data(),
out_numel,
platform::ToNCCLDataType(output->dtype()),
......@@ -294,20 +294,127 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
ncclComm_t comm,
gpuStream_t stream) {
int root = opts.source_rank + opts.source_root;
return platform::dynload::ncclBroadcast(
NCCL_CHECK(platform::dynload::ncclBroadcast(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
root,
comm,
stream);
stream));
},
CommType::BROADCAST,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduce(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream));
},
CommType::REDUCE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduceScatter(
input.data(),
output->data(),
output->numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
stream));
},
CommType::REDUCE_SCATTER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
int64_t numel = input.numel() / size_;
if (rank_ == opts.root_rank) {
int64_t offset = 0;
phi::DenseTensor partial_tensor;
GroupStart();
for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(input, offset, numel);
NCCL_CHECK(platform::dynload::ncclSend(
partial_tensor.data(),
numel,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
offset += numel;
}
NCCL_CHECK(platform::dynload::ncclRecv(
output->data(),
numel,
platform::ToNCCLDataType(output->dtype()),
opts.root_rank,
comm,
stream));
GroupEnd();
} else {
NCCL_CHECK(platform::dynload::ncclRecv(
output->data(),
numel,
platform::ToNCCLDataType(output->dtype()),
opts.root_rank,
comm,
stream));
}
},
CommType::SCATTER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
phi::DenseTensor* tensor,
int src_rank,
......@@ -328,13 +435,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
int src,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclRecv(
NCCL_CHECK(platform::dynload::ncclRecv(
output->data(),
output->numel(),
platform::ToNCCLDataType(output->dtype()),
src,
comm,
stream);
stream));
},
CommType::RECV,
sync_op,
......@@ -361,13 +468,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
int dst,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclSend(
NCCL_CHECK(platform::dynload::ncclSend(
input->data(),
input->numel(),
platform::ToNCCLDataType(input->dtype()),
dst,
comm,
stream);
stream));
},
CommType::SEND,
sync_op,
......@@ -406,7 +513,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
ncclUniqueId nccl_id;
if (rank_ == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id));
NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id));
}
BroadcastUniqueNCCLID(&nccl_id);
......@@ -418,7 +525,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
platform::DeviceContextPool::Instance().Get(place));
auto comm_ctx = std::make_unique<phi::GPUContext>(place);
ncclComm_t nccl_comm;
NCCLCHECK(platform::dynload::ncclCommInitRank(
NCCL_CHECK(platform::dynload::ncclCommInitRank(
&nccl_comm, GetSize(), nccl_id, GetRank()));
comm_ctx->set_nccl_comm(nccl_comm);
......@@ -611,7 +718,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
ncclUniqueId nccl_id;
if (rank_ == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id));
NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id));
}
BroadcastUniqueNCCLID(&nccl_id);
......@@ -632,7 +739,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
dev_ctx[i].reset(new phi::GPUContext(places[i]));
ncclComm_t nccl_comm;
NCCLCHECK(platform::dynload::ncclCommInitRank(
NCCL_CHECK(platform::dynload::ncclCommInitRank(
&nccl_comm, GetSize(), nccl_id, GetRank()));
dev_ctx[i]->set_nccl_comm(nccl_comm);
dev_ctx_raw[i] = dev_ctx[i].get();
......@@ -1257,70 +1364,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
CommType::REDUCE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream));
},
CommType::REDUCE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
platform::CUDADeviceGuard cuda_guard;
cuda_guard.SetDevice(output.place());
memory::RecordStream(output.Holder(), stream);
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
input.data(),
output.data(),
output.numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
stream));
},
CommType::REDUCE_SCATTER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......@@ -1374,67 +1417,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
CommType::SCATTER);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
PADDLE_ENFORCE_EQ(
output.numel(),
input.numel() / size_,
platform::errors::InvalidArgument(
"Input and output tensors should have the same shape."));
size_t offset = 0;
if (rank_ == opts.root_rank) {
GroupStart();
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
offset += input.numel() / size_;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
GroupEnd();
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
}
},
CommType::SCATTER,
sync_op,
use_calc_stream);
}
} // namespace distributed
} // namespace paddle
......@@ -127,6 +127,25 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
......@@ -184,32 +203,11 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> ReduceScatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) override;
private:
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
int rank,
......
......@@ -120,6 +120,72 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
"ProcessGroup%s does not support broadcast.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op) {
return Reduce(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support reduce.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op) {
return ReduceScatter(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support reduce_scatter.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op) {
return Scatter(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support scatter.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
phi::DenseTensor* tensor,
int src_rank,
......@@ -190,72 +256,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
"ProcessGroup%s does not support do alltoall", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts,
bool sync_op) {
return Reduce(in_tensors,
out_tensors,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do reduce", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::ReduceScatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceScatterOptions& opts,
bool sync_op) {
return ReduceScatter(in_tensors,
out_tensors,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::ReduceScatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do reduce_scatter", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts,
bool sync_op) {
return Scatter(in_tensors,
out_tensors,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do scatter", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank, bool sync_op) {
return Recv(tensors,
......
......@@ -117,6 +117,43 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
......@@ -155,45 +192,6 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const ReduceOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> ReduceScatter(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const ReduceScatterOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const ScatterOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, // NOLINT
int src_rank,
......
......@@ -412,16 +412,17 @@ void BindDistributed(py::module *m) {
.def(
"reduce",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_tensor,
int dst,
distributed::ReduceOp op,
bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
auto in_dense = *p_dense;
distributed::ReduceOptions opts{op, dst};
auto dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Reduce(tensors, tensors, opts, sync_op);
return self.Reduce(out_dense, in_dense, opts, sync_op);
},
py::arg("tensor"),
py::arg("dst"),
......@@ -432,28 +433,27 @@ void BindDistributed(py::module *m) {
.def(
"reduce_scatter",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor_list,
py::handle py_out_tensor,
py::handle py_in_tensor_list,
distributed::ReduceOp op,
bool sync_op) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
auto out_dense = p_out_tensor.get();
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto in_dense = *p_in_tensor;
distributed::ReduceScatterOptions opts{op};
return self.ReduceScatter(
in_wrapper, out_wrapper, opts, sync_op);
return self.ReduceScatter(out_dense, in_dense, opts, sync_op);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("op"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
......@@ -461,26 +461,25 @@ void BindDistributed(py::module *m) {
.def(
"reduce_scatter_tensor",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
py::handle py_in_tensor,
distributed::ReduceOp op,
bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
distributed::ReduceScatterOptions opts{op};
return self.ReduceScatter(
in_wrapper, out_wrapper, opts, sync_op);
return self.ReduceScatter(out_dense, in_dense, opts, sync_op);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("op"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
......@@ -488,27 +487,27 @@ void BindDistributed(py::module *m) {
.def(
"scatter",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor_list,
py::handle py_out_tensor,
py::handle py_in_tensor_list,
int src,
bool sync_op) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
auto *out_dense = p_out_tensor.get();
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto in_dense = *p_in_tensor;
distributed::ScatterOptions opts{src};
return self.Scatter(in_wrapper, out_wrapper, opts, sync_op);
return self.Scatter(out_dense, in_dense, opts, sync_op);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("src"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
......@@ -516,25 +515,25 @@ void BindDistributed(py::module *m) {
.def(
"scatter_tensor",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
py::handle py_in_tensor,
int src,
bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
distributed::ScatterOptions opts{src};
return self.Scatter(in_wrapper, out_wrapper, opts, sync_op);
return self.Scatter(out_dense, in_dense, opts, sync_op);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("src"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
......@@ -986,16 +985,17 @@ void BindDistributed(py::module *m) {
.def(
"reduce_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_tensor,
int dst,
distributed::ReduceOp op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
auto in_dense = *p_dense;
distributed::ReduceOptions opts{op, dst};
auto dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Reduce(tensors,
tensors,
return self.Reduce(out_dense,
in_dense,
opts,
/*sync_op*/ true,
/*use_calc_stream*/ true);
......@@ -1008,116 +1008,116 @@ void BindDistributed(py::module *m) {
.def(
"reduce_scatter_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor_list,
py::handle py_out_tensor,
py::handle py_in_tensor_list,
distributed::ReduceOp op) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
auto out_dense = p_out_tensor.get();
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto in_dense = *p_in_tensor;
distributed::ReduceScatterOptions opts{op};
return self.ReduceScatter(in_wrapper,
out_wrapper,
return self.ReduceScatter(out_dense,
in_dense,
opts,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("op"),
py::call_guard<py::gil_scoped_release>())
.def(
"reduce_scatter_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
py::handle py_in_tensor,
distributed::ReduceOp op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
distributed::ReduceScatterOptions opts{op};
return self.ReduceScatter(in_wrapper,
out_wrapper,
return self.ReduceScatter(out_dense,
in_dense,
opts,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("op"),
py::call_guard<py::gil_scoped_release>())
.def(
"scatter_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor_list,
py::handle py_out_tensor,
py::handle py_in_tensor_list,
int src) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
auto *out_dense = p_out_tensor.get();
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto in_dense = *p_in_tensor;
distributed::ScatterOptions opts{src};
return self.Scatter(in_wrapper,
out_wrapper,
return self.Scatter(out_dense,
in_dense,
opts,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def(
"scatter_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
py::handle py_in_tensor,
int src) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
distributed::ScatterOptions opts{src};
return self.Scatter(in_wrapper,
out_wrapper,
return self.Scatter(out_dense,
in_dense,
opts,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
......
......@@ -57,11 +57,11 @@ def _reduce_scatter_tensor_in_dygraph(
if use_calc_stream:
return group.process_group.reduce_scatter_tensor_on_calc_stream(
in_tensor, out_tensor, op_type
out_tensor, in_tensor, op_type
)
task = group.process_group.reduce_scatter_tensor(
in_tensor, out_tensor, op_type, sync_op
out_tensor, in_tensor, op_type, sync_op
)
if sync_op:
task.wait()
......@@ -78,11 +78,11 @@ def _reduce_scatter_in_dygraph(
if use_calc_stream:
return group.process_group.reduce_scatter_on_calc_stream(
tensor_list, tensor, op_type
tensor, tensor_list, op_type
)
task = group.process_group.reduce_scatter(
tensor_list, tensor, op_type, sync_op
tensor, tensor_list, op_type, sync_op
)
if sync_op:
task.wait()
......
......@@ -53,11 +53,11 @@ def _scatter_tensor_in_dygraph(
if use_calc_stream:
return group.process_group.scatter_tensor_on_calc_stream(
in_tensor, out_tensor, src_rank_in_group
out_tensor, in_tensor, src_rank_in_group
)
task = group.process_group.scatter_tensor(
in_tensor, out_tensor, src_rank_in_group, sync_op
out_tensor, in_tensor, src_rank_in_group, sync_op
)
if sync_op:
task.wait()
......@@ -80,11 +80,11 @@ def _scatter_in_dygraph(
if use_calc_stream:
return group.process_group.scatter_on_calc_stream(
tensor_list, tensor, src_rank_in_group
tensor, tensor_list, src_rank_in_group
)
task = group.process_group.scatter(
tensor_list, tensor, src_rank_in_group, sync_op
tensor, tensor_list, src_rank_in_group, sync_op
)
if sync_op:
task.wait()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册