未验证 提交 4c0c458a 编写于 作者: A Azure 提交者: GitHub

[AutoTuner] Add GBS search, gpu memory usage (#55466)

* temp commit

* distribute best cfg

* update metric extracting

* fix bugs of prune and reading log

* fix adding cfg bug

* reset status

* remove alarm and set logdir

* deepcopy ctx

* change alarm

* fix restart bug

* best no need alarm

* add gbs search, add gpu memory to history csv, add memory detect

* fix bug

* fix memory read bug; fix etcd connection bug

* fix memory read bug, add oom detection for all ranks

* fix read log and oom detaction, add error code for read log

* add unit test

* Update master.py

---------
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
上级 bd930f83
...@@ -122,6 +122,8 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None): ...@@ -122,6 +122,8 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None):
""" """
micro_batch_size = cur_cfg.get("micro_batch_size", None) micro_batch_size = cur_cfg.get("micro_batch_size", None)
global_batch_size = tuner_cfg["model_cfg"].get("global_batch_size", None) global_batch_size = tuner_cfg["model_cfg"].get("global_batch_size", None)
if global_batch_size == "auto":
global_batch_size = cur_cfg["global_batch_size"]
if global_batch_size: if global_batch_size:
local_batch_size = ( local_batch_size = (
global_batch_size global_batch_size
......
...@@ -19,7 +19,7 @@ from typing import Tuple ...@@ -19,7 +19,7 @@ from typing import Tuple
import pandas as pd import pandas as pd
class History_recorder: class HistoryRecorder:
# NOTE increase extenable ablitity # NOTE increase extenable ablitity
def __init__(self) -> None: def __init__(self) -> None:
self.history = [] self.history = []
...@@ -63,7 +63,9 @@ class History_recorder: ...@@ -63,7 +63,9 @@ class History_recorder:
cols = df.columns.tolist() cols = df.columns.tolist()
cols.insert(0, cols.pop(cols.index('job_id'))) cols.insert(0, cols.pop(cols.index('job_id')))
df = df.reindex(columns=cols) df = df.reindex(columns=cols)
df = df.drop(columns=['time']) # check if 'time' exists
if 'time' in df.columns:
df = df.drop(columns=['time'])
# write to csv # write to csv
df.to_csv(self.store_path, index=False) df.to_csv(self.store_path, index=False)
...@@ -79,3 +81,7 @@ class History_recorder: ...@@ -79,3 +81,7 @@ class History_recorder:
reader = csv.reader(f) reader = csv.reader(f)
self.history = list(reader) self.history = list(reader)
return (self.history, err) return (self.history, err)
def clean_history(self) -> None:
"""Clean history."""
self.history = []
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from .prune import _PRUNE_FUNC from .prune import _PRUNE_FUNC
from .utils import search_all from .utils import gbs_search_all, search_all
class SearchAlgo(ABC): class SearchAlgo(ABC):
...@@ -52,3 +52,24 @@ class GridSearch(SearchAlgo): ...@@ -52,3 +52,24 @@ class GridSearch(SearchAlgo):
else: else:
return None return None
return new_cfg return new_cfg
class GBSSearch(SearchAlgo):
def __init__(self, tuner_cfg):
super().__init__(tuner_cfg)
self.idx = 0
self.all_tasks = gbs_search_all(tuner_cfg)
def search_once(self, history_cfgs):
new_cfg = None
stop = False
while not stop:
if self.idx < len(self.all_tasks):
new_cfg = self.all_tasks[self.idx]
self.idx += 1
glb = new_cfg.get("global_batch_size", None)
self.tuner_cfg["model_cfg"]["global_batch_size"] = glb
stop = not self.prune(self.tuner_cfg, new_cfg, history_cfgs)
else:
return None
return new_cfg
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from .utils import default_candidates from .utils import default_candidates, gbs_default_candidates
class AutoTuner: class AutoTuner:
...@@ -29,13 +29,18 @@ class AutoTuner: ...@@ -29,13 +29,18 @@ class AutoTuner:
self.cur_task_id = 1 self.cur_task_id = 1
self.task_limit = tuner_cfg.get("task_limit", 100) self.task_limit = tuner_cfg.get("task_limit", 100)
tuner_cfg["candidates"] = default_candidates(tuner_cfg)
search_algo = tuner_cfg.get("search_algo", "grid") search_algo = tuner_cfg.get("search_algo", "grid")
if search_algo == "grid": if search_algo == "grid":
from .search import GridSearch from .search import GridSearch
tuner_cfg["candidates"] = default_candidates(tuner_cfg)
self.algo = GridSearch(tuner_cfg) self.algo = GridSearch(tuner_cfg)
elif search_algo == "gbs":
from .search import GBSSearch
tuner_cfg["candidates"] = gbs_default_candidates(tuner_cfg)
self.algo = GBSSearch(tuner_cfg)
else: else:
raise NotImplementedError() raise NotImplementedError()
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
import csv
import itertools import itertools
import os import os
import re import re
...@@ -320,38 +321,227 @@ def gen_new_args(raw_args, cfg, tuner_cfg): ...@@ -320,38 +321,227 @@ def gen_new_args(raw_args, cfg, tuner_cfg):
return res_args return res_args
def read_log( def read_metric_log(
path, file="workerlog.0", target_metric='step/s' path, file="workerlog.0", target_metric='step/s'
) -> Tuple[float, bool]: ) -> Tuple[float, int]:
"""For extracting metric from log file.""" """For extracting metric from log file."""
"""
return:
metric: average metric of last 10 steps
err_code:
00: no error
01: no metric
10: out of memory
"""
err_code = 0
target_file = path + "/" + file target_file = path + "/" + file
if not os.path.exists(target_file): if not os.path.exists(target_file):
return (0.0, True) return (0.0, 1)
with open(target_file, "r") as f: with open(target_file, "r") as f:
# read file # read file
re_metric_pattern = ( re_metric_pattern = (
target_metric + r":* *(\d+(\.\d*)?)|(\d+(\.\d*)?) *" + target_metric target_metric + r":* *(\d+(\.\d*)?)|(\d+(\.\d*)?) *" + target_metric
) )
re_out_of_memory_pattern = r"Out of memory"
out_of_memory_flag = 0
metric_list = [] metric_list = []
lines = f.readlines() lines = f.readlines()
for line in lines: for line in lines:
metric = re.findall(re_metric_pattern, line) metric = re.findall(re_metric_pattern, line)
out_of_memory = re.findall(
re_out_of_memory_pattern, line, re.IGNORECASE
)
if metric: if metric:
metric_list.append(float(metric[0][0])) metric_list.append(float(metric[0][0]))
if out_of_memory:
out_of_memory_flag = 1
if out_of_memory_flag:
metric_ave = 0.0
err_code = err_code | (out_of_memory_flag << 1)
if not metric_list: if not metric_list:
metric_ave = 0.0 metric_ave = 0.0
flag = True err_code = err_code | 1
elif len(metric_list) < 10: elif len(metric_list) < 10:
metric_ave = metric_list[-1] metric_ave = metric_list[-1]
flag = False
elif len(metric_list) < 20: elif len(metric_list) < 20:
metric_ave = sum(metric_list[9:]) / (len(metric_list[9:])) metric_ave = sum(metric_list[9:]) / (len(metric_list[9:]))
flag = False
else: else:
metric_ave = sum(metric_list[-10:]) / 10 metric_ave = sum(metric_list[-10:]) / 10
flag = False
# round to 5 decimal places # round to 5 decimal places
metric_ave = round(metric_ave, 5) metric_ave = round(metric_ave, 5)
res = metric_ave, flag res = metric_ave, err_code
return res return res
def read_memory_log(path, file) -> Tuple[float, bool]:
log_path = os.path.join(path, file)
if not os.path.exists(log_path):
return (0.0, True)
memory_used = []
utilization_gpu = []
indexs = []
with open(log_path, 'r') as f:
reader = csv.reader(f)
flag = False
# skip headers
while not flag:
# show the first line of reader
row = next(reader)
if len(row) == 6 and 'memory_used' in row:
flag = True
for row in reader:
# If row length is 6 then it's a utilization data row
# skip header
if len(row) == 6:
index, util_gpu, _, mem_used, _, _ = row
indexs.append(int(index))
memory_used.append(int(mem_used))
utilization_gpu.append(int(util_gpu))
return max(memory_used), False
def read_log(
path,
metric_file="workerlog.0",
target_metric='step/s',
memory_file="0.gpu.log",
) -> Tuple[float, float, int]:
"""
extract metric and max memory usage from log file
return:
metric: average metric of last 10 steps
memory: max memory used
err_code: 00: no error, 01: no metric, 10: out of memory, 100: no memory log
"""
err_code = 0
# check out of memory
for root, dirs, files in os.walk(path):
for file in files:
if not file.startswith("workerlog"):
continue
metric, metric_flag = read_metric_log(path, file, target_metric)
if metric_flag:
err_code = (metric_flag & 2) | err_code
# read metric
res_metric, metric_flag = read_metric_log(path, metric_file, target_metric)
err_code = metric_flag | err_code
# check max memory usage
try:
res_memory, memory_flag = read_memory_log(path, memory_file)
err_code = (memory_flag << 2) | err_code
except:
res_memory = 0.0
err_code = (1 << 2) | err_code
return res_metric, res_memory, err_code
def three_mul_combinations(target):
"""Return the combinations of three numbers which product is target."""
results = []
for i in range(1, target // 3 + 1):
if target % i == 0:
for j in range(i, target // 2 + 1):
if (target // i) % j == 0:
results.append((i, j, target // i // j))
return results
def gbs_dp_mp_pp_candidates(tuner_cfg, num_gpus, num_nodes):
"""Return middle candidates of dp, mp, pp"""
start = round(num_gpus ** (1 / 3))
# find factors that can be evenly distributed
for i in range(start, 0, -1):
if num_gpus % i == 0:
remaining = num_gpus // i
# find the square root as a factor for the remaining part
j = round(remaining**0.5)
while remaining % j != 0:
j -= 1
return i, j, remaining // j
raise ValueError("Cannot distribute GPUs equally")
def gbs_default_candidates(tuner_cfg):
"""Return the default candidates of every hyper param which user defined auto"""
candidates = {}
num_gpus = tuner_cfg["num_gpus"]
num_nodes = tuner_cfg["nodes"]
assert num_gpus > 0
global_batch_size = tuner_cfg.get("model_cfg", {}).get(
"global_batch_size", "auto"
)
if global_batch_size == "auto":
dp_candidate, mp_candidate, pp_candidate = gbs_dp_mp_pp_candidates(
tuner_cfg, num_gpus, num_nodes
)
sharding_dgree_candidate = dp_candidate
candidates["dp_degree"] = [1]
candidates["mp_degree"] = [mp_candidate]
candidates["pp_degree"] = [pp_candidate]
candidates["sharding_degree"] = [sharding_dgree_candidate]
candidates["sharding_stage"] = [1]
candidates["use_recompute"] = [False]
candidates["recompute_granularity"] = [None]
candidates["micro_batch_size"] = [2**i for i in range(0, 10)]
candidates["global_batch_size"] = [
pp_candidate * dp_candidate * e
for e in candidates["micro_batch_size"]
]
return candidates
def gbs_search_all(tuner_cfg):
"""Permutate the candidates of all hyper params."""
candidates = tuner_cfg["candidates"]
# Order: dp -> mp -> pp -> mbs -> sharding-> recompute
dp_degree_candidates = candidates["dp_degree"]
mp_degree_candidates = candidates["mp_degree"]
pp_degree_candidates = candidates["pp_degree"]
mbs_candidates = candidates["micro_batch_size"]
sharding_stage_candidates = candidates["sharding_stage"]
sharding_degree_candidates = candidates["sharding_degree"]
use_recompute_candidates = candidates["use_recompute"]
recompute_granularity_candidates = candidates["recompute_granularity"]
# gbs_candidates = candidates["global_batch_size"]
all_cfgs = list(
itertools.product(
dp_degree_candidates,
mp_degree_candidates,
pp_degree_candidates,
mbs_candidates,
sharding_degree_candidates,
sharding_stage_candidates,
use_recompute_candidates,
recompute_granularity_candidates,
# gbs_candidates,
)
)
mapping = {
0: "dp_degree",
1: "mp_degree",
2: "pp_degree",
3: "micro_batch_size",
5: "sharding_stage",
4: "sharding_degree",
6: "use_recompute",
7: "recompute_granularity",
# 8: "global_batch_size",
}
new_all_cfgs = []
for cfg in all_cfgs:
new_cfg = {}
for idx, val in enumerate(cfg):
new_cfg[mapping[idx]] = val
new_cfg["global_batch_size"] = (
new_cfg["pp_degree"]
* new_cfg["dp_degree"]
* new_cfg["micro_batch_size"]
)
new_all_cfgs.append(new_cfg)
return new_all_cfgs
...@@ -260,7 +260,17 @@ class ETCDMaster(Master): ...@@ -260,7 +260,17 @@ class ETCDMaster(Master):
delete_success = True delete_success = True
except: except:
time.sleep(1) time.sleep(1)
lease = self.client.lease(ttl)
if self.ctx.is_auto_tuner_mode():
lease_success = False
while not lease_success:
try:
lease = self.client.lease(ttl)
lease_success = True
except:
time.sleep(1)
else:
lease = self.client.lease(ttl)
# self.client.delete_prefix(self.job_prefix) # self.client.delete_prefix(self.job_prefix)
......
...@@ -298,7 +298,7 @@ def launch(): ...@@ -298,7 +298,7 @@ def launch():
import sys import sys
import time import time
from ..auto_tuner.recorder import History_recorder from ..auto_tuner.recorder import HistoryRecorder
from ..auto_tuner.tuner import AutoTuner from ..auto_tuner.tuner import AutoTuner
from ..auto_tuner.utils import gen_new_args, read_log from ..auto_tuner.utils import gen_new_args, read_log
from . import controllers from . import controllers
...@@ -340,11 +340,6 @@ def launch(): ...@@ -340,11 +340,6 @@ def launch():
client = etcd3.client(host=master_ip, port=port) client = etcd3.client(host=master_ip, port=port)
client.delete("best_cfg") client.delete("best_cfg")
# build AutoTuner to get new config
auto_tuner = AutoTuner(tuner_cfg)
cur_cfg = auto_tuner.search_once()
auto_tuner.add_cfg(cur_cfg)
# get max time per task run # get max time per task run
max_time_per_task = tuner_cfg.get("max_time_per_task", 1800) max_time_per_task = tuner_cfg.get("max_time_per_task", 1800)
ctx.max_time_per_task = max_time_per_task ctx.max_time_per_task = max_time_per_task
...@@ -358,11 +353,140 @@ def launch(): ...@@ -358,11 +353,140 @@ def launch():
is_first_task = True is_first_task = True
# build history recorder # build history recorder
recorder = History_recorder() recorder = HistoryRecorder()
job_id = 0 job_id = 0
ctx.args.max_restart = -1 ctx.args.max_restart = -1
raw_ctx = copy.deepcopy(ctx) raw_ctx = copy.deepcopy(ctx)
# gbs search
if (
tuner_cfg.get('model_cfg', {}).get('global_batch_size', 'auto')
== "auto"
):
# adjust micron batch size until out of memory to get best global batch size
gbs_tuner_cfg = copy.deepcopy(tuner_cfg)
gbs_tuner_cfg["search_algo"] = "gbs"
gbs_tuner = AutoTuner(gbs_tuner_cfg)
gbs_cur_cfg = gbs_tuner.search_once()
best_gbs = None
while gbs_cur_cfg:
ctx = copy.deepcopy(raw_ctx)
log_dir = "GBSSearch/GBS{}_DP{}_MP{}_PP{}_Sharding_degree_{}_stage_{}_MBS_{}_Recompute_{}_granularity_{}".format(
gbs_cur_cfg["global_batch_size"],
gbs_cur_cfg["dp_degree"],
gbs_cur_cfg["mp_degree"],
gbs_cur_cfg["pp_degree"],
gbs_cur_cfg["sharding_degree"],
gbs_cur_cfg["sharding_stage"],
gbs_cur_cfg["micro_batch_size"],
gbs_cur_cfg["use_recompute"],
gbs_cur_cfg["recompute_granularity"],
)
ctx.args.log_dir = log_dir
# every task has own job id
job_id += 1
task_job_id = "gbs_tuner_" + str(job_id)
ctx.args.job_id = task_job_id
# generate script args of task
gbs_new_args = gen_new_args(
raw_args, gbs_cur_cfg, gbs_tuner_cfg
)
ctx.args.training_script_args = gbs_new_args
# launch task
ctx.logger.info(
"Launch task from auto tuner: job_id {}, log_dir {}, config {}".format(
task_job_id, log_dir, gbs_cur_cfg
)
)
c = controllers.init(ctx)
c.run()
# process generated result
# TODO diffentiate out of memory and no loss(maybe over time)
# TODO integragte memory and metric read
metric, mem, err = read_log(
path=ctx.args.log_dir,
metric_file="workerlog.0",
target_metric=tuner_cfg["metric_cfg"]["name"],
memory_file=f"{ctx.args.job_id}.gpu.log",
)
if err & (1 << 0):
ctx.logger.warning(
f"Read metric failed for parameters: {log_dir}"
)
# for pruner use
gbs_cur_cfg['time'] = -1
gbs_cur_cfg[tuner_cfg['metric_cfg']['name']] = None
gbs_cur_cfg["max_mem_usage"] = mem
if err & (1 << 1):
ctx.logger.warning(
f"Out of memory for parameters: {log_dir}"
)
# for pruner use
gbs_cur_cfg['time'] = -1
gbs_cur_cfg[tuner_cfg['metric_cfg']['name']] = None
gbs_cur_cfg["max_mem_usage"] = "OOM"
# not err & (1 << 1): do not record memory usage when out of memory
if err & (1 << 2) and not err & (1 << 1):
ctx.logger.warning(
f"Read memory usage failed for parameters: {log_dir}"
)
gbs_cur_cfg["max_mem_usage"] = None
if not err:
# for pruner use
gbs_cur_cfg['time'] = metric
gbs_cur_cfg[tuner_cfg['metric_cfg']['name']] = metric
gbs_cur_cfg["max_mem_usage"] = mem
if err & (1 << 0) or err & (1 << 1):
# no metric or out of memory, end gbs search
break
# store and update args for next round
gbs_cur_cfg["job_id"] = job_id
best_gbs = gbs_cur_cfg["global_batch_size"]
recorder.add_cfg(**gbs_cur_cfg)
c.finalize(exit=False)
recorder.store_history("./tuner_gbs_history.csv")
# new cfgs for next round
gbs_new_cfg = gbs_tuner.search_once()
gbs_cur_cfg = copy.deepcopy(gbs_new_cfg)
gbs_tuner.add_cfg(gbs_cur_cfg)
# per task launch interval
time.sleep(3)
# prevent no valid global batch size found
if best_gbs is None:
raise ValueError(
"No valid global batch size found, check memory or valid search time. cur_tuner_cfg{}".format(
gbs_tuner_cfg
)
)
# set best global batch size to tuner cfg
tuner_cfg["model_cfg"]["global_batch_size"] = best_gbs
recorder.store_history("./tuner_gbs_history.csv")
recorder.clean_history()
end_time = time.time()
ctx.logger.info(
f"AtuoTuner for GBS search ends in {end_time-start_time}s."
)
# build AutoTuner to get new config
auto_tuner = AutoTuner(tuner_cfg)
cur_cfg = auto_tuner.search_once()
auto_tuner.add_cfg(cur_cfg)
while cur_cfg: while cur_cfg:
ctx = copy.deepcopy(raw_ctx) ctx = copy.deepcopy(raw_ctx)
if is_first_task: if is_first_task:
...@@ -401,20 +525,42 @@ def launch(): ...@@ -401,20 +525,42 @@ def launch():
c.run() c.run()
# process generated result # process generated result
metric, err = read_log(
metric, mem, err = read_log(
path=ctx.args.log_dir, path=ctx.args.log_dir,
file="workerlog.0", metric_file="workerlog.0",
target_metric=tuner_cfg["metric_cfg"]["name"], target_metric=tuner_cfg["metric_cfg"]["name"],
memory_file=f"{ctx.args.job_id}.gpu.log",
) )
if err:
ctx.logger.warning(f"Read log failed for parameters: {log_dir}") if err & (1 << 0):
ctx.logger.warning(
f"Read metric failed for parameters: {log_dir}"
)
# for pruner use # for pruner use
cur_cfg['time'] = -1 cur_cfg['time'] = -1
cur_cfg[tuner_cfg['metric_cfg']['name']] = None cur_cfg[tuner_cfg['metric_cfg']['name']] = None
else: cur_cfg["max_mem_usage"] = mem
if err & (1 << 1):
ctx.logger.warning(f"Out of memory for parameters: {log_dir}")
# for pruner use
cur_cfg['time'] = -1
cur_cfg[tuner_cfg['metric_cfg']['name']] = None
cur_cfg["max_mem_usage"] = "OOM"
# not err & (1 << 1): do not record memory usage when out of memory
if err & (1 << 2) and not err & (1 << 1):
ctx.logger.warning(
f"Read memory usage failed for parameters: {log_dir}"
)
cur_cfg["max_mem_usage"] = None
if not err:
# for pruner use # for pruner use
cur_cfg['time'] = metric cur_cfg['time'] = metric
cur_cfg[tuner_cfg['metric_cfg']['name']] = metric cur_cfg[tuner_cfg['metric_cfg']['name']] = metric
cur_cfg["max_mem_usage"] = mem
# record history # record history
cur_cfg['job_id'] = job_id cur_cfg['job_id'] = job_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册