未验证 提交 84c9a0d6 编写于 作者: L LiYuRio 提交者: GitHub

refine comm api implementation (#47713)

上级 3198af20
......@@ -391,9 +391,25 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
}
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper = {in_tensor};
std::vector<phi::DenseTensor> out_wrapper = {*out_tensor};
return AllGather(in_wrapper, out_wrapper, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
return AllGather(in_tensors, out_tensors, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op) {
std::shared_ptr<AllgatherGlooTask> task;
auto tag = next_tag();
auto context = get_context();
......
......@@ -108,6 +108,11 @@ class ProcessGroupGloo : public ProcessGroup {
~ProcessGroupGloo() = default;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
......@@ -144,6 +149,11 @@ class ProcessGroupGloo : public ProcessGroup {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......
......@@ -546,7 +546,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
out = paddle.empty(tensor_shape, tensor.dtype)
else:
out = paddle.concat(tensor_list, axis=0)
task = group.process_group.all_gather(tensor, out)
task = group.process_group.all_gather_into_tensor(out, tensor, sync_op)
task.wait()
tensor_list.clear()
list_of_tensor = paddle.split(out, group.nranks, 0)
......
......@@ -44,12 +44,12 @@ def _all_gather_into_tensor_in_dygraph(
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
if use_calc_stream:
return group.process_group.allgather_into_tensor_on_calc_stream(
return group.process_group.all_gather_into_tensor_on_calc_stream(
out_tensor,
in_tensor,
)
task = group.process_group.allgather_into_tensor(
task = group.process_group.all_gather_into_tensor(
out_tensor, in_tensor, sync_op
)
if sync_op:
......@@ -69,9 +69,11 @@ def _all_gather_in_dygraph(
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
if use_calc_stream:
return group.process_group.allgather_on_calc_stream(tensor_list, tensor)
return group.process_group.all_gather_on_calc_stream(
tensor_list, tensor
)
task = group.process_group.allgather(tensor_list, tensor, sync_op)
task = group.process_group.all_gather(tensor_list, tensor, sync_op)
if sync_op:
task.wait()
......
......@@ -25,11 +25,10 @@ from paddle.distributed.communication.group import (
def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream):
op_type = _get_reduce_op(op, "allreduce")
group = _get_global_group() if group is None else group
if use_calc_stream:
return group.process_group.allreduce_on_calc_stream(tensor, op_type)
return group.process_group.all_reduce_on_calc_stream(tensor, op_type)
task = group.process_group.allreduce(tensor, op_type, sync_op)
task = group.process_group.all_reduce(tensor, op_type, sync_op)
if sync_op:
task.wait()
......@@ -119,6 +118,7 @@ def all_reduce(
)
if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group
return _all_reduce_in_dygraph(
tensor, op, group, sync_op, use_calc_stream
)
......
......@@ -47,11 +47,11 @@ def _all_to_all_tensor_in_dygraph(
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
if use_calc_stream:
return group.process_group.alltoall_tensor_on_calc_stream(
return group.process_group.all_to_all_tensor_on_calc_stream(
in_tensor, out_tensor
)
task = group.process_group.alltoall_tensor(in_tensor, out_tensor, sync_op)
task = group.process_group.all_to_all_tensor(in_tensor, out_tensor, sync_op)
if sync_op:
task.wait()
......@@ -74,11 +74,11 @@ def _all_to_all_in_dygraph(
)
if use_calc_stream:
return group.process_group.alltoall_on_calc_stream(
return group.process_group.all_to_all_on_calc_stream(
in_tensor_list, out_tensor_list
)
task = group.process_group.alltoall(
task = group.process_group.all_to_all(
in_tensor_list, out_tensor_list, sync_op
)
if sync_op:
......@@ -249,11 +249,11 @@ def _alltoall_single_in_dygraph(
in_split_sizes = []
if use_calc_stream:
return group.process_group.alltoall_single_on_calc_stream(
return group.process_group.all_to_all_single_on_calc_stream(
in_tensor, out_tensor, in_split_sizes, out_split_sizes
)
task = group.process_group.alltoall_single(
task = group.process_group.all_to_all_single(
in_tensor, out_tensor, in_split_sizes, out_split_sizes, sync_op
)
if sync_op:
......
......@@ -52,7 +52,6 @@ def _reduce_scatter_tensor_in_dygraph(
caller="reduce_scatter",
):
op_type = _get_reduce_op(op, caller)
group = _get_global_group() if group is None else group
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
......@@ -74,7 +73,6 @@ def _reduce_scatter_in_dygraph(
tensor, tensor_list, op, group, sync_op, use_calc_stream
):
op_type = _get_reduce_op(op, "reduce_scatter")
group = _get_global_group() if group is None else group
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
......@@ -149,6 +147,7 @@ def reduce_scatter(
)
if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group
if paddle.is_tensor(tensor_or_tensor_list):
return _reduce_scatter_tensor_in_dygraph(
tensor,
......@@ -230,6 +229,7 @@ def _reduce_scatter_base(
)
if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group
return _reduce_scatter_tensor_in_dygraph(
out_tensor,
in_tensor,
......
......@@ -62,7 +62,7 @@ def _c_identity(tensor, group=None):
@staticmethod
def backward(ctx, dy):
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
group.process_group.allreduce_on_calc_stream(dy, op_type)
group.process_group.all_reduce_on_calc_stream(dy, op_type)
return dy
return c_identity_eager.apply(tensor)
......@@ -255,7 +255,7 @@ def _mp_allreduce(
if use_calc_stream:
op_type = _get_reduce_op(op, "_mp_allreduce")
group.process_group.allreduce_on_calc_stream(
group.process_group.all_reduce_on_calc_stream(
tensor, op_type
)
return tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册