提交 983002a3 编写于 作者: W wanghaoshuang

Change sensitive pruner based size of parameters to that based FLOPS.

上级 5970f5b6
...@@ -208,7 +208,7 @@ def compress(args): ...@@ -208,7 +208,7 @@ def compress(args):
end = args.prune_steps end = args.prune_steps
for iter in range(start, end): for iter in range(start, end):
pruned_program, pruned_val_program = pruner.greedy_prune( pruned_program, pruned_val_program = pruner.greedy_prune(
pruned_program, pruned_val_program, params, 0.1, topk=1) pruned_program, pruned_val_program, params, 0.03, topk=1)
current_flops = flops(pruned_val_program) current_flops = flops(pruned_val_program)
print("iter:{}; pruned FLOPS: {}".format( print("iter:{}; pruned FLOPS: {}".format(
iter, float(base_flops - current_flops) / base_flops)) iter, float(base_flops - current_flops) / base_flops))
......
...@@ -20,11 +20,12 @@ import numpy as np ...@@ -20,11 +20,12 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from ..core import GraphWrapper from ..core import GraphWrapper
from ..common import get_logger from ..common import get_logger
from ..analysis import flops
from ..prune import Pruner from ..prune import Pruner
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
__all__ = ["sensitivity"] __all__ = ["sensitivity", "flops_sensitivity"]
def sensitivity(program, def sensitivity(program,
...@@ -92,6 +93,85 @@ def sensitivity(program, ...@@ -92,6 +93,85 @@ def sensitivity(program,
return sensitivities return sensitivities
def flops_sensitivity(program,
place,
param_names,
eval_func,
sensitivities_file=None,
pruned_flops_rate=0.1):
assert (1.0 / len(param_names) > pruned_flops_rate)
scope = fluid.global_scope()
graph = GraphWrapper(program)
sensitivities = _load_sensitivities(sensitivities_file)
for name in param_names:
if name not in sensitivities:
size = graph.var(name).shape()[0]
sensitivities[name] = {
'pruned_percent': [],
'loss': [],
'size': size
}
base_flops = flops(program)
target_pruned_flops = base_flops * pruned_flops_rate
pruner = Pruner()
baseline = None
for name in sensitivities:
pruned_program = pruner.prune(
program=graph.program,
scope=None,
params=[name],
ratios=[0.5],
place=None,
lazy=False,
only_graph=True)
param_flops = (base_flops - flops(pruned_program)) * 2
channel_size = sensitivities[name]["size"]
pruned_ratio = target_pruned_flops / float(param_flops)
pruned_size = round(pruned_ratio * channel_size)
pruned_ratio = 1 if pruned_size >= channel_size else pruned_ratio
if len(sensitivities[name]["pruned_percent"]) > 0:
_logger.debug('{} exist; pruned ratio: {}; excepted ratio: {}'.
format(name, sensitivities[name]["pruned_percent"][
0], pruned_ratio))
continue
if baseline is None:
baseline = eval_func(graph.program)
param_backup = {}
pruner = Pruner()
_logger.info("sensitive - param: {}; ratios: {}".format(name,
pruned_ratio))
loss = 1
if pruned_ratio < 1:
pruned_program = pruner.prune(
program=graph.program,
scope=scope,
params=[name],
ratios=[pruned_ratio],
place=place,
lazy=True,
only_graph=False,
param_backup=param_backup)
pruned_metric = eval_func(pruned_program)
loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format(name, pruned_ratio,
loss))
sensitivities[name]['pruned_percent'].append(pruned_ratio)
sensitivities[name]['loss'].append(loss)
_save_sensitivities(sensitivities, sensitivities_file)
# restore pruned parameters
for param_name in param_backup.keys():
param_t = scope.find_var(param_name).get_tensor()
param_t.set(param_backup[param_name], place)
return sensitivities
def _load_sensitivities(sensitivities_file): def _load_sensitivities(sensitivities_file):
""" """
Load sensitivities from file. Load sensitivities from file.
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from ..common import get_logger from ..common import get_logger
from .sensitive import sensitivity from .sensitive import sensitivity
from .sensitive import flops_sensitivity
from ..analysis import flops from ..analysis import flops
from .pruner import Pruner from .pruner import Pruner
...@@ -90,20 +91,19 @@ class SensitivePruner(object): ...@@ -90,20 +91,19 @@ class SensitivePruner(object):
train_program, train_program,
eval_program, eval_program,
params, params,
pruned_ratio, pruned_flops_rate,
topk=1): topk=1):
sensitivities_file = "greedy_sensitivities_iter{}.data".format( sensitivities_file = "greedy_sensitivities_iter{}.data".format(
self._iter) self._iter)
with fluid.scope_guard(self._scope): with fluid.scope_guard(self._scope):
sensitivities = sensitivity( sensitivities = flops_sensitivity(
eval_program, eval_program,
self._place, self._place,
params, params,
self._eval_func, self._eval_func,
sensitivities_file=sensitivities_file, sensitivities_file=sensitivities_file,
step_size=pruned_ratio, pruned_flops_rate=pruned_flops_rate)
max_pruned_times=1)
print sensitivities print sensitivities
params, ratios = self._greedy_ratio_by_sensitive(sensitivities, topk) params, ratios = self._greedy_ratio_by_sensitive(sensitivities, topk)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册