未验证 提交 6d34d266 编写于 作者: X xiayanming 提交者: GitHub

Optimize fleet elastic scale in/out (#37177)

* fleet support elastic train

* fleet support elastic train

* support elastic

* add unittest

* fix unitest bug

* fix unittest bug

* fix unittest bug

* fix unittest coverage

* fix unittest coverage

* fix unittest coverage

* fix unittest coverage

* fix unittest coverage

* fix elastic bug

* fix ci fail

* fix ci fail

* fix elastic bug

* fix elastic bug

* fix joint debugging bug

* fix joint debugging bug

* fix windows ci failed

* fix windows ci failed

* Optimize fleet elastic scale in/out

* elastic support pre hook

* add prehook unittest
上级 36a95654
......@@ -56,6 +56,9 @@ def launch_elastic(args, distribute_mode):
# wait for all nodes ready to run
elastic.wait()
# execute pre hook action, eg: run shell
elastic.pre_hook()
# run self with specified launcher
elastic.run(CollectiveLauncher)
......
......@@ -22,6 +22,7 @@ import signal
import random
import threading
import traceback
import subprocess
from paddle.distributed.fleet import cloud_utils
from paddle.distributed.fleet import launch_utils
......@@ -133,11 +134,7 @@ class ElasticManager(object):
scale = args.scale or int(os.getenv('PADDLE_ELASTIC_SCALE', 0))
force = args.force or os.getenv('PADDLE_ELASTIC_FORCE')
start_port = 6170
if os.environ.get('FLAGS_START_PORT') is not None:
start_port = int(os.environ.get('FLAGS_START_PORT'))
if cloud_utils.use_paddlecloud():
start_port = int(os.getenv("PADDLE_PORT", ""))
self.host = host if host else self._get_host()
(self.device_mode,
self.devices_per_proc) = launch_utils.get_device_proc_info(args)
......@@ -145,16 +142,30 @@ class ElasticManager(object):
self.elastic_timeout = int(
os.getenv('PADDLE_ELASTIC_TIMEOUT', ELASTIC_TIMEOUT))
elastic_ttl = int(os.getenv('PADDLE_ELASTIC_TTL', ELASTIC_TTL))
self.dist_endpoints = os.getenv('DISTRIBUTED_TRAINER_ENDPOINTS', '')
self.trainers = os.getenv('PADDLE_TRAINERS', '')
self.all_host_endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS',
'').split(",")
self.np = len(self.all_host_endpoints)
logger.info(f'start job with np={self.np}')
#[ "%s:%d" % (ip, start_port) for ip in self.trainers.split(",")]
self.start_port = None
if cloud_utils.use_paddlecloud():
self.trainers = os.getenv('PADDLE_TRAINERS', '')
self.np = len(self.trainers.split(","))
self.start_port = int(os.getenv("PADDLE_PORT", "6170"))
self.dist_endpoints = os.getenv('DISTRIBUTED_TRAINER_ENDPOINTS', '')
trainer_endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS', '')
self.trainer_endpoints_list = trainer_endpoints.split(",")
else:
self.trainers = args.ips or os.getenv('PADDLE_TRAINERS', '')
node_ips = self.trainers.split(",")
self.np = len(node_ips)
self.start_port = int(os.getenv("FLAGS_START_PORT", "6170"))
self.dist_endpoints = self._host_to_endpoints(
node_ips, self.devices_per_proc, self.start_port)
self.trainer_endpoints_list = [
"%s:%d" % (ip, self.start_port) for ip in node_ips
]
self.curr_host = "%s:%d" % (self.host, self.start_port)
logger.info(f'start job with np={self.np}')
logger.info(
f"trainers={self.trainers}, all_host_endpoints={self.all_host_endpoints}"
f"trainers={self.trainers}, trainer_endpoints_list={self.trainer_endpoints_list}"
)
# auto correct the value of elastic_level
......@@ -198,8 +209,6 @@ class ElasticManager(object):
self.enable = True
self.etcd = etcd_client
self.host = host if host else self._get_host()
self.host_port = "%s:%d" % (self.host, start_port)
# etcd data
self.prefix = "/paddle/" + name
......@@ -224,7 +233,7 @@ class ElasticManager(object):
for i in self.etcd.get_prefix(self.node_prefix)
]
logger.info(
f"host_call_back curr_host={self.host_port}, hosts:{self.hosts}")
f"host_call_back curr_host={self.curr_host}, hosts:{self.hosts}")
self.need_sync = True
self.elastic_startup_time = None
......@@ -243,13 +252,13 @@ class ElasticManager(object):
for i in self.etcd.get_prefix(self.node_prefix)
]
logger.info(
f"[lease_heartbeat] curr_host={self.host_port}, hosts={hosts}"
f"[lease_heartbeat] curr_host={self.curr_host}, hosts={hosts}"
)
if self.host_port not in hosts:
if self.curr_host not in hosts:
logger.info(
f"[lease_heartbeat] register host={self.host_port}")
f"[lease_heartbeat] register host={self.curr_host}")
self.etcd.put(self.host_path,
six.b(self.host_port),
six.b(self.curr_host),
lease=host_lease)
except Exception as e:
logger.error("[lease_heartbeat] internal error:{} {}".
......@@ -261,7 +270,7 @@ class ElasticManager(object):
name='lease_heartbeat', target=lease_heartbeat, daemon=True)
keepalived_thread.start()
self.etcd.put(self.host_path, six.b(self.host_port), lease=host_lease)
self.etcd.put(self.host_path, six.b(self.curr_host), lease=host_lease)
# endpoints handle DISTRIBUTED_TRAINER_ENDPOINTS and PADDLE_TRAINERS
self.etcd.put(self.endpoints_path,
......@@ -282,6 +291,26 @@ class ElasticManager(object):
self.watches = [host_watch, endpoints_watch]
self.launcher = None
def _host_to_endpoints(self,
ip_port_list: list,
devices_per_proc: list,
start_port: int=6170) -> str:
endpoint_list = []
for ip_port in ip_port_list:
endpoints = ip_port.split(":")
if len(endpoints) == 2:
ip = endpoints[0]
port = int(endpoints[1])
else:
ip = endpoints
port = start_port
ports = [x for x in range(port, port + len(devices_per_proc))]
endpoint_list.extend(["%s:%d" % (ip, port) for port in ports])
dist_endpoints = ','.join(endpoint_list)
return dist_endpoints
def exit(self, completed=False):
logger.info('manager exist completed {}'.format(completed))
......@@ -302,6 +331,22 @@ class ElasticManager(object):
if len(hosts) == 0:
self.etcd.delete_prefix(self.prefix)
def pre_hook(self):
if not self.args.elastic_pre_hook:
logger.info("skip pre_hook")
return
current_env = copy.copy(os.environ.copy())
out, err = subprocess.Popen(
self.args.elastic_pre_hook,
env=current_env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True).communicate()
if err:
logger.warn("pre_hook exec failed")
else:
logger.info(f"pre_hook exec result: {out.decode('utf-8').strip()}")
def _parse_np(self, np: str):
"""
np format is "MIN" or "MIN:MAX"
......@@ -354,7 +399,6 @@ class ElasticManager(object):
return False
if self.elastic_level == ElasticLevel.ELASTIC:
# FIXME(xym) add freeze status
hosts_num = len(self.hosts)
if hosts_num == self.np:
return True
......@@ -384,120 +428,113 @@ class ElasticManager(object):
self.etcd.put(self.endpoints_path,
six.b('{}|{}'.format(endpoints, hosts)))
def _update_hosts(self):
assert len(self.hosts) != 0, 'hosts empty'
def _update_fault_tolrance(self):
rank = int(os.getenv('PADDLE_TRAINER_ID', -1))
if self.elastic_level == ElasticLevel.FAULT_TOLERANCE:
if self.host_port in self.dist_endpoints:
os.environ[
'DISTRIBUTED_TRAINER_ENDPOINTS'] = self.dist_endpoints
os.environ['PADDLE_TRAINERS'] = self.trainers
logger.info("update env DISTRIBUTED_TRAINER_ENDPOINTS {} ".
format(self.dist_endpoints))
logger.info("update env PADDLE_TRAINERS {} ".format(
self.trainers))
return
if self.curr_host in self.dist_endpoints:
os.environ['DISTRIBUTED_TRAINER_ENDPOINTS'] = self.dist_endpoints
os.environ['PADDLE_TRAINERS'] = self.trainers
logger.info("update env DISTRIBUTED_TRAINER_ENDPOINTS {} ".format(
self.dist_endpoints))
logger.info("update env PADDLE_TRAINERS {} ".format(self.trainers))
return
# fault tolerance
idx = self.hosts.index(self.curr_host)
# swap if self.host not in the right position
if rank >= 0:
self.hosts[idx] = self.hosts[rank]
self.hosts[rank] = self.curr_host
else:
os.environ['PADDLE_TRAINER_ID'] = '{}'.format(idx)
hosts = ','.join([host_port.split(":")[0] for host_port in self.hosts])
self.args.ips = hosts
os.environ['PADDLE_TRAINERS'] = hosts
def _update_elastic_scale_out(self):
host_endpoints = copy.deepcopy(self.trainer_endpoints_list)
logger.info(
f"elastic scale out, from {len(self.hosts)} to {self.np}, hosts={self.hosts}, host_endpoints={host_endpoints}"
)
# fault tolerance
idx = self.hosts.index(self.host_port)
for curr_host_port in self.hosts:
if curr_host_port not in host_endpoints:
host_endpoints.append(curr_host_port)
os.environ['PADDLE_TRAINER_ID'] = '{}'.format(
host_endpoints.index(self.curr_host))
hosts = ','.join(
[host_port.split(":")[0] for host_port in host_endpoints])
self.args.ips = hosts
os.environ['PADDLE_TRAINERS'] = hosts
self.np = len(host_endpoints)
os.environ['PADDLE_TRAINER_ENDPOINTS'] = ','.join(host_endpoints)
os.environ['DISTRIBUTED_TRAINER_ENDPOINTS'] = self.dist_endpoints
self.trainer_endpoints_list = host_endpoints
def _update_elastic_scale_in(self):
host_endpoints = copy.deepcopy(self.trainer_endpoints_list)
logger.info(
f"elastic scale in, from {self.np} to {len(self.hosts)}, hosts={self.hosts}, host_endpoints={host_endpoints}"
)
# swap if self.host not in the right position
if rank >= 0:
self.hosts[idx] = self.hosts[rank]
self.hosts[rank] = self.host_port
# If scale in node from the first of the rank list, you need to minimize the movement of the rank
# eg:
# the source trainers is:10.10.10.0,10.10.10.1,10.10.10.2,10.10.10.3
# 10.10.10.0 is removed
# the new trainers is:10.10.10.3,10.10.10.1,10.10.10.2
# In this case, the rank of 10.10.10.1 and 10.10.10.2 remains unchanged, while the rank of 10.10.10.3 is set to rank0
endpoints_dict = dict()
unsorted_endpoints = []
for id, host_port in enumerate(self.hosts):
idx = host_endpoints.index(host_port)
if idx <= len(self.hosts) - 1 and not endpoints_dict.get(idx):
endpoints_dict[idx] = host_port
else:
os.environ['PADDLE_TRAINER_ID'] = '{}'.format(idx)
hosts = ','.join(
[host_port.split(":")[0] for host_port in self.hosts])
self.args.ips = hosts
os.environ['PADDLE_TRAINERS'] = hosts
unsorted_endpoints.append(host_port)
idle_index = 0
sorted_endpoints = []
for idx in range(len(self.hosts)):
if not endpoints_dict.get(idx) and len(unsorted_endpoints) > 0:
endpoints_dict[idx] = unsorted_endpoints[idle_index]
idle_index += 1
sorted_endpoints.append(endpoints_dict.get(idx))
logger.info(f"elastic scale in, sorted_endpoints={sorted_endpoints}")
self.trainer_endpoints_list = sorted_endpoints
ip_list = [ip_port.split(":")[0] for ip_port in sorted_endpoints]
hosts = ','.join(ip_list)
new_endpoints = self._host_to_endpoints(sorted_endpoints,
self.devices_per_proc)
self.args.ips = hosts
os.environ['PADDLE_TRAINER_ID'] = '{}'.format(
sorted_endpoints.index(self.curr_host))
os.environ['PADDLE_TRAINERS'] = hosts
self.np = len(sorted_endpoints)
os.environ['PADDLE_TRAINER_ENDPOINTS'] = ','.join(sorted_endpoints)
os.environ['DISTRIBUTED_TRAINER_ENDPOINTS'] = new_endpoints
self._update_endpoint(new_endpoints, hosts)
def _update_hosts(self):
assert len(self.hosts) != 0, 'hosts empty'
if self.elastic_level == ElasticLevel.FAULT_TOLERANCE:
self._update_fault_tolrance()
else:
# elastic, scale up/down
endpoints = copy.deepcopy(self.all_host_endpoints)
if len(self.hosts) > self.np:
# scale up
logger.info(
f"elastic scale up, from {self.np} to {len(self.hosts)}, hosts={self.hosts}, endpoints={endpoints}"
)
for curr_host_port in self.hosts:
if curr_host_port not in endpoints:
endpoints.append(curr_host_port)
os.environ['PADDLE_TRAINER_ID'] = '{}'.format(
endpoints.index(self.host_port))
hosts = ','.join(
[host_port.split(":")[0] for host_port in endpoints])
self.args.ips = hosts
os.environ['PADDLE_TRAINERS'] = hosts
self.np = len(endpoints)
os.environ['PADDLE_TRAINER_ENDPOINTS'] = ','.join(endpoints)
os.environ[
'DISTRIBUTED_TRAINER_ENDPOINTS'] = self.dist_endpoints
self.all_host_endpoints = endpoints
# elastic
if len(self.hosts) == self.np:
logger.info(f"elastic startup, hosts={self.hosts}")
self._update_fault_tolrance()
elif len(self.hosts) > self.np:
# scale out
self._update_elastic_scale_out()
else:
# scale down
logger.info(
f"elastic scale down, from {len(self.hosts)} to {self.np}, hosts={self.hosts}, endpoints={endpoints}"
)
# If the shrink node is from the first of the rank list, you need to minimize the movement of the rank
# eg:
# the source trainers is:10.10.10.0,10.10.10.1,10.10.10.2,10.10.10.3
# 10.10.10.0 is removed
# the new trainers is:10.10.10.3,10.10.10.1,10.10.10.2
# In this case, the rank of 10.10.10.1 and 10.10.10.2 remains unchanged, while the rank of 10.10.10.3 is set to rank0
endpoints_dict = dict()
unsorted_endpoints = []
for id, host_port in enumerate(self.hosts):
idx = endpoints.index(host_port)
if idx <= len(self.hosts) - 1 and not endpoints_dict.get(
idx):
endpoints_dict[idx] = host_port
else:
unsorted_endpoints.append(host_port)
idle_index = 0
sorted_endpoints = []
for idx in range(len(self.hosts)):
if not endpoints_dict.get(idx) and len(
unsorted_endpoints) > 0:
endpoints_dict[idx] = unsorted_endpoints[idle_index]
idle_index += 1
sorted_endpoints.append(endpoints_dict.get(idx))
logger.info(
f"elastic scale down, sorted_endpoints={sorted_endpoints}")
self.all_host_endpoints = sorted_endpoints
endpoint_list = []
ip_list = []
for host_port in sorted_endpoints:
host_port_list = host_port.split(":")
ip = host_port_list[0]
port = int(host_port_list[1])
ip_list.append(ip)
ports = [
x
for x in range(port, port + len(self.devices_per_proc))
]
endpoint_list.extend(
["%s:%d" % (ip, port) for port in ports])
hosts = ','.join(ip_list)
new_endpoints = ','.join(endpoint_list)
self.args.ips = hosts
os.environ['PADDLE_TRAINER_ID'] = '{}'.format(
sorted_endpoints.index(self.host_port))
os.environ['PADDLE_TRAINERS'] = hosts
self.np = len(sorted_endpoints)
os.environ['PADDLE_TRAINER_ENDPOINTS'] = ','.join(
sorted_endpoints)
os.environ['DISTRIBUTED_TRAINER_ENDPOINTS'] = new_endpoints
self._update_endpoint(new_endpoints, hosts)
# scale in
self._update_elastic_scale_in()
def wait(self):
if not self.enable:
......
......@@ -218,6 +218,9 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
elastic_group = parser.add_argument_group("Elastic Parameters")
elastic_group.add_argument(
"--elastic_server", type=str, help="etcd server host:port")
elastic_group.add_argument(
"--elastic_pre_hook", type=str, help="elastic pre_hook shell cmd")
elastic_group.add_argument("--job_id", type=str, help="job unique id")
elastic_group.add_argument("--np", type=int, help="job pod/node number")
elastic_group.add_argument("--scale", type=int, default=0, help="scale np")
......
......@@ -78,7 +78,8 @@ class TestElasticManager(unittest.TestCase):
gpus = "0"
nproc_per_node = 1
host = None
host_port = None
curr_host = None
ips = None
scale = None
force = None
backend = 'gloo'
......@@ -100,18 +101,25 @@ class TestElasticManager(unittest.TestCase):
gpus = "0"
nproc_per_node = 1
host = None
host_port = None
curr_host = None
ips = None
scale = None
force = None
backend = 'gloo'
args = Argument()
args.ips = "10.10.10.1,10.10.10.2"
elastic = ElasticManager(args, self.etcd_client)
os.environ['FLAGS_START_PORT'] = "6001"
hosts = ["10.10.10.1:6001", "10.10.10.2:6001"]
os.environ[
'PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.2:6001"
self.assertEqual(elastic._match(hosts), True)
hosts = ["10.10.10.1:6001"]
args.ips = "10.10.10.1"
os.environ['PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.1:6001"
self.assertEqual(elastic._match(hosts), False)
......@@ -123,13 +131,15 @@ class TestElasticManager(unittest.TestCase):
gpus = "0"
nproc_per_node = 1
host = None
host_port = None
curr_host = None
ips = None
scale = None
force = None
backend = 'gloo'
os.environ['PADDLE_ELASTIC_TIMEOUT'] = "60"
args = Argument()
args.ips = "10.10.10.1,10.10.10.2,10.10.10.3,10.10.10.4"
os.environ['FLAGS_START_PORT'] = "6001"
os.environ[
'DISTRIBUTED_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.2:6001,10.10.10.3:6001,10.10.10.4:6001"
......@@ -151,6 +161,7 @@ class TestElasticManager(unittest.TestCase):
hosts = ["10.10.10.1:6001"]
self.assertEqual(elastic._match(hosts), False)
args.ips = "10.10.10.1,10.10.10.2"
os.environ[
'DISTRIBUTED_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.2:6001"
os.environ[
......@@ -171,7 +182,8 @@ class TestElasticManager(unittest.TestCase):
gpus = "0"
nproc_per_node = 1
host = None
host_port = None
curr_host = None
ips = None
scale = None
force = None
backend = 'gloo'
......@@ -187,19 +199,19 @@ class TestElasticManager(unittest.TestCase):
elastic = ElasticManager(args, self.etcd_client)
# add 10.10.10.3:6001
os.environ['PADDLE_TRAINER_ID'] = "0"
elastic.host_port = "10.10.10.1:6001"
elastic.curr_host = "10.10.10.1:6001"
elastic.hosts = ["10.10.10.1:6001", "10.10.10.2:6001"]
elastic._update_hosts()
self.assertEqual(os.getenv('PADDLE_TRAINERS'), "10.10.10.1,10.10.10.2")
# add 10.10.10.3:6001
elastic.host_port = "10.10.10.3:6001"
elastic.curr_host = "10.10.10.3:6001"
elastic.hosts = ["10.10.10.1:6001", "10.10.10.3:6001"]
os.environ['PADDLE_TRAINER_ID'] = "1"
elastic._update_hosts()
self.assertEqual(os.getenv('PADDLE_TRAINERS'), "10.10.10.1,10.10.10.3")
elastic.host_port = "10.10.10.3:6001"
elastic.curr_host = "10.10.10.3:6001"
elastic.hosts = ["10.10.10.1:6001", "10.10.10.3:6001"]
os.environ['PADDLE_TRAINER_ID'] = "-1"
elastic._update_hosts()
......@@ -216,7 +228,8 @@ class TestElasticManager(unittest.TestCase):
gpus = "0"
nproc_per_node = 1
host = None
host_port = None
curr_host = None
ips = None
scale = None
force = None
backend = 'gloo'
......@@ -231,7 +244,7 @@ class TestElasticManager(unittest.TestCase):
'PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.2:6001"
elastic = ElasticManager(args, self.etcd_client)
# add 10.10.10.3:6001
elastic.host_port = "10.10.10.1:6001"
elastic.curr_host = "10.10.10.1:6001"
elastic.hosts = [
"10.10.10.1:6001", "10.10.10.2:6001", "10.10.10.3:6001"
]
......@@ -242,7 +255,7 @@ class TestElasticManager(unittest.TestCase):
os.getenv('PADDLE_TRAINERS'), "10.10.10.1,10.10.10.2,10.10.10.3")
#######################
# elastic, scale down #
# elastic, scale in #
#######################
os.environ[
'PADDLE_TRAINERS'] = "10.10.10.0,10.10.10.1,10.10.10.2,10.10.10.3"
......@@ -250,9 +263,14 @@ class TestElasticManager(unittest.TestCase):
'DISTRIBUTED_TRAINER_ENDPOINTS'] = "10.10.10.0:6000,10.10.10.1:6001,10.10.10.2:6001,10.10.10.3:6001"
os.environ[
'PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.0:6000,10.10.10.1:6001,10.10.10.2:6001,10.10.10.3:6001"
os.environ['POD_IP'] = "10.10.10.1"
os.environ['TRAINER_PORTS_NUM'] = "4"
os.environ['PADDLE_TRAINER_ID'] = "1"
os.environ['PADDLE_PORT'] = "6001"
args = Argument()
elastic = ElasticManager(args, self.etcd_client)
# remove 10.10.10.1:6001
elastic.host_port = "10.10.10.1:6001"
elastic.curr_host = "10.10.10.1:6001"
elastic.hosts = [
"10.10.10.1:6001", "10.10.10.2:6001", "10.10.10.3:6001"
]
......@@ -266,23 +284,28 @@ class TestElasticManager(unittest.TestCase):
"10.10.10.3:6001,10.10.10.1:6001,10.10.10.2:6001")
############
os.environ['PADDLE_TRAINERS'] = "10.10.10.1,10.10.10.1"
os.environ[
'PADDLE_TRAINERS'] = "10.10.10.1,10.10.10.1,10.10.10.1,10.10.10.1"
os.environ[
'DISTRIBUTED_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.1:6002,10.10.10.1:6003,10.10.10.1:6004"
os.environ[
'PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.1:6002,10.10.10.1:6003,10.10.10.1:6004"
os.environ['POD_IP'] = "10.10.10.1"
os.environ['TRAINER_PORTS_NUM'] = "4"
os.environ['PADDLE_PORT'] = "6001"
args = Argument()
elastic = ElasticManager(args, self.etcd_client)
# remove 10.10.10.1:6001
elastic.host_port = "10.10.10.1:6001"
elastic.curr_host = "10.10.10.1:6001"
os.environ['PADDLE_TRAINER_ID'] = "-1"
elastic.hosts = ["10.10.10.1:6001", "10.10.10.1:6001"]
elastic.hosts = ["10.10.10.1:6001", "10.10.10.1:6003"]
elastic._update_hosts()
#self.assertEqual(elastic.all_host_endpoints,
# ["10.10.10.1:6001", "10.10.10.1:6001"])
self.assertEqual(os.getenv('PADDLE_TRAINERS'), "10.10.10.1,10.10.10.1")
self.assertEqual(
os.getenv('DISTRIBUTED_TRAINER_ENDPOINTS'),
"10.10.10.1:6001,10.10.10.1:6001")
"10.10.10.1:6001,10.10.10.1:6003")
def test_exit(self):
class Argument:
......@@ -292,7 +315,8 @@ class TestElasticManager(unittest.TestCase):
gpus = "0"
nproc_per_node = 1
host = None
host_port = None
curr_host = None
ips = None
scale = None
force = None
backend = 'gloo'
......@@ -301,6 +325,28 @@ class TestElasticManager(unittest.TestCase):
elastic = ElasticManager(args, self.etcd_client)
elastic.exit()
def test_pre_hook(self):
class Argument:
elastic_server = "127.0.0.1:2379"
job_id = "test_job_id_123"
np = "2"
gpus = "0"
nproc_per_node = 1
host = None
curr_host = None
ips = None
scale = None
force = None
backend = 'gloo'
elastic_pre_hook = None
args = Argument()
elastic = ElasticManager(args, self.etcd_client)
elastic.pre_hook()
args.elastic_pre_hook = "hostname"
elastic.pre_hook()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册