From ebf486acb8accd341cf19dc9667f365de0bdd57d Mon Sep 17 00:00:00 2001 From: kuizhiqing Date: Tue, 24 May 2022 11:16:57 +0800 Subject: [PATCH] [launch] fix timeout reset (#42941) --- python/paddle/distributed/launch/context/__init__.py | 7 +++++++ .../paddle/distributed/launch/context/args_envs.py | 4 ++-- .../distributed/launch/controllers/__init__.py | 1 + .../distributed/launch/controllers/collective.py | 6 +++++- .../paddle/distributed/launch/controllers/master.py | 12 +++++++++++- python/paddle/distributed/launch/controllers/ps.py | 2 ++ python/paddle/distributed/launch/plugins/__init__.py | 3 ++- python/paddle/fluid/tests/unittests/test_run.py | 4 ++-- 8 files changed, 32 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/launch/context/__init__.py b/python/paddle/distributed/launch/context/__init__.py index 08c8f0835c5..fbea5d0db86 100644 --- a/python/paddle/distributed/launch/context/__init__.py +++ b/python/paddle/distributed/launch/context/__init__.py @@ -17,6 +17,7 @@ from paddle.distributed.launch import plugins from .node import Node from .status import Status from .args_envs import parse_args, fetch_envs, env_args_mapping +import six import logging @@ -39,6 +40,12 @@ class Context(object): if enable_plugin: self._enable_plugin() + def print(self): + self.logger.info("----------- Configuration ----------------------") + for arg, value in sorted(six.iteritems(vars(self.args))): + self.logger.info("%s: %s" % (arg, value)) + self.logger.info("--------------------------------------------------") + def is_legacy_mode(self): if self.args.legacy: return True diff --git a/python/paddle/distributed/launch/context/args_envs.py b/python/paddle/distributed/launch/context/args_envs.py index b624281e44d..ea8bf3d597a 100644 --- a/python/paddle/distributed/launch/context/args_envs.py +++ b/python/paddle/distributed/launch/context/args_envs.py @@ -85,7 +85,7 @@ def parse_args(): base_group.add_argument( "--run_mode", type=str, - default="collective", + default=None, help="run mode of the job, collective/ps/ps-heter") base_group.add_argument( @@ -125,7 +125,7 @@ def parse_args(): ps_group.add_argument( "--gloo_port", type=int, default=6767, help="gloo http port") ps_group.add_argument( - "--with_gloo", type=str, default="0", help="use gloo or not") + "--with_gloo", type=str, default="1", help="use gloo or not") # parameter elastic mode elastic_group = parser.add_argument_group("Elastic Parameters") diff --git a/python/paddle/distributed/launch/controllers/__init__.py b/python/paddle/distributed/launch/controllers/__init__.py index 706131300f0..f1c6ea5399a 100644 --- a/python/paddle/distributed/launch/controllers/__init__.py +++ b/python/paddle/distributed/launch/controllers/__init__.py @@ -29,4 +29,5 @@ _controllers = [ def init(ctx): for c in _controllers: if c.enable(ctx): + ctx.print() return c(ctx) diff --git a/python/paddle/distributed/launch/controllers/collective.py b/python/paddle/distributed/launch/controllers/collective.py index 3763bac0414..5225fd6e81f 100644 --- a/python/paddle/distributed/launch/controllers/collective.py +++ b/python/paddle/distributed/launch/controllers/collective.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .controller import Controller +from .controller import Controller, ControleMode import json import os @@ -23,8 +23,10 @@ import time class CollectiveController(Controller): @classmethod def enable(cls, ctx): + # collective is the default mode if ctx: ctx.logger.debug("{} enabled".format(cls.__name__)) + ctx.args.run_mode = ControleMode.COLLECTIVE return True else: return False @@ -85,6 +87,7 @@ class CollectiveController(Controller): "PADDLE_LOCAL_SIZE": "{}".format(self.pod.replicas), "PADDLE_GLOBAL_RANK": "{}".format(i + rank_offset), "PADDLE_LOCAL_RANK": "{}".format(i), + "PADDLE_NNODES": "{}".format(self.job.replicas), ## compatible env "PADDLE_TRAINER_ENDPOINTS": ",".join(job_endpoints), "PADDLE_CURRENT_ENDPOINT": endpoints[i], @@ -106,6 +109,7 @@ class CollectiveElasticController(CollectiveController): def enable(cls, ctx): if ctx.args.master and ctx.args.master.startswith("etcd://"): ctx.logger.debug("{} enabled".format(cls.__name__)) + ctx.args.run_mode = ControleMode.COLLECTIVE return True else: return False diff --git a/python/paddle/distributed/launch/controllers/master.py b/python/paddle/distributed/launch/controllers/master.py index 43eda4cdffa..742fea9e16d 100644 --- a/python/paddle/distributed/launch/controllers/master.py +++ b/python/paddle/distributed/launch/controllers/master.py @@ -276,10 +276,20 @@ class ETCDMaster(Master): return peer_alive def wait_peer_ready(self, replicas_min, replicas_max, timeout): + timeout = timeout if timeout > 1 else 3 + end = time.time() + timeout + np_pre = len(self.fetch_peer_alive()) while not self.ctx.status.is_done() and time.time() < end: - if len(self.fetch_peer_alive()) == replicas_max: + np = len(self.fetch_peer_alive()) + if np == replicas_max: + # maximum replicas reached, return immediately return (True, replicas_max) + elif np != np_pre: + # replicas are changing, reset timeout + end = time.time() + timeout + np_pre = np + time.sleep(0.2) else: time.sleep(0.5) diff --git a/python/paddle/distributed/launch/controllers/ps.py b/python/paddle/distributed/launch/controllers/ps.py index 6504f1240ee..037bd313bbc 100644 --- a/python/paddle/distributed/launch/controllers/ps.py +++ b/python/paddle/distributed/launch/controllers/ps.py @@ -171,6 +171,7 @@ class PSController(Controller): for i in range(server_num): e = { + "PADDLE_NNODES": "{}".format(self.job.replicas), "PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints), "PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints), "PADDLE_PORT": @@ -186,6 +187,7 @@ class PSController(Controller): for i in range(trainer_num): e = { + "PADDLE_NNODES": "{}".format(self.job.replicas), "PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints), "PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints), "PADDLE_PORT": diff --git a/python/paddle/distributed/launch/plugins/__init__.py b/python/paddle/distributed/launch/plugins/__init__.py index 35a44ed942c..13c09b4c27c 100644 --- a/python/paddle/distributed/launch/plugins/__init__.py +++ b/python/paddle/distributed/launch/plugins/__init__.py @@ -17,6 +17,7 @@ import six __all__ = [] +# print configuration after args are well filled in controller init def log(ctx): ctx.logger.info("----------- Configuration ----------------------") for arg, value in sorted(six.iteritems(vars(ctx.args))): @@ -59,4 +60,4 @@ def rewrite_host_ip(ctx): ctx.node.ip = ctx.args.host -enabled_plugins = [collective_compatible, rewrite_host_ip, process_args, log] +enabled_plugins = [collective_compatible, rewrite_host_ip, process_args] diff --git a/python/paddle/fluid/tests/unittests/test_run.py b/python/paddle/fluid/tests/unittests/test_run.py index 28bcc379fb9..c0157c5b906 100644 --- a/python/paddle/fluid/tests/unittests/test_run.py +++ b/python/paddle/fluid/tests/unittests/test_run.py @@ -95,7 +95,7 @@ class Collective_Test(unittest.TestCase): shutil.rmtree('./log') port = random.randrange(6000, 8000) - args = "--job_id test3 --devices 0,1 --master 127.0.0.1:{} --np 2".format( + args = "--job_id test3 --devices 0,1 --master 127.0.0.1:{} --nnodes 2".format( port) p1 = self.pdrun(args) p2 = self.pdrun(args) @@ -143,7 +143,7 @@ class PS_Test(unittest.TestCase): shutil.rmtree('./log') port = random.randrange(6000, 8000) - args = "--job_id ps3 --master 127.0.0.1:{} --np 2 --server_num=1 --trainer_num=1".format( + args = "--job_id ps3 --master 127.0.0.1:{} --nnodes 2 --server_num=1 --trainer_num=1".format( port) p1 = self.pdrun(args) p2 = self.pdrun(args) -- GitLab