未验证 提交 72f2ed43 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto parallel] Optimization Tuning (#43782)

* fixed bug for pass & engine

* fixed bug for benchmark GPT-3

* add tuner & profiler

* add algorithms & config
上级 da3743fd
......@@ -27,6 +27,7 @@ message RecomputeConfig {
repeated string checkpoints = 1;
optional bool enable_offload = 2 [ default = false ];
repeated int32 checkpoint_shape = 3;
optional bool enable_tuning = 4 [ default = false ]; // incubate for auto parallel
}
message ShardingConfig {
......@@ -46,6 +47,7 @@ message ShardingConfig {
// Optimizer sharding. Temporary plans and may be deprecated
optional bool _dp_as_optimizer_sharding = 13 [ default = false ];
optional int32 stage = 14 [ default = 1 ];
optional bool enable_tuning = 15 [ default = false ]; // incubate for auto parallel
}
message HybridConfig {
......
......@@ -58,7 +58,8 @@ class Engine:
inputs_spec=None,
labels_spec=None,
cluster=None,
strategy=None):
strategy=None,
user_tuning_config=None):
self.model = model
self.inputs_spec = self._validate_spec(inputs_spec)
self.labels_spec = self._validate_spec(labels_spec)
......@@ -68,6 +69,7 @@ class Engine:
self.strategy = strategy
if self.strategy is None:
self.strategy = fleet.DistributedStrategy()
self._user_tuning_config = user_tuning_config
self._executor = None
self._cur_rank = paddle.distributed.get_rank()
......@@ -127,19 +129,21 @@ class Engine:
self._prepare_single_mode("train")
def _prepare_single_mode(self, mode):
self._modes = [mode]
self._build(self._modes[0])
# Do auto parallel process
for mode in self._modes:
# Do the planning process
self._plan(mode)
for mode in self._modes:
# Do the parallel process
self._parallel(mode, self._all_ranks)
# Init comm and startup program
self._initialize(mode)
self._mode_init_states[mode] = True
self._build(mode)
# Do the planning process
self._plan(mode)
# Do the Optimization tuning
if self._user_tuning_config and mode == "train":
self._optimization_tuning(mode)
# Do the parallel process
self._parallel(mode, self._all_ranks)
# Init comm and startup program
self._initialize(mode)
self._mode_init_states[mode] = True
def _build(self, mode):
if _non_static_mode() or self._dygraph_mode:
......@@ -174,6 +178,7 @@ class Engine:
metrics = []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
# FIXME to support grad clip
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
inputs_spec = self.inputs_spec
......@@ -204,12 +209,41 @@ class Engine:
"metrics": metrics
}
self._set_recompute_ckpts()
self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
self._dist_contexts[mode]._dygraph_mode = self._dygraph_mode
def _optimization_tuning(self, mode):
self.mode = mode
assert "batch_size" in self._user_tuning_config, "Optimization Tuning should provide with batch size."
assert "dataset" in self._user_tuning_config, "Optimization Tuning should provide with dataset."
batch_size = self._user_tuning_config["batch_size"]
dataset = self._user_tuning_config["dataset"]
dataset.dp_world_size = self._dp_world_size
dataset.dp_rank = self._dp_rank
from .tuner.optimization_tuner import OptimizationTuner
self._optimization_tuner = OptimizationTuner(self._user_tuning_config,
self._dist_contexts[mode],
dataset,
self.inputs_spec,
self.labels_spec,
batch_size=batch_size,
rank=self._cur_rank)
self._optimization_tuner.tune()
if self._user_tuning_config["run_after_tuning"]:
# update the strategy
self._dist_contexts[
mode]._strategy = self._optimization_tuner.get_best_config()
else:
return
def _plan(self, mode):
if self._planned_mode is None:
self._planned_mode = mode
......@@ -219,6 +253,18 @@ class Engine:
self._planners[mode] = Planner(mode, self._dist_contexts[mode])
self._planners[mode].plan()
# infer data parallel info
inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"]
labels_var = self._dist_contexts[mode].serial_feed_vars["labels"]
block = self._dist_contexts[mode].serial_main_program.global_block()
feed_list = []
for var in inputs_var + labels_var:
if var.name in block.vars:
feed_list.append(block.vars[var.name])
self._dp_world_size, self._dp_rank = self._get_data_parallel_info(
feed_list[0], self._dist_contexts[mode])
def _parallel(self, mode, all_ranks):
# Parallelize program based on the planner's results
# For now, the completer has to be passed to the planner,
......@@ -317,6 +363,40 @@ class Engine:
prune_startup_prog = dist_startup_prog._prune(uninitialized)
self._executor.run(prune_startup_prog)
if self.strategy.amp and self.strategy.amp_configs['use_pure_fp16']:
# from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16
def cast_parameters_to_fp16(place,
program,
scope=None,
to_fp16_var_names=None):
"""
Traverse all parameters in the whole model and set them to the FP16 data type.
Whereas, this function will keep parameters of batchnorms in FP32.
Args:
place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the FP16 weight tensors.
program (Program): The used program.
scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values.
Default is None.
to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names`
will be set to FP16. Usually, it is the returned
value of `cast_model_to_fp16` API.
"""
from paddle.framework import core
import numpy as np
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())
var_scope = scope if scope else paddle.static.global_scope()
for param in all_parameters:
if param.dtype == core.VarDesc.VarType.FP16:
param_t = var_scope.find_var(
param.name).get_tensor()
data = np.array(param_t)
param_t.set(np.float16(data), place)
cast_parameters_to_fp16(self._place, prune_startup_prog)
def fit(self,
train_data,
batch_size=1,
......@@ -342,7 +422,6 @@ class Engine:
usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch)
for epoch in range(epochs):
train_logs = {"epoch": epoch}
for step, _ in enumerate(train_dataloader):
......@@ -457,8 +536,6 @@ class Engine:
for var in inputs_var + labels_var:
if var.name in dist_main_block.vars:
feed_list.append(dist_main_block.vars[var.name])
dp_world_size, dp_rank = self._get_data_parallel_info(
feed_list[0], dist_context)
# remove the first three ops if multi run fit/evaluate/predict
op_size = len(dist_main_block.ops)
......@@ -477,8 +554,8 @@ class Engine:
batch_size,
epochs,
steps_per_epoch,
data_parallel_world_size=dp_world_size,
data_parallel_rank=dp_rank)
data_parallel_world_size=self._dp_world_size,
data_parallel_rank=self._dp_rank)
# move read op from the end of program to the start of program
new_op_size = len(dist_main_block.ops)
......@@ -561,6 +638,32 @@ class Engine:
return None, None
def _set_recompute_ckpts(self):
# NOTE hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here
config = self.strategy.recompute_configs
# extract ckpts by specific model
self.model
if isinstance(self.model, paddle.nn.Layer):
if hasattr(
self.model, "model"
) and self.model.model.__class__.__name__ == 'GPTForPretraining':
exact_ckpts = self.model.model.gpt.checkpoints
else:
exact_ckpts = config["checkpoints"]
# modify strategy
if self.strategy.recompute:
config["checkpoints"] = exact_ckpts[:]
self.strategy.recompute_configs = config
logs = {
'Model Class': self.model.model.__class__.__name__,
'Applied Recompute ckpts': exact_ckpts
}
self._logger.info(logs)
def save(self, path, training=True, mode=None):
if not mode:
mode = self.mode
......
......@@ -48,7 +48,6 @@ from .mapper import mapping
from .dist_op import DistributedOperator
from .dist_tensor import DistributedTensor
from .planner import Planner
from paddle.distributed.passes import new_pass, PassContext
_logger = get_logger(logging.INFO)
......
......@@ -42,7 +42,13 @@ def get_world_process_group():
return _g_process_group_map[0]
def new_process_group(ranks):
def clear_all_process_groups():
global _g_process_group_map
_g_process_group_map = {}
_g_process_group_map[0] = ProcessGroup(0, [])
def new_process_group(ranks, group_id=None):
global _g_process_group_map
# A key constructed from ranks is used for avoiding duplication
new_key = ''.join(map(str, sorted(ranks)))
......@@ -54,7 +60,9 @@ def new_process_group(ranks):
num_groups = len(_g_process_group_map)
# Note: our process group may interfere with the original implementation
# so the created group id should start from the original _new_ring_id()
group_id = _new_ring_id() + num_groups + 1
if group_id == None:
group_id = _new_ring_id() + num_groups + 1
new_pg = ProcessGroup(group_id, ranks)
_g_process_group_map[group_id] = new_pg
return new_pg
......
......@@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .profiler import profiler
__all__ = []
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from abc import ABC, abstractmethod
import logging
from paddle.distributed.utils import get_logger
from .trial import TrialStatus
from .trial import OptimizationTunerTrial as Trial
class AlgorithmBase(ABC):
"""
An Tuning alogrithm is a class to find out an optimal configuration
given the selected tuning optimization pass(es) and the arguments to be tuned.
Different optimization pass(es) will correspond to a different algorithm,
where different search space **pruning rules** will applied.
In another word, the key "algorithm" for this class is the
search space pruning rules specific for the given optimization scenario.
"""
_REGISTERED_ALGORITHMS = {}
name = None
@staticmethod
def _register(algo_name, algo_class):
assert issubclass(algo_class, AlgorithmBase)
AlgorithmBase._REGISTERED_ALGORITHMS[algo_name] = algo_class
def __init__(self, config):
self._config = config
self._init_spaces()
self._logger = get_logger(logging.INFO)
self._changed_configs = []
@property
def changed_configs(self):
return self._changed_configs[:]
def collect_model_info(self, main_prog, startup_prog):
"""
Collect the model static info (from programs) that could be used to
pruning candidate trials and saving tuning time.For instance,
model info like number of model parameters and activation memory could be
used to prune candidated trial and decide the next trial.
"""
pass
@abstractmethod
def _init_spaces(self):
pass
@abstractmethod
def next_trial(self):
pass
@abstractmethod
def update(self, results):
"""
Update the algorthim with the results of last trial. Using this information is used to
pruning the search space of the future trial.
"""
pass
def get_config_from_trial(self, trial):
"""
Return a new fleet.DistributedStrategy with the configurations in trial.
"""
assert len(self._changed_configs) > 0
new_strategy = copy.deepcopy(self._config.dist_strategy)
for name in self._changed_configs:
config = getattr(trial.space, name)
setattr(new_strategy, name, config)
return new_strategy
def register_algor(name):
def impl(cls):
AlgorithmBase._register(name, cls)
cls.name = name
return cls
return impl
def new_algorithm(name, config):
algor_class = AlgorithmBase._REGISTERED_ALGORITHMS.get(name)
assert algor_class is not None, "Algorithm {} is not defined.".format(name)
algor_obj = algor_class(config)
return algor_obj
@register_algor("sharding")
class ShardingStageAlgorithm(AlgorithmBase):
# TODO import trial class & copy strategy
def __init__(self, config):
super().__init__(config)
self._changed_configs = ["sharding_configs"]
def _init_spaces(self):
self._max_stage = 3
self._trial_idx = 0
stage_range = self._config.sharding_configs.get("stage_range", None)
if stage_range:
assert set(stage_range).issubset(
set([0, 1, 2, 3])
), "Sharding Stage should belong into range within 0 - 3 but got {}.".format(
stage_range)
stage_range.sort(reverse=True)
else:
stage_range = list(range(self._max_stage + 1)).sort(reverse=True)
self._stage_range = stage_range[:]
self._total_num_trial = len(self._stage_range)
def next_trial(self):
if self._trial_idx < self._total_num_trial:
stage = self._stage_range[self._trial_idx]
new_strategy = copy.deepcopy(self._config.dist_strategy)
config_dict = new_strategy.sharding_configs
config_dict["stage"] = stage
new_strategy.sharding_configs = config_dict
name = "trial-sharding-stage{}".format(stage)
trial = Trial(new_strategy, name, self.changed_configs)
return trial
else:
return Trial(None, None, None, status=TrialStatus.STOPPED)
def update(self, results):
et = results.get("ErrorType", None)
if et and et == "ResourceExhaustedError":
self._trial_idx = self._total_num_trial
self._logger.info(
"Last trial is failed with OOM, all remaining trials are pruned to save time !"
)
else:
self._trial_idx += 1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import pathlib
import paddle
from paddle.distributed import fleet
_tuning_supported_passes = ["sharding", "recompute"]
_strategy_config_suffiex = "_configs"
def _get_pass_config(strategy, pass_name):
config_name = pass_name + _strategy_config_suffiex
config = getattr(strategy, config_name)
return config
class TuningConfig(object):
"""
A uniform config wrap:
distributed strategy: the user defined configuration for optimization pass
tuning config: configuration for the tuning process: mode (profile or cost model), log dir, extra tuning config for optimization like search range for specific
"""
def __init__(self, user_config, strategy):
if not isinstance(strategy, fleet.DistributedStrategy):
raise TypeError(
"'strategy' must be object of class `fleet.DistributedStrategy`."
)
if not user_config:
user_config = {}
self._tuning_passes_name = set()
self._dist_strategy = copy.deepcopy(strategy)
self._mode = None
self._profile_start_step = None
self._profile_end_step = None
self._project_dir = None
self._max_num_trial = None
self._early_stop = None
self._verbose = None
self._initialize(user_config)
@property
def mode(self):
return self._mode
@property
def profile_start_step(self):
return self._profile_start_step
@property
def profile_end_step(self):
return self._profile_end_step
@property
def project_dir(self):
return self._project_dir
@property
def tuning_passes_name(self):
return self._tuning_passes_name
@property
def max_num_trial(self):
return self._max_num_trial
@property
def early_stop(self):
return self._early_stop
@property
def verbose(self):
return self._verbose
@property
def dist_strategy(self):
return self._dist_strategy
# initialize config with user define value or default value
def _initialize(self, user_config):
self._mode = user_config.get("mode", "PROFILE")
self._profile_start_step = user_config.get("profile_start_step", 10)
self._profile_end_step = user_config.get("profile_end_step", 30)
self._max_num_trial = user_config.get("max_num_trial", 50)
self._early_stop = user_config.get("early_stop", None)
self._verbose = user_config.get("verbose", False)
project_dir = user_config.get("project_dir", None)
if not project_dir:
project_dir = os.path.join(os.getcwd(), "OptimizationTuning")
self._project_dir = project_dir
for p in _tuning_supported_passes:
if getattr(self._dist_strategy, p) and _get_pass_config(
self._dist_strategy, p)["enable_tuning"]:
# TODO distinguish different args of each passes
self._tuning_passes_name.add(p)
config_name = p + _strategy_config_suffiex
p_dict = getattr(self._dist_strategy, config_name)
self.__dict__[config_name] = p_dict
# TODO verify the user defined configs
user_config_for_pass = user_config.get(p, None)
if user_config_for_pass:
for k, v in user_config_for_pass.items():
self.__dict__[config_name][k] = v
# (NOTE)tuning config ONLY wraps dist strategy for pass config which is to be tuned
def __getattr__(self, item):
return getattr(self._dist_strategy, item)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
import traceback
import pickle
import json
import time
import numpy as np
from functools import partial
import paddle
from paddle.fluid.framework import Program, _current_expected_place
from paddle.fluid.framework import Operator, Parameter
from paddle.distributed.auto_parallel.process_group import clear_all_process_groups, get_all_process_groups, new_process_group
from paddle.distributed.auto_parallel.dist_loader import NonIterableGeneratorLoader
from paddle.distributed.collective import _get_global_env
paddle.enable_static()
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--profile_start_step",
default=10,
type=int,
help="integer indicates the warmup step before starting profile.")
parser.add_argument("--profile_end_step",
default=30,
type=int,
help="integer indicates at the end step of profile.")
parser.add_argument("--rank",
type=int,
required=True,
help="the rank id of the this process.")
parser.add_argument("--device_id",
type=int,
required=True,
help="the device id of the this process.")
parser.add_argument(
"--ctx_filename",
type=str,
required=True,
help=
"the filename to the profile context file saved by optimizaiton tuner")
args = parser.parse_args()
return args
def init_process_groups(group_map, rank):
for group_id, ranks in group_map.items():
if group_id == 0:
continue
new_process_group(ranks=ranks, group_id=group_id)
# TODO should instantiate global group first
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if process_group.id == 0 or rank not in process_group.ranks:
continue
print(process_group)
process_group.instantiate()
def get_cpp_error_type(error):
msg = str(error).splitlines()
cpp_error_types = [
'InvalidArgumentError',
'NotFoundError',
'OutOfRangeError',
'AlreadyExistsError',
'ResourceExhaustedError',
'PreconditionNotMetError',
'PermissionDeniedError',
'ExecutionTimeoutError',
'UnimplementedError',
'UnavailableError',
'FatalError',
'ExternalError',
]
error_type = 'FatalError'
for et in cpp_error_types:
for line in msg:
if et in line:
return et
return error_type
def create_dataloader(main_program,
startup_program,
profile_ctx,
epochs=1,
steps_per_epoch=None):
dataset = profile_ctx["dataset"]
main_block = main_program.global_block()
feed_list = []
for name in dataset.input_names:
if name in main_block.vars:
feed_list.append(main_block.vars[name])
# remove the first three ops if multi run fit/evaluate/predict
op_size = len(main_block.ops)
if main_block.ops[0].type == 'create_py_reader':
op_size -= 3
for _ in range(3):
main_block._remove_op(0, sync=False)
# insert read op at the end of program
places = paddle.static.cuda_places()
with paddle.static.program_guard(main_program, startup_program):
dataloader = NonIterableGeneratorLoader(
dataset,
feed_list,
places,
dataset.batch_size,
epochs,
steps_per_epoch,
data_parallel_world_size=dataset.dp_world_size,
data_parallel_rank=dataset.dp_rank)
# move read op from the end of program to the start of program
new_op_size = len(main_block.ops)
for _ in range(new_op_size - 1, op_size - 1, -1):
op = main_block.ops[new_op_size - 1]
new_op_desc = main_block.desc._prepend_op()
new_op_desc.copy_from(op.desc)
new_op = Operator(main_block, new_op_desc, type=new_op_desc.type())
main_block.ops.insert(0, new_op)
for _ in range(new_op_size - op_size):
main_block._remove_op(new_op_size, sync=False)
main_block._sync_with_cpp()
return dataloader
def init_comm(profile_ctx):
# override the env for current process
dist_env = profile_ctx['distributed_env']
genv = _get_global_env()
genv = dist_env
print("current process rank: {}, device_id: {}, ip: {}.", genv.rank,
genv.device_id, genv.current_endpoint)
# init nccl comm
group_map = profile_ctx['group_map']
init_process_groups(group_map, args.rank)
def load_programs(profile_ctx):
main_program_desc_str = profile_ctx['main_program_decs']
main_program = Program.parse_from_string(main_program_desc_str)
startup_program_decs_str = profile_ctx['startup_program_decs']
startup_program = Program.parse_from_string(startup_program_decs_str)
loss_var_name = profile_ctx["loss_var_name"]
assert main_program.global_block().has_var(loss_var_name)
loss_var = main_program.global_block().var(loss_var_name)
return main_program, startup_program, loss_var
def get_executor():
place_type = _current_expected_place()
if not isinstance(place_type, paddle.CUDAPlace):
raise RuntimeError("OptimizationTuner only support CUDA GPU right now.")
genv = _get_global_env()
place = paddle.CUDAPlace(genv.device_id)
exe = paddle.static.Executor(place)
return exe
def profiler(args):
"""
main function to profile experiment for each pass hyper-parameter.
"""
# load ctx
if not os.path.isfile(args.ctx_filename):
raise ValueError("There is no profile context named {}.".format(
args.ctx_filename))
with open(args.ctx_filename, 'rb') as f:
profile_ctx = pickle.load(f, encoding='latin1')
init_comm(profile_ctx)
main_program, startup_program, loss_var = load_programs(profile_ctx)
data_loader = create_dataloader(main_program, startup_program, profile_ctx)
result_path = profile_ctx["result_filename"]
exe = get_executor()
exe.run(startup_program)
# profile main
duration = 0
eval_step = 0
data_loader._inner_dataloader.start()
try:
while eval_step < args.profile_end_step:
start_time = time.time()
loss = exe.run(
main_program,
fetch_list=[loss_var],
use_program_cache=True,
)
end_time = time.time()
if eval_step >= args.profile_start_step:
duration += end_time - start_time
print("step: %d, loss_print: %f" % (eval_step, loss[0]))
eval_step += 1
avg_tput = 1.0 * (args.profile_end_step -
args.profile_start_step) / duration
result_dict = {
"Throughtput": avg_tput,
"ErrorType": None,
}
if paddle.distributed.get_rank() == 0:
with open(result_path, 'w') as fp:
json.dump(result_dict, fp)
print("profile done! avg speed : {} step / s.".format((avg_tput)))
except paddle.framework.core.EOFException:
data_loader._inner_dataloader.reset()
except Exception as e:
error_type = get_cpp_error_type(e)
result_dict = {
"Throughtput": -1,
"ErrorType": error_type,
}
if not os.path.isfile(result_path):
with open(result_path, 'w') as fp:
json.dump(result_dict, fp)
print("profile failed with error: [{}]".format(error_type))
print(e)
print(traceback.format_exc())
data_loader._inner_dataloader.reset()
del data_loader._inner_dataloader
exit(1)
data_loader._inner_dataloader.reset()
del data_loader._inner_dataloader
if __name__ == "__main__":
args = parse_args()
profiler(args)
......@@ -115,6 +115,55 @@ class Trial(Storable):
return trial
class OptimizationTunerTrial(Trial):
def __init__(self,
config,
name,
changed_configs,
trial_id=None,
status=TrialStatus.RUNNING):
super(OptimizationTunerTrial, self).__init__(config, trial_id, status)
self._name = name
self._changed_configs = changed_configs
@property
def name(self):
return self._name
def summary(self):
spacing = 2
max_k = 38
max_v = 38
length = max_k + max_v + spacing
h1_format = " " + "|{{:^{}s}}|\n".format(length)
h2_format = " " + "|{{:>{}s}}{}{{:^{}s}}|\n".format(
max_k, " " * spacing, max_v)
border = " +" + "".join(["="] * length) + "+"
line = " +" + "".join(["-"] * length) + "+"
draws = border + "\n"
draws += h1_format.format("")
draws += h1_format.format("Tuned Configuartions Overview")
draws += h1_format.format("")
for name in self._changed_configs:
draws += border + "\n"
draws += h1_format.format("{} auto=True <-> {}".format(name, name))
draws += line + "\n"
my_configs = getattr(self.space, name)
keys = my_configs.keys()
for key in keys:
draws += h2_format.format(key, str(my_configs.get(key, None)))
result_res = draws + border
return result_res
def _generate_trial_id():
s = str(time.time()) + str(random.randint(1, int(1e7)))
return hashlib.sha256(s.encode("utf-8")).hexdigest()[:32]
......@@ -1473,3 +1473,11 @@ def to_list(value):
if isinstance(value, (list, tuple)):
return list(value)
return [value]
def debug_program(program, path, name):
filename = os.path.join(
path, name + '_program' + ".%d" % (paddle.distributed.get_rank()))
with open(filename, 'w') as f:
f.write(str(program))
......@@ -142,7 +142,6 @@ class AMPState(object):
modified from paddle.fluid.contrib.mixed_precision
"""
num_cast_ops = 0
var_name_dict = {}
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
......
......@@ -245,10 +245,10 @@ class ShardingPass(PassBase):
})
dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
main_block.var(sum_op_output))
assert dist_attr is not None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op, dist_attr.process_mesh, dist_attr.dims_mapping,
self._dist_context)
# assert dist_attr is not None
# naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
# new_op, dist_attr.process_mesh, dist_attr.dims_mapping,
# self._dist_context)
break
main_block._sync_with_cpp()
......
......@@ -25,6 +25,11 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_engine_api_dp
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
py_test_modules(test_optimization_tuner_api MODULES
test_optimization_tuner_api ENVS ${dist_ENVS})
set_tests_properties(test_optimization_tuner_api
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS})
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import time
import tempfile
import copy
import os
import numpy as np
import subprocess
import paddle
import paddle.nn as nn
import paddle.fluid as fluid
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from engine_api_dp import MyDataset
paddle.enable_static()
batch_size = 16
batch_num = 5
hidden_size = 1024
sequence_len = 512
image_size = hidden_size
class_num = 10
paddle.seed(44)
# class MyDataset(Dataset):
# def __init__(self, num_samples):
# super(MyDataset, self).__init__()
# self.num_samples = num_samples
# def __getitem__(self, index):
# input = np.random.uniform(size=image_size).astype("float32")
# label = np.random.randint(0, class_num - 1, dtype="int64")
# return input, label
# def __len__(self):
# return self.num_samples
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(d_model,
dim_feedforward,
weight_attr,
bias_attr=bias_attr)
self.linear1 = nn.Linear(dim_feedforward,
d_model,
weight_attr,
bias_attr=bias_attr)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
self.out = out
return out
def train(fetch):
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True
dist_strategy.sharding = True
dist_strategy.sharding_configs = {
"sharding_degree": 2,
"stage": 3,
"enable_tuning": True,
}
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
import tempfile
tmp_dir = tempfile.TemporaryDirectory()
dataset = MyDataset(batch_num * batch_size)
# Tuning configuration
tuning_config = {
"batch_size": batch_size,
"dataset": dataset,
"profile_start_step": 1,
"profile_end_step": 5,
"run_after_tuning": True,
"sharding": {
"stage_range": [0, 1, 2, 3]
},
"verbose": True,
}
engine = Engine(mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy,
user_tuning_config=tuning_config)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# check tuned
assert (engine._dist_contexts['train'].strategy.sharding_configs['stage'] !=
3)
if __name__ == "__main__":
train(True)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestOptimizationTunerAPI(unittest.TestCase):
def test_engine_api(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "optimization_tuner_api.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--gpus", "0,1", "--log_dir", tmp_dir.name,
launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
shutil.rmtree('./OptimizationTuning', ignore_errors=True)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册