未验证 提交 6464b878 编写于 作者: Y yukavio 提交者: GitHub

make sensitive analysis more flexible (#449)

上级 7f3d0460
......@@ -26,8 +26,8 @@ from ..prune import Pruner
_logger = get_logger(__name__, level=logging.INFO)
__all__ = [
"sensitivity", "flops_sensitivity", "load_sensitivities",
"merge_sensitive", "get_ratios_by_loss"
"sensitivity", "flops_sensitivity", "load_sensitivities", "merge_sensitive",
"get_ratios_by_loss"
]
......@@ -36,7 +36,9 @@ def sensitivity(program,
param_names,
eval_func,
sensitivities_file=None,
pruned_ratios=None):
pruned_ratios=None,
eval_args=None,
criterion='l1_norm'):
"""Compute the sensitivities of convolutions in a model. The sensitivity of a convolution is the losses of accuracy on test dataset in differenct pruned ratios. The sensitivities can be used to get a group of best ratios with some condition.
This function return a dict storing sensitivities as below:
......@@ -83,9 +85,12 @@ def sensitivity(program,
_logger.debug('{}, {} has computed.'.format(name, ratio))
continue
if baseline is None:
baseline = eval_func(graph.program)
if eval_args is None:
baseline = eval_func(graph.program)
else:
baseline = eval_func(eval_args)
pruner = Pruner()
pruner = Pruner(criterion=criterion)
_logger.info("sensitive - param: {}; ratios: {}".format(name,
ratio))
pruned_program, param_backup, _ = pruner.prune(
......@@ -97,7 +102,10 @@ def sensitivity(program,
lazy=True,
only_graph=False,
param_backup=True)
pruned_metric = eval_func(pruned_program)
if eval_args is None:
pruned_metric = eval_func(pruned_program)
else:
pruned_metric = eval_func(eval_args)
loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format(name, ratio,
loss))
......
......@@ -60,20 +60,42 @@ class TestSensitivity(unittest.TestCase):
print("acc_val_mean: {}".format(acc_val_mean))
return acc_val_mean
def eval_func_for_args(args):
program = args[0]
feeder = fluid.DataFeeder(
feed_list=['image', 'label'], place=place, program=program)
acc_set = []
for data in val_reader():
acc_np = exe.run(program=program,
feed=feeder.feed(data),
fetch_list=[acc_top1])
acc_set.append(float(acc_np[0]))
acc_val_mean = numpy.array(acc_set).mean()
print("acc_val_mean: {}".format(acc_val_mean))
return acc_val_mean
sensitivity(
eval_program,
place, ["conv4_weights"],
eval_func,
"./sensitivities_file_0",
sensitivities_file="./sensitivities_file_0",
pruned_ratios=[0.1, 0.2])
sensitivity(
eval_program,
place, ["conv4_weights"],
eval_func,
"./sensitivities_file_1",
sensitivities_file="./sensitivities_file_1",
pruned_ratios=[0.3, 0.4])
params_sens = sensitivity(
eval_program,
place, ["conv4_weights"],
eval_func_for_args,
eval_args=[eval_program],
sensitivities_file="./sensitivites_file_params",
pruned_ratios=[0.1, 0.2, 0.3, 0.4])
sens_0 = load_sensitivities('./sensitivities_file_0')
sens_1 = load_sensitivities('./sensitivities_file_1')
sens = merge_sensitive([sens_0, sens_1])
......@@ -81,8 +103,9 @@ class TestSensitivity(unittest.TestCase):
eval_program,
place, ["conv4_weights"],
eval_func,
"./sensitivities_file_1",
sensitivities_file="./sensitivities_file_2",
pruned_ratios=[0.1, 0.2, 0.3, 0.4])
self.assertTrue(params_sens == origin_sens)
self.assertTrue(sens == origin_sens)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册