diff --git a/python/paddle/distributed/fleet/elastic.py b/python/paddle/distributed/fleet/elastic.py new file mode 100644 index 0000000000000000000000000000000000000000..b919c4737576d5947265004854a8d1f66450a8fc --- /dev/null +++ b/python/paddle/distributed/fleet/elastic.py @@ -0,0 +1,312 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import socket +import os +import six +import logging +import signal + +logging.basicConfig(level=os.environ.get('LOGLEVEL', 'INFO').upper()) +logger = logging.getLogger("ELASTIC") + +ELASTIC_EXIT_CODE = 101 + + +class ElasticStatus: + COMPLETED = "completed" + ERROR = "error" + HOLD = "hold" + RESTART = "restart" + EXIT = "exit" + + +class LauncherInterface(object): + def __init__(self, args): + self.args = args + self.procs = [] + + def _terminate_procs(self): + for p in self.procs: + if p.proc.poll() is None: + p.proc.terminate() + if p.log_fn: + p.log_fn.close() + logger.info("terminate process id:{}".format(p.proc.pid)) + + for step in range(0, 50): + alive = False + for p in self.procs: + if p.proc.poll() is None: # not termniate + os.kill(p.proc.pid, signal.SIGKILL) + alive = True + + if not alive: + logger.info("terminate all the procs") + return True + + time.sleep(1) + return False + + def _check_procs(self): + alive = False + result = None + for p in self.procs: + ret = p.proc.poll() + if ret is None: + alive = True + elif ret != 0: + logger.error("ERROR rank {} error with code {}".format(p.rank, + ret)) + result = ret + if not alive and result is None: + return 0 + else: + return result + + def launch(self): + raise NotImplementedError + + def stop(self): + raise NotImplementedError + + def watch(self): + raise NotImplementedError + + +class ElasticManager(object): + def __init__(self, args): + + self.args = args + server = args.elastic_server or os.getenv('PADDLE_ELASTIC_SERVER') + name = args.job_id or os.getenv('PADDLE_ELASTIC_JOB_ID') + np = args.np or int(os.getenv('PADDLE_ELASTIC_NP', 0)) + host = args.host or os.getenv('POD_IP') + scale = args.scale or int(os.getenv('PADDLE_ELASTIC_SCALE', 0)) + force = args.force or os.getenv('PADDLE_ELASTIC_FORCE') + + self.endpoints = os.getenv('DISTRIBUTED_TRAINER_ENDPOINTS', '') + self.trainers = os.getenv('PADDLE_TRAINERS', '') + + self.elastic_level = int( + os.getenv('PADDLE_ELASTIC_FAULT_TOLERANC_LEVEL', 1)) + + #elastic_timeout = os.getenv('PADDLE_ELASTIC_TIMEOUT',1) + + logger.debug('init with server {} host {}'.format(server, host)) + + self.hosts = [] + self.stopped = False + + self.sigint = 0 + + if not server or ':' not in server or not name or not np: + logger.info( + 'Elastic is not enabled with server {} name {} and np {}'. + format(server, name, np)) + self.enable = False + return + else: + self.enable = True + + import etcd3 + + srv, port = server.split(':') + self.etcd = etcd3.client(host=srv, port=port) + self.host = host if host else self._get_host() + + # etcd data + self.prefix = "/paddle/" + name + self.node_prefix = self.prefix + '/nodes/' + self.np_path = self.prefix + '/np' + self.endpoints_path = self.prefix + '/endpoints' + self.host_path = '{}{}'.format(self.node_prefix, time.time()) + + self.np = np + scale + ''' + 0 group mode, be aware of healthy status of other workers + 1 decouple mode, check own status only + ''' + self.etcd.put(self.prefix, b'0') + + # host + # register self host to etcd + # register watch to reset host after host been deleted + self.etcd.delete_prefix(self.node_prefix) + + def host_call_back(event): + if self.etcd.get(self.host_path)[0] == None: + # ensure unmatch trigger + logger.info('register host again {}'.format(self.host)) + time.sleep(5) + + self.etcd.put(self.host_path, six.b(self.host)) + + host_watch = self.etcd.add_watch_callback(self.host_path, + host_call_back) + self.etcd.put(self.host_path, six.b(self.host)) + + # np describes the exact number of nodes to run the job + inp = int(self.etcd.get(self.np_path)[0] or 0) + if scale == 0 and not force: + assert inp == np or inp == 0, "np {} is not consistent with np in etcd {}".format( + np, inp) + else: + assert inp == np or inp == self.np, "np {} scale to {} by {} is not allowed".format( + inp, self.np, scale) + + self.etcd.put(self.np_path, six.b("%d" % (self.np))) + + def np_call_back(event): + gnp = int(self.etcd.get(self.np_path)[0]) + if gnp != self.np: + logger.info("scale np {} to {} ".format(self.np, gnp)) + self.np = gnp + + np_watch = self.etcd.add_watch_callback(self.np_path, np_call_back) + + # endpoints handle DISTRIBUTED_TRAINER_ENDPOINTS and PADDLE_TRAINERS + self.etcd.put(self.endpoints_path, + six.b('{}|{}'.format(self.endpoints, self.trainers))) + + def endpoints_call_back(event): + if not self.endpoints: + return + edps = six.ensure_str(self.etcd.get(self.endpoints_path)[0] or '') + self.endpoints, self.trainers = edps.split('|') + logger.info("set DISTRIBUTED_TRAINER_ENDPOINTS {} ".format( + self.endpoints)) + logger.info("set PADDLE_TRAINERS {} ".format(self.trainers)) + + endpoints_watch = self.etcd.add_watch_callback(self.endpoints_path, + endpoints_call_back) + + self.watches = [host_watch, np_watch, endpoints_watch] + + def exit(self, completed=False): + logger.info('manager exist completed {}'.format(completed)) + + if not self.enable: + return + + if completed: + self.etcd.put(self.prefix, b'1') + + for watch in self.watches: + self.etcd.cancel_watch(watch) + self.etcd.delete(self.host_path) + + hosts = [i for i in self.etcd.get_prefix(self.node_prefix)] + if len(hosts) == 0: + self.etcd.delete_prefix(self.prefix) + + def _get_host(self): + try: + return socket.gethostbyname(socket.getfqdn(socket.gethostname())) + except: + return '127.0.0.1' + + def _completed(self): + if not self.enable: + return True + + return int(self.etcd.get(self.prefix)[0]) == 1 + + def _match(self): + self.hosts = [ + six.ensure_str(i[0]) for i in self.etcd.get_prefix(self.node_prefix) + ] + if len(self.hosts) == self.np: + return True + else: + return False + + def _update_hosts(self): + assert len(self.hosts) != 0, 'hosts empty' + + if self.host in self.endpoints: + os.environ['DISTRIBUTED_TRAINER_ENDPOINTS'] = self.endpoints + os.environ['PADDLE_TRAINERS'] = self.trainers + logger.info("update env DISTRIBUTED_TRAINER_ENDPOINTS {} ".format( + self.endpoints)) + logger.info("update env PADDLE_TRAINERS {} ".format(self.trainers)) + return + + rank = int(os.getenv('PADDLE_TRAINER_ID', -1)) + idx = self.hosts.index(self.host) + + # swap if self.host not in the right position + if rank >= 0: + self.hosts[idx] = self.hosts[rank] + self.hosts[rank] = self.host + else: + os.environ['PADDLE_TRAINER_ID'] = '{}'.format(idx) + + hosts = ','.join(self.hosts) + self.args.ips = hosts + os.environ['PADDLE_TRAINERS'] = hosts + + def wait(self): + if not self.enable: + return + + while not self.stopped: + if self._match(): + logger.info('ready with hosts {}'.format(self.hosts)) + self._update_hosts() + return + logger.info('not ready for np {} with hosts {}'.format(self.np, + self.hosts)) + time.sleep(3) + return + + def run(self, launcher): + if self.stopped: + return + + self.launcher = launcher(self.args) + self.launcher.launch() + + def watch(self): + + while not self.stopped: + ret = self.launcher.watch() + + if ret is not None: # self terminated + logger.info('job exit with code {}'.format(ret)) + # process is completed if ret >= 0 or error else + completed = True if ret == 0 else False + self.launcher.stop() + self.exit(completed=completed) + if completed: + return ElasticStatus.COMPLETED + if self.elastic_level == 1: + return ElasticStatus.RESTART + else: + return ElasticStatus.ERROR + + if not self._completed() and not self._match(): + self.launcher.stop() + return ElasticStatus.HOLD + + time.sleep(3) + + return ElasticStatus.EXIT + + def signal_handler(self, sigint, frame): + if self.enable: + self.exit() + self.sigint = sigint + self.stopped = True diff --git a/python/paddle/distributed/fleet/launch.py b/python/paddle/distributed/fleet/launch.py index 25b10133191788cd85085fd4612ae7cea0f122f3..07862a07c92c419c88869e2544414ba4e63141e0 100644 --- a/python/paddle/distributed/fleet/launch.py +++ b/python/paddle/distributed/fleet/launch.py @@ -69,12 +69,18 @@ from argparse import ArgumentParser, REMAINDER import paddle import paddle.fluid as fluid from paddle.distributed.fleet import launch_utils +import signal # TODO(danleifeng): Don't import * from a module from paddle.distributed.fleet.launch_utils import * import paddle.distributed.fleet.cloud_utils as cloud_utils import paddle.distributed.fleet.ascend_utils as ascend_utils +from paddle.distributed.fleet.elastic import ElasticManager +from paddle.distributed.fleet.elastic import LauncherInterface +from paddle.distributed.fleet.elastic import ElasticStatus +from paddle.distributed.fleet.elastic import ELASTIC_EXIT_CODE + __all__ = [] @@ -175,6 +181,18 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra "--heter_worker_num", type=int, help="number of heter_workers") ps_group.add_argument("--http_port", type=int, help="Gloo http Port") + # parameter elastic mode + elastic_group = parser.add_argument_group("Elastic Parameters") + elastic_group.add_argument( + "--elastic_server", type=str, help="etcd server host:port") + elastic_group.add_argument("--job_id", type=str, help="job unique id") + elastic_group.add_argument("--np", type=int, help="job pod/node number") + elastic_group.add_argument("--scale", type=int, default=0, help="scale np") + elastic_group.add_argument( + "--host", type=str, help="bind host, default to POD_IP env") + elastic_group.add_argument( + "--force", type=bool, default=False, help="update np force") + return parser.parse_args() @@ -183,7 +201,10 @@ def get_cluster_from_args(args, device_mode, devices_per_proc): if len(node_ips) == 1: node_ip = node_ips[0] else: - _, node_ip = get_host_name_ip() + if args.host: + node_ip = args.host + else: + _, node_ip = get_host_name_ip() assert node_ip in node_ips, "Can't find your local ip {%s} in node_ips: {%s}" \ % (node_ip, node_ips) @@ -214,65 +235,75 @@ def get_cluster_from_args(args, device_mode, devices_per_proc): devices_per_proc) -def launch_collective(args): - # parse arguments, used for cloud-single-machine and local - (device_mode, devices_per_proc) = launch_utils.get_device_proc_info(args) - trainers_num = cloud_utils.get_trainers_num() - logger.debug("parsed from args trainerss_num:{} mode:{} devices:{}".format( - trainers_num, device_mode, devices_per_proc)) - - cluster = None - pod = None - - start_port = 6170 - if os.environ.get('FLAGS_START_PORT') is not None: - start_port = os.environ.get('FLAGS_START_PORT') - if cloud_utils.use_paddlecloud() and trainers_num != 1: - cluster, pod = cloud_utils.get_cloud_cluster( - args.ips, device_mode, devices_per_proc, start_port) - logger.debug("get cluster from cloud:{}".format(cluster)) - elif device_mode == DeviceMode.ASCEND_NPU: - # for ascend - cluster, pod = ascend_utils.get_cloud_cluster( - rank_table_file=os.getenv("RANK_TABLE_FILE", None), - device_mode=device_mode, - start_port=start_port) - else: - # trainers_num = 1 or not use paddlecloud ips="a,b" - cluster, pod = get_cluster_from_args(args, device_mode, - devices_per_proc) - logger.debug("get cluster from args:{}".format(cluster)) - - global_envs = copy.copy(os.environ.copy()) - gloo_rendezvous_dir = tempfile.mkdtemp() - # add gloo env - global_envs["PADDLE_WITH_GLOO"] = str(os.getenv("PADDLE_WITH_GLOO", "0")) - global_envs["PADDLE_GLOO_RENDEZVOUS"] = "3" - global_envs["PADDLE_GLOO_FS_PATH"] = gloo_rendezvous_dir - - procs = start_local_trainers( - cluster, - pod, - training_script=args.training_script, - training_script_args=args.training_script_args, - log_dir=args.log_dir, - envs=global_envs) - - for idx, proc in enumerate(procs): - print("launch proc_id:{} idx:{}".format(proc.proc.pid, idx)) +class CollectiveLauncher(LauncherInterface): + def __init__(self, args): + self.args = args + self.procs = [] - while True: - alive = watch_local_trainers(procs, cluster.trainers_nranks()) + def launch(self): + logger.info("collective lauchner launch ...") + args = self.args + # parse arguments, used for cloud-single-machine and local + (device_mode, + devices_per_proc) = launch_utils.get_device_proc_info(args) + trainers_num = cloud_utils.get_trainers_num() + logger.debug("parsed from args trainerss_num:{} mode:{} devices:{}". + format(trainers_num, device_mode, devices_per_proc)) - if not alive: - logger.info("Local processes completed.") - logger.debug("POD info:{}".format(pod)) - break + cluster = None + pod = None - time.sleep(3) - - if os.path.exists(gloo_rendezvous_dir): - shutil.rmtree(gloo_rendezvous_dir) + start_port = 6170 + if os.environ.get('FLAGS_START_PORT') is not None: + start_port = os.environ.get('FLAGS_START_PORT') + if cloud_utils.use_paddlecloud() and trainers_num != 1: + cluster, pod = cloud_utils.get_cloud_cluster( + args.ips, device_mode, devices_per_proc, start_port) + logger.debug("get cluster from cloud:{}".format(cluster)) + elif device_mode == DeviceMode.ASCEND_NPU: + # for ascend + cluster, pod = ascend_utils.get_cloud_cluster( + rank_table_file=os.getenv("RANK_TABLE_FILE", None), + device_mode=device_mode, + start_port=start_port) + else: + # trainers_num = 1 or not use paddlecloud ips="a,b" + cluster, pod = get_cluster_from_args(args, device_mode, + devices_per_proc) + logger.debug("get cluster from args:{}".format(cluster)) + + global_envs = copy.copy(os.environ.copy()) + self.gloo_rendezvous_dir = tempfile.mkdtemp() + # add gloo env + global_envs["PADDLE_WITH_GLOO"] = str( + os.getenv("PADDLE_WITH_GLOO", "0")) + global_envs["PADDLE_GLOO_RENDEZVOUS"] = "3" + global_envs["PADDLE_GLOO_FS_PATH"] = self.gloo_rendezvous_dir + + self.procs = start_local_trainers( + cluster, + pod, + training_script=args.training_script, + training_script_args=args.training_script_args, + log_dir=args.log_dir, + envs=global_envs) + + for idx, proc in enumerate(self.procs): + logger.info("launch proc_id:{} idx:{}".format(proc.proc.pid, idx)) + + def stop(self): + logger.info("collective lauchner stop ...") + self._terminate_procs() + if os.path.exists(self.gloo_rendezvous_dir): + shutil.rmtree(self.gloo_rendezvous_dir) + + def watch(self): + logger.debug("collective lauchner watch ...") + for p in self.procs: + if p.log_fn and p.local_rank == 0: + pull_worker_log(p) + ret = self._check_procs() + return ret def launch_ps(args, distribute_mode): @@ -367,10 +398,42 @@ def launch(): _print_arguments(args) distribute_mode = which_distributed_mode(args) - if distribute_mode == DistributeMode.COLLECTIVE: - launch_collective(args) - else: + # TODO(kuizhiqing) support ps later + if not distribute_mode == DistributeMode.COLLECTIVE: launch_ps(args, distribute_mode) + return + + elastic = ElasticManager(args) + + signal.signal(signal.SIGTERM, elastic.signal_handler) + signal.signal(signal.SIGABRT, elastic.signal_handler) + signal.signal(signal.SIGINT, elastic.signal_handler) + + while True: + + # wait for all nodes ready to run + elastic.wait() + + # run self with specified launcher + elastic.run(CollectiveLauncher) + + # keep wathing the health status of self and being notified for other's failure + ret = elastic.watch() + if ret == ElasticStatus.COMPLETED: + break + if ret == ElasticStatus.HOLD: + continue + if ret == ElasticStatus.EXIT: + break + if ret == ElasticStatus.ERROR: + sys.exit(3) + if ret == ElasticStatus.RESTART: + sys.exit(ELASTIC_EXIT_CODE) + + if int(elastic.sigint) > 0: + sys.exit(128 + int(elastic.sigint)) + else: + sys.exit(0) if __name__ == "__main__":