未验证 提交 6d34d266 编写于 作者: X xiayanming 提交者: GitHub

Optimize fleet elastic scale in/out (#37177)

* fleet support elastic train

* fleet support elastic train

* support elastic

* add unittest

* fix unitest bug

* fix unittest bug

* fix unittest bug

* fix unittest coverage

* fix unittest coverage

* fix unittest coverage

* fix unittest coverage

* fix unittest coverage

* fix elastic bug

* fix ci fail

* fix ci fail

* fix elastic bug

* fix elastic bug

* fix joint debugging bug

* fix joint debugging bug

* fix windows ci failed

* fix windows ci failed

* Optimize fleet elastic scale in/out

* elastic support pre hook

* add prehook unittest
上级 36a95654
...@@ -56,6 +56,9 @@ def launch_elastic(args, distribute_mode): ...@@ -56,6 +56,9 @@ def launch_elastic(args, distribute_mode):
# wait for all nodes ready to run # wait for all nodes ready to run
elastic.wait() elastic.wait()
# execute pre hook action, eg: run shell
elastic.pre_hook()
# run self with specified launcher # run self with specified launcher
elastic.run(CollectiveLauncher) elastic.run(CollectiveLauncher)
......
...@@ -22,6 +22,7 @@ import signal ...@@ -22,6 +22,7 @@ import signal
import random import random
import threading import threading
import traceback import traceback
import subprocess
from paddle.distributed.fleet import cloud_utils from paddle.distributed.fleet import cloud_utils
from paddle.distributed.fleet import launch_utils from paddle.distributed.fleet import launch_utils
...@@ -133,11 +134,7 @@ class ElasticManager(object): ...@@ -133,11 +134,7 @@ class ElasticManager(object):
scale = args.scale or int(os.getenv('PADDLE_ELASTIC_SCALE', 0)) scale = args.scale or int(os.getenv('PADDLE_ELASTIC_SCALE', 0))
force = args.force or os.getenv('PADDLE_ELASTIC_FORCE') force = args.force or os.getenv('PADDLE_ELASTIC_FORCE')
start_port = 6170 self.host = host if host else self._get_host()
if os.environ.get('FLAGS_START_PORT') is not None:
start_port = int(os.environ.get('FLAGS_START_PORT'))
if cloud_utils.use_paddlecloud():
start_port = int(os.getenv("PADDLE_PORT", ""))
(self.device_mode, (self.device_mode,
self.devices_per_proc) = launch_utils.get_device_proc_info(args) self.devices_per_proc) = launch_utils.get_device_proc_info(args)
...@@ -145,16 +142,30 @@ class ElasticManager(object): ...@@ -145,16 +142,30 @@ class ElasticManager(object):
self.elastic_timeout = int( self.elastic_timeout = int(
os.getenv('PADDLE_ELASTIC_TIMEOUT', ELASTIC_TIMEOUT)) os.getenv('PADDLE_ELASTIC_TIMEOUT', ELASTIC_TIMEOUT))
elastic_ttl = int(os.getenv('PADDLE_ELASTIC_TTL', ELASTIC_TTL)) elastic_ttl = int(os.getenv('PADDLE_ELASTIC_TTL', ELASTIC_TTL))
self.dist_endpoints = os.getenv('DISTRIBUTED_TRAINER_ENDPOINTS', '')
self.start_port = None
if cloud_utils.use_paddlecloud():
self.trainers = os.getenv('PADDLE_TRAINERS', '') self.trainers = os.getenv('PADDLE_TRAINERS', '')
self.all_host_endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS', self.np = len(self.trainers.split(","))
'').split(",") self.start_port = int(os.getenv("PADDLE_PORT", "6170"))
self.np = len(self.all_host_endpoints) self.dist_endpoints = os.getenv('DISTRIBUTED_TRAINER_ENDPOINTS', '')
logger.info(f'start job with np={self.np}') trainer_endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS', '')
self.trainer_endpoints_list = trainer_endpoints.split(",")
else:
self.trainers = args.ips or os.getenv('PADDLE_TRAINERS', '')
node_ips = self.trainers.split(",")
self.np = len(node_ips)
self.start_port = int(os.getenv("FLAGS_START_PORT", "6170"))
self.dist_endpoints = self._host_to_endpoints(
node_ips, self.devices_per_proc, self.start_port)
self.trainer_endpoints_list = [
"%s:%d" % (ip, self.start_port) for ip in node_ips
]
#[ "%s:%d" % (ip, start_port) for ip in self.trainers.split(",")] self.curr_host = "%s:%d" % (self.host, self.start_port)
logger.info(f'start job with np={self.np}')
logger.info( logger.info(
f"trainers={self.trainers}, all_host_endpoints={self.all_host_endpoints}" f"trainers={self.trainers}, trainer_endpoints_list={self.trainer_endpoints_list}"
) )
# auto correct the value of elastic_level # auto correct the value of elastic_level
...@@ -198,8 +209,6 @@ class ElasticManager(object): ...@@ -198,8 +209,6 @@ class ElasticManager(object):
self.enable = True self.enable = True
self.etcd = etcd_client self.etcd = etcd_client
self.host = host if host else self._get_host()
self.host_port = "%s:%d" % (self.host, start_port)
# etcd data # etcd data
self.prefix = "/paddle/" + name self.prefix = "/paddle/" + name
...@@ -224,7 +233,7 @@ class ElasticManager(object): ...@@ -224,7 +233,7 @@ class ElasticManager(object):
for i in self.etcd.get_prefix(self.node_prefix) for i in self.etcd.get_prefix(self.node_prefix)
] ]
logger.info( logger.info(
f"host_call_back curr_host={self.host_port}, hosts:{self.hosts}") f"host_call_back curr_host={self.curr_host}, hosts:{self.hosts}")
self.need_sync = True self.need_sync = True
self.elastic_startup_time = None self.elastic_startup_time = None
...@@ -243,13 +252,13 @@ class ElasticManager(object): ...@@ -243,13 +252,13 @@ class ElasticManager(object):
for i in self.etcd.get_prefix(self.node_prefix) for i in self.etcd.get_prefix(self.node_prefix)
] ]
logger.info( logger.info(
f"[lease_heartbeat] curr_host={self.host_port}, hosts={hosts}" f"[lease_heartbeat] curr_host={self.curr_host}, hosts={hosts}"
) )
if self.host_port not in hosts: if self.curr_host not in hosts:
logger.info( logger.info(
f"[lease_heartbeat] register host={self.host_port}") f"[lease_heartbeat] register host={self.curr_host}")
self.etcd.put(self.host_path, self.etcd.put(self.host_path,
six.b(self.host_port), six.b(self.curr_host),
lease=host_lease) lease=host_lease)
except Exception as e: except Exception as e:
logger.error("[lease_heartbeat] internal error:{} {}". logger.error("[lease_heartbeat] internal error:{} {}".
...@@ -261,7 +270,7 @@ class ElasticManager(object): ...@@ -261,7 +270,7 @@ class ElasticManager(object):
name='lease_heartbeat', target=lease_heartbeat, daemon=True) name='lease_heartbeat', target=lease_heartbeat, daemon=True)
keepalived_thread.start() keepalived_thread.start()
self.etcd.put(self.host_path, six.b(self.host_port), lease=host_lease) self.etcd.put(self.host_path, six.b(self.curr_host), lease=host_lease)
# endpoints handle DISTRIBUTED_TRAINER_ENDPOINTS and PADDLE_TRAINERS # endpoints handle DISTRIBUTED_TRAINER_ENDPOINTS and PADDLE_TRAINERS
self.etcd.put(self.endpoints_path, self.etcd.put(self.endpoints_path,
...@@ -282,6 +291,26 @@ class ElasticManager(object): ...@@ -282,6 +291,26 @@ class ElasticManager(object):
self.watches = [host_watch, endpoints_watch] self.watches = [host_watch, endpoints_watch]
self.launcher = None self.launcher = None
def _host_to_endpoints(self,
ip_port_list: list,
devices_per_proc: list,
start_port: int=6170) -> str:
endpoint_list = []
for ip_port in ip_port_list:
endpoints = ip_port.split(":")
if len(endpoints) == 2:
ip = endpoints[0]
port = int(endpoints[1])
else:
ip = endpoints
port = start_port
ports = [x for x in range(port, port + len(devices_per_proc))]
endpoint_list.extend(["%s:%d" % (ip, port) for port in ports])
dist_endpoints = ','.join(endpoint_list)
return dist_endpoints
def exit(self, completed=False): def exit(self, completed=False):
logger.info('manager exist completed {}'.format(completed)) logger.info('manager exist completed {}'.format(completed))
...@@ -302,6 +331,22 @@ class ElasticManager(object): ...@@ -302,6 +331,22 @@ class ElasticManager(object):
if len(hosts) == 0: if len(hosts) == 0:
self.etcd.delete_prefix(self.prefix) self.etcd.delete_prefix(self.prefix)
def pre_hook(self):
if not self.args.elastic_pre_hook:
logger.info("skip pre_hook")
return
current_env = copy.copy(os.environ.copy())
out, err = subprocess.Popen(
self.args.elastic_pre_hook,
env=current_env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True).communicate()
if err:
logger.warn("pre_hook exec failed")
else:
logger.info(f"pre_hook exec result: {out.decode('utf-8').strip()}")
def _parse_np(self, np: str): def _parse_np(self, np: str):
""" """
np format is "MIN" or "MIN:MAX" np format is "MIN" or "MIN:MAX"
...@@ -354,7 +399,6 @@ class ElasticManager(object): ...@@ -354,7 +399,6 @@ class ElasticManager(object):
return False return False
if self.elastic_level == ElasticLevel.ELASTIC: if self.elastic_level == ElasticLevel.ELASTIC:
# FIXME(xym) add freeze status
hosts_num = len(self.hosts) hosts_num = len(self.hosts)
if hosts_num == self.np: if hosts_num == self.np:
return True return True
...@@ -384,64 +428,57 @@ class ElasticManager(object): ...@@ -384,64 +428,57 @@ class ElasticManager(object):
self.etcd.put(self.endpoints_path, self.etcd.put(self.endpoints_path,
six.b('{}|{}'.format(endpoints, hosts))) six.b('{}|{}'.format(endpoints, hosts)))
def _update_hosts(self): def _update_fault_tolrance(self):
assert len(self.hosts) != 0, 'hosts empty'
rank = int(os.getenv('PADDLE_TRAINER_ID', -1)) rank = int(os.getenv('PADDLE_TRAINER_ID', -1))
if self.elastic_level == ElasticLevel.FAULT_TOLERANCE: if self.curr_host in self.dist_endpoints:
if self.host_port in self.dist_endpoints: os.environ['DISTRIBUTED_TRAINER_ENDPOINTS'] = self.dist_endpoints
os.environ[
'DISTRIBUTED_TRAINER_ENDPOINTS'] = self.dist_endpoints
os.environ['PADDLE_TRAINERS'] = self.trainers os.environ['PADDLE_TRAINERS'] = self.trainers
logger.info("update env DISTRIBUTED_TRAINER_ENDPOINTS {} ". logger.info("update env DISTRIBUTED_TRAINER_ENDPOINTS {} ".format(
format(self.dist_endpoints)) self.dist_endpoints))
logger.info("update env PADDLE_TRAINERS {} ".format( logger.info("update env PADDLE_TRAINERS {} ".format(self.trainers))
self.trainers))
return return
# fault tolerance # fault tolerance
idx = self.hosts.index(self.host_port) idx = self.hosts.index(self.curr_host)
# swap if self.host not in the right position # swap if self.host not in the right position
if rank >= 0: if rank >= 0:
self.hosts[idx] = self.hosts[rank] self.hosts[idx] = self.hosts[rank]
self.hosts[rank] = self.host_port self.hosts[rank] = self.curr_host
else: else:
os.environ['PADDLE_TRAINER_ID'] = '{}'.format(idx) os.environ['PADDLE_TRAINER_ID'] = '{}'.format(idx)
hosts = ','.join( hosts = ','.join([host_port.split(":")[0] for host_port in self.hosts])
[host_port.split(":")[0] for host_port in self.hosts])
self.args.ips = hosts self.args.ips = hosts
os.environ['PADDLE_TRAINERS'] = hosts os.environ['PADDLE_TRAINERS'] = hosts
else:
# elastic, scale up/down def _update_elastic_scale_out(self):
endpoints = copy.deepcopy(self.all_host_endpoints) host_endpoints = copy.deepcopy(self.trainer_endpoints_list)
if len(self.hosts) > self.np:
# scale up
logger.info( logger.info(
f"elastic scale up, from {self.np} to {len(self.hosts)}, hosts={self.hosts}, endpoints={endpoints}" f"elastic scale out, from {len(self.hosts)} to {self.np}, hosts={self.hosts}, host_endpoints={host_endpoints}"
) )
for curr_host_port in self.hosts: for curr_host_port in self.hosts:
if curr_host_port not in endpoints: if curr_host_port not in host_endpoints:
endpoints.append(curr_host_port) host_endpoints.append(curr_host_port)
os.environ['PADDLE_TRAINER_ID'] = '{}'.format( os.environ['PADDLE_TRAINER_ID'] = '{}'.format(
endpoints.index(self.host_port)) host_endpoints.index(self.curr_host))
hosts = ','.join( hosts = ','.join(
[host_port.split(":")[0] for host_port in endpoints]) [host_port.split(":")[0] for host_port in host_endpoints])
self.args.ips = hosts self.args.ips = hosts
os.environ['PADDLE_TRAINERS'] = hosts os.environ['PADDLE_TRAINERS'] = hosts
self.np = len(endpoints) self.np = len(host_endpoints)
os.environ['PADDLE_TRAINER_ENDPOINTS'] = ','.join(endpoints) os.environ['PADDLE_TRAINER_ENDPOINTS'] = ','.join(host_endpoints)
os.environ[ os.environ['DISTRIBUTED_TRAINER_ENDPOINTS'] = self.dist_endpoints
'DISTRIBUTED_TRAINER_ENDPOINTS'] = self.dist_endpoints self.trainer_endpoints_list = host_endpoints
self.all_host_endpoints = endpoints
else: def _update_elastic_scale_in(self):
# scale down host_endpoints = copy.deepcopy(self.trainer_endpoints_list)
logger.info( logger.info(
f"elastic scale down, from {len(self.hosts)} to {self.np}, hosts={self.hosts}, endpoints={endpoints}" f"elastic scale in, from {self.np} to {len(self.hosts)}, hosts={self.hosts}, host_endpoints={host_endpoints}"
) )
# If the shrink node is from the first of the rank list, you need to minimize the movement of the rank # If scale in node from the first of the rank list, you need to minimize the movement of the rank
# eg: # eg:
# the source trainers is:10.10.10.0,10.10.10.1,10.10.10.2,10.10.10.3 # the source trainers is:10.10.10.0,10.10.10.1,10.10.10.2,10.10.10.3
# 10.10.10.0 is removed # 10.10.10.0 is removed
...@@ -450,9 +487,8 @@ class ElasticManager(object): ...@@ -450,9 +487,8 @@ class ElasticManager(object):
endpoints_dict = dict() endpoints_dict = dict()
unsorted_endpoints = [] unsorted_endpoints = []
for id, host_port in enumerate(self.hosts): for id, host_port in enumerate(self.hosts):
idx = endpoints.index(host_port) idx = host_endpoints.index(host_port)
if idx <= len(self.hosts) - 1 and not endpoints_dict.get( if idx <= len(self.hosts) - 1 and not endpoints_dict.get(idx):
idx):
endpoints_dict[idx] = host_port endpoints_dict[idx] = host_port
else: else:
unsorted_endpoints.append(host_port) unsorted_endpoints.append(host_port)
...@@ -460,45 +496,46 @@ class ElasticManager(object): ...@@ -460,45 +496,46 @@ class ElasticManager(object):
idle_index = 0 idle_index = 0
sorted_endpoints = [] sorted_endpoints = []
for idx in range(len(self.hosts)): for idx in range(len(self.hosts)):
if not endpoints_dict.get(idx) and len( if not endpoints_dict.get(idx) and len(unsorted_endpoints) > 0:
unsorted_endpoints) > 0:
endpoints_dict[idx] = unsorted_endpoints[idle_index] endpoints_dict[idx] = unsorted_endpoints[idle_index]
idle_index += 1 idle_index += 1
sorted_endpoints.append(endpoints_dict.get(idx)) sorted_endpoints.append(endpoints_dict.get(idx))
logger.info( logger.info(f"elastic scale in, sorted_endpoints={sorted_endpoints}")
f"elastic scale down, sorted_endpoints={sorted_endpoints}") self.trainer_endpoints_list = sorted_endpoints
self.all_host_endpoints = sorted_endpoints
endpoint_list = []
ip_list = []
for host_port in sorted_endpoints:
host_port_list = host_port.split(":")
ip = host_port_list[0]
port = int(host_port_list[1])
ip_list.append(ip)
ports = [
x
for x in range(port, port + len(self.devices_per_proc))
]
endpoint_list.extend(
["%s:%d" % (ip, port) for port in ports])
ip_list = [ip_port.split(":")[0] for ip_port in sorted_endpoints]
hosts = ','.join(ip_list) hosts = ','.join(ip_list)
new_endpoints = ','.join(endpoint_list) new_endpoints = self._host_to_endpoints(sorted_endpoints,
self.devices_per_proc)
self.args.ips = hosts self.args.ips = hosts
os.environ['PADDLE_TRAINER_ID'] = '{}'.format( os.environ['PADDLE_TRAINER_ID'] = '{}'.format(
sorted_endpoints.index(self.host_port)) sorted_endpoints.index(self.curr_host))
os.environ['PADDLE_TRAINERS'] = hosts os.environ['PADDLE_TRAINERS'] = hosts
self.np = len(sorted_endpoints) self.np = len(sorted_endpoints)
os.environ['PADDLE_TRAINER_ENDPOINTS'] = ','.join( os.environ['PADDLE_TRAINER_ENDPOINTS'] = ','.join(sorted_endpoints)
sorted_endpoints)
os.environ['DISTRIBUTED_TRAINER_ENDPOINTS'] = new_endpoints os.environ['DISTRIBUTED_TRAINER_ENDPOINTS'] = new_endpoints
self._update_endpoint(new_endpoints, hosts) self._update_endpoint(new_endpoints, hosts)
def _update_hosts(self):
assert len(self.hosts) != 0, 'hosts empty'
if self.elastic_level == ElasticLevel.FAULT_TOLERANCE:
self._update_fault_tolrance()
else:
# elastic
if len(self.hosts) == self.np:
logger.info(f"elastic startup, hosts={self.hosts}")
self._update_fault_tolrance()
elif len(self.hosts) > self.np:
# scale out
self._update_elastic_scale_out()
else:
# scale in
self._update_elastic_scale_in()
def wait(self): def wait(self):
if not self.enable: if not self.enable:
return return
......
...@@ -218,6 +218,9 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ...@@ -218,6 +218,9 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
elastic_group = parser.add_argument_group("Elastic Parameters") elastic_group = parser.add_argument_group("Elastic Parameters")
elastic_group.add_argument( elastic_group.add_argument(
"--elastic_server", type=str, help="etcd server host:port") "--elastic_server", type=str, help="etcd server host:port")
elastic_group.add_argument(
"--elastic_pre_hook", type=str, help="elastic pre_hook shell cmd")
elastic_group.add_argument("--job_id", type=str, help="job unique id") elastic_group.add_argument("--job_id", type=str, help="job unique id")
elastic_group.add_argument("--np", type=int, help="job pod/node number") elastic_group.add_argument("--np", type=int, help="job pod/node number")
elastic_group.add_argument("--scale", type=int, default=0, help="scale np") elastic_group.add_argument("--scale", type=int, default=0, help="scale np")
......
...@@ -78,7 +78,8 @@ class TestElasticManager(unittest.TestCase): ...@@ -78,7 +78,8 @@ class TestElasticManager(unittest.TestCase):
gpus = "0" gpus = "0"
nproc_per_node = 1 nproc_per_node = 1
host = None host = None
host_port = None curr_host = None
ips = None
scale = None scale = None
force = None force = None
backend = 'gloo' backend = 'gloo'
...@@ -100,18 +101,25 @@ class TestElasticManager(unittest.TestCase): ...@@ -100,18 +101,25 @@ class TestElasticManager(unittest.TestCase):
gpus = "0" gpus = "0"
nproc_per_node = 1 nproc_per_node = 1
host = None host = None
host_port = None curr_host = None
ips = None
scale = None scale = None
force = None force = None
backend = 'gloo' backend = 'gloo'
args = Argument() args = Argument()
args.ips = "10.10.10.1,10.10.10.2"
elastic = ElasticManager(args, self.etcd_client) elastic = ElasticManager(args, self.etcd_client)
os.environ['FLAGS_START_PORT'] = "6001"
hosts = ["10.10.10.1:6001", "10.10.10.2:6001"] hosts = ["10.10.10.1:6001", "10.10.10.2:6001"]
os.environ[ os.environ[
'PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.2:6001" 'PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.2:6001"
self.assertEqual(elastic._match(hosts), True) self.assertEqual(elastic._match(hosts), True)
hosts = ["10.10.10.1:6001"] hosts = ["10.10.10.1:6001"]
args.ips = "10.10.10.1"
os.environ['PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.1:6001" os.environ['PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.1:6001"
self.assertEqual(elastic._match(hosts), False) self.assertEqual(elastic._match(hosts), False)
...@@ -123,13 +131,15 @@ class TestElasticManager(unittest.TestCase): ...@@ -123,13 +131,15 @@ class TestElasticManager(unittest.TestCase):
gpus = "0" gpus = "0"
nproc_per_node = 1 nproc_per_node = 1
host = None host = None
host_port = None curr_host = None
ips = None
scale = None scale = None
force = None force = None
backend = 'gloo' backend = 'gloo'
os.environ['PADDLE_ELASTIC_TIMEOUT'] = "60" os.environ['PADDLE_ELASTIC_TIMEOUT'] = "60"
args = Argument() args = Argument()
args.ips = "10.10.10.1,10.10.10.2,10.10.10.3,10.10.10.4"
os.environ['FLAGS_START_PORT'] = "6001" os.environ['FLAGS_START_PORT'] = "6001"
os.environ[ os.environ[
'DISTRIBUTED_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.2:6001,10.10.10.3:6001,10.10.10.4:6001" 'DISTRIBUTED_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.2:6001,10.10.10.3:6001,10.10.10.4:6001"
...@@ -151,6 +161,7 @@ class TestElasticManager(unittest.TestCase): ...@@ -151,6 +161,7 @@ class TestElasticManager(unittest.TestCase):
hosts = ["10.10.10.1:6001"] hosts = ["10.10.10.1:6001"]
self.assertEqual(elastic._match(hosts), False) self.assertEqual(elastic._match(hosts), False)
args.ips = "10.10.10.1,10.10.10.2"
os.environ[ os.environ[
'DISTRIBUTED_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.2:6001" 'DISTRIBUTED_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.2:6001"
os.environ[ os.environ[
...@@ -171,7 +182,8 @@ class TestElasticManager(unittest.TestCase): ...@@ -171,7 +182,8 @@ class TestElasticManager(unittest.TestCase):
gpus = "0" gpus = "0"
nproc_per_node = 1 nproc_per_node = 1
host = None host = None
host_port = None curr_host = None
ips = None
scale = None scale = None
force = None force = None
backend = 'gloo' backend = 'gloo'
...@@ -187,19 +199,19 @@ class TestElasticManager(unittest.TestCase): ...@@ -187,19 +199,19 @@ class TestElasticManager(unittest.TestCase):
elastic = ElasticManager(args, self.etcd_client) elastic = ElasticManager(args, self.etcd_client)
# add 10.10.10.3:6001 # add 10.10.10.3:6001
os.environ['PADDLE_TRAINER_ID'] = "0" os.environ['PADDLE_TRAINER_ID'] = "0"
elastic.host_port = "10.10.10.1:6001" elastic.curr_host = "10.10.10.1:6001"
elastic.hosts = ["10.10.10.1:6001", "10.10.10.2:6001"] elastic.hosts = ["10.10.10.1:6001", "10.10.10.2:6001"]
elastic._update_hosts() elastic._update_hosts()
self.assertEqual(os.getenv('PADDLE_TRAINERS'), "10.10.10.1,10.10.10.2") self.assertEqual(os.getenv('PADDLE_TRAINERS'), "10.10.10.1,10.10.10.2")
# add 10.10.10.3:6001 # add 10.10.10.3:6001
elastic.host_port = "10.10.10.3:6001" elastic.curr_host = "10.10.10.3:6001"
elastic.hosts = ["10.10.10.1:6001", "10.10.10.3:6001"] elastic.hosts = ["10.10.10.1:6001", "10.10.10.3:6001"]
os.environ['PADDLE_TRAINER_ID'] = "1" os.environ['PADDLE_TRAINER_ID'] = "1"
elastic._update_hosts() elastic._update_hosts()
self.assertEqual(os.getenv('PADDLE_TRAINERS'), "10.10.10.1,10.10.10.3") self.assertEqual(os.getenv('PADDLE_TRAINERS'), "10.10.10.1,10.10.10.3")
elastic.host_port = "10.10.10.3:6001" elastic.curr_host = "10.10.10.3:6001"
elastic.hosts = ["10.10.10.1:6001", "10.10.10.3:6001"] elastic.hosts = ["10.10.10.1:6001", "10.10.10.3:6001"]
os.environ['PADDLE_TRAINER_ID'] = "-1" os.environ['PADDLE_TRAINER_ID'] = "-1"
elastic._update_hosts() elastic._update_hosts()
...@@ -216,7 +228,8 @@ class TestElasticManager(unittest.TestCase): ...@@ -216,7 +228,8 @@ class TestElasticManager(unittest.TestCase):
gpus = "0" gpus = "0"
nproc_per_node = 1 nproc_per_node = 1
host = None host = None
host_port = None curr_host = None
ips = None
scale = None scale = None
force = None force = None
backend = 'gloo' backend = 'gloo'
...@@ -231,7 +244,7 @@ class TestElasticManager(unittest.TestCase): ...@@ -231,7 +244,7 @@ class TestElasticManager(unittest.TestCase):
'PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.2:6001" 'PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.2:6001"
elastic = ElasticManager(args, self.etcd_client) elastic = ElasticManager(args, self.etcd_client)
# add 10.10.10.3:6001 # add 10.10.10.3:6001
elastic.host_port = "10.10.10.1:6001" elastic.curr_host = "10.10.10.1:6001"
elastic.hosts = [ elastic.hosts = [
"10.10.10.1:6001", "10.10.10.2:6001", "10.10.10.3:6001" "10.10.10.1:6001", "10.10.10.2:6001", "10.10.10.3:6001"
] ]
...@@ -242,7 +255,7 @@ class TestElasticManager(unittest.TestCase): ...@@ -242,7 +255,7 @@ class TestElasticManager(unittest.TestCase):
os.getenv('PADDLE_TRAINERS'), "10.10.10.1,10.10.10.2,10.10.10.3") os.getenv('PADDLE_TRAINERS'), "10.10.10.1,10.10.10.2,10.10.10.3")
####################### #######################
# elastic, scale down # # elastic, scale in #
####################### #######################
os.environ[ os.environ[
'PADDLE_TRAINERS'] = "10.10.10.0,10.10.10.1,10.10.10.2,10.10.10.3" 'PADDLE_TRAINERS'] = "10.10.10.0,10.10.10.1,10.10.10.2,10.10.10.3"
...@@ -250,9 +263,14 @@ class TestElasticManager(unittest.TestCase): ...@@ -250,9 +263,14 @@ class TestElasticManager(unittest.TestCase):
'DISTRIBUTED_TRAINER_ENDPOINTS'] = "10.10.10.0:6000,10.10.10.1:6001,10.10.10.2:6001,10.10.10.3:6001" 'DISTRIBUTED_TRAINER_ENDPOINTS'] = "10.10.10.0:6000,10.10.10.1:6001,10.10.10.2:6001,10.10.10.3:6001"
os.environ[ os.environ[
'PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.0:6000,10.10.10.1:6001,10.10.10.2:6001,10.10.10.3:6001" 'PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.0:6000,10.10.10.1:6001,10.10.10.2:6001,10.10.10.3:6001"
os.environ['POD_IP'] = "10.10.10.1"
os.environ['TRAINER_PORTS_NUM'] = "4"
os.environ['PADDLE_TRAINER_ID'] = "1"
os.environ['PADDLE_PORT'] = "6001"
args = Argument()
elastic = ElasticManager(args, self.etcd_client) elastic = ElasticManager(args, self.etcd_client)
# remove 10.10.10.1:6001 # remove 10.10.10.1:6001
elastic.host_port = "10.10.10.1:6001" elastic.curr_host = "10.10.10.1:6001"
elastic.hosts = [ elastic.hosts = [
"10.10.10.1:6001", "10.10.10.2:6001", "10.10.10.3:6001" "10.10.10.1:6001", "10.10.10.2:6001", "10.10.10.3:6001"
] ]
...@@ -266,23 +284,28 @@ class TestElasticManager(unittest.TestCase): ...@@ -266,23 +284,28 @@ class TestElasticManager(unittest.TestCase):
"10.10.10.3:6001,10.10.10.1:6001,10.10.10.2:6001") "10.10.10.3:6001,10.10.10.1:6001,10.10.10.2:6001")
############ ############
os.environ['PADDLE_TRAINERS'] = "10.10.10.1,10.10.10.1" os.environ[
'PADDLE_TRAINERS'] = "10.10.10.1,10.10.10.1,10.10.10.1,10.10.10.1"
os.environ[ os.environ[
'DISTRIBUTED_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.1:6002,10.10.10.1:6003,10.10.10.1:6004" 'DISTRIBUTED_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.1:6002,10.10.10.1:6003,10.10.10.1:6004"
os.environ[ os.environ[
'PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.1:6002,10.10.10.1:6003,10.10.10.1:6004" 'PADDLE_TRAINER_ENDPOINTS'] = "10.10.10.1:6001,10.10.10.1:6002,10.10.10.1:6003,10.10.10.1:6004"
os.environ['POD_IP'] = "10.10.10.1"
os.environ['TRAINER_PORTS_NUM'] = "4"
os.environ['PADDLE_PORT'] = "6001"
args = Argument()
elastic = ElasticManager(args, self.etcd_client) elastic = ElasticManager(args, self.etcd_client)
# remove 10.10.10.1:6001 # remove 10.10.10.1:6001
elastic.host_port = "10.10.10.1:6001" elastic.curr_host = "10.10.10.1:6001"
os.environ['PADDLE_TRAINER_ID'] = "-1" os.environ['PADDLE_TRAINER_ID'] = "-1"
elastic.hosts = ["10.10.10.1:6001", "10.10.10.1:6001"] elastic.hosts = ["10.10.10.1:6001", "10.10.10.1:6003"]
elastic._update_hosts() elastic._update_hosts()
#self.assertEqual(elastic.all_host_endpoints, #self.assertEqual(elastic.all_host_endpoints,
# ["10.10.10.1:6001", "10.10.10.1:6001"]) # ["10.10.10.1:6001", "10.10.10.1:6001"])
self.assertEqual(os.getenv('PADDLE_TRAINERS'), "10.10.10.1,10.10.10.1") self.assertEqual(os.getenv('PADDLE_TRAINERS'), "10.10.10.1,10.10.10.1")
self.assertEqual( self.assertEqual(
os.getenv('DISTRIBUTED_TRAINER_ENDPOINTS'), os.getenv('DISTRIBUTED_TRAINER_ENDPOINTS'),
"10.10.10.1:6001,10.10.10.1:6001") "10.10.10.1:6001,10.10.10.1:6003")
def test_exit(self): def test_exit(self):
class Argument: class Argument:
...@@ -292,7 +315,8 @@ class TestElasticManager(unittest.TestCase): ...@@ -292,7 +315,8 @@ class TestElasticManager(unittest.TestCase):
gpus = "0" gpus = "0"
nproc_per_node = 1 nproc_per_node = 1
host = None host = None
host_port = None curr_host = None
ips = None
scale = None scale = None
force = None force = None
backend = 'gloo' backend = 'gloo'
...@@ -301,6 +325,28 @@ class TestElasticManager(unittest.TestCase): ...@@ -301,6 +325,28 @@ class TestElasticManager(unittest.TestCase):
elastic = ElasticManager(args, self.etcd_client) elastic = ElasticManager(args, self.etcd_client)
elastic.exit() elastic.exit()
def test_pre_hook(self):
class Argument:
elastic_server = "127.0.0.1:2379"
job_id = "test_job_id_123"
np = "2"
gpus = "0"
nproc_per_node = 1
host = None
curr_host = None
ips = None
scale = None
force = None
backend = 'gloo'
elastic_pre_hook = None
args = Argument()
elastic = ElasticManager(args, self.etcd_client)
elastic.pre_hook()
args.elastic_pre_hook = "hostname"
elastic.pre_hook()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册