From 8da0238b7a7bb8cfa658fda5a71603f99ae1fa3c Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 26 Oct 2022 13:46:09 -0700 Subject: [PATCH] rollback ds config changes (#2395) * rollback ds config changes * fix format * Fix error when output_file is a relative path without a prefix (#2397) Co-authored-by: Benjamin Steenhoek * fix restuls and exprs path to use absolute path * write out optimial config after tuning * fix format * assert tuning result dir creation Co-authored-by: Benjamin Steenhoek Co-authored-by: Michael Wyatt --- deepspeed/autotuning/autotuner.py | 91 +++++++++++++------ deepspeed/autotuning/config.py | 18 ++-- deepspeed/autotuning/constants.py | 7 +- deepspeed/autotuning/scheduler.py | 8 +- deepspeed/launcher/runner.py | 1 + .../profiling/flops_profiler/profiler.py | 2 +- 6 files changed, 81 insertions(+), 46 deletions(-) diff --git a/deepspeed/autotuning/autotuner.py b/deepspeed/autotuning/autotuner.py index b8a67075..a3f72e3a 100755 --- a/deepspeed/autotuning/autotuner.py +++ b/deepspeed/autotuning/autotuner.py @@ -9,7 +9,7 @@ import hjson from ..runtime.config_utils import dict_raise_error_on_duplicate_keys from ..runtime.constants import * -from ..runtime.zero.config import DeepSpeedZeroConfig, ZERO_OPTIMIZATION, ZeroStageEnum +from ..runtime.zero.config import ZERO_OPTIMIZATION, ZeroStageEnum from ..utils import logger from .config import DeepSpeedAutotuningConfig from .constants import * @@ -22,6 +22,11 @@ try: except ImportError: tabulate = None +ZERO_OPTIMIZATION_STAGE = "stage" +OFFLOAD_OPTIMIZER = "offload_optimizer" +OFFLOAD_PARAM = "offload_param" +ZERO_OPTIMIZATION_STAGE_DEFAULT = ZeroStageEnum.disabled + class Autotuner: """The DeepSpeed Autotuner automatically discovers the optimal DeepSpeed configuration that delivers good training speed. The Autotuner uses model information, system information, and heuristics to efficiently tune system knobs that affect compute and memory efficiencies, such as ZeRO optimization stages, micro-batch sizes, and many other ZeRO optimization configurations. It not only reduces the time and resources user spend on tuning, but also can discover configurations better than hand-tuned methods. @@ -39,22 +44,37 @@ class Autotuner: assert self.user_config is not None, "DeepSpeed configuration is not provided" self.autotuning_config = DeepSpeedAutotuningConfig(self.user_config) + if self.user_config[AUTOTUNING]: + if AUTOTUNING_EXPS_DIR in self.user_config[AUTOTUNING].keys(): + del self.user_config[AUTOTUNING][AUTOTUNING_EXPS_DIR] + if AUTOTUNING_RESULTS_DIR in self.user_config[AUTOTUNING].keys(): + del self.user_config[AUTOTUNING][AUTOTUNING_RESULTS_DIR] - self.exps_dir = DEFAULT_EXPRS_DIR - if self.autotuning_config.exps_dir and self.autotuning_config.exps_dir != "": - self.exps_dir = self.autotuning_config.exps_dir + self.exps_dir = self.autotuning_config.exps_dir if self.autotuning_config.overwrite and os.path.exists(self.exps_dir): shutil.rmtree(self.exps_dir, ignore_errors=True) if not os.path.exists(self.exps_dir): - os.makedirs(self.exps_dir, exist_ok=True) + try: + os.makedirs(self.exps_dir, exist_ok=True) + logger.info(f"Created autotuning experiments directory: {self.exps_dir}") + except: + logger.error( + f"Failed to create {self.exps_dir}, please check `exps_dir` in the autotuning config file is accessible by all the nodes in the job." + ) + exit(-1) - self.results_dir = DEFAULT_RESULTS_DIR - if self.autotuning_config.results_dir and self.autotuning_config.results_dir != "": - self.results_dir = self.autotuning_config.results_dir + self.results_dir = self.autotuning_config.results_dir if self.autotuning_config.overwrite and os.path.exists(self.results_dir): shutil.rmtree(self.results_dir, ignore_errors=True) if not os.path.exists(self.results_dir): - os.makedirs(self.results_dir, exist_ok=True) + try: + os.makedirs(self.results_dir, exist_ok=True) + logger.info(f"Created autotuning resutls directory: {self.exps_dir}") + except: + logger.error( + f"Failed to create {self.results_dir}, please check `results_dir` in the autotuning config file is accessible by all the nodes in the job." + ) + exit(-1) # set the active resource for the autotuner resource manager self.rm = self._get_resource_manager(active_resources) @@ -304,8 +324,8 @@ class Autotuner: exps = [] # each zero stage uses a different template configuration file - config_zero = tuning_space.zero_optimization - stage = config_zero.stage + config_zero = tuning_space.get(ZERO_OPTIMIZATION, {}) + stage = config_zero.get(ZERO_OPTIMIZATION_STAGE, ZERO_OPTIMIZATION_STAGE_DEFAULT) template_config = {} if stage == 0: template_path = DEFAULT_TEMPLATE_PATH_ZERO_0 @@ -365,13 +385,12 @@ class Autotuner: # if the config does not use offloading, remove the offloading section config_zero = config.get(ZERO_OPTIMIZATION, None) if config_zero: - if not config_zero.offload_optimizer and 'offload_optimizer' in exp_config[ + if OFFLOAD_OPTIMIZER not in config_zero and OFFLOAD_OPTIMIZER in exp_config[ ZERO_OPTIMIZATION]: - del exp_config[ZERO_OPTIMIZATION]['offload_optimizer'] - if not config_zero.offload_param and 'offload_param' in exp_config[ + del exp_config[ZERO_OPTIMIZATION][OFFLOAD_OPTIMIZER] + if OFFLOAD_PARAM not in config_zero and OFFLOAD_PARAM in exp_config[ ZERO_OPTIMIZATION]: - del exp_config[ZERO_OPTIMIZATION]['offload_param'] - + del exp_config[ZERO_OPTIMIZATION][OFFLOAD_PARAM] # set gradient accumulation steps according to max_train_batch_size_per_gpu mbs = exp_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] gas = max_train_batch_size_per_gpu // mbs @@ -416,7 +435,11 @@ class Autotuner: f"The model requires at least {memory_to_string(self.activation_mem, postfix='B')} activation memory for micro batch size 1." ) - stage = self.user_config.zero_optimization.stage if 'stage' in self.user_config.zero_optimization.__fields_set__ else "all" + #TODO: FIX THIS + stage = self.user_config.get(ZERO_OPTIMIZATION, + {}).get(ZERO_OPTIMIZATION_STAGE, + "all") + stage = "all" user_zero_stages = [stage] if not isinstance(stage, list) else stage logger.info(f"User-defined zero stages are {stage}.") @@ -499,7 +522,7 @@ class Autotuner: prev_best_mbs=0, prev_best_metric_val=0): config_zero = tuning_space.get(ZERO_OPTIMIZATION, {}) - stage = config_zero.stage + stage = config_zero.get(ZERO_OPTIMIZATION_STAGE, None) tuning_space_name = TUNING_MICRO_BATCH_SIZE_PREFIX + str(stage) tuning_micro_batch_sizes = [] max_train_batch_size_per_gpu = 0 @@ -753,7 +776,7 @@ class Autotuner: max_micro_batch_size_metric_val = 0 ds_config = get_first_config(self.user_config) - ds_config[ZERO_OPTIMIZATION] = DeepSpeedZeroConfig(stage=stage) + ds_config[ZERO_OPTIMIZATION] = {ZERO_OPTIMIZATION_STAGE: stage} tuning_space_name = TUNING_MICRO_BATCH_SIZE_PREFIX + str(stage) exp_paths = [] @@ -852,7 +875,7 @@ class Autotuner: tuning_space_name = TUNING_MICRO_BATCH_SIZE_PREFIX + str(stage) ds_config = get_first_config(self.user_config) - ds_config[ZERO_OPTIMIZATION] = DeepSpeedZeroConfig(stage=stage) + ds_config[ZERO_OPTIMIZATION] = {ZERO_OPTIMIZATION_STAGE: stage} gas = self.get_gas_from_user_config() ds_config[GRADIENT_ACCUMULATION_STEPS] = gas @@ -1085,19 +1108,12 @@ class Autotuner: self.rm.clear() return exp, metric_val - def run_after_tuning(self): - """ Launches the training with the optmimal DeepSpeed configuration found through the autotuning process. - "ds_config_optimal.json" describing the optmimal DeepSpeed configuration as well the command used to launch training "cmd_optimal.txt" are saved to self.results_dir. - """ + def write_optimal_config(self): best_space_records = self.get_best_space_records() if GLOBAL_TUNING_SPACE not in best_space_records: return best_exp, best_metric_val, _ = best_space_records[GLOBAL_TUNING_SPACE] if best_exp: - logger.info( - "Start training with the optmimal DeepSpeed configuration found through the tuning process" - ) - exp_dir = best_exp["result_dir"] cmd = None with open(os.path.join(exp_dir, "cmd.txt"), "r") as f: @@ -1117,10 +1133,25 @@ class Autotuner: fd.write(" ".join(cmd)) fd.write("\n") fd.flush() + self.optimal_cmd = cmd + self.optmal_ds_config = ds_config + logger.info( + f"Wrote the optimal DeepSpeed configuration found by autotuning to {ds_config_path}, and the corresponding DeepSpeed command to {cmd_path}" + ) + else: + self.optimal_cmd = None + self.optmal_ds_config = None - result = subprocess.Popen(cmd) + def run_after_tuning(self): + """ Launches the training with the optimal DeepSpeed configuration found through the autotuning process. + "ds_config_optimal.json" describing the optmimal DeepSpeed configuration as well the command used to launch training "cmd_optimal.txt" are saved to self.results_dir. + """ + if self.optimal_cmd: + result = subprocess.Popen(self.optimal_cmd) result.wait() logger.info( - f"Done running with the optimal DeepSpeed configuration found by autotuning: {ds_config_path}" + f"Done running with the optimal DeepSpeed configuration using {self.optimal_cmd}" ) + else: + logger.info(f"No optimal DeepSpeed configuration found by autotuning.") diff --git a/deepspeed/autotuning/config.py b/deepspeed/autotuning/config.py index dea36f03..f3a658a0 100644 --- a/deepspeed/autotuning/config.py +++ b/deepspeed/autotuning/config.py @@ -38,14 +38,16 @@ class DeepSpeedAutotuningConfig(DeepSpeedConfigObject): AUTOTUNING_FAST, AUTOTUNING_FAST_DEFAULT) - self.results_dir = get_scalar_param(autotuning_dict, - AUTOTUNING_RESULTS_DIR, - AUTOTUNING_RESULTS_DIR_DEFAULT) - - self.exps_dir = get_scalar_param(autotuning_dict, - AUTOTUNING_EXPS_DIR, - AUTOTUNING_EXPS_DIR_DEFAULT) - + self.results_dir = os.path.abspath( + get_scalar_param(autotuning_dict, + AUTOTUNING_RESULTS_DIR, + AUTOTUNING_RESULTS_DIR_DEFAULT)) + assert self.results_dir, "results_dir cannot be empty" + self.exps_dir = os.path.abspath( + get_scalar_param(autotuning_dict, + AUTOTUNING_EXPS_DIR, + AUTOTUNING_EXPS_DIR_DEFAULT)) + assert self.exps_dir, "exps_dir cannot be empty" self.overwrite = get_scalar_param(autotuning_dict, AUTOTUNING_OVERWRITE, AUTOTUNING_OVERWRITE_DEFAULT) diff --git a/deepspeed/autotuning/constants.py b/deepspeed/autotuning/constants.py index 3bfcd272..6d1c530a 100644 --- a/deepspeed/autotuning/constants.py +++ b/deepspeed/autotuning/constants.py @@ -22,9 +22,6 @@ DEFAULT_TEMPLATE_PATH_ZERO_3 = os.path.join(os.path.dirname(os.path.realpath(__f "config_templates", "template_zero3.json") -DEFAULT_EXPRS_DIR = os.path.join(os.getcwd(), "autotuning_exps") -DEFAULT_RESULTS_DIR = os.path.join(os.getcwd(), "autotuning_results") - METRIC_PERCENT_DIFF_CONST = 0.05 DS_CONFIG = "ds_config" BUFSIZE = 1 # line buffer size for writing files @@ -54,10 +51,10 @@ AUTOTUNING_FAST = "fast" AUTOTUNING_FAST_DEFAULT = True AUTOTUNING_RESULTS_DIR = "results_dir" -AUTOTUNING_RESULTS_DIR_DEFAULT = None +AUTOTUNING_RESULTS_DIR_DEFAULT = "autotuning_results" AUTOTUNING_EXPS_DIR = "exps_dir" -AUTOTUNING_EXPS_DIR_DEFAULT = None +AUTOTUNING_EXPS_DIR_DEFAULT = "autotuning_exps" AUTOTUNING_OVERWRITE = "overwrite" AUTOTUNING_OVERWRITE_DEFAULT = True diff --git a/deepspeed/autotuning/scheduler.py b/deepspeed/autotuning/scheduler.py index 4f91f3cc..1c6bef65 100755 --- a/deepspeed/autotuning/scheduler.py +++ b/deepspeed/autotuning/scheduler.py @@ -374,7 +374,9 @@ def run_experiment(exp: dict, reservations, user_script, user_args): fd.flush() os.fsync(fd) - logger.info(f"Launching exp_id = {exp['exp_id']}, exp_name = {exp['name']}") + logger.info( + f"Launching exp_id = {exp['exp_id']}, exp_name = {exp['name']}, with resource = {include_str}" + ) with open(os.path.join(exp_dir, "stdout.log"), "wb") as out, open( os.path.join(exp_dir, "stderr.log"), "wb" @@ -388,7 +390,9 @@ def run_experiment(exp: dict, reservations, user_script, user_args): clean_up(exp, reservations) - logger.info(f"Done running exp_id = {exp['exp_id']}, exp_name = {exp['name']}") + logger.info( + f"Done running exp_id = {exp['exp_id']}, exp_name = {exp['name']}, with resource = {include_str}" + ) PDSH_MAX_FAN_OUT = 1024 diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index bb78c0d0..cd9f90e6 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -329,6 +329,7 @@ def run_autotuning(args, active_resources): tuner.print_tuning_results() logger.info("[End] Running autotuning") + tuner.write_optimal_config() if args.autotuning == "run": tuner.run_after_tuning() diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index c112d7a2..a93c6fc2 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -254,7 +254,7 @@ class FlopsProfiler(object): original_stdout = None f = None if output_file and output_file != "": - dir_path = os.path.dirname(output_file) + dir_path = os.path.dirname(os.path.abspath(output_file)) if not os.path.exists(dir_path): os.makedirs(dir_path) original_stdout = sys.stdout -- GitLab