未验证 提交 ebf486ac 编写于 作者: K kuizhiqing 提交者: GitHub

[launch] fix timeout reset (#42941)

上级 a5ad2659
...@@ -17,6 +17,7 @@ from paddle.distributed.launch import plugins ...@@ -17,6 +17,7 @@ from paddle.distributed.launch import plugins
from .node import Node from .node import Node
from .status import Status from .status import Status
from .args_envs import parse_args, fetch_envs, env_args_mapping from .args_envs import parse_args, fetch_envs, env_args_mapping
import six
import logging import logging
...@@ -39,6 +40,12 @@ class Context(object): ...@@ -39,6 +40,12 @@ class Context(object):
if enable_plugin: if enable_plugin:
self._enable_plugin() self._enable_plugin()
def print(self):
self.logger.info("----------- Configuration ----------------------")
for arg, value in sorted(six.iteritems(vars(self.args))):
self.logger.info("%s: %s" % (arg, value))
self.logger.info("--------------------------------------------------")
def is_legacy_mode(self): def is_legacy_mode(self):
if self.args.legacy: if self.args.legacy:
return True return True
......
...@@ -85,7 +85,7 @@ def parse_args(): ...@@ -85,7 +85,7 @@ def parse_args():
base_group.add_argument( base_group.add_argument(
"--run_mode", "--run_mode",
type=str, type=str,
default="collective", default=None,
help="run mode of the job, collective/ps/ps-heter") help="run mode of the job, collective/ps/ps-heter")
base_group.add_argument( base_group.add_argument(
...@@ -125,7 +125,7 @@ def parse_args(): ...@@ -125,7 +125,7 @@ def parse_args():
ps_group.add_argument( ps_group.add_argument(
"--gloo_port", type=int, default=6767, help="gloo http port") "--gloo_port", type=int, default=6767, help="gloo http port")
ps_group.add_argument( ps_group.add_argument(
"--with_gloo", type=str, default="0", help="use gloo or not") "--with_gloo", type=str, default="1", help="use gloo or not")
# parameter elastic mode # parameter elastic mode
elastic_group = parser.add_argument_group("Elastic Parameters") elastic_group = parser.add_argument_group("Elastic Parameters")
......
...@@ -29,4 +29,5 @@ _controllers = [ ...@@ -29,4 +29,5 @@ _controllers = [
def init(ctx): def init(ctx):
for c in _controllers: for c in _controllers:
if c.enable(ctx): if c.enable(ctx):
ctx.print()
return c(ctx) return c(ctx)
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .controller import Controller from .controller import Controller, ControleMode
import json import json
import os import os
...@@ -23,8 +23,10 @@ import time ...@@ -23,8 +23,10 @@ import time
class CollectiveController(Controller): class CollectiveController(Controller):
@classmethod @classmethod
def enable(cls, ctx): def enable(cls, ctx):
# collective is the default mode
if ctx: if ctx:
ctx.logger.debug("{} enabled".format(cls.__name__)) ctx.logger.debug("{} enabled".format(cls.__name__))
ctx.args.run_mode = ControleMode.COLLECTIVE
return True return True
else: else:
return False return False
...@@ -85,6 +87,7 @@ class CollectiveController(Controller): ...@@ -85,6 +87,7 @@ class CollectiveController(Controller):
"PADDLE_LOCAL_SIZE": "{}".format(self.pod.replicas), "PADDLE_LOCAL_SIZE": "{}".format(self.pod.replicas),
"PADDLE_GLOBAL_RANK": "{}".format(i + rank_offset), "PADDLE_GLOBAL_RANK": "{}".format(i + rank_offset),
"PADDLE_LOCAL_RANK": "{}".format(i), "PADDLE_LOCAL_RANK": "{}".format(i),
"PADDLE_NNODES": "{}".format(self.job.replicas),
## compatible env ## compatible env
"PADDLE_TRAINER_ENDPOINTS": ",".join(job_endpoints), "PADDLE_TRAINER_ENDPOINTS": ",".join(job_endpoints),
"PADDLE_CURRENT_ENDPOINT": endpoints[i], "PADDLE_CURRENT_ENDPOINT": endpoints[i],
...@@ -106,6 +109,7 @@ class CollectiveElasticController(CollectiveController): ...@@ -106,6 +109,7 @@ class CollectiveElasticController(CollectiveController):
def enable(cls, ctx): def enable(cls, ctx):
if ctx.args.master and ctx.args.master.startswith("etcd://"): if ctx.args.master and ctx.args.master.startswith("etcd://"):
ctx.logger.debug("{} enabled".format(cls.__name__)) ctx.logger.debug("{} enabled".format(cls.__name__))
ctx.args.run_mode = ControleMode.COLLECTIVE
return True return True
else: else:
return False return False
......
...@@ -276,10 +276,20 @@ class ETCDMaster(Master): ...@@ -276,10 +276,20 @@ class ETCDMaster(Master):
return peer_alive return peer_alive
def wait_peer_ready(self, replicas_min, replicas_max, timeout): def wait_peer_ready(self, replicas_min, replicas_max, timeout):
timeout = timeout if timeout > 1 else 3
end = time.time() + timeout end = time.time() + timeout
np_pre = len(self.fetch_peer_alive())
while not self.ctx.status.is_done() and time.time() < end: while not self.ctx.status.is_done() and time.time() < end:
if len(self.fetch_peer_alive()) == replicas_max: np = len(self.fetch_peer_alive())
if np == replicas_max:
# maximum replicas reached, return immediately
return (True, replicas_max) return (True, replicas_max)
elif np != np_pre:
# replicas are changing, reset timeout
end = time.time() + timeout
np_pre = np
time.sleep(0.2)
else: else:
time.sleep(0.5) time.sleep(0.5)
......
...@@ -171,6 +171,7 @@ class PSController(Controller): ...@@ -171,6 +171,7 @@ class PSController(Controller):
for i in range(server_num): for i in range(server_num):
e = { e = {
"PADDLE_NNODES": "{}".format(self.job.replicas),
"PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints), "PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints),
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints), "PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints),
"PADDLE_PORT": "PADDLE_PORT":
...@@ -186,6 +187,7 @@ class PSController(Controller): ...@@ -186,6 +187,7 @@ class PSController(Controller):
for i in range(trainer_num): for i in range(trainer_num):
e = { e = {
"PADDLE_NNODES": "{}".format(self.job.replicas),
"PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints), "PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints),
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints), "PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints),
"PADDLE_PORT": "PADDLE_PORT":
......
...@@ -17,6 +17,7 @@ import six ...@@ -17,6 +17,7 @@ import six
__all__ = [] __all__ = []
# print configuration after args are well filled in controller init
def log(ctx): def log(ctx):
ctx.logger.info("----------- Configuration ----------------------") ctx.logger.info("----------- Configuration ----------------------")
for arg, value in sorted(six.iteritems(vars(ctx.args))): for arg, value in sorted(six.iteritems(vars(ctx.args))):
...@@ -59,4 +60,4 @@ def rewrite_host_ip(ctx): ...@@ -59,4 +60,4 @@ def rewrite_host_ip(ctx):
ctx.node.ip = ctx.args.host ctx.node.ip = ctx.args.host
enabled_plugins = [collective_compatible, rewrite_host_ip, process_args, log] enabled_plugins = [collective_compatible, rewrite_host_ip, process_args]
...@@ -95,7 +95,7 @@ class Collective_Test(unittest.TestCase): ...@@ -95,7 +95,7 @@ class Collective_Test(unittest.TestCase):
shutil.rmtree('./log') shutil.rmtree('./log')
port = random.randrange(6000, 8000) port = random.randrange(6000, 8000)
args = "--job_id test3 --devices 0,1 --master 127.0.0.1:{} --np 2".format( args = "--job_id test3 --devices 0,1 --master 127.0.0.1:{} --nnodes 2".format(
port) port)
p1 = self.pdrun(args) p1 = self.pdrun(args)
p2 = self.pdrun(args) p2 = self.pdrun(args)
...@@ -143,7 +143,7 @@ class PS_Test(unittest.TestCase): ...@@ -143,7 +143,7 @@ class PS_Test(unittest.TestCase):
shutil.rmtree('./log') shutil.rmtree('./log')
port = random.randrange(6000, 8000) port = random.randrange(6000, 8000)
args = "--job_id ps3 --master 127.0.0.1:{} --np 2 --server_num=1 --trainer_num=1".format( args = "--job_id ps3 --master 127.0.0.1:{} --nnodes 2 --server_num=1 --trainer_num=1".format(
port) port)
p1 = self.pdrun(args) p1 = self.pdrun(args)
p2 = self.pdrun(args) p2 = self.pdrun(args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册