# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import yaml import os import sys import copy import shlex import pathlib import time import shutil import pickle import json import logging import subprocess import traceback import paddle from paddle.fluid import program_guard from paddle.fluid.backward import append_backward from paddle.distributed.passes import new_pass, PassContext from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.process_group import clear_all_process_groups, get_all_process_groups from paddle.distributed.auto_parallel.utils import debug_program from paddle.distributed.auto_parallel.utils import make_data_unshard, set_grad_var_shape from ..utils import get_logger from .config import TuningConfig from .algorithms import new_algorithm from .trial import TrialStatus def _get_new_params_grads(target_program, ref_program, ref_params_grads): ref_block = ref_program.global_block() target_block = target_program.global_block() target_params_grads = [] for p, g in ref_params_grads: # NOTE grad var might not be generated assert ref_block.has_var(p.name) assert target_block.has_var(p.name) new_p = target_block.var(p.name) if g: new_g = target_block.var(g.name) else: new_g = None target_params_grads.append((new_p, new_g)) return target_params_grads def _get_new_loss(target_program, ref_program, loss): ref_block = ref_program.global_block() target_block = target_program.global_block() assert ref_block.has_var(loss.name) return target_block.var(loss.name) def parse_process_groups(): group_map = {} all_process_groups = get_all_process_groups() for process_group in all_process_groups: group_map[process_group.id] = process_group.ranks return group_map def get_metric(results): assert isinstance( results, dict), "results should be type of dictionary, but got {}.".format( type(results)) if 'Throughtput' in results and isinstance(results['Throughtput'], float): return float(results['Throughtput']) else: return -1.0 def parse_results(results): if results['Throughtput'] > 0: return "Throughtput: {} step / s.".format(results['Throughtput']) et = results.get("ErrorType", None) if et == "ResourceExhaustedError": return "Fail with OOM" else: return "Fail with UNKWON ERROR" # TODO only dependent on dist context # all env need to be start a new pass are member of dist context def _copy_context(ref_dist_context): clear_all_process_groups() new_dist_context = DistributedContext() new_dist_context._serial_main_program = ref_dist_context.serial_main_program.clone( for_test=False) new_dist_context._serial_startup_program = ref_dist_context.serial_startup_program.clone( for_test=False) # mapping variable into new dist context if getattr(ref_dist_context, '_params_grads', None): new_dist_context._params_grads = _get_new_params_grads( new_dist_context.serial_main_program, ref_dist_context.serial_main_program, ref_dist_context._params_grads) new_dist_context._serial_loss = _get_new_loss( new_dist_context.serial_main_program, ref_dist_context.serial_main_program, ref_dist_context.serial_loss) for key, var_list in ref_dist_context._serial_feed_vars.items(): new_var_list = [] for var in var_list: block_idx = var.block.idx var_name = var.name var = new_dist_context._serial_main_program.blocks[ block_idx]._var_recursive(var_name) new_var_list.append(var) new_dist_context._serial_feed_vars[key] = new_var_list for key, var_list in ref_dist_context._serial_fetch_vars.items(): new_var_list = [] # metrics is a list of list if key == "metrics": for inner_var_list in var_list: new_inner_var_list = [] for var in inner_var_list: block_idx = var.block.idx var_name = var.name var = new_dist_context._serial_main_program.blocks[ block_idx]._var_recursive(var_name) new_inner_var_list.append(var) new_var_list.append(new_inner_var_list) else: for var in var_list: block_idx = var.block.idx var_name = var.name var = new_dist_context._serial_main_program.blocks[ block_idx]._var_recursive(var_name) new_var_list.append(var) new_dist_context._serial_fetch_vars[key] = new_var_list # copy information in forward and backward new_dist_context._serial_optimizer = copy.deepcopy( ref_dist_context.serial_optimizer) new_dist_context._dist_tensors_for_program = copy.deepcopy( ref_dist_context._dist_tensors_for_program) new_dist_context._dist_ops_for_program = copy.deepcopy( ref_dist_context._dist_ops_for_program) for pm in ref_dist_context.process_meshes: new_dist_context.add_process_mesh(pm) new_dist_context._dist_op_context = copy.deepcopy( ref_dist_context._dist_op_context) new_dist_context._block_state = copy.deepcopy(ref_dist_context.block_state) return new_dist_context class OptimizationTuner: """ OptimizationTuner is used to manage the tuning procedure of hyper-parameters (configs) of Optimization Pass in AutoParallel. """ def __init__( self, user_configs, dist_context, dataset, inputs_spec, labels_spec, batch_size, rank, ): self._config = TuningConfig(user_configs, dist_context._strategy) # should not modify dist context from calling function self._baseline_dist_context = _copy_context(dist_context) self._baseline_completer = Completer(self._baseline_dist_context) self._rank = rank self._inputs_spec = inputs_spec self._labels_spec = labels_spec self._dataset = dataset self._batch_size = batch_size self._finished_trials = [] self._best_metric = None self._best_iter = float("-inf") self._logger = get_logger(logging.INFO) self._build_programs_without_optimization() self._select_tuning_algorithm() @property def project_dir(self): dirname = self._config.project_dir if not os.path.exists(dirname): if self.rank == 0: pathlib.Path(dirname).mkdir(parents=True, exist_ok=True) return dirname @property def rank(self): return self._rank @property def device_id(self): return paddle.distributed.ParallelEnv().device_id # TODO Generate compelet program with all parts like forward, backward, update # as well as parallelism transformation. def _build_programs_without_optimization(self): serial_main_program = self._baseline_dist_context.serial_main_program serial_startup_program = self._baseline_dist_context.serial_startup_program serial_loss = self._baseline_dist_context.serial_loss with program_guard(serial_main_program, serial_startup_program): params_grads = append_backward( serial_loss, distop_context=self._baseline_dist_context.dist_op_context) self._baseline_completer.complete_backward_annotation( serial_main_program) self._baseline_dist_context.block_state.parse_backward_blocks( serial_main_program) self._baseline_dist_context._params_grads = params_grads if self._config.verbose: baseline_dir = os.path.join(self.project_dir, "baseline") if not os.path.exists(baseline_dir): pathlib.Path(baseline_dir).mkdir(parents=True, exist_ok=True) debug_program(self._baseline_dist_context._serial_main_program, baseline_dir, "main") debug_program(self._baseline_dist_context._serial_startup_program, baseline_dir, "startup") def _select_tuning_algorithm(self): selected_passes_set = self._config.tuning_passes_name algorithm_name = "_".join(sorted(selected_passes_set)) self._algorithm = new_algorithm(algorithm_name, self._config) def _apply_optimization(self, trial): new_strategy = trial.space dist_context = _copy_context(self._baseline_dist_context) pass_context = PassContext() completer = Completer(dist_context) main_program = dist_context.serial_main_program startup_program = dist_context.serial_startup_program # applying optimization pass if new_strategy.amp.enable: config = copy.deepcopy(new_strategy.amp.to_dict()) config["dist_context"] = dist_context config["params_grads"] = dist_context._params_grads # TODO AMP Pass should not use loss var config["loss"] = dist_context.serial_loss config["input_data"] = self._baseline_dist_context.serial_feed_vars["inputs"] \ + self._baseline_dist_context.serial_feed_vars["labels"] if config["use_pure_fp16"]: config["base_opt"] = dist_context.serial_optimizer auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass.apply([main_program], [startup_program], pass_context) dist_context.serial_loss = auto_parallel_fp16_pass.get_loss() else: auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) auto_parallel_amp_pass.apply([main_program], [startup_program], pass_context) dist_context.serial_loss = auto_parallel_amp_pass.get_loss() if new_strategy.recompute.enable: config = copy.deepcopy(new_strategy.recompute.to_dict()) config["dist_context"] = dist_context config["no_grad_set"] = None config["loss"] = dist_context.serial_loss auto_parallel_recompute_pass = new_pass("auto_parallel_recompute", config) auto_parallel_recompute_pass.apply([main_program], [startup_program], pass_context) # Do logical partition partitioner = Partitioner(dist_context, self.rank) dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( main_program, startup_program, dist_context._params_grads) # Generate optimizer # FIXME should be remove from apply pass after pass support optimizers with program_guard(dist_main_prog, dist_startup_prog): optimizer_ops = dist_context.serial_optimizer.apply_gradients( dist_params_grads) completer.complete_update_annotation(dist_main_prog) # Do reshard process set_grad_var_shape(dist_main_prog, dist_context) resharder = Resharder(dist_main_prog, dist_startup_prog, self.rank, dist_context, dist_params_grads) resharder.reshard() if new_strategy.sharding.enable: config = copy.deepcopy(new_strategy.sharding.to_dict()) config["dist_context"] = dist_context config["params_grads"] = dist_params_grads config["global_rank"] = self.rank auto_parallel_sharding_pass = new_pass("auto_parallel_sharding", config) auto_parallel_sharding_pass.apply([dist_main_prog], [dist_startup_prog], pass_context) if new_strategy.gradient_merge.enable: config = copy.deepcopy(new_strategy.gradient_merge.to_dict()) config["dist_context"] = dist_context config["params_grads"] = dist_params_grads auto_parallel_gradient_merge_pass = new_pass( "auto_parallel_gradient_merge_pass", config) auto_parallel_gradient_merge_pass.apply([dist_main_prog], [dist_startup_prog], pass_context) trial.main_program, trial.startup_program = dist_main_prog, dist_startup_prog return trial def _get_profile_context(self, trial, result_path): profile_ctx = {} profile_ctx['distributed_env'] = copy.deepcopy( paddle.distributed.ParallelEnv()) profile_ctx['group_map'] = parse_process_groups() profile_ctx[ "loss_var_name"] = self._baseline_dist_context.serial_loss.name profile_ctx[ "main_program_decs"] = trial.main_program.desc.serialize_to_string( ) profile_ctx[ "startup_program_decs"] = trial.startup_program.desc.serialize_to_string( ) self._dataset.batch_size = self._batch_size self._dataset.input_names = self._get_input_names() profile_ctx["dataset"] = self._dataset profile_ctx["result_filename"] = result_path return profile_ctx def _get_input_names(self): input_names = [] for input_spec in self._inputs_spec[:] + self._labels_spec[:]: input_names.append(input_spec.name) return input_names def _launch_profile(self, ctx_path, trial_dir): if os.environ.get("WITH_COVERAGE", "OFF") == "ON": coverage_args = ["-m", "coverage", "run", "--branch", "-p"] else: coverage_args = [] profile_args = " ".join([ "--rank", str(self.rank), "--device_id", str(self.device_id), "--ctx_filename", ctx_path, "--profile_start_step", str(self._config.profile_start_step), "--profile_end_step", str(self._config.profile_end_step), ]) cmd_args = "-m paddle.distributed.auto_parallel.tuner.profiler" + " " + profile_args cmd = [sys.executable, "-u"] + coverage_args + shlex.split(cmd_args) parent_env = copy.copy(os.environ.copy()) # env flags need for profile new_env = { "FLAGS_USE_STANDALONE_EXECUTOR": "False", } new_env.update(parent_env) # TODO if any rank hang or fail, kill all processes self._logger.debug("Executing cmd:\n{} .".format(" ".join(cmd))) # new_process = subprocess.Popen(cmd, env=new_env) with open(os.path.join(trial_dir, "stdout.log" + str(self.rank)), "wb") as out, open( os.path.join(trial_dir, "stderr.log" + str(self.rank)), "wb") as err: result = subprocess.Popen(cmd, stdout=out, stderr=err, env=new_env) result.wait() out.flush() err.flush() os.fsync(out) os.fsync(err) def _profile_trial(self, trial): # Making working directory trial_dir = self._get_trial_dir(trial) if not os.path.exists(trial_dir): if self.rank == 0: pathlib.Path(trial_dir).mkdir(parents=True, exist_ok=True) else: while not os.path.exists(trial_dir): pass ctx_filename = "profile_ctx." + str(self.rank) ctx_path = os.path.join(trial_dir, ctx_filename) result_path = os.path.join(trial_dir, "result.json") # Prepare Profile Context profile_ctx = self._get_profile_context(trial, result_path) with open(ctx_path, 'wb') as f: pickle.dump(profile_ctx, f, protocol=4) if self._config.verbose: debug_program(trial.main_program, trial_dir, "main_program") debug_program(trial.startup_program, trial_dir, "startup_program") # Run self._launch_profile(ctx_path, trial_dir) # Load results try: with open(result_path, 'r') as fp: results = json.load(fp) return results except FileNotFoundError: Error_results = {"Throughtput": -1, "ErrorType": 'FatalError'} return Error_results def _evaluate_trial(self, trial): self._logger.info("Trial {} evaluation start.".format(trial.name)) self._apply_optimization(trial) if self._config.mode == "PROFILE": results = self._profile_trial(trial) elif self._config.mode == "COSTMODEL": raise NotImplementedError( "COSTMODEL mode for optimization tuning is not supported yet!") else: raise NotImplementedError("invalid evaluation mode: {}".format( self._config.mode)) self._logger.info("Trial {} evaluation finish with {}.".format( trial.name, parse_results(results))) return results def _update(self, i, trial, results): self._finished_trials.append(trial) cur_mertic = get_metric(results) if self._best_metric == None or cur_mertic > self._best_metric: self._best_metric = cur_mertic self._best_iter = i def _get_trial_dir(self, trial): return os.path.join(self.project_dir, trial.name) def get_best_config(self): """ Return the best optimization configuration found in the tuning. Returns: A object of fleet.DistributedStrategy with best configuration. """ assert self._best_iter >= 0, "The best configuration is not found yet !" best_trial = self._finished_trials[self._best_iter] return self._algorithm.get_config_from_trial(best_trial) def summary(self): """ Display tuning result summary. """ # TODO summary with the trial_name with metric_of_trial best_trial = self._finished_trials[self._best_iter] summary_ = """ Tuning Result Summary Run total {} trials with {} min. The best trial is: [{}], whose configuration is following: """.format(len(self._finished_trials), (time.time() - self._tuning_start_time) / 60, best_trial.name) summary_ += "\n" + best_trial.summary() + "\n"\ self._logger.info(summary_) with open(os.path.join(self.project_dir, "summary.txt"), "w+") as fw: for line in summary_.split("\n"): fw.write(line + "\n") # full_strategy = self.get_best_config() # path = os.path.join(self.project_dir, "tuned_dist_strategy.yaml") # with open(path, 'w') as outfile: # yaml.dump(full_strategy, outfile, default_flow_style=False) def clear(self): """ Clear the temporary file generated in tuning procedure. """ # TODO clear up zombie process created by tuning if not self._config.verbose: for trial in self._finished_trials: trial_dir = self._get_trial_dir(trial) shutil.rmtree(trial_dir, ignore_errors=True) def tune(self): """ Performs the search for best hyperparameter configuations for the selected optimization pass(es). """ # step1: collect model info which might be used for # pruning the search space of the algorithm self._tuning_start_time = time.time() self._algorithm.collect_model_info( self._baseline_dist_context.serial_main_program, self._baseline_dist_context.serial_startup_program) # main search loop i = 0 while i < self._config.max_num_trial: # step2: create a new trial trial = self._algorithm.next_trial() if trial.status == TrialStatus.STOPPED: break # step3: evaluate the trial results = self._evaluate_trial(trial) # step4: update the algorithm with last result, # which could be used by algorithm to pruning the # remaining search space. self._algorithm.update(results) self._update(i, trial, results) # early stop i += 1 if self._config.early_stop and self._config.early_stop <= i - self._best_iter: self._logger.info( "Early stop the Tuning since there is no better trial found within [{}] trials" .format(self._config.early_stop)) break # step5: summary the best config and return self.summary() self.clear()