diff --git a/paddleslim/prune/sensitive.py b/paddleslim/prune/sensitive.py index d0edbc2ba9f1e6709fea22136628836081e676a8..a93211cb51886638af545be5215512b6e23b372a 100644 --- a/paddleslim/prune/sensitive.py +++ b/paddleslim/prune/sensitive.py @@ -33,12 +33,14 @@ def sensitivity(program, param_names, eval_func, sensitivities_file=None, - step_size=0.2, - max_pruned_times=None): + pruned_ratios=None): scope = fluid.global_scope() graph = GraphWrapper(program) sensitivities = _load_sensitivities(sensitivities_file) + if pruned_ratios is None: + pruned_ratios = np.arange(0.1, 1, step=0.1) + for name in param_names: if name not in sensitivities: size = graph.var(name).shape()[0] @@ -49,16 +51,9 @@ def sensitivity(program, } baseline = None for name in sensitivities: - ratio = step_size - pruned_times = 0 - while ratio < 1: - if max_pruned_times is not None and pruned_times >= max_pruned_times: - break - ratio = round(ratio, 2) + for ratio in pruned_ratios: if ratio in sensitivities[name]['pruned_percent']: _logger.debug('{}, {} has computed.'.format(name, ratio)) - ratio += step_size - pruned_times += 1 continue if baseline is None: baseline = eval_func(graph.program) @@ -92,8 +87,6 @@ def sensitivity(program, for param_name in param_backup.keys(): param_t = scope.find_var(param_name).get_tensor() param_t.set(param_backup[param_name], place) - ratio += step_size - pruned_times += 1 return sensitivities