未验证 提交 edda13cd 编写于 作者: W Wen Sun 提交者: GitHub

Refactor collective communication reduce, scatter, reduce_scatter C++ API (#48115)

上级 d7f7963f
...@@ -47,7 +47,7 @@ ...@@ -47,7 +47,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
#define NCCLCHECK(cmd) \ #define NCCL_CHECK(cmd) \
do { \ do { \
ncclResult_t r = cmd; \ ncclResult_t r = cmd; \
if (r != ncclSuccess) { \ if (r != ncclSuccess) { \
...@@ -60,6 +60,7 @@ namespace distributed { ...@@ -60,6 +60,7 @@ namespace distributed {
} while (0) } while (0)
ncclRedOp_t ToNCCLRedType(ReduceOp reduction); ncclRedOp_t ToNCCLRedType(ReduceOp reduction);
std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID); std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID);
} // namespace distributed } // namespace distributed
......
...@@ -150,6 +150,36 @@ class ProcessGroup { ...@@ -150,6 +150,36 @@ class ProcessGroup {
GetBackendName())); GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support reduce with sync_op flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support reduce_scatter with sync_op flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support scatter with sync_op flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor, virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank, int src_rank,
int64_t offset, int64_t offset,
...@@ -273,16 +303,6 @@ class ProcessGroup { ...@@ -273,16 +303,6 @@ class ProcessGroup {
"ProcessGroup%s does not support reduce", GetBackendName())); "ProcessGroup%s does not support reduce", GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const ReduceOptions&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support reduce with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Scatter( virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
...@@ -291,26 +311,6 @@ class ProcessGroup { ...@@ -291,26 +311,6 @@ class ProcessGroup {
"ProcessGroup%s does not support scatter", GetBackendName())); "ProcessGroup%s does not support scatter", GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ScatterOptions&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support scatter with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ReduceScatterOptions&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support reduce_scatter with sync_op flag",
GetBackendName()));
}
protected: protected:
const int rank_; const int rank_;
const int size_; const int size_;
......
...@@ -234,8 +234,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast( ...@@ -234,8 +234,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts, const BroadcastOptions& opts,
bool sync_op) { bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper = {in_tensor}; std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper = {*out_tensor}; std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Broadcast(in_wrapper, out_wrapper, opts, true); return Broadcast(in_wrapper, out_wrapper, opts, true);
} }
...@@ -396,8 +396,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather( ...@@ -396,8 +396,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
int64_t offset, // for compatibility, no use now int64_t offset, // for compatibility, no use now
int64_t numel, // for compatibility, no use now int64_t numel, // for compatibility, no use now
bool sync_op) { bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper = {in_tensor}; std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper = {*out_tensor}; std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllGather(in_wrapper, out_wrapper, true); return AllGather(in_wrapper, out_wrapper, true);
} }
...@@ -475,26 +475,34 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -475,26 +475,34 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
}; };
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
std::vector<phi::DenseTensor>& inputs, phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& outputs, const phi::DenseTensor& in_tensor,
const ReduceOptions& opts) {
return Reduce(inputs, outputs, opts, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const ReduceOptions& opts, const ReduceOptions& opts,
bool sync_op) { bool sync_op // for compatibility, no use now
) {
std::shared_ptr<ReduceGlooTask> task; std::shared_ptr<ReduceGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto context = get_context();
task = std::make_shared<ReduceGlooTask>( std::vector<phi::DenseTensor> in_wrapper{in_tensor};
rank_, context, inputs, outputs, opts.reduce_op, opts.root_rank, tag); std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
task = std::make_shared<ReduceGlooTask>(rank_,
context,
in_wrapper,
out_wrapper,
opts.reduce_op,
opts.root_rank,
tag);
task->Run(); task->Run();
return task; return task;
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const ReduceOptions& opts) {
return Reduce(&outputs[0], inputs[0], opts, true);
}
class ScatterGlooTask : public ProcessGroupGloo::GlooTask { class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
public: public:
ScatterGlooTask(int rank, ScatterGlooTask(int rank,
...@@ -538,26 +546,28 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -538,26 +546,28 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
}; };
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter( std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
std::vector<phi::DenseTensor>& in_tensors, phi::DenseTensor* out_tensor,
std::vector<phi::DenseTensor>& out_tensors, const phi::DenseTensor& in_tensor,
const ScatterOptions& opts) {
return Scatter(in_tensors, out_tensors, opts, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts, const ScatterOptions& opts,
bool sync_op) { bool sync_op) {
std::shared_ptr<ScatterGlooTask> task; std::shared_ptr<ScatterGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto context = get_context();
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
task = std::make_shared<ScatterGlooTask>( task = std::make_shared<ScatterGlooTask>(
rank_, context, in_tensors, out_tensors, opts.root_rank, size_, tag); rank_, context, in_wrapper, out_wrapper, opts.root_rank, size_, tag);
task->Run(); task->Run();
return task; return task;
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) {
return Scatter(&out_tensors[0], in_tensors[0], opts, true);
}
std::shared_ptr<::gloo::transport::Device> std::shared_ptr<::gloo::transport::Device>
ProcessGroupGloo::createDeviceForInterface(const std::string& ifname) { ProcessGroupGloo::createDeviceForInterface(const std::string& ifname) {
::gloo::transport::tcp::attr attr; ::gloo::transport::tcp::attr attr;
......
...@@ -120,6 +120,16 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -120,6 +120,16 @@ class ProcessGroupGloo : public ProcessGroup {
const BroadcastOptions& opts, const BroadcastOptions& opts,
bool sync_op) override; bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op) override;
// TODO(sunyilun): methods below will be removed later // TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast( std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& inputs, std::vector<phi::DenseTensor>& inputs,
...@@ -155,23 +165,11 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -155,23 +165,11 @@ class ProcessGroupGloo : public ProcessGroup {
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
bool sync_op) override; bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Reduce( std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override; const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions&,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Scatter( std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
......
...@@ -87,11 +87,11 @@ ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store, ...@@ -87,11 +87,11 @@ ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
: ProcessGroupStream(rank, size, gid), store_(store) {} : ProcessGroupStream(rank, size, gid), store_(store) {}
void ProcessGroupNCCL::GroupStart() { void ProcessGroupNCCL::GroupStart() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); NCCL_CHECK(platform::dynload::ncclGroupStart());
} }
void ProcessGroupNCCL::GroupEnd() { void ProcessGroupNCCL::GroupEnd() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); NCCL_CHECK(platform::dynload::ncclGroupEnd());
} }
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
...@@ -144,13 +144,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather( ...@@ -144,13 +144,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
const phi::DenseTensor& input, const phi::DenseTensor& input,
ncclComm_t comm, ncclComm_t comm,
gpuStream_t stream) { gpuStream_t stream) {
return platform::dynload::ncclAllGather( NCCL_CHECK(platform::dynload::ncclAllGather(
input.data(), input.data(),
output->data(), output->data(),
input.numel(), input.numel(),
platform::ToNCCLDataType(input.dtype()), platform::ToNCCLDataType(input.dtype()),
comm, comm,
stream); stream));
}, },
CommType::ALLGATHER, CommType::ALLGATHER,
sync_op, sync_op,
...@@ -170,14 +170,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce( ...@@ -170,14 +170,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
const phi::DenseTensor& input, const phi::DenseTensor& input,
ncclComm_t comm, ncclComm_t comm,
gpuStream_t stream) { gpuStream_t stream) {
return platform::dynload::ncclAllReduce( NCCL_CHECK(platform::dynload::ncclAllReduce(
input.data(), input.data(),
output->data(), output->data(),
input.numel(), input.numel(),
platform::ToNCCLDataType(input.type()), platform::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op), ToNCCLRedType(opts.reduce_op),
comm, comm,
stream); stream));
}, },
CommType::ALLREDUCE, CommType::ALLREDUCE,
sync_op, sync_op,
...@@ -231,7 +231,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -231,7 +231,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
in_numel = in_size_each_rank[i] * in_row_size; in_numel = in_size_each_rank[i] * in_row_size;
input_partial = GetPartialTensor(input, in_offset, in_numel); input_partial = GetPartialTensor(input, in_offset, in_numel);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( NCCL_CHECK(platform::dynload::ncclSend(
input_partial.data(), input_partial.data(),
in_numel, in_numel,
platform::ToNCCLDataType(input.dtype()), platform::ToNCCLDataType(input.dtype()),
...@@ -242,7 +242,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -242,7 +242,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
out_numel = out_size_each_rank[i] * out_row_size; out_numel = out_size_each_rank[i] * out_row_size;
output_partial = GetPartialTensor(*output, out_offset, out_numel); output_partial = GetPartialTensor(*output, out_offset, out_numel);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( NCCL_CHECK(platform::dynload::ncclRecv(
output_partial.data(), output_partial.data(),
out_numel, out_numel,
platform::ToNCCLDataType(output->dtype()), platform::ToNCCLDataType(output->dtype()),
...@@ -294,20 +294,127 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( ...@@ -294,20 +294,127 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
ncclComm_t comm, ncclComm_t comm,
gpuStream_t stream) { gpuStream_t stream) {
int root = opts.source_rank + opts.source_root; int root = opts.source_rank + opts.source_root;
return platform::dynload::ncclBroadcast( NCCL_CHECK(platform::dynload::ncclBroadcast(
input.data(), input.data(),
output->data(), output->data(),
input.numel(), input.numel(),
platform::ToNCCLDataType(input.type()), platform::ToNCCLDataType(input.type()),
root, root,
comm, comm,
stream); stream));
}, },
CommType::BROADCAST, CommType::BROADCAST,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduce(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream));
},
CommType::REDUCE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduceScatter(
input.data(),
output->data(),
output->numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
stream));
},
CommType::REDUCE_SCATTER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
int64_t numel = input.numel() / size_;
if (rank_ == opts.root_rank) {
int64_t offset = 0;
phi::DenseTensor partial_tensor;
GroupStart();
for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(input, offset, numel);
NCCL_CHECK(platform::dynload::ncclSend(
partial_tensor.data(),
numel,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
offset += numel;
}
NCCL_CHECK(platform::dynload::ncclRecv(
output->data(),
numel,
platform::ToNCCLDataType(output->dtype()),
opts.root_rank,
comm,
stream));
GroupEnd();
} else {
NCCL_CHECK(platform::dynload::ncclRecv(
output->data(),
numel,
platform::ToNCCLDataType(output->dtype()),
opts.root_rank,
comm,
stream));
}
},
CommType::SCATTER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
phi::DenseTensor* tensor, phi::DenseTensor* tensor,
int src_rank, int src_rank,
...@@ -328,13 +435,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv( ...@@ -328,13 +435,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
int src, int src,
ncclComm_t comm, ncclComm_t comm,
gpuStream_t stream) { gpuStream_t stream) {
return platform::dynload::ncclRecv( NCCL_CHECK(platform::dynload::ncclRecv(
output->data(), output->data(),
output->numel(), output->numel(),
platform::ToNCCLDataType(output->dtype()), platform::ToNCCLDataType(output->dtype()),
src, src,
comm, comm,
stream); stream));
}, },
CommType::RECV, CommType::RECV,
sync_op, sync_op,
...@@ -361,13 +468,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send( ...@@ -361,13 +468,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
int dst, int dst,
ncclComm_t comm, ncclComm_t comm,
gpuStream_t stream) { gpuStream_t stream) {
return platform::dynload::ncclSend( NCCL_CHECK(platform::dynload::ncclSend(
input->data(), input->data(),
input->numel(), input->numel(),
platform::ToNCCLDataType(input->dtype()), platform::ToNCCLDataType(input->dtype()),
dst, dst,
comm, comm,
stream); stream));
}, },
CommType::SEND, CommType::SEND,
sync_op, sync_op,
...@@ -406,7 +513,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, ...@@ -406,7 +513,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
ncclUniqueId nccl_id; ncclUniqueId nccl_id;
if (rank_ == 0) { if (rank_ == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id)); NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id));
} }
BroadcastUniqueNCCLID(&nccl_id); BroadcastUniqueNCCLID(&nccl_id);
...@@ -418,7 +525,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, ...@@ -418,7 +525,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
auto comm_ctx = std::make_unique<phi::GPUContext>(place); auto comm_ctx = std::make_unique<phi::GPUContext>(place);
ncclComm_t nccl_comm; ncclComm_t nccl_comm;
NCCLCHECK(platform::dynload::ncclCommInitRank( NCCL_CHECK(platform::dynload::ncclCommInitRank(
&nccl_comm, GetSize(), nccl_id, GetRank())); &nccl_comm, GetSize(), nccl_id, GetRank()));
comm_ctx->set_nccl_comm(nccl_comm); comm_ctx->set_nccl_comm(nccl_comm);
...@@ -611,7 +718,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( ...@@ -611,7 +718,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
ncclUniqueId nccl_id; ncclUniqueId nccl_id;
if (rank_ == 0) { if (rank_ == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id)); NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id));
} }
BroadcastUniqueNCCLID(&nccl_id); BroadcastUniqueNCCLID(&nccl_id);
...@@ -632,7 +739,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( ...@@ -632,7 +739,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
dev_ctx[i].reset(new phi::GPUContext(places[i])); dev_ctx[i].reset(new phi::GPUContext(places[i]));
ncclComm_t nccl_comm; ncclComm_t nccl_comm;
NCCLCHECK(platform::dynload::ncclCommInitRank( NCCL_CHECK(platform::dynload::ncclCommInitRank(
&nccl_comm, GetSize(), nccl_id, GetRank())); &nccl_comm, GetSize(), nccl_id, GetRank()));
dev_ctx[i]->set_nccl_comm(nccl_comm); dev_ctx[i]->set_nccl_comm(nccl_comm);
dev_ctx_raw[i] = dev_ctx[i].get(); dev_ctx_raw[i] = dev_ctx[i].get();
...@@ -1257,70 +1364,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce( ...@@ -1257,70 +1364,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
CommType::REDUCE); CommType::REDUCE);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream));
},
CommType::REDUCE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
platform::CUDADeviceGuard cuda_guard;
cuda_guard.SetDevice(output.place());
memory::RecordStream(output.Holder(), stream);
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
input.data(),
output.data(),
output.numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
stream));
},
CommType::REDUCE_SCATTER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
...@@ -1374,67 +1417,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -1374,67 +1417,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
CommType::SCATTER); CommType::SCATTER);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
PADDLE_ENFORCE_EQ(
output.numel(),
input.numel() / size_,
platform::errors::InvalidArgument(
"Input and output tensors should have the same shape."));
size_t offset = 0;
if (rank_ == opts.root_rank) {
GroupStart();
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
offset += input.numel() / size_;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
GroupEnd();
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(),
input.numel() / size_,
platform::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
}
},
CommType::SCATTER,
sync_op,
use_calc_stream);
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -127,6 +127,25 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -127,6 +127,25 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor, std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank, int src_rank,
int64_t offset, int64_t offset,
...@@ -184,32 +203,11 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -184,32 +203,11 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override; const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> ReduceScatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Scatter( std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) override; const ScatterOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) override;
private: private:
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place, std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
int rank, int rank,
......
...@@ -120,6 +120,72 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast( ...@@ -120,6 +120,72 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
"ProcessGroup%s does not support broadcast.", GetBackendName())); "ProcessGroup%s does not support broadcast.", GetBackendName()));
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op) {
return Reduce(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support reduce.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op) {
return ReduceScatter(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support reduce_scatter.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op) {
return Scatter(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support scatter.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
phi::DenseTensor* tensor, phi::DenseTensor* tensor,
int src_rank, int src_rank,
...@@ -190,72 +256,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll( ...@@ -190,72 +256,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
"ProcessGroup%s does not support do alltoall", GetBackendName())); "ProcessGroup%s does not support do alltoall", GetBackendName()));
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts,
bool sync_op) {
return Reduce(in_tensors,
out_tensors,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do reduce", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::ReduceScatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceScatterOptions& opts,
bool sync_op) {
return ReduceScatter(in_tensors,
out_tensors,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::ReduceScatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do reduce_scatter", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts,
bool sync_op) {
return Scatter(in_tensors,
out_tensors,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do scatter", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank, bool sync_op) { std::vector<phi::DenseTensor>& tensors, int src_rank, bool sync_op) {
return Recv(tensors, return Recv(tensors,
......
...@@ -117,6 +117,43 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -117,6 +117,43 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor, std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank, int src_rank,
int64_t offset, int64_t offset,
...@@ -155,45 +192,6 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -155,45 +192,6 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const ReduceOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> ReduceScatter(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const ReduceScatterOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const ScatterOptions& opts,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Recv( std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, // NOLINT std::vector<phi::DenseTensor>& tensors, // NOLINT
int src_rank, int src_rank,
......
...@@ -412,16 +412,17 @@ void BindDistributed(py::module *m) { ...@@ -412,16 +412,17 @@ void BindDistributed(py::module *m) {
.def( .def(
"reduce", "reduce",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor, py::handle py_tensor,
int dst, int dst,
distributed::ReduceOp op, distributed::ReduceOp op,
bool sync_op) { bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
auto in_dense = *p_dense;
distributed::ReduceOptions opts{op, dst}; distributed::ReduceOptions opts{op, dst};
auto dense = std::dynamic_pointer_cast<phi::DenseTensor>( return self.Reduce(out_dense, in_dense, opts, sync_op);
in_tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Reduce(tensors, tensors, opts, sync_op);
}, },
py::arg("tensor"), py::arg("tensor"),
py::arg("dst"), py::arg("dst"),
...@@ -432,28 +433,27 @@ void BindDistributed(py::module *m) { ...@@ -432,28 +433,27 @@ void BindDistributed(py::module *m) {
.def( .def(
"reduce_scatter", "reduce_scatter",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor_list,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor_list,
distributed::ReduceOp op, distributed::ReduceOp op,
bool sync_op) { bool sync_op) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
auto out_dense = p_out_tensor.get();
auto in_tensor_list = auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl()); concat_in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense}; auto in_dense = *p_in_tensor;
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
distributed::ReduceScatterOptions opts{op}; distributed::ReduceScatterOptions opts{op};
return self.ReduceScatter( return self.ReduceScatter(out_dense, in_dense, opts, sync_op);
in_wrapper, out_wrapper, opts, sync_op);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("op"), py::arg("op"),
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
...@@ -461,26 +461,25 @@ void BindDistributed(py::module *m) { ...@@ -461,26 +461,25 @@ void BindDistributed(py::module *m) {
.def( .def(
"reduce_scatter_tensor", "reduce_scatter_tensor",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor,
distributed::ReduceOp op, distributed::ReduceOp op,
bool sync_op) { bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl()); out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
distributed::ReduceScatterOptions opts{op}; distributed::ReduceScatterOptions opts{op};
return self.ReduceScatter( return self.ReduceScatter(out_dense, in_dense, opts, sync_op);
in_wrapper, out_wrapper, opts, sync_op);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("op"), py::arg("op"),
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
...@@ -488,27 +487,27 @@ void BindDistributed(py::module *m) { ...@@ -488,27 +487,27 @@ void BindDistributed(py::module *m) {
.def( .def(
"scatter", "scatter",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor_list,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor_list,
int src, int src,
bool sync_op) { bool sync_op) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
auto *out_dense = p_out_tensor.get();
auto in_tensor_list = auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl()); concat_in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense}; auto in_dense = *p_in_tensor;
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
distributed::ScatterOptions opts{src}; distributed::ScatterOptions opts{src};
return self.Scatter(in_wrapper, out_wrapper, opts, sync_op); return self.Scatter(out_dense, in_dense, opts, sync_op);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("src"), py::arg("src"),
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
...@@ -516,25 +515,25 @@ void BindDistributed(py::module *m) { ...@@ -516,25 +515,25 @@ void BindDistributed(py::module *m) {
.def( .def(
"scatter_tensor", "scatter_tensor",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor,
int src, int src,
bool sync_op) { bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl()); out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
distributed::ScatterOptions opts{src}; distributed::ScatterOptions opts{src};
return self.Scatter(in_wrapper, out_wrapper, opts, sync_op); return self.Scatter(out_dense, in_dense, opts, sync_op);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("src"), py::arg("src"),
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
...@@ -986,16 +985,17 @@ void BindDistributed(py::module *m) { ...@@ -986,16 +985,17 @@ void BindDistributed(py::module *m) {
.def( .def(
"reduce_on_calc_stream", "reduce_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor, py::handle py_tensor,
int dst, int dst,
distributed::ReduceOp op) { distributed::ReduceOp op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
auto in_dense = *p_dense;
distributed::ReduceOptions opts{op, dst}; distributed::ReduceOptions opts{op, dst};
auto dense = std::dynamic_pointer_cast<phi::DenseTensor>( return self.Reduce(out_dense,
in_tensor.impl()); in_dense,
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Reduce(tensors,
tensors,
opts, opts,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
...@@ -1008,116 +1008,116 @@ void BindDistributed(py::module *m) { ...@@ -1008,116 +1008,116 @@ void BindDistributed(py::module *m) {
.def( .def(
"reduce_scatter_on_calc_stream", "reduce_scatter_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor_list,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor_list,
distributed::ReduceOp op) { distributed::ReduceOp op) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
auto out_dense = p_out_tensor.get();
auto in_tensor_list = auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl()); concat_in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense}; auto in_dense = *p_in_tensor;
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
distributed::ReduceScatterOptions opts{op}; distributed::ReduceScatterOptions opts{op};
return self.ReduceScatter(in_wrapper, return self.ReduceScatter(out_dense,
out_wrapper, in_dense,
opts, opts,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("op"), py::arg("op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"reduce_scatter_tensor_on_calc_stream", "reduce_scatter_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor,
distributed::ReduceOp op) { distributed::ReduceOp op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl()); out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
distributed::ReduceScatterOptions opts{op}; distributed::ReduceScatterOptions opts{op};
return self.ReduceScatter(in_wrapper, return self.ReduceScatter(out_dense,
out_wrapper, in_dense,
opts, opts,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("op"), py::arg("op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"scatter_on_calc_stream", "scatter_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor_list,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor_list,
int src) { int src) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
auto *out_dense = p_out_tensor.get();
auto in_tensor_list = auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl()); concat_in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense}; auto in_dense = *p_in_tensor;
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
distributed::ScatterOptions opts{src}; distributed::ScatterOptions opts{src};
return self.Scatter(in_wrapper, return self.Scatter(out_dense,
out_wrapper, in_dense,
opts, opts,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("src"), py::arg("src"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"scatter_tensor_on_calc_stream", "scatter_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor,
int src) { int src) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl()); out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
distributed::ScatterOptions opts{src}; distributed::ScatterOptions opts{src};
return self.Scatter(in_wrapper, return self.Scatter(out_dense,
out_wrapper, in_dense,
opts, opts,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("src"), py::arg("src"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
......
...@@ -57,11 +57,11 @@ def _reduce_scatter_tensor_in_dygraph( ...@@ -57,11 +57,11 @@ def _reduce_scatter_tensor_in_dygraph(
if use_calc_stream: if use_calc_stream:
return group.process_group.reduce_scatter_tensor_on_calc_stream( return group.process_group.reduce_scatter_tensor_on_calc_stream(
in_tensor, out_tensor, op_type out_tensor, in_tensor, op_type
) )
task = group.process_group.reduce_scatter_tensor( task = group.process_group.reduce_scatter_tensor(
in_tensor, out_tensor, op_type, sync_op out_tensor, in_tensor, op_type, sync_op
) )
if sync_op: if sync_op:
task.wait() task.wait()
...@@ -78,11 +78,11 @@ def _reduce_scatter_in_dygraph( ...@@ -78,11 +78,11 @@ def _reduce_scatter_in_dygraph(
if use_calc_stream: if use_calc_stream:
return group.process_group.reduce_scatter_on_calc_stream( return group.process_group.reduce_scatter_on_calc_stream(
tensor_list, tensor, op_type tensor, tensor_list, op_type
) )
task = group.process_group.reduce_scatter( task = group.process_group.reduce_scatter(
tensor_list, tensor, op_type, sync_op tensor, tensor_list, op_type, sync_op
) )
if sync_op: if sync_op:
task.wait() task.wait()
......
...@@ -53,11 +53,11 @@ def _scatter_tensor_in_dygraph( ...@@ -53,11 +53,11 @@ def _scatter_tensor_in_dygraph(
if use_calc_stream: if use_calc_stream:
return group.process_group.scatter_tensor_on_calc_stream( return group.process_group.scatter_tensor_on_calc_stream(
in_tensor, out_tensor, src_rank_in_group out_tensor, in_tensor, src_rank_in_group
) )
task = group.process_group.scatter_tensor( task = group.process_group.scatter_tensor(
in_tensor, out_tensor, src_rank_in_group, sync_op out_tensor, in_tensor, src_rank_in_group, sync_op
) )
if sync_op: if sync_op:
task.wait() task.wait()
...@@ -80,11 +80,11 @@ def _scatter_in_dygraph( ...@@ -80,11 +80,11 @@ def _scatter_in_dygraph(
if use_calc_stream: if use_calc_stream:
return group.process_group.scatter_on_calc_stream( return group.process_group.scatter_on_calc_stream(
tensor_list, tensor, src_rank_in_group tensor, tensor_list, src_rank_in_group
) )
task = group.process_group.scatter( task = group.process_group.scatter(
tensor_list, tensor, src_rank_in_group, sync_op tensor, tensor_list, src_rank_in_group, sync_op
) )
if sync_op: if sync_op:
task.wait() task.wait()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册