未验证 提交 50db9490 编写于 作者: C ceci3 提交者: GitHub

optimization train config (#1187)

上级 8c6e3ab9
......@@ -46,12 +46,13 @@ default_hpo_config = {
# default quant config, can be used by ptq&hpo and qat&distillation
default_quant_config = {
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul', 'matmul'],
'quantize_op_types':
['conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'],
'weight_bits': 8,
'activation_bits': 8,
"is_full_quantize": False,
"activation_quantize_type": 'range_abs_max',
"weight_quantize_type": 'abs_max',
"activation_quantize_type": 'moving_average_abs_max',
"weight_quantize_type": 'channel_wise_abs_max',
"not_quant_pattern": ["skip_quant"],
}
......@@ -60,10 +61,12 @@ DefaultTrainConfig = {
"epochs": 1,
"eval_iter": 500,
"learning_rate": 0.0001,
"optimizer": "Momentum",
"optim_args": {
"optimizer_builder": {
"optimizer": {
"type": "Momentum",
},
"weight_decay": 4.0e-05
},
}
}
EXPERIENCE_STRATEGY_WITHOUT_LOSS = [
......
......@@ -16,6 +16,7 @@ import logging
import os
import sys
import numpy as np
import copy
import inspect
import shutil
from time import gmtime, strftime
......@@ -28,7 +29,7 @@ from ..common import get_logger
from ..common.patterns import get_patterns
from ..analysis import TableLatencyPredictor
from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program, remove_unused_var_nodes
from .strategy_config import ProgramInfo, merge_config
from .strategy_config import TrainConfig, ProgramInfo, merge_config
from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config, create_train_config
from .utils.predict import with_variable_shape
......@@ -127,7 +128,6 @@ class AutoCompression:
if not os.path.exists(self.final_dir):
os.makedirs(self.final_dir)
self.strategy_config = strategy_config
self.train_config = train_config
self.train_dataloader = train_dataloader
self.target_speedup = target_speedup
self.eval_function = eval_callback
......@@ -142,7 +142,7 @@ class AutoCompression:
self.model_type = self._get_model_type(self._exe, model_dir,
model_filename, params_filename)
if self.train_config is not None and self.train_config.use_fleet:
if train_config is not None and train_config.use_fleet:
fleet.init(is_collective=True)
if with_variable_shape(
......@@ -173,10 +173,48 @@ class AutoCompression:
self._strategy, self._config = self._prepare_strategy(
self.strategy_config)
self.train_config = self._get_final_train_config(
train_config, self._strategy, self.model_type)
def _get_final_train_config(self, train_config, strategy_config,
model_type):
# If train_config is None, set default train_config
if self.train_config is None:
self.train_config = create_train_config(self.strategy_config,
self.model_type)
if train_config is None:
train_config = create_train_config(strategy_config, model_type)
train_configs = [train_config]
for idx in range(1, len(self._strategy)):
if 'qat' in self._strategy[idx]:
### if compress strategy more than one, the train config in the yaml set for prune
### the train config for quantization is extrapolate from the yaml
tmp_train_config = copy.deepcopy(train_config.__dict__)
### the epoch, train_iter, learning rate of quant is 10% of the prune compress
tmp_train_config['epochs'] = max(
int(train_config.epochs * 0.1), 1)
if train_config.train_iter is not None:
tmp_train_config['train_iter'] = int(
train_config.train_iter * 0.1)
if isinstance(train_config.learning_rate, float):
tmp_train_config[
'learning_rate'] = train_config.learning_rate * 0.1
else:
if 'learning_rate' in train_config.learning_rate:
tmp_train_config['learning_rate'][
'learning_rate'] = train_config.learning_rate[
'learning_rate'] * 0.1
else: ### learning rate decay is PiecewiseDecay
tmp_train_config['learning_rate']['values'] = list(
map(lambda x: x * 0.1, train_config.learning_rate[
'values']))
train_cfg = TrainConfig(**tmp_train_config)
elif 'ptq' in self._strategy[idx]:
train_cfg = None
else:
tmp_train_config = copy.deepcopy(train_config.__dict__)
train_cfg = TrainConfig(**tmp_train_config)
train_configs.append(train_cfg)
return train_configs
def _infer_shape(self, model_dir, model_filename, params_filename,
input_shapes, save_path):
......@@ -285,42 +323,50 @@ class AutoCompression:
single_teacher_distill_config is not None else \
multi_teacher_distill_config
### case1: quant_config & hpo_config ==> PTQ & HPO
if quant_config is not None and hpo_config is not None:
strategy.append('ptq_hpo')
config.append(merge_config(quant_config, hpo_config))
### case2: quant_config & distill config ==> QAT & Distill
elif quant_config is not None and self._distill_config is not None:
strategy.append('qat_dis')
config.append(merge_config(quant_config, self._distill_config))
only_distillation = True
### case3: prune_config & distill config
elif prune_config is not None and self._distill_config is not None:
### case1: prune_config & distill config
if prune_config is not None and self._distill_config is not None:
only_distillation = False
strategy.append('channel_prune_dis')
config.append(merge_config(prune_config, self._distill_config))
### case4: asp_config & distill config
elif asp_config is not None and self._distill_config is not None:
### case2: asp_config & distill config
if asp_config is not None and self._distill_config is not None:
only_distillation = False
strategy.append('asp_prune_dis')
config.append(merge_config(asp_config, self._distill_config))
### case5: transformer_prune_config & distill config
elif transformer_prune_config is not None and self._distill_config is not None:
### case3: transformer_prune_config & distill config
if transformer_prune_config is not None and self._distill_config is not None:
only_distillation = False
strategy.append('transformer_prune_dis')
config.append(
merge_config(transformer_prune_config,
self._distill_config))
### case6: unstructure_config & distill config
elif unstructure_prune_config is not None and self._distill_config is not None:
### case4: unstructure_config & distill config
if unstructure_prune_config is not None and self._distill_config is not None:
only_distillation = False
strategy.append('unstructure_prune_dis')
config.append(
merge_config(unstructure_prune_config,
self._distill_config))
### case5: quant_config & hpo_config ==> PTQ & HPO
if quant_config is not None and hpo_config is not None:
only_distillation = False
strategy.append('ptq_hpo')
config.append(merge_config(quant_config, hpo_config))
### case6: quant_config & distill config ==> QAT & Distill
if quant_config is not None and self._distill_config is not None:
only_distillation = False
strategy.append('qat_dis')
config.append(merge_config(quant_config, self._distill_config))
### case7: distill_config
elif self._distill_config is not None:
if only_distillation == True and self._distill_config is not None:
if single_teacher_distill_config is not None:
strategy.append('single_teacher_dis')
config.append(single_teacher_distill_config)
......@@ -328,11 +374,18 @@ class AutoCompression:
strategy.append('multi_teacher_dis')
config.append(multi_teacher_distill_config)
### case N: todo
else:
raise NotImplementedError(
"Not Implemented {} be set at the same time now".format(
strategy_c.keys()))
### NOTE: keep quantation in the last step
idx = -1
if 'qat_dis' in strategy and strategy.index('qat_dis') != (
len(strategy) - 1):
idx = strategy.index('qat_dis')
elif 'ptq_hpo' in strategy and strategy.index('ptq_hpo') != (
len(strategy) - 1):
idx = strategy.index('ptq_hpo')
if idx != -1:
strategy = strategy[:idx] + strategy[idx + 1:] + [strategy[idx]]
config = config[:idx] + config[idx + 1:] + [config[idx]]
return strategy, config
......@@ -356,7 +409,8 @@ class AutoCompression:
return strategy
def _prepare_program(self, program, feed_target_names, fetch_targets,
patterns, default_distill_node_pair, strategy, config):
patterns, default_distill_node_pair, strategy, config,
train_config):
train_program = recover_inference_program(program)
startup_program = paddle.static.Program()
train_program_info = ProgramInfo(startup_program, train_program,
......@@ -369,11 +423,11 @@ class AutoCompression:
_logger.info(
"Calculating the iterations per epoch……(It will take some time)")
# NOTE:XXX: This way of calculating the iters needs to be improved.
if self.train_config.epochs:
if train_config.epochs:
iters_per_epoch = len(list(self.train_dataloader()))
total_iters = self.train_config.epochs * iters_per_epoch
elif self.train_config.train_iter:
total_iters = self.train_config.train_iter
total_iters = train_config.epochs * iters_per_epoch
elif train_config.train_iter:
total_iters = train_config.train_iter
else:
raise RuntimeError(
'train_config must has `epochs` or `train_iter` field.')
......@@ -392,8 +446,8 @@ class AutoCompression:
self._exe, self._places, config_dict, train_program_info,
strategy, patterns, self.eval_dataloader)
if self.train_config.use_fleet:
dist_strategy = _prepare_fleet_strategy(self.train_config)
if train_config.use_fleet:
dist_strategy = _prepare_fleet_strategy(train_config)
else:
dist_strategy = None
......@@ -403,7 +457,7 @@ class AutoCompression:
self._exe,
self._places,
config_dict,
self.train_config.__dict__,
train_config.__dict__,
train_program_info,
pruner=self._pruner,
dist_strategy=dist_strategy,
......@@ -415,7 +469,7 @@ class AutoCompression:
train_program_info, test_program_info, self._quant_config = build_quant_program(
self._exe, self._places, config_dict, train_program_info,
test_program_info)
if self.train_config.sparse_model:
if train_config.sparse_model:
from ..prune.unstructured_pruner import UnstructuredPruner
# NOTE: The initialization parameter of this pruner doesn't work, it is only used to call the 'set_static_masks' function
self._pruner = UnstructuredPruner(
......@@ -428,10 +482,10 @@ class AutoCompression:
self._exe.run(train_program_info.startup_program)
if (not self.train_config.use_fleet
) and self.train_config.amp_config is not None:
if hasattr(self.train_config.amp_config, 'use_pure_fp16'
) and self.train_config.amp_config.use_pure_fp16:
if (not train_config.use_fleet) and train_config.amp_config is not None:
if hasattr(
train_config.amp_config,
'use_pure_fp16') and train_config.amp_config.use_pure_fp16:
train_program_info.optimizer.amp_init(
self._places, scope=paddle.static.global_scope())
......@@ -439,7 +493,7 @@ class AutoCompression:
### prune weight in scope
self._pruner.prune_model(train_program_info.program)
if not self.train_config.use_fleet:
if not train_config.use_fleet:
train_program_info = self._compiled_program(train_program_info,
strategy)
test_program_info = self._compiled_program(test_program_info,
......@@ -475,9 +529,10 @@ class AutoCompression:
def compress(self):
self.tmp_dir = self.create_tmp_dir(self.final_dir)
for strategy_idx, (
strategy,
config) in enumerate(zip(self._strategy, self._config)):
self.single_strategy_compress(strategy, config, strategy_idx)
strategy, config, train_config
) in enumerate(zip(self._strategy, self._config, self.train_config)):
self.single_strategy_compress(strategy, config, strategy_idx,
train_config)
if strategy == 'ptq_hpo' and config.max_quant_count == 1 and platform.system(
).lower() == 'linux':
......@@ -488,7 +543,8 @@ class AutoCompression:
quant_strategy, quant_config = self._prepare_strategy(
final_quant_config)
self.single_strategy_compress(quant_strategy[0],
quant_config[0], strategy_idx)
quant_config[0], strategy_idx,
train_config)
tmp_model_path = os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1)))
final_model_path = os.path.join(self.final_dir)
......@@ -507,7 +563,8 @@ class AutoCompression:
format(final_model_path))
os._exit(0)
def single_strategy_compress(self, strategy, config, strategy_idx):
def single_strategy_compress(self, strategy, config, strategy_idx,
train_config):
# start compress, including train/eval model
# TODO: add the emd loss of evaluation model.
if strategy == 'quant_post':
......@@ -581,19 +638,19 @@ class AutoCompression:
### used to check whether the dataloader is right
self.metric_before_compressed = None
if self.eval_function is not None and self.train_config.origin_metric is not None:
if self.eval_function is not None and train_config.origin_metric is not None:
_logger.info("start to test metric before compress")
metric = self.eval_function(self._exe, inference_program,
feed_target_names, fetch_targets)
_logger.info("metric of compressed model is: {}".format(metric))
buf = 0.05
if metric < (float(self.train_config.origin_metric) - buf) or \
metric > (float(self.train_config.origin_metric) + buf):
if metric < (float(train_config.origin_metric) - buf) or \
metric > (float(train_config.origin_metric) + buf):
raise RuntimeError("target metric of pretrained model is {}, \
but now is {}, Please check the format of evaluation dataset \
or check the origin_metric in train_config"
.format(\
self.train_config.origin_metric, metric))
train_config.origin_metric, metric))
self.metric_before_compressed = metric
patterns, default_distill_node_pair, _ = get_patterns(
......@@ -601,15 +658,16 @@ class AutoCompression:
train_program_info, test_program_info = self._prepare_program(
inference_program, feed_target_names, fetch_targets, patterns,
default_distill_node_pair, strategy, config)
default_distill_node_pair, strategy, config, train_config)
if 'unstructure' in self._strategy:
test_program_info.program._program = remove_unused_var_nodes(
test_program_info.program._program)
test_program_info = self._start_train(train_program_info,
test_program_info, strategy)
test_program_info = self._start_train(
train_program_info, test_program_info, strategy, train_config)
self._save_model(test_program_info, strategy, strategy_idx)
def _start_train(self, train_program_info, test_program_info, strategy):
def _start_train(self, train_program_info, test_program_info, strategy,
train_config):
best_metric = -1.0
total_epochs = self.train_config.epochs if self.train_config.epochs else 100
total_train_iter = 0
......@@ -623,10 +681,10 @@ class AutoCompression:
if 'unstructure' in strategy:
self._pruner.step()
if self.train_config.logging_iter is None:
if train_config.logging_iter is None:
logging_iter = 10
else:
logging_iter = self.train_config.logging_iter
logging_iter = train_config.logging_iter
if batch_id % int(logging_iter) == 0:
_logger.info(
"Total iter: {}, epoch: {}, batch: {}, loss: {}".format(
......@@ -661,8 +719,8 @@ class AutoCompression:
self.metric_before_compressed)
) / self.metric_before_compressed <= 0.005:
break
if self.train_config.target_metric is not None:
if metric > float(self.train_config.target_metric):
if train_config.target_metric is not None:
if metric > float(train_config.target_metric):
break
else:
......@@ -672,7 +730,7 @@ class AutoCompression:
if self.train_config.train_iter and total_train_iter >= self.train_config.train_iter:
break
if 'unstructure' in self._strategy or self.train_config.sparse_model:
if 'unstructure' in self._strategy or train_config.sparse_model:
self._pruner.update_params()
return test_program_info
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册