diff --git a/python/paddle/distributed/auto_tuner/prune.py b/python/paddle/distributed/auto_tuner/prune.py index 888b53ee6b2b6ef0dc3c64d6d055d7678307304e..6f6d549e504927ea17281c0936d565112da999e4 100644 --- a/python/paddle/distributed/auto_tuner/prune.py +++ b/python/paddle/distributed/auto_tuner/prune.py @@ -122,6 +122,8 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None): """ micro_batch_size = cur_cfg.get("micro_batch_size", None) global_batch_size = tuner_cfg["model_cfg"].get("global_batch_size", None) + if global_batch_size == "auto": + global_batch_size = cur_cfg["global_batch_size"] if global_batch_size: local_batch_size = ( global_batch_size diff --git a/python/paddle/distributed/auto_tuner/recorder.py b/python/paddle/distributed/auto_tuner/recorder.py index ad388a9bfe2f792254abf8798e465595dcacbd37..ad8847c8a0dc6d31e7f2624ba0f960e5b5d6456c 100644 --- a/python/paddle/distributed/auto_tuner/recorder.py +++ b/python/paddle/distributed/auto_tuner/recorder.py @@ -19,7 +19,7 @@ from typing import Tuple import pandas as pd -class History_recorder: +class HistoryRecorder: # NOTE increase extenable ablitity def __init__(self) -> None: self.history = [] @@ -63,7 +63,9 @@ class History_recorder: cols = df.columns.tolist() cols.insert(0, cols.pop(cols.index('job_id'))) df = df.reindex(columns=cols) - df = df.drop(columns=['time']) + # check if 'time' exists + if 'time' in df.columns: + df = df.drop(columns=['time']) # write to csv df.to_csv(self.store_path, index=False) @@ -79,3 +81,7 @@ class History_recorder: reader = csv.reader(f) self.history = list(reader) return (self.history, err) + + def clean_history(self) -> None: + """Clean history.""" + self.history = [] diff --git a/python/paddle/distributed/auto_tuner/search.py b/python/paddle/distributed/auto_tuner/search.py index 01029e7f3727b9627f1b468ce00f66420412b6eb..0e0114a5249f0874d5fcb39617f25cc5a621bcd7 100644 --- a/python/paddle/distributed/auto_tuner/search.py +++ b/python/paddle/distributed/auto_tuner/search.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from .prune import _PRUNE_FUNC -from .utils import search_all +from .utils import gbs_search_all, search_all class SearchAlgo(ABC): @@ -52,3 +52,24 @@ class GridSearch(SearchAlgo): else: return None return new_cfg + + +class GBSSearch(SearchAlgo): + def __init__(self, tuner_cfg): + super().__init__(tuner_cfg) + self.idx = 0 + self.all_tasks = gbs_search_all(tuner_cfg) + + def search_once(self, history_cfgs): + new_cfg = None + stop = False + while not stop: + if self.idx < len(self.all_tasks): + new_cfg = self.all_tasks[self.idx] + self.idx += 1 + glb = new_cfg.get("global_batch_size", None) + self.tuner_cfg["model_cfg"]["global_batch_size"] = glb + stop = not self.prune(self.tuner_cfg, new_cfg, history_cfgs) + else: + return None + return new_cfg diff --git a/python/paddle/distributed/auto_tuner/tuner.py b/python/paddle/distributed/auto_tuner/tuner.py index 26831a2e8fcbda89c54821c46467ab2022380aa3..bdc6bed5c6a0854b09cbf9f5730581252a3de9fb 100644 --- a/python/paddle/distributed/auto_tuner/tuner.py +++ b/python/paddle/distributed/auto_tuner/tuner.py @@ -13,7 +13,7 @@ # limitations under the License. -from .utils import default_candidates +from .utils import default_candidates, gbs_default_candidates class AutoTuner: @@ -29,13 +29,18 @@ class AutoTuner: self.cur_task_id = 1 self.task_limit = tuner_cfg.get("task_limit", 100) - tuner_cfg["candidates"] = default_candidates(tuner_cfg) search_algo = tuner_cfg.get("search_algo", "grid") if search_algo == "grid": from .search import GridSearch + tuner_cfg["candidates"] = default_candidates(tuner_cfg) self.algo = GridSearch(tuner_cfg) + elif search_algo == "gbs": + from .search import GBSSearch + + tuner_cfg["candidates"] = gbs_default_candidates(tuner_cfg) + self.algo = GBSSearch(tuner_cfg) else: raise NotImplementedError() diff --git a/python/paddle/distributed/auto_tuner/utils.py b/python/paddle/distributed/auto_tuner/utils.py index 8db11df08c5a997515612d82d6f2c6eac1030097..43ac4bddf48e69bc82c01be270937f51c4d7e099 100644 --- a/python/paddle/distributed/auto_tuner/utils.py +++ b/python/paddle/distributed/auto_tuner/utils.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +import csv import itertools import os import re @@ -320,38 +321,227 @@ def gen_new_args(raw_args, cfg, tuner_cfg): return res_args -def read_log( +def read_metric_log( path, file="workerlog.0", target_metric='step/s' -) -> Tuple[float, bool]: +) -> Tuple[float, int]: """For extracting metric from log file.""" + """ + return: + metric: average metric of last 10 steps + err_code: + 00: no error + 01: no metric + 10: out of memory + """ + err_code = 0 target_file = path + "/" + file if not os.path.exists(target_file): - return (0.0, True) + return (0.0, 1) with open(target_file, "r") as f: # read file re_metric_pattern = ( target_metric + r":* *(\d+(\.\d*)?)|(\d+(\.\d*)?) *" + target_metric ) - + re_out_of_memory_pattern = r"Out of memory" + out_of_memory_flag = 0 metric_list = [] lines = f.readlines() for line in lines: metric = re.findall(re_metric_pattern, line) + out_of_memory = re.findall( + re_out_of_memory_pattern, line, re.IGNORECASE + ) if metric: metric_list.append(float(metric[0][0])) + if out_of_memory: + out_of_memory_flag = 1 + + if out_of_memory_flag: + metric_ave = 0.0 + err_code = err_code | (out_of_memory_flag << 1) if not metric_list: metric_ave = 0.0 - flag = True + err_code = err_code | 1 elif len(metric_list) < 10: metric_ave = metric_list[-1] - flag = False elif len(metric_list) < 20: metric_ave = sum(metric_list[9:]) / (len(metric_list[9:])) - flag = False else: metric_ave = sum(metric_list[-10:]) / 10 - flag = False # round to 5 decimal places metric_ave = round(metric_ave, 5) - res = metric_ave, flag + res = metric_ave, err_code return res + + +def read_memory_log(path, file) -> Tuple[float, bool]: + log_path = os.path.join(path, file) + if not os.path.exists(log_path): + return (0.0, True) + memory_used = [] + utilization_gpu = [] + indexs = [] + + with open(log_path, 'r') as f: + reader = csv.reader(f) + flag = False + # skip headers + while not flag: + # show the first line of reader + row = next(reader) + if len(row) == 6 and 'memory_used' in row: + flag = True + for row in reader: + # If row length is 6 then it's a utilization data row + # skip header + if len(row) == 6: + index, util_gpu, _, mem_used, _, _ = row + indexs.append(int(index)) + memory_used.append(int(mem_used)) + utilization_gpu.append(int(util_gpu)) + return max(memory_used), False + + +def read_log( + path, + metric_file="workerlog.0", + target_metric='step/s', + memory_file="0.gpu.log", +) -> Tuple[float, float, int]: + """ + extract metric and max memory usage from log file + return: + metric: average metric of last 10 steps + memory: max memory used + err_code: 00: no error, 01: no metric, 10: out of memory, 100: no memory log + """ + err_code = 0 + # check out of memory + for root, dirs, files in os.walk(path): + for file in files: + if not file.startswith("workerlog"): + continue + metric, metric_flag = read_metric_log(path, file, target_metric) + if metric_flag: + err_code = (metric_flag & 2) | err_code + + # read metric + res_metric, metric_flag = read_metric_log(path, metric_file, target_metric) + err_code = metric_flag | err_code + # check max memory usage + try: + res_memory, memory_flag = read_memory_log(path, memory_file) + err_code = (memory_flag << 2) | err_code + except: + res_memory = 0.0 + err_code = (1 << 2) | err_code + return res_metric, res_memory, err_code + + +def three_mul_combinations(target): + """Return the combinations of three numbers which product is target.""" + results = [] + for i in range(1, target // 3 + 1): + if target % i == 0: + for j in range(i, target // 2 + 1): + if (target // i) % j == 0: + results.append((i, j, target // i // j)) + return results + + +def gbs_dp_mp_pp_candidates(tuner_cfg, num_gpus, num_nodes): + """Return middle candidates of dp, mp, pp""" + + start = round(num_gpus ** (1 / 3)) + + # find factors that can be evenly distributed + for i in range(start, 0, -1): + if num_gpus % i == 0: + remaining = num_gpus // i + # find the square root as a factor for the remaining part + j = round(remaining**0.5) + while remaining % j != 0: + j -= 1 + return i, j, remaining // j + + raise ValueError("Cannot distribute GPUs equally") + + +def gbs_default_candidates(tuner_cfg): + """Return the default candidates of every hyper param which user defined auto""" + candidates = {} + num_gpus = tuner_cfg["num_gpus"] + num_nodes = tuner_cfg["nodes"] + assert num_gpus > 0 + global_batch_size = tuner_cfg.get("model_cfg", {}).get( + "global_batch_size", "auto" + ) + if global_batch_size == "auto": + dp_candidate, mp_candidate, pp_candidate = gbs_dp_mp_pp_candidates( + tuner_cfg, num_gpus, num_nodes + ) + sharding_dgree_candidate = dp_candidate + candidates["dp_degree"] = [1] + candidates["mp_degree"] = [mp_candidate] + candidates["pp_degree"] = [pp_candidate] + candidates["sharding_degree"] = [sharding_dgree_candidate] + candidates["sharding_stage"] = [1] + candidates["use_recompute"] = [False] + candidates["recompute_granularity"] = [None] + candidates["micro_batch_size"] = [2**i for i in range(0, 10)] + candidates["global_batch_size"] = [ + pp_candidate * dp_candidate * e + for e in candidates["micro_batch_size"] + ] + return candidates + + +def gbs_search_all(tuner_cfg): + """Permutate the candidates of all hyper params.""" + candidates = tuner_cfg["candidates"] + # Order: dp -> mp -> pp -> mbs -> sharding-> recompute + dp_degree_candidates = candidates["dp_degree"] + mp_degree_candidates = candidates["mp_degree"] + pp_degree_candidates = candidates["pp_degree"] + mbs_candidates = candidates["micro_batch_size"] + sharding_stage_candidates = candidates["sharding_stage"] + sharding_degree_candidates = candidates["sharding_degree"] + use_recompute_candidates = candidates["use_recompute"] + recompute_granularity_candidates = candidates["recompute_granularity"] + # gbs_candidates = candidates["global_batch_size"] + all_cfgs = list( + itertools.product( + dp_degree_candidates, + mp_degree_candidates, + pp_degree_candidates, + mbs_candidates, + sharding_degree_candidates, + sharding_stage_candidates, + use_recompute_candidates, + recompute_granularity_candidates, + # gbs_candidates, + ) + ) + mapping = { + 0: "dp_degree", + 1: "mp_degree", + 2: "pp_degree", + 3: "micro_batch_size", + 5: "sharding_stage", + 4: "sharding_degree", + 6: "use_recompute", + 7: "recompute_granularity", + # 8: "global_batch_size", + } + new_all_cfgs = [] + for cfg in all_cfgs: + new_cfg = {} + for idx, val in enumerate(cfg): + new_cfg[mapping[idx]] = val + new_cfg["global_batch_size"] = ( + new_cfg["pp_degree"] + * new_cfg["dp_degree"] + * new_cfg["micro_batch_size"] + ) + new_all_cfgs.append(new_cfg) + return new_all_cfgs diff --git a/python/paddle/distributed/launch/controllers/master.py b/python/paddle/distributed/launch/controllers/master.py index c19a7ff53595fa302143984379222518797f2bf1..e04ee59b2428587b6f2a6bb994af408c003f8b7a 100644 --- a/python/paddle/distributed/launch/controllers/master.py +++ b/python/paddle/distributed/launch/controllers/master.py @@ -260,7 +260,17 @@ class ETCDMaster(Master): delete_success = True except: time.sleep(1) - lease = self.client.lease(ttl) + + if self.ctx.is_auto_tuner_mode(): + lease_success = False + while not lease_success: + try: + lease = self.client.lease(ttl) + lease_success = True + except: + time.sleep(1) + else: + lease = self.client.lease(ttl) # self.client.delete_prefix(self.job_prefix) diff --git a/python/paddle/distributed/launch/main.py b/python/paddle/distributed/launch/main.py index 908c0af8cc18f17648f2e2d6e790c75537766d17..bd77c84bbcc87fd6ef3172b6b6a618225cd39bcf 100644 --- a/python/paddle/distributed/launch/main.py +++ b/python/paddle/distributed/launch/main.py @@ -298,7 +298,7 @@ def launch(): import sys import time - from ..auto_tuner.recorder import History_recorder + from ..auto_tuner.recorder import HistoryRecorder from ..auto_tuner.tuner import AutoTuner from ..auto_tuner.utils import gen_new_args, read_log from . import controllers @@ -340,11 +340,6 @@ def launch(): client = etcd3.client(host=master_ip, port=port) client.delete("best_cfg") - # build AutoTuner to get new config - auto_tuner = AutoTuner(tuner_cfg) - cur_cfg = auto_tuner.search_once() - auto_tuner.add_cfg(cur_cfg) - # get max time per task run max_time_per_task = tuner_cfg.get("max_time_per_task", 1800) ctx.max_time_per_task = max_time_per_task @@ -358,11 +353,140 @@ def launch(): is_first_task = True # build history recorder - recorder = History_recorder() + recorder = HistoryRecorder() job_id = 0 ctx.args.max_restart = -1 raw_ctx = copy.deepcopy(ctx) + + # gbs search + if ( + tuner_cfg.get('model_cfg', {}).get('global_batch_size', 'auto') + == "auto" + ): + # adjust micron batch size until out of memory to get best global batch size + gbs_tuner_cfg = copy.deepcopy(tuner_cfg) + gbs_tuner_cfg["search_algo"] = "gbs" + gbs_tuner = AutoTuner(gbs_tuner_cfg) + + gbs_cur_cfg = gbs_tuner.search_once() + best_gbs = None + while gbs_cur_cfg: + ctx = copy.deepcopy(raw_ctx) + log_dir = "GBSSearch/GBS{}_DP{}_MP{}_PP{}_Sharding_degree_{}_stage_{}_MBS_{}_Recompute_{}_granularity_{}".format( + gbs_cur_cfg["global_batch_size"], + gbs_cur_cfg["dp_degree"], + gbs_cur_cfg["mp_degree"], + gbs_cur_cfg["pp_degree"], + gbs_cur_cfg["sharding_degree"], + gbs_cur_cfg["sharding_stage"], + gbs_cur_cfg["micro_batch_size"], + gbs_cur_cfg["use_recompute"], + gbs_cur_cfg["recompute_granularity"], + ) + ctx.args.log_dir = log_dir + + # every task has own job id + job_id += 1 + task_job_id = "gbs_tuner_" + str(job_id) + ctx.args.job_id = task_job_id + + # generate script args of task + gbs_new_args = gen_new_args( + raw_args, gbs_cur_cfg, gbs_tuner_cfg + ) + ctx.args.training_script_args = gbs_new_args + + # launch task + ctx.logger.info( + "Launch task from auto tuner: job_id {}, log_dir {}, config {}".format( + task_job_id, log_dir, gbs_cur_cfg + ) + ) + c = controllers.init(ctx) + c.run() + + # process generated result + # TODO diffentiate out of memory and no loss(maybe over time) + # TODO integragte memory and metric read + metric, mem, err = read_log( + path=ctx.args.log_dir, + metric_file="workerlog.0", + target_metric=tuner_cfg["metric_cfg"]["name"], + memory_file=f"{ctx.args.job_id}.gpu.log", + ) + + if err & (1 << 0): + ctx.logger.warning( + f"Read metric failed for parameters: {log_dir}" + ) + # for pruner use + gbs_cur_cfg['time'] = -1 + gbs_cur_cfg[tuner_cfg['metric_cfg']['name']] = None + gbs_cur_cfg["max_mem_usage"] = mem + + if err & (1 << 1): + ctx.logger.warning( + f"Out of memory for parameters: {log_dir}" + ) + # for pruner use + gbs_cur_cfg['time'] = -1 + gbs_cur_cfg[tuner_cfg['metric_cfg']['name']] = None + gbs_cur_cfg["max_mem_usage"] = "OOM" + + # not err & (1 << 1): do not record memory usage when out of memory + if err & (1 << 2) and not err & (1 << 1): + ctx.logger.warning( + f"Read memory usage failed for parameters: {log_dir}" + ) + gbs_cur_cfg["max_mem_usage"] = None + + if not err: + # for pruner use + gbs_cur_cfg['time'] = metric + gbs_cur_cfg[tuner_cfg['metric_cfg']['name']] = metric + gbs_cur_cfg["max_mem_usage"] = mem + + if err & (1 << 0) or err & (1 << 1): + # no metric or out of memory, end gbs search + break + + # store and update args for next round + gbs_cur_cfg["job_id"] = job_id + best_gbs = gbs_cur_cfg["global_batch_size"] + recorder.add_cfg(**gbs_cur_cfg) + c.finalize(exit=False) + recorder.store_history("./tuner_gbs_history.csv") + + # new cfgs for next round + gbs_new_cfg = gbs_tuner.search_once() + gbs_cur_cfg = copy.deepcopy(gbs_new_cfg) + gbs_tuner.add_cfg(gbs_cur_cfg) + + # per task launch interval + time.sleep(3) + # prevent no valid global batch size found + if best_gbs is None: + raise ValueError( + "No valid global batch size found, check memory or valid search time. cur_tuner_cfg{}".format( + gbs_tuner_cfg + ) + ) + # set best global batch size to tuner cfg + tuner_cfg["model_cfg"]["global_batch_size"] = best_gbs + + recorder.store_history("./tuner_gbs_history.csv") + recorder.clean_history() + + end_time = time.time() + ctx.logger.info( + f"AtuoTuner for GBS search ends in {end_time-start_time}s." + ) + # build AutoTuner to get new config + auto_tuner = AutoTuner(tuner_cfg) + cur_cfg = auto_tuner.search_once() + auto_tuner.add_cfg(cur_cfg) + while cur_cfg: ctx = copy.deepcopy(raw_ctx) if is_first_task: @@ -401,20 +525,42 @@ def launch(): c.run() # process generated result - metric, err = read_log( + + metric, mem, err = read_log( path=ctx.args.log_dir, - file="workerlog.0", + metric_file="workerlog.0", target_metric=tuner_cfg["metric_cfg"]["name"], + memory_file=f"{ctx.args.job_id}.gpu.log", ) - if err: - ctx.logger.warning(f"Read log failed for parameters: {log_dir}") + + if err & (1 << 0): + ctx.logger.warning( + f"Read metric failed for parameters: {log_dir}" + ) # for pruner use cur_cfg['time'] = -1 cur_cfg[tuner_cfg['metric_cfg']['name']] = None - else: + cur_cfg["max_mem_usage"] = mem + + if err & (1 << 1): + ctx.logger.warning(f"Out of memory for parameters: {log_dir}") + # for pruner use + cur_cfg['time'] = -1 + cur_cfg[tuner_cfg['metric_cfg']['name']] = None + cur_cfg["max_mem_usage"] = "OOM" + + # not err & (1 << 1): do not record memory usage when out of memory + if err & (1 << 2) and not err & (1 << 1): + ctx.logger.warning( + f"Read memory usage failed for parameters: {log_dir}" + ) + cur_cfg["max_mem_usage"] = None + + if not err: # for pruner use cur_cfg['time'] = metric cur_cfg[tuner_cfg['metric_cfg']['name']] = metric + cur_cfg["max_mem_usage"] = mem # record history cur_cfg['job_id'] = job_id