未验证 提交 1734bc6f 编写于 作者: L LiYuRio 提交者: GitHub

Add store_barrier to prevent master exit (#44964)

上级 eb4531f1
...@@ -194,8 +194,8 @@ void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) { ...@@ -194,8 +194,8 @@ void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) {
<< " from addr info:" << GetSockName(fds[i].fd); << " from addr info:" << GetSockName(fds[i].fd);
} }
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
fds.erase(fds.begin() + i);
tcputils::close_socket(fds[i].fd); tcputils::close_socket(fds[i].fd);
fds.erase(fds.begin() + i);
#ifdef _WIN32 #ifdef _WIN32
_sockets.erase(_sockets.begin() + i - 1); _sockets.erase(_sockets.begin() + i - 1);
#else #else
...@@ -405,12 +405,14 @@ std::vector<uint8_t> TCPStore::get(const std::string& key) { ...@@ -405,12 +405,14 @@ std::vector<uint8_t> TCPStore::get(const std::string& key) {
void TCPStore::wait(const std::string& key) { void TCPStore::wait(const std::string& key) {
ReplyType reply; ReplyType reply;
VLOG(3) << "TCPStore wait."; VLOG(3) << "TCPStore wait.";
do { _client->send_command_for_key(Command::WAIT, _key_prefix + key);
_client->send_command_for_key(Command::WAIT, _key_prefix + key); reply = _client->receive_value<ReplyType>();
while (reply != ReplyType::STOP_WAIT) {
std::this_thread::sleep_for(std::chrono::milliseconds(500));
_client->send_command_for_key(Command::WAIT, _key_prefix + key);
reply = _client->receive_value<ReplyType>(); reply = _client->receive_value<ReplyType>();
std::this_thread::sleep_for(std::chrono::milliseconds(500)); }
} while (reply != ReplyType::STOP_WAIT);
} }
TCPStore::~TCPStore() { VLOG(3) << "TCPStore destructure"; } TCPStore::~TCPStore() { VLOG(3) << "TCPStore destructure"; }
......
...@@ -16,7 +16,8 @@ import numpy as np ...@@ -16,7 +16,8 @@ import numpy as np
import os import os
import pickle import pickle
import io import io
from datetime import timedelta import datetime
import time
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import Variable from ..fluid.framework import Variable
from ..fluid.framework import in_dygraph_mode from ..fluid.framework import in_dygraph_mode
...@@ -134,6 +135,7 @@ def _get_global_env(): ...@@ -134,6 +135,7 @@ def _get_global_env():
# group map : the map of all group, 0 for GlobalGroup # group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group] # Dict[int, Group]
_group_map = {} _group_map = {}
_global_env_gid = 0
# group map by name : the map of all groups from their names # group map by name : the map of all groups from their names
# Dict[name, Group] # Dict[name, Group]
...@@ -149,6 +151,8 @@ _default_group_name = "_default_pg" ...@@ -149,6 +151,8 @@ _default_group_name = "_default_pg"
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl'] _valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl']
_default_store = None # the default tcp store _default_store = None # the default tcp store
_default_backend = None _default_backend = None
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
def _set_default_backend(backend): def _set_default_backend(backend):
...@@ -163,16 +167,16 @@ def _set_default_store(store): ...@@ -163,16 +167,16 @@ def _set_default_store(store):
def _get_group_map(): def _get_group_map():
global _group_map global _group_map
if not _group_map: if _global_env_gid not in _group_map:
genv = _get_global_env() genv = _get_global_env()
_group_map[0] = Group(genv.rank, _group_map[_global_env_gid] = Group(genv.rank,
genv.world_size, genv.world_size,
ranks=list(range(genv.world_size))) ranks=list(range(genv.world_size)))
return _group_map return _group_map
def _get_global_group(): def _get_global_group():
return _get_group_map()[0] return _get_group_map()[_global_env_gid]
def _get_group_map_by_name(): def _get_group_map_by_name():
...@@ -206,7 +210,13 @@ def _set_group_map_backend(group, backend): ...@@ -206,7 +210,13 @@ def _set_group_map_backend(group, backend):
def _new_ring_id(): def _new_ring_id():
return len(_get_group_map()) + max(_get_global_env().nrings, 9) # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
if in_dygraph_mode():
global _start_ring_id
_start_ring_id += 1
return _start_ring_id + max(_get_global_env().nrings, 9)
else:
return len(_get_group_map()) + max(_get_global_env().nrings, 9)
def _get_reduce_op(reduce_op, func_name): def _get_reduce_op(reduce_op, func_name):
...@@ -293,6 +303,7 @@ def _new_process_group_impl(backend, ...@@ -293,6 +303,7 @@ def _new_process_group_impl(backend,
cluster_id - 1] cluster_id - 1]
global_rank = cluster_offset + rank global_rank = cluster_offset + rank
global_world_size = cluster_size_cumsum[-1] global_world_size = cluster_size_cumsum[-1]
global_rank, global_world_size = _get_global_config(backend, rank)
pg = core.ProcessGroupHeter(store, pg = core.ProcessGroupHeter(store,
rank=global_rank, rank=global_rank,
world_size=global_world_size, world_size=global_world_size,
...@@ -368,7 +379,43 @@ def _set_custom_gid(gid): ...@@ -368,7 +379,43 @@ def _set_custom_gid(gid):
_custom_gid = gid _custom_gid = gid
def new_group(ranks=None, backend=None): def _barrier_by_tcp_store(group_name, store, timeout):
global_rank = paddle.distributed.get_rank()
global_world_size = paddle.distributed.get_world_size()
if global_world_size < 2:
return
barrier_prefix = "Barrier/" + group_name + "/"
is_master = (global_rank == 0)
def _check_keys_ready(wait_keys):
start_time = time.time()
while len(wait_keys) > 0:
time.sleep(0.1)
elapse_time = time.time() - start_time
if datetime.timedelta(seconds=elapse_time) > timeout:
raise RuntimeError(
"Timeout while initializing process group {}."
"Keys {} are not ready sinck rank {} is waiting them."
"Two reason may cause this error:\n 1. The create process group api should be called by all ranks.\n"
" 2. Try to increase the waiting time.\n".format(
group_name, wait_keys, global_rank))
wait_keys = list(
filter(lambda key: int(store.get(key)) != 1, wait_keys))
# all the workers set their exiting key and exit
# the master will wait for all workers' exiting key, ensure to exit in the end
if is_master:
wait_keys = [
barrier_prefix + str(rank) for rank in range(1, global_world_size)
]
_check_keys_ready(wait_keys)
else:
store.add(barrier_prefix + str(global_rank), 1)
def new_group(ranks=None, backend=None, timeout=_default_timeout):
""" """
Creates a new distributed communication group. Creates a new distributed communication group.
...@@ -376,6 +423,7 @@ def new_group(ranks=None, backend=None): ...@@ -376,6 +423,7 @@ def new_group(ranks=None, backend=None):
Args: Args:
ranks (list): The global ranks of group members. ranks (list): The global ranks of group members.
backend (str): The backend used to create group, only nccl is supported now. backend (str): The backend used to create group, only nccl is supported now.
timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
Returns: Returns:
Group: The group instance. Group: The group instance.
...@@ -433,6 +481,11 @@ def new_group(ranks=None, backend=None): ...@@ -433,6 +481,11 @@ def new_group(ranks=None, backend=None):
# TODO(shenliang03): This is a temporary solution to solve the problem of # TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by tcp # hang caused by tcp
paddle.distributed.barrier(group=group) paddle.distributed.barrier(group=group)
# NOTE(liyurui): All processors should hang and wait using tcp store, in case master exit before sub-group is created.
if backend != 'heter':
_barrier_by_tcp_store(group_name, _default_store, timeout)
else:
print("Warning: store barrier is not supported for heter backend.")
return group return group
if not backend: if not backend:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册