From 1734bc6fff4490d231cdedd0f2112419ea981dd2 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Tue, 23 Aug 2022 11:23:49 +0800 Subject: [PATCH] Add store_barrier to prevent master exit (#44964) --- paddle/fluid/distributed/store/tcp_store.cc | 12 ++-- python/paddle/distributed/collective.py | 69 ++++++++++++++++++--- 2 files changed, 68 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/distributed/store/tcp_store.cc b/paddle/fluid/distributed/store/tcp_store.cc index e4228e4428d..28387af44df 100644 --- a/paddle/fluid/distributed/store/tcp_store.cc +++ b/paddle/fluid/distributed/store/tcp_store.cc @@ -194,8 +194,8 @@ void MasterDaemon::ProcessCommands(std::vector* p_fds) { << " from addr info:" << GetSockName(fds[i].fd); } } catch (const std::exception& ex) { - fds.erase(fds.begin() + i); tcputils::close_socket(fds[i].fd); + fds.erase(fds.begin() + i); #ifdef _WIN32 _sockets.erase(_sockets.begin() + i - 1); #else @@ -405,12 +405,14 @@ std::vector TCPStore::get(const std::string& key) { void TCPStore::wait(const std::string& key) { ReplyType reply; VLOG(3) << "TCPStore wait."; - do { - _client->send_command_for_key(Command::WAIT, _key_prefix + key); + _client->send_command_for_key(Command::WAIT, _key_prefix + key); + reply = _client->receive_value(); + while (reply != ReplyType::STOP_WAIT) { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + _client->send_command_for_key(Command::WAIT, _key_prefix + key); reply = _client->receive_value(); - std::this_thread::sleep_for(std::chrono::milliseconds(500)); - } while (reply != ReplyType::STOP_WAIT); + } } TCPStore::~TCPStore() { VLOG(3) << "TCPStore destructure"; } diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index bf36015f894..a4e918fd01f 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -16,7 +16,8 @@ import numpy as np import os import pickle import io -from datetime import timedelta +import datetime +import time from ..fluid.layer_helper import LayerHelper from ..fluid.framework import Variable from ..fluid.framework import in_dygraph_mode @@ -134,6 +135,7 @@ def _get_global_env(): # group map : the map of all group, 0 for GlobalGroup # Dict[int, Group] _group_map = {} +_global_env_gid = 0 # group map by name : the map of all groups from their names # Dict[name, Group] @@ -149,6 +151,8 @@ _default_group_name = "_default_pg" _valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl'] _default_store = None # the default tcp store _default_backend = None +_default_timeout = datetime.timedelta(seconds=1800) +_start_ring_id = 0 def _set_default_backend(backend): @@ -163,16 +167,16 @@ def _set_default_store(store): def _get_group_map(): global _group_map - if not _group_map: + if _global_env_gid not in _group_map: genv = _get_global_env() - _group_map[0] = Group(genv.rank, - genv.world_size, - ranks=list(range(genv.world_size))) + _group_map[_global_env_gid] = Group(genv.rank, + genv.world_size, + ranks=list(range(genv.world_size))) return _group_map def _get_global_group(): - return _get_group_map()[0] + return _get_group_map()[_global_env_gid] def _get_group_map_by_name(): @@ -206,7 +210,13 @@ def _set_group_map_backend(group, backend): def _new_ring_id(): - return len(_get_group_map()) + max(_get_global_env().nrings, 9) + # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax. + if in_dygraph_mode(): + global _start_ring_id + _start_ring_id += 1 + return _start_ring_id + max(_get_global_env().nrings, 9) + else: + return len(_get_group_map()) + max(_get_global_env().nrings, 9) def _get_reduce_op(reduce_op, func_name): @@ -293,6 +303,7 @@ def _new_process_group_impl(backend, cluster_id - 1] global_rank = cluster_offset + rank global_world_size = cluster_size_cumsum[-1] + global_rank, global_world_size = _get_global_config(backend, rank) pg = core.ProcessGroupHeter(store, rank=global_rank, world_size=global_world_size, @@ -368,7 +379,43 @@ def _set_custom_gid(gid): _custom_gid = gid -def new_group(ranks=None, backend=None): +def _barrier_by_tcp_store(group_name, store, timeout): + global_rank = paddle.distributed.get_rank() + global_world_size = paddle.distributed.get_world_size() + + if global_world_size < 2: + return + + barrier_prefix = "Barrier/" + group_name + "/" + is_master = (global_rank == 0) + + def _check_keys_ready(wait_keys): + start_time = time.time() + while len(wait_keys) > 0: + time.sleep(0.1) + elapse_time = time.time() - start_time + if datetime.timedelta(seconds=elapse_time) > timeout: + raise RuntimeError( + "Timeout while initializing process group {}." + "Keys {} are not ready sinck rank {} is waiting them." + "Two reason may cause this error:\n 1. The create process group api should be called by all ranks.\n" + " 2. Try to increase the waiting time.\n".format( + group_name, wait_keys, global_rank)) + wait_keys = list( + filter(lambda key: int(store.get(key)) != 1, wait_keys)) + + # all the workers set their exiting key and exit + # the master will wait for all workers' exiting key, ensure to exit in the end + if is_master: + wait_keys = [ + barrier_prefix + str(rank) for rank in range(1, global_world_size) + ] + _check_keys_ready(wait_keys) + else: + store.add(barrier_prefix + str(global_rank), 1) + + +def new_group(ranks=None, backend=None, timeout=_default_timeout): """ Creates a new distributed communication group. @@ -376,6 +423,7 @@ def new_group(ranks=None, backend=None): Args: ranks (list): The global ranks of group members. backend (str): The backend used to create group, only nccl is supported now. + timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes. Returns: Group: The group instance. @@ -433,6 +481,11 @@ def new_group(ranks=None, backend=None): # TODO(shenliang03): This is a temporary solution to solve the problem of # hang caused by tcp paddle.distributed.barrier(group=group) + # NOTE(liyurui): All processors should hang and wait using tcp store, in case master exit before sub-group is created. + if backend != 'heter': + _barrier_by_tcp_store(group_name, _default_store, timeout) + else: + print("Warning: store barrier is not supported for heter backend.") return group if not backend: -- GitLab