未验证 提交 7f6d222f 编写于 作者: C caozhou 提交者: GitHub

[AutoTuner] Distribute best cfg (#54834)

* distribute best cfg

* adapt to multi args transmission

* update metric extracting

* fix bugs of prune and reading log

* fix time default value

* remove time record

* adjust the order of searching dim

* fix prune bugs

* fix adding cfg bug

* fix multi nodes bug

* reset status

* remove alarm and set logdir

* deepcopy ctx

* change alarm

* fix restart bug

* add exit

* best no need alarm

* add warmup time
上级 5de773d1
...@@ -26,7 +26,7 @@ def same_cfgs_beside(attr, cur_cfg, history_cfgs): ...@@ -26,7 +26,7 @@ def same_cfgs_beside(attr, cur_cfg, history_cfgs):
for key in cur_cfg: for key in cur_cfg:
if key == attr: if key == attr:
continue continue
if key not in history_cfgs or history_cfgs[key] != cur_cfg[key]: if key not in cfg or cfg[key] != cur_cfg[key]:
same = False same = False
break break
if same: if same:
...@@ -56,7 +56,7 @@ def prune_by_mp(tuner_cfg, cur_cfg, history_cfgs=None): ...@@ -56,7 +56,7 @@ def prune_by_mp(tuner_cfg, cur_cfg, history_cfgs=None):
hidden_size = tuner_cfg["model_cfg"].get("hidden_size", None) hidden_size = tuner_cfg["model_cfg"].get("hidden_size", None)
vocab_size = tuner_cfg["model_cfg"].get("vocab_size", None) vocab_size = tuner_cfg["model_cfg"].get("vocab_size", None)
if not mp_degree: if mp_degree is None:
return False return False
if hidden_size and hidden_size % mp_degree != 0: if hidden_size and hidden_size % mp_degree != 0:
...@@ -93,7 +93,7 @@ def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=None): ...@@ -93,7 +93,7 @@ def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=None):
num_layers = tuner_cfg["model_cfg"].get("num_layers", None) num_layers = tuner_cfg["model_cfg"].get("num_layers", None)
num_nodes = tuner_cfg.get("num_nodes", 1) num_nodes = tuner_cfg.get("num_nodes", 1)
if not pp_degree: if pp_degree is None:
return False return False
if num_layers: if num_layers:
...@@ -128,12 +128,15 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None): ...@@ -128,12 +128,15 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None):
// cur_cfg["dp_degree"] // cur_cfg["dp_degree"]
// cur_cfg["sharding_degree"] // cur_cfg["sharding_degree"]
) )
if local_batch_size == 0:
return True
mbs_candidates = tuner_cfg.get("micro_batch_size", None) mbs_candidates = tuner_cfg.get("micro_batch_size", None)
if mbs_candidates == "auto": if mbs_candidates == "auto":
mbs_candidates = tuner_cfg["candidates"]["micro_batch_size"] mbs_candidates = tuner_cfg["candidates"]["micro_batch_size"]
if not micro_batch_size: if micro_batch_size is None:
return False return False
if local_batch_size: if local_batch_size:
...@@ -222,7 +225,7 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs): ...@@ -222,7 +225,7 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs):
""" """
recompute_granularity = cur_cfg.get("recompute_granularity", None) recompute_granularity = cur_cfg.get("recompute_granularity", None)
use_recompute = cur_cfg.get("use_recompute", None) use_recompute = cur_cfg.get("use_recompute", None)
if not use_recompute: if use_recompute is None:
return False return False
recompute_granularity_candidates = tuner_cfg["candidates"].get( recompute_granularity_candidates = tuner_cfg["candidates"].get(
...@@ -253,10 +256,11 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs): ...@@ -253,10 +256,11 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs):
): ):
return True return True
if use_recompute is False: if not use_recompute:
cfgs = same_cfgs_beside("recompute_granularity", cur_cfg, history_cfgs) cfgs = same_cfgs_beside("recompute_granularity", cur_cfg, history_cfgs)
if cfgs: if cfgs:
return True return True
return False return False
......
...@@ -63,6 +63,7 @@ class History_recorder: ...@@ -63,6 +63,7 @@ class History_recorder:
cols = df.columns.tolist() cols = df.columns.tolist()
cols.insert(0, cols.pop(cols.index('job_id'))) cols.insert(0, cols.pop(cols.index('job_id')))
df = df.reindex(columns=cols) df = df.reindex(columns=cols)
df = df.drop(columns=['time'])
# write to csv # write to csv
df.to_csv(self.store_path, index=False) df.to_csv(self.store_path, index=False)
......
...@@ -47,6 +47,9 @@ class AutoTuner: ...@@ -47,6 +47,9 @@ class AutoTuner:
return None return None
new_cfg = self.algo.search_once(self.history_cfgs) new_cfg = self.algo.search_once(self.history_cfgs)
self.cur_task_id += 1 self.cur_task_id += 1
self.history_cfgs.append(new_cfg)
return new_cfg return new_cfg
def add_cfg(self, cfg):
"""Add cfg into history cfgs"""
self.history_cfgs.append(cfg)
...@@ -38,11 +38,11 @@ def dist_degree(mode, num_gpus, num_nodes): ...@@ -38,11 +38,11 @@ def dist_degree(mode, num_gpus, num_nodes):
assert mode in ["dp", "mp", "pp", "sharding"] assert mode in ["dp", "mp", "pp", "sharding"]
results = [] results = []
if mode == "dp": if mode == "dp":
results = divisor(num_gpus, reverse=True) results = divisor(num_gpus, reverse=False)
elif mode == "pp": elif mode == "pp":
if num_nodes > 1: if num_nodes > 1:
results = list(range(num_nodes)) results = list(range(1, num_nodes + 1))
else: else:
results = divisor(num_gpus, reverse=True) results = divisor(num_gpus, reverse=True)
...@@ -123,9 +123,7 @@ def default_candidates(tuner_cfg): ...@@ -123,9 +123,7 @@ def default_candidates(tuner_cfg):
elif tuner_cfg.get("micro_batch_size", None): elif tuner_cfg.get("micro_batch_size", None):
candidates["micro_batch_size"] = tuner_cfg.get("micro_batch_size") candidates["micro_batch_size"] = tuner_cfg.get("micro_batch_size")
else: else:
candidates["micro_batch_size"] = [ candidates["micro_batch_size"] = [None]
tuner_cfg["model_cfg"]["global_batch_size"]
]
return candidates return candidates
...@@ -133,7 +131,7 @@ def default_candidates(tuner_cfg): ...@@ -133,7 +131,7 @@ def default_candidates(tuner_cfg):
def search_all(tuner_cfg): def search_all(tuner_cfg):
"""Permutate the candidates of all hyper params.""" """Permutate the candidates of all hyper params."""
candidates = tuner_cfg["candidates"] candidates = tuner_cfg["candidates"]
# Order: dp -> mp -> pp -> mbs -> sharding-> recompute # Order: dp -> sharding -> mbs -> pp -> mp -> recompute
dp_degree_candidates = candidates["dp_degree"] dp_degree_candidates = candidates["dp_degree"]
mp_degree_candidates = candidates["mp_degree"] mp_degree_candidates = candidates["mp_degree"]
pp_degree_candidates = candidates["pp_degree"] pp_degree_candidates = candidates["pp_degree"]
...@@ -145,22 +143,22 @@ def search_all(tuner_cfg): ...@@ -145,22 +143,22 @@ def search_all(tuner_cfg):
all_cfgs = list( all_cfgs = list(
itertools.product( itertools.product(
dp_degree_candidates, dp_degree_candidates,
mp_degree_candidates,
pp_degree_candidates,
mbs_candidates,
sharding_degree_candidates, sharding_degree_candidates,
sharding_stage_candidates, sharding_stage_candidates,
mbs_candidates,
pp_degree_candidates,
mp_degree_candidates,
use_recompute_candidates, use_recompute_candidates,
recompute_granularity_candidates, recompute_granularity_candidates,
) )
) )
mapping = { mapping = {
0: "dp_degree", 0: "dp_degree",
1: "mp_degree", 1: "sharding_degree",
2: "pp_degree", 2: "sharding_stage",
3: "micro_batch_size", 3: "micro_batch_size",
5: "sharding_stage", 4: "pp_degree",
4: "sharding_degree", 5: "mp_degree",
6: "use_recompute", 6: "use_recompute",
7: "recompute_granularity", 7: "recompute_granularity",
} }
...@@ -179,42 +177,90 @@ def gen_new_args(raw_args, cfg, tuner_cfg): ...@@ -179,42 +177,90 @@ def gen_new_args(raw_args, cfg, tuner_cfg):
cmd = copy.deepcopy(tuner_cfg["run_cmd"]) cmd = copy.deepcopy(tuner_cfg["run_cmd"])
res_args = copy.deepcopy(raw_args) res_args = copy.deepcopy(raw_args)
if "dp_degree" in cmd and "dp_degree" in cfg: if "dp_degree" in cmd and "dp_degree" in cfg:
cmd["dp_degree"][1] = cmd["dp_degree"][1] + "=" + str(cfg["dp_degree"]) if "--" in cmd["dp_degree"][0]:
cmd["dp_degree"][1] = cmd["dp_degree"][1] + str(cfg["dp_degree"])
res_args.extend(cmd["dp_degree"])
else:
cmd["dp_degree"][1] = (
cmd["dp_degree"][1] + "=" + str(cfg["dp_degree"])
)
res_args.extend(cmd["dp_degree"]) res_args.extend(cmd["dp_degree"])
if "mp_degree" in cmd and "mp_degree" in cfg: if "mp_degree" in cmd and "mp_degree" in cfg:
cmd["mp_degree"][1] = cmd["mp_degree"][1] + "=" + str(cfg["mp_degree"]) if "--" in cmd["mp_degree"][0]:
cmd["mp_degree"][1] = cmd["mp_degree"][1] + str(cfg["mp_degree"])
res_args.extend(cmd["mp_degree"])
else:
cmd["mp_degree"][1] = (
cmd["mp_degree"][1] + "=" + str(cfg["mp_degree"])
)
res_args.extend(cmd["mp_degree"]) res_args.extend(cmd["mp_degree"])
if "pp_degree" in cmd and "pp_degree" in cfg: if "pp_degree" in cmd and "pp_degree" in cfg:
cmd["pp_degree"][1] = cmd["pp_degree"][1] + "=" + str(cfg["pp_degree"]) if "--" in cmd["pp_degree"][0]:
cmd["pp_degree"][1] = cmd["pp_degree"][1] + str(cfg["pp_degree"])
res_args.extend(cmd["pp_degree"])
else:
cmd["pp_degree"][1] = (
cmd["pp_degree"][1] + "=" + str(cfg["pp_degree"])
)
res_args.extend(cmd["pp_degree"]) res_args.extend(cmd["pp_degree"])
if "micro_batch_size" in cmd and "micro_batch_size" in cfg: if "micro_batch_size" in cmd and "micro_batch_size" in cfg:
if "--" in cmd["micro_batch_size"][0]:
cmd["micro_batch_size"][1] = cmd["micro_batch_size"][1] + str(
cfg["micro_batch_size"]
)
res_args.extend(cmd["micro_batch_size"])
else:
cmd["micro_batch_size"][1] = ( cmd["micro_batch_size"][1] = (
cmd["micro_batch_size"][1] + "=" + str(cfg["micro_batch_size"]) cmd["micro_batch_size"][1] + "=" + str(cfg["micro_batch_size"])
) )
res_args.extend(cmd["micro_batch_size"]) res_args.extend(cmd["micro_batch_size"])
if "sharding_degree" in cmd and "sharding_degree" in cfg: if "sharding_degree" in cmd and "sharding_degree" in cfg:
if "--" in cmd["sharding_degree"][0]:
cmd["sharding_degree"][1] = cmd["sharding_degree"][1] + str(
cfg["sharding_degree"]
)
res_args.extend(cmd["sharding_degree"])
else:
cmd["sharding_degree"][1] = ( cmd["sharding_degree"][1] = (
cmd["sharding_degree"][1] + "=" + str(cfg["sharding_degree"]) cmd["sharding_degree"][1] + "=" + str(cfg["sharding_degree"])
) )
res_args.extend(cmd["sharding_degree"]) res_args.extend(cmd["sharding_degree"])
if "sharding_stage" in cmd and "sharding_stage" in cfg: if "sharding_stage" in cmd and "sharding_stage" in cfg:
if "--" in cmd["sharding_stage"][0]:
cmd["sharding_stage"][1] = cmd["sharding_stage"][1] + str(
cfg["sharding_stage"]
)
res_args.extend(cmd["sharding_stage"])
else:
cmd["sharding_stage"][1] = ( cmd["sharding_stage"][1] = (
cmd["sharding_stage"][1] + "=" + str(cfg["sharding_stage"]) cmd["sharding_stage"][1] + "=" + str(cfg["sharding_stage"])
) )
res_args.extend(cmd["sharding_stage"]) res_args.extend(cmd["sharding_stage"])
if "use_recompute" in cmd and "use_recompute" in cfg: if "use_recompute" in cmd and "use_recompute" in cfg:
if "--" in cmd["use_recompute"][0]:
cmd["use_recompute"][1] = cmd["use_recompute"][1] + str(
cfg["use_recompute"]
)
res_args.extend(cmd["use_recompute"])
else:
cmd["use_recompute"][1] = ( cmd["use_recompute"][1] = (
cmd["use_recompute"][1] + "=" + str(cfg["use_recompute"]) cmd["use_recompute"][1] + "=" + str(cfg["use_recompute"])
) )
res_args.extend(cmd["use_recompute"]) res_args.extend(cmd["use_recompute"])
if "recompute_granularity" in cmd and "recompute_granularity" in cfg: if "recompute_granularity" in cmd and "recompute_granularity" in cfg:
if "--" in cmd["recompute_granularity"][0]:
cmd["recompute_granularity"][1] = cmd["recompute_granularity"][
1
] + str(cfg["recompute_granularity"])
res_args.extend(cmd["recompute_granularity"])
else:
cmd["recompute_granularity"][1] = ( cmd["recompute_granularity"][1] = (
cmd["recompute_granularity"][1] cmd["recompute_granularity"][1]
+ "=" + "="
...@@ -228,11 +274,49 @@ def gen_new_args(raw_args, cfg, tuner_cfg): ...@@ -228,11 +274,49 @@ def gen_new_args(raw_args, cfg, tuner_cfg):
// cfg["sharding_degree"] // cfg["sharding_degree"]
// cfg["dp_degree"] // cfg["dp_degree"]
) )
if "--" in cmd["local_batch_size"][0]:
cmd["local_batch_size"][1] = cmd["local_batch_size"][1] + str(
local_batch_size
)
res_args.extend(cmd["local_batch_size"])
else:
cmd["local_batch_size"][1] = ( cmd["local_batch_size"][1] = (
cmd["local_batch_size"][1] + "=" + str(local_batch_size) cmd["local_batch_size"][1] + "=" + str(local_batch_size)
) )
res_args.extend(cmd["local_batch_size"]) res_args.extend(cmd["local_batch_size"])
if "gradient_accumulation_steps" in cmd:
if "--" in cmd["gradient_accumulation_steps"][0]:
try:
gradient_accumulation_steps = (
tuner_cfg["model_cfg"]["global_batch_size"]
// cfg["sharding_degree"]
// cfg["dp_degree"]
// cfg["micro_batch_size"]
)
cmd["gradient_accumulation_steps"][1] = cmd[
"gradient_accumulation_steps"
][1] + str(gradient_accumulation_steps)
res_args.extend(cmd["gradient_accumulation_steps"])
except:
pass
else:
try:
gradient_accumulation_steps = (
tuner_cfg["model_cfg"]["global_batch_size"]
// cfg["sharding_degree"]
// cfg["dp_degree"]
// cfg["micro_batch_size"]
)
cmd["gradient_accumulation_steps"][1] = (
cmd["gradient_accumulation_steps"][1]
+ "="
+ str(gradient_accumulation_steps)
)
res_args.extend(cmd["gradient_accumulation_steps"])
except:
pass
return res_args return res_args
...@@ -245,7 +329,9 @@ def read_log( ...@@ -245,7 +329,9 @@ def read_log(
return (0.0, True) return (0.0, True)
with open(target_file, "r") as f: with open(target_file, "r") as f:
# read file # read file
re_metric_pattern = r'speed: (\d+(\.\d*)?) *' + target_metric re_metric_pattern = (
target_metric + r":* *(\d+(\.\d*)?)|(\d+(\.\d*)?) *" + target_metric
)
metric_list = [] metric_list = []
lines = f.readlines() lines = f.readlines()
......
...@@ -38,6 +38,8 @@ class Context: ...@@ -38,6 +38,8 @@ class Context:
if enable_plugin: if enable_plugin:
self._enable_plugin() self._enable_plugin()
self.max_time_per_task = -1
self.run_best = False
def print(self): def print(self):
self.logger.info("----------- Configuration ----------------------") self.logger.info("----------- Configuration ----------------------")
......
...@@ -282,6 +282,13 @@ class CollectiveElasticController(CollectiveController): ...@@ -282,6 +282,13 @@ class CollectiveElasticController(CollectiveController):
self.job.replicas = replicas self.job.replicas = replicas
else: else:
self.ctx.logger.warning(f"peer not ready {self.job}") self.ctx.logger.warning(f"peer not ready {self.job}")
if self.ctx.is_auto_tuner_mode():
self.ctx.logger.info(
"Failed to start peer, auto tuner exit."
)
import sys
sys.exit(-1)
break break
self.ctx.logger.debug(f"Run {self.job}") self.ctx.logger.debug(f"Run {self.job}")
......
...@@ -36,6 +36,13 @@ class ControllerBase: ...@@ -36,6 +36,13 @@ class ControllerBase:
signal.signal(signal.SIGTERM, self.signal_handler) signal.signal(signal.SIGTERM, self.signal_handler)
signal.signal(signal.SIGABRT, self.signal_handler) signal.signal(signal.SIGABRT, self.signal_handler)
signal.signal(signal.SIGINT, self.signal_handler) signal.signal(signal.SIGINT, self.signal_handler)
if ctx.is_auto_tuner_mode():
if not ctx.run_best:
# set per task timeout
signal.signal(signal.SIGALRM, self.not_exit_signal_handler)
signal.alarm(ctx.max_time_per_task)
else:
signal.alarm(0)
self.ctx = ctx self.ctx = ctx
self.master = Master.factory(self.ctx) self.master = Master.factory(self.ctx)
......
...@@ -295,7 +295,6 @@ def launch(): ...@@ -295,7 +295,6 @@ def launch():
elif ctx.is_auto_tuner_mode(): elif ctx.is_auto_tuner_mode():
import copy import copy
import json import json
import signal
import sys import sys
import time import time
...@@ -304,6 +303,7 @@ def launch(): ...@@ -304,6 +303,7 @@ def launch():
from ..auto_tuner.utils import gen_new_args, read_log from ..auto_tuner.utils import gen_new_args, read_log
from . import controllers from . import controllers
start_time = time.time()
# read user defined tuner config json # read user defined tuner config json
try: try:
with open(ctx.args.auto_tuner_json, "r") as f: with open(ctx.args.auto_tuner_json, "r") as f:
...@@ -326,24 +326,48 @@ def launch(): ...@@ -326,24 +326,48 @@ def launch():
gpus_per_node = len(ctx.args.devices.split(",")) gpus_per_node = len(ctx.args.devices.split(","))
nnodes = ctx.args.nnodes nnodes = ctx.args.nnodes
if isinstance(nnodes, str): if isinstance(nnodes, str):
tuner_cfg["nodes"] = int(nnodes.split(":")[0]) nnodes = int(nnodes.split(":")[0])
else: else:
tuner_cfg["nodes"] = int(nnodes) nnodes = int(nnodes)
tuner_cfg["nodes"] = nnodes
tuner_cfg["num_gpus"] = gpus_per_node * tuner_cfg["nodes"] tuner_cfg["num_gpus"] = gpus_per_node * tuner_cfg["nodes"]
if nnodes > 1:
import etcd3
assert "etcd://" in ctx.args.master
master_ip, port = ctx.args.master.strip("etcd://").split(':')
client = etcd3.client(host=master_ip, port=port)
client.delete("best_cfg")
# build AutoTuner to get new config # build AutoTuner to get new config
auto_tuner = AutoTuner(tuner_cfg) auto_tuner = AutoTuner(tuner_cfg)
cur_cfg = auto_tuner.search_once() cur_cfg = auto_tuner.search_once()
auto_tuner.add_cfg(cur_cfg)
# get max time per task run # get max time per task run
max_time_per_task = tuner_cfg.get("max_time_per_task", 1800) max_time_per_task = tuner_cfg.get("max_time_per_task", 1800)
ctx.max_time_per_task = max_time_per_task
# warmup
warmup_time = (
max_time_per_task
if "warmup_time" not in tuner_cfg
else tuner_cfg.get("warmup_time")
)
is_first_task = True
# build history recorder # build history recorder
recorder = History_recorder() recorder = History_recorder()
job_id = 0 job_id = 0
ctx.args.max_restart = -1
raw_ctx = copy.deepcopy(ctx)
while cur_cfg: while cur_cfg:
ctx.status._current_status = None ctx = copy.deepcopy(raw_ctx)
if is_first_task:
ctx.max_time_per_task = warmup_time
is_first_task = False
# auto tuner supports dp, mp, pp, micro batch size, sharding, recompute by default and every task has own log dir # auto tuner supports dp, mp, pp, micro batch size, sharding, recompute by default and every task has own log dir
log_dir = "DP{}_MP{}_PP{}_Sharding_degree_{}_stage_{}_MBS_{}_Recompute_{}_granularity_{}".format( log_dir = "DP{}_MP{}_PP{}_Sharding_degree_{}_stage_{}_MBS_{}_Recompute_{}_granularity_{}".format(
cur_cfg["dp_degree"], cur_cfg["dp_degree"],
...@@ -373,14 +397,10 @@ def launch(): ...@@ -373,14 +397,10 @@ def launch():
task_job_id, log_dir, cur_cfg task_job_id, log_dir, cur_cfg
) )
) )
c = controllers.init(ctx) c = controllers.init(ctx)
# set per task timeout
signal.signal(signal.SIGALRM, c.not_exit_signal_handler)
signal.alarm(max_time_per_task)
c.run() c.run()
# Process generated result # process generated result
metric, err = read_log( metric, err = read_log(
path=ctx.args.log_dir, path=ctx.args.log_dir,
file="workerlog.0", file="workerlog.0",
...@@ -388,11 +408,14 @@ def launch(): ...@@ -388,11 +408,14 @@ def launch():
) )
if err: if err:
ctx.logger.warning(f"Read log failed for parameters: {log_dir}") ctx.logger.warning(f"Read log failed for parameters: {log_dir}")
cur_cfg['time'] = None # for pruner use. # for pruner use
cur_cfg['time'] = -1
cur_cfg[tuner_cfg['metric_cfg']['name']] = None cur_cfg[tuner_cfg['metric_cfg']['name']] = None
else: else:
cur_cfg['time'] = metric # for pruner use. # for pruner use
cur_cfg['time'] = metric
cur_cfg[tuner_cfg['metric_cfg']['name']] = metric cur_cfg[tuner_cfg['metric_cfg']['name']] = metric
# record history # record history
cur_cfg['job_id'] = job_id cur_cfg['job_id'] = job_id
recorder.add_cfg(**cur_cfg) recorder.add_cfg(**cur_cfg)
...@@ -409,18 +432,78 @@ def launch(): ...@@ -409,18 +432,78 @@ def launch():
ctx.logger.info( ctx.logger.info(
"Get best config failed. Currently there are no appropriate configs." "Get best config failed. Currently there are no appropriate configs."
) )
c.finalize(exit=False)
# generate a new config
new_cfg = auto_tuner.search_once() new_cfg = auto_tuner.search_once()
if new_cfg: cur_cfg = copy.deepcopy(new_cfg)
c.finalize(exit=False) auto_tuner.add_cfg(cur_cfg)
else:
c.finalize(exit=True)
# per task launch interval # per task launch interval
time.sleep(5) time.sleep(3)
cur_cfg = copy.deepcopy(new_cfg)
recorder.store_history() recorder.store_history()
# get best config to run
best_cfg = None
ctx = copy.deepcopy(raw_ctx)
if nnodes > 1:
import socket
ip = None
try:
hostname = socket.gethostname()
ip = socket.gethostbyname(socket.getfqdn(hostname))
except:
ip = '127.0.0.1'
if ip == master_ip:
best_cfg, err = recorder.get_best(
metric=tuner_cfg['metric_cfg']['name'],
direction=tuner_cfg['metric_cfg']['OptimizationDirection'],
)
if err:
raise ValueError(
"Get best config failed. Currently there are no appropriate configs."
)
data = json.dumps(best_cfg)
while not client.put("best_cfg", data):
time.sleep(1)
continue
else:
for i in range(10):
try:
data = client.get("best_cfg")[0].decode()
best_cfg = json.loads(data)
except Exception as e:
ctx.logger.warning(e)
time.sleep(2)
if best_cfg:
break
assert best_cfg
else:
best_cfg, err = recorder.get_best(
metric=tuner_cfg['metric_cfg']['name'],
direction=tuner_cfg['metric_cfg']['OptimizationDirection'],
)
if err:
raise ValueError(
"Get best config failed. Currently there are no appropriate configs."
)
assert best_cfg
end_time = time.time()
ctx.logger.info(f"AutoTuner ends in {end_time-start_time}s.")
# launch best cfg
new_args = gen_new_args(raw_args, best_cfg, tuner_cfg)
ctx.run_best = True
ctx.args.training_script_args = new_args
ctx.args.job_id = "best_cfg"
ctx.logger.info(f"Launch best cfg from auto tuner: {best_cfg}")
ctx.args.log_dir = "best_cfg"
# run best cfg
c = controllers.init(ctx)
c.run()
c.finalize(exit=True)
else: else:
from . import controllers from . import controllers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册