# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os import logging import pickle import numpy as np import paddle.fluid as fluid from ..core import GraphWrapper from ..common import get_logger from ..analysis import flops from ..prune import Pruner _logger = get_logger(__name__, level=logging.INFO) __all__ = [ "sensitivity", "flops_sensitivity", "load_sensitivities", "merge_sensitive", "get_ratios_by_loss" ] def sensitivity(program, place, param_names, eval_func, sensitivities_file=None, pruned_ratios=None): scope = fluid.global_scope() graph = GraphWrapper(program) sensitivities = load_sensitivities(sensitivities_file) if pruned_ratios is None: pruned_ratios = np.arange(0.1, 1, step=0.1) for name in param_names: if name not in sensitivities: sensitivities[name] = {} baseline = None for name in sensitivities: for ratio in pruned_ratios: if ratio in sensitivities[name]: _logger.debug('{}, {} has computed.'.format(name, ratio)) continue if baseline is None: baseline = eval_func(graph.program) pruner = Pruner() _logger.info("sensitive - param: {}; ratios: {}".format(name, ratio)) pruned_program, param_backup, _ = pruner.prune( program=graph.program, scope=scope, params=[name], ratios=[ratio], place=place, lazy=True, only_graph=False, param_backup=True) pruned_metric = eval_func(pruned_program) loss = (baseline - pruned_metric) / baseline _logger.info("pruned param: {}; {}; loss={}".format(name, ratio, loss)) sensitivities[name][ratio] = loss _save_sensitivities(sensitivities, sensitivities_file) # restore pruned parameters for param_name in param_backup.keys(): param_t = scope.find_var(param_name).get_tensor() param_t.set(param_backup[param_name], place) return sensitivities def flops_sensitivity(program, place, param_names, eval_func, sensitivities_file=None, pruned_flops_rate=0.1): assert (1.0 / len(param_names) > pruned_flops_rate) scope = fluid.global_scope() graph = GraphWrapper(program) sensitivities = load_sensitivities(sensitivities_file) for name in param_names: if name not in sensitivities: sensitivities[name] = {} base_flops = flops(program) target_pruned_flops = base_flops * pruned_flops_rate pruner = Pruner() baseline = None for name in sensitivities: pruned_program, _, _ = pruner.prune( program=graph.program, scope=None, params=[name], ratios=[0.5], place=None, lazy=False, only_graph=True) param_flops = (base_flops - flops(pruned_program)) * 2 channel_size = graph.var(name).shape()[0] pruned_ratio = target_pruned_flops / float(param_flops) pruned_ratio = round(pruned_ratio, 3) pruned_size = round(pruned_ratio * channel_size) pruned_ratio = 1 if pruned_size >= channel_size else pruned_ratio if len(sensitivities[name].keys()) > 0: _logger.debug( '{} exist; pruned ratio: {}; excepted ratio: {}'.format( name, sensitivities[name].keys(), pruned_ratio)) continue if baseline is None: baseline = eval_func(graph.program) param_backup = {} pruner = Pruner() _logger.info("sensitive - param: {}; ratios: {}".format(name, pruned_ratio)) loss = 1 if pruned_ratio < 1: pruned_program = pruner.prune( program=graph.program, scope=scope, params=[name], ratios=[pruned_ratio], place=place, lazy=True, only_graph=False, param_backup=param_backup) pruned_metric = eval_func(pruned_program) loss = (baseline - pruned_metric) / baseline _logger.info("pruned param: {}; {}; loss={}".format(name, pruned_ratio, loss)) sensitivities[name][pruned_ratio] = loss _save_sensitivities(sensitivities, sensitivities_file) # restore pruned parameters for param_name in param_backup.keys(): param_t = scope.find_var(param_name).get_tensor() param_t.set(param_backup[param_name], place) return sensitivities def merge_sensitive(sensitivities): """ Merge sensitivities. Args: sensitivities(list | list): The sensitivities to be merged. It cann be a list of sensitivities files or dict. Returns: sensitivities(dict): A dict with sensitivities. """ assert len(sensitivities) > 0 if not isinstance(sensitivities[0], dict): sensitivities = [pickle.load(open(sen, 'r')) for sen in sensitivities] new_sensitivities = {} for sen in sensitivities: for param, losses in sen.items(): if param not in new_sensitivities: new_sensitivities[param] = {} for percent, loss in losses.items(): new_sensitivities[param][percent] = loss return new_sensitivities def load_sensitivities(sensitivities_file): """ Load sensitivities from file. """ sensitivities = {} if sensitivities_file and os.path.exists(sensitivities_file): with open(sensitivities_file, 'rb') as f: if sys.version_info < (3, 0): sensitivities = pickle.load(f) else: sensitivities = pickle.load(f, encoding='bytes') return sensitivities def _save_sensitivities(sensitivities, sensitivities_file): """ Save sensitivities into file. """ with open(sensitivities_file, 'wb') as f: pickle.dump(sensitivities, f) def get_ratios_by_loss(sensitivities, loss): """ Get the max ratio of each parameter. The loss of accuracy must be less than given `loss` when the single parameter was pruned by the max ratio. Args: sensitivities(dict): The sensitivities used to generate a group of pruning ratios. The key of dict is name of parameters to be pruned. The value of dict is a list of tuple with format `(pruned_ratio, accuracy_loss)`. loss(float): The threshold of accuracy loss. Returns: ratios(dict): A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned. """ ratios = {} for param, losses in sensitivities.items(): losses = losses.items() losses.sort() for i in range(len(losses))[::-1]: if losses[i][1] <= loss: if i == (len(losses) - 1): ratios[param] = losses[i][0] else: r0, l0 = losses[i] r1, l1 = losses[i + 1] d0 = loss - l0 d1 = l1 - loss ratio = r0 + (loss - l0) * (r1 - r0) / (l1 - l0) ratios[param] = ratio if ratio > 1: print losses, ratio, (r1 - r0) / (l1 - l0), i break return ratios