From 5970f5b6ca6a6745e4afa0c7f771ae5a8cb2cda2 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 25 Nov 2019 17:45:57 +0800 Subject: [PATCH] Fix greedy pruner. --- demo/sensitive_prune/greedy_prune.py | 2 +- paddleslim/prune/sensitive.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/demo/sensitive_prune/greedy_prune.py b/demo/sensitive_prune/greedy_prune.py index b6703a9a..2534045f 100644 --- a/demo/sensitive_prune/greedy_prune.py +++ b/demo/sensitive_prune/greedy_prune.py @@ -217,7 +217,7 @@ def compress(args): train(i, pruned_program) acc = test(i, pruned_val_program) print("iter:{}; pruned FLOPS: {}; acc: {}".format( - iter, float(base_flops - current_flops) / base_flops), acc) + iter, float(base_flops - current_flops) / base_flops, acc)) pruner.save_checkpoint(pruned_program, pruned_val_program) diff --git a/paddleslim/prune/sensitive.py b/paddleslim/prune/sensitive.py index 23338284..ab3a57cb 100644 --- a/paddleslim/prune/sensitive.py +++ b/paddleslim/prune/sensitive.py @@ -57,12 +57,15 @@ def sensitivity(program, if ratio in sensitivities[name]['pruned_percent']: _logger.debug('{}, {} has computed.'.format(name, ratio)) ratio += step_size + pruned_times += 1 continue if baseline is None: baseline = eval_func(graph.program) param_backup = {} pruner = Pruner() + _logger.info("sensitive - param: {}; ratios: {}".format(name, + ratio)) pruned_program = pruner.prune( program=graph.program, scope=scope, -- GitLab