未验证 提交 50db9490 编写于 作者: C ceci3 提交者: GitHub

optimization train config (#1187)

上级 8c6e3ab9
...@@ -46,12 +46,13 @@ default_hpo_config = { ...@@ -46,12 +46,13 @@ default_hpo_config = {
# default quant config, can be used by ptq&hpo and qat&distillation # default quant config, can be used by ptq&hpo and qat&distillation
default_quant_config = { default_quant_config = {
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul', 'matmul'], 'quantize_op_types':
['conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'],
'weight_bits': 8, 'weight_bits': 8,
'activation_bits': 8, 'activation_bits': 8,
"is_full_quantize": False, "is_full_quantize": False,
"activation_quantize_type": 'range_abs_max', "activation_quantize_type": 'moving_average_abs_max',
"weight_quantize_type": 'abs_max', "weight_quantize_type": 'channel_wise_abs_max',
"not_quant_pattern": ["skip_quant"], "not_quant_pattern": ["skip_quant"],
} }
...@@ -60,10 +61,12 @@ DefaultTrainConfig = { ...@@ -60,10 +61,12 @@ DefaultTrainConfig = {
"epochs": 1, "epochs": 1,
"eval_iter": 500, "eval_iter": 500,
"learning_rate": 0.0001, "learning_rate": 0.0001,
"optimizer": "Momentum", "optimizer_builder": {
"optim_args": { "optimizer": {
"weight_decay": 4.0e-05 "type": "Momentum",
}, },
"weight_decay": 4.0e-05
}
} }
EXPERIENCE_STRATEGY_WITHOUT_LOSS = [ EXPERIENCE_STRATEGY_WITHOUT_LOSS = [
......
...@@ -16,6 +16,7 @@ import logging ...@@ -16,6 +16,7 @@ import logging
import os import os
import sys import sys
import numpy as np import numpy as np
import copy
import inspect import inspect
import shutil import shutil
from time import gmtime, strftime from time import gmtime, strftime
...@@ -28,7 +29,7 @@ from ..common import get_logger ...@@ -28,7 +29,7 @@ from ..common import get_logger
from ..common.patterns import get_patterns from ..common.patterns import get_patterns
from ..analysis import TableLatencyPredictor from ..analysis import TableLatencyPredictor
from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program, remove_unused_var_nodes from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program, remove_unused_var_nodes
from .strategy_config import ProgramInfo, merge_config from .strategy_config import TrainConfig, ProgramInfo, merge_config
from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config, create_train_config from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config, create_train_config
from .utils.predict import with_variable_shape from .utils.predict import with_variable_shape
...@@ -127,7 +128,6 @@ class AutoCompression: ...@@ -127,7 +128,6 @@ class AutoCompression:
if not os.path.exists(self.final_dir): if not os.path.exists(self.final_dir):
os.makedirs(self.final_dir) os.makedirs(self.final_dir)
self.strategy_config = strategy_config self.strategy_config = strategy_config
self.train_config = train_config
self.train_dataloader = train_dataloader self.train_dataloader = train_dataloader
self.target_speedup = target_speedup self.target_speedup = target_speedup
self.eval_function = eval_callback self.eval_function = eval_callback
...@@ -142,7 +142,7 @@ class AutoCompression: ...@@ -142,7 +142,7 @@ class AutoCompression:
self.model_type = self._get_model_type(self._exe, model_dir, self.model_type = self._get_model_type(self._exe, model_dir,
model_filename, params_filename) model_filename, params_filename)
if self.train_config is not None and self.train_config.use_fleet: if train_config is not None and train_config.use_fleet:
fleet.init(is_collective=True) fleet.init(is_collective=True)
if with_variable_shape( if with_variable_shape(
...@@ -173,10 +173,48 @@ class AutoCompression: ...@@ -173,10 +173,48 @@ class AutoCompression:
self._strategy, self._config = self._prepare_strategy( self._strategy, self._config = self._prepare_strategy(
self.strategy_config) self.strategy_config)
self.train_config = self._get_final_train_config(
train_config, self._strategy, self.model_type)
def _get_final_train_config(self, train_config, strategy_config,
model_type):
# If train_config is None, set default train_config # If train_config is None, set default train_config
if self.train_config is None: if train_config is None:
self.train_config = create_train_config(self.strategy_config, train_config = create_train_config(strategy_config, model_type)
self.model_type)
train_configs = [train_config]
for idx in range(1, len(self._strategy)):
if 'qat' in self._strategy[idx]:
### if compress strategy more than one, the train config in the yaml set for prune
### the train config for quantization is extrapolate from the yaml
tmp_train_config = copy.deepcopy(train_config.__dict__)
### the epoch, train_iter, learning rate of quant is 10% of the prune compress
tmp_train_config['epochs'] = max(
int(train_config.epochs * 0.1), 1)
if train_config.train_iter is not None:
tmp_train_config['train_iter'] = int(
train_config.train_iter * 0.1)
if isinstance(train_config.learning_rate, float):
tmp_train_config[
'learning_rate'] = train_config.learning_rate * 0.1
else:
if 'learning_rate' in train_config.learning_rate:
tmp_train_config['learning_rate'][
'learning_rate'] = train_config.learning_rate[
'learning_rate'] * 0.1
else: ### learning rate decay is PiecewiseDecay
tmp_train_config['learning_rate']['values'] = list(
map(lambda x: x * 0.1, train_config.learning_rate[
'values']))
train_cfg = TrainConfig(**tmp_train_config)
elif 'ptq' in self._strategy[idx]:
train_cfg = None
else:
tmp_train_config = copy.deepcopy(train_config.__dict__)
train_cfg = TrainConfig(**tmp_train_config)
train_configs.append(train_cfg)
return train_configs
def _infer_shape(self, model_dir, model_filename, params_filename, def _infer_shape(self, model_dir, model_filename, params_filename,
input_shapes, save_path): input_shapes, save_path):
...@@ -285,42 +323,50 @@ class AutoCompression: ...@@ -285,42 +323,50 @@ class AutoCompression:
single_teacher_distill_config is not None else \ single_teacher_distill_config is not None else \
multi_teacher_distill_config multi_teacher_distill_config
### case1: quant_config & hpo_config ==> PTQ & HPO only_distillation = True
if quant_config is not None and hpo_config is not None:
strategy.append('ptq_hpo')
config.append(merge_config(quant_config, hpo_config))
### case2: quant_config & distill config ==> QAT & Distill
elif quant_config is not None and self._distill_config is not None:
strategy.append('qat_dis')
config.append(merge_config(quant_config, self._distill_config))
### case3: prune_config & distill config ### case1: prune_config & distill config
elif prune_config is not None and self._distill_config is not None: if prune_config is not None and self._distill_config is not None:
only_distillation = False
strategy.append('channel_prune_dis') strategy.append('channel_prune_dis')
config.append(merge_config(prune_config, self._distill_config)) config.append(merge_config(prune_config, self._distill_config))
### case4: asp_config & distill config ### case2: asp_config & distill config
elif asp_config is not None and self._distill_config is not None: if asp_config is not None and self._distill_config is not None:
only_distillation = False
strategy.append('asp_prune_dis') strategy.append('asp_prune_dis')
config.append(merge_config(asp_config, self._distill_config)) config.append(merge_config(asp_config, self._distill_config))
### case5: transformer_prune_config & distill config ### case3: transformer_prune_config & distill config
elif transformer_prune_config is not None and self._distill_config is not None: if transformer_prune_config is not None and self._distill_config is not None:
only_distillation = False
strategy.append('transformer_prune_dis') strategy.append('transformer_prune_dis')
config.append( config.append(
merge_config(transformer_prune_config, merge_config(transformer_prune_config,
self._distill_config)) self._distill_config))
### case6: unstructure_config & distill config ### case4: unstructure_config & distill config
elif unstructure_prune_config is not None and self._distill_config is not None: if unstructure_prune_config is not None and self._distill_config is not None:
only_distillation = False
strategy.append('unstructure_prune_dis') strategy.append('unstructure_prune_dis')
config.append( config.append(
merge_config(unstructure_prune_config, merge_config(unstructure_prune_config,
self._distill_config)) self._distill_config))
### case5: quant_config & hpo_config ==> PTQ & HPO
if quant_config is not None and hpo_config is not None:
only_distillation = False
strategy.append('ptq_hpo')
config.append(merge_config(quant_config, hpo_config))
### case6: quant_config & distill config ==> QAT & Distill
if quant_config is not None and self._distill_config is not None:
only_distillation = False
strategy.append('qat_dis')
config.append(merge_config(quant_config, self._distill_config))
### case7: distill_config ### case7: distill_config
elif self._distill_config is not None: if only_distillation == True and self._distill_config is not None:
if single_teacher_distill_config is not None: if single_teacher_distill_config is not None:
strategy.append('single_teacher_dis') strategy.append('single_teacher_dis')
config.append(single_teacher_distill_config) config.append(single_teacher_distill_config)
...@@ -328,11 +374,18 @@ class AutoCompression: ...@@ -328,11 +374,18 @@ class AutoCompression:
strategy.append('multi_teacher_dis') strategy.append('multi_teacher_dis')
config.append(multi_teacher_distill_config) config.append(multi_teacher_distill_config)
### case N: todo ### NOTE: keep quantation in the last step
else: idx = -1
raise NotImplementedError( if 'qat_dis' in strategy and strategy.index('qat_dis') != (
"Not Implemented {} be set at the same time now".format( len(strategy) - 1):
strategy_c.keys())) idx = strategy.index('qat_dis')
elif 'ptq_hpo' in strategy and strategy.index('ptq_hpo') != (
len(strategy) - 1):
idx = strategy.index('ptq_hpo')
if idx != -1:
strategy = strategy[:idx] + strategy[idx + 1:] + [strategy[idx]]
config = config[:idx] + config[idx + 1:] + [config[idx]]
return strategy, config return strategy, config
...@@ -356,7 +409,8 @@ class AutoCompression: ...@@ -356,7 +409,8 @@ class AutoCompression:
return strategy return strategy
def _prepare_program(self, program, feed_target_names, fetch_targets, def _prepare_program(self, program, feed_target_names, fetch_targets,
patterns, default_distill_node_pair, strategy, config): patterns, default_distill_node_pair, strategy, config,
train_config):
train_program = recover_inference_program(program) train_program = recover_inference_program(program)
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
train_program_info = ProgramInfo(startup_program, train_program, train_program_info = ProgramInfo(startup_program, train_program,
...@@ -369,11 +423,11 @@ class AutoCompression: ...@@ -369,11 +423,11 @@ class AutoCompression:
_logger.info( _logger.info(
"Calculating the iterations per epoch……(It will take some time)") "Calculating the iterations per epoch……(It will take some time)")
# NOTE:XXX: This way of calculating the iters needs to be improved. # NOTE:XXX: This way of calculating the iters needs to be improved.
if self.train_config.epochs: if train_config.epochs:
iters_per_epoch = len(list(self.train_dataloader())) iters_per_epoch = len(list(self.train_dataloader()))
total_iters = self.train_config.epochs * iters_per_epoch total_iters = train_config.epochs * iters_per_epoch
elif self.train_config.train_iter: elif train_config.train_iter:
total_iters = self.train_config.train_iter total_iters = train_config.train_iter
else: else:
raise RuntimeError( raise RuntimeError(
'train_config must has `epochs` or `train_iter` field.') 'train_config must has `epochs` or `train_iter` field.')
...@@ -392,8 +446,8 @@ class AutoCompression: ...@@ -392,8 +446,8 @@ class AutoCompression:
self._exe, self._places, config_dict, train_program_info, self._exe, self._places, config_dict, train_program_info,
strategy, patterns, self.eval_dataloader) strategy, patterns, self.eval_dataloader)
if self.train_config.use_fleet: if train_config.use_fleet:
dist_strategy = _prepare_fleet_strategy(self.train_config) dist_strategy = _prepare_fleet_strategy(train_config)
else: else:
dist_strategy = None dist_strategy = None
...@@ -403,7 +457,7 @@ class AutoCompression: ...@@ -403,7 +457,7 @@ class AutoCompression:
self._exe, self._exe,
self._places, self._places,
config_dict, config_dict,
self.train_config.__dict__, train_config.__dict__,
train_program_info, train_program_info,
pruner=self._pruner, pruner=self._pruner,
dist_strategy=dist_strategy, dist_strategy=dist_strategy,
...@@ -415,7 +469,7 @@ class AutoCompression: ...@@ -415,7 +469,7 @@ class AutoCompression:
train_program_info, test_program_info, self._quant_config = build_quant_program( train_program_info, test_program_info, self._quant_config = build_quant_program(
self._exe, self._places, config_dict, train_program_info, self._exe, self._places, config_dict, train_program_info,
test_program_info) test_program_info)
if self.train_config.sparse_model: if train_config.sparse_model:
from ..prune.unstructured_pruner import UnstructuredPruner from ..prune.unstructured_pruner import UnstructuredPruner
# NOTE: The initialization parameter of this pruner doesn't work, it is only used to call the 'set_static_masks' function # NOTE: The initialization parameter of this pruner doesn't work, it is only used to call the 'set_static_masks' function
self._pruner = UnstructuredPruner( self._pruner = UnstructuredPruner(
...@@ -428,10 +482,10 @@ class AutoCompression: ...@@ -428,10 +482,10 @@ class AutoCompression:
self._exe.run(train_program_info.startup_program) self._exe.run(train_program_info.startup_program)
if (not self.train_config.use_fleet if (not train_config.use_fleet) and train_config.amp_config is not None:
) and self.train_config.amp_config is not None: if hasattr(
if hasattr(self.train_config.amp_config, 'use_pure_fp16' train_config.amp_config,
) and self.train_config.amp_config.use_pure_fp16: 'use_pure_fp16') and train_config.amp_config.use_pure_fp16:
train_program_info.optimizer.amp_init( train_program_info.optimizer.amp_init(
self._places, scope=paddle.static.global_scope()) self._places, scope=paddle.static.global_scope())
...@@ -439,7 +493,7 @@ class AutoCompression: ...@@ -439,7 +493,7 @@ class AutoCompression:
### prune weight in scope ### prune weight in scope
self._pruner.prune_model(train_program_info.program) self._pruner.prune_model(train_program_info.program)
if not self.train_config.use_fleet: if not train_config.use_fleet:
train_program_info = self._compiled_program(train_program_info, train_program_info = self._compiled_program(train_program_info,
strategy) strategy)
test_program_info = self._compiled_program(test_program_info, test_program_info = self._compiled_program(test_program_info,
...@@ -475,9 +529,10 @@ class AutoCompression: ...@@ -475,9 +529,10 @@ class AutoCompression:
def compress(self): def compress(self):
self.tmp_dir = self.create_tmp_dir(self.final_dir) self.tmp_dir = self.create_tmp_dir(self.final_dir)
for strategy_idx, ( for strategy_idx, (
strategy, strategy, config, train_config
config) in enumerate(zip(self._strategy, self._config)): ) in enumerate(zip(self._strategy, self._config, self.train_config)):
self.single_strategy_compress(strategy, config, strategy_idx) self.single_strategy_compress(strategy, config, strategy_idx,
train_config)
if strategy == 'ptq_hpo' and config.max_quant_count == 1 and platform.system( if strategy == 'ptq_hpo' and config.max_quant_count == 1 and platform.system(
).lower() == 'linux': ).lower() == 'linux':
...@@ -488,7 +543,8 @@ class AutoCompression: ...@@ -488,7 +543,8 @@ class AutoCompression:
quant_strategy, quant_config = self._prepare_strategy( quant_strategy, quant_config = self._prepare_strategy(
final_quant_config) final_quant_config)
self.single_strategy_compress(quant_strategy[0], self.single_strategy_compress(quant_strategy[0],
quant_config[0], strategy_idx) quant_config[0], strategy_idx,
train_config)
tmp_model_path = os.path.join( tmp_model_path = os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))) self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1)))
final_model_path = os.path.join(self.final_dir) final_model_path = os.path.join(self.final_dir)
...@@ -507,7 +563,8 @@ class AutoCompression: ...@@ -507,7 +563,8 @@ class AutoCompression:
format(final_model_path)) format(final_model_path))
os._exit(0) os._exit(0)
def single_strategy_compress(self, strategy, config, strategy_idx): def single_strategy_compress(self, strategy, config, strategy_idx,
train_config):
# start compress, including train/eval model # start compress, including train/eval model
# TODO: add the emd loss of evaluation model. # TODO: add the emd loss of evaluation model.
if strategy == 'quant_post': if strategy == 'quant_post':
...@@ -581,19 +638,19 @@ class AutoCompression: ...@@ -581,19 +638,19 @@ class AutoCompression:
### used to check whether the dataloader is right ### used to check whether the dataloader is right
self.metric_before_compressed = None self.metric_before_compressed = None
if self.eval_function is not None and self.train_config.origin_metric is not None: if self.eval_function is not None and train_config.origin_metric is not None:
_logger.info("start to test metric before compress") _logger.info("start to test metric before compress")
metric = self.eval_function(self._exe, inference_program, metric = self.eval_function(self._exe, inference_program,
feed_target_names, fetch_targets) feed_target_names, fetch_targets)
_logger.info("metric of compressed model is: {}".format(metric)) _logger.info("metric of compressed model is: {}".format(metric))
buf = 0.05 buf = 0.05
if metric < (float(self.train_config.origin_metric) - buf) or \ if metric < (float(train_config.origin_metric) - buf) or \
metric > (float(self.train_config.origin_metric) + buf): metric > (float(train_config.origin_metric) + buf):
raise RuntimeError("target metric of pretrained model is {}, \ raise RuntimeError("target metric of pretrained model is {}, \
but now is {}, Please check the format of evaluation dataset \ but now is {}, Please check the format of evaluation dataset \
or check the origin_metric in train_config" or check the origin_metric in train_config"
.format(\ .format(\
self.train_config.origin_metric, metric)) train_config.origin_metric, metric))
self.metric_before_compressed = metric self.metric_before_compressed = metric
patterns, default_distill_node_pair, _ = get_patterns( patterns, default_distill_node_pair, _ = get_patterns(
...@@ -601,15 +658,16 @@ class AutoCompression: ...@@ -601,15 +658,16 @@ class AutoCompression:
train_program_info, test_program_info = self._prepare_program( train_program_info, test_program_info = self._prepare_program(
inference_program, feed_target_names, fetch_targets, patterns, inference_program, feed_target_names, fetch_targets, patterns,
default_distill_node_pair, strategy, config) default_distill_node_pair, strategy, config, train_config)
if 'unstructure' in self._strategy: if 'unstructure' in self._strategy:
test_program_info.program._program = remove_unused_var_nodes( test_program_info.program._program = remove_unused_var_nodes(
test_program_info.program._program) test_program_info.program._program)
test_program_info = self._start_train(train_program_info, test_program_info = self._start_train(
test_program_info, strategy) train_program_info, test_program_info, strategy, train_config)
self._save_model(test_program_info, strategy, strategy_idx) self._save_model(test_program_info, strategy, strategy_idx)
def _start_train(self, train_program_info, test_program_info, strategy): def _start_train(self, train_program_info, test_program_info, strategy,
train_config):
best_metric = -1.0 best_metric = -1.0
total_epochs = self.train_config.epochs if self.train_config.epochs else 100 total_epochs = self.train_config.epochs if self.train_config.epochs else 100
total_train_iter = 0 total_train_iter = 0
...@@ -623,10 +681,10 @@ class AutoCompression: ...@@ -623,10 +681,10 @@ class AutoCompression:
if 'unstructure' in strategy: if 'unstructure' in strategy:
self._pruner.step() self._pruner.step()
if self.train_config.logging_iter is None: if train_config.logging_iter is None:
logging_iter = 10 logging_iter = 10
else: else:
logging_iter = self.train_config.logging_iter logging_iter = train_config.logging_iter
if batch_id % int(logging_iter) == 0: if batch_id % int(logging_iter) == 0:
_logger.info( _logger.info(
"Total iter: {}, epoch: {}, batch: {}, loss: {}".format( "Total iter: {}, epoch: {}, batch: {}, loss: {}".format(
...@@ -661,8 +719,8 @@ class AutoCompression: ...@@ -661,8 +719,8 @@ class AutoCompression:
self.metric_before_compressed) self.metric_before_compressed)
) / self.metric_before_compressed <= 0.005: ) / self.metric_before_compressed <= 0.005:
break break
if self.train_config.target_metric is not None: if train_config.target_metric is not None:
if metric > float(self.train_config.target_metric): if metric > float(train_config.target_metric):
break break
else: else:
...@@ -672,7 +730,7 @@ class AutoCompression: ...@@ -672,7 +730,7 @@ class AutoCompression:
if self.train_config.train_iter and total_train_iter >= self.train_config.train_iter: if self.train_config.train_iter and total_train_iter >= self.train_config.train_iter:
break break
if 'unstructure' in self._strategy or self.train_config.sparse_model: if 'unstructure' in self._strategy or train_config.sparse_model:
self._pruner.update_params() self._pruner.update_params()
return test_program_info return test_program_info
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册