diff --git a/python/paddle/distributed/auto_tuner/prune.py b/python/paddle/distributed/auto_tuner/prune.py index 9016f222e53af04b416e88999cc883c535c82909..71e7d95e2ed64077bbf8bf08a3e63d1a5c1c97d5 100644 --- a/python/paddle/distributed/auto_tuner/prune.py +++ b/python/paddle/distributed/auto_tuner/prune.py @@ -83,6 +83,12 @@ def prune_by_mp(tuner_cfg, cur_cfg, history_cfgs=None): @register_prune def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=None): + """ + Prune by pp (pipeline-parallelism), the rules are: + 1. PP degree should be evenly divided by number of layers. + 2. PP degree should be in the candidates of user defined. + 3. If no candidates, PP degree should be less than or equal to the number of nodes. + """ pp_degree = cur_cfg.get("pp_degree", None) num_layers = tuner_cfg["model_cfg"].get("num_layers", None) num_nodes = tuner_cfg.get("num_nodes", 1) @@ -108,6 +114,12 @@ def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=None): @register_prune def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None): + """ + Prune by mbs (micro batch size), the rules are: + 1. Micro batch size should be evenly divided by the local batch size. + 2. Micro batch size should be in the candidates of user defined. + 3. Prune if a similar configuration with a larger micro batch size resulted in a valid run. + """ micro_batch_size = cur_cfg.get("micro_batch_size", None) global_batch_size = tuner_cfg["model_cfg"].get("global_batch_size", None) if global_batch_size: @@ -146,6 +158,13 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None): @register_prune def prune_by_sharding(tuner_cfg, cur_cfg, history_cfgs): + """ + Prune by sharding parameters, the rules are: + 1. Sharding stage and sharding degree should be specified. + 2. Sharding stage and degree should be in the candidates of user defined. + 3. If PP (pipeline-parallelism) degree is not 1, sharding stage must be 1. + 4. Prune if a similar configuration with a lower sharding stage resulted in a valid run. + """ sharding_stage = cur_cfg.get("sharding_stage", None) sharding_degree = cur_cfg.get("sharding_degree", None) pp_degree = cur_cfg.get("pp_degree", None) @@ -183,11 +202,24 @@ def prune_by_sharding(tuner_cfg, cur_cfg, history_cfgs): and cfg.get("time", -1) > 0 ): return True + + if sharding_degree == 1: + cfgs = same_cfgs_beside("sharding_stage", cur_cfg, history_cfgs) + if cfgs: + return True + return False @register_prune def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs): + """ + Prune by recompute parameters, the rules are: + 1. If recompute is not used, return False directly. + 2. Usage of recompute and recompute granularity should be in the candidates of user defined. + 3. If recompute is not used, but recompute granularity is set, return True for pruning. + 4. Prune if a similar configuration without using recompute resulted in a valid run. + """ recompute_granularity = cur_cfg.get("recompute_granularity", None) use_recompute = cur_cfg.get("use_recompute", None) if not use_recompute: @@ -221,6 +253,10 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs): ): return True + if use_recompute is False: + cfgs = same_cfgs_beside("recompute_granularity", cur_cfg, history_cfgs) + if cfgs: + return True return False diff --git a/python/paddle/distributed/auto_tuner/recorder.py b/python/paddle/distributed/auto_tuner/recorder.py new file mode 100644 index 0000000000000000000000000000000000000000..d742d751a7a2ce1768955bbd17d3a616a730a3ec --- /dev/null +++ b/python/paddle/distributed/auto_tuner/recorder.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import os +from typing import Tuple + +import pandas as pd + + +class History_recorder: + # NOTE increase extenable ablitity + def __init__(self) -> None: + self.history = [] + self.store_path = None + + def add_cfg(self, **kwargs): + cur_configs = {} + for key, val in kwargs.items(): + cur_configs[key] = val + self.history.append(cur_configs) + + def sort_metric(self, direction, metric_name) -> None: + if direction == 'Maximize': + self.history.sort( + key=lambda x: x[metric_name] + if x[metric_name] is not None + else float('-inf'), + reverse=True, + ) + else: + self.history.sort( + key=lambda x: x[metric_name] + if x[metric_name] is not None + else float('inf'), + reverse=False, + ) + return + + def get_best(self, metric, direction) -> Tuple[dict, bool]: + self.sort_metric(direction=direction, metric_name=metric) + if len(self.history) == 0: + return (self.history[0], True) + return (self.history[0], False) + + def store_history(self, path="./history.csv"): + """Store history to csv file.""" + self.store_path = path + # convert to pd dataframe + df = pd.DataFrame(self.history) + # move 'job_id' to the first column + cols = df.columns.tolist() + cols.insert(0, cols.pop(cols.index('job_id'))) + df = df.reindex(columns=cols) + # write to csv + df.to_csv(self.store_path, index=False) + + def load_history(self, path="./history.csv") -> Tuple[list, bool]: + """Load history from csv file.""" + err = False + if self.store_path is None: + self.store_path = path + if not os.path.exists(self.store_path): + err = True + else: + with open(self.store_path, "r") as f: + reader = csv.reader(f) + self.history = list(reader) + return (self.history, err) diff --git a/python/paddle/distributed/auto_tuner/utils.py b/python/paddle/distributed/auto_tuner/utils.py index a0a4d669f23505127bd4e9d8859fdc1d01dfcfbe..9c322b1d7a535186975f17ad3b7222eb5eec5a0c 100644 --- a/python/paddle/distributed/auto_tuner/utils.py +++ b/python/paddle/distributed/auto_tuner/utils.py @@ -14,6 +14,9 @@ import copy import itertools +import os +import re +from typing import Tuple def divisor(num, reverse=False): @@ -231,3 +234,38 @@ def gen_new_args(raw_args, cfg, tuner_cfg): res_args.extend(cmd["local_batch_size"]) return res_args + + +def read_log( + path, file="workerlog.0", target_metric='step/s' +) -> Tuple[float, bool]: + """For extracting metric from log file.""" + target_file = path + "/" + file + if not os.path.exists(target_file): + return (0.0, True) + with open(target_file, "r") as f: + # read file + re_metric_pattern = r'speed: (\d+(\.\d*)?) *' + target_metric + + metric_list = [] + lines = f.readlines() + for line in lines: + metric = re.findall(re_metric_pattern, line) + if metric: + metric_list.append(float(metric[0][0])) + if not metric_list: + metric_ave = 0.0 + flag = True + elif len(metric_list) < 10: + metric_ave = metric_list[-1] + flag = False + elif len(metric_list) < 20: + metric_ave = sum(metric_list[9:]) / (len(metric_list[9:])) + flag = False + else: + metric_ave = sum(metric_list[-10:]) / 10 + flag = False + # round to 5 decimal places + metric_ave = round(metric_ave, 5) + res = metric_ave, flag + return res diff --git a/python/paddle/distributed/launch/main.py b/python/paddle/distributed/launch/main.py index 53a1be6c915a01ea25003765dba29db574aac894..7823ddad27ca364d7b649e9fe748903e6376baab 100644 --- a/python/paddle/distributed/launch/main.py +++ b/python/paddle/distributed/launch/main.py @@ -299,8 +299,9 @@ def launch(): import sys import time + from ..auto_tuner.recorder import History_recorder from ..auto_tuner.tuner import AutoTuner - from ..auto_tuner.utils import gen_new_args + from ..auto_tuner.utils import gen_new_args, read_log from . import controllers # read user defined tuner config json @@ -323,7 +324,11 @@ def launch(): gpus_per_node = 8 else: gpus_per_node = len(ctx.args.devices.split(",")) - tuner_cfg["nodes"] = int(ctx.args.nnodes) + nnodes = ctx.args.nnodes + if isinstance(nnodes, str): + tuner_cfg["nodes"] = int(nnodes.split(":")[0]) + else: + tuner_cfg["nodes"] = int(nnodes) tuner_cfg["num_gpus"] = gpus_per_node * tuner_cfg["nodes"] # build AutoTuner to get new config @@ -333,12 +338,16 @@ def launch(): # get max time per task run max_time_per_task = tuner_cfg.get("max_time_per_task", 1800) + # build history recorder + recorder = History_recorder() + job_id = 0 while cur_cfg: + ctx.status._current_status = None # auto tuner supports dp, mp, pp, micro batch size, sharding, recompute by default and every task has own log dir log_dir = "DP{}_MP{}_PP{}_Sharding_degree_{}_stage_{}_MBS_{}_Recompute_{}_granularity_{}".format( cur_cfg["dp_degree"], - cur_cfg["pp_degree"], + cur_cfg["mp_degree"], cur_cfg["pp_degree"], cur_cfg["sharding_degree"], cur_cfg["sharding_stage"], @@ -371,19 +380,47 @@ def launch(): signal.alarm(max_time_per_task) c.run() + # Process generated result + metric, err = read_log( + path=ctx.args.log_dir, + file="workerlog.0", + target_metric=tuner_cfg["metric_cfg"]["name"], + ) + if err: + ctx.logger.warning(f"Read log failed for parameters: {log_dir}") + cur_cfg['time'] = None # for pruner use. + cur_cfg[tuner_cfg['metric_cfg']['name']] = None + else: + cur_cfg['time'] = metric # for pruner use. + cur_cfg[tuner_cfg['metric_cfg']['name']] = metric + # record history + cur_cfg['job_id'] = job_id + recorder.add_cfg(**cur_cfg) + cur_best_cfgs, err = recorder.get_best( + metric=tuner_cfg['metric_cfg']['name'], + direction=tuner_cfg['metric_cfg']['OptimizationDirection'], + ) + if not err: + ctx.logger.info(f"Current best config: {cur_best_cfgs}") + recorder.store_history( + ctx.args.auto_tuner_json.split(".")[0] + "_history.csv" + ) + else: + ctx.logger.info( + "Get best config failed. Currently there are no appropriate configs." + ) + new_cfg = auto_tuner.search_once() if new_cfg: c.finalize(exit=False) else: c.finalize(exit=True) - # NOTE: The statistics and comparison function of task results will be implemented in the future. - # per task launch interval time.sleep(5) cur_cfg = copy.deepcopy(new_cfg) - + recorder.store_history() else: from . import controllers diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index d7c86c4a01af7ca47e835267dbd5fc8545bd3845..43c62bd9b1e083dee0feab6253b9992c331e26af 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -68,6 +68,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_auto_tuner MODULES test_auto_tuner) set_tests_properties(test_auto_tuner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) + py_test_modules(test_auto_tuner_compare MODULES test_auto_tuner_compare) + set_tests_properties(test_auto_tuner_compare + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) # End of unittests WITH multi cards and timeout # NOTE(zyl): unittests WITH multi cards and WITHOUT timeout diff --git a/test/auto_parallel/test_auto_tuner.py b/test/auto_parallel/test_auto_tuner.py index 31973051b6f24af77eee6a9ba8e956bbc405652d..ad0eb94d48c78b60b0db369d539c27d1ab81aebd 100644 --- a/test/auto_parallel/test_auto_tuner.py +++ b/test/auto_parallel/test_auto_tuner.py @@ -61,6 +61,10 @@ class TestEngineAPI(unittest.TestCase): "use_recompute": ["-o", "Model.use_recompute"], "recompute_granularity": ["-o", "Model.recompute_granularity"], }, + "metric_cfg": { + "name": "step/s", + "OptimizationDirection": "Maximize", + }, } tmp_dir = tempfile.TemporaryDirectory() diff --git a/test/auto_parallel/test_auto_tuner_compare.py b/test/auto_parallel/test_auto_tuner_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..b693d72493fdab2158cded097e096553557df0c4 --- /dev/null +++ b/test/auto_parallel/test_auto_tuner_compare.py @@ -0,0 +1,99 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import subprocess +import sys +import tempfile +import unittest + + +class TestEngineAPI(unittest.TestCase): + def test_auto_tuner_compare(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "engine_api_dp.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + test_info = { + "dp_degree": "auto", + "mp_degree": "auto", + "pp_degree": "auto", + "micro_batch_size": "auto", + "sharding_degree": "auto", + "sharding_stage": "auto", + "use_recompute": "auto", + "recompute_granularity": "auto", + "task_limit": 1, + "max_time_per_task": 90, + "model_cfg": { + "hidden_size": 2048, + "global_batch_size": 64, + "num_layers": 24, + "num_attention_heads": 16, + "vocab_size": 50304, + }, + "run_cmd": { + "dp_degree": ["-o", "Distributed.dp_degree"], + "mp_degree": ["-o", "Distributed.mp_degree"], + "pp_degree": ["-o", "Distributed.pp_degree"], + "micro_batch_size": ["-o", "Global.micro_batch_size"], + "local_batch_size": ["-o", "Global.local_batch_size"], + "sharding_degree": [ + "-o", + "Distributed.sharding.sharding_degree", + ], + "sharding_stage": ["-o", "Distributed.sharding.sharding_stage"], + "use_recompute": ["-o", "Model.use_recompute"], + "recompute_granularity": ["-o", "Model.recompute_granularity"], + }, + "metric_cfg": { + "name": "step/s", + "OptimizationDirection": "Maximize", + }, + } + + tmp_dir = tempfile.TemporaryDirectory() + json_object = json.dumps(test_info) + test_json_path = os.path.join(tmp_dir.name, "test.json") + with open(test_json_path, "w") as f: + f.write(json_object) + cmd = ( + [sys.executable, "-u"] + + coverage_args + + [ + "-m", + "paddle.distributed.launch", + "--devices", + "0,1", + "--log_dir", + tmp_dir.name, + "--auto_tuner_json", + test_json_path, + launch_model_path, + ] + ) + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main()