From 50db9490baffb385586e38e70fb7c33edd9370f6 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Fri, 24 Jun 2022 11:51:34 +0800 Subject: [PATCH] optimization train config (#1187) --- paddleslim/auto_compression/auto_strategy.py | 15 +- paddleslim/auto_compression/compressor.py | 180 ++++++++++++------- 2 files changed, 128 insertions(+), 67 deletions(-) diff --git a/paddleslim/auto_compression/auto_strategy.py b/paddleslim/auto_compression/auto_strategy.py index a51153e1..a1bb5ad7 100644 --- a/paddleslim/auto_compression/auto_strategy.py +++ b/paddleslim/auto_compression/auto_strategy.py @@ -46,12 +46,13 @@ default_hpo_config = { # default quant config, can be used by ptq&hpo and qat&distillation default_quant_config = { - 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul', 'matmul'], + 'quantize_op_types': + ['conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'], 'weight_bits': 8, 'activation_bits': 8, "is_full_quantize": False, - "activation_quantize_type": 'range_abs_max', - "weight_quantize_type": 'abs_max', + "activation_quantize_type": 'moving_average_abs_max', + "weight_quantize_type": 'channel_wise_abs_max', "not_quant_pattern": ["skip_quant"], } @@ -60,10 +61,12 @@ DefaultTrainConfig = { "epochs": 1, "eval_iter": 500, "learning_rate": 0.0001, - "optimizer": "Momentum", - "optim_args": { + "optimizer_builder": { + "optimizer": { + "type": "Momentum", + }, "weight_decay": 4.0e-05 - }, + } } EXPERIENCE_STRATEGY_WITHOUT_LOSS = [ diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index 4ea8cf4a..1909482e 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -16,6 +16,7 @@ import logging import os import sys import numpy as np +import copy import inspect import shutil from time import gmtime, strftime @@ -28,7 +29,7 @@ from ..common import get_logger from ..common.patterns import get_patterns from ..analysis import TableLatencyPredictor from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program, remove_unused_var_nodes -from .strategy_config import ProgramInfo, merge_config +from .strategy_config import TrainConfig, ProgramInfo, merge_config from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config, create_train_config from .utils.predict import with_variable_shape @@ -127,7 +128,6 @@ class AutoCompression: if not os.path.exists(self.final_dir): os.makedirs(self.final_dir) self.strategy_config = strategy_config - self.train_config = train_config self.train_dataloader = train_dataloader self.target_speedup = target_speedup self.eval_function = eval_callback @@ -142,7 +142,7 @@ class AutoCompression: self.model_type = self._get_model_type(self._exe, model_dir, model_filename, params_filename) - if self.train_config is not None and self.train_config.use_fleet: + if train_config is not None and train_config.use_fleet: fleet.init(is_collective=True) if with_variable_shape( @@ -173,10 +173,48 @@ class AutoCompression: self._strategy, self._config = self._prepare_strategy( self.strategy_config) + self.train_config = self._get_final_train_config( + train_config, self._strategy, self.model_type) + + def _get_final_train_config(self, train_config, strategy_config, + model_type): # If train_config is None, set default train_config - if self.train_config is None: - self.train_config = create_train_config(self.strategy_config, - self.model_type) + if train_config is None: + train_config = create_train_config(strategy_config, model_type) + + train_configs = [train_config] + for idx in range(1, len(self._strategy)): + if 'qat' in self._strategy[idx]: + ### if compress strategy more than one, the train config in the yaml set for prune + ### the train config for quantization is extrapolate from the yaml + tmp_train_config = copy.deepcopy(train_config.__dict__) + ### the epoch, train_iter, learning rate of quant is 10% of the prune compress + tmp_train_config['epochs'] = max( + int(train_config.epochs * 0.1), 1) + if train_config.train_iter is not None: + tmp_train_config['train_iter'] = int( + train_config.train_iter * 0.1) + if isinstance(train_config.learning_rate, float): + tmp_train_config[ + 'learning_rate'] = train_config.learning_rate * 0.1 + else: + if 'learning_rate' in train_config.learning_rate: + tmp_train_config['learning_rate'][ + 'learning_rate'] = train_config.learning_rate[ + 'learning_rate'] * 0.1 + else: ### learning rate decay is PiecewiseDecay + tmp_train_config['learning_rate']['values'] = list( + map(lambda x: x * 0.1, train_config.learning_rate[ + 'values'])) + train_cfg = TrainConfig(**tmp_train_config) + elif 'ptq' in self._strategy[idx]: + train_cfg = None + else: + tmp_train_config = copy.deepcopy(train_config.__dict__) + train_cfg = TrainConfig(**tmp_train_config) + + train_configs.append(train_cfg) + return train_configs def _infer_shape(self, model_dir, model_filename, params_filename, input_shapes, save_path): @@ -285,42 +323,50 @@ class AutoCompression: single_teacher_distill_config is not None else \ multi_teacher_distill_config - ### case1: quant_config & hpo_config ==> PTQ & HPO - if quant_config is not None and hpo_config is not None: - strategy.append('ptq_hpo') - config.append(merge_config(quant_config, hpo_config)) - - ### case2: quant_config & distill config ==> QAT & Distill - elif quant_config is not None and self._distill_config is not None: - strategy.append('qat_dis') - config.append(merge_config(quant_config, self._distill_config)) + only_distillation = True - ### case3: prune_config & distill config - elif prune_config is not None and self._distill_config is not None: + ### case1: prune_config & distill config + if prune_config is not None and self._distill_config is not None: + only_distillation = False strategy.append('channel_prune_dis') config.append(merge_config(prune_config, self._distill_config)) - ### case4: asp_config & distill config - elif asp_config is not None and self._distill_config is not None: + ### case2: asp_config & distill config + if asp_config is not None and self._distill_config is not None: + only_distillation = False strategy.append('asp_prune_dis') config.append(merge_config(asp_config, self._distill_config)) - ### case5: transformer_prune_config & distill config - elif transformer_prune_config is not None and self._distill_config is not None: + ### case3: transformer_prune_config & distill config + if transformer_prune_config is not None and self._distill_config is not None: + only_distillation = False strategy.append('transformer_prune_dis') config.append( merge_config(transformer_prune_config, self._distill_config)) - ### case6: unstructure_config & distill config - elif unstructure_prune_config is not None and self._distill_config is not None: + ### case4: unstructure_config & distill config + if unstructure_prune_config is not None and self._distill_config is not None: + only_distillation = False strategy.append('unstructure_prune_dis') config.append( merge_config(unstructure_prune_config, self._distill_config)) + ### case5: quant_config & hpo_config ==> PTQ & HPO + if quant_config is not None and hpo_config is not None: + only_distillation = False + strategy.append('ptq_hpo') + config.append(merge_config(quant_config, hpo_config)) + + ### case6: quant_config & distill config ==> QAT & Distill + if quant_config is not None and self._distill_config is not None: + only_distillation = False + strategy.append('qat_dis') + config.append(merge_config(quant_config, self._distill_config)) + ### case7: distill_config - elif self._distill_config is not None: + if only_distillation == True and self._distill_config is not None: if single_teacher_distill_config is not None: strategy.append('single_teacher_dis') config.append(single_teacher_distill_config) @@ -328,11 +374,18 @@ class AutoCompression: strategy.append('multi_teacher_dis') config.append(multi_teacher_distill_config) - ### case N: todo - else: - raise NotImplementedError( - "Not Implemented {} be set at the same time now".format( - strategy_c.keys())) + ### NOTE: keep quantation in the last step + idx = -1 + if 'qat_dis' in strategy and strategy.index('qat_dis') != ( + len(strategy) - 1): + idx = strategy.index('qat_dis') + elif 'ptq_hpo' in strategy and strategy.index('ptq_hpo') != ( + len(strategy) - 1): + idx = strategy.index('ptq_hpo') + + if idx != -1: + strategy = strategy[:idx] + strategy[idx + 1:] + [strategy[idx]] + config = config[:idx] + config[idx + 1:] + [config[idx]] return strategy, config @@ -356,7 +409,8 @@ class AutoCompression: return strategy def _prepare_program(self, program, feed_target_names, fetch_targets, - patterns, default_distill_node_pair, strategy, config): + patterns, default_distill_node_pair, strategy, config, + train_config): train_program = recover_inference_program(program) startup_program = paddle.static.Program() train_program_info = ProgramInfo(startup_program, train_program, @@ -369,11 +423,11 @@ class AutoCompression: _logger.info( "Calculating the iterations per epoch……(It will take some time)") # NOTE:XXX: This way of calculating the iters needs to be improved. - if self.train_config.epochs: + if train_config.epochs: iters_per_epoch = len(list(self.train_dataloader())) - total_iters = self.train_config.epochs * iters_per_epoch - elif self.train_config.train_iter: - total_iters = self.train_config.train_iter + total_iters = train_config.epochs * iters_per_epoch + elif train_config.train_iter: + total_iters = train_config.train_iter else: raise RuntimeError( 'train_config must has `epochs` or `train_iter` field.') @@ -392,8 +446,8 @@ class AutoCompression: self._exe, self._places, config_dict, train_program_info, strategy, patterns, self.eval_dataloader) - if self.train_config.use_fleet: - dist_strategy = _prepare_fleet_strategy(self.train_config) + if train_config.use_fleet: + dist_strategy = _prepare_fleet_strategy(train_config) else: dist_strategy = None @@ -403,7 +457,7 @@ class AutoCompression: self._exe, self._places, config_dict, - self.train_config.__dict__, + train_config.__dict__, train_program_info, pruner=self._pruner, dist_strategy=dist_strategy, @@ -415,7 +469,7 @@ class AutoCompression: train_program_info, test_program_info, self._quant_config = build_quant_program( self._exe, self._places, config_dict, train_program_info, test_program_info) - if self.train_config.sparse_model: + if train_config.sparse_model: from ..prune.unstructured_pruner import UnstructuredPruner # NOTE: The initialization parameter of this pruner doesn't work, it is only used to call the 'set_static_masks' function self._pruner = UnstructuredPruner( @@ -428,10 +482,10 @@ class AutoCompression: self._exe.run(train_program_info.startup_program) - if (not self.train_config.use_fleet - ) and self.train_config.amp_config is not None: - if hasattr(self.train_config.amp_config, 'use_pure_fp16' - ) and self.train_config.amp_config.use_pure_fp16: + if (not train_config.use_fleet) and train_config.amp_config is not None: + if hasattr( + train_config.amp_config, + 'use_pure_fp16') and train_config.amp_config.use_pure_fp16: train_program_info.optimizer.amp_init( self._places, scope=paddle.static.global_scope()) @@ -439,7 +493,7 @@ class AutoCompression: ### prune weight in scope self._pruner.prune_model(train_program_info.program) - if not self.train_config.use_fleet: + if not train_config.use_fleet: train_program_info = self._compiled_program(train_program_info, strategy) test_program_info = self._compiled_program(test_program_info, @@ -475,9 +529,10 @@ class AutoCompression: def compress(self): self.tmp_dir = self.create_tmp_dir(self.final_dir) for strategy_idx, ( - strategy, - config) in enumerate(zip(self._strategy, self._config)): - self.single_strategy_compress(strategy, config, strategy_idx) + strategy, config, train_config + ) in enumerate(zip(self._strategy, self._config, self.train_config)): + self.single_strategy_compress(strategy, config, strategy_idx, + train_config) if strategy == 'ptq_hpo' and config.max_quant_count == 1 and platform.system( ).lower() == 'linux': @@ -488,7 +543,8 @@ class AutoCompression: quant_strategy, quant_config = self._prepare_strategy( final_quant_config) self.single_strategy_compress(quant_strategy[0], - quant_config[0], strategy_idx) + quant_config[0], strategy_idx, + train_config) tmp_model_path = os.path.join( self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))) final_model_path = os.path.join(self.final_dir) @@ -507,7 +563,8 @@ class AutoCompression: format(final_model_path)) os._exit(0) - def single_strategy_compress(self, strategy, config, strategy_idx): + def single_strategy_compress(self, strategy, config, strategy_idx, + train_config): # start compress, including train/eval model # TODO: add the emd loss of evaluation model. if strategy == 'quant_post': @@ -581,19 +638,19 @@ class AutoCompression: ### used to check whether the dataloader is right self.metric_before_compressed = None - if self.eval_function is not None and self.train_config.origin_metric is not None: + if self.eval_function is not None and train_config.origin_metric is not None: _logger.info("start to test metric before compress") metric = self.eval_function(self._exe, inference_program, feed_target_names, fetch_targets) _logger.info("metric of compressed model is: {}".format(metric)) buf = 0.05 - if metric < (float(self.train_config.origin_metric) - buf) or \ - metric > (float(self.train_config.origin_metric) + buf): + if metric < (float(train_config.origin_metric) - buf) or \ + metric > (float(train_config.origin_metric) + buf): raise RuntimeError("target metric of pretrained model is {}, \ but now is {}, Please check the format of evaluation dataset \ or check the origin_metric in train_config" .format(\ - self.train_config.origin_metric, metric)) + train_config.origin_metric, metric)) self.metric_before_compressed = metric patterns, default_distill_node_pair, _ = get_patterns( @@ -601,15 +658,16 @@ class AutoCompression: train_program_info, test_program_info = self._prepare_program( inference_program, feed_target_names, fetch_targets, patterns, - default_distill_node_pair, strategy, config) + default_distill_node_pair, strategy, config, train_config) if 'unstructure' in self._strategy: test_program_info.program._program = remove_unused_var_nodes( test_program_info.program._program) - test_program_info = self._start_train(train_program_info, - test_program_info, strategy) + test_program_info = self._start_train( + train_program_info, test_program_info, strategy, train_config) self._save_model(test_program_info, strategy, strategy_idx) - def _start_train(self, train_program_info, test_program_info, strategy): + def _start_train(self, train_program_info, test_program_info, strategy, + train_config): best_metric = -1.0 total_epochs = self.train_config.epochs if self.train_config.epochs else 100 total_train_iter = 0 @@ -623,10 +681,10 @@ class AutoCompression: if 'unstructure' in strategy: self._pruner.step() - if self.train_config.logging_iter is None: + if train_config.logging_iter is None: logging_iter = 10 else: - logging_iter = self.train_config.logging_iter + logging_iter = train_config.logging_iter if batch_id % int(logging_iter) == 0: _logger.info( "Total iter: {}, epoch: {}, batch: {}, loss: {}".format( @@ -661,8 +719,8 @@ class AutoCompression: self.metric_before_compressed) ) / self.metric_before_compressed <= 0.005: break - if self.train_config.target_metric is not None: - if metric > float(self.train_config.target_metric): + if train_config.target_metric is not None: + if metric > float(train_config.target_metric): break else: @@ -672,7 +730,7 @@ class AutoCompression: if self.train_config.train_iter and total_train_iter >= self.train_config.train_iter: break - if 'unstructure' in self._strategy or self.train_config.sparse_model: + if 'unstructure' in self._strategy or train_config.sparse_model: self._pruner.update_params() return test_program_info -- GitLab