未验证 提交 170a31f9 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] recompute tuning (#48608)

* [AutoParallel] recompute tuning

* fix conflict

* update comment

* bug fix

* update rc algo

* tiny fix

* fix clear process_group

* remove comment

* update segment print

* fix import OpRole

* adapt amp pass and grad_clip pass for opt_tuner

* update tuning config

* fix import

* annotate recompute info on ops and upgrade recompute pass

* add op_namescope for seed op

* record reserved vars

* fix recompute var's dist_attr

* fix strategy unittest

* adapt for fp16

* update unittest

* revert copy opt

* update unittest

* rename set_recompute_segments

* fix unittest
上级 b9207054
...@@ -54,7 +54,7 @@ set_field_default_config(BASE, "reinit", False) # Only for debug ...@@ -54,7 +54,7 @@ set_field_default_config(BASE, "reinit", False) # Only for debug
######################################### #########################################
RECOMPUTE = "recompute" RECOMPUTE = "recompute"
set_field_default_config(RECOMPUTE, "enable", False) set_field_default_config(RECOMPUTE, "enable", False)
set_field_default_config(RECOMPUTE, "checkpoints", None) set_field_default_config(RECOMPUTE, "checkpoints", [])
set_field_default_config(RECOMPUTE, "no_recompute_segments", []) set_field_default_config(RECOMPUTE, "no_recompute_segments", [])
set_field_default_config(RECOMPUTE, "enable_tuning", False) set_field_default_config(RECOMPUTE, "enable_tuning", False)
...@@ -113,12 +113,10 @@ set_field_default_config(QAT, "algo", None) ...@@ -113,12 +113,10 @@ set_field_default_config(QAT, "algo", None)
# ######################################### # #########################################
TUNING = "tuning" TUNING = "tuning"
set_field_default_config(TUNING, "enable", False) set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1)
set_field_default_config(TUNING, "dataset", None)
set_field_default_config(TUNING, "profile_start_step", 1) set_field_default_config(TUNING, "profile_start_step", 1)
set_field_default_config(TUNING, "profile_end_step", 1) set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True) set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True) set_field_default_config(TUNING, "debug", False)
######################################### #########################################
# dataset configuration # dataset configuration
......
...@@ -609,7 +609,9 @@ class Engine: ...@@ -609,7 +609,9 @@ class Engine:
if mode != "train": if mode != "train":
serial_main_prog = serial_main_prog.clone(for_test=True) serial_main_prog = serial_main_prog.clone(for_test=True)
auto_utils.set_recompute_ckpts(self._model, self._strategy) auto_utils.set_recompute_segments(
self._model, self._losses, self._strategy, serial_main_prog
)
self._dist_contexts[mode] = DistributedContext( self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_main_prog,
serial_startup_prog, serial_startup_prog,
...@@ -649,7 +651,6 @@ class Engine: ...@@ -649,7 +651,6 @@ class Engine:
from .tuner.optimization_tuner import OptimizationTuner from .tuner.optimization_tuner import OptimizationTuner
self._optimization_tuner = OptimizationTuner( self._optimization_tuner = OptimizationTuner(
self._tuning.to_dict(),
self._dist_contexts[mode], self._dist_contexts[mode],
dataset, dataset,
self._inputs_spec, self._inputs_spec,
......
...@@ -73,6 +73,10 @@ class BaseConfig: ...@@ -73,6 +73,10 @@ class BaseConfig:
setattr(result, k, copy.deepcopy(v, memo)) setattr(result, k, copy.deepcopy(v, memo))
return result return result
def get(self, k, d=None):
result_dict = self.to_dict()
return result_dict.get(k, d)
class RecomputeConfig(BaseConfig): class RecomputeConfig(BaseConfig):
def __init__(self, config_dict=None): def __init__(self, config_dict=None):
......
...@@ -16,7 +16,7 @@ import copy ...@@ -16,7 +16,7 @@ import copy
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ..utils import get_logger from ..utils import get_logger, is_recompute_op
from .trial import OptimizationTunerTrial as Trial from .trial import OptimizationTunerTrial as Trial
from .trial import TrialStatus from .trial import TrialStatus
...@@ -54,7 +54,7 @@ class AlgorithmBase(ABC): ...@@ -54,7 +54,7 @@ class AlgorithmBase(ABC):
def collect_model_info(self, main_prog, startup_prog): def collect_model_info(self, main_prog, startup_prog):
""" """
Collect the model static info (from programs) that could be used to Collect the model static info (from programs) that could be used to
pruning candidate trials and saving tuning time.For instance, pruning candidate trials and saving tuning time. For instance,
model info like number of model parameters and activation memory could be model info like number of model parameters and activation memory could be
used to prune candidated trial and decide the next trial. used to prune candidated trial and decide the next trial.
""" """
...@@ -116,7 +116,7 @@ class ShardingStageAlgorithm(AlgorithmBase): ...@@ -116,7 +116,7 @@ class ShardingStageAlgorithm(AlgorithmBase):
self._max_stage = 3 self._max_stage = 3
self._trial_idx = 0 self._trial_idx = 0
stage_range = self._config.sharding.to_dict().get("tuning_range", None) stage_range = self._config.sharding.get("tuning_range", None)
if stage_range: if stage_range:
assert set(stage_range).issubset( assert set(stage_range).issubset(
set([0, 1, 2, 3]) set([0, 1, 2, 3])
...@@ -157,3 +157,92 @@ class ShardingStageAlgorithm(AlgorithmBase): ...@@ -157,3 +157,92 @@ class ShardingStageAlgorithm(AlgorithmBase):
) )
else: else:
self._trial_idx += 1 self._trial_idx += 1
@register_algor("recompute")
class ReccomputeCheckpointAlgorithm(AlgorithmBase):
def __init__(self, config):
super().__init__(config)
self._changed_configs = ["recompute"]
def collect_model_info(self, main_prog, startup_prog):
segments = []
for op in main_prog.global_block().ops:
if not is_recompute_op(op):
continue
seg_name = op.attr('op_namescope')
if seg_name not in segments:
segments.append(seg_name)
self._total_num_trial = len(segments)
self._tuning_segments = list(range(len(segments)))
self._trail_left = 0
self._trail_right = len(segments) - 1
self._trial_idx = int(0 + (len(segments) - 1) / 2)
def _init_spaces(self):
self._recompute_mode = "all"
def next_trial(self):
if self._trial_idx < self._total_num_trial:
if self._recompute_mode == "all":
self._recompute_flag = False
new_strategy = copy.deepcopy(self._config.dist_strategy)
name = "trial-recompute-all-segments"
return Trial(new_strategy, name, self.changed_configs)
elif self._recompute_mode == "none":
self._recompute_flag = False
new_strategy = copy.deepcopy(self._config.dist_strategy)
recompute = new_strategy.recompute
recompute.enable = False
name = "trial-recompute-none-segments"
return Trial(new_strategy, name, self.changed_configs)
elif self._recompute_mode == "part":
new_no_recompute = self._tuning_segments[: self._trial_idx]
new_strategy = copy.deepcopy(self._config.dist_strategy)
recompute = new_strategy.recompute
recompute.no_recompute_segments.extend(new_no_recompute)
name = "trial-recompute-part-segments-idx{}".format(
self._trial_idx
)
return Trial(new_strategy, name, self.changed_configs)
else:
return Trial(None, None, None, status=TrialStatus.STOPPED)
def update(self, results):
et = results.get("ErrorType", None)
if self._recompute_mode == "all":
if et and et == "ResourceExhaustedError":
self._trial_idx = self._total_num_trial
self._logger.info(
"Recompute all candidate segments is failed with OOM, please reduce model size or batch size."
)
else:
self._recompute_mode = "none"
elif self._recompute_mode == "none":
if et and et == "ResourceExhaustedError":
self._recompute_mode = "part"
else:
self._trial_idx = self._total_num_trial
self._logger.info(
"Recompute is unnecessary for this model size, which will reduce the Throughtput."
)
else:
if self._trail_left >= self._trail_right:
self._trial_idx = self._total_num_trial
elif et and et == "ResourceExhaustedError":
self._trail_left = self._trail_left
self._trail_right = self._trial_idx - 1
self._trial_idx = int(
self._trail_left
+ (self._trail_right - self._trail_left) / 2
)
else:
self._trail_left = self._trial_idx + 1
self._trail_right = self._trail_right
self._trial_idx = int(
self._trail_left
+ (self._trail_right - self._trail_left) / 2
)
...@@ -32,14 +32,11 @@ class TuningConfig: ...@@ -32,14 +32,11 @@ class TuningConfig:
tuning config: configuration for the tuning process: mode (profile or cost model), log dir, extra tuning config for optimization like search range for specific tuning config: configuration for the tuning process: mode (profile or cost model), log dir, extra tuning config for optimization like search range for specific
""" """
def __init__(self, user_config, strategy): def __init__(self, strategy):
if not isinstance(strategy, Strategy): if not isinstance(strategy, Strategy):
raise TypeError("'strategy' must be object of class `Strategy`.") raise TypeError("'strategy' must be object of class `Strategy`.")
if not user_config:
user_config = {}
self._tuning_passes_name = set() self._tuning_passes_name = set()
self._dist_strategy = copy.deepcopy(strategy) self._dist_strategy = copy.deepcopy(strategy)
self._mode = None self._mode = None
...@@ -48,9 +45,9 @@ class TuningConfig: ...@@ -48,9 +45,9 @@ class TuningConfig:
self._project_dir = None self._project_dir = None
self._max_num_trial = None self._max_num_trial = None
self._early_stop = None self._early_stop = None
self._verbose = None self._debug = None
self._initialize(user_config) self._initialize()
@property @property
def mode(self): def mode(self):
...@@ -81,29 +78,25 @@ class TuningConfig: ...@@ -81,29 +78,25 @@ class TuningConfig:
return self._early_stop return self._early_stop
@property @property
def verbose(self): def debug(self):
return self._verbose return self._debug
@property @property
def dist_strategy(self): def dist_strategy(self):
return self._dist_strategy return self._dist_strategy
# initialize config with user define value or default value # initialize config with user define value or default value
def _initialize(self, user_config): def _initialize(self):
tuning_strategy = self._dist_strategy.tuning
self._mode = user_config.get("mode", "PROFILE")
self._profile_start_step = user_config.get("profile_start_step", 10)
self._profile_end_step = user_config.get("profile_end_step", 30)
self._max_num_trial = user_config.get("max_num_trial", 50)
self._early_stop = user_config.get("early_stop", None)
self._verbose = user_config.get("verbose", False) self._mode = tuning_strategy.get("mode", "PROFILE")
self._profile_start_step = tuning_strategy.get("profile_start_step", 10)
self._profile_end_step = tuning_strategy.get("profile_end_step", 30)
self._max_num_trial = tuning_strategy.get("max_num_trial", 50)
self._early_stop = tuning_strategy.get("early_stop", None)
self._debug = tuning_strategy.get("debug", False)
project_dir = user_config.get("project_dir", None) project_dir = tuning_strategy.get("project_dir", None)
if not project_dir: if not project_dir:
project_dir = os.path.join(os.getcwd(), "OptimizationTuning") project_dir = os.path.join(os.getcwd(), "OptimizationTuning")
self._project_dir = project_dir self._project_dir = project_dir
...@@ -116,15 +109,14 @@ class TuningConfig: ...@@ -116,15 +109,14 @@ class TuningConfig:
# TODO distinguish different args of each passes # TODO distinguish different args of each passes
self._tuning_passes_name.add(p) self._tuning_passes_name.add(p)
config_name = p p_strategy = getattr(self._dist_strategy, p)
p_dict = getattr(self._dist_strategy, config_name) self.__dict__[p] = p_strategy
self.__dict__[config_name] = p_dict
# TODO verify the user defined configs # # TODO verify the user defined configs
user_config_for_pass = user_config.get(p, None) # tuning_config_for_pass = tuning_strategy.get(p, None)
if user_config_for_pass: # if tuning_config_for_pass:
for k, v in user_config_for_pass.items(): # for k, v in tuning_config_for_pass.items():
self.__dict__[config_name][k] = v # self.__dict__[p][k] = v
# (NOTE)tuning config ONLY wraps dist strategy for pass config which is to be tuned # (NOTE)tuning config ONLY wraps dist strategy for pass config which is to be tuned
def __getattr__(self, item): def __getattr__(self, item):
......
...@@ -33,6 +33,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner ...@@ -33,6 +33,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.process_group import (
clear_all_process_groups, clear_all_process_groups,
get_all_process_groups, get_all_process_groups,
new_process_group,
) )
from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.utils import (
...@@ -40,7 +41,7 @@ from paddle.distributed.auto_parallel.utils import ( ...@@ -40,7 +41,7 @@ from paddle.distributed.auto_parallel.utils import (
set_grad_var_shape, set_grad_var_shape,
) )
from paddle.distributed.passes import PassContext, new_pass from paddle.distributed.passes import PassContext, new_pass
from paddle.fluid import program_guard from paddle.fluid import program_guard, unique_name
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from ..utils import get_logger from ..utils import get_logger
...@@ -109,7 +110,12 @@ def parse_results(results): ...@@ -109,7 +110,12 @@ def parse_results(results):
# all env need to be start a new pass are member of dist context # all env need to be start a new pass are member of dist context
def _copy_context(ref_dist_context): def _copy_context(ref_dist_context):
# clear all process groups and recover the world process group
clear_all_process_groups() clear_all_process_groups()
ranks = []
for process_mesh in ref_dist_context._process_meshes:
ranks.extend(process_mesh.processes)
new_process_group(list(set(ranks)))
new_dist_context = DistributedContext() new_dist_context = DistributedContext()
new_dist_context._serial_main_program = ( new_dist_context._serial_main_program = (
...@@ -195,7 +201,6 @@ class OptimizationTuner: ...@@ -195,7 +201,6 @@ class OptimizationTuner:
def __init__( def __init__(
self, self,
user_configs,
dist_context, dist_context,
dataset, dataset,
inputs_spec, inputs_spec,
...@@ -204,7 +209,7 @@ class OptimizationTuner: ...@@ -204,7 +209,7 @@ class OptimizationTuner:
rank, rank,
): ):
self._config = TuningConfig(user_configs, dist_context._strategy) self._config = TuningConfig(dist_context.strategy)
# should not modify dist context from calling function # should not modify dist context from calling function
self._baseline_dist_context = _copy_context(dist_context) self._baseline_dist_context = _copy_context(dist_context)
self._baseline_completer = Completer(self._baseline_dist_context) self._baseline_completer = Completer(self._baseline_dist_context)
...@@ -264,7 +269,7 @@ class OptimizationTuner: ...@@ -264,7 +269,7 @@ class OptimizationTuner:
) )
self._baseline_dist_context._params_grads = params_grads self._baseline_dist_context._params_grads = params_grads
if self._config.verbose: if self._config.debug:
baseline_dir = os.path.join(self.project_dir, "baseline") baseline_dir = os.path.join(self.project_dir, "baseline")
if not os.path.exists(baseline_dir): if not os.path.exists(baseline_dir):
pathlib.Path(baseline_dir).mkdir(parents=True, exist_ok=True) pathlib.Path(baseline_dir).mkdir(parents=True, exist_ok=True)
...@@ -299,7 +304,6 @@ class OptimizationTuner: ...@@ -299,7 +304,6 @@ class OptimizationTuner:
config = copy.deepcopy(new_strategy.amp.to_dict()) config = copy.deepcopy(new_strategy.amp.to_dict())
config["dist_context"] = dist_context config["dist_context"] = dist_context
config["params_grads"] = dist_context._params_grads config["params_grads"] = dist_context._params_grads
# TODO AMP Pass should not use loss var # TODO AMP Pass should not use loss var
config["loss"] = dist_context.serial_loss config["loss"] = dist_context.serial_loss
config["input_data"] = ( config["input_data"] = (
...@@ -312,13 +316,13 @@ class OptimizationTuner: ...@@ -312,13 +316,13 @@ class OptimizationTuner:
auto_parallel_fp16_pass.apply( auto_parallel_fp16_pass.apply(
[main_program], [startup_program], pass_context [main_program], [startup_program], pass_context
) )
dist_context.serial_loss = auto_parallel_fp16_pass.get_loss() dist_context._serial_loss = auto_parallel_fp16_pass.get_loss()
else: else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply( auto_parallel_amp_pass.apply(
[main_program], [startup_program], pass_context [main_program], [startup_program], pass_context
) )
dist_context.serial_loss = auto_parallel_amp_pass.get_loss() dist_context._serial_loss = auto_parallel_amp_pass.get_loss()
if new_strategy.recompute.enable: if new_strategy.recompute.enable:
config = copy.deepcopy(new_strategy.recompute.to_dict()) config = copy.deepcopy(new_strategy.recompute.to_dict())
...@@ -345,9 +349,10 @@ class OptimizationTuner: ...@@ -345,9 +349,10 @@ class OptimizationTuner:
# Generate optimizer # Generate optimizer
# FIXME should be remove from apply pass after pass support optimizers # FIXME should be remove from apply pass after pass support optimizers
with program_guard(dist_main_prog, dist_startup_prog): with program_guard(dist_main_prog, dist_startup_prog):
optimizer_ops = dist_context.serial_optimizer.apply_gradients( with unique_name.guard("opt_"):
dist_params_grads optimizer_ops = dist_context.serial_optimizer.apply_gradients(
) dist_params_grads
)
completer.complete_update_annotation(dist_main_prog) completer.complete_update_annotation(dist_main_prog)
# Do reshard process # Do reshard process
...@@ -361,6 +366,13 @@ class OptimizationTuner: ...@@ -361,6 +366,13 @@ class OptimizationTuner:
) )
resharder.reshard() resharder.reshard()
config = {}
config["dist_context"] = dist_context
config["global_rank"] = self.rank
config["use_sharding"] = new_strategy.sharding.enable
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
dp_pass.apply([dist_main_prog], [dist_startup_prog], pass_context)
if new_strategy.sharding.enable: if new_strategy.sharding.enable:
config = copy.deepcopy(new_strategy.sharding.to_dict()) config = copy.deepcopy(new_strategy.sharding.to_dict())
config["dist_context"] = dist_context config["dist_context"] = dist_context
...@@ -372,6 +384,17 @@ class OptimizationTuner: ...@@ -372,6 +384,17 @@ class OptimizationTuner:
auto_parallel_sharding_pass.apply( auto_parallel_sharding_pass.apply(
[dist_main_prog], [dist_startup_prog], pass_context [dist_main_prog], [dist_startup_prog], pass_context
) )
dist_params_grads = pass_context.get_attr("params_grads")
# gradient clip
config = copy.deepcopy(new_strategy.sharding.to_dict())
config["dist_context"] = dist_context
config["params_grads"] = dist_params_grads
config["rank_id"] = self.rank
auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config)
auto_parallel_clip_pass.apply(
[dist_main_prog], [dist_startup_prog], pass_context
)
if new_strategy.gradient_merge.enable: if new_strategy.gradient_merge.enable:
config = copy.deepcopy(new_strategy.gradient_merge.to_dict()) config = copy.deepcopy(new_strategy.gradient_merge.to_dict())
...@@ -488,7 +511,7 @@ class OptimizationTuner: ...@@ -488,7 +511,7 @@ class OptimizationTuner:
with open(ctx_path, 'wb') as f: with open(ctx_path, 'wb') as f:
pickle.dump(profile_ctx, f, protocol=4) pickle.dump(profile_ctx, f, protocol=4)
if self._config.verbose: if self._config.debug:
debug_program(trial.main_program, trial_dir, "main_program") debug_program(trial.main_program, trial_dir, "main_program")
debug_program(trial.startup_program, trial_dir, "startup_program") debug_program(trial.startup_program, trial_dir, "startup_program")
...@@ -581,7 +604,7 @@ The best trial is: [{}], whose configuration is following: ...@@ -581,7 +604,7 @@ The best trial is: [{}], whose configuration is following:
Clear the temporary file generated in tuning procedure. Clear the temporary file generated in tuning procedure.
""" """
# TODO clear up zombie process created by tuning # TODO clear up zombie process created by tuning
if not self._config.verbose: if not self._config.debug:
for trial in self._finished_trials: for trial in self._finished_trials:
trial_dir = self._get_trial_dir(trial) trial_dir = self._get_trial_dir(trial)
shutil.rmtree(trial_dir, ignore_errors=True) shutil.rmtree(trial_dir, ignore_errors=True)
......
...@@ -89,7 +89,7 @@ def init_process_groups(group_map, rank): ...@@ -89,7 +89,7 @@ def init_process_groups(group_map, rank):
# TODO should instantiate global group first # TODO should instantiate global group first
all_process_groups = get_all_process_groups() all_process_groups = get_all_process_groups()
for process_group in all_process_groups: for process_group in all_process_groups:
if process_group.id == 0 or rank not in process_group.ranks: if rank not in process_group.ranks:
continue continue
print(process_group) print(process_group)
process_group.instantiate() process_group.instantiate()
...@@ -173,10 +173,11 @@ def init_comm(profile_ctx): ...@@ -173,10 +173,11 @@ def init_comm(profile_ctx):
genv = _get_global_env() genv = _get_global_env()
genv = dist_env genv = dist_env
print( print(
"current process rank: {}, device_id: {}, ip: {}.", "current process rank: {}, device_id: {}, ip: {}.".format(
genv.rank, genv.rank,
genv.device_id, genv.device_id,
genv.current_endpoint, genv.current_endpoint,
)
) )
# init nccl comm # init nccl comm
...@@ -231,13 +232,12 @@ def profiler(args): ...@@ -231,13 +232,12 @@ def profiler(args):
exe = get_executor() exe = get_executor()
exe.run(startup_program)
# profile main
duration = 0
eval_step = 0
data_loader._inner_dataloader.start()
try: try:
exe.run(startup_program)
# profile main
duration = 0
eval_step = 0
data_loader._inner_dataloader.start()
while eval_step < args.profile_end_step: while eval_step < args.profile_end_step:
start_time = time.time() start_time = time.time()
......
...@@ -22,18 +22,17 @@ from functools import reduce ...@@ -22,18 +22,17 @@ from functools import reduce
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid.core as core from paddle.fluid.framework import Variable
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.fluid.io import is_belong_to_optimizer, is_parameter
from paddle.framework import core
from .dist_attribute import (
OperatorDistributedAttribute, OperatorDistributedAttribute,
TensorDistributedAttribute, TensorDistributedAttribute,
) )
from paddle.distributed.auto_parallel.process_group import ( from .process_group import get_all_process_groups
get_all_process_groups,
)
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.framework import Variable
from paddle.fluid.io import is_belong_to_optimizer, is_parameter
OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
__no_shape_var_type__ = [ __no_shape_var_type__ = [
...@@ -1921,10 +1920,16 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank): ...@@ -1921,10 +1920,16 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):
server_socket.close() server_socket.close()
def set_recompute_ckpts(model, strategy): def is_recompute_op(op):
from .interface import _g_recompute_idx return op.has_attr('op_namescope') and "/auto_parallel/rc" in op.attr(
'op_namescope'
)
if _g_recompute_idx > -1:
def set_recompute_segments(model, losses, strategy, program):
from ..passes.auto_parallel_recompute import RecomputeState
if not losses:
return return
recompute = strategy.recompute recompute = strategy.recompute
...@@ -1934,24 +1939,65 @@ def set_recompute_ckpts(model, strategy): ...@@ -1934,24 +1939,65 @@ def set_recompute_ckpts(model, strategy):
# NOTE: hack to enable recompute in engine api for GPT-3 # NOTE: hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here # TODO support more PaddleNLP/CV models here
# extract ckpts by specific model # extract ckpts by specific model
ckpts = []
if isinstance(model, paddle.nn.Layer): if isinstance(model, paddle.nn.Layer):
if hasattr(model, "gpt") and model.__class__.__name__ in [ if (
'GPTForPretraining', hasattr(model, "gpt")
'GPTForPretrainingAuto', and model.__class__.__name__
]: in [
exact_ckpts = model.gpt.checkpoints 'GPTForPretraining',
'GPTForPretrainingAuto',
]
and hasattr(model.gpt, "checkpoints")
):
ckpts = model.gpt.checkpoints
else: else:
exact_ckpts = recompute.checkpoints ckpts = recompute.checkpoints
else: else:
exact_ckpts = recompute.checkpoints ckpts = recompute.checkpoints
# modify strategy if not ckpts:
recompute.checkpoints = exact_ckpts[:] return
logs = {
'Model Class': model.__class__.__name__, block = program.global_block()
'Applied Recompute ckpts': exact_ckpts, rc_state = RecomputeState(block, block.ops)
} rc_state.build_stats()
logging.info(logs) checkpoints = rc_state.sort_checkpoints(ckpts)
segments = []
start_idx = -1
pre_segment_end_idx = -1
while start_idx + 1 < len(checkpoints):
if start_idx == -1:
ckpt_name = checkpoints[start_idx + 1]
if ckpt_name not in rc_state.var_op_deps:
start_idx += 1
continue
op_idx_list = rc_state.var_op_deps[ckpt_name]["var_as_output_ops"]
if op_idx_list and max(op_idx_list) > 0:
segments.append([0, max(op_idx_list) + 1])
else:
flag, min_idx, max_idx = rc_state.is_subgraph(
[checkpoints[start_idx]], [checkpoints[start_idx + 1]]
)
if flag:
min_idx = rc_state._update_segment_start(
min_idx, pre_segment_end_idx
)
segments.append([min_idx, max_idx + 1])
else:
logging.debug(
"Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1
)
)
start_idx += 1
for i, segment in enumerate(segments):
for j in range(segment[0], segment[1]):
block.ops[j]._set_attr(
'op_namescope', "/auto_parallel/rc_" + str(i)
)
def get_input_split_info(cur_rank, var, dist_context): def get_input_split_info(cur_rank, var, dist_context):
......
...@@ -226,6 +226,9 @@ class AMPState: ...@@ -226,6 +226,9 @@ class AMPState:
dist_context, out_var, ref_mapping, ref_mesh dist_context, out_var, ref_mapping, ref_mesh
) )
op_namescope = "/"
if op.has_attr('op_namescope'):
op_namescope = op.attr('op_namescope')
cast_op = self._block._insert_op_without_sync( cast_op = self._block._insert_op_without_sync(
idx, idx,
type="cast", type="cast",
...@@ -236,6 +239,9 @@ class AMPState: ...@@ -236,6 +239,9 @@ class AMPState:
"out_dtype": out_var.dtype, "out_dtype": out_var.dtype,
}, },
) )
cast_op._set_attr(
'op_namescope', op_namescope
) # for recompute
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context cast_op, ref_mesh, ref_mapping, dist_context
) )
......
...@@ -22,13 +22,12 @@ from paddle.distributed.auto_parallel.process_group import ( ...@@ -22,13 +22,12 @@ from paddle.distributed.auto_parallel.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.utils import (
OP_ROLE_KEY,
OpRole,
is_backward_op, is_backward_op,
is_forward_op, is_forward_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr, set_var_dist_attr,
) )
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.contrib.mixed_precision.fp16_utils import ( from paddle.fluid.contrib.mixed_precision.fp16_utils import (
AutoMixedPrecisionLists, AutoMixedPrecisionLists,
...@@ -417,6 +416,9 @@ class FP16State: ...@@ -417,6 +416,9 @@ class FP16State:
dist_context, cast_var, ref_mapping, ref_mesh dist_context, cast_var, ref_mapping, ref_mesh
) )
op_namescope = "/"
if op.has_attr('op_namescope'):
op_namescope = op.attr('op_namescope')
cast_op = block._insert_op_without_sync( cast_op = block._insert_op_without_sync(
idx, idx,
type="cast", type="cast",
...@@ -428,6 +430,9 @@ class FP16State: ...@@ -428,6 +430,9 @@ class FP16State:
OP_ROLE_KEY: OpRole.Forward, OP_ROLE_KEY: OpRole.Forward,
}, },
) )
cast_op._set_attr(
'op_namescope', op_namescope
) # for recompute
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context cast_op, ref_mesh, ref_mapping, dist_context
) )
......
...@@ -17,6 +17,7 @@ from functools import reduce ...@@ -17,6 +17,7 @@ from functools import reduce
import numpy as np import numpy as np
import paddle import paddle
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from ..auto_parallel.dist_attribute import ( from ..auto_parallel.dist_attribute import (
OperatorDistributedAttribute, OperatorDistributedAttribute,
...@@ -25,8 +26,6 @@ from ..auto_parallel.dist_attribute import ( ...@@ -25,8 +26,6 @@ from ..auto_parallel.dist_attribute import (
from ..auto_parallel.process_group import get_world_process_group from ..auto_parallel.process_group import get_world_process_group
from ..auto_parallel.reshard import Resharder from ..auto_parallel.reshard import Resharder
from ..auto_parallel.utils import ( from ..auto_parallel.utils import (
OP_ROLE_KEY,
OpRole,
_get_comm_group, _get_comm_group,
insert_dependencies_for_two_vars, insert_dependencies_for_two_vars,
is_gradient_clip_op, is_gradient_clip_op,
......
...@@ -19,12 +19,11 @@ from paddle.distributed.auto_parallel.process_group import ( ...@@ -19,12 +19,11 @@ from paddle.distributed.auto_parallel.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.utils import (
OP_ROLE_KEY,
OpRole,
is_optimize_op, is_optimize_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr, set_var_dist_attr,
) )
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.fluid import layers from paddle.fluid import layers
from paddle.fluid.framework import device_guard from paddle.fluid.framework import device_guard
from paddle.framework import core from paddle.framework import core
......
...@@ -14,19 +14,8 @@ ...@@ -14,19 +14,8 @@
import logging import logging
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
OperatorDistributedAttribute, from paddle.fluid import core, framework, unique_name
)
from paddle.distributed.auto_parallel.utils import (
get_loss_op,
insert_dependencies_for_two_ops,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_dist_op_desc_original_id,
set_var_dist_attr,
)
from paddle.fluid import core
from paddle.fluid import framework as framework
from paddle.fluid import unique_name
from paddle.fluid.backward import ( from paddle.fluid.backward import (
ProgramStats, ProgramStats,
_append_grad_suffix_, _append_grad_suffix_,
...@@ -35,28 +24,43 @@ from paddle.fluid.backward import ( ...@@ -35,28 +24,43 @@ from paddle.fluid.backward import (
_rename_arg_, _rename_arg_,
) )
from ..auto_parallel.dist_attribute import OperatorDistributedAttribute
from ..auto_parallel.utils import (
get_loss_op,
insert_dependencies_for_two_ops,
is_backward_op,
is_recompute_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_dist_op_desc_original_id,
set_var_dist_attr,
)
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
def _to_be_recomputed(op):
return op.has_attr('op_namescope') and "/auto_parallel/rc_" in op.attr(
'op_namescope'
)
class RecomputeState(ProgramStats): class RecomputeState(ProgramStats):
def __init__(self, block, ops): def __init__(self, block, ops):
super().__init__(block=block, ops=ops) super().__init__(block=block, ops=ops)
self._block = block
self._ops = ops
# {varname: {as_input_ops: op_idx, as_output_ops: op_idx}}
self.var_op_deps = {}
# {segment_name: op_idx}
self.seg_op_deps = {} self.seg_op_deps = {}
self._checkpoints = []
self._reserved_vars = []
@property
def checkpoints(self):
return self._checkpoints
@property
def reserved_vars(self):
return self._reserved_vars
def build_stats(self): def is_recompute(self):
for i, op in enumerate(self._ops): return any([is_recompute_op(op) for op in self.ops])
for name in op.desc.input_arg_names():
def build_states(self):
for i, op in enumerate(self.ops):
if is_backward_op(op):
break
for name in op.input_arg_names:
if name in self.var_op_deps: if name in self.var_op_deps:
self.var_op_deps[name]["var_as_input_ops"].extend([i]) self.var_op_deps[name]["var_as_input_ops"].extend([i])
else: else:
...@@ -64,7 +68,7 @@ class RecomputeState(ProgramStats): ...@@ -64,7 +68,7 @@ class RecomputeState(ProgramStats):
self.var_op_deps[name]["var_as_input_ops"] = [i] self.var_op_deps[name]["var_as_input_ops"] = [i]
self.var_op_deps[name]["var_as_output_ops"] = [] self.var_op_deps[name]["var_as_output_ops"] = []
for name in op.desc.output_arg_names(): for name in op.output_arg_names:
if name in self.var_op_deps: if name in self.var_op_deps:
self.var_op_deps[name]["var_as_output_ops"].extend([i]) self.var_op_deps[name]["var_as_output_ops"].extend([i])
else: else:
...@@ -72,7 +76,8 @@ class RecomputeState(ProgramStats): ...@@ -72,7 +76,8 @@ class RecomputeState(ProgramStats):
self.var_op_deps[name]["var_as_input_ops"] = [] self.var_op_deps[name]["var_as_input_ops"] = []
self.var_op_deps[name]["var_as_output_ops"] = [i] self.var_op_deps[name]["var_as_output_ops"] = [i]
if not _to_be_recomputed(op): if not is_recompute_op(op):
self._checkpoints.extend(op.output_arg_names)
continue continue
seg_name = op.attr('op_namescope') seg_name = op.attr('op_namescope')
...@@ -84,97 +89,42 @@ class RecomputeState(ProgramStats): ...@@ -84,97 +89,42 @@ class RecomputeState(ProgramStats):
), "The recompute segment's ops should be continuous" ), "The recompute segment's ops should be continuous"
self.seg_op_deps[seg_name].extend([i]) self.seg_op_deps[seg_name].extend([i])
def get_recompute_segments( def get_recompute_segments(self, no_recompute_segments=[]):
self, checkpoints_list=None, no_recompute_segments=[]
):
"""get recompute segments and checkpoints"""
segments = [] segments = []
checkpoints = checkpoints_list or [] for segment_idx in self.seg_op_deps.values():
if len(segment_idx) == 1:
if len(checkpoints) == 0: continue
# the segments is marked by `auto.recompute()` api segments.append([segment_idx[0], segment_idx[-1] + 1])
for segment_idx in self.seg_op_deps.values(): self._checkpoints.extend(self.ops[segment_idx[-1]].output_arg_names)
if len(segment_idx) == 1:
continue for i in reversed(sorted(no_recompute_segments)):
segments.append([segment_idx[0], segment_idx[-1] + 1]) assert i < len(
checkpoints.extend(self._ops[segment_idx[-1]].output_arg_names) segments
else: ), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format(
# the segments is marked by `strategy.checkpoints` api i, len(segments)
start_idx = -1
pre_segment_end_idx = -1
while start_idx + 1 < len(checkpoints):
if start_idx == -1:
ckpt_name = checkpoints[start_idx + 1]
if ckpt_name not in self.var_op_deps:
start_idx += 1
continue
op_idx_list = self.var_op_deps[ckpt_name][
"var_as_output_ops"
]
if op_idx_list:
segments.append([0, max(op_idx_list) + 1])
else:
flag, min_idx, max_idx = self.is_subgraph(
[checkpoints[start_idx]], [checkpoints[start_idx + 1]]
)
if flag:
min_idx = self._update_segment_start(
min_idx, pre_segment_end_idx
)
segments.append([min_idx, max_idx + 1])
else:
logging.info(
"Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1
)
)
start_idx += 1
if no_recompute_segments:
for i in reversed(sorted(no_recompute_segments)):
assert i < len(
segments
), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format(
i, len(segments)
)
segments.pop(i)
for i, (idx1, idx2) in enumerate(segments):
logging.info("recompute segment[{}]".format(i))
logging.info(
"segment start op: [{}]: [{}] [{}]".format(
self._ops[idx1].desc.type(),
self._ops[idx1].desc.input_arg_names(),
self._ops[idx1].desc.output_arg_names(),
)
)
logging.info(
"segment end op: [{}]: [{}] [{}]".format(
self._ops[idx2 - 1].desc.type(),
self._ops[idx2 - 1].desc.input_arg_names(),
self._ops[idx2 - 1].desc.output_arg_names(),
)
) )
segments.pop(i)
return segments, checkpoints return segments
def is_recompute(self):
return any([_to_be_recomputed(op) for op in self._ops])
def modify_forward_desc_for_recompute(self, dist_context): def modify_forward_desc_for_recompute(self, dist_context):
""" """
If program's foward part has 'dropout' op, this function will insert If program's foward part has 'dropout' op, this function will insert
a seed op before it to guarantee that two dropout op have the same outputs. a seed op before it to guarantee that two dropout op have the same outputs.
""" """
op_types = [op.desc.type() for op in self._ops] op_types = [op.type for op in self.ops]
if "dropout" not in op_types: if "dropout" not in op_types:
return return
op_idx = 0 op_idx = 0
while op_idx < len(self._ops): while op_idx < len(self.ops):
cur_op = self._ops[op_idx] cur_op = self.ops[op_idx]
if "grad" in cur_op.type: if "grad" in cur_op.type:
break break
if cur_op.type == "seed":
self._reserved_vars.extend(cur_op.output_arg_names)
op_idx += 1
continue
if cur_op.type != "dropout": if cur_op.type != "dropout":
op_idx += 1 op_idx += 1
continue continue
...@@ -188,7 +138,8 @@ class RecomputeState(ProgramStats): ...@@ -188,7 +138,8 @@ class RecomputeState(ProgramStats):
var_unique_name = unique_name.generate_with_ignorable_key( var_unique_name = unique_name.generate_with_ignorable_key(
".".join([op_unique_name, 'tmp']) ".".join([op_unique_name, 'tmp'])
) )
seed_var = self._block.create_var( self._reserved_vars.append(var_unique_name)
seed_var = self.block.create_var(
name=var_unique_name, name=var_unique_name,
dtype='int32', dtype='int32',
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
...@@ -209,7 +160,7 @@ class RecomputeState(ProgramStats): ...@@ -209,7 +160,7 @@ class RecomputeState(ProgramStats):
else int(cur_op.attr("seed")) else int(cur_op.attr("seed"))
) )
# TODO add dependency for seed op to ensure it be issued just before recompute. # TODO add dependency for seed op to ensure it be issued just before recompute.
seed_op = self._block._insert_op_without_sync( seed_op = self.block._insert_op_without_sync(
index=cur_op.idx, index=cur_op.idx,
type="seed", type="seed",
inputs={}, inputs={},
...@@ -223,7 +174,7 @@ class RecomputeState(ProgramStats): ...@@ -223,7 +174,7 @@ class RecomputeState(ProgramStats):
) )
# modify dropout op's desc # modify dropout op's desc
self._ops.insert(op_idx, seed_op) self.ops.insert(op_idx, seed_op)
cur_op.desc.set_input("Seed", [var_unique_name]) cur_op.desc.set_input("Seed", [var_unique_name])
cur_op._remove_attr("fix_seed") cur_op._remove_attr("fix_seed")
cur_op._remove_attr("seed") cur_op._remove_attr("seed")
...@@ -232,7 +183,7 @@ class RecomputeState(ProgramStats): ...@@ -232,7 +183,7 @@ class RecomputeState(ProgramStats):
) )
op_idx += 2 op_idx += 2
self._block._sync_with_cpp() self.block._sync_with_cpp()
def _find_op_index(block, cur_op): def _find_op_index(block, cur_op):
...@@ -242,7 +193,7 @@ def _find_op_index(block, cur_op): ...@@ -242,7 +193,7 @@ def _find_op_index(block, cur_op):
return -1 return -1
def _get_stop_gradients(program, no_grad_set): def _get_stop_gradients(program, no_grad_set=None):
"""get no grad var""" """get no grad var"""
if no_grad_set is None: if no_grad_set is None:
no_grad_set = set() no_grad_set = set()
...@@ -260,16 +211,15 @@ def _get_stop_gradients(program, no_grad_set): ...@@ -260,16 +211,15 @@ def _get_stop_gradients(program, no_grad_set):
def _add_needed_descs_to_block( def _add_needed_descs_to_block(
descs, block, main_block, in_memory_vars, dist_context descs, block, main_block, vars_should_be_hold, dist_context
): ):
""" """
Get the recomputed ops which will insert the backward part Get the recomputed ops which will insert the backward part
""" """
if len(descs) == 0: if len(descs) == 0:
return [] return []
result_descs = [] result_descs = []
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
for desc in descs: for desc in descs:
if isinstance(desc, framework.Operator): if isinstance(desc, framework.Operator):
desc = desc.desc desc = desc.desc
...@@ -279,22 +229,29 @@ def _add_needed_descs_to_block( ...@@ -279,22 +229,29 @@ def _add_needed_descs_to_block(
for name in desc.output_arg_names(): for name in desc.output_arg_names():
if main_block.has_var(name) and main_block.var(name).persistable: if main_block.has_var(name) and main_block.var(name).persistable:
continue continue
if name not in in_memory_vars: if name not in vars_should_be_hold:
is_needed = True is_needed = True
if is_needed: if is_needed:
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc) new_op_desc.copy_from(desc)
set_dist_op_desc_original_id(new_op_desc, desc, dist_context) set_dist_op_desc_original_id(new_op_desc, desc, dist_context)
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
result_descs.append(new_op_desc) result_descs.append(new_op_desc)
return result_descs return result_descs
def _find_op_path(main_program, loss, no_grad_set=None):
no_grad_set_name = _get_stop_gradients(main_program, no_grad_set)
op_path = _find_op_path_(
main_program.global_block(), [loss], [], no_grad_set_name
)
return op_path
@register_pass("auto_parallel_recompute") @register_pass("auto_parallel_recompute")
class RecomputePass(PassBase): class RecomputePass(PassBase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.set_attr("checkpoints", None)
self.set_attr("loss", None) self.set_attr("loss", None)
self.set_attr("dist_context", None) self.set_attr("dist_context", None)
self.set_attr("no_grad_set", None) self.set_attr("no_grad_set", None)
...@@ -311,49 +268,64 @@ class RecomputePass(PassBase): ...@@ -311,49 +268,64 @@ class RecomputePass(PassBase):
return True return True
def _apply_single_impl(self, main_program, startup_program, context): def _apply_single_impl(self, main_program, startup_program, context):
checkpoints = self.get_attr("checkpoints")
no_recompute_segments = self.get_attr("no_recompute_segments")
loss = self.get_attr("loss") loss = self.get_attr("loss")
no_grad_set = self.get_attr("no_grad_set") no_grad_set = self.get_attr("no_grad_set")
no_recompute_segments = self.get_attr("no_recompute_segments")
self._dist_context = self.get_attr("dist_context") self._dist_context = self.get_attr("dist_context")
# 0. get op_path which is related to loss # 0. get op_path which is related to loss
main_block = main_program.global_block() main_block = main_program.global_block()
no_grad_set_name = _get_stop_gradients(main_program, no_grad_set) op_path = _find_op_path(main_program, loss, no_grad_set)
op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name)
# 1. build recompute state # 1. build recompute state
rc_state = RecomputeState(main_block, op_path) rc_state = RecomputeState(main_block, op_path)
if not rc_state.is_recompute() and not checkpoints: if not rc_state.is_recompute():
return return
# 2. get the segments to be recomputed # 2. get the segments to be recomputed
rc_state.modify_forward_desc_for_recompute(self._dist_context) rc_state.modify_forward_desc_for_recompute(self._dist_context)
rc_state.build_stats() rc_state.build_states()
checkpoints = rc_state.sort_checkpoints(checkpoints or []) segments = rc_state.get_recompute_segments(no_recompute_segments)
segments, checkpoints = rc_state.get_recompute_segments( if segments == []:
checkpoints, no_recompute_segments
)
if segments == [] or checkpoints == []:
return return
for i, (idx1, idx2) in enumerate(segments):
logging.info(
"recompute segment[{}/{}]".format(i + 1, len(segments))
)
logging.info(
"segment start op: [{}]: [{}] [{}]".format(
rc_state.ops[idx1].type,
rc_state.ops[idx1].input_arg_names,
rc_state.ops[idx1].output_arg_names,
)
)
logging.info(
"segment end op: [{}]: [{}] [{}]".format(
rc_state.ops[idx2 - 1].type,
rc_state.ops[idx2 - 1].input_arg_names,
rc_state.ops[idx2 - 1].output_arg_names,
)
)
# 3. get vars that should be hold in memory # 3. get vars that should be hold in memory
vars_should_be_hold = [] vars_should_be_hold = []
for segment in segments: for segment in segments:
vars_should_be_hold.extend( vars_should_be_hold.extend(
rc_state.get_out_of_subgraph_vars(segment[0], segment[1]) rc_state.get_out_of_subgraph_vars(segment[0], segment[1])
) )
cross_vars = set(vars_should_be_hold) - set(checkpoints) cross_vars = set(vars_should_be_hold) - set(rc_state.checkpoints)
logging.info( logging.info(
"found [{}] vars which cross recompute segment: [{}]," "found [{}] vars which cross recompute segment: [{}],"
"better checkpoints might be set to reduce those vars".format( "better checkpoints might be set to reduce those vars".format(
len(cross_vars), cross_vars len(cross_vars), cross_vars
) )
) )
vars_should_be_hold.extend(rc_state.get_reserved_vars()) vars_should_be_hold.extend(rc_state.reserved_vars)
vars_should_be_hold.extend(rc_state.get_input_nodes()) vars_should_be_hold.extend(rc_state.get_input_nodes())
vars_should_be_hold = list(set(vars_should_be_hold)) vars_should_be_hold = list(
vars_in_memory = vars_should_be_hold + checkpoints set(vars_should_be_hold) | set(rc_state.checkpoints)
)
# 4. get the fwd ops desc to be recomputed. # 4. get the fwd ops desc to be recomputed.
var_name_dict = {} # varname --> varname.subprog_XXX var_name_dict = {} # varname --> varname.subprog_XXX
...@@ -364,20 +336,23 @@ class RecomputePass(PassBase): ...@@ -364,20 +336,23 @@ class RecomputePass(PassBase):
var_suffix = ".subprog_%d" % i var_suffix = ".subprog_%d" % i
for op in fwd_ops: for op in fwd_ops:
input_and_output_names = [] input_and_output_names = []
input_and_output_names.extend(op.desc.input_arg_names()) input_and_output_names.extend(op.input_arg_names)
input_and_output_names.extend(op.desc.output_arg_names()) input_and_output_names.extend(op.output_arg_names)
cur_op_dist_attr = ( cur_op_dist_attr = (
self._dist_context.get_op_dist_attr_for_program(op) self._dist_context.get_op_dist_attr_for_program(op)
) )
assert cur_op_dist_attr is not None assert cur_op_dist_attr is not None
for name in input_and_output_names: for name in input_and_output_names:
if main_block.var(name).persistable or name in checkpoints: if (
continue main_block.var(name).persistable
if name in vars_should_be_hold: or name in vars_should_be_hold
):
continue continue
if name not in var_name_dict: if name not in var_name_dict:
ref_process_mesh = cur_op_dist_attr.process_mesh ref_process_mesh = cur_op_dist_attr.process_mesh
if name in op.desc.input_arg_names(): if name in op.input_arg_names:
ref_dims_mapping = ( ref_dims_mapping = (
cur_op_dist_attr.get_input_dims_mapping(name) cur_op_dist_attr.get_input_dims_mapping(name)
) )
...@@ -385,6 +360,7 @@ class RecomputePass(PassBase): ...@@ -385,6 +360,7 @@ class RecomputePass(PassBase):
ref_dims_mapping = ( ref_dims_mapping = (
cur_op_dist_attr.get_output_dims_mapping(name) cur_op_dist_attr.get_output_dims_mapping(name)
) )
# record recomputed var's old_name and new_name (old_name.subprog_XXX) # record recomputed var's old_name and new_name (old_name.subprog_XXX)
# create new var with new name # create new var with new name
var_name_dict[name] = name + var_suffix var_name_dict[name] = name + var_suffix
...@@ -409,7 +385,7 @@ class RecomputePass(PassBase): ...@@ -409,7 +385,7 @@ class RecomputePass(PassBase):
fwd_ops, fwd_ops,
buffer_block, buffer_block,
main_block, main_block,
vars_in_memory, vars_should_be_hold,
self._dist_context, self._dist_context,
) )
# rename recomputed ops' input and output var name # rename recomputed ops' input and output var name
...@@ -437,15 +413,15 @@ class RecomputePass(PassBase): ...@@ -437,15 +413,15 @@ class RecomputePass(PassBase):
grad_op._remove_attr("fix_seed") grad_op._remove_attr("fix_seed")
grad_op._remove_attr("seed") grad_op._remove_attr("seed")
# rename grad op's var_name which is not in 'vars_in_memory' input_and_output_names = []
for key in var_name_dict: input_and_output_names.extend(grad_op.input_arg_names)
if ( input_and_output_names.extend(grad_op.output_arg_names)
key
not in grad_op.input_arg_names + grad_op.output_arg_names for varname in var_name_dict:
): if varname not in input_and_output_names:
continue continue
self.reset_op_dist_attr(grad_op, var_name_dict) self.reset_op_dist_attr(grad_op, var_name_dict)
_rename_arg_([grad_op.desc], key, var_name_dict[key]) _rename_arg_([grad_op.desc], varname, var_name_dict[varname])
# insert recomputed ops # insert recomputed ops
original_id = grad_op.desc.original_id() original_id = grad_op.desc.original_id()
...@@ -504,13 +480,13 @@ class RecomputePass(PassBase): ...@@ -504,13 +480,13 @@ class RecomputePass(PassBase):
def reset_op_dist_attr(self, op, var_name_dict): def reset_op_dist_attr(self, op, var_name_dict):
op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op) op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr is not None assert op_dist_attr is not None
for input in op.desc.input_arg_names(): for input in op.input_arg_names:
if input in var_name_dict.keys(): if input in var_name_dict.keys():
in_dist_attr = op_dist_attr.get_input_dist_attr(input) in_dist_attr = op_dist_attr.get_input_dist_attr(input)
op_dist_attr.set_input_dist_attr( op_dist_attr.set_input_dist_attr(
var_name_dict[input], in_dist_attr var_name_dict[input], in_dist_attr
) )
for output in op.desc.output_arg_names(): for output in op.output_arg_names:
if output in var_name_dict.keys(): if output in var_name_dict.keys():
out_dist_attr = op_dist_attr.get_output_dist_attr(output) out_dist_attr = op_dist_attr.get_output_dist_attr(output)
op_dist_attr.set_output_dist_attr( op_dist_attr.set_output_dist_attr(
......
...@@ -74,6 +74,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -74,6 +74,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_selective_recompute MODULES test_selective_recompute) py_test_modules(test_selective_recompute MODULES test_selective_recompute)
set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50) set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50)
py_test_modules(test_tuning_recompute MODULES test_tuning_recompute)
set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 240)
py_test_modules(test_while_op_completion MODULES test_while_op_completion py_test_modules(test_while_op_completion MODULES test_while_op_completion
ENVS ${dist_ENVS}) ENVS ${dist_ENVS})
......
...@@ -28,12 +28,9 @@ from auto_parallel_gpt_model import ( ...@@ -28,12 +28,9 @@ from auto_parallel_gpt_model import (
GPTPretrainingCriterion, GPTPretrainingCriterion,
) )
sequence_len = 512
vocab_size = 1000
class FakeDataset(paddle.io.Dataset): class FakeDataset(paddle.io.Dataset):
def __init__(self, num_samples): def __init__(self, num_samples, vocab_size=1000, sequence_len=512):
self.num_samples = num_samples self.num_samples = num_samples
self.sequence_len = sequence_len self.sequence_len = sequence_len
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -57,7 +54,7 @@ class FakeDataset(paddle.io.Dataset): ...@@ -57,7 +54,7 @@ class FakeDataset(paddle.io.Dataset):
return self.num_samples return self.num_samples
def create_data_holder(batch_size): def create_data_holder(batch_size, vocab_size=1000, sequence_len=512):
tokens = paddle.static.InputSpec( tokens = paddle.static.InputSpec(
name="tokens", shape=[batch_size, sequence_len], dtype='int64' name="tokens", shape=[batch_size, sequence_len], dtype='int64'
) )
......
...@@ -98,7 +98,7 @@ def train(fetch): ...@@ -98,7 +98,7 @@ def train(fetch):
tuning.profile_start_step = 1 tuning.profile_start_step = 1
tuning.profile_end_step = 5 tuning.profile_end_step = 5
tuning.run_after_tuning = True tuning.run_after_tuning = True
tuning.verbose = True tuning.debug = True
dataset = MyDataset(batch_num * batch_size) dataset = MyDataset(batch_num * batch_size)
engine = auto.Engine( engine = auto.Engine(
......
...@@ -24,7 +24,7 @@ class TestStrategy(unittest.TestCase): ...@@ -24,7 +24,7 @@ class TestStrategy(unittest.TestCase):
recompute = strategy.recompute recompute = strategy.recompute
self.assertEqual(recompute.enable, False) self.assertEqual(recompute.enable, False)
self.assertIsNone(recompute.checkpoints) self.assertEqual(recompute.checkpoints, [])
amp = strategy.amp amp = strategy.amp
self.assertEqual(amp.enable, False) self.assertEqual(amp.enable, False)
...@@ -66,12 +66,10 @@ class TestStrategy(unittest.TestCase): ...@@ -66,12 +66,10 @@ class TestStrategy(unittest.TestCase):
tuning = strategy.tuning tuning = strategy.tuning
self.assertEqual(tuning.enable, False) self.assertEqual(tuning.enable, False)
self.assertEqual(tuning.batch_size, 1)
self.assertIsNone(tuning.dataset)
self.assertEqual(tuning.profile_start_step, 1) self.assertEqual(tuning.profile_start_step, 1)
self.assertEqual(tuning.profile_end_step, 1) self.assertEqual(tuning.profile_end_step, 1)
self.assertEqual(tuning.run_after_tuning, True) self.assertEqual(tuning.run_after_tuning, True)
self.assertEqual(tuning.verbose, True) self.assertEqual(tuning.debug, False)
def test_modify_config(self): def test_modify_config(self):
strategy = auto.Strategy() strategy = auto.Strategy()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
from get_gpt_model import FakeDataset
import paddle
from paddle.distributed.fleet import auto
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import (
GPTForPretraining,
GPTModel,
GPTPretrainingCriterion,
)
def generate_model():
modeling.init_global()
modeling._global_parallel_strategy = "serial"
gpt = GPTModel(
vocab_size=50304,
hidden_size=1024,
num_hidden_layers=14,
num_attention_heads=16,
intermediate_size=1024 * 4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
use_new_recompute=True,
recompute_granularity="full",
)
model = GPTForPretraining(
gpt, vocab_size=50304, hidden_size=1024, initializer_range=0.02
)
criterion = GPTPretrainingCriterion()
return model, criterion
def apply_pass():
strategy = auto.Strategy()
strategy.auto_mode = "semi"
recompute = strategy.recompute
recompute.enable = True
recompute.enable_tuning = True
tuning = strategy.tuning
tuning.enable = True
tuning.profile_start_step = 1
tuning.profile_end_step = 2
tuning.run_after_tuning = True
tuning.verbose = True
return strategy
class TestRecomputePassTuning(unittest.TestCase):
def setUp(self):
self.batch_size = 8
self.batch_num = 200
self.dataset = FakeDataset(
self.batch_size * self.batch_num,
vocab_size=50304,
sequence_len=1024,
)
def test_recompute_pass(self):
strategy = apply_pass()
clip = paddle.nn.ClipGradByGlobalNorm(0.2)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model()
engine = auto.Engine(model, loss, opt, strategy=strategy)
engine._tune(self.dataset, 3, batch_size=self.batch_size)
assert (
len(
engine._dist_contexts[
'train'
].strategy.recompute.no_recompute_segments
)
> 0
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册