未验证 提交 e9e68c36 编写于 作者: L lilong12 提交者: GitHub

support group with one rank (#41398)

上级 814315b4
...@@ -16,7 +16,6 @@ import numpy as np ...@@ -16,7 +16,6 @@ import numpy as np
import os import os
from datetime import timedelta from datetime import timedelta
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
import paddle.fluid.framework as framework
from ..fluid.framework import Variable from ..fluid.framework import Variable
from ..fluid.framework import in_dygraph_mode from ..fluid.framework import in_dygraph_mode
from ..fluid.framework import OpProtoHolder from ..fluid.framework import OpProtoHolder
...@@ -144,6 +143,16 @@ _default_store = None # the default tcp store ...@@ -144,6 +143,16 @@ _default_store = None # the default tcp store
_default_backend = None _default_backend = None
def _set_default_backend(backend):
global _default_backend
_default_backend = backend
def _set_default_store(store):
global _default_store
_default_store = store
def _get_group_map(): def _get_group_map():
global _group_map global _group_map
if not _group_map: if not _group_map:
...@@ -159,19 +168,29 @@ def _get_global_group(): ...@@ -159,19 +168,29 @@ def _get_global_group():
def _get_group_map_by_name(): def _get_group_map_by_name():
global _group_map_by_name global _group_map_by_name
assert _default_group_name in _group_map_by_name, (
"Call paddle.distributed.init_parallel_env first "
"to initialize the distributed environment.")
return _group_map_by_name return _group_map_by_name
def _get_default_group(): def _get_default_group():
global _group_map_by_name
assert _default_group_name in _group_map_by_name, ( assert _default_group_name in _group_map_by_name, (
"Call paddle.distributed.init_parallel_env first " "Call paddle.distributed.init_parallel_env first "
"to initialize the distributed environment.") "to initialize the distributed environment.")
return _get_group_map_by_name()[_default_group_name] return _get_group_map_by_name()[_default_group_name]
def _set_group_map(gid, group):
global _group_map
assert gid not in _group_map
_group_map[gid] = group
def _set_group_map_by_name(name, group):
global _group_map_by_name
assert name not in _group_map_by_name
_group_map_by_name[name] = group
def _new_ring_id(): def _new_ring_id():
return len(_get_group_map()) + max(_get_global_env().nrings, 9) return len(_get_group_map()) + max(_get_global_env().nrings, 9)
...@@ -208,6 +227,7 @@ def _new_process_group_impl(backend, ...@@ -208,6 +227,7 @@ def _new_process_group_impl(backend,
pg_options, pg_options,
group_id=0): group_id=0):
pg = None pg = None
assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
if backend == "gloo": if backend == "gloo":
pg = core.ProcessGroupGloo(store, rank, world_size, group_id) pg = core.ProcessGroupGloo(store, rank, world_size, group_id)
elif backend == "nccl": elif backend == "nccl":
...@@ -242,7 +262,7 @@ def barrier(group=None): ...@@ -242,7 +262,7 @@ def barrier(group=None):
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
if framework._in_eager_mode_ and in_dygraph_mode(): if in_dygraph_mode():
group = _get_default_group() if group is None else group group = _get_default_group() if group is None else group
task = group.process_group.barrier() task = group.process_group.barrier()
task.wait() task.wait()
...@@ -290,22 +310,22 @@ def new_group(ranks=None, backend=None): ...@@ -290,22 +310,22 @@ def new_group(ranks=None, backend=None):
""" """
global _group_map global _group_map
if framework._in_eager_mode_: if in_dygraph_mode():
global _default_group_name global _default_group_name
gid = _new_ring_id() gid = _new_ring_id()
group_name = _default_group_name + str(gid) group_name = _default_group_name + str(gid)
global_group = _get_default_group() global_group = _get_default_group()
global_rank = global_group.rank global_rank = global_group.rank
global_ranks = global_group.ranks global_ranks = global_group.ranks
backend = _default_backend if backend is None else backend
if ranks is None: if ranks is None:
ranks = global_ranks ranks = global_ranks
assert len(ranks) <= len(global_ranks), ( assert len(ranks) <= len(global_ranks), (
"Size of new group must be less than or " "Size of new group must be less than or "
"equal to that of the default global group.") "equal to that of the default global group.")
size = len(ranks) size = len(ranks)
assert size > 1, "A group must have at least two memebers."
ranks = sorted(ranks) ranks = sorted(ranks)
if global_rank in ranks: if global_rank in ranks and size > 1:
rank = ranks.index(global_rank) rank = ranks.index(global_rank)
pg = _new_process_group_impl( pg = _new_process_group_impl(
backend, backend,
...@@ -495,7 +515,7 @@ def broadcast(tensor, src, group=None, use_calc_stream=True): ...@@ -495,7 +515,7 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):
if not isinstance(src, int): if not isinstance(src, int):
raise ValueError("src should be int.") raise ValueError("src should be int.")
if framework._in_eager_mode_ and in_dygraph_mode(): if in_dygraph_mode():
group = _get_default_group() if group is None else group group = _get_default_group() if group is None else group
gsrc = group.get_group_rank(src) gsrc = group.get_group_rank(src)
assert gsrc >= 0, ("src rank out of group, need global rank") assert gsrc >= 0, ("src rank out of group, need global rank")
...@@ -579,7 +599,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): ...@@ -579,7 +599,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
if framework._in_eager_mode_ and in_dygraph_mode(): if in_dygraph_mode():
if op == ReduceOp.SUM: if op == ReduceOp.SUM:
op_type = core.ReduceOp.SUM op_type = core.ReduceOp.SUM
elif op == ReduceOp.MAX: elif op == ReduceOp.MAX:
...@@ -681,7 +701,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): ...@@ -681,7 +701,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
if framework._in_eager_mode_ and in_dygraph_mode(): if in_dygraph_mode():
if op == ReduceOp.SUM: if op == ReduceOp.SUM:
op_type = core.ReduceOp.SUM op_type = core.ReduceOp.SUM
elif op == ReduceOp.MAX: elif op == ReduceOp.MAX:
...@@ -802,7 +822,7 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): ...@@ -802,7 +822,7 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
if framework._in_eager_mode_ and in_dygraph_mode(): if in_dygraph_mode():
group = _get_default_group() if group is None else group group = _get_default_group() if group is None else group
out = paddle.concat(tensor_list) out = paddle.concat(tensor_list)
task = group.process_group.all_gather(tensor, out) task = group.process_group.all_gather(tensor, out)
...@@ -899,7 +919,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): ...@@ -899,7 +919,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
if not isinstance(src, int): if not isinstance(src, int):
raise ValueError("src should be int.") raise ValueError("src should be int.")
if framework._in_eager_mode_ and in_dygraph_mode(): if in_dygraph_mode():
group = _get_default_group() if group is None else group group = _get_default_group() if group is None else group
gsrc = group.get_group_rank(src) gsrc = group.get_group_rank(src)
rank = group.rank rank = group.rank
...@@ -916,7 +936,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): ...@@ -916,7 +936,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
for _ in range(nranks): for _ in range(nranks):
tensor_list.append(tensor) tensor_list.append(tensor)
temp = paddle.concat(tensor_list, axis=0) temp = paddle.concat(tensor_list, axis=0)
if framework._in_eager_mode_ and in_dygraph_mode(): if in_dygraph_mode():
task = group.process_group.scatter(temp, tensor, gsrc) task = group.process_group.scatter(temp, tensor, gsrc)
if use_calc_stream: if use_calc_stream:
task.wait() task.wait()
...@@ -924,7 +944,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): ...@@ -924,7 +944,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
else: else:
return task return task
if in_dygraph_mode(): if _non_static_mode():
return _C_ops.c_scatter(temp, tensor, 'use_calc_stream', return _C_ops.c_scatter(temp, tensor, 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id, 'nranks', use_calc_stream, 'ring_id', ring_id, 'nranks',
nranks, 'root', gsrc) nranks, 'root', gsrc)
...@@ -1694,14 +1714,14 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): ...@@ -1694,14 +1714,14 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
if framework._in_eager_mode_ and in_dygraph_mode(): if in_dygraph_mode():
group = _get_default_group() if group is None else group group = _get_default_group() if group is None else group
else: else:
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
temp = paddle.concat(in_tensor_list, axis=0) temp = paddle.concat(in_tensor_list, axis=0)
nranks = len(in_tensor_list) nranks = len(in_tensor_list)
if framework._in_eager_mode_ and in_dygraph_mode(): if in_dygraph_mode():
out = paddle.concat(out_tensor_list, axis=0) out = paddle.concat(out_tensor_list, axis=0)
task = group.process_group.alltoall(temp, out) task = group.process_group.alltoall(temp, out)
task.wait() task.wait()
...@@ -1776,7 +1796,7 @@ def send(tensor, dst=0, group=None, use_calc_stream=True): ...@@ -1776,7 +1796,7 @@ def send(tensor, dst=0, group=None, use_calc_stream=True):
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
if framework._in_eager_mode_ and in_dygraph_mode(): if in_dygraph_mode():
group = _get_default_group() if group is None else group group = _get_default_group() if group is None else group
task = group.process_group.send(tensor, dst) task = group.process_group.send(tensor, dst)
if use_calc_stream: if use_calc_stream:
...@@ -1839,7 +1859,7 @@ def recv(tensor, src=0, group=None, use_calc_stream=True): ...@@ -1839,7 +1859,7 @@ def recv(tensor, src=0, group=None, use_calc_stream=True):
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
if framework._in_eager_mode_ and in_dygraph_mode(): if in_dygraph_mode():
group = _get_default_group() if group is None else group group = _get_default_group() if group is None else group
task = group.process_group.recv(tensor, src) task = group.process_group.recv(tensor, src)
if use_calc_stream: if use_calc_stream:
......
...@@ -24,19 +24,20 @@ from paddle import compat as cpt ...@@ -24,19 +24,20 @@ from paddle import compat as cpt
# deprecated module import # deprecated module import
from paddle.fluid import core from paddle.fluid import core
import paddle.fluid.framework as framework from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import _set_expected_place from paddle.fluid.framework import _set_expected_place
from paddle.fluid.dygraph import parallel_helper from paddle.fluid.dygraph import parallel_helper
from paddle.distributed.fleet.launch_utils import check_backend from paddle.distributed.fleet.launch_utils import check_backend
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed.fleet.base.private_helper_function import wait_server_ready # noqa: F401 from paddle.distributed.fleet.base.private_helper_function import wait_server_ready # noqa: F401
import paddle.distributed.collective as collective from paddle.distributed.collective import _set_group_map
from paddle.distributed.collective import _set_group_map_by_name
from paddle.distributed.collective import _get_group_map_by_name
from paddle.distributed.collective import _group_map_by_name from paddle.distributed.collective import _group_map_by_name
from paddle.distributed.collective import _group_map
from paddle.distributed.collective import _default_group_name from paddle.distributed.collective import _default_group_name
from paddle.distributed.collective import _valid_backend_list from paddle.distributed.collective import _valid_backend_list
from paddle.distributed.collective import _default_backend from paddle.distributed.collective import _set_default_backend
from paddle.distributed.collective import _default_store from paddle.distributed.collective import _set_default_store
from paddle.distributed.collective import _new_process_group_impl from paddle.distributed.collective import _new_process_group_impl
from paddle.distributed.collective import Group from paddle.distributed.collective import Group
...@@ -205,10 +206,10 @@ def init_parallel_env(): ...@@ -205,10 +206,10 @@ def init_parallel_env():
_set_expected_place(place) _set_expected_place(place)
group = None group = None
if backend in _valid_backend_list and framework._in_eager_mode_: if backend in _valid_backend_list and in_dygraph_mode():
if _default_group_name in collective._group_map_by_name: if _default_group_name in _get_group_map_by_name():
return collective._group_map_by_name[_default_group_name] return _get_group_map_by_name()[_default_group_name]
_default_backend = backend _set_default_backend(backend)
rank = int(os.getenv("PADDLE_TRAINER_ID")) rank = int(os.getenv("PADDLE_TRAINER_ID"))
world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) world_size = int(os.getenv("PADDLE_TRAINERS_NUM"))
assert rank >= 0 and world_size > rank and world_size > 1, ( assert rank >= 0 and world_size > rank and world_size > 1, (
...@@ -230,11 +231,12 @@ def init_parallel_env(): ...@@ -230,11 +231,12 @@ def init_parallel_env():
master_addr, master_port = endpoints.split(":") master_addr, master_port = endpoints.split(":")
master_port = int(master_port) master_port = int(master_port)
is_master = rank == 0 is_master = rank == 0
_default_store = core.TCPStore(master_addr, master_port, is_master, default_store = core.TCPStore(master_addr, master_port, is_master,
world_size) world_size)
_set_default_store(default_store)
pg = _new_process_group_impl( pg = _new_process_group_impl(
backend, backend,
_default_store, default_store,
rank, rank,
world_size, world_size,
_default_group_name, _default_group_name,
...@@ -247,8 +249,8 @@ def init_parallel_env(): ...@@ -247,8 +249,8 @@ def init_parallel_env():
ranks=ranks, ranks=ranks,
pg=pg, pg=pg,
name=_default_group_name) name=_default_group_name)
collective._group_map_by_name[_default_group_name] = group _set_group_map_by_name(_default_group_name, group)
_group_map[0] = group _set_group_map(0, group)
parallel_helper._set_parallel_ctx(True) parallel_helper._set_parallel_ctx(True)
return group return group
......
...@@ -45,7 +45,7 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -45,7 +45,7 @@ class TestProcessGroupFp32(unittest.TestCase):
nranks = ParallelEnv().nranks nranks = ParallelEnv().nranks
rank = ParallelEnv().local_rank rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", 6172, is_master, store = paddle.fluid.core.TCPStore("127.0.0.1", 6272, is_master,
nranks, datetime.timedelta(0)) nranks, datetime.timedelta(0))
pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks) pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册