未验证 提交 170a31f9 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] recompute tuning (#48608)

* [AutoParallel] recompute tuning

* fix conflict

* update comment

* bug fix

* update rc algo

* tiny fix

* fix clear process_group

* remove comment

* update segment print

* fix import OpRole

* adapt amp pass and grad_clip pass for opt_tuner

* update tuning config

* fix import

* annotate recompute info on ops and upgrade recompute pass

* add op_namescope for seed op

* record reserved vars

* fix recompute var's dist_attr

* fix strategy unittest

* adapt for fp16

* update unittest

* revert copy opt

* update unittest

* rename set_recompute_segments

* fix unittest
上级 b9207054
......@@ -54,7 +54,7 @@ set_field_default_config(BASE, "reinit", False) # Only for debug
#########################################
RECOMPUTE = "recompute"
set_field_default_config(RECOMPUTE, "enable", False)
set_field_default_config(RECOMPUTE, "checkpoints", None)
set_field_default_config(RECOMPUTE, "checkpoints", [])
set_field_default_config(RECOMPUTE, "no_recompute_segments", [])
set_field_default_config(RECOMPUTE, "enable_tuning", False)
......@@ -113,12 +113,10 @@ set_field_default_config(QAT, "algo", None)
# #########################################
TUNING = "tuning"
set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1)
set_field_default_config(TUNING, "dataset", None)
set_field_default_config(TUNING, "profile_start_step", 1)
set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True)
set_field_default_config(TUNING, "debug", False)
#########################################
# dataset configuration
......
......@@ -609,7 +609,9 @@ class Engine:
if mode != "train":
serial_main_prog = serial_main_prog.clone(for_test=True)
auto_utils.set_recompute_ckpts(self._model, self._strategy)
auto_utils.set_recompute_segments(
self._model, self._losses, self._strategy, serial_main_prog
)
self._dist_contexts[mode] = DistributedContext(
serial_main_prog,
serial_startup_prog,
......@@ -649,7 +651,6 @@ class Engine:
from .tuner.optimization_tuner import OptimizationTuner
self._optimization_tuner = OptimizationTuner(
self._tuning.to_dict(),
self._dist_contexts[mode],
dataset,
self._inputs_spec,
......
......@@ -73,6 +73,10 @@ class BaseConfig:
setattr(result, k, copy.deepcopy(v, memo))
return result
def get(self, k, d=None):
result_dict = self.to_dict()
return result_dict.get(k, d)
class RecomputeConfig(BaseConfig):
def __init__(self, config_dict=None):
......
......@@ -16,7 +16,7 @@ import copy
import logging
from abc import ABC, abstractmethod
from ..utils import get_logger
from ..utils import get_logger, is_recompute_op
from .trial import OptimizationTunerTrial as Trial
from .trial import TrialStatus
......@@ -54,7 +54,7 @@ class AlgorithmBase(ABC):
def collect_model_info(self, main_prog, startup_prog):
"""
Collect the model static info (from programs) that could be used to
pruning candidate trials and saving tuning time.For instance,
pruning candidate trials and saving tuning time. For instance,
model info like number of model parameters and activation memory could be
used to prune candidated trial and decide the next trial.
"""
......@@ -116,7 +116,7 @@ class ShardingStageAlgorithm(AlgorithmBase):
self._max_stage = 3
self._trial_idx = 0
stage_range = self._config.sharding.to_dict().get("tuning_range", None)
stage_range = self._config.sharding.get("tuning_range", None)
if stage_range:
assert set(stage_range).issubset(
set([0, 1, 2, 3])
......@@ -157,3 +157,92 @@ class ShardingStageAlgorithm(AlgorithmBase):
)
else:
self._trial_idx += 1
@register_algor("recompute")
class ReccomputeCheckpointAlgorithm(AlgorithmBase):
def __init__(self, config):
super().__init__(config)
self._changed_configs = ["recompute"]
def collect_model_info(self, main_prog, startup_prog):
segments = []
for op in main_prog.global_block().ops:
if not is_recompute_op(op):
continue
seg_name = op.attr('op_namescope')
if seg_name not in segments:
segments.append(seg_name)
self._total_num_trial = len(segments)
self._tuning_segments = list(range(len(segments)))
self._trail_left = 0
self._trail_right = len(segments) - 1
self._trial_idx = int(0 + (len(segments) - 1) / 2)
def _init_spaces(self):
self._recompute_mode = "all"
def next_trial(self):
if self._trial_idx < self._total_num_trial:
if self._recompute_mode == "all":
self._recompute_flag = False
new_strategy = copy.deepcopy(self._config.dist_strategy)
name = "trial-recompute-all-segments"
return Trial(new_strategy, name, self.changed_configs)
elif self._recompute_mode == "none":
self._recompute_flag = False
new_strategy = copy.deepcopy(self._config.dist_strategy)
recompute = new_strategy.recompute
recompute.enable = False
name = "trial-recompute-none-segments"
return Trial(new_strategy, name, self.changed_configs)
elif self._recompute_mode == "part":
new_no_recompute = self._tuning_segments[: self._trial_idx]
new_strategy = copy.deepcopy(self._config.dist_strategy)
recompute = new_strategy.recompute
recompute.no_recompute_segments.extend(new_no_recompute)
name = "trial-recompute-part-segments-idx{}".format(
self._trial_idx
)
return Trial(new_strategy, name, self.changed_configs)
else:
return Trial(None, None, None, status=TrialStatus.STOPPED)
def update(self, results):
et = results.get("ErrorType", None)
if self._recompute_mode == "all":
if et and et == "ResourceExhaustedError":
self._trial_idx = self._total_num_trial
self._logger.info(
"Recompute all candidate segments is failed with OOM, please reduce model size or batch size."
)
else:
self._recompute_mode = "none"
elif self._recompute_mode == "none":
if et and et == "ResourceExhaustedError":
self._recompute_mode = "part"
else:
self._trial_idx = self._total_num_trial
self._logger.info(
"Recompute is unnecessary for this model size, which will reduce the Throughtput."
)
else:
if self._trail_left >= self._trail_right:
self._trial_idx = self._total_num_trial
elif et and et == "ResourceExhaustedError":
self._trail_left = self._trail_left
self._trail_right = self._trial_idx - 1
self._trial_idx = int(
self._trail_left
+ (self._trail_right - self._trail_left) / 2
)
else:
self._trail_left = self._trial_idx + 1
self._trail_right = self._trail_right
self._trial_idx = int(
self._trail_left
+ (self._trail_right - self._trail_left) / 2
)
......@@ -32,14 +32,11 @@ class TuningConfig:
tuning config: configuration for the tuning process: mode (profile or cost model), log dir, extra tuning config for optimization like search range for specific
"""
def __init__(self, user_config, strategy):
def __init__(self, strategy):
if not isinstance(strategy, Strategy):
raise TypeError("'strategy' must be object of class `Strategy`.")
if not user_config:
user_config = {}
self._tuning_passes_name = set()
self._dist_strategy = copy.deepcopy(strategy)
self._mode = None
......@@ -48,9 +45,9 @@ class TuningConfig:
self._project_dir = None
self._max_num_trial = None
self._early_stop = None
self._verbose = None
self._debug = None
self._initialize(user_config)
self._initialize()
@property
def mode(self):
......@@ -81,29 +78,25 @@ class TuningConfig:
return self._early_stop
@property
def verbose(self):
return self._verbose
def debug(self):
return self._debug
@property
def dist_strategy(self):
return self._dist_strategy
# initialize config with user define value or default value
def _initialize(self, user_config):
self._mode = user_config.get("mode", "PROFILE")
self._profile_start_step = user_config.get("profile_start_step", 10)
self._profile_end_step = user_config.get("profile_end_step", 30)
self._max_num_trial = user_config.get("max_num_trial", 50)
self._early_stop = user_config.get("early_stop", None)
def _initialize(self):
tuning_strategy = self._dist_strategy.tuning
self._verbose = user_config.get("verbose", False)
self._mode = tuning_strategy.get("mode", "PROFILE")
self._profile_start_step = tuning_strategy.get("profile_start_step", 10)
self._profile_end_step = tuning_strategy.get("profile_end_step", 30)
self._max_num_trial = tuning_strategy.get("max_num_trial", 50)
self._early_stop = tuning_strategy.get("early_stop", None)
self._debug = tuning_strategy.get("debug", False)
project_dir = user_config.get("project_dir", None)
project_dir = tuning_strategy.get("project_dir", None)
if not project_dir:
project_dir = os.path.join(os.getcwd(), "OptimizationTuning")
self._project_dir = project_dir
......@@ -116,15 +109,14 @@ class TuningConfig:
# TODO distinguish different args of each passes
self._tuning_passes_name.add(p)
config_name = p
p_dict = getattr(self._dist_strategy, config_name)
self.__dict__[config_name] = p_dict
p_strategy = getattr(self._dist_strategy, p)
self.__dict__[p] = p_strategy
# TODO verify the user defined configs
user_config_for_pass = user_config.get(p, None)
if user_config_for_pass:
for k, v in user_config_for_pass.items():
self.__dict__[config_name][k] = v
# # TODO verify the user defined configs
# tuning_config_for_pass = tuning_strategy.get(p, None)
# if tuning_config_for_pass:
# for k, v in tuning_config_for_pass.items():
# self.__dict__[p][k] = v
# (NOTE)tuning config ONLY wraps dist strategy for pass config which is to be tuned
def __getattr__(self, item):
......
......@@ -33,6 +33,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.process_group import (
clear_all_process_groups,
get_all_process_groups,
new_process_group,
)
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import (
......@@ -40,7 +41,7 @@ from paddle.distributed.auto_parallel.utils import (
set_grad_var_shape,
)
from paddle.distributed.passes import PassContext, new_pass
from paddle.fluid import program_guard
from paddle.fluid import program_guard, unique_name
from paddle.fluid.backward import append_backward
from ..utils import get_logger
......@@ -109,7 +110,12 @@ def parse_results(results):
# all env need to be start a new pass are member of dist context
def _copy_context(ref_dist_context):
# clear all process groups and recover the world process group
clear_all_process_groups()
ranks = []
for process_mesh in ref_dist_context._process_meshes:
ranks.extend(process_mesh.processes)
new_process_group(list(set(ranks)))
new_dist_context = DistributedContext()
new_dist_context._serial_main_program = (
......@@ -195,7 +201,6 @@ class OptimizationTuner:
def __init__(
self,
user_configs,
dist_context,
dataset,
inputs_spec,
......@@ -204,7 +209,7 @@ class OptimizationTuner:
rank,
):
self._config = TuningConfig(user_configs, dist_context._strategy)
self._config = TuningConfig(dist_context.strategy)
# should not modify dist context from calling function
self._baseline_dist_context = _copy_context(dist_context)
self._baseline_completer = Completer(self._baseline_dist_context)
......@@ -264,7 +269,7 @@ class OptimizationTuner:
)
self._baseline_dist_context._params_grads = params_grads
if self._config.verbose:
if self._config.debug:
baseline_dir = os.path.join(self.project_dir, "baseline")
if not os.path.exists(baseline_dir):
pathlib.Path(baseline_dir).mkdir(parents=True, exist_ok=True)
......@@ -299,7 +304,6 @@ class OptimizationTuner:
config = copy.deepcopy(new_strategy.amp.to_dict())
config["dist_context"] = dist_context
config["params_grads"] = dist_context._params_grads
# TODO AMP Pass should not use loss var
config["loss"] = dist_context.serial_loss
config["input_data"] = (
......@@ -312,13 +316,13 @@ class OptimizationTuner:
auto_parallel_fp16_pass.apply(
[main_program], [startup_program], pass_context
)
dist_context.serial_loss = auto_parallel_fp16_pass.get_loss()
dist_context._serial_loss = auto_parallel_fp16_pass.get_loss()
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply(
[main_program], [startup_program], pass_context
)
dist_context.serial_loss = auto_parallel_amp_pass.get_loss()
dist_context._serial_loss = auto_parallel_amp_pass.get_loss()
if new_strategy.recompute.enable:
config = copy.deepcopy(new_strategy.recompute.to_dict())
......@@ -345,6 +349,7 @@ class OptimizationTuner:
# Generate optimizer
# FIXME should be remove from apply pass after pass support optimizers
with program_guard(dist_main_prog, dist_startup_prog):
with unique_name.guard("opt_"):
optimizer_ops = dist_context.serial_optimizer.apply_gradients(
dist_params_grads
)
......@@ -361,6 +366,13 @@ class OptimizationTuner:
)
resharder.reshard()
config = {}
config["dist_context"] = dist_context
config["global_rank"] = self.rank
config["use_sharding"] = new_strategy.sharding.enable
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
dp_pass.apply([dist_main_prog], [dist_startup_prog], pass_context)
if new_strategy.sharding.enable:
config = copy.deepcopy(new_strategy.sharding.to_dict())
config["dist_context"] = dist_context
......@@ -372,6 +384,17 @@ class OptimizationTuner:
auto_parallel_sharding_pass.apply(
[dist_main_prog], [dist_startup_prog], pass_context
)
dist_params_grads = pass_context.get_attr("params_grads")
# gradient clip
config = copy.deepcopy(new_strategy.sharding.to_dict())
config["dist_context"] = dist_context
config["params_grads"] = dist_params_grads
config["rank_id"] = self.rank
auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config)
auto_parallel_clip_pass.apply(
[dist_main_prog], [dist_startup_prog], pass_context
)
if new_strategy.gradient_merge.enable:
config = copy.deepcopy(new_strategy.gradient_merge.to_dict())
......@@ -488,7 +511,7 @@ class OptimizationTuner:
with open(ctx_path, 'wb') as f:
pickle.dump(profile_ctx, f, protocol=4)
if self._config.verbose:
if self._config.debug:
debug_program(trial.main_program, trial_dir, "main_program")
debug_program(trial.startup_program, trial_dir, "startup_program")
......@@ -581,7 +604,7 @@ The best trial is: [{}], whose configuration is following:
Clear the temporary file generated in tuning procedure.
"""
# TODO clear up zombie process created by tuning
if not self._config.verbose:
if not self._config.debug:
for trial in self._finished_trials:
trial_dir = self._get_trial_dir(trial)
shutil.rmtree(trial_dir, ignore_errors=True)
......
......@@ -89,7 +89,7 @@ def init_process_groups(group_map, rank):
# TODO should instantiate global group first
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if process_group.id == 0 or rank not in process_group.ranks:
if rank not in process_group.ranks:
continue
print(process_group)
process_group.instantiate()
......@@ -173,11 +173,12 @@ def init_comm(profile_ctx):
genv = _get_global_env()
genv = dist_env
print(
"current process rank: {}, device_id: {}, ip: {}.",
"current process rank: {}, device_id: {}, ip: {}.".format(
genv.rank,
genv.device_id,
genv.current_endpoint,
)
)
# init nccl comm
group_map = profile_ctx['group_map']
......@@ -231,13 +232,12 @@ def profiler(args):
exe = get_executor()
try:
exe.run(startup_program)
# profile main
duration = 0
eval_step = 0
data_loader._inner_dataloader.start()
try:
while eval_step < args.profile_end_step:
start_time = time.time()
......
......@@ -22,18 +22,17 @@ from functools import reduce
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.distributed.auto_parallel.dist_attribute import (
from paddle.fluid.framework import Variable
from paddle.fluid.io import is_belong_to_optimizer, is_parameter
from paddle.framework import core
from .dist_attribute import (
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from paddle.distributed.auto_parallel.process_group import (
get_all_process_groups,
)
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.framework import Variable
from paddle.fluid.io import is_belong_to_optimizer, is_parameter
from .process_group import get_all_process_groups
OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
__no_shape_var_type__ = [
......@@ -1921,10 +1920,16 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):
server_socket.close()
def set_recompute_ckpts(model, strategy):
from .interface import _g_recompute_idx
def is_recompute_op(op):
return op.has_attr('op_namescope') and "/auto_parallel/rc" in op.attr(
'op_namescope'
)
if _g_recompute_idx > -1:
def set_recompute_segments(model, losses, strategy, program):
from ..passes.auto_parallel_recompute import RecomputeState
if not losses:
return
recompute = strategy.recompute
......@@ -1934,24 +1939,65 @@ def set_recompute_ckpts(model, strategy):
# NOTE: hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here
# extract ckpts by specific model
ckpts = []
if isinstance(model, paddle.nn.Layer):
if hasattr(model, "gpt") and model.__class__.__name__ in [
if (
hasattr(model, "gpt")
and model.__class__.__name__
in [
'GPTForPretraining',
'GPTForPretrainingAuto',
]:
exact_ckpts = model.gpt.checkpoints
]
and hasattr(model.gpt, "checkpoints")
):
ckpts = model.gpt.checkpoints
else:
exact_ckpts = recompute.checkpoints
ckpts = recompute.checkpoints
else:
exact_ckpts = recompute.checkpoints
ckpts = recompute.checkpoints
# modify strategy
recompute.checkpoints = exact_ckpts[:]
logs = {
'Model Class': model.__class__.__name__,
'Applied Recompute ckpts': exact_ckpts,
}
logging.info(logs)
if not ckpts:
return
block = program.global_block()
rc_state = RecomputeState(block, block.ops)
rc_state.build_stats()
checkpoints = rc_state.sort_checkpoints(ckpts)
segments = []
start_idx = -1
pre_segment_end_idx = -1
while start_idx + 1 < len(checkpoints):
if start_idx == -1:
ckpt_name = checkpoints[start_idx + 1]
if ckpt_name not in rc_state.var_op_deps:
start_idx += 1
continue
op_idx_list = rc_state.var_op_deps[ckpt_name]["var_as_output_ops"]
if op_idx_list and max(op_idx_list) > 0:
segments.append([0, max(op_idx_list) + 1])
else:
flag, min_idx, max_idx = rc_state.is_subgraph(
[checkpoints[start_idx]], [checkpoints[start_idx + 1]]
)
if flag:
min_idx = rc_state._update_segment_start(
min_idx, pre_segment_end_idx
)
segments.append([min_idx, max_idx + 1])
else:
logging.debug(
"Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1
)
)
start_idx += 1
for i, segment in enumerate(segments):
for j in range(segment[0], segment[1]):
block.ops[j]._set_attr(
'op_namescope', "/auto_parallel/rc_" + str(i)
)
def get_input_split_info(cur_rank, var, dist_context):
......
......@@ -226,6 +226,9 @@ class AMPState:
dist_context, out_var, ref_mapping, ref_mesh
)
op_namescope = "/"
if op.has_attr('op_namescope'):
op_namescope = op.attr('op_namescope')
cast_op = self._block._insert_op_without_sync(
idx,
type="cast",
......@@ -236,6 +239,9 @@ class AMPState:
"out_dtype": out_var.dtype,
},
)
cast_op._set_attr(
'op_namescope', op_namescope
) # for recompute
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context
)
......
......@@ -22,13 +22,12 @@ from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.utils import (
OP_ROLE_KEY,
OpRole,
is_backward_op,
is_forward_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
)
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.fluid import unique_name
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
AutoMixedPrecisionLists,
......@@ -417,6 +416,9 @@ class FP16State:
dist_context, cast_var, ref_mapping, ref_mesh
)
op_namescope = "/"
if op.has_attr('op_namescope'):
op_namescope = op.attr('op_namescope')
cast_op = block._insert_op_without_sync(
idx,
type="cast",
......@@ -428,6 +430,9 @@ class FP16State:
OP_ROLE_KEY: OpRole.Forward,
},
)
cast_op._set_attr(
'op_namescope', op_namescope
) # for recompute
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context
)
......
......@@ -17,6 +17,7 @@ from functools import reduce
import numpy as np
import paddle
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from ..auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
......@@ -25,8 +26,6 @@ from ..auto_parallel.dist_attribute import (
from ..auto_parallel.process_group import get_world_process_group
from ..auto_parallel.reshard import Resharder
from ..auto_parallel.utils import (
OP_ROLE_KEY,
OpRole,
_get_comm_group,
insert_dependencies_for_two_vars,
is_gradient_clip_op,
......
......@@ -19,12 +19,11 @@ from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.utils import (
OP_ROLE_KEY,
OpRole,
is_optimize_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
)
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.fluid import layers
from paddle.fluid.framework import device_guard
from paddle.framework import core
......
......@@ -14,19 +14,8 @@
import logging
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
)
from paddle.distributed.auto_parallel.utils import (
get_loss_op,
insert_dependencies_for_two_ops,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_dist_op_desc_original_id,
set_var_dist_attr,
)
from paddle.fluid import core
from paddle.fluid import framework as framework
from paddle.fluid import unique_name
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.fluid import core, framework, unique_name
from paddle.fluid.backward import (
ProgramStats,
_append_grad_suffix_,
......@@ -35,28 +24,43 @@ from paddle.fluid.backward import (
_rename_arg_,
)
from ..auto_parallel.dist_attribute import OperatorDistributedAttribute
from ..auto_parallel.utils import (
get_loss_op,
insert_dependencies_for_two_ops,
is_backward_op,
is_recompute_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_dist_op_desc_original_id,
set_var_dist_attr,
)
from .pass_base import PassBase, register_pass
def _to_be_recomputed(op):
return op.has_attr('op_namescope') and "/auto_parallel/rc_" in op.attr(
'op_namescope'
)
class RecomputeState(ProgramStats):
def __init__(self, block, ops):
super().__init__(block=block, ops=ops)
self._block = block
self._ops = ops
# {varname: {as_input_ops: op_idx, as_output_ops: op_idx}}
self.var_op_deps = {}
# {segment_name: op_idx}
self.seg_op_deps = {}
self._checkpoints = []
self._reserved_vars = []
@property
def checkpoints(self):
return self._checkpoints
@property
def reserved_vars(self):
return self._reserved_vars
def is_recompute(self):
return any([is_recompute_op(op) for op in self.ops])
def build_stats(self):
for i, op in enumerate(self._ops):
for name in op.desc.input_arg_names():
def build_states(self):
for i, op in enumerate(self.ops):
if is_backward_op(op):
break
for name in op.input_arg_names:
if name in self.var_op_deps:
self.var_op_deps[name]["var_as_input_ops"].extend([i])
else:
......@@ -64,7 +68,7 @@ class RecomputeState(ProgramStats):
self.var_op_deps[name]["var_as_input_ops"] = [i]
self.var_op_deps[name]["var_as_output_ops"] = []
for name in op.desc.output_arg_names():
for name in op.output_arg_names:
if name in self.var_op_deps:
self.var_op_deps[name]["var_as_output_ops"].extend([i])
else:
......@@ -72,7 +76,8 @@ class RecomputeState(ProgramStats):
self.var_op_deps[name]["var_as_input_ops"] = []
self.var_op_deps[name]["var_as_output_ops"] = [i]
if not _to_be_recomputed(op):
if not is_recompute_op(op):
self._checkpoints.extend(op.output_arg_names)
continue
seg_name = op.attr('op_namescope')
......@@ -84,53 +89,14 @@ class RecomputeState(ProgramStats):
), "The recompute segment's ops should be continuous"
self.seg_op_deps[seg_name].extend([i])
def get_recompute_segments(
self, checkpoints_list=None, no_recompute_segments=[]
):
"""get recompute segments and checkpoints"""
def get_recompute_segments(self, no_recompute_segments=[]):
segments = []
checkpoints = checkpoints_list or []
if len(checkpoints) == 0:
# the segments is marked by `auto.recompute()` api
for segment_idx in self.seg_op_deps.values():
if len(segment_idx) == 1:
continue
segments.append([segment_idx[0], segment_idx[-1] + 1])
checkpoints.extend(self._ops[segment_idx[-1]].output_arg_names)
else:
# the segments is marked by `strategy.checkpoints` api
start_idx = -1
pre_segment_end_idx = -1
while start_idx + 1 < len(checkpoints):
if start_idx == -1:
ckpt_name = checkpoints[start_idx + 1]
if ckpt_name not in self.var_op_deps:
start_idx += 1
continue
op_idx_list = self.var_op_deps[ckpt_name][
"var_as_output_ops"
]
if op_idx_list:
segments.append([0, max(op_idx_list) + 1])
else:
flag, min_idx, max_idx = self.is_subgraph(
[checkpoints[start_idx]], [checkpoints[start_idx + 1]]
)
if flag:
min_idx = self._update_segment_start(
min_idx, pre_segment_end_idx
)
segments.append([min_idx, max_idx + 1])
else:
logging.info(
"Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1
)
)
start_idx += 1
self._checkpoints.extend(self.ops[segment_idx[-1]].output_arg_names)
if no_recompute_segments:
for i in reversed(sorted(no_recompute_segments)):
assert i < len(
segments
......@@ -139,42 +105,26 @@ class RecomputeState(ProgramStats):
)
segments.pop(i)
for i, (idx1, idx2) in enumerate(segments):
logging.info("recompute segment[{}]".format(i))
logging.info(
"segment start op: [{}]: [{}] [{}]".format(
self._ops[idx1].desc.type(),
self._ops[idx1].desc.input_arg_names(),
self._ops[idx1].desc.output_arg_names(),
)
)
logging.info(
"segment end op: [{}]: [{}] [{}]".format(
self._ops[idx2 - 1].desc.type(),
self._ops[idx2 - 1].desc.input_arg_names(),
self._ops[idx2 - 1].desc.output_arg_names(),
)
)
return segments, checkpoints
def is_recompute(self):
return any([_to_be_recomputed(op) for op in self._ops])
return segments
def modify_forward_desc_for_recompute(self, dist_context):
"""
If program's foward part has 'dropout' op, this function will insert
a seed op before it to guarantee that two dropout op have the same outputs.
"""
op_types = [op.desc.type() for op in self._ops]
op_types = [op.type for op in self.ops]
if "dropout" not in op_types:
return
op_idx = 0
while op_idx < len(self._ops):
cur_op = self._ops[op_idx]
while op_idx < len(self.ops):
cur_op = self.ops[op_idx]
if "grad" in cur_op.type:
break
if cur_op.type == "seed":
self._reserved_vars.extend(cur_op.output_arg_names)
op_idx += 1
continue
if cur_op.type != "dropout":
op_idx += 1
continue
......@@ -188,7 +138,8 @@ class RecomputeState(ProgramStats):
var_unique_name = unique_name.generate_with_ignorable_key(
".".join([op_unique_name, 'tmp'])
)
seed_var = self._block.create_var(
self._reserved_vars.append(var_unique_name)
seed_var = self.block.create_var(
name=var_unique_name,
dtype='int32',
type=core.VarDesc.VarType.LOD_TENSOR,
......@@ -209,7 +160,7 @@ class RecomputeState(ProgramStats):
else int(cur_op.attr("seed"))
)
# TODO add dependency for seed op to ensure it be issued just before recompute.
seed_op = self._block._insert_op_without_sync(
seed_op = self.block._insert_op_without_sync(
index=cur_op.idx,
type="seed",
inputs={},
......@@ -223,7 +174,7 @@ class RecomputeState(ProgramStats):
)
# modify dropout op's desc
self._ops.insert(op_idx, seed_op)
self.ops.insert(op_idx, seed_op)
cur_op.desc.set_input("Seed", [var_unique_name])
cur_op._remove_attr("fix_seed")
cur_op._remove_attr("seed")
......@@ -232,7 +183,7 @@ class RecomputeState(ProgramStats):
)
op_idx += 2
self._block._sync_with_cpp()
self.block._sync_with_cpp()
def _find_op_index(block, cur_op):
......@@ -242,7 +193,7 @@ def _find_op_index(block, cur_op):
return -1
def _get_stop_gradients(program, no_grad_set):
def _get_stop_gradients(program, no_grad_set=None):
"""get no grad var"""
if no_grad_set is None:
no_grad_set = set()
......@@ -260,16 +211,15 @@ def _get_stop_gradients(program, no_grad_set):
def _add_needed_descs_to_block(
descs, block, main_block, in_memory_vars, dist_context
descs, block, main_block, vars_should_be_hold, dist_context
):
"""
Get the recomputed ops which will insert the backward part
"""
if len(descs) == 0:
return []
result_descs = []
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
for desc in descs:
if isinstance(desc, framework.Operator):
desc = desc.desc
......@@ -279,22 +229,29 @@ def _add_needed_descs_to_block(
for name in desc.output_arg_names():
if main_block.has_var(name) and main_block.var(name).persistable:
continue
if name not in in_memory_vars:
if name not in vars_should_be_hold:
is_needed = True
if is_needed:
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
set_dist_op_desc_original_id(new_op_desc, desc, dist_context)
new_op_desc._set_attr(op_role_attr_name, backward)
new_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
result_descs.append(new_op_desc)
return result_descs
def _find_op_path(main_program, loss, no_grad_set=None):
no_grad_set_name = _get_stop_gradients(main_program, no_grad_set)
op_path = _find_op_path_(
main_program.global_block(), [loss], [], no_grad_set_name
)
return op_path
@register_pass("auto_parallel_recompute")
class RecomputePass(PassBase):
def __init__(self):
super().__init__()
self.set_attr("checkpoints", None)
self.set_attr("loss", None)
self.set_attr("dist_context", None)
self.set_attr("no_grad_set", None)
......@@ -311,49 +268,64 @@ class RecomputePass(PassBase):
return True
def _apply_single_impl(self, main_program, startup_program, context):
checkpoints = self.get_attr("checkpoints")
no_recompute_segments = self.get_attr("no_recompute_segments")
loss = self.get_attr("loss")
no_grad_set = self.get_attr("no_grad_set")
no_recompute_segments = self.get_attr("no_recompute_segments")
self._dist_context = self.get_attr("dist_context")
# 0. get op_path which is related to loss
main_block = main_program.global_block()
no_grad_set_name = _get_stop_gradients(main_program, no_grad_set)
op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name)
op_path = _find_op_path(main_program, loss, no_grad_set)
# 1. build recompute state
rc_state = RecomputeState(main_block, op_path)
if not rc_state.is_recompute() and not checkpoints:
if not rc_state.is_recompute():
return
# 2. get the segments to be recomputed
rc_state.modify_forward_desc_for_recompute(self._dist_context)
rc_state.build_stats()
checkpoints = rc_state.sort_checkpoints(checkpoints or [])
segments, checkpoints = rc_state.get_recompute_segments(
checkpoints, no_recompute_segments
)
if segments == [] or checkpoints == []:
rc_state.build_states()
segments = rc_state.get_recompute_segments(no_recompute_segments)
if segments == []:
return
for i, (idx1, idx2) in enumerate(segments):
logging.info(
"recompute segment[{}/{}]".format(i + 1, len(segments))
)
logging.info(
"segment start op: [{}]: [{}] [{}]".format(
rc_state.ops[idx1].type,
rc_state.ops[idx1].input_arg_names,
rc_state.ops[idx1].output_arg_names,
)
)
logging.info(
"segment end op: [{}]: [{}] [{}]".format(
rc_state.ops[idx2 - 1].type,
rc_state.ops[idx2 - 1].input_arg_names,
rc_state.ops[idx2 - 1].output_arg_names,
)
)
# 3. get vars that should be hold in memory
vars_should_be_hold = []
for segment in segments:
vars_should_be_hold.extend(
rc_state.get_out_of_subgraph_vars(segment[0], segment[1])
)
cross_vars = set(vars_should_be_hold) - set(checkpoints)
cross_vars = set(vars_should_be_hold) - set(rc_state.checkpoints)
logging.info(
"found [{}] vars which cross recompute segment: [{}],"
"better checkpoints might be set to reduce those vars".format(
len(cross_vars), cross_vars
)
)
vars_should_be_hold.extend(rc_state.get_reserved_vars())
vars_should_be_hold.extend(rc_state.reserved_vars)
vars_should_be_hold.extend(rc_state.get_input_nodes())
vars_should_be_hold = list(set(vars_should_be_hold))
vars_in_memory = vars_should_be_hold + checkpoints
vars_should_be_hold = list(
set(vars_should_be_hold) | set(rc_state.checkpoints)
)
# 4. get the fwd ops desc to be recomputed.
var_name_dict = {} # varname --> varname.subprog_XXX
......@@ -364,20 +336,23 @@ class RecomputePass(PassBase):
var_suffix = ".subprog_%d" % i
for op in fwd_ops:
input_and_output_names = []
input_and_output_names.extend(op.desc.input_arg_names())
input_and_output_names.extend(op.desc.output_arg_names())
input_and_output_names.extend(op.input_arg_names)
input_and_output_names.extend(op.output_arg_names)
cur_op_dist_attr = (
self._dist_context.get_op_dist_attr_for_program(op)
)
assert cur_op_dist_attr is not None
for name in input_and_output_names:
if main_block.var(name).persistable or name in checkpoints:
continue
if name in vars_should_be_hold:
if (
main_block.var(name).persistable
or name in vars_should_be_hold
):
continue
if name not in var_name_dict:
ref_process_mesh = cur_op_dist_attr.process_mesh
if name in op.desc.input_arg_names():
if name in op.input_arg_names:
ref_dims_mapping = (
cur_op_dist_attr.get_input_dims_mapping(name)
)
......@@ -385,6 +360,7 @@ class RecomputePass(PassBase):
ref_dims_mapping = (
cur_op_dist_attr.get_output_dims_mapping(name)
)
# record recomputed var's old_name and new_name (old_name.subprog_XXX)
# create new var with new name
var_name_dict[name] = name + var_suffix
......@@ -409,7 +385,7 @@ class RecomputePass(PassBase):
fwd_ops,
buffer_block,
main_block,
vars_in_memory,
vars_should_be_hold,
self._dist_context,
)
# rename recomputed ops' input and output var name
......@@ -437,15 +413,15 @@ class RecomputePass(PassBase):
grad_op._remove_attr("fix_seed")
grad_op._remove_attr("seed")
# rename grad op's var_name which is not in 'vars_in_memory'
for key in var_name_dict:
if (
key
not in grad_op.input_arg_names + grad_op.output_arg_names
):
input_and_output_names = []
input_and_output_names.extend(grad_op.input_arg_names)
input_and_output_names.extend(grad_op.output_arg_names)
for varname in var_name_dict:
if varname not in input_and_output_names:
continue
self.reset_op_dist_attr(grad_op, var_name_dict)
_rename_arg_([grad_op.desc], key, var_name_dict[key])
_rename_arg_([grad_op.desc], varname, var_name_dict[varname])
# insert recomputed ops
original_id = grad_op.desc.original_id()
......@@ -504,13 +480,13 @@ class RecomputePass(PassBase):
def reset_op_dist_attr(self, op, var_name_dict):
op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr is not None
for input in op.desc.input_arg_names():
for input in op.input_arg_names:
if input in var_name_dict.keys():
in_dist_attr = op_dist_attr.get_input_dist_attr(input)
op_dist_attr.set_input_dist_attr(
var_name_dict[input], in_dist_attr
)
for output in op.desc.output_arg_names():
for output in op.output_arg_names:
if output in var_name_dict.keys():
out_dist_attr = op_dist_attr.get_output_dist_attr(output)
op_dist_attr.set_output_dist_attr(
......
......@@ -74,6 +74,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_selective_recompute MODULES test_selective_recompute)
set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50)
py_test_modules(test_tuning_recompute MODULES test_tuning_recompute)
set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 240)
py_test_modules(test_while_op_completion MODULES test_while_op_completion
ENVS ${dist_ENVS})
......
......@@ -28,12 +28,9 @@ from auto_parallel_gpt_model import (
GPTPretrainingCriterion,
)
sequence_len = 512
vocab_size = 1000
class FakeDataset(paddle.io.Dataset):
def __init__(self, num_samples):
def __init__(self, num_samples, vocab_size=1000, sequence_len=512):
self.num_samples = num_samples
self.sequence_len = sequence_len
self.vocab_size = vocab_size
......@@ -57,7 +54,7 @@ class FakeDataset(paddle.io.Dataset):
return self.num_samples
def create_data_holder(batch_size):
def create_data_holder(batch_size, vocab_size=1000, sequence_len=512):
tokens = paddle.static.InputSpec(
name="tokens", shape=[batch_size, sequence_len], dtype='int64'
)
......
......@@ -98,7 +98,7 @@ def train(fetch):
tuning.profile_start_step = 1
tuning.profile_end_step = 5
tuning.run_after_tuning = True
tuning.verbose = True
tuning.debug = True
dataset = MyDataset(batch_num * batch_size)
engine = auto.Engine(
......
......@@ -24,7 +24,7 @@ class TestStrategy(unittest.TestCase):
recompute = strategy.recompute
self.assertEqual(recompute.enable, False)
self.assertIsNone(recompute.checkpoints)
self.assertEqual(recompute.checkpoints, [])
amp = strategy.amp
self.assertEqual(amp.enable, False)
......@@ -66,12 +66,10 @@ class TestStrategy(unittest.TestCase):
tuning = strategy.tuning
self.assertEqual(tuning.enable, False)
self.assertEqual(tuning.batch_size, 1)
self.assertIsNone(tuning.dataset)
self.assertEqual(tuning.profile_start_step, 1)
self.assertEqual(tuning.profile_end_step, 1)
self.assertEqual(tuning.run_after_tuning, True)
self.assertEqual(tuning.verbose, True)
self.assertEqual(tuning.debug, False)
def test_modify_config(self):
strategy = auto.Strategy()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
from get_gpt_model import FakeDataset
import paddle
from paddle.distributed.fleet import auto
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import (
GPTForPretraining,
GPTModel,
GPTPretrainingCriterion,
)
def generate_model():
modeling.init_global()
modeling._global_parallel_strategy = "serial"
gpt = GPTModel(
vocab_size=50304,
hidden_size=1024,
num_hidden_layers=14,
num_attention_heads=16,
intermediate_size=1024 * 4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
use_new_recompute=True,
recompute_granularity="full",
)
model = GPTForPretraining(
gpt, vocab_size=50304, hidden_size=1024, initializer_range=0.02
)
criterion = GPTPretrainingCriterion()
return model, criterion
def apply_pass():
strategy = auto.Strategy()
strategy.auto_mode = "semi"
recompute = strategy.recompute
recompute.enable = True
recompute.enable_tuning = True
tuning = strategy.tuning
tuning.enable = True
tuning.profile_start_step = 1
tuning.profile_end_step = 2
tuning.run_after_tuning = True
tuning.verbose = True
return strategy
class TestRecomputePassTuning(unittest.TestCase):
def setUp(self):
self.batch_size = 8
self.batch_num = 200
self.dataset = FakeDataset(
self.batch_size * self.batch_num,
vocab_size=50304,
sequence_len=1024,
)
def test_recompute_pass(self):
strategy = apply_pass()
clip = paddle.nn.ClipGradByGlobalNorm(0.2)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model()
engine = auto.Engine(model, loss, opt, strategy=strategy)
engine._tune(self.dataset, 3, batch_size=self.batch_size)
assert (
len(
engine._dist_contexts[
'train'
].strategy.recompute.no_recompute_segments
)
> 0
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册