diff --git a/python/paddle/distributed/auto_tuner/prune.py b/python/paddle/distributed/auto_tuner/prune.py index 6f6d549e504927ea17281c0936d565112da999e4..66f16ff67fb9dab9c3d2eff7637fd57e778f8d11 100644 --- a/python/paddle/distributed/auto_tuner/prune.py +++ b/python/paddle/distributed/auto_tuner/prune.py @@ -91,7 +91,7 @@ def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=None): """ pp_degree = cur_cfg.get("pp_degree", None) num_layers = tuner_cfg["model_cfg"].get("num_layers", None) - num_nodes = tuner_cfg.get("num_nodes", 1) + num_nodes = tuner_cfg.get("nodes", 1) if pp_degree is None: return False diff --git a/python/paddle/distributed/launch/controllers/collective.py b/python/paddle/distributed/launch/controllers/collective.py index ad3a811ec8f458e00619fa3885a89de57ce06119..35bd244bb2f89732c3a7c5b8414475b4378e9ab5 100644 --- a/python/paddle/distributed/launch/controllers/collective.py +++ b/python/paddle/distributed/launch/controllers/collective.py @@ -199,6 +199,10 @@ class CollectiveController(Controller): ''' collective_master = peer_list[0]['candidate'] + # get collective master ip + collective_master_ip = collective_master.split(':')[0].strip() + os.environ["COLLECTIVE_MASTER_IP"] = collective_master_ip + job_endpoints = [i['endpoints'] for i in peer_list] # self.pod.reset() diff --git a/python/paddle/distributed/launch/controllers/master.py b/python/paddle/distributed/launch/controllers/master.py index e04ee59b2428587b6f2a6bb994af408c003f8b7a..d625887b8167f0f71776c26a2f6451aea53e523e 100644 --- a/python/paddle/distributed/launch/controllers/master.py +++ b/python/paddle/distributed/launch/controllers/master.py @@ -193,7 +193,11 @@ class ETCDMaster(Master): import etcd3 + from ..utils.etcd_client import ETCDClient + host, port = self.endpoint.split(':') + if ctx.is_auto_tuner_mode(): + self.etcd_client = ETCDClient(host=host, port=port) self.client = etcd3.client(host=host, port=port) def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int): @@ -253,29 +257,21 @@ class ETCDMaster(Master): self.job_prefix = f'/paddle/{job_id}' self.heartbeat_prefix = f'{self.job_prefix}/heartbeat' if self.ctx.is_auto_tuner_mode(): - delete_success = False - while not delete_success: - try: - self.client.delete_prefix(self.job_prefix) - delete_success = True - except: - time.sleep(1) - - if self.ctx.is_auto_tuner_mode(): - lease_success = False - while not lease_success: - try: - lease = self.client.lease(ttl) - lease_success = True - except: - time.sleep(1) + self.etcd_client.delete_prefix(self.job_prefix) + lease = self.etcd_client.lease(ttl) else: + self.client.delete_prefix(self.job_prefix) lease = self.client.lease(ttl) # self.client.delete_prefix(self.job_prefix) beat_path = f"{self.heartbeat_prefix}/{pod_id}" - self.client.put(beat_path, pod_id.encode('latin-1'), lease=lease) + if self.ctx.is_auto_tuner_mode(): + self.etcd_client.put( + beat_path, pod_id.encode('latin-1'), lease=lease + ) + else: + self.client.put(beat_path, pod_id.encode('latin-1'), lease=lease) def _beat_watch(event): self.ctx.status.restart() diff --git a/python/paddle/distributed/launch/main.py b/python/paddle/distributed/launch/main.py index bd77c84bbcc87fd6ef3172b6b6a618225cd39bcf..3cee9b9b4e14359c862260d240c77bceb2a2c408 100644 --- a/python/paddle/distributed/launch/main.py +++ b/python/paddle/distributed/launch/main.py @@ -295,6 +295,7 @@ def launch(): elif ctx.is_auto_tuner_mode(): import copy import json + import os import sys import time @@ -333,12 +334,13 @@ def launch(): tuner_cfg["num_gpus"] = gpus_per_node * tuner_cfg["nodes"] if nnodes > 1: - import etcd3 + from .utils.etcd_client import ETCDClient assert "etcd://" in ctx.args.master master_ip, port = ctx.args.master.strip("etcd://").split(':') - client = etcd3.client(host=master_ip, port=port) + client = ETCDClient(host=master_ip, port=port) client.delete("best_cfg") + client.delete_prefix("auto_tuner") # get max time per task run max_time_per_task = tuner_cfg.get("max_time_per_task", 1800) @@ -532,6 +534,45 @@ def launch(): target_metric=tuner_cfg["metric_cfg"]["name"], memory_file=f"{ctx.args.job_id}.gpu.log", ) + # sync sigint + timeout_flag = True + + if nnodes > 1: + import socket + + ip = None + try: + hostname = socket.gethostname() + ip = socket.gethostbyname(socket.getfqdn(hostname)) + except: + ip = '127.0.0.1' + assert ip != '127.0.0.1' + path = f"auto_tuner/{job_id}/{ip}" + OOM_flag = err & (1 << 1) + if OOM_flag: + client.put(path, "OOM".encode('latin-1')) + ctx.logger.info(f"Put OOM to {path}") + elif hasattr(c, 'sigint') and c.sigint == 14: + client.put(path, "OK".encode('latin-1')) + ctx.logger.info(f"Put OK to {path}") + else: + client.put(path, "Error".encode('latin-1')) + ctx.logger.info(f"Put Error to {path}") + + result = list(client.get_prefix(f"auto_tuner/{job_id}/")) + size = len(result) + while size != nnodes: + time.sleep(1) + result = list(client.get_prefix(f"auto_tuner/{job_id}/")) + size = len(result) + + status = [i[0].decode() for i in result] + ctx.logger.info(f"Status of auto_tuner/{job_id}/: {status}") + + if "OOM" in status: + timeout_flag = False + elif "OK" not in status: + timeout_flag = False if err & (1 << 0): ctx.logger.warning( @@ -556,12 +597,17 @@ def launch(): ) cur_cfg["max_mem_usage"] = None - if not err: + if not err and timeout_flag: # for pruner use cur_cfg['time'] = metric cur_cfg[tuner_cfg['metric_cfg']['name']] = metric cur_cfg["max_mem_usage"] = mem + if not err and not timeout_flag: + cur_cfg['time'] = -1 + cur_cfg[tuner_cfg['metric_cfg']['name']] = None + cur_cfg["max_mem_usage"] = None + # record history cur_cfg['job_id'] = job_id recorder.add_cfg(**cur_cfg) @@ -586,6 +632,14 @@ def launch(): auto_tuner.add_cfg(cur_cfg) # per task launch interval + self_pid = str(os.getpid()) + processes = os.popen( + "fuser -v /dev/nvidia* |awk '{for(i=1;i<=NF;i++) print $i;}'" + ).readlines() + for process in processes: + pid = str(process.strip()) + if pid != self_pid: + os.system("kill -9 " + pid) time.sleep(3) recorder.store_history() @@ -601,7 +655,10 @@ def launch(): ip = socket.gethostbyname(socket.getfqdn(hostname)) except: ip = '127.0.0.1' - if ip == master_ip: + + collective_master_ip = os.environ.get("COLLECTIVE_MASTER_IP", None) + assert collective_master_ip is not None + if ip == collective_master_ip: best_cfg, err = recorder.get_best( metric=tuner_cfg['metric_cfg']['name'], direction=tuner_cfg['metric_cfg']['OptimizationDirection'], diff --git a/python/paddle/distributed/launch/utils/etcd_client.py b/python/paddle/distributed/launch/utils/etcd_client.py new file mode 100644 index 0000000000000000000000000000000000000000..e4bbf8e1409a4defd86dc8c1374fa68bf8322193 --- /dev/null +++ b/python/paddle/distributed/launch/utils/etcd_client.py @@ -0,0 +1,142 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +import etcd3 + + +class ETCDClient: + def __init__(self, host, port, retry_times=20): + self.retry_times = retry_times + times = 0 + while times < self.retry_times: + try: + self.client = etcd3.client(host=host, port=port) + break + except Exception as e: + times += 1 + logging.info( + f"Initialize etcd client failed with exception {e}, retry after 1 second." + ) + time.sleep(1) + + if times >= self.retry_times: + raise ValueError( + f"Initialize etcd client failed failed after {self.retry_times} times." + ) + + def put(self, key, value, lease=None, prev_kv=False): + times = 0 + while times < self.retry_times: + try: + return self.client.put(key, value, lease, prev_kv) + except Exception as e: + times += 1 + logging.info( + f"Put failed with exception {e}, retry after 1 second." + ) + time.sleep(1) + + if times >= self.retry_times: + raise ValueError(f"Put failed after {self.retry_times} times.") + + def get(self, key): + times = 0 + while times < self.retry_times: + try: + return self.client.get(key) + break + except Exception as e: + times += 1 + logging.info( + f"Get {key} failed with exception {e}, retry after 1 second." + ) + time.sleep(1) + + if times >= self.retry_times: + raise ValueError( + f"Get {key} failed after {self.retry_times} times." + ) + + def delete(self, key, prev_kv=False, return_response=False): + times = 0 + while times < self.retry_times: + try: + return self.client.delete(key, prev_kv, return_response) + break + except Exception as e: + times += 1 + logging.info( + f"Delete {key} failed with exception {e}, retry after 1 second." + ) + time.sleep(1) + + if times >= self.retry_times: + raise ValueError( + f"Delete {key} failed after {self.retry_times} times." + ) + + def get_prefix(self, key_prefix, sort_order=None, sort_target='key'): + times = 0 + while times < self.retry_times: + try: + return self.client.get_prefix(key_prefix) + break + except Exception as e: + times += 1 + logging.info( + f"Get prefix {key_prefix} failed with exception {e}, retry after 1 second." + ) + time.sleep(1) + + if times >= self.retry_times: + raise ValueError( + f"Get prefix {key_prefix} failed after {self.retry_times} times." + ) + + def delete_prefix(self, prefix): + times = 0 + while times < self.retry_times: + try: + return self.client.delete_prefix(prefix) + break + except Exception as e: + times += 1 + logging.info( + f"Delete prefix {prefix} failed with exception {e}, retry after 1 second." + ) + time.sleep(1) + + if times >= self.retry_times: + raise ValueError( + f"Delete prefix {prefix} failed after {self.retry_times} times." + ) + + def lease(self, ttl, lease_id=None): + times = 0 + while times < self.retry_times: + try: + return self.client.lease(ttl, lease_id) + break + except Exception as e: + times += 1 + logging.info( + f"Lease failed with exception {e}, retry after 1 second." + ) + time.sleep(1) + + if times >= self.retry_times: + raise ValueError(f"Lease failed after {self.retry_times} times.")