未验证 提交 8c0529fd 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Fix performance of pp+mp by using send/recv_calc_stream instead of send/recv (#46116)

上级 ff37e48e
......@@ -214,6 +214,16 @@ class ProcessGroup {
"ProcessGroup%s does not support AllGather_Partial", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
int offset,
int length,
bool) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllGather_Partial", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
......
......@@ -1034,6 +1034,41 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather_Partial(
CommType::ALLGATHER);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int offset,
int length,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllGather(
GetPointerByOffset(input.data(), offset, input.dtype()),
output.data(),
length,
platform::ToNCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
......
......@@ -182,6 +182,14 @@ class ProcessGroupNCCL : public ProcessGroupStream {
int offset,
int length) override;
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int offset,
int length,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in,
std::vector<phi::DenseTensor>& out) override;
......
......@@ -154,5 +154,30 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv_Partial(
"ProcessGroup%s does not support do recv_partial", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int offset,
int length,
bool sync_op) {
return AllGather_Partial(in_tensors,
out_tensors,
offset,
length,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int offset,
int length,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do recv_partial", GetBackendName()));
}
} // namespace distributed
} // namespace paddle
......@@ -132,6 +132,21 @@ class ProcessGroupStream : public ProcessGroup {
int length,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int offset,
int length,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
int offset,
int length,
bool sync_op,
bool use_calc_stream);
};
} // namespace distributed
......
......@@ -621,6 +621,37 @@ void BindDistributed(py::module *m) {
py::arg("op"),
py::call_guard<py::gil_scoped_release>())
.def(
"all_gather_partial_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
int nranks,
int rank_id) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
int numel = (*in_dense).numel();
int send_numel = numel / nranks;
int offset = send_numel * rank_id;
return self.AllGather_Partial(in_tensors,
out_tensors,
offset,
send_numel,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("in"),
py::arg("out"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"send_on_calc_stream",
[](distributed::ProcessGroupStream &self,
......
......@@ -43,7 +43,26 @@ def _c_identity(tensor, group=None):
return
ring_id = 0 if group is None else group.id
if _non_static_mode():
if in_dygraph_mode():
from paddle.autograd import PyLayer
class c_identity_eager(PyLayer):
@staticmethod
def forward(ctx, tensor):
return _legacy_C_ops.c_identity(tensor, 'use_calc_stream', True,
'ring_id', group.id,
'use_model_parallel', True)
@staticmethod
def backward(ctx, dy):
op_type = collective._get_reduce_op(ReduceOp.SUM, "_c_identity")
group.process_group.allreduce_on_calc_stream(dy, op_type)
return dy
return c_identity_eager.apply(tensor)
elif _in_legacy_dygraph():
return _legacy_C_ops.c_identity(tensor, 'use_calc_stream', True,
'ring_id', ring_id,
'use_model_parallel', True)
......
......@@ -173,7 +173,9 @@ def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks,
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
return group.process_group.send_partial(tensor, dst, nranks, rank_id)
comm_op = group.process_group.send_partial_on_calc_stream \
if use_calc_stream else group.process_group.send_partial
return comm_op(tensor, dst, nranks, rank_id)
def send_partial(tensor,
......@@ -212,12 +214,9 @@ def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks,
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
task = group.process_group.recv_partial(tensor, src, nranks, rank_id)
if use_calc_stream:
task.wait()
return None
else:
return task
comm_op = group.process_group.recv_partial_on_calc_stream \
if use_calc_stream else group.process_group.recv_partial
return comm_op(tensor, src, nranks, rank_id)
def recv_partial(tensor,
......@@ -255,13 +254,9 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks,
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
task = group.process_group.all_gather_partial(tensor, tensor, nranks,
rank_id)
if use_calc_stream:
task.wait()
return None
else:
return task
comm_op = group.process_group.all_gather_partial_on_calc_stream \
if use_calc_stream else group.process_group.all_gather_partial
return comm_op(tensor, tensor, nranks, rank_id)
def allgather_partial(tensor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册