From 0d04d6bd65b54075b5247a77386a637a3b9fcf96 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 4 Dec 2019 20:12:38 +0800 Subject: [PATCH] Simplify arguments of sensitive API. --- paddleslim/prune/sensitive.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/paddleslim/prune/sensitive.py b/paddleslim/prune/sensitive.py index d0edbc2b..a93211cb 100644 --- a/paddleslim/prune/sensitive.py +++ b/paddleslim/prune/sensitive.py @@ -33,12 +33,14 @@ def sensitivity(program, param_names, eval_func, sensitivities_file=None, - step_size=0.2, - max_pruned_times=None): + pruned_ratios=None): scope = fluid.global_scope() graph = GraphWrapper(program) sensitivities = _load_sensitivities(sensitivities_file) + if pruned_ratios is None: + pruned_ratios = np.arange(0.1, 1, step=0.1) + for name in param_names: if name not in sensitivities: size = graph.var(name).shape()[0] @@ -49,16 +51,9 @@ def sensitivity(program, } baseline = None for name in sensitivities: - ratio = step_size - pruned_times = 0 - while ratio < 1: - if max_pruned_times is not None and pruned_times >= max_pruned_times: - break - ratio = round(ratio, 2) + for ratio in pruned_ratios: if ratio in sensitivities[name]['pruned_percent']: _logger.debug('{}, {} has computed.'.format(name, ratio)) - ratio += step_size - pruned_times += 1 continue if baseline is None: baseline = eval_func(graph.program) @@ -92,8 +87,6 @@ def sensitivity(program, for param_name in param_backup.keys(): param_t = scope.find_var(param_name).get_tensor() param_t.set(param_backup[param_name], place) - ratio += step_size - pruned_times += 1 return sensitivities -- GitLab