diff --git a/paddleslim/auto_compression/analysis.py b/paddleslim/auto_compression/analysis.py index db9f601e77edf3d42ca865da19fddc98995e94d2..3423db4a796819a2aae4d832c21847c8f0f60b14 100644 --- a/paddleslim/auto_compression/analysis.py +++ b/paddleslim/auto_compression/analysis.py @@ -28,7 +28,19 @@ def analysis_prune(eval_function, params_filename, analysis_file, pruned_ratios, - target_loss=None): + target_loss=None, + criterion='l1_norm'): + ''' + Args: + eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.static.Program` as argument and return a score on test dataset. + model_dir(str): Directory path to load model. If you want to load onnx model, only set ``model_dir=model.onnx``. + model_filename(str): Specify model_filename. If you want to load onnx model, model filename should be None. + params_filename(str): Specify params_filename. If you want to load onnx model, params filename should be None. + analysis_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library. + pruned_ratios(list): The ratios to be pruned. + criterion(str|function): The criterion used to sort channels for pruning. Currently supports l1_ norm, bn_scale, geometry_median. Default: l1_norm. + ''' + devices = paddle.device.get_device().split(':')[0] places = paddle.device._convert_to_place(devices) exe = paddle.static.Executor(places) @@ -47,7 +59,8 @@ def analysis_prune(eval_function, eval_function, sensitivities_file=analysis_file, eval_args=[exe, feed_target_names, fetch_targets], - pruned_ratios=pruned_ratios) + pruned_ratios=pruned_ratios, + criterion=criterion) with open(analysis_file, 'rb') as f: if sys.version_info < (3, 0): diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index d59d8d9e1af6281a084365c655365436fec75c2c..f86911b823e82bac81d37daf3ab2033c2d53a955 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -783,13 +783,17 @@ class AutoCompression: total_epochs = train_config.epochs if train_config.epochs else 100 total_train_iter = 0 stop_training = False + + loss_vars = [var for var in train_program_info.loss_dict.values()] + loss_names = [name for name in train_program_info.loss_dict.keys()] + for epoch_id in range(total_epochs): if stop_training: break for batch_id, data in enumerate(self.train_dataloader()): - np_probs_float, = self._exe.run(train_program_info.program, \ + loss = self._exe.run(train_program_info.program, \ feed=data, \ - fetch_list=train_program_info.fetch_targets) + fetch_list=train_program_info.fetch_targets+loss_vars) if not isinstance(train_program_info.learning_rate, float): train_program_info.learning_rate.step() if 'unstructure' in strategy: @@ -800,10 +804,12 @@ class AutoCompression: else: logging_iter = train_config.logging_iter if batch_id % int(logging_iter) == 0: - _logger.info( - "Total iter: {}, epoch: {}, batch: {}, loss: {}".format( - total_train_iter, epoch_id, batch_id, - np_probs_float)) + print_info = "Total iter: {}, epoch: {}, batch: {}, loss: {}".format( + total_train_iter, epoch_id, batch_id, loss[0]) + for idx, loss_value in enumerate(loss[1:]): + print_info += '{}: {} '.format(loss_names[idx], + loss_value) + _logger.info(print_info) total_train_iter += 1 if total_train_iter % int( train_config.eval_iter) == 0 and total_train_iter != 0: diff --git a/paddleslim/auto_compression/create_compressed_program.py b/paddleslim/auto_compression/create_compressed_program.py index 7332619a30f807cce3e5cc06de6fc5ade4f98d15..677bba963c89255b51d92102a96c7375d215d8c6 100644 --- a/paddleslim/auto_compression/create_compressed_program.py +++ b/paddleslim/auto_compression/create_compressed_program.py @@ -24,6 +24,7 @@ from ..common.recover_program import recover_inference_program, _remove_fetch_no from ..common import get_logger from .strategy_config import ProgramInfo from ..common.load_model import load_inference_model +from ..analysis import flops _logger = get_logger(__name__, level=logging.INFO) __all__ = [ @@ -133,7 +134,7 @@ def _parse_distill_loss(distill_node_pair, distill_lambda=1.0): """parse distill loss config""" loss_dist = 0.0 - losses = [] + losses = {} if isinstance(distill_node_pair[0], str): assert isinstance(distill_loss, str) assert isinstance(distill_lambda, float) @@ -143,16 +144,17 @@ def _parse_distill_loss(distill_node_pair, assert len(distill_node_pair) == len(distill_loss) assert len(distill_node_pair) == len(distill_lambda) - for node, loss, lam in zip(distill_node_pair, distill_loss, distill_lambda): - tmp_loss = 0.0 - _logger.info("train config.distill_node_pair: {}".format(node, loss, - lam)) + for node, loss_clas, lam in zip(distill_node_pair, distill_loss, + distill_lambda): + tmp_loss = losses.get(loss_clas, 0.0) + _logger.info("train config.distill_node_pair: {}".format( + node, loss_clas, lam)) assert len(node) % 2 == 0, \ "distill_node_pair config wrong, the length needs to be an even number" for i in range(len(node) // 2): - tmp_loss += eval(loss)(node[i * 2], node[i * 2 + 1]) - loss_dist += lam * tmp_loss - losses.append(tmp_loss) + tmp_loss += eval(loss_clas)(node[i * 2], node[i * 2 + 1]) * lam + loss_dist += tmp_loss + losses[loss_clas] = tmp_loss return loss_dist, losses @@ -364,7 +366,7 @@ def build_distill_program(executor, use_dynamic_loss_scaling=True, **train_config['amp_config']) - distill_loss, losses = _parse_distill_loss( + distill_loss, loss_dict = _parse_distill_loss( distill_node_pair, config.get('loss') or 'l2', ### default loss is l2 config.get('alpha') or 1.0) ### default alpha is 1.0 @@ -385,7 +387,7 @@ def build_distill_program(executor, train_program_info = ProgramInfo(startup_program, train_program, feed_target_names, train_fetch_list, - optimizer, learning_rate) + optimizer, learning_rate, loss_dict) test_program_info = ProgramInfo(startup_program, test_program, feed_target_names, fetch_targets) return train_program_info, test_program_info @@ -520,6 +522,8 @@ def build_prune_program(executor, params.append(param.name) original_shapes[param.name] = param.shape + origin_flops = flops(train_program_info.program) + pruned_program, _, _ = pruner.prune( train_program_info.program, paddle.static.global_scope(), @@ -530,12 +534,18 @@ def build_prune_program(executor, place=place) _logger.info( "####################channel pruning##########################") - for param in pruned_program.all_parameters(): + for param in pruned_program.global_block().all_parameters(): if param.name in original_shapes: _logger.info("{}, from {} to {}".format( param.name, original_shapes[param.name], param.shape)) _logger.info( "####################channel pruning end##########################") + + final_flops = flops(pruned_program) + pruned_flops = abs(origin_flops - final_flops) / origin_flops + _logger.info("FLOPs before pruning: {}".format(origin_flops)) + _logger.info("FLOPs after pruning: {}. Pruned FLOPs: {}%.".format( + final_flops, round(pruned_flops * 100, 2))) train_program_info.program = pruned_program elif strategy.startswith('asp'): diff --git a/paddleslim/auto_compression/strategy_config.py b/paddleslim/auto_compression/strategy_config.py index c2f663d1ce01bfd41b502efc406bd264d90650aa..522b22743bc03a39ed37866200bed0691c073f3f 100644 --- a/paddleslim/auto_compression/strategy_config.py +++ b/paddleslim/auto_compression/strategy_config.py @@ -431,7 +431,8 @@ class ProgramInfo: feed_target_names, fetch_targets, optimizer=None, - learning_rate=None): + learning_rate=None, + loss_dict=None): """ ProgramInfo Config. Args: @@ -441,6 +442,7 @@ class ProgramInfo: fetch_targets(list(Variable)): The fetch variable in the program. optimizer(Optimizer, optional): Optimizer in training. Default: None. learning_rate(float|paddle.optimizer.lr, optional): learning_rate in training. Default: None. + loss_dict(dict): The components of losses. """ self.startup_program = startup_program self.program = program @@ -448,3 +450,4 @@ class ProgramInfo: self.fetch_targets = fetch_targets self.optimizer = optimizer self.learning_rate = learning_rate + self.loss_dict = loss_dict