未验证 提交 e0b50269 编写于 作者: W Wen Sun 提交者: GitHub

refactor: rm fluid deps in distributed communication (#49722)

上级 b1faa562
...@@ -16,7 +16,7 @@ import numpy as np ...@@ -16,7 +16,7 @@ import numpy as np
import paddle import paddle
import paddle.distributed.communication.stream as stream import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework import paddle.framework as framework
from .serialization_utils import ( from .serialization_utils import (
convert_object_to_tensor, convert_object_to_tensor,
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
import contextlib import contextlib
import paddle.distributed as dist import paddle.distributed as dist
import paddle.fluid.core as core import paddle.framework as framework
import paddle.fluid.framework as framework
from paddle.distributed.communication.group import ( from paddle.distributed.communication.group import (
_get_global_group, _get_global_group,
_warn_cur_rank_not_in_group, _warn_cur_rank_not_in_group,
...@@ -79,12 +78,12 @@ class P2POp: ...@@ -79,12 +78,12 @@ class P2POp:
@contextlib.contextmanager @contextlib.contextmanager
def _with_batch_p2p_guard(backend): def _with_batch_p2p_guard(backend):
if backend == "NCCL": if backend == "NCCL":
core.ProcessGroupNCCL.group_start() framework.core.ProcessGroupNCCL.group_start()
try: try:
yield yield
finally: finally:
if backend == "NCCL": if backend == "NCCL":
core.ProcessGroupNCCL.group_end() framework.core.ProcessGroupNCCL.group_end()
def _check_p2p_op_list(p2p_op_list): def _check_p2p_op_list(p2p_op_list):
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
import paddle.distributed.communication.stream as stream import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework import paddle.framework as framework
from .serialization_utils import ( from .serialization_utils import (
convert_object_to_tensor, convert_object_to_tensor,
......
...@@ -16,9 +16,7 @@ import warnings ...@@ -16,9 +16,7 @@ import warnings
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
import paddle.fluid.core as core import paddle.framework as framework
import paddle.fluid.framework as framework
import paddle.fluid.layer_helper as layer_helper
class Group: class Group:
...@@ -239,7 +237,7 @@ def _sync_calc_stream(tensor): ...@@ -239,7 +237,7 @@ def _sync_calc_stream(tensor):
return paddle._legacy_C_ops.c_sync_calc_stream(tensor, tensor) return paddle._legacy_C_ops.c_sync_calc_stream(tensor, tensor)
else: else:
op_type = 'c_sync_calc_stream' op_type = 'c_sync_calc_stream'
helper = layer_helper.LayerHelper(op_type, **locals()) helper = framework.LayerHelper(op_type, **locals())
helper.append_op( helper.append_op(
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
...@@ -254,7 +252,7 @@ def _sync_comm_stream(tensor, ring_id=0): ...@@ -254,7 +252,7 @@ def _sync_comm_stream(tensor, ring_id=0):
) )
else: else:
op_type = 'c_sync_comm_stream' op_type = 'c_sync_comm_stream'
helper = layer_helper.LayerHelper(op_type, **locals()) helper = framework.LayerHelper(op_type, **locals())
helper.append_op( helper.append_op(
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
...@@ -325,7 +323,7 @@ def barrier(group=None): ...@@ -325,7 +323,7 @@ def barrier(group=None):
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group group = _get_global_group() if group is None else group
place = framework._current_expected_place() place = framework._current_expected_place()
if isinstance(place, core.CPUPlace): if isinstance(place, framework.CPUPlace):
task = group.process_group.barrier() task = group.process_group.barrier()
else: else:
device_id = place.get_device_id() device_id = place.get_device_id()
...@@ -344,7 +342,7 @@ def barrier(group=None): ...@@ -344,7 +342,7 @@ def barrier(group=None):
op_type = 'barrier' op_type = 'barrier'
if not isinstance(ring_id, int): if not isinstance(ring_id, int):
raise ValueError("The type of 'group' for barrier must be int.") raise ValueError("The type of 'group' for barrier must be int.")
helper = layer_helper.LayerHelper(op_type, **locals()) helper = framework.LayerHelper(op_type, **locals())
helper.append_op( helper.append_op(
type=op_type, type=op_type,
inputs={'X': [barrier_tensor]}, inputs={'X': [barrier_tensor]},
......
...@@ -14,8 +14,7 @@ ...@@ -14,8 +14,7 @@
import paddle import paddle
import paddle.distributed.communication.stream as stream import paddle.distributed.communication.stream as stream
import paddle.fluid.core as core import paddle.framework as framework
import paddle.fluid.framework as framework
class ReduceOp: class ReduceOp:
...@@ -59,13 +58,13 @@ class ReduceOp: ...@@ -59,13 +58,13 @@ class ReduceOp:
def _get_reduce_op(reduce_op, func_name): def _get_reduce_op(reduce_op, func_name):
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
if reduce_op == ReduceOp.SUM: if reduce_op == ReduceOp.SUM:
return core.ReduceOp.SUM return framework.core.ReduceOp.SUM
elif reduce_op == ReduceOp.MAX: elif reduce_op == ReduceOp.MAX:
return core.ReduceOp.MAX return framework.core.ReduceOp.MAX
elif reduce_op == ReduceOp.MIN: elif reduce_op == ReduceOp.MIN:
return core.ReduceOp.MIN return framework.core.ReduceOp.MIN
elif reduce_op == ReduceOp.PROD: elif reduce_op == ReduceOp.PROD:
return core.ReduceOp.PRODUCT return framework.core.ReduceOp.PRODUCT
else: else:
if reduce_op == ReduceOp.SUM: if reduce_op == ReduceOp.SUM:
return 'c_{}_sum'.format(func_name) return 'c_{}_sum'.format(func_name)
......
...@@ -17,7 +17,7 @@ import numpy as np ...@@ -17,7 +17,7 @@ import numpy as np
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
import paddle.distributed.communication.stream as stream import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework import paddle.framework as framework
from .serialization_utils import ( from .serialization_utils import (
convert_object_to_tensor, convert_object_to_tensor,
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
import paddle.fluid.data_feeder as data_feeder import paddle.fluid.data_feeder as data_feeder
import paddle.fluid.framework as framework import paddle.framework as framework
import paddle.fluid.layer_helper as layer_helper
from paddle.distributed.communication.group import _get_global_group from paddle.distributed.communication.group import _get_global_group
...@@ -62,7 +61,7 @@ def _all_gather_in_dygraph( ...@@ -62,7 +61,7 @@ def _all_gather_in_dygraph(
def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op): def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op):
op_type = 'c_allgather' op_type = 'c_allgather'
helper = layer_helper.LayerHelper(op_type, **locals()) helper = framework.LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype) out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
for elem in tensor_list: for elem in tensor_list:
data_feeder.check_variable_and_dtype( data_feeder.check_variable_and_dtype(
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle.fluid.data_feeder as data_feeder import paddle.fluid.data_feeder as data_feeder
import paddle.fluid.framework as framework import paddle.framework as framework
import paddle.fluid.layer_helper as layer_helper
from paddle.distributed.communication.group import ( from paddle.distributed.communication.group import (
_get_global_group, _get_global_group,
_warn_cur_rank_not_in_group, _warn_cur_rank_not_in_group,
...@@ -60,7 +59,7 @@ def _all_reduce_in_static_mode(tensor, op, group, sync_op, use_calc_stream): ...@@ -60,7 +59,7 @@ def _all_reduce_in_static_mode(tensor, op, group, sync_op, use_calc_stream):
# TODO: Support task and use task.wait in static graph mode # TODO: Support task and use task.wait in static graph mode
# Use use_calc_stream rather than sync_op # Use use_calc_stream rather than sync_op
helper = layer_helper.LayerHelper(op_type, **locals()) helper = framework.LayerHelper(op_type, **locals())
helper.append_op( helper.append_op(
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
import paddle.fluid.data_feeder as data_feeder import paddle.fluid.data_feeder as data_feeder
import paddle.fluid.framework as framework import paddle.framework as framework
import paddle.fluid.layer_helper as layer_helper
from paddle.distributed.communication.group import ( from paddle.distributed.communication.group import (
_get_global_group, _get_global_group,
_warn_cur_rank_not_in_group, _warn_cur_rank_not_in_group,
...@@ -73,7 +72,7 @@ def _all_to_all_in_static_mode( ...@@ -73,7 +72,7 @@ def _all_to_all_in_static_mode(
op_type = 'alltoall' op_type = 'alltoall'
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
nranks = dist.get_world_size() nranks = dist.get_world_size()
helper = layer_helper.LayerHelper(op_type, **locals()) helper = framework.LayerHelper(op_type, **locals())
in_tensor = in_tensor_or_tensor_list in_tensor = in_tensor_or_tensor_list
if isinstance(in_tensor_or_tensor_list, list): if isinstance(in_tensor_or_tensor_list, list):
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle.fluid.data_feeder as data_feeder import paddle.fluid.data_feeder as data_feeder
import paddle.fluid.framework as framework import paddle.framework as framework
import paddle.fluid.layer_helper as layer_helper
from paddle.distributed.communication.group import ( from paddle.distributed.communication.group import (
_get_global_group, _get_global_group,
_get_or_throw_group_rank, _get_or_throw_group_rank,
...@@ -57,7 +56,7 @@ def _broadcast_in_static_mode( ...@@ -57,7 +56,7 @@ def _broadcast_in_static_mode(
) )
op_type = 'c_broadcast' op_type = 'c_broadcast'
helper = layer_helper.LayerHelper(op_type, **locals()) helper = framework.LayerHelper(op_type, **locals())
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
helper.append_op( helper.append_op(
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle.fluid.data_feeder as data_feeder import paddle.fluid.data_feeder as data_feeder
import paddle.fluid.framework as framework import paddle.framework as framework
import paddle.fluid.layer_helper as layer_helper
from paddle.distributed.communication.group import ( from paddle.distributed.communication.group import (
_get_global_group, _get_global_group,
_get_or_throw_group_rank, _get_or_throw_group_rank,
...@@ -48,7 +47,7 @@ def _recv_in_static_mode( ...@@ -48,7 +47,7 @@ def _recv_in_static_mode(
'recv', 'recv',
) )
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
helper = layer_helper.LayerHelper(op_type, **locals()) helper = framework.LayerHelper(op_type, **locals())
helper.append_op( helper.append_op(
type=op_type, type=op_type,
outputs={'Out': [tensor]}, outputs={'Out': [tensor]},
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle.fluid.data_feeder as data_feeder import paddle.fluid.data_feeder as data_feeder
import paddle.fluid.framework as framework import paddle.framework as framework
import paddle.fluid.layer_helper as layer_helper
from paddle.distributed.communication.group import ( from paddle.distributed.communication.group import (
_get_global_group, _get_global_group,
_get_or_throw_group_rank, _get_or_throw_group_rank,
...@@ -63,7 +62,7 @@ def _reduce_in_static_mode( ...@@ -63,7 +62,7 @@ def _reduce_in_static_mode(
op_type = _get_reduce_op(op, "reduce") op_type = _get_reduce_op(op, "reduce")
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
helper = layer_helper.LayerHelper(op_type, **locals()) helper = framework.LayerHelper(op_type, **locals())
helper.append_op( helper.append_op(
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
import paddle.fluid.framework as framework import paddle.framework as framework
from paddle.distributed.communication.group import ( from paddle.distributed.communication.group import (
_get_global_group, _get_global_group,
_warn_cur_rank_not_in_group, _warn_cur_rank_not_in_group,
......
...@@ -17,8 +17,7 @@ import warnings ...@@ -17,8 +17,7 @@ import warnings
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
import paddle.fluid.data_feeder as data_feeder import paddle.fluid.data_feeder as data_feeder
import paddle.fluid.framework as framework import paddle.framework as framework
import paddle.fluid.layer_helper as layer_helper
from paddle.distributed.communication.group import ( from paddle.distributed.communication.group import (
_get_global_group, _get_global_group,
_get_or_throw_group_rank, _get_or_throw_group_rank,
...@@ -113,7 +112,7 @@ def _scatter_in_static_mode( ...@@ -113,7 +112,7 @@ def _scatter_in_static_mode(
) )
op_type = 'c_scatter' op_type = 'c_scatter'
helper = layer_helper.LayerHelper(op_type, **locals()) helper = framework.LayerHelper(op_type, **locals())
helper.append_op( helper.append_op(
type=op_type, type=op_type,
inputs={'X': [input_tensor]}, inputs={'X': [input_tensor]},
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle.fluid.data_feeder as data_feeder import paddle.fluid.data_feeder as data_feeder
import paddle.fluid.framework as framework import paddle.framework as framework
import paddle.fluid.layer_helper as layer_helper
from paddle.distributed.communication.group import ( from paddle.distributed.communication.group import (
_get_global_group, _get_global_group,
_get_or_throw_group_rank, _get_or_throw_group_rank,
...@@ -49,7 +48,7 @@ def _send_in_static_mode( ...@@ -49,7 +48,7 @@ def _send_in_static_mode(
) )
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
helper = layer_helper.LayerHelper(op_type, **locals()) helper = framework.LayerHelper(op_type, **locals())
helper.append_op( helper.append_op(
type=op_type, type=op_type,
inputs={'X': [tensor]}, inputs={'X': [tensor]},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册