未验证 提交 fd689106 编写于 作者: W Wen Sun 提交者: GitHub

Remove unnecessary exports in `distributed.communication` and move `wait` & `barrier` (#48396)

* refactor: move wait

* refactor: move barrier

* fix: fix incorrect import
上级 174726fc
...@@ -27,10 +27,8 @@ from paddle.distributed.fleet.dataset import InMemoryDataset # noqa: F401 ...@@ -27,10 +27,8 @@ from paddle.distributed.fleet.dataset import InMemoryDataset # noqa: F401
from paddle.distributed.fleet.dataset import QueueDataset # noqa: F401 from paddle.distributed.fleet.dataset import QueueDataset # noqa: F401
from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401 from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401
from .collective import barrier # noqa: F401
from .collective import split # noqa: F401 from .collective import split # noqa: F401
from .collective import new_group # noqa: F401 from .collective import new_group # noqa: F401
from .collective import wait # noqa: F401
from .communication import ( from .communication import (
stream, stream,
...@@ -53,6 +51,8 @@ from .communication import ( ...@@ -53,6 +51,8 @@ from .communication import (
is_initialized, is_initialized,
destroy_process_group, destroy_process_group,
get_group, get_group,
wait,
barrier,
) # noqa: F401 ) # noqa: F401
from .auto_parallel import shard_op # noqa: F401 from .auto_parallel import shard_op # noqa: F401
......
...@@ -13,13 +13,10 @@ ...@@ -13,13 +13,10 @@
# limitations under the License. # limitations under the License.
import datetime import datetime
from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import in_dygraph_mode from ..fluid.framework import in_dygraph_mode
from ..fluid.framework import _non_static_mode from ..fluid.framework import _non_static_mode
from ..fluid.layers.tensor import fill_constant
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle import _legacy_C_ops
from .fleet.layers.mpu.mp_ops import split # noqa: F401 from .fleet.layers.mpu.mp_ops import split # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_identity # noqa: F401 from .fleet.layers.mpu.mp_ops import _c_identity # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_concat # noqa: F401 from .fleet.layers.mpu.mp_ops import _c_concat # noqa: F401
...@@ -160,60 +157,6 @@ def _new_process_group_impl( ...@@ -160,60 +157,6 @@ def _new_process_group_impl(
return pg return pg
def barrier(group=None):
"""
Barrier among all participators in the group.
Args:
group (Group): The group instance return by new_group or None for global default group.
Returns:
None.
Examples:
.. code-block:: python
import paddle
from paddle.distributed import init_parallel_env
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
paddle.distributed.barrier()
"""
if group is not None and not group.is_member():
return
if in_dygraph_mode():
group = _get_default_group() if group is None else group
place = paddle.fluid.framework._current_expected_place()
if isinstance(place, paddle.fluid.core.CPUPlace):
task = group.process_group.barrier()
else:
device_id = place.get_device_id()
task = group.process_group.barrier(device_id)
task.wait()
return
ring_id = 0 if group is None else group.id
temp = fill_constant([1], dtype="int32", value="1")
if _non_static_mode():
return _legacy_C_ops.barrier(temp, temp, 'ring_id', ring_id)
op_type = 'barrier'
if not isinstance(ring_id, int):
raise ValueError("The type of 'group' for barrier must be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [temp]},
outputs={'Out': [temp]},
attrs={'ring_id': ring_id},
)
# _custom_gid provides a way for users to # _custom_gid provides a way for users to
# set the group id, which is usually useful # set the group id, which is usually useful
# to be compatible with the static mode. # to be compatible with the static mode.
...@@ -356,78 +299,8 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout): ...@@ -356,78 +299,8 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
tmp = ( tmp = (
paddle.to_tensor([1], dtype="int32") paddle.to_tensor([1], dtype="int32")
if _non_static_mode() if _non_static_mode()
else fill_constant([0], dtype="int32", value="1") else paddle.full([0], 1, dtype="int32")
) )
paddle.distributed.all_reduce(tmp, sync_op=True) paddle.distributed.all_reduce(tmp, sync_op=True)
paddle.distributed.wait(tmp) paddle.distributed.wait(tmp)
return gp return gp
def wait(tensor, group=None, use_calc_stream=True):
"""
wait to sync stream for group.
Args:
tensor (Tensor): The Tensor used before sync.
group (Group): The Group instance to perform sync.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.
Returns:
None.
Examples:
.. code-block:: python
import paddle
paddle.distributed.init_parallel_env()
tindata = paddle.randn(shape=[2, 3])
paddle.distributed.all_reduce(tindata, sync_op=True)
paddle.distributed.wait(tindata)
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
if use_calc_stream:
_sync_calc_stream(tensor)
else:
_sync_comm_stream(tensor, ring_id)
def _sync_calc_stream(tensor):
if _non_static_mode():
return _legacy_C_ops.c_sync_calc_stream(tensor, tensor)
op_type = 'c_sync_calc_stream'
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
)
def _sync_comm_stream(tensor, ring_id=0):
if _non_static_mode():
return _legacy_C_ops.c_sync_comm_stream(
[tensor], [tensor], 'ring_id', ring_id
)
op_type = 'c_sync_comm_stream'
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={'ring_id': ring_id},
)
...@@ -21,26 +21,10 @@ from .scatter import scatter ...@@ -21,26 +21,10 @@ from .scatter import scatter
from .batch_isend_irecv import batch_isend_irecv, P2POp from .batch_isend_irecv import batch_isend_irecv, P2POp
from .reduce_scatter import reduce_scatter from .reduce_scatter import reduce_scatter
from .all_to_all import alltoall, alltoall_single from .all_to_all import alltoall, alltoall_single
from .group import is_initialized, destroy_process_group, get_group from .group import (
is_initialized,
__all__ = [ destroy_process_group,
"ReduceOp", get_group,
"all_gather", wait,
"all_gather_object", barrier,
"all_reduce", )
"alltoall",
"alltoall_single",
"broadcast",
"reduce",
"send",
"scatter",
"isend",
"recv",
"irecv",
"batch_isend_irecv",
"P2POp",
"reduce_scatter",
"is_initialized",
"destroy_process_group",
"get_group",
]
...@@ -13,7 +13,11 @@ ...@@ -13,7 +13,11 @@
# limitations under the License. # limitations under the License.
import warnings import warnings
import paddle
import paddle.distributed as dist import paddle.distributed as dist
import paddle.fluid.core as core
import paddle.fluid.framework as framework
import paddle.fluid.layer_helper as layer_helper
class Group: class Group:
...@@ -227,3 +231,122 @@ def get_group(id=0): ...@@ -227,3 +231,122 @@ def get_group(id=0):
return _GroupManager.group_map_by_id[id] return _GroupManager.group_map_by_id[id]
warnings.warn("Group {} is not initialized.".format(id)) warnings.warn("Group {} is not initialized.".format(id))
return None return None
def _sync_calc_stream(tensor):
if framework._non_static_mode():
return paddle._legacy_C_ops.c_sync_calc_stream(tensor, tensor)
op_type = 'c_sync_calc_stream'
helper = layer_helper.LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
)
def _sync_comm_stream(tensor, ring_id=0):
if framework._non_static_mode():
return paddle._legacy_C_ops.c_sync_comm_stream(
[tensor], [tensor], 'ring_id', ring_id
)
op_type = 'c_sync_comm_stream'
helper = layer_helper.LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={'ring_id': ring_id},
)
def wait(tensor, group=None, use_calc_stream=True):
"""
wait to sync stream for group.
Args:
tensor (Tensor): The Tensor used before sync.
group (Group): The Group instance to perform sync.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.
Returns:
None.
Examples:
.. code-block:: python
import paddle
paddle.distributed.init_parallel_env()
tindata = paddle.randn(shape=[2, 3])
paddle.distributed.all_reduce(tindata, sync_op=True)
paddle.distributed.wait(tindata)
"""
if group is not None and not group.is_member():
return
if use_calc_stream:
_sync_calc_stream(tensor)
else:
ring_id = 0 if group is None else group.id
_sync_comm_stream(tensor, ring_id)
def barrier(group=None):
"""
Barrier among all participators in the group.
Args:
group (Group): The group instance return by new_group or None for global default group.
Returns:
None.
Examples:
.. code-block:: python
import paddle
from paddle.distributed import init_parallel_env
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
paddle.distributed.barrier()
"""
if group is not None and not group.is_member():
return
if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group
place = framework._current_expected_place()
if isinstance(place, core.CPUPlace):
task = group.process_group.barrier()
else:
device_id = place.get_device_id()
task = group.process_group.barrier(device_id)
task.wait()
return
ring_id = 0 if group is None else group.id
barrier_tensor = paddle.full([1], 1, dtype="int32")
if framework._non_static_mode():
return paddle._legacy_C_ops.barrier(
barrier_tensor, barrier_tensor, 'ring_id', ring_id
)
op_type = 'barrier'
if not isinstance(ring_id, int):
raise ValueError("The type of 'group' for barrier must be int.")
helper = layer_helper.LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [barrier_tensor]},
outputs={'Out': [barrier_tensor]},
attrs={'ring_id': ring_id},
)
...@@ -34,7 +34,6 @@ from paddle.fluid.clip import ClipGradByGlobalNorm ...@@ -34,7 +34,6 @@ from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed.collective import ( from paddle.distributed.collective import (
_get_global_group, _get_global_group,
new_group, new_group,
wait,
) )
from ...utils.internal_storage import ParamStorage, GradStorage from ...utils.internal_storage import ParamStorage, GradStorage
...@@ -174,7 +173,7 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -174,7 +173,7 @@ class ShardingOptimizerStage2(Optimizer):
) )
# Multi stream operation will be supported later # Multi stream operation will be supported later
wait(tensor=p, group=self.group, use_calc_stream=True) dist.wait(tensor=p, group=self.group, use_calc_stream=True)
def _generate_master_params(self, trainable_params): def _generate_master_params(self, trainable_params):
if self.offload: if self.offload:
...@@ -464,7 +463,7 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -464,7 +463,7 @@ class ShardingOptimizerStage2(Optimizer):
) )
# Multi stream operation will be supported later # Multi stream operation will be supported later
wait( dist.wait(
tensor=internal_storage.buffer, tensor=internal_storage.buffer,
group=self.group, group=self.group,
use_calc_stream=True, use_calc_stream=True,
......
...@@ -318,7 +318,7 @@ class ShardingStage2(nn.Layer): ...@@ -318,7 +318,7 @@ class ShardingStage2(nn.Layer):
buffer, self._global_root_rank, self._group, sync_op=True buffer, self._global_root_rank, self._group, sync_op=True
) )
# Multi stream operation will be supported later # Multi stream operation will be supported later
collective.wait(tensor=buffer, group=self._group, use_calc_stream=True) dist.wait(tensor=buffer, group=self._group, use_calc_stream=True)
def __getattr__(self, name): def __getattr__(self, name):
"""Forward missing attributes to wrapped layer.""" """Forward missing attributes to wrapped layer."""
...@@ -382,7 +382,7 @@ class ShardingStage2(nn.Layer): ...@@ -382,7 +382,7 @@ class ShardingStage2(nn.Layer):
) )
# Multi stream operation will be supported later # Multi stream operation will be supported later
collective.wait( dist.wait(
tensor=param.grad, tensor=param.grad,
group=self._group, group=self._group,
use_calc_stream=True, use_calc_stream=True,
...@@ -448,7 +448,7 @@ class ShardingStage2(nn.Layer): ...@@ -448,7 +448,7 @@ class ShardingStage2(nn.Layer):
) )
# Multi stream operation will be supported later # Multi stream operation will be supported later
collective.wait( dist.wait(
tensor=grad_storage.buffer, tensor=grad_storage.buffer,
group=self._group, group=self._group,
use_calc_stream=True, use_calc_stream=True,
......
...@@ -184,7 +184,7 @@ class ShardingStage3(nn.Layer): ...@@ -184,7 +184,7 @@ class ShardingStage3(nn.Layer):
) )
# Multi stream operation will be supported later # Multi stream operation will be supported later
collective.wait(tensor=p, group=self._group, use_calc_stream=True) dist.wait(tensor=p, group=self._group, use_calc_stream=True)
def _clear_gradients(self): def _clear_gradients(self):
assert len(self._trainable_params.keys()) > 0 assert len(self._trainable_params.keys()) > 0
...@@ -485,7 +485,7 @@ class ShardingStage3(nn.Layer): ...@@ -485,7 +485,7 @@ class ShardingStage3(nn.Layer):
buffer, self._global_root_rank, self._group, sync_op=True buffer, self._global_root_rank, self._group, sync_op=True
) )
# Multi stream operation will be supported later # Multi stream operation will be supported later
collective.wait(tensor=buffer, group=self._group, use_calc_stream=True) dist.wait(tensor=buffer, group=self._group, use_calc_stream=True)
def __getattr__(self, name): def __getattr__(self, name):
"""Forward missing attributes to wrapped layer.""" """Forward missing attributes to wrapped layer."""
...@@ -529,7 +529,7 @@ class ShardingStage3(nn.Layer): ...@@ -529,7 +529,7 @@ class ShardingStage3(nn.Layer):
dist.all_reduce( dist.all_reduce(
tensor=grad_storage.buffer, group=self._group, sync_op=True tensor=grad_storage.buffer, group=self._group, sync_op=True
) )
collective.wait( dist.wait(
tensor=grad_storage.buffer, tensor=grad_storage.buffer,
group=self._group, group=self._group,
use_calc_stream=True, use_calc_stream=True,
...@@ -601,7 +601,7 @@ class ShardingStage3(nn.Layer): ...@@ -601,7 +601,7 @@ class ShardingStage3(nn.Layer):
dist.all_reduce( dist.all_reduce(
tensor=full_grad, group=self._group, sync_op=True tensor=full_grad, group=self._group, sync_op=True
) )
collective.wait( dist.wait(
tensor=full_grad, group=self._group, use_calc_stream=True tensor=full_grad, group=self._group, use_calc_stream=True
) )
...@@ -946,7 +946,7 @@ def _allgather_buffer( ...@@ -946,7 +946,7 @@ def _allgather_buffer(
# Allgather current layer in the 1st step synchronously # Allgather current layer in the 1st step synchronously
if sync_wait: if sync_wait:
with paddle.amp.auto_cast(enable=False): with paddle.amp.auto_cast(enable=False):
collective.wait( dist.wait(
tensor=full_param, tensor=full_param,
group=group, group=group,
use_calc_stream=use_calc_stream, use_calc_stream=use_calc_stream,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册