diff --git a/paddleslim/prune/sensitive.py b/paddleslim/prune/sensitive.py index 608bc83fa42574d060d07d1d758d535c3027850e..a5a6e3601e4a493db17de83200e48bf04109164a 100644 --- a/paddleslim/prune/sensitive.py +++ b/paddleslim/prune/sensitive.py @@ -26,8 +26,8 @@ from ..prune import Pruner _logger = get_logger(__name__, level=logging.INFO) __all__ = [ - "sensitivity", "flops_sensitivity", "load_sensitivities", - "merge_sensitive", "get_ratios_by_loss" + "sensitivity", "flops_sensitivity", "load_sensitivities", "merge_sensitive", + "get_ratios_by_loss" ] @@ -36,7 +36,9 @@ def sensitivity(program, param_names, eval_func, sensitivities_file=None, - pruned_ratios=None): + pruned_ratios=None, + eval_args=None, + criterion='l1_norm'): """Compute the sensitivities of convolutions in a model. The sensitivity of a convolution is the losses of accuracy on test dataset in differenct pruned ratios. The sensitivities can be used to get a group of best ratios with some condition. This function return a dict storing sensitivities as below: @@ -83,9 +85,12 @@ def sensitivity(program, _logger.debug('{}, {} has computed.'.format(name, ratio)) continue if baseline is None: - baseline = eval_func(graph.program) + if eval_args is None: + baseline = eval_func(graph.program) + else: + baseline = eval_func(eval_args) - pruner = Pruner() + pruner = Pruner(criterion=criterion) _logger.info("sensitive - param: {}; ratios: {}".format(name, ratio)) pruned_program, param_backup, _ = pruner.prune( @@ -97,7 +102,10 @@ def sensitivity(program, lazy=True, only_graph=False, param_backup=True) - pruned_metric = eval_func(pruned_program) + if eval_args is None: + pruned_metric = eval_func(pruned_program) + else: + pruned_metric = eval_func(eval_args) loss = (baseline - pruned_metric) / baseline _logger.info("pruned param: {}; {}; loss={}".format(name, ratio, loss)) diff --git a/tests/test_sensitivity.py b/tests/test_sensitivity.py index 2fad0a4273518130f4b45f16c6aa96274d0873de..948a57e1095d0190e1b9f79782c257fbc88fae57 100644 --- a/tests/test_sensitivity.py +++ b/tests/test_sensitivity.py @@ -60,20 +60,42 @@ class TestSensitivity(unittest.TestCase): print("acc_val_mean: {}".format(acc_val_mean)) return acc_val_mean + def eval_func_for_args(args): + program = args[0] + feeder = fluid.DataFeeder( + feed_list=['image', 'label'], place=place, program=program) + acc_set = [] + for data in val_reader(): + acc_np = exe.run(program=program, + feed=feeder.feed(data), + fetch_list=[acc_top1]) + acc_set.append(float(acc_np[0])) + acc_val_mean = numpy.array(acc_set).mean() + print("acc_val_mean: {}".format(acc_val_mean)) + return acc_val_mean + sensitivity( eval_program, place, ["conv4_weights"], eval_func, - "./sensitivities_file_0", + sensitivities_file="./sensitivities_file_0", pruned_ratios=[0.1, 0.2]) sensitivity( eval_program, place, ["conv4_weights"], eval_func, - "./sensitivities_file_1", + sensitivities_file="./sensitivities_file_1", pruned_ratios=[0.3, 0.4]) + params_sens = sensitivity( + eval_program, + place, ["conv4_weights"], + eval_func_for_args, + eval_args=[eval_program], + sensitivities_file="./sensitivites_file_params", + pruned_ratios=[0.1, 0.2, 0.3, 0.4]) + sens_0 = load_sensitivities('./sensitivities_file_0') sens_1 = load_sensitivities('./sensitivities_file_1') sens = merge_sensitive([sens_0, sens_1]) @@ -81,8 +103,9 @@ class TestSensitivity(unittest.TestCase): eval_program, place, ["conv4_weights"], eval_func, - "./sensitivities_file_1", + sensitivities_file="./sensitivities_file_2", pruned_ratios=[0.1, 0.2, 0.3, 0.4]) + self.assertTrue(params_sens == origin_sens) self.assertTrue(sens == origin_sens)