未验证 提交 ebf486ac 编写于 作者: K kuizhiqing 提交者: GitHub

[launch] fix timeout reset (#42941)

上级 a5ad2659
......@@ -17,6 +17,7 @@ from paddle.distributed.launch import plugins
from .node import Node
from .status import Status
from .args_envs import parse_args, fetch_envs, env_args_mapping
import six
import logging
......@@ -39,6 +40,12 @@ class Context(object):
if enable_plugin:
self._enable_plugin()
def print(self):
self.logger.info("----------- Configuration ----------------------")
for arg, value in sorted(six.iteritems(vars(self.args))):
self.logger.info("%s: %s" % (arg, value))
self.logger.info("--------------------------------------------------")
def is_legacy_mode(self):
if self.args.legacy:
return True
......
......@@ -85,7 +85,7 @@ def parse_args():
base_group.add_argument(
"--run_mode",
type=str,
default="collective",
default=None,
help="run mode of the job, collective/ps/ps-heter")
base_group.add_argument(
......@@ -125,7 +125,7 @@ def parse_args():
ps_group.add_argument(
"--gloo_port", type=int, default=6767, help="gloo http port")
ps_group.add_argument(
"--with_gloo", type=str, default="0", help="use gloo or not")
"--with_gloo", type=str, default="1", help="use gloo or not")
# parameter elastic mode
elastic_group = parser.add_argument_group("Elastic Parameters")
......
......@@ -29,4 +29,5 @@ _controllers = [
def init(ctx):
for c in _controllers:
if c.enable(ctx):
ctx.print()
return c(ctx)
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .controller import Controller
from .controller import Controller, ControleMode
import json
import os
......@@ -23,8 +23,10 @@ import time
class CollectiveController(Controller):
@classmethod
def enable(cls, ctx):
# collective is the default mode
if ctx:
ctx.logger.debug("{} enabled".format(cls.__name__))
ctx.args.run_mode = ControleMode.COLLECTIVE
return True
else:
return False
......@@ -85,6 +87,7 @@ class CollectiveController(Controller):
"PADDLE_LOCAL_SIZE": "{}".format(self.pod.replicas),
"PADDLE_GLOBAL_RANK": "{}".format(i + rank_offset),
"PADDLE_LOCAL_RANK": "{}".format(i),
"PADDLE_NNODES": "{}".format(self.job.replicas),
## compatible env
"PADDLE_TRAINER_ENDPOINTS": ",".join(job_endpoints),
"PADDLE_CURRENT_ENDPOINT": endpoints[i],
......@@ -106,6 +109,7 @@ class CollectiveElasticController(CollectiveController):
def enable(cls, ctx):
if ctx.args.master and ctx.args.master.startswith("etcd://"):
ctx.logger.debug("{} enabled".format(cls.__name__))
ctx.args.run_mode = ControleMode.COLLECTIVE
return True
else:
return False
......
......@@ -276,10 +276,20 @@ class ETCDMaster(Master):
return peer_alive
def wait_peer_ready(self, replicas_min, replicas_max, timeout):
timeout = timeout if timeout > 1 else 3
end = time.time() + timeout
np_pre = len(self.fetch_peer_alive())
while not self.ctx.status.is_done() and time.time() < end:
if len(self.fetch_peer_alive()) == replicas_max:
np = len(self.fetch_peer_alive())
if np == replicas_max:
# maximum replicas reached, return immediately
return (True, replicas_max)
elif np != np_pre:
# replicas are changing, reset timeout
end = time.time() + timeout
np_pre = np
time.sleep(0.2)
else:
time.sleep(0.5)
......
......@@ -171,6 +171,7 @@ class PSController(Controller):
for i in range(server_num):
e = {
"PADDLE_NNODES": "{}".format(self.job.replicas),
"PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints),
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints),
"PADDLE_PORT":
......@@ -186,6 +187,7 @@ class PSController(Controller):
for i in range(trainer_num):
e = {
"PADDLE_NNODES": "{}".format(self.job.replicas),
"PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints),
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints),
"PADDLE_PORT":
......
......@@ -17,6 +17,7 @@ import six
__all__ = []
# print configuration after args are well filled in controller init
def log(ctx):
ctx.logger.info("----------- Configuration ----------------------")
for arg, value in sorted(six.iteritems(vars(ctx.args))):
......@@ -59,4 +60,4 @@ def rewrite_host_ip(ctx):
ctx.node.ip = ctx.args.host
enabled_plugins = [collective_compatible, rewrite_host_ip, process_args, log]
enabled_plugins = [collective_compatible, rewrite_host_ip, process_args]
......@@ -95,7 +95,7 @@ class Collective_Test(unittest.TestCase):
shutil.rmtree('./log')
port = random.randrange(6000, 8000)
args = "--job_id test3 --devices 0,1 --master 127.0.0.1:{} --np 2".format(
args = "--job_id test3 --devices 0,1 --master 127.0.0.1:{} --nnodes 2".format(
port)
p1 = self.pdrun(args)
p2 = self.pdrun(args)
......@@ -143,7 +143,7 @@ class PS_Test(unittest.TestCase):
shutil.rmtree('./log')
port = random.randrange(6000, 8000)
args = "--job_id ps3 --master 127.0.0.1:{} --np 2 --server_num=1 --trainer_num=1".format(
args = "--job_id ps3 --master 127.0.0.1:{} --nnodes 2 --server_num=1 --trainer_num=1".format(
port)
p1 = self.pdrun(args)
p2 = self.pdrun(args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册