如何理解敏感度裁切中的greedy_prune和普通的prune区别?
Created by: zhizunbao-y
1.如提问 2.下列代码中 ,#标出的部分有何意义?为什么要乘2?
def flops_sensitivity(program,
place,
param_names,
eval_func,
sensitivities_file=None,
pruned_flops_rate=0.1):
assert (1.0 / len(param_names) > pruned_flops_rate)
scope = fluid.global_scope()
graph = GraphWrapper(program)
sensitivities = load_sensitivities(sensitivities_file)
for name in param_names:
if name not in sensitivities:
sensitivities[name] = {}
base_flops = flops(program)
target_pruned_flops = base_flops * pruned_flops_rate
pruner = Pruner()
baseline = None
for name in sensitivities:
pruned_program, _, _ = pruner.prune(
program=graph.program,
scope=None,
params=[name],
ratios=[0.5],
place=None,
lazy=False,
only_graph=True)
################################################
param_flops = (base_flops - flops(pruned_program)) * 2
channel_size = graph.var(name).shape()[0]
pruned_ratio = target_pruned_flops / float(param_flops)
################################################
pruned_ratio = round(pruned_ratio, 3)
pruned_size = round(pruned_ratio * channel_size)
pruned_ratio = 1 if pruned_size >= channel_size else pruned_ratio
if len(sensitivities[name].keys()) > 0:
_logger.debug(
'{} exist; pruned ratio: {}; excepted ratio: {}'.format(
name, sensitivities[name].keys(), pruned_ratio))
continue
if baseline is None:
baseline = eval_func(graph.program)
param_backup = {}
pruner = Pruner()
_logger.info("sensitive - param: {}; ratios: {}".format(name,
pruned_ratio))
loss = 1
if pruned_ratio < 1:
pruned_program = pruner.prune(
program=graph.program,
scope=scope,
params=[name],
ratios=[pruned_ratio],
place=place,
lazy=True,
only_graph=False,
param_backup=param_backup)
pruned_metric = eval_func(pruned_program)
loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format(name, pruned_ratio,
loss))
sensitivities[name][pruned_ratio] = loss
_save_sensitivities(sensitivities, sensitivities_file)
# restore pruned parameters
for param_name in param_backup.keys():
param_t = scope.find_var(param_name).get_tensor()
param_t.set(param_backup[param_name], place)
return sensitivities
3.下列代码中min_loss和max_loss初始值均为0,会导致while循环不执行
def get_ratios_by_sensitive(self, sensitivities, pruned_flops,
eval_program):
"""
Search a group of ratios for pruning target flops.
Args:
sensitivities(dict): The sensitivities used to generate a group of pruning ratios. The key of dict
is name of parameters to be pruned. The value of dict is a list of tuple with
format `(pruned_ratio, accuracy_loss)`.
pruned_flops(float): The percent of FLOPS to be pruned.
eval_program(Program): The program whose FLOPS is considered.
Returns:
dict: A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned.
"""
min_loss = 0.
max_loss = 0.
# step 2: Find a group of ratios by binary searching.
base_flops = flops(eval_program)
ratios = None
max_times = 20
while min_loss < max_loss and max_times > 0:
loss = (max_loss + min_loss) / 2
_logger.info(
'-----------Try pruned ratios while acc loss={}-----------'.
format(loss))
ratios = self.get_ratios_by_loss(sensitivities, loss)
_logger.info('Pruned ratios={}'.format(
[round(ratio, 3) for ratio in ratios.values()]))
pruned_program = self._pruner.prune(
eval_program,
None, # scope
ratios.keys(),
ratios.values(),
None, # place
only_graph=True)
pruned_ratio = 1 - (float(flops(pruned_program)) / base_flops)
_logger.info('Pruned flops: {:.4f}'.format(pruned_ratio))
# Check whether current ratios is enough
if abs(pruned_ratio - pruned_flops) < 0.015:
break
if pruned_ratio > pruned_flops:
max_loss = loss
else:
min_loss = loss
max_times -= 1
return ratios