未验证 提交 2f69edc5 编写于 作者: C caozhou 提交者: GitHub

update auto tuner in the multi nodes scene (#56374)

上级 ca8f9552
...@@ -91,7 +91,7 @@ def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=None): ...@@ -91,7 +91,7 @@ def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=None):
""" """
pp_degree = cur_cfg.get("pp_degree", None) pp_degree = cur_cfg.get("pp_degree", None)
num_layers = tuner_cfg["model_cfg"].get("num_layers", None) num_layers = tuner_cfg["model_cfg"].get("num_layers", None)
num_nodes = tuner_cfg.get("num_nodes", 1) num_nodes = tuner_cfg.get("nodes", 1)
if pp_degree is None: if pp_degree is None:
return False return False
......
...@@ -199,6 +199,10 @@ class CollectiveController(Controller): ...@@ -199,6 +199,10 @@ class CollectiveController(Controller):
''' '''
collective_master = peer_list[0]['candidate'] collective_master = peer_list[0]['candidate']
# get collective master ip
collective_master_ip = collective_master.split(':')[0].strip()
os.environ["COLLECTIVE_MASTER_IP"] = collective_master_ip
job_endpoints = [i['endpoints'] for i in peer_list] job_endpoints = [i['endpoints'] for i in peer_list]
# self.pod.reset() # self.pod.reset()
......
...@@ -193,7 +193,11 @@ class ETCDMaster(Master): ...@@ -193,7 +193,11 @@ class ETCDMaster(Master):
import etcd3 import etcd3
from ..utils.etcd_client import ETCDClient
host, port = self.endpoint.split(':') host, port = self.endpoint.split(':')
if ctx.is_auto_tuner_mode():
self.etcd_client = ETCDClient(host=host, port=port)
self.client = etcd3.client(host=host, port=port) self.client = etcd3.client(host=host, port=port)
def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int): def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int):
...@@ -253,28 +257,20 @@ class ETCDMaster(Master): ...@@ -253,28 +257,20 @@ class ETCDMaster(Master):
self.job_prefix = f'/paddle/{job_id}' self.job_prefix = f'/paddle/{job_id}'
self.heartbeat_prefix = f'{self.job_prefix}/heartbeat' self.heartbeat_prefix = f'{self.job_prefix}/heartbeat'
if self.ctx.is_auto_tuner_mode(): if self.ctx.is_auto_tuner_mode():
delete_success = False self.etcd_client.delete_prefix(self.job_prefix)
while not delete_success: lease = self.etcd_client.lease(ttl)
try:
self.client.delete_prefix(self.job_prefix)
delete_success = True
except:
time.sleep(1)
if self.ctx.is_auto_tuner_mode():
lease_success = False
while not lease_success:
try:
lease = self.client.lease(ttl)
lease_success = True
except:
time.sleep(1)
else: else:
self.client.delete_prefix(self.job_prefix)
lease = self.client.lease(ttl) lease = self.client.lease(ttl)
# self.client.delete_prefix(self.job_prefix) # self.client.delete_prefix(self.job_prefix)
beat_path = f"{self.heartbeat_prefix}/{pod_id}" beat_path = f"{self.heartbeat_prefix}/{pod_id}"
if self.ctx.is_auto_tuner_mode():
self.etcd_client.put(
beat_path, pod_id.encode('latin-1'), lease=lease
)
else:
self.client.put(beat_path, pod_id.encode('latin-1'), lease=lease) self.client.put(beat_path, pod_id.encode('latin-1'), lease=lease)
def _beat_watch(event): def _beat_watch(event):
......
...@@ -295,6 +295,7 @@ def launch(): ...@@ -295,6 +295,7 @@ def launch():
elif ctx.is_auto_tuner_mode(): elif ctx.is_auto_tuner_mode():
import copy import copy
import json import json
import os
import sys import sys
import time import time
...@@ -333,12 +334,13 @@ def launch(): ...@@ -333,12 +334,13 @@ def launch():
tuner_cfg["num_gpus"] = gpus_per_node * tuner_cfg["nodes"] tuner_cfg["num_gpus"] = gpus_per_node * tuner_cfg["nodes"]
if nnodes > 1: if nnodes > 1:
import etcd3 from .utils.etcd_client import ETCDClient
assert "etcd://" in ctx.args.master assert "etcd://" in ctx.args.master
master_ip, port = ctx.args.master.strip("etcd://").split(':') master_ip, port = ctx.args.master.strip("etcd://").split(':')
client = etcd3.client(host=master_ip, port=port) client = ETCDClient(host=master_ip, port=port)
client.delete("best_cfg") client.delete("best_cfg")
client.delete_prefix("auto_tuner")
# get max time per task run # get max time per task run
max_time_per_task = tuner_cfg.get("max_time_per_task", 1800) max_time_per_task = tuner_cfg.get("max_time_per_task", 1800)
...@@ -532,6 +534,45 @@ def launch(): ...@@ -532,6 +534,45 @@ def launch():
target_metric=tuner_cfg["metric_cfg"]["name"], target_metric=tuner_cfg["metric_cfg"]["name"],
memory_file=f"{ctx.args.job_id}.gpu.log", memory_file=f"{ctx.args.job_id}.gpu.log",
) )
# sync sigint
timeout_flag = True
if nnodes > 1:
import socket
ip = None
try:
hostname = socket.gethostname()
ip = socket.gethostbyname(socket.getfqdn(hostname))
except:
ip = '127.0.0.1'
assert ip != '127.0.0.1'
path = f"auto_tuner/{job_id}/{ip}"
OOM_flag = err & (1 << 1)
if OOM_flag:
client.put(path, "OOM".encode('latin-1'))
ctx.logger.info(f"Put OOM to {path}")
elif hasattr(c, 'sigint') and c.sigint == 14:
client.put(path, "OK".encode('latin-1'))
ctx.logger.info(f"Put OK to {path}")
else:
client.put(path, "Error".encode('latin-1'))
ctx.logger.info(f"Put Error to {path}")
result = list(client.get_prefix(f"auto_tuner/{job_id}/"))
size = len(result)
while size != nnodes:
time.sleep(1)
result = list(client.get_prefix(f"auto_tuner/{job_id}/"))
size = len(result)
status = [i[0].decode() for i in result]
ctx.logger.info(f"Status of auto_tuner/{job_id}/: {status}")
if "OOM" in status:
timeout_flag = False
elif "OK" not in status:
timeout_flag = False
if err & (1 << 0): if err & (1 << 0):
ctx.logger.warning( ctx.logger.warning(
...@@ -556,12 +597,17 @@ def launch(): ...@@ -556,12 +597,17 @@ def launch():
) )
cur_cfg["max_mem_usage"] = None cur_cfg["max_mem_usage"] = None
if not err: if not err and timeout_flag:
# for pruner use # for pruner use
cur_cfg['time'] = metric cur_cfg['time'] = metric
cur_cfg[tuner_cfg['metric_cfg']['name']] = metric cur_cfg[tuner_cfg['metric_cfg']['name']] = metric
cur_cfg["max_mem_usage"] = mem cur_cfg["max_mem_usage"] = mem
if not err and not timeout_flag:
cur_cfg['time'] = -1
cur_cfg[tuner_cfg['metric_cfg']['name']] = None
cur_cfg["max_mem_usage"] = None
# record history # record history
cur_cfg['job_id'] = job_id cur_cfg['job_id'] = job_id
recorder.add_cfg(**cur_cfg) recorder.add_cfg(**cur_cfg)
...@@ -586,6 +632,14 @@ def launch(): ...@@ -586,6 +632,14 @@ def launch():
auto_tuner.add_cfg(cur_cfg) auto_tuner.add_cfg(cur_cfg)
# per task launch interval # per task launch interval
self_pid = str(os.getpid())
processes = os.popen(
"fuser -v /dev/nvidia* |awk '{for(i=1;i<=NF;i++) print $i;}'"
).readlines()
for process in processes:
pid = str(process.strip())
if pid != self_pid:
os.system("kill -9 " + pid)
time.sleep(3) time.sleep(3)
recorder.store_history() recorder.store_history()
...@@ -601,7 +655,10 @@ def launch(): ...@@ -601,7 +655,10 @@ def launch():
ip = socket.gethostbyname(socket.getfqdn(hostname)) ip = socket.gethostbyname(socket.getfqdn(hostname))
except: except:
ip = '127.0.0.1' ip = '127.0.0.1'
if ip == master_ip:
collective_master_ip = os.environ.get("COLLECTIVE_MASTER_IP", None)
assert collective_master_ip is not None
if ip == collective_master_ip:
best_cfg, err = recorder.get_best( best_cfg, err = recorder.get_best(
metric=tuner_cfg['metric_cfg']['name'], metric=tuner_cfg['metric_cfg']['name'],
direction=tuner_cfg['metric_cfg']['OptimizationDirection'], direction=tuner_cfg['metric_cfg']['OptimizationDirection'],
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
import etcd3
class ETCDClient:
def __init__(self, host, port, retry_times=20):
self.retry_times = retry_times
times = 0
while times < self.retry_times:
try:
self.client = etcd3.client(host=host, port=port)
break
except Exception as e:
times += 1
logging.info(
f"Initialize etcd client failed with exception {e}, retry after 1 second."
)
time.sleep(1)
if times >= self.retry_times:
raise ValueError(
f"Initialize etcd client failed failed after {self.retry_times} times."
)
def put(self, key, value, lease=None, prev_kv=False):
times = 0
while times < self.retry_times:
try:
return self.client.put(key, value, lease, prev_kv)
except Exception as e:
times += 1
logging.info(
f"Put failed with exception {e}, retry after 1 second."
)
time.sleep(1)
if times >= self.retry_times:
raise ValueError(f"Put failed after {self.retry_times} times.")
def get(self, key):
times = 0
while times < self.retry_times:
try:
return self.client.get(key)
break
except Exception as e:
times += 1
logging.info(
f"Get {key} failed with exception {e}, retry after 1 second."
)
time.sleep(1)
if times >= self.retry_times:
raise ValueError(
f"Get {key} failed after {self.retry_times} times."
)
def delete(self, key, prev_kv=False, return_response=False):
times = 0
while times < self.retry_times:
try:
return self.client.delete(key, prev_kv, return_response)
break
except Exception as e:
times += 1
logging.info(
f"Delete {key} failed with exception {e}, retry after 1 second."
)
time.sleep(1)
if times >= self.retry_times:
raise ValueError(
f"Delete {key} failed after {self.retry_times} times."
)
def get_prefix(self, key_prefix, sort_order=None, sort_target='key'):
times = 0
while times < self.retry_times:
try:
return self.client.get_prefix(key_prefix)
break
except Exception as e:
times += 1
logging.info(
f"Get prefix {key_prefix} failed with exception {e}, retry after 1 second."
)
time.sleep(1)
if times >= self.retry_times:
raise ValueError(
f"Get prefix {key_prefix} failed after {self.retry_times} times."
)
def delete_prefix(self, prefix):
times = 0
while times < self.retry_times:
try:
return self.client.delete_prefix(prefix)
break
except Exception as e:
times += 1
logging.info(
f"Delete prefix {prefix} failed with exception {e}, retry after 1 second."
)
time.sleep(1)
if times >= self.retry_times:
raise ValueError(
f"Delete prefix {prefix} failed after {self.retry_times} times."
)
def lease(self, ttl, lease_id=None):
times = 0
while times < self.retry_times:
try:
return self.client.lease(ttl, lease_id)
break
except Exception as e:
times += 1
logging.info(
f"Lease failed with exception {e}, retry after 1 second."
)
time.sleep(1)
if times >= self.retry_times:
raise ValueError(f"Lease failed after {self.retry_times} times.")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册