未验证 提交 e9e68c36 编写于 作者: L lilong12 提交者: GitHub

support group with one rank (#41398)

上级 814315b4
......@@ -16,7 +16,6 @@ import numpy as np
import os
from datetime import timedelta
from ..fluid.layer_helper import LayerHelper
import paddle.fluid.framework as framework
from ..fluid.framework import Variable
from ..fluid.framework import in_dygraph_mode
from ..fluid.framework import OpProtoHolder
......@@ -144,6 +143,16 @@ _default_store = None # the default tcp store
_default_backend = None
def _set_default_backend(backend):
global _default_backend
_default_backend = backend
def _set_default_store(store):
global _default_store
_default_store = store
def _get_group_map():
global _group_map
if not _group_map:
......@@ -159,19 +168,29 @@ def _get_global_group():
def _get_group_map_by_name():
global _group_map_by_name
assert _default_group_name in _group_map_by_name, (
"Call paddle.distributed.init_parallel_env first "
"to initialize the distributed environment.")
return _group_map_by_name
def _get_default_group():
global _group_map_by_name
assert _default_group_name in _group_map_by_name, (
"Call paddle.distributed.init_parallel_env first "
"to initialize the distributed environment.")
return _get_group_map_by_name()[_default_group_name]
def _set_group_map(gid, group):
global _group_map
assert gid not in _group_map
_group_map[gid] = group
def _set_group_map_by_name(name, group):
global _group_map_by_name
assert name not in _group_map_by_name
_group_map_by_name[name] = group
def _new_ring_id():
return len(_get_group_map()) + max(_get_global_env().nrings, 9)
......@@ -208,6 +227,7 @@ def _new_process_group_impl(backend,
pg_options,
group_id=0):
pg = None
assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
if backend == "gloo":
pg = core.ProcessGroupGloo(store, rank, world_size, group_id)
elif backend == "nccl":
......@@ -242,7 +262,7 @@ def barrier(group=None):
if group is not None and not group.is_member():
return
if framework._in_eager_mode_ and in_dygraph_mode():
if in_dygraph_mode():
group = _get_default_group() if group is None else group
task = group.process_group.barrier()
task.wait()
......@@ -290,22 +310,22 @@ def new_group(ranks=None, backend=None):
"""
global _group_map
if framework._in_eager_mode_:
if in_dygraph_mode():
global _default_group_name
gid = _new_ring_id()
group_name = _default_group_name + str(gid)
global_group = _get_default_group()
global_rank = global_group.rank
global_ranks = global_group.ranks
backend = _default_backend if backend is None else backend
if ranks is None:
ranks = global_ranks
assert len(ranks) <= len(global_ranks), (
"Size of new group must be less than or "
"equal to that of the default global group.")
size = len(ranks)
assert size > 1, "A group must have at least two memebers."
ranks = sorted(ranks)
if global_rank in ranks:
if global_rank in ranks and size > 1:
rank = ranks.index(global_rank)
pg = _new_process_group_impl(
backend,
......@@ -495,7 +515,7 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):
if not isinstance(src, int):
raise ValueError("src should be int.")
if framework._in_eager_mode_ and in_dygraph_mode():
if in_dygraph_mode():
group = _get_default_group() if group is None else group
gsrc = group.get_group_rank(src)
assert gsrc >= 0, ("src rank out of group, need global rank")
......@@ -579,7 +599,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
if group is not None and not group.is_member():
return
if framework._in_eager_mode_ and in_dygraph_mode():
if in_dygraph_mode():
if op == ReduceOp.SUM:
op_type = core.ReduceOp.SUM
elif op == ReduceOp.MAX:
......@@ -681,7 +701,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
if group is not None and not group.is_member():
return
if framework._in_eager_mode_ and in_dygraph_mode():
if in_dygraph_mode():
if op == ReduceOp.SUM:
op_type = core.ReduceOp.SUM
elif op == ReduceOp.MAX:
......@@ -802,7 +822,7 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
if group is not None and not group.is_member():
return
if framework._in_eager_mode_ and in_dygraph_mode():
if in_dygraph_mode():
group = _get_default_group() if group is None else group
out = paddle.concat(tensor_list)
task = group.process_group.all_gather(tensor, out)
......@@ -899,7 +919,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
if not isinstance(src, int):
raise ValueError("src should be int.")
if framework._in_eager_mode_ and in_dygraph_mode():
if in_dygraph_mode():
group = _get_default_group() if group is None else group
gsrc = group.get_group_rank(src)
rank = group.rank
......@@ -916,7 +936,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
for _ in range(nranks):
tensor_list.append(tensor)
temp = paddle.concat(tensor_list, axis=0)
if framework._in_eager_mode_ and in_dygraph_mode():
if in_dygraph_mode():
task = group.process_group.scatter(temp, tensor, gsrc)
if use_calc_stream:
task.wait()
......@@ -924,7 +944,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
else:
return task
if in_dygraph_mode():
if _non_static_mode():
return _C_ops.c_scatter(temp, tensor, 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id, 'nranks',
nranks, 'root', gsrc)
......@@ -1694,14 +1714,14 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
if group is not None and not group.is_member():
return
if framework._in_eager_mode_ and in_dygraph_mode():
if in_dygraph_mode():
group = _get_default_group() if group is None else group
else:
ring_id = 0 if group is None else group.id
temp = paddle.concat(in_tensor_list, axis=0)
nranks = len(in_tensor_list)
if framework._in_eager_mode_ and in_dygraph_mode():
if in_dygraph_mode():
out = paddle.concat(out_tensor_list, axis=0)
task = group.process_group.alltoall(temp, out)
task.wait()
......@@ -1776,7 +1796,7 @@ def send(tensor, dst=0, group=None, use_calc_stream=True):
if group is not None and not group.is_member():
return
if framework._in_eager_mode_ and in_dygraph_mode():
if in_dygraph_mode():
group = _get_default_group() if group is None else group
task = group.process_group.send(tensor, dst)
if use_calc_stream:
......@@ -1839,7 +1859,7 @@ def recv(tensor, src=0, group=None, use_calc_stream=True):
if group is not None and not group.is_member():
return
if framework._in_eager_mode_ and in_dygraph_mode():
if in_dygraph_mode():
group = _get_default_group() if group is None else group
task = group.process_group.recv(tensor, src)
if use_calc_stream:
......
......@@ -24,19 +24,20 @@ from paddle import compat as cpt
# deprecated module import
from paddle.fluid import core
import paddle.fluid.framework as framework
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import _set_expected_place
from paddle.fluid.dygraph import parallel_helper
from paddle.distributed.fleet.launch_utils import check_backend
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed.fleet.base.private_helper_function import wait_server_ready # noqa: F401
import paddle.distributed.collective as collective
from paddle.distributed.collective import _set_group_map
from paddle.distributed.collective import _set_group_map_by_name
from paddle.distributed.collective import _get_group_map_by_name
from paddle.distributed.collective import _group_map_by_name
from paddle.distributed.collective import _group_map
from paddle.distributed.collective import _default_group_name
from paddle.distributed.collective import _valid_backend_list
from paddle.distributed.collective import _default_backend
from paddle.distributed.collective import _default_store
from paddle.distributed.collective import _set_default_backend
from paddle.distributed.collective import _set_default_store
from paddle.distributed.collective import _new_process_group_impl
from paddle.distributed.collective import Group
......@@ -205,10 +206,10 @@ def init_parallel_env():
_set_expected_place(place)
group = None
if backend in _valid_backend_list and framework._in_eager_mode_:
if _default_group_name in collective._group_map_by_name:
return collective._group_map_by_name[_default_group_name]
_default_backend = backend
if backend in _valid_backend_list and in_dygraph_mode():
if _default_group_name in _get_group_map_by_name():
return _get_group_map_by_name()[_default_group_name]
_set_default_backend(backend)
rank = int(os.getenv("PADDLE_TRAINER_ID"))
world_size = int(os.getenv("PADDLE_TRAINERS_NUM"))
assert rank >= 0 and world_size > rank and world_size > 1, (
......@@ -230,11 +231,12 @@ def init_parallel_env():
master_addr, master_port = endpoints.split(":")
master_port = int(master_port)
is_master = rank == 0
_default_store = core.TCPStore(master_addr, master_port, is_master,
default_store = core.TCPStore(master_addr, master_port, is_master,
world_size)
_set_default_store(default_store)
pg = _new_process_group_impl(
backend,
_default_store,
default_store,
rank,
world_size,
_default_group_name,
......@@ -247,8 +249,8 @@ def init_parallel_env():
ranks=ranks,
pg=pg,
name=_default_group_name)
collective._group_map_by_name[_default_group_name] = group
_group_map[0] = group
_set_group_map_by_name(_default_group_name, group)
_set_group_map(0, group)
parallel_helper._set_parallel_ctx(True)
return group
......
......@@ -45,7 +45,7 @@ class TestProcessGroupFp32(unittest.TestCase):
nranks = ParallelEnv().nranks
rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", 6172, is_master,
store = paddle.fluid.core.TCPStore("127.0.0.1", 6272, is_master,
nranks, datetime.timedelta(0))
pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册