diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 857245b9be4257ae9536b8d53c614cf8e18d3f96..ce72304dc75cdf732bf0362c6a82d08ec134155b 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -54,7 +54,7 @@ set_field_default_config(BASE, "reinit", False) # Only for debug ######################################### RECOMPUTE = "recompute" set_field_default_config(RECOMPUTE, "enable", False) -set_field_default_config(RECOMPUTE, "checkpoints", None) +set_field_default_config(RECOMPUTE, "checkpoints", []) set_field_default_config(RECOMPUTE, "no_recompute_segments", []) set_field_default_config(RECOMPUTE, "enable_tuning", False) @@ -113,12 +113,10 @@ set_field_default_config(QAT, "algo", None) # ######################################### TUNING = "tuning" set_field_default_config(TUNING, "enable", False) -set_field_default_config(TUNING, "batch_size", 1) -set_field_default_config(TUNING, "dataset", None) set_field_default_config(TUNING, "profile_start_step", 1) set_field_default_config(TUNING, "profile_end_step", 1) set_field_default_config(TUNING, "run_after_tuning", True) -set_field_default_config(TUNING, "verbose", True) +set_field_default_config(TUNING, "debug", False) ######################################### # dataset configuration diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 092212a87168b5481f57c451afa58a739bfdc74c..dc7470283aef8859d3c46462a117f3b6bb427e6d 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -609,7 +609,9 @@ class Engine: if mode != "train": serial_main_prog = serial_main_prog.clone(for_test=True) - auto_utils.set_recompute_ckpts(self._model, self._strategy) + auto_utils.set_recompute_segments( + self._model, self._losses, self._strategy, serial_main_prog + ) self._dist_contexts[mode] = DistributedContext( serial_main_prog, serial_startup_prog, @@ -649,7 +651,6 @@ class Engine: from .tuner.optimization_tuner import OptimizationTuner self._optimization_tuner = OptimizationTuner( - self._tuning.to_dict(), self._dist_contexts[mode], dataset, self._inputs_spec, diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index 4d626bb6ae495046bf6a3015d8d2bfaee52e5375..7e6b98665a8d011bed4e051a45bba701eda9ce03 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -73,6 +73,10 @@ class BaseConfig: setattr(result, k, copy.deepcopy(v, memo)) return result + def get(self, k, d=None): + result_dict = self.to_dict() + return result_dict.get(k, d) + class RecomputeConfig(BaseConfig): def __init__(self, config_dict=None): diff --git a/python/paddle/distributed/auto_parallel/tuner/algorithms.py b/python/paddle/distributed/auto_parallel/tuner/algorithms.py index 8ce570d03c2881c91f9f3dd188e3f9bf4b1de90f..74e8f3e9ee3f1d7e3fd57fb1ddf0d8961fa2ceb9 100644 --- a/python/paddle/distributed/auto_parallel/tuner/algorithms.py +++ b/python/paddle/distributed/auto_parallel/tuner/algorithms.py @@ -16,7 +16,7 @@ import copy import logging from abc import ABC, abstractmethod -from ..utils import get_logger +from ..utils import get_logger, is_recompute_op from .trial import OptimizationTunerTrial as Trial from .trial import TrialStatus @@ -54,7 +54,7 @@ class AlgorithmBase(ABC): def collect_model_info(self, main_prog, startup_prog): """ Collect the model static info (from programs) that could be used to - pruning candidate trials and saving tuning time.For instance, + pruning candidate trials and saving tuning time. For instance, model info like number of model parameters and activation memory could be used to prune candidated trial and decide the next trial. """ @@ -116,7 +116,7 @@ class ShardingStageAlgorithm(AlgorithmBase): self._max_stage = 3 self._trial_idx = 0 - stage_range = self._config.sharding.to_dict().get("tuning_range", None) + stage_range = self._config.sharding.get("tuning_range", None) if stage_range: assert set(stage_range).issubset( set([0, 1, 2, 3]) @@ -157,3 +157,92 @@ class ShardingStageAlgorithm(AlgorithmBase): ) else: self._trial_idx += 1 + + +@register_algor("recompute") +class ReccomputeCheckpointAlgorithm(AlgorithmBase): + def __init__(self, config): + super().__init__(config) + self._changed_configs = ["recompute"] + + def collect_model_info(self, main_prog, startup_prog): + segments = [] + for op in main_prog.global_block().ops: + if not is_recompute_op(op): + continue + + seg_name = op.attr('op_namescope') + if seg_name not in segments: + segments.append(seg_name) + + self._total_num_trial = len(segments) + self._tuning_segments = list(range(len(segments))) + self._trail_left = 0 + self._trail_right = len(segments) - 1 + self._trial_idx = int(0 + (len(segments) - 1) / 2) + + def _init_spaces(self): + self._recompute_mode = "all" + + def next_trial(self): + if self._trial_idx < self._total_num_trial: + if self._recompute_mode == "all": + self._recompute_flag = False + new_strategy = copy.deepcopy(self._config.dist_strategy) + name = "trial-recompute-all-segments" + return Trial(new_strategy, name, self.changed_configs) + elif self._recompute_mode == "none": + self._recompute_flag = False + new_strategy = copy.deepcopy(self._config.dist_strategy) + recompute = new_strategy.recompute + recompute.enable = False + name = "trial-recompute-none-segments" + return Trial(new_strategy, name, self.changed_configs) + elif self._recompute_mode == "part": + new_no_recompute = self._tuning_segments[: self._trial_idx] + new_strategy = copy.deepcopy(self._config.dist_strategy) + recompute = new_strategy.recompute + recompute.no_recompute_segments.extend(new_no_recompute) + name = "trial-recompute-part-segments-idx{}".format( + self._trial_idx + ) + return Trial(new_strategy, name, self.changed_configs) + else: + return Trial(None, None, None, status=TrialStatus.STOPPED) + + def update(self, results): + + et = results.get("ErrorType", None) + if self._recompute_mode == "all": + if et and et == "ResourceExhaustedError": + self._trial_idx = self._total_num_trial + self._logger.info( + "Recompute all candidate segments is failed with OOM, please reduce model size or batch size." + ) + else: + self._recompute_mode = "none" + elif self._recompute_mode == "none": + if et and et == "ResourceExhaustedError": + self._recompute_mode = "part" + else: + self._trial_idx = self._total_num_trial + self._logger.info( + "Recompute is unnecessary for this model size, which will reduce the Throughtput." + ) + else: + if self._trail_left >= self._trail_right: + self._trial_idx = self._total_num_trial + elif et and et == "ResourceExhaustedError": + self._trail_left = self._trail_left + self._trail_right = self._trial_idx - 1 + self._trial_idx = int( + self._trail_left + + (self._trail_right - self._trail_left) / 2 + ) + else: + self._trail_left = self._trial_idx + 1 + self._trail_right = self._trail_right + self._trial_idx = int( + self._trail_left + + (self._trail_right - self._trail_left) / 2 + ) diff --git a/python/paddle/distributed/auto_parallel/tuner/config.py b/python/paddle/distributed/auto_parallel/tuner/config.py index f47ec1ae2d04160ae36df1110c42dee8acc1bfe3..78f94b87b360b32f25d2a3b3c4e5c677586780b5 100644 --- a/python/paddle/distributed/auto_parallel/tuner/config.py +++ b/python/paddle/distributed/auto_parallel/tuner/config.py @@ -32,14 +32,11 @@ class TuningConfig: tuning config: configuration for the tuning process: mode (profile or cost model), log dir, extra tuning config for optimization like search range for specific """ - def __init__(self, user_config, strategy): + def __init__(self, strategy): if not isinstance(strategy, Strategy): raise TypeError("'strategy' must be object of class `Strategy`.") - if not user_config: - user_config = {} - self._tuning_passes_name = set() self._dist_strategy = copy.deepcopy(strategy) self._mode = None @@ -48,9 +45,9 @@ class TuningConfig: self._project_dir = None self._max_num_trial = None self._early_stop = None - self._verbose = None + self._debug = None - self._initialize(user_config) + self._initialize() @property def mode(self): @@ -81,29 +78,25 @@ class TuningConfig: return self._early_stop @property - def verbose(self): - return self._verbose + def debug(self): + return self._debug @property def dist_strategy(self): return self._dist_strategy # initialize config with user define value or default value - def _initialize(self, user_config): - - self._mode = user_config.get("mode", "PROFILE") - - self._profile_start_step = user_config.get("profile_start_step", 10) - - self._profile_end_step = user_config.get("profile_end_step", 30) - - self._max_num_trial = user_config.get("max_num_trial", 50) - - self._early_stop = user_config.get("early_stop", None) + def _initialize(self): + tuning_strategy = self._dist_strategy.tuning - self._verbose = user_config.get("verbose", False) + self._mode = tuning_strategy.get("mode", "PROFILE") + self._profile_start_step = tuning_strategy.get("profile_start_step", 10) + self._profile_end_step = tuning_strategy.get("profile_end_step", 30) + self._max_num_trial = tuning_strategy.get("max_num_trial", 50) + self._early_stop = tuning_strategy.get("early_stop", None) + self._debug = tuning_strategy.get("debug", False) - project_dir = user_config.get("project_dir", None) + project_dir = tuning_strategy.get("project_dir", None) if not project_dir: project_dir = os.path.join(os.getcwd(), "OptimizationTuning") self._project_dir = project_dir @@ -116,15 +109,14 @@ class TuningConfig: # TODO distinguish different args of each passes self._tuning_passes_name.add(p) - config_name = p - p_dict = getattr(self._dist_strategy, config_name) - self.__dict__[config_name] = p_dict + p_strategy = getattr(self._dist_strategy, p) + self.__dict__[p] = p_strategy - # TODO verify the user defined configs - user_config_for_pass = user_config.get(p, None) - if user_config_for_pass: - for k, v in user_config_for_pass.items(): - self.__dict__[config_name][k] = v + # # TODO verify the user defined configs + # tuning_config_for_pass = tuning_strategy.get(p, None) + # if tuning_config_for_pass: + # for k, v in tuning_config_for_pass.items(): + # self.__dict__[p][k] = v # (NOTE)tuning config ONLY wraps dist strategy for pass config which is to be tuned def __getattr__(self, item): diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index 8a2867a315d3e0f795a29ead793fcfd5943fecec..c3de081c752bab614d6cc6946d0d6dd7c99bea68 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -33,6 +33,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.process_group import ( clear_all_process_groups, get_all_process_groups, + new_process_group, ) from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.utils import ( @@ -40,7 +41,7 @@ from paddle.distributed.auto_parallel.utils import ( set_grad_var_shape, ) from paddle.distributed.passes import PassContext, new_pass -from paddle.fluid import program_guard +from paddle.fluid import program_guard, unique_name from paddle.fluid.backward import append_backward from ..utils import get_logger @@ -109,7 +110,12 @@ def parse_results(results): # all env need to be start a new pass are member of dist context def _copy_context(ref_dist_context): + # clear all process groups and recover the world process group clear_all_process_groups() + ranks = [] + for process_mesh in ref_dist_context._process_meshes: + ranks.extend(process_mesh.processes) + new_process_group(list(set(ranks))) new_dist_context = DistributedContext() new_dist_context._serial_main_program = ( @@ -195,7 +201,6 @@ class OptimizationTuner: def __init__( self, - user_configs, dist_context, dataset, inputs_spec, @@ -204,7 +209,7 @@ class OptimizationTuner: rank, ): - self._config = TuningConfig(user_configs, dist_context._strategy) + self._config = TuningConfig(dist_context.strategy) # should not modify dist context from calling function self._baseline_dist_context = _copy_context(dist_context) self._baseline_completer = Completer(self._baseline_dist_context) @@ -264,7 +269,7 @@ class OptimizationTuner: ) self._baseline_dist_context._params_grads = params_grads - if self._config.verbose: + if self._config.debug: baseline_dir = os.path.join(self.project_dir, "baseline") if not os.path.exists(baseline_dir): pathlib.Path(baseline_dir).mkdir(parents=True, exist_ok=True) @@ -299,7 +304,6 @@ class OptimizationTuner: config = copy.deepcopy(new_strategy.amp.to_dict()) config["dist_context"] = dist_context config["params_grads"] = dist_context._params_grads - # TODO AMP Pass should not use loss var config["loss"] = dist_context.serial_loss config["input_data"] = ( @@ -312,13 +316,13 @@ class OptimizationTuner: auto_parallel_fp16_pass.apply( [main_program], [startup_program], pass_context ) - dist_context.serial_loss = auto_parallel_fp16_pass.get_loss() + dist_context._serial_loss = auto_parallel_fp16_pass.get_loss() else: auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) auto_parallel_amp_pass.apply( [main_program], [startup_program], pass_context ) - dist_context.serial_loss = auto_parallel_amp_pass.get_loss() + dist_context._serial_loss = auto_parallel_amp_pass.get_loss() if new_strategy.recompute.enable: config = copy.deepcopy(new_strategy.recompute.to_dict()) @@ -345,9 +349,10 @@ class OptimizationTuner: # Generate optimizer # FIXME should be remove from apply pass after pass support optimizers with program_guard(dist_main_prog, dist_startup_prog): - optimizer_ops = dist_context.serial_optimizer.apply_gradients( - dist_params_grads - ) + with unique_name.guard("opt_"): + optimizer_ops = dist_context.serial_optimizer.apply_gradients( + dist_params_grads + ) completer.complete_update_annotation(dist_main_prog) # Do reshard process @@ -361,6 +366,13 @@ class OptimizationTuner: ) resharder.reshard() + config = {} + config["dist_context"] = dist_context + config["global_rank"] = self.rank + config["use_sharding"] = new_strategy.sharding.enable + dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) + dp_pass.apply([dist_main_prog], [dist_startup_prog], pass_context) + if new_strategy.sharding.enable: config = copy.deepcopy(new_strategy.sharding.to_dict()) config["dist_context"] = dist_context @@ -372,6 +384,17 @@ class OptimizationTuner: auto_parallel_sharding_pass.apply( [dist_main_prog], [dist_startup_prog], pass_context ) + dist_params_grads = pass_context.get_attr("params_grads") + + # gradient clip + config = copy.deepcopy(new_strategy.sharding.to_dict()) + config["dist_context"] = dist_context + config["params_grads"] = dist_params_grads + config["rank_id"] = self.rank + auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config) + auto_parallel_clip_pass.apply( + [dist_main_prog], [dist_startup_prog], pass_context + ) if new_strategy.gradient_merge.enable: config = copy.deepcopy(new_strategy.gradient_merge.to_dict()) @@ -488,7 +511,7 @@ class OptimizationTuner: with open(ctx_path, 'wb') as f: pickle.dump(profile_ctx, f, protocol=4) - if self._config.verbose: + if self._config.debug: debug_program(trial.main_program, trial_dir, "main_program") debug_program(trial.startup_program, trial_dir, "startup_program") @@ -581,7 +604,7 @@ The best trial is: [{}], whose configuration is following: Clear the temporary file generated in tuning procedure. """ # TODO clear up zombie process created by tuning - if not self._config.verbose: + if not self._config.debug: for trial in self._finished_trials: trial_dir = self._get_trial_dir(trial) shutil.rmtree(trial_dir, ignore_errors=True) diff --git a/python/paddle/distributed/auto_parallel/tuner/profiler.py b/python/paddle/distributed/auto_parallel/tuner/profiler.py index 4a4dfea7631575a9906cf49f29be4b372900a106..cdd4a0045c8c9c122ee6529fff298529293d6fbd 100644 --- a/python/paddle/distributed/auto_parallel/tuner/profiler.py +++ b/python/paddle/distributed/auto_parallel/tuner/profiler.py @@ -89,7 +89,7 @@ def init_process_groups(group_map, rank): # TODO should instantiate global group first all_process_groups = get_all_process_groups() for process_group in all_process_groups: - if process_group.id == 0 or rank not in process_group.ranks: + if rank not in process_group.ranks: continue print(process_group) process_group.instantiate() @@ -173,10 +173,11 @@ def init_comm(profile_ctx): genv = _get_global_env() genv = dist_env print( - "current process rank: {}, device_id: {}, ip: {}.", - genv.rank, - genv.device_id, - genv.current_endpoint, + "current process rank: {}, device_id: {}, ip: {}.".format( + genv.rank, + genv.device_id, + genv.current_endpoint, + ) ) # init nccl comm @@ -231,13 +232,12 @@ def profiler(args): exe = get_executor() - exe.run(startup_program) - - # profile main - duration = 0 - eval_step = 0 - data_loader._inner_dataloader.start() try: + exe.run(startup_program) + # profile main + duration = 0 + eval_step = 0 + data_loader._inner_dataloader.start() while eval_step < args.profile_end_step: start_time = time.time() diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 0883417fc9e82ca2fce373b831311747773ab811..4d474569fb3ebcc2f09c561136d917bf811515e1 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -22,18 +22,17 @@ from functools import reduce import numpy as np import paddle -import paddle.fluid.core as core -from paddle.distributed.auto_parallel.dist_attribute import ( +from paddle.fluid.framework import Variable +from paddle.fluid.io import is_belong_to_optimizer, is_parameter +from paddle.framework import core + +from .dist_attribute import ( OperatorDistributedAttribute, TensorDistributedAttribute, ) -from paddle.distributed.auto_parallel.process_group import ( - get_all_process_groups, -) -from paddle.distributed.fleet.meta_optimizers.common import OpRole -from paddle.fluid.framework import Variable -from paddle.fluid.io import is_belong_to_optimizer, is_parameter +from .process_group import get_all_process_groups +OpRole = core.op_proto_and_checker_maker.OpRole OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() __no_shape_var_type__ = [ @@ -1921,10 +1920,16 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank): server_socket.close() -def set_recompute_ckpts(model, strategy): - from .interface import _g_recompute_idx +def is_recompute_op(op): + return op.has_attr('op_namescope') and "/auto_parallel/rc" in op.attr( + 'op_namescope' + ) - if _g_recompute_idx > -1: + +def set_recompute_segments(model, losses, strategy, program): + from ..passes.auto_parallel_recompute import RecomputeState + + if not losses: return recompute = strategy.recompute @@ -1934,24 +1939,65 @@ def set_recompute_ckpts(model, strategy): # NOTE: hack to enable recompute in engine api for GPT-3 # TODO support more PaddleNLP/CV models here # extract ckpts by specific model + ckpts = [] if isinstance(model, paddle.nn.Layer): - if hasattr(model, "gpt") and model.__class__.__name__ in [ - 'GPTForPretraining', - 'GPTForPretrainingAuto', - ]: - exact_ckpts = model.gpt.checkpoints + if ( + hasattr(model, "gpt") + and model.__class__.__name__ + in [ + 'GPTForPretraining', + 'GPTForPretrainingAuto', + ] + and hasattr(model.gpt, "checkpoints") + ): + ckpts = model.gpt.checkpoints else: - exact_ckpts = recompute.checkpoints + ckpts = recompute.checkpoints else: - exact_ckpts = recompute.checkpoints + ckpts = recompute.checkpoints - # modify strategy - recompute.checkpoints = exact_ckpts[:] - logs = { - 'Model Class': model.__class__.__name__, - 'Applied Recompute ckpts': exact_ckpts, - } - logging.info(logs) + if not ckpts: + return + + block = program.global_block() + rc_state = RecomputeState(block, block.ops) + rc_state.build_stats() + checkpoints = rc_state.sort_checkpoints(ckpts) + + segments = [] + start_idx = -1 + pre_segment_end_idx = -1 + while start_idx + 1 < len(checkpoints): + if start_idx == -1: + ckpt_name = checkpoints[start_idx + 1] + if ckpt_name not in rc_state.var_op_deps: + start_idx += 1 + continue + op_idx_list = rc_state.var_op_deps[ckpt_name]["var_as_output_ops"] + if op_idx_list and max(op_idx_list) > 0: + segments.append([0, max(op_idx_list) + 1]) + else: + flag, min_idx, max_idx = rc_state.is_subgraph( + [checkpoints[start_idx]], [checkpoints[start_idx + 1]] + ) + if flag: + min_idx = rc_state._update_segment_start( + min_idx, pre_segment_end_idx + ) + segments.append([min_idx, max_idx + 1]) + else: + logging.debug( + "Could not recompute op range [{}] - [{}] ".format( + min_idx, max_idx + 1 + ) + ) + start_idx += 1 + + for i, segment in enumerate(segments): + for j in range(segment[0], segment[1]): + block.ops[j]._set_attr( + 'op_namescope', "/auto_parallel/rc_" + str(i) + ) def get_input_split_info(cur_rank, var, dist_context): diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index e96cd4ec77d8f1b74505c74e9526d5aa3792e7d6..cba613676d58de12ac44fa8ffa3e412c4002ef93 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -226,6 +226,9 @@ class AMPState: dist_context, out_var, ref_mapping, ref_mesh ) + op_namescope = "/" + if op.has_attr('op_namescope'): + op_namescope = op.attr('op_namescope') cast_op = self._block._insert_op_without_sync( idx, type="cast", @@ -236,6 +239,9 @@ class AMPState: "out_dtype": out_var.dtype, }, ) + cast_op._set_attr( + 'op_namescope', op_namescope + ) # for recompute naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, ref_mapping, dist_context ) diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 7aed31b01ec2b2561ec793612ac1418801f871a2..0e834343e2800bacd2ba043c1aa4eab3d3b5fb0c 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -22,13 +22,12 @@ from paddle.distributed.auto_parallel.process_group import ( get_world_process_group, ) from paddle.distributed.auto_parallel.utils import ( - OP_ROLE_KEY, - OpRole, is_backward_op, is_forward_op, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, ) +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.fluid import unique_name from paddle.fluid.contrib.mixed_precision.fp16_utils import ( AutoMixedPrecisionLists, @@ -417,6 +416,9 @@ class FP16State: dist_context, cast_var, ref_mapping, ref_mesh ) + op_namescope = "/" + if op.has_attr('op_namescope'): + op_namescope = op.attr('op_namescope') cast_op = block._insert_op_without_sync( idx, type="cast", @@ -428,6 +430,9 @@ class FP16State: OP_ROLE_KEY: OpRole.Forward, }, ) + cast_op._set_attr( + 'op_namescope', op_namescope + ) # for recompute naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, ref_mapping, dist_context ) diff --git a/python/paddle/distributed/passes/auto_parallel_grad_clip.py b/python/paddle/distributed/passes/auto_parallel_grad_clip.py index af5259680e4a596042c5307c416c443a2477a104..7258eca661d63e5d1bc0e852b16d16d5dc661b11 100644 --- a/python/paddle/distributed/passes/auto_parallel_grad_clip.py +++ b/python/paddle/distributed/passes/auto_parallel_grad_clip.py @@ -17,6 +17,7 @@ from functools import reduce import numpy as np import paddle +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from ..auto_parallel.dist_attribute import ( OperatorDistributedAttribute, @@ -25,8 +26,6 @@ from ..auto_parallel.dist_attribute import ( from ..auto_parallel.process_group import get_world_process_group from ..auto_parallel.reshard import Resharder from ..auto_parallel.utils import ( - OP_ROLE_KEY, - OpRole, _get_comm_group, insert_dependencies_for_two_vars, is_gradient_clip_op, diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index c4ccb89d2f56fc846387fb5f88f4b9b39d7b1b0d..1ec482e5cdfdcbb38f116170d33ce823837a1175 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -19,12 +19,11 @@ from paddle.distributed.auto_parallel.process_group import ( get_world_process_group, ) from paddle.distributed.auto_parallel.utils import ( - OP_ROLE_KEY, - OpRole, is_optimize_op, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, ) +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.fluid import layers from paddle.fluid.framework import device_guard from paddle.framework import core diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index aa213e24322323567736d9874a48d9a967e88459..d99f335517a1675450a0943af6aa536747a63809 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -14,19 +14,8 @@ import logging -from paddle.distributed.auto_parallel.dist_attribute import ( - OperatorDistributedAttribute, -) -from paddle.distributed.auto_parallel.utils import ( - get_loss_op, - insert_dependencies_for_two_ops, - naive_set_dist_op_attr_for_program_by_mesh_and_mapping, - set_dist_op_desc_original_id, - set_var_dist_attr, -) -from paddle.fluid import core -from paddle.fluid import framework as framework -from paddle.fluid import unique_name +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole +from paddle.fluid import core, framework, unique_name from paddle.fluid.backward import ( ProgramStats, _append_grad_suffix_, @@ -35,28 +24,43 @@ from paddle.fluid.backward import ( _rename_arg_, ) +from ..auto_parallel.dist_attribute import OperatorDistributedAttribute +from ..auto_parallel.utils import ( + get_loss_op, + insert_dependencies_for_two_ops, + is_backward_op, + is_recompute_op, + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, + set_dist_op_desc_original_id, + set_var_dist_attr, +) from .pass_base import PassBase, register_pass -def _to_be_recomputed(op): - return op.has_attr('op_namescope') and "/auto_parallel/rc_" in op.attr( - 'op_namescope' - ) - - class RecomputeState(ProgramStats): def __init__(self, block, ops): super().__init__(block=block, ops=ops) - self._block = block - self._ops = ops - # {varname: {as_input_ops: op_idx, as_output_ops: op_idx}} - self.var_op_deps = {} - # {segment_name: op_idx} self.seg_op_deps = {} + self._checkpoints = [] + self._reserved_vars = [] + + @property + def checkpoints(self): + return self._checkpoints + + @property + def reserved_vars(self): + return self._reserved_vars - def build_stats(self): - for i, op in enumerate(self._ops): - for name in op.desc.input_arg_names(): + def is_recompute(self): + return any([is_recompute_op(op) for op in self.ops]) + + def build_states(self): + for i, op in enumerate(self.ops): + if is_backward_op(op): + break + + for name in op.input_arg_names: if name in self.var_op_deps: self.var_op_deps[name]["var_as_input_ops"].extend([i]) else: @@ -64,7 +68,7 @@ class RecomputeState(ProgramStats): self.var_op_deps[name]["var_as_input_ops"] = [i] self.var_op_deps[name]["var_as_output_ops"] = [] - for name in op.desc.output_arg_names(): + for name in op.output_arg_names: if name in self.var_op_deps: self.var_op_deps[name]["var_as_output_ops"].extend([i]) else: @@ -72,7 +76,8 @@ class RecomputeState(ProgramStats): self.var_op_deps[name]["var_as_input_ops"] = [] self.var_op_deps[name]["var_as_output_ops"] = [i] - if not _to_be_recomputed(op): + if not is_recompute_op(op): + self._checkpoints.extend(op.output_arg_names) continue seg_name = op.attr('op_namescope') @@ -84,97 +89,42 @@ class RecomputeState(ProgramStats): ), "The recompute segment's ops should be continuous" self.seg_op_deps[seg_name].extend([i]) - def get_recompute_segments( - self, checkpoints_list=None, no_recompute_segments=[] - ): - """get recompute segments and checkpoints""" + def get_recompute_segments(self, no_recompute_segments=[]): segments = [] - checkpoints = checkpoints_list or [] - - if len(checkpoints) == 0: - # the segments is marked by `auto.recompute()` api - for segment_idx in self.seg_op_deps.values(): - if len(segment_idx) == 1: - continue - segments.append([segment_idx[0], segment_idx[-1] + 1]) - checkpoints.extend(self._ops[segment_idx[-1]].output_arg_names) - else: - # the segments is marked by `strategy.checkpoints` api - start_idx = -1 - pre_segment_end_idx = -1 - while start_idx + 1 < len(checkpoints): - if start_idx == -1: - ckpt_name = checkpoints[start_idx + 1] - if ckpt_name not in self.var_op_deps: - start_idx += 1 - continue - op_idx_list = self.var_op_deps[ckpt_name][ - "var_as_output_ops" - ] - if op_idx_list: - segments.append([0, max(op_idx_list) + 1]) - else: - flag, min_idx, max_idx = self.is_subgraph( - [checkpoints[start_idx]], [checkpoints[start_idx + 1]] - ) - if flag: - min_idx = self._update_segment_start( - min_idx, pre_segment_end_idx - ) - segments.append([min_idx, max_idx + 1]) - else: - logging.info( - "Could not recompute op range [{}] - [{}] ".format( - min_idx, max_idx + 1 - ) - ) - start_idx += 1 - - if no_recompute_segments: - for i in reversed(sorted(no_recompute_segments)): - assert i < len( - segments - ), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format( - i, len(segments) - ) - segments.pop(i) - - for i, (idx1, idx2) in enumerate(segments): - logging.info("recompute segment[{}]".format(i)) - logging.info( - "segment start op: [{}]: [{}] [{}]".format( - self._ops[idx1].desc.type(), - self._ops[idx1].desc.input_arg_names(), - self._ops[idx1].desc.output_arg_names(), - ) - ) - logging.info( - "segment end op: [{}]: [{}] [{}]".format( - self._ops[idx2 - 1].desc.type(), - self._ops[idx2 - 1].desc.input_arg_names(), - self._ops[idx2 - 1].desc.output_arg_names(), - ) + for segment_idx in self.seg_op_deps.values(): + if len(segment_idx) == 1: + continue + segments.append([segment_idx[0], segment_idx[-1] + 1]) + self._checkpoints.extend(self.ops[segment_idx[-1]].output_arg_names) + + for i in reversed(sorted(no_recompute_segments)): + assert i < len( + segments + ), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format( + i, len(segments) ) + segments.pop(i) - return segments, checkpoints - - def is_recompute(self): - return any([_to_be_recomputed(op) for op in self._ops]) + return segments def modify_forward_desc_for_recompute(self, dist_context): """ If program's foward part has 'dropout' op, this function will insert a seed op before it to guarantee that two dropout op have the same outputs. """ - op_types = [op.desc.type() for op in self._ops] + op_types = [op.type for op in self.ops] if "dropout" not in op_types: return op_idx = 0 - while op_idx < len(self._ops): - cur_op = self._ops[op_idx] + while op_idx < len(self.ops): + cur_op = self.ops[op_idx] if "grad" in cur_op.type: break + if cur_op.type == "seed": + self._reserved_vars.extend(cur_op.output_arg_names) + op_idx += 1 + continue if cur_op.type != "dropout": op_idx += 1 continue @@ -188,7 +138,8 @@ class RecomputeState(ProgramStats): var_unique_name = unique_name.generate_with_ignorable_key( ".".join([op_unique_name, 'tmp']) ) - seed_var = self._block.create_var( + self._reserved_vars.append(var_unique_name) + seed_var = self.block.create_var( name=var_unique_name, dtype='int32', type=core.VarDesc.VarType.LOD_TENSOR, @@ -209,7 +160,7 @@ class RecomputeState(ProgramStats): else int(cur_op.attr("seed")) ) # TODO add dependency for seed op to ensure it be issued just before recompute. - seed_op = self._block._insert_op_without_sync( + seed_op = self.block._insert_op_without_sync( index=cur_op.idx, type="seed", inputs={}, @@ -223,7 +174,7 @@ class RecomputeState(ProgramStats): ) # modify dropout op's desc - self._ops.insert(op_idx, seed_op) + self.ops.insert(op_idx, seed_op) cur_op.desc.set_input("Seed", [var_unique_name]) cur_op._remove_attr("fix_seed") cur_op._remove_attr("seed") @@ -232,7 +183,7 @@ class RecomputeState(ProgramStats): ) op_idx += 2 - self._block._sync_with_cpp() + self.block._sync_with_cpp() def _find_op_index(block, cur_op): @@ -242,7 +193,7 @@ def _find_op_index(block, cur_op): return -1 -def _get_stop_gradients(program, no_grad_set): +def _get_stop_gradients(program, no_grad_set=None): """get no grad var""" if no_grad_set is None: no_grad_set = set() @@ -260,16 +211,15 @@ def _get_stop_gradients(program, no_grad_set): def _add_needed_descs_to_block( - descs, block, main_block, in_memory_vars, dist_context + descs, block, main_block, vars_should_be_hold, dist_context ): """ Get the recomputed ops which will insert the backward part """ if len(descs) == 0: return [] + result_descs = [] - op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() - backward = core.op_proto_and_checker_maker.OpRole.Backward for desc in descs: if isinstance(desc, framework.Operator): desc = desc.desc @@ -279,22 +229,29 @@ def _add_needed_descs_to_block( for name in desc.output_arg_names(): if main_block.has_var(name) and main_block.var(name).persistable: continue - if name not in in_memory_vars: + if name not in vars_should_be_hold: is_needed = True if is_needed: new_op_desc = block.desc.append_op() new_op_desc.copy_from(desc) set_dist_op_desc_original_id(new_op_desc, desc, dist_context) - new_op_desc._set_attr(op_role_attr_name, backward) + new_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward) result_descs.append(new_op_desc) return result_descs +def _find_op_path(main_program, loss, no_grad_set=None): + no_grad_set_name = _get_stop_gradients(main_program, no_grad_set) + op_path = _find_op_path_( + main_program.global_block(), [loss], [], no_grad_set_name + ) + return op_path + + @register_pass("auto_parallel_recompute") class RecomputePass(PassBase): def __init__(self): super().__init__() - self.set_attr("checkpoints", None) self.set_attr("loss", None) self.set_attr("dist_context", None) self.set_attr("no_grad_set", None) @@ -311,49 +268,64 @@ class RecomputePass(PassBase): return True def _apply_single_impl(self, main_program, startup_program, context): - checkpoints = self.get_attr("checkpoints") - no_recompute_segments = self.get_attr("no_recompute_segments") loss = self.get_attr("loss") no_grad_set = self.get_attr("no_grad_set") + no_recompute_segments = self.get_attr("no_recompute_segments") self._dist_context = self.get_attr("dist_context") # 0. get op_path which is related to loss main_block = main_program.global_block() - no_grad_set_name = _get_stop_gradients(main_program, no_grad_set) - op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name) + op_path = _find_op_path(main_program, loss, no_grad_set) # 1. build recompute state rc_state = RecomputeState(main_block, op_path) - if not rc_state.is_recompute() and not checkpoints: + if not rc_state.is_recompute(): return # 2. get the segments to be recomputed rc_state.modify_forward_desc_for_recompute(self._dist_context) - rc_state.build_stats() - checkpoints = rc_state.sort_checkpoints(checkpoints or []) - segments, checkpoints = rc_state.get_recompute_segments( - checkpoints, no_recompute_segments - ) - if segments == [] or checkpoints == []: + rc_state.build_states() + segments = rc_state.get_recompute_segments(no_recompute_segments) + if segments == []: return + for i, (idx1, idx2) in enumerate(segments): + logging.info( + "recompute segment[{}/{}]".format(i + 1, len(segments)) + ) + logging.info( + "segment start op: [{}]: [{}] [{}]".format( + rc_state.ops[idx1].type, + rc_state.ops[idx1].input_arg_names, + rc_state.ops[idx1].output_arg_names, + ) + ) + logging.info( + "segment end op: [{}]: [{}] [{}]".format( + rc_state.ops[idx2 - 1].type, + rc_state.ops[idx2 - 1].input_arg_names, + rc_state.ops[idx2 - 1].output_arg_names, + ) + ) + # 3. get vars that should be hold in memory vars_should_be_hold = [] for segment in segments: vars_should_be_hold.extend( rc_state.get_out_of_subgraph_vars(segment[0], segment[1]) ) - cross_vars = set(vars_should_be_hold) - set(checkpoints) + cross_vars = set(vars_should_be_hold) - set(rc_state.checkpoints) logging.info( "found [{}] vars which cross recompute segment: [{}]," "better checkpoints might be set to reduce those vars".format( len(cross_vars), cross_vars ) ) - vars_should_be_hold.extend(rc_state.get_reserved_vars()) + vars_should_be_hold.extend(rc_state.reserved_vars) vars_should_be_hold.extend(rc_state.get_input_nodes()) - vars_should_be_hold = list(set(vars_should_be_hold)) - vars_in_memory = vars_should_be_hold + checkpoints + vars_should_be_hold = list( + set(vars_should_be_hold) | set(rc_state.checkpoints) + ) # 4. get the fwd ops desc to be recomputed. var_name_dict = {} # varname --> varname.subprog_XXX @@ -364,20 +336,23 @@ class RecomputePass(PassBase): var_suffix = ".subprog_%d" % i for op in fwd_ops: input_and_output_names = [] - input_and_output_names.extend(op.desc.input_arg_names()) - input_and_output_names.extend(op.desc.output_arg_names()) + input_and_output_names.extend(op.input_arg_names) + input_and_output_names.extend(op.output_arg_names) + cur_op_dist_attr = ( self._dist_context.get_op_dist_attr_for_program(op) ) assert cur_op_dist_attr is not None + for name in input_and_output_names: - if main_block.var(name).persistable or name in checkpoints: - continue - if name in vars_should_be_hold: + if ( + main_block.var(name).persistable + or name in vars_should_be_hold + ): continue if name not in var_name_dict: ref_process_mesh = cur_op_dist_attr.process_mesh - if name in op.desc.input_arg_names(): + if name in op.input_arg_names: ref_dims_mapping = ( cur_op_dist_attr.get_input_dims_mapping(name) ) @@ -385,6 +360,7 @@ class RecomputePass(PassBase): ref_dims_mapping = ( cur_op_dist_attr.get_output_dims_mapping(name) ) + # record recomputed var's old_name and new_name (old_name.subprog_XXX) # create new var with new name var_name_dict[name] = name + var_suffix @@ -409,7 +385,7 @@ class RecomputePass(PassBase): fwd_ops, buffer_block, main_block, - vars_in_memory, + vars_should_be_hold, self._dist_context, ) # rename recomputed ops' input and output var name @@ -437,15 +413,15 @@ class RecomputePass(PassBase): grad_op._remove_attr("fix_seed") grad_op._remove_attr("seed") - # rename grad op's var_name which is not in 'vars_in_memory' - for key in var_name_dict: - if ( - key - not in grad_op.input_arg_names + grad_op.output_arg_names - ): + input_and_output_names = [] + input_and_output_names.extend(grad_op.input_arg_names) + input_and_output_names.extend(grad_op.output_arg_names) + + for varname in var_name_dict: + if varname not in input_and_output_names: continue self.reset_op_dist_attr(grad_op, var_name_dict) - _rename_arg_([grad_op.desc], key, var_name_dict[key]) + _rename_arg_([grad_op.desc], varname, var_name_dict[varname]) # insert recomputed ops original_id = grad_op.desc.original_id() @@ -504,13 +480,13 @@ class RecomputePass(PassBase): def reset_op_dist_attr(self, op, var_name_dict): op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op) assert op_dist_attr is not None - for input in op.desc.input_arg_names(): + for input in op.input_arg_names: if input in var_name_dict.keys(): in_dist_attr = op_dist_attr.get_input_dist_attr(input) op_dist_attr.set_input_dist_attr( var_name_dict[input], in_dist_attr ) - for output in op.desc.output_arg_names(): + for output in op.output_arg_names: if output in var_name_dict.keys(): out_dist_attr = op_dist_attr.get_output_dist_attr(output) op_dist_attr.set_output_dist_attr( diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index d13e9b69b578a1dc29ae3f6a998d803f55301c14..21c0f88438ad818a17e48195438feb30bb3b538c 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -74,6 +74,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120) py_test_modules(test_selective_recompute MODULES test_selective_recompute) set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50) + py_test_modules(test_tuning_recompute MODULES test_tuning_recompute) + set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 240) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py index b77d42653abdba910a9a64b46f9a880fccdfd8dc..35bf1a323d15c4a0f07d51bbb9f3394ab65ee4b9 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py @@ -28,12 +28,9 @@ from auto_parallel_gpt_model import ( GPTPretrainingCriterion, ) -sequence_len = 512 -vocab_size = 1000 - class FakeDataset(paddle.io.Dataset): - def __init__(self, num_samples): + def __init__(self, num_samples, vocab_size=1000, sequence_len=512): self.num_samples = num_samples self.sequence_len = sequence_len self.vocab_size = vocab_size @@ -57,7 +54,7 @@ class FakeDataset(paddle.io.Dataset): return self.num_samples -def create_data_holder(batch_size): +def create_data_holder(batch_size, vocab_size=1000, sequence_len=512): tokens = paddle.static.InputSpec( name="tokens", shape=[batch_size, sequence_len], dtype='int64' ) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py index 10005008cdbe54dd6a3d2809c16d7abff94e7e0f..dfb554ac722d1ffc8b479e9d2cbcecdc23cea4f7 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py @@ -98,7 +98,7 @@ def train(fetch): tuning.profile_start_step = 1 tuning.profile_end_step = 5 tuning.run_after_tuning = True - tuning.verbose = True + tuning.debug = True dataset = MyDataset(batch_num * batch_size) engine = auto.Engine( diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py index 8649c0f8dffcd30b8412cc3a0a9a34f632fe2114..529d1d5f6255d64873d20e836d82df5db13c0bb8 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py @@ -24,7 +24,7 @@ class TestStrategy(unittest.TestCase): recompute = strategy.recompute self.assertEqual(recompute.enable, False) - self.assertIsNone(recompute.checkpoints) + self.assertEqual(recompute.checkpoints, []) amp = strategy.amp self.assertEqual(amp.enable, False) @@ -66,12 +66,10 @@ class TestStrategy(unittest.TestCase): tuning = strategy.tuning self.assertEqual(tuning.enable, False) - self.assertEqual(tuning.batch_size, 1) - self.assertIsNone(tuning.dataset) self.assertEqual(tuning.profile_start_step, 1) self.assertEqual(tuning.profile_end_step, 1) self.assertEqual(tuning.run_after_tuning, True) - self.assertEqual(tuning.verbose, True) + self.assertEqual(tuning.debug, False) def test_modify_config(self): strategy = auto.Strategy() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a7deee6d216001ef04751e7cc423d81eb1f6cc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_tuning_recompute.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +from get_gpt_model import FakeDataset + +import paddle +from paddle.distributed.fleet import auto + +sys.path.append("..") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import ( + GPTForPretraining, + GPTModel, + GPTPretrainingCriterion, +) + + +def generate_model(): + modeling.init_global() + modeling._global_parallel_strategy = "serial" + + gpt = GPTModel( + vocab_size=50304, + hidden_size=1024, + num_hidden_layers=14, + num_attention_heads=16, + intermediate_size=1024 * 4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3, + use_new_recompute=True, + recompute_granularity="full", + ) + model = GPTForPretraining( + gpt, vocab_size=50304, hidden_size=1024, initializer_range=0.02 + ) + criterion = GPTPretrainingCriterion() + return model, criterion + + +def apply_pass(): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + + recompute = strategy.recompute + recompute.enable = True + recompute.enable_tuning = True + + tuning = strategy.tuning + tuning.enable = True + tuning.profile_start_step = 1 + tuning.profile_end_step = 2 + tuning.run_after_tuning = True + tuning.verbose = True + return strategy + + +class TestRecomputePassTuning(unittest.TestCase): + def setUp(self): + + self.batch_size = 8 + self.batch_num = 200 + self.dataset = FakeDataset( + self.batch_size * self.batch_num, + vocab_size=50304, + sequence_len=1024, + ) + + def test_recompute_pass(self): + + strategy = apply_pass() + clip = paddle.nn.ClipGradByGlobalNorm(0.2) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model() + + engine = auto.Engine(model, loss, opt, strategy=strategy) + engine._tune(self.dataset, 3, batch_size=self.batch_size) + + assert ( + len( + engine._dist_contexts[ + 'train' + ].strategy.recompute.no_recompute_segments + ) + > 0 + ) + + +if __name__ == "__main__": + unittest.main()