未验证 提交 84c9a0d6 编写于 作者: L LiYuRio 提交者: GitHub

refine comm api implementation (#47713)

上级 3198af20
...@@ -391,9 +391,25 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -391,9 +391,25 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
} }
}; };
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper = {in_tensor};
std::vector<phi::DenseTensor> out_wrapper = {*out_tensor};
return AllGather(in_wrapper, out_wrapper, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather( std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) { std::vector<phi::DenseTensor>& out_tensors) {
return AllGather(in_tensors, out_tensors, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op) {
std::shared_ptr<AllgatherGlooTask> task; std::shared_ptr<AllgatherGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto context = get_context();
......
...@@ -108,6 +108,11 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -108,6 +108,11 @@ class ProcessGroupGloo : public ProcessGroup {
~ProcessGroupGloo() = default; ~ProcessGroupGloo() = default;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Broadcast( std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
...@@ -144,6 +149,11 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -144,6 +149,11 @@ class ProcessGroupGloo : public ProcessGroup {
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override; std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Reduce( std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
......
...@@ -546,7 +546,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True): ...@@ -546,7 +546,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
out = paddle.empty(tensor_shape, tensor.dtype) out = paddle.empty(tensor_shape, tensor.dtype)
else: else:
out = paddle.concat(tensor_list, axis=0) out = paddle.concat(tensor_list, axis=0)
task = group.process_group.all_gather(tensor, out) task = group.process_group.all_gather_into_tensor(out, tensor, sync_op)
task.wait() task.wait()
tensor_list.clear() tensor_list.clear()
list_of_tensor = paddle.split(out, group.nranks, 0) list_of_tensor = paddle.split(out, group.nranks, 0)
......
...@@ -44,12 +44,12 @@ def _all_gather_into_tensor_in_dygraph( ...@@ -44,12 +44,12 @@ def _all_gather_into_tensor_in_dygraph(
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks) _check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.allgather_into_tensor_on_calc_stream( return group.process_group.all_gather_into_tensor_on_calc_stream(
out_tensor, out_tensor,
in_tensor, in_tensor,
) )
task = group.process_group.allgather_into_tensor( task = group.process_group.all_gather_into_tensor(
out_tensor, in_tensor, sync_op out_tensor, in_tensor, sync_op
) )
if sync_op: if sync_op:
...@@ -69,9 +69,11 @@ def _all_gather_in_dygraph( ...@@ -69,9 +69,11 @@ def _all_gather_in_dygraph(
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks) _check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.allgather_on_calc_stream(tensor_list, tensor) return group.process_group.all_gather_on_calc_stream(
tensor_list, tensor
)
task = group.process_group.allgather(tensor_list, tensor, sync_op) task = group.process_group.all_gather(tensor_list, tensor, sync_op)
if sync_op: if sync_op:
task.wait() task.wait()
......
...@@ -25,11 +25,10 @@ from paddle.distributed.communication.group import ( ...@@ -25,11 +25,10 @@ from paddle.distributed.communication.group import (
def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream): def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream):
op_type = _get_reduce_op(op, "allreduce") op_type = _get_reduce_op(op, "allreduce")
group = _get_global_group() if group is None else group
if use_calc_stream: if use_calc_stream:
return group.process_group.allreduce_on_calc_stream(tensor, op_type) return group.process_group.all_reduce_on_calc_stream(tensor, op_type)
task = group.process_group.allreduce(tensor, op_type, sync_op) task = group.process_group.all_reduce(tensor, op_type, sync_op)
if sync_op: if sync_op:
task.wait() task.wait()
...@@ -119,6 +118,7 @@ def all_reduce( ...@@ -119,6 +118,7 @@ def all_reduce(
) )
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group
return _all_reduce_in_dygraph( return _all_reduce_in_dygraph(
tensor, op, group, sync_op, use_calc_stream tensor, op, group, sync_op, use_calc_stream
) )
......
...@@ -47,11 +47,11 @@ def _all_to_all_tensor_in_dygraph( ...@@ -47,11 +47,11 @@ def _all_to_all_tensor_in_dygraph(
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks) _check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.alltoall_tensor_on_calc_stream( return group.process_group.all_to_all_tensor_on_calc_stream(
in_tensor, out_tensor in_tensor, out_tensor
) )
task = group.process_group.alltoall_tensor(in_tensor, out_tensor, sync_op) task = group.process_group.all_to_all_tensor(in_tensor, out_tensor, sync_op)
if sync_op: if sync_op:
task.wait() task.wait()
...@@ -74,11 +74,11 @@ def _all_to_all_in_dygraph( ...@@ -74,11 +74,11 @@ def _all_to_all_in_dygraph(
) )
if use_calc_stream: if use_calc_stream:
return group.process_group.alltoall_on_calc_stream( return group.process_group.all_to_all_on_calc_stream(
in_tensor_list, out_tensor_list in_tensor_list, out_tensor_list
) )
task = group.process_group.alltoall( task = group.process_group.all_to_all(
in_tensor_list, out_tensor_list, sync_op in_tensor_list, out_tensor_list, sync_op
) )
if sync_op: if sync_op:
...@@ -249,11 +249,11 @@ def _alltoall_single_in_dygraph( ...@@ -249,11 +249,11 @@ def _alltoall_single_in_dygraph(
in_split_sizes = [] in_split_sizes = []
if use_calc_stream: if use_calc_stream:
return group.process_group.alltoall_single_on_calc_stream( return group.process_group.all_to_all_single_on_calc_stream(
in_tensor, out_tensor, in_split_sizes, out_split_sizes in_tensor, out_tensor, in_split_sizes, out_split_sizes
) )
task = group.process_group.alltoall_single( task = group.process_group.all_to_all_single(
in_tensor, out_tensor, in_split_sizes, out_split_sizes, sync_op in_tensor, out_tensor, in_split_sizes, out_split_sizes, sync_op
) )
if sync_op: if sync_op:
......
...@@ -52,7 +52,6 @@ def _reduce_scatter_tensor_in_dygraph( ...@@ -52,7 +52,6 @@ def _reduce_scatter_tensor_in_dygraph(
caller="reduce_scatter", caller="reduce_scatter",
): ):
op_type = _get_reduce_op(op, caller) op_type = _get_reduce_op(op, caller)
group = _get_global_group() if group is None else group
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks) _check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
...@@ -74,7 +73,6 @@ def _reduce_scatter_in_dygraph( ...@@ -74,7 +73,6 @@ def _reduce_scatter_in_dygraph(
tensor, tensor_list, op, group, sync_op, use_calc_stream tensor, tensor_list, op, group, sync_op, use_calc_stream
): ):
op_type = _get_reduce_op(op, "reduce_scatter") op_type = _get_reduce_op(op, "reduce_scatter")
group = _get_global_group() if group is None else group
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks) _check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
...@@ -149,6 +147,7 @@ def reduce_scatter( ...@@ -149,6 +147,7 @@ def reduce_scatter(
) )
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group
if paddle.is_tensor(tensor_or_tensor_list): if paddle.is_tensor(tensor_or_tensor_list):
return _reduce_scatter_tensor_in_dygraph( return _reduce_scatter_tensor_in_dygraph(
tensor, tensor,
...@@ -230,6 +229,7 @@ def _reduce_scatter_base( ...@@ -230,6 +229,7 @@ def _reduce_scatter_base(
) )
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group
return _reduce_scatter_tensor_in_dygraph( return _reduce_scatter_tensor_in_dygraph(
out_tensor, out_tensor,
in_tensor, in_tensor,
......
...@@ -62,7 +62,7 @@ def _c_identity(tensor, group=None): ...@@ -62,7 +62,7 @@ def _c_identity(tensor, group=None):
@staticmethod @staticmethod
def backward(ctx, dy): def backward(ctx, dy):
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity") op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
group.process_group.allreduce_on_calc_stream(dy, op_type) group.process_group.all_reduce_on_calc_stream(dy, op_type)
return dy return dy
return c_identity_eager.apply(tensor) return c_identity_eager.apply(tensor)
...@@ -255,7 +255,7 @@ def _mp_allreduce( ...@@ -255,7 +255,7 @@ def _mp_allreduce(
if use_calc_stream: if use_calc_stream:
op_type = _get_reduce_op(op, "_mp_allreduce") op_type = _get_reduce_op(op, "_mp_allreduce")
group.process_group.allreduce_on_calc_stream( group.process_group.all_reduce_on_calc_stream(
tensor, op_type tensor, op_type
) )
return tensor return tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册