未验证 提交 3f480af2 编写于 作者: W Wen Sun 提交者: GitHub

Refactor collective communication all_to_all, all_to_all_single C++ API (#48059)

上级 dbc63555
......@@ -46,7 +46,6 @@ enum class CommType : std::uint8_t {
SEND = 9,
RECV = 10,
BARRIER = 11,
ALLTOALL_SINGLE = 12,
UNKNOWN = 100,
};
......@@ -124,6 +123,17 @@ class ProcessGroup {
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support all_to_all with sync_op flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::Unimplemented(
......@@ -255,25 +265,6 @@ class ProcessGroup {
"ProcessGroup%s does not support alltoall", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllToAll_Single(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<int64_t>&,
std::vector<int64_t>&) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllToAll_Single", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllToAllSingle(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<int64_t>&,
std::vector<int64_t>&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support alltoall_single", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
......
......@@ -184,6 +184,80 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
use_calc_stream);
}
void CheckSizeOnEachRank(const phi::DDim& tensor_dim,
const std::vector<int64_t>& size_on_each_rank,
int world_size) {
int length_size_on_each_rank = size_on_each_rank.size();
PADDLE_ENFORCE_EQ(
length_size_on_each_rank,
world_size,
platform::errors::InvalidArgument(
"The length of size_on_each_rank must be equal to world_size."));
int64_t sum_size_on_each_rank =
std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0);
PADDLE_ENFORCE_EQ(
sum_size_on_each_rank,
tensor_dim[0],
platform::errors::InvalidArgument(
"The sum of size_on_each_rank must be equal to tensor's dim[0]."));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) {
const phi::DDim& out_dim = out_tensor->dims();
const phi::DDim& in_dim = in_tensor.dims();
CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
int64_t in_row_size = input.numel() / in_dim[0],
out_row_size = output->numel() / out_dim[0];
int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
phi::DenseTensor input_partial, output_partial;
GroupStart();
for (auto i = 0; i < size_; i++) {
in_numel = in_size_each_rank[i] * in_row_size;
input_partial = GetPartialTensor(input, in_offset, in_numel);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
input_partial.data(),
in_numel,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
in_offset += in_numel;
out_numel = out_size_each_rank[i] * out_row_size;
output_partial = GetPartialTensor(*output, out_offset, out_numel);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output_partial.data(),
out_numel,
platform::ToNCCLDataType(output->dtype()),
i,
comm,
stream));
out_offset += out_numel;
}
GroupEnd();
},
CommType::ALLTOALL,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
const BarrierOptions& opts) {
PADDLE_ENFORCE_GE(opts.device_id,
......@@ -551,7 +625,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
std::vector<phi::GPUContext*> dev_ctx_raw;
dev_ctx_raw.resize(places.size());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
GroupStart();
for (size_t i = 0; i < places.size(); ++i) {
platform::CUDADeviceGuard guard(places[i]);
......@@ -564,7 +638,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
dev_ctx_raw[i] = dev_ctx[i].get();
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
GroupEnd();
// TODO(sunyilun): for compatibility, will be removed later
place_to_calc_event_.emplace(places_key, places[0]);
......@@ -1086,7 +1160,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
ncclComm_t comm,
const gpuStream_t& stream) {
size_t offset = 0;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
GroupStart();
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()),
......@@ -1104,7 +1178,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
stream));
offset += input.numel() / size_;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
GroupEnd();
},
CommType::ALLTOALL);
}
......@@ -1130,7 +1204,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
ncclComm_t comm,
const gpuStream_t& stream) {
size_t offset = 0;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
GroupStart();
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()),
......@@ -1148,141 +1222,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
stream));
offset += input.numel() / size_;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
GroupEnd();
},
CommType::ALLTOALL,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll_Single(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
PADDLE_ENFORCE_EQ(input.dtype() == output.dtype(),
true,
platform::errors::InvalidArgument(
"The dtypes of input and output must be equal."));
std::vector<int64_t> in_dims = phi::vectorize(input.dims());
std::vector<int64_t> out_dims = phi::vectorize(output.dims());
CheckSplitSizes(&in_sizes, in_dims);
CheckSplitSizes(&out_sizes, out_dims);
size_t in_offset = 0, out_offset = 0;
size_t in_length = 0, out_length = 0;
size_t in_row_size = input.numel() / in_dims[0];
size_t out_row_size = output.numel() / out_dims[0];
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < size_; i++) {
in_length = in_sizes[i] * in_row_size;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), in_offset, input.dtype()),
in_length,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
in_offset += in_length;
out_length = out_sizes[i] * out_row_size;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
GetPointerByOffset(output.data(), out_offset, input.dtype()),
out_length,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
out_offset += out_length;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
},
CommType::ALLTOALL_SINGLE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAllSingle(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
PADDLE_ENFORCE_EQ(input.dtype() == output.dtype(),
true,
platform::errors::InvalidArgument(
"The dtypes of input and output must be equal."));
std::vector<int64_t> in_dims = phi::vectorize(input.dims());
std::vector<int64_t> out_dims = phi::vectorize(output.dims());
CheckSplitSizes(&in_sizes, in_dims);
CheckSplitSizes(&out_sizes, out_dims);
size_t in_offset = 0, out_offset = 0;
size_t in_length = 0, out_length = 0;
size_t in_row_size = input.numel() / in_dims[0];
size_t out_row_size = output.numel() / out_dims[0];
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < size_; i++) {
in_length = in_sizes[i] * in_row_size;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), in_offset, input.dtype()),
in_length,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
in_offset += in_length;
out_length = out_sizes[i] * out_row_size;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
GetPointerByOffset(output.data(), out_offset, input.dtype()),
out_length,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
out_offset += out_length;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
},
CommType::ALLTOALL_SINGLE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......@@ -1396,7 +1342,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
const gpuStream_t& stream) {
size_t offset = 0;
if (rank_ == opts.root_rank) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
GroupStart();
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()),
......@@ -1414,7 +1360,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
opts.root_rank,
comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
GroupEnd();
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(),
......@@ -1456,7 +1402,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
"Input and output tensors should have the same shape."));
size_t offset = 0;
if (rank_ == opts.root_rank) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
GroupStart();
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()),
......@@ -1474,7 +1420,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
opts.root_rank,
comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
GroupEnd();
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(),
......
......@@ -109,6 +109,14 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
......@@ -171,20 +179,6 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllToAll_Single(
std::vector<phi::DenseTensor>& in,
std::vector<phi::DenseTensor>& out,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes) override;
std::shared_ptr<ProcessGroup::Task> AllToAllSingle(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& tensors,
std::vector<phi::DenseTensor>& out_tensors,
......
......@@ -73,6 +73,31 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
"ProcessGroup%s does not support all_reduce.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op) {
return AllToAll(out_tensor,
in_tensor,
out_size_each_rank,
in_size_each_rank,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support all_to_all.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
......@@ -165,31 +190,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
"ProcessGroup%s does not support do alltoall", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes,
bool sync_op) {
return AllToAllSingle(in_tensors,
out_tensors,
in_sizes,
out_sizes,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do alltoall_single", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......
......@@ -89,6 +89,21 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
......@@ -140,21 +155,6 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> AllToAllSingle(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
std::vector<int64_t>& in_sizes, // NOLINT
std::vector<int64_t>& out_sizes, // NOLINT
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllToAllSingle(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
std::vector<int64_t>& in_sizes, // NOLINT
std::vector<int64_t>& out_sizes, // NOLINT
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
......
......@@ -277,7 +277,7 @@ void BindDistributed(py::module *m) {
/*offset*/ 0,
/*numel*/ -1,
sync_op);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx);
return task;
},
......@@ -316,84 +316,96 @@ void BindDistributed(py::module *m) {
.def(
"all_to_all",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor_list,
py::handle py_out_tensor_list,
py::handle py_in_tensor_list,
bool sync_op) {
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
auto in_dense = *p_in_tensor;
// in_tensor_list should not be empty
const auto &dev_ctx =
self.GetDeviceContext(in_tensor_list.back().place());
auto task = self.AllToAll(in_wrapper, out_wrapper, sync_op);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
int world_size = self.GetSize();
auto task =
self.AllToAll(out_dense,
in_dense,
GetDefaultSplitSizes(*out_dense, world_size),
GetDefaultSplitSizes(in_dense, world_size),
sync_op);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx);
return task;
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"all_to_all_tensor",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
py::handle py_in_tensor,
bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
return self.AllToAll(in_wrapper, out_wrapper, sync_op);
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
int world_size = self.GetSize();
return self.AllToAll(
out_dense,
in_dense,
GetDefaultSplitSizes(*out_dense, world_size),
GetDefaultSplitSizes(in_dense, world_size),
sync_op);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"all_to_all_single",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
std::vector<int64_t> &in_sizes,
std::vector<int64_t> &out_sizes,
py::handle py_in_tensor,
const std::vector<int64_t> &out_sizes,
const std::vector<int64_t> &in_sizes,
bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllToAllSingle(
in_wrapper, out_wrapper, in_sizes, out_sizes, sync_op);
return self.AllToAll(
out_dense, in_dense, out_sizes, in_sizes, sync_op);
},
py::arg("in"),
py::arg("out"),
py::arg("in_sizes"),
py::arg("in"),
py::arg("out_sizes"),
py::arg("in_sizes"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
......@@ -674,18 +686,20 @@ void BindDistributed(py::module *m) {
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
std::vector<int64_t> in_sizes,
std::vector<int64_t> out_sizes) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
const std::vector<int64_t> in_sizes,
const std::vector<int64_t> out_sizes) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll_Single(
in_tensors, out_tensors, in_sizes, out_sizes);
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllToAll(
out_dense, in_dense, out_sizes, in_sizes, /*sync_op*/ true);
},
py::arg("in"),
py::arg("out"),
......@@ -765,7 +779,7 @@ void BindDistributed(py::module *m) {
/*numel*/ -1,
/*sync_op*/ true,
/*use_calc_stream*/ true);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
return task;
},
py::arg("out"),
......@@ -856,88 +870,96 @@ void BindDistributed(py::module *m) {
.def(
"all_to_all_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor_list,
py::handle py_out_tensor_list) {
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
py::handle py_out_tensor_list,
py::handle py_in_tensor_list) {
auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
// in_tensor_list must not be empty
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
auto in_dense = *p_in_tensor;
// in_tensor_list should not be empty
const auto &dev_ctx = self.GetDeviceContext(
in_tensor_list.back().place(), /*use_calc_stream*/ true);
auto task = self.AllToAll(in_wrapper,
out_wrapper,
/*sync_op*/ true,
/*use_calc_stream*/ true);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
int world_size = self.GetSize();
auto task =
self.AllToAll(out_dense,
in_dense,
GetDefaultSplitSizes(*out_dense, world_size),
GetDefaultSplitSizes(in_dense, world_size),
/*sync_op*/ true,
/*use_calc_stream*/ true);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
return task;
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::call_guard<py::gil_scoped_release>())
.def(
"all_to_all_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
py::handle py_out_tensor,
py::handle py_in_tensor) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
return self.AllToAll(in_wrapper,
out_wrapper,
/*sync_op*/ true,
/*use_calc_stream*/ true);
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
int world_size = self.GetSize();
return self.AllToAll(
out_dense,
in_dense,
GetDefaultSplitSizes(*out_dense, world_size),
GetDefaultSplitSizes(in_dense, world_size),
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("in"),
py::arg("out"),
py::arg("in"),
py::call_guard<py::gil_scoped_release>())
.def(
"all_to_all_single_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
std::vector<int64_t> &in_sizes,
std::vector<int64_t> &out_sizes) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
py::handle py_in_tensor,
const std::vector<int64_t> &out_sizes,
const std::vector<int64_t> &in_sizes) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
auto *out_dense = p_out_tensor.get();
return self.AllToAllSingle(in_wrapper,
out_wrapper,
in_sizes,
out_sizes,
/*sync_op*/ true,
/*use_calc_stream*/ true);
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllToAll(out_dense,
in_dense,
out_sizes,
in_sizes,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("in"),
py::arg("out"),
py::arg("in_sizes"),
py::arg("in"),
py::arg("out_sizes"),
py::arg("in_sizes"),
py::call_guard<py::gil_scoped_release>())
.def(
......
......@@ -21,7 +21,7 @@
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
namespace paddle {
namespace distributed {
namespace pybind {
template <typename DeviceContext, typename T>
struct ConcatDenseTensor {
......@@ -113,6 +113,10 @@ void ConcatDenseTensorWithType(const DeviceContext &dev_ctx,
ConcatDenseTensor<DeviceContext, phi::dtype::float16>()(
dev_ctx, t_list, p_out);
break;
case phi::DataType::BFLOAT16:
ConcatDenseTensor<DeviceContext, phi::dtype::bfloat16>()(
dev_ctx, t_list, p_out);
break;
case phi::DataType::FLOAT32:
ConcatDenseTensor<DeviceContext, float>()(dev_ctx, t_list, p_out);
break;
......@@ -150,6 +154,10 @@ void SplitDenseTensorWithType(const DeviceContext &dev_ctx,
SplitDenseTensor<DeviceContext, phi::dtype::float16>()(
dev_ctx, t_in, p_list);
break;
case phi::DataType::BFLOAT16:
SplitDenseTensor<DeviceContext, phi::dtype::bfloat16>()(
dev_ctx, t_in, p_list);
break;
case phi::DataType::FLOAT32:
SplitDenseTensor<DeviceContext, float>()(dev_ctx, t_in, p_list);
break;
......@@ -249,5 +257,10 @@ void SplitTensor(const phi::DeviceContext &dev_ctx,
}
}
} // namespace distributed
inline std::vector<int64_t> GetDefaultSplitSizes(const phi::DenseTensor &tensor,
int world_size) {
return std::vector<int64_t>(world_size, tensor.dims()[0] / world_size);
}
} // namespace pybind
} // namespace paddle
......@@ -75,11 +75,11 @@ def _all_to_all_in_dygraph(
if use_calc_stream:
return group.process_group.all_to_all_on_calc_stream(
in_tensor_list, out_tensor_list
out_tensor_list, in_tensor_list
)
task = group.process_group.all_to_all(
in_tensor_list, out_tensor_list, sync_op
out_tensor_list, in_tensor_list, sync_op
)
if sync_op:
task.wait()
......@@ -243,18 +243,23 @@ def _alltoall_single_in_dygraph(
sync_op,
use_calc_stream,
):
world_size = dist.get_world_size()
if out_split_sizes is None:
out_split_sizes = []
out_split_sizes = [
out_tensor.shape[0] // world_size for _ in range(world_size)
]
if in_split_sizes is None:
in_split_sizes = []
in_split_sizes = [
in_tensor.shape[0] // world_size for _ in range(world_size)
]
if use_calc_stream:
return group.process_group.all_to_all_single_on_calc_stream(
in_tensor, out_tensor, in_split_sizes, out_split_sizes
out_tensor, in_tensor, out_split_sizes, in_split_sizes
)
task = group.process_group.all_to_all_single(
in_tensor, out_tensor, in_split_sizes, out_split_sizes, sync_op
out_tensor, in_tensor, out_split_sizes, in_split_sizes, sync_op
)
if sync_op:
task.wait()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册