未验证 提交 e12d2867 编写于 作者: C caozhou 提交者: GitHub

[AutoTuner] Add auto tuner to obtain optima configuration (#54460)

* add auto tuner

* fix prune

* fix sharding prune and mbs candidates

* fix cfg

* fix launch

* fix launch

* add unittest

* fix code style
上级 a90d9088
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = []
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
_PRUNE_FUNC = []
def same_cfgs_beside(attr, cur_cfg, history_cfgs):
"""
Compare the current configuration with the history configuration,
and obtain the same configurations as the current configuration except for the given attr.
"""
results = []
same = True
for cfg in history_cfgs:
for key in cur_cfg:
if key == attr:
continue
if key not in history_cfgs or history_cfgs[key] != cur_cfg[key]:
same = False
break
if same:
results.append(cfg)
else:
same = True
return results
def register_prune(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
_PRUNE_FUNC.append(wrapper)
return wrapper
@register_prune
def prune_by_mp(tuner_cfg, cur_cfg, history_cfgs=None):
"""
Prune by mp, the rules are:
1. MP degree should be evenly divided by hidden size and vocab size
2. MP degree should be in the candidates of user defined.
3. MP degree should be less than 8 if no candidates.
"""
mp_degree = cur_cfg.get("mp_degree", None)
hidden_size = tuner_cfg["model_cfg"].get("hidden_size", None)
vocab_size = tuner_cfg["model_cfg"].get("vocab_size", None)
if not mp_degree:
return False
if hidden_size and hidden_size % mp_degree != 0:
return True
if vocab_size and vocab_size % mp_degree != 0:
return True
mp_degree_candidates = tuner_cfg.get("mp_degree", None)
if mp_degree_candidates == "auto":
mp_degree_candidates = tuner_cfg["candidates"]["mp_degree"]
if mp_degree_candidates:
if mp_degree not in mp_degree_candidates:
return True
# prune default candidates
if mp_degree > 8:
return True
return False
@register_prune
def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=None):
pp_degree = cur_cfg.get("pp_degree", None)
num_layers = tuner_cfg["model_cfg"].get("num_layers", None)
num_nodes = tuner_cfg.get("num_nodes", 1)
if not pp_degree:
return False
if num_layers:
if num_layers % pp_degree != 0:
return True
pp_degree_candidates = tuner_cfg.get("pp_degree", None)
if pp_degree_candidates == "auto":
pp_degree_candidates = tuner_cfg["candidates"]["pp_degree"]
if pp_degree_candidates:
if pp_degree not in pp_degree_candidates:
return True
else:
if num_nodes != 1 and pp_degree > num_nodes:
return True
return False
@register_prune
def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None):
micro_batch_size = cur_cfg.get("micro_batch_size", None)
global_batch_size = tuner_cfg["model_cfg"].get("global_batch_size", None)
if global_batch_size:
local_batch_size = (
global_batch_size
// cur_cfg["dp_degree"]
// cur_cfg["sharding_degree"]
)
mbs_candidates = tuner_cfg.get("micro_batch_size", None)
if mbs_candidates == "auto":
mbs_candidates = tuner_cfg["candidates"]["micro_batch_size"]
if not micro_batch_size:
return False
if local_batch_size:
if local_batch_size % micro_batch_size != 0:
return True
if mbs_candidates:
if micro_batch_size not in mbs_candidates:
return True
cfgs = same_cfgs_beside("micro_batch_size", cur_cfg, history_cfgs)
if cfgs:
for cfg in cfgs:
if (
cfg["micro_batch_size"] > micro_batch_size
and cfg.get("time", -1) > 0
):
return True
return False
@register_prune
def prune_by_sharding(tuner_cfg, cur_cfg, history_cfgs):
sharding_stage = cur_cfg.get("sharding_stage", None)
sharding_degree = cur_cfg.get("sharding_degree", None)
pp_degree = cur_cfg.get("pp_degree", None)
if not sharding_stage:
return False
if not sharding_degree:
return False
sharding_stage_candidates = tuner_cfg.get("sharding_stage", None)
if sharding_stage_candidates == "auto":
sharding_stage_candidates = tuner_cfg["candidates"]["sharding_stage"]
sharding_degree_candidates = tuner_cfg.get("sharding_degree", None)
if sharding_degree_candidates == "auto":
sharding_degree_candidates = tuner_cfg["candidates"]["sharding_degree"]
if sharding_stage_candidates:
if sharding_stage not in sharding_stage_candidates:
return True
if sharding_degree_candidates:
if sharding_degree not in sharding_degree_candidates:
return True
if pp_degree and pp_degree != 1 and sharding_stage != 1:
return True
cfgs = same_cfgs_beside("sharding_stage", cur_cfg, history_cfgs)
if cfgs:
for cfg in cfgs:
if (
cfg["sharding_stage"] < sharding_stage
and cfg.get("time", -1) > 0
):
return True
return False
@register_prune
def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs):
recompute_granularity = cur_cfg.get("recompute_granularity", None)
use_recompute = cur_cfg.get("use_recompute", None)
if not use_recompute:
return False
recompute_granularity_candidates = tuner_cfg["candidates"].get(
"recompute_granularity", None
)
use_recompute_candidates = tuner_cfg["candidates"].get(
"use_recompute", None
)
if use_recompute_candidates:
if use_recompute not in use_recompute_candidates:
return True
if recompute_granularity_candidates and recompute_granularity:
if recompute_granularity not in recompute_granularity_candidates:
return True
if not use_recompute and recompute_granularity:
return True
cfgs = same_cfgs_beside("use_recompute", cur_cfg, history_cfgs)
if cfgs:
for cfg in cfgs:
if (
not cfg["use_recompute"]
and use_recompute
and cfg.get("time", -1) > 0
):
return True
return False
@register_prune
def prune_by_num_gpus(tuner_cfg, cur_cfg, history_cfgs):
num_gpus = tuner_cfg.get("num_gpus")
dp_degree = cur_cfg.get("dp_degree", 1)
mp_degree = cur_cfg.get("mp_degree", 1)
pp_degree = cur_cfg.get("pp_degree", 1)
sharding_degree = cur_cfg.get("sharding_degree", 1)
if dp_degree * mp_degree * pp_degree * sharding_degree != num_gpus:
return True
return False
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from .prune import _PRUNE_FUNC
from .utils import search_all
class SearchAlgo(ABC):
def __init__(self, tuner_cfg):
self.tuner_cfg = tuner_cfg
@abstractmethod
def search_once(self, history_cfgs):
pass
def prune(self, tuner_cfg, cur_cfg, history_cfgs):
for func in _PRUNE_FUNC:
result = func(tuner_cfg, cur_cfg, history_cfgs)
if result:
return True
return False
class GridSearch(SearchAlgo):
def __init__(self, tuner_cfg):
super().__init__(tuner_cfg)
self.idx = 0
self.all_tasks = search_all(tuner_cfg)
def search_once(self, history_cfgs):
new_cfg = None
stop = False
while not stop:
if self.idx < len(self.all_tasks):
new_cfg = self.all_tasks[self.idx]
self.idx += 1
stop = not self.prune(self.tuner_cfg, new_cfg, history_cfgs)
else:
return None
return new_cfg
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .utils import default_candidates
class AutoTuner:
"""
The AutoTuner can automatically provide running task based on user-defined settings
and the task will be launched for execution.
Args:
tuner_cfg (dict): The configuration of auto tuner user defined.
"""
def __init__(self, tuner_cfg):
self.cur_task_id = 1
self.task_limit = tuner_cfg.get("task_limit", 100)
tuner_cfg["candidates"] = default_candidates(tuner_cfg)
search_algo = tuner_cfg.get("search_algo", "grid")
if search_algo == "grid":
from .search import GridSearch
self.algo = GridSearch(tuner_cfg)
else:
raise NotImplementedError()
self.history_cfgs = []
def search_once(self):
"""Return a new task config."""
if self.cur_task_id > self.task_limit:
return None
new_cfg = self.algo.search_once(self.history_cfgs)
self.cur_task_id += 1
self.history_cfgs.append(new_cfg)
return new_cfg
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import itertools
def divisor(num, reverse=False):
"""Return the divisor of the given number."""
results = set()
i = 1
mid = num // 2 + 1
while i < mid:
if num % i == 0:
results.add(i)
results.add(num // i)
i += 1
results = list(results)
return sorted(results, reverse=reverse)
def dist_degree(mode, num_gpus, num_nodes):
"""Return the degree of different parallel modes by gpus and nodes num."""
assert mode in ["dp", "mp", "pp", "sharding"]
results = []
if mode == "dp":
results = divisor(num_gpus, reverse=True)
elif mode == "pp":
if num_nodes > 1:
results = list(range(num_nodes))
else:
results = divisor(num_gpus, reverse=True)
elif mode == "mp":
gpus_per_node = num_gpus // num_nodes
results = divisor(gpus_per_node, reverse=True)
elif mode == "sharding":
results = divisor(num_gpus, reverse=True)
return results
def default_candidates(tuner_cfg):
"""Return the default candidates of every hyper param which user defined auto"""
candidates = {}
num_gpus = tuner_cfg["num_gpus"]
num_nodes = tuner_cfg["nodes"]
assert num_gpus > 0
if tuner_cfg.get("dp_degree", None) == "auto":
candidates["dp_degree"] = dist_degree("dp", num_gpus, num_nodes)
elif tuner_cfg.get("dp_degree", None):
candidates["dp_degree"] = tuner_cfg.get("dp_degree")
else:
candidates["dp_degree"] = [1]
if tuner_cfg.get("mp_degree", None) == "auto":
candidates["mp_degree"] = dist_degree("mp", num_gpus, num_nodes)
elif tuner_cfg.get("mp_degree", None):
candidates["mp_degree"] = tuner_cfg.get("mp_degree")
else:
candidates["mp_degree"] = [1]
if tuner_cfg.get("pp_degree", None) == "auto":
candidates["pp_degree"] = dist_degree("pp", num_gpus, num_nodes)
elif tuner_cfg.get("pp_degree", None):
candidates["pp_degree"] = tuner_cfg.get("pp_degree")
else:
candidates["pp_degree"] = [1]
if tuner_cfg.get("sharding_degree", None) == "auto":
candidates["sharding_degree"] = dist_degree(
"sharding", num_gpus, num_nodes
)
elif tuner_cfg.get("sharding_degree", None):
candidates["sharding_degree"] = tuner_cfg.get("sharding_degree")
else:
candidates["sharding_degree"] = [1]
if tuner_cfg.get("sharding_stage", None) == "auto":
candidates["sharding_stage"] = [1, 2, 3]
elif tuner_cfg.get("sharding_stage", None):
candidates["sharding_stage"] = tuner_cfg.get("sharding_stage")
else:
candidates["sharding_stage"] = [None]
if tuner_cfg.get("use_recompute", None) == "auto":
candidates["use_recompute"] = [False, True]
elif tuner_cfg.get("use_recompute", None):
candidates["use_recompute"] = tuner_cfg.get("use_recompute")
else:
candidates["use_recompute"] = [None]
if tuner_cfg.get("recompute_granularity", None) == "auto":
candidates["recompute_granularity"] = ["full_attn", "full"]
elif tuner_cfg.get("recompute_granularity", None):
candidates["recompute_granularity"] = tuner_cfg.get(
"recompute_granularity"
)
else:
candidates["recompute_granularity"] = [None]
if tuner_cfg.get("micro_batch_size", None) == "auto":
candidates["micro_batch_size"] = list(
range(tuner_cfg["model_cfg"]["global_batch_size"], 0, -1)
)
elif tuner_cfg.get("micro_batch_size", None):
candidates["micro_batch_size"] = tuner_cfg.get("micro_batch_size")
else:
candidates["micro_batch_size"] = [
tuner_cfg["model_cfg"]["global_batch_size"]
]
return candidates
def search_all(tuner_cfg):
"""Permutate the candidates of all hyper params."""
candidates = tuner_cfg["candidates"]
# Order: dp -> mp -> pp -> mbs -> sharding-> recompute
dp_degree_candidates = candidates["dp_degree"]
mp_degree_candidates = candidates["mp_degree"]
pp_degree_candidates = candidates["pp_degree"]
mbs_candidates = candidates["micro_batch_size"]
sharding_stage_candidates = candidates["sharding_stage"]
sharding_degree_candidates = candidates["sharding_degree"]
use_recompute_candidates = candidates["use_recompute"]
recompute_granularity_candidates = candidates["recompute_granularity"]
all_cfgs = list(
itertools.product(
dp_degree_candidates,
mp_degree_candidates,
pp_degree_candidates,
mbs_candidates,
sharding_degree_candidates,
sharding_stage_candidates,
use_recompute_candidates,
recompute_granularity_candidates,
)
)
mapping = {
0: "dp_degree",
1: "mp_degree",
2: "pp_degree",
3: "micro_batch_size",
5: "sharding_stage",
4: "sharding_degree",
6: "use_recompute",
7: "recompute_granularity",
}
new_all_cfgs = []
for cfg in all_cfgs:
new_cfg = {}
for idx, val in enumerate(cfg):
new_cfg[mapping[idx]] = val
new_all_cfgs.append(new_cfg)
return new_all_cfgs
def gen_new_args(raw_args, cfg, tuner_cfg):
"""Generate new script args."""
assert "run_cmd" in tuner_cfg
cmd = copy.deepcopy(tuner_cfg["run_cmd"])
res_args = copy.deepcopy(raw_args)
if "dp_degree" in cmd and "dp_degree" in cfg:
cmd["dp_degree"][1] = cmd["dp_degree"][1] + "=" + str(cfg["dp_degree"])
res_args.extend(cmd["dp_degree"])
if "mp_degree" in cmd and "mp_degree" in cfg:
cmd["mp_degree"][1] = cmd["mp_degree"][1] + "=" + str(cfg["mp_degree"])
res_args.extend(cmd["mp_degree"])
if "pp_degree" in cmd and "pp_degree" in cfg:
cmd["pp_degree"][1] = cmd["pp_degree"][1] + "=" + str(cfg["pp_degree"])
res_args.extend(cmd["pp_degree"])
if "micro_batch_size" in cmd and "micro_batch_size" in cfg:
cmd["micro_batch_size"][1] = (
cmd["micro_batch_size"][1] + "=" + str(cfg["micro_batch_size"])
)
res_args.extend(cmd["micro_batch_size"])
if "sharding_degree" in cmd and "sharding_degree" in cfg:
cmd["sharding_degree"][1] = (
cmd["sharding_degree"][1] + "=" + str(cfg["sharding_degree"])
)
res_args.extend(cmd["sharding_degree"])
if "sharding_stage" in cmd and "sharding_stage" in cfg:
cmd["sharding_stage"][1] = (
cmd["sharding_stage"][1] + "=" + str(cfg["sharding_stage"])
)
res_args.extend(cmd["sharding_stage"])
if "use_recompute" in cmd and "use_recompute" in cfg:
cmd["use_recompute"][1] = (
cmd["use_recompute"][1] + "=" + str(cfg["use_recompute"])
)
res_args.extend(cmd["use_recompute"])
if "recompute_granularity" in cmd and "recompute_granularity" in cfg:
cmd["recompute_granularity"][1] = (
cmd["recompute_granularity"][1]
+ "="
+ str(cfg["recompute_granularity"])
)
res_args.extend(cmd["recompute_granularity"])
if "local_batch_size" in cmd:
local_batch_size = (
tuner_cfg["model_cfg"]["global_batch_size"]
// cfg["sharding_degree"]
// cfg["dp_degree"]
)
cmd["local_batch_size"][1] = (
cmd["local_batch_size"][1] + "=" + str(local_batch_size)
)
res_args.extend(cmd["local_batch_size"])
return res_args
......@@ -60,6 +60,11 @@ class Context:
return False
def is_auto_tuner_mode(self):
if self.args.auto_tuner_json:
return True
return False
def get_envs(self):
return self.envs.copy()
......
......@@ -161,6 +161,13 @@ def parse_args():
"training script",
)
base_group.add_argument(
"--auto_tuner_json",
type=str,
default=None,
help="auto tuner json file path",
)
base_group.add_argument('training_script_args', nargs=REMAINDER)
ps_group = parser.add_argument_group("Parameter-Server Parameters")
......
......@@ -130,6 +130,11 @@ class ControllerBase:
self.ctx.status.is_restarting()
and self.master.get_status() != self.ctx.status.COMPLETED
):
# when peer failure, stop peer
if self.ctx.args.elastic_level == -1:
self.pod.stop(timeout=3)
return True
self.pod.stop(timeout=30)
return False
......@@ -141,12 +146,13 @@ class ControllerBase:
self.master.stop()
self.pod.stop(timeout=30)
def finalize(self):
def finalize(self, exit=True):
self.pod.join()
self.master.stop()
self.ctx.logger.info(f"Exit code {self.pod.exit_code}")
sys.exit(self.pod.exit_code)
if exit:
sys.exit(self.pod.exit_code)
def signal_handler(self, sigint, frame):
if hasattr(self, 'sigint'):
......@@ -162,6 +168,18 @@ class ControllerBase:
self.ctx.logger.info(f"Exit with signal {sigint}")
sys.exit(sigint)
def not_exit_signal_handler(self, sigint, frame):
if hasattr(self, 'sigint'):
self.ctx.logger.info("Force quit in 10 seconds...")
self.pod.stop(timeout=10)
self.ctx.logger.info(f"Terminating with signal {sigint}")
self.sigint = sigint
self.ctx.status.done()
self.stop(sigint=sigint)
self.ctx.logger.info(f"Exit with signal {sigint}")
class Controller(ControllerBase):
'''
......
......@@ -292,6 +292,98 @@ def launch():
launch.launch()
elif ctx.is_auto_tuner_mode():
import copy
import json
import signal
import sys
import time
from ..auto_tuner.tuner import AutoTuner
from ..auto_tuner.utils import gen_new_args
from . import controllers
# read user defined tuner config json
try:
with open(ctx.args.auto_tuner_json, "r") as f:
tuner_cfg = json.load(f)
except:
raise ValueError("Please check your auto tuner json whether valid.")
# copy training script args
if ctx.args.training_script.endswith('.py'):
entrypoint = [sys.executable, "-u", ctx.args.training_script]
else:
entrypoint = [ctx.args.training_script]
entrypoint.extend(ctx.args.training_script_args)
raw_args = copy.deepcopy(ctx.args.training_script_args)
# get nodes and gpus from args
if not ctx.args.devices:
gpus_per_node = 8
else:
gpus_per_node = len(ctx.args.devices.split(","))
tuner_cfg["nodes"] = int(ctx.args.nnodes)
tuner_cfg["num_gpus"] = gpus_per_node * tuner_cfg["nodes"]
# build AutoTuner to get new config
auto_tuner = AutoTuner(tuner_cfg)
cur_cfg = auto_tuner.search_once()
# get max time per task run
max_time_per_task = tuner_cfg.get("max_time_per_task", 1800)
job_id = 0
while cur_cfg:
# auto tuner supports dp, mp, pp, micro batch size, sharding, recompute by default and every task has own log dir
log_dir = "DP{}_MP{}_PP{}_Sharding_degree_{}_stage_{}_MBS_{}_Recompute_{}_granularity_{}".format(
cur_cfg["dp_degree"],
cur_cfg["pp_degree"],
cur_cfg["pp_degree"],
cur_cfg["sharding_degree"],
cur_cfg["sharding_stage"],
cur_cfg["micro_batch_size"],
cur_cfg["use_recompute"],
cur_cfg["recompute_granularity"],
)
ctx.args.log_dir = log_dir
# every task has own job id
job_id += 1
task_job_id = "auto_tuner_" + str(job_id)
ctx.args.job_id = task_job_id
# generate script args of task
new_args = gen_new_args(raw_args, cur_cfg, tuner_cfg)
ctx.args.training_script_args = new_args
# launch task
ctx.logger.info(
"Launch task from auto tuner: job_id {}, log_dir {}, config {}".format(
task_job_id, log_dir, cur_cfg
)
)
c = controllers.init(ctx)
# set per task timeout
signal.signal(signal.SIGALRM, c.not_exit_signal_handler)
signal.alarm(max_time_per_task)
c.run()
new_cfg = auto_tuner.search_once()
if new_cfg:
c.finalize(exit=False)
else:
c.finalize(exit=True)
# NOTE: The statistics and comparison function of task results will be implemented in the future.
# per task launch interval
time.sleep(5)
cur_cfg = copy.deepcopy(new_cfg)
else:
from . import controllers
......
......@@ -65,6 +65,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pass_1F1B MODULES test_pass_1F1B)
set_tests_properties(test_pass_1F1B PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_auto_tuner MODULES test_auto_tuner)
set_tests_properties(test_auto_tuner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 100)
# End of unittests WITH multi cards and timeout
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import subprocess
import sys
import tempfile
import unittest
class TestEngineAPI(unittest.TestCase):
def test_auto_tuner(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "engine_api_dp.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
test_info = {
"dp_degree": "auto",
"mp_degree": "auto",
"pp_degree": "auto",
"micro_batch_size": "auto",
"sharding_degree": "auto",
"sharding_stage": "auto",
"use_recompute": "auto",
"recompute_granularity": "auto",
"task_limit": 1,
"max_time_per_task": 90,
"model_cfg": {
"hidden_size": 2048,
"global_batch_size": 64,
"num_layers": 24,
"num_attention_heads": 16,
"vocab_size": 50304,
},
"run_cmd": {
"dp_degree": ["-o", "Distributed.dp_degree"],
"mp_degree": ["-o", "Distributed.mp_degree"],
"pp_degree": ["-o", "Distributed.pp_degree"],
"micro_batch_size": ["-o", "Global.micro_batch_size"],
"local_batch_size": ["-o", "Global.local_batch_size"],
"sharding_degree": [
"-o",
"Distributed.sharding.sharding_degree",
],
"sharding_stage": ["-o", "Distributed.sharding.sharding_stage"],
"use_recompute": ["-o", "Model.use_recompute"],
"recompute_granularity": ["-o", "Model.recompute_granularity"],
},
}
tmp_dir = tempfile.TemporaryDirectory()
json_object = json.dumps(test_info)
test_json_path = os.path.join(tmp_dir.name, "test.json")
with open(test_json_path, "w") as f:
f.write(json_object)
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
"--auto_tuner_json",
test_json_path,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册