diff --git a/deploy/slim/prune/sensitivity_anal.py b/deploy/slim/prune/sensitivity_anal.py index f80ddd9fb424a4b0c8f30304530d7d20f982598a..0f0492af2f57eea9b9c1d13ec5ee1dad9fc2f1bc 100644 --- a/deploy/slim/prune/sensitivity_anal.py +++ b/deploy/slim/prune/sensitivity_anal.py @@ -110,25 +110,42 @@ def main(config, device, logger, vdl_writer): logger.info("metric['hmean']: {}".format(metric['hmean'])) return metric['hmean'] - params_sensitive = pruner.sensitive( - eval_func=eval_fn, - sen_file="./sen.pickle", - skip_vars=[ - "conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0" - ]) - - logger.info( - "The sensitivity analysis results of model parameters saved in sen.pickle" - ) - # calculate pruned params's ratio - params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02) - for key in params_sensitive.keys(): - logger.info("{}, {}".format(key, params_sensitive[key])) - - #params_sensitive = {} - #for param in model.parameters(): - # if 'transpose' not in param.name and 'linear' not in param.name: - # params_sensitive[param.name] = 0.1 + run_sensitive_analysis = False + """ + run_sensitive_analysis=True: + Automatically compute the sensitivities of convolutions in a model. + The sensitivity of a convolution is the losses of accuracy on test dataset in + differenct pruned ratios. The sensitivities can be used to get a group of best + ratios with some condition. + + run_sensitive_analysis=False: + Set prune trim ratio to a fixed value, such as 10%. The larger the value, + the more convolution weights will be cropped. + + """ + + if run_sensitive_analysis: + params_sensitive = pruner.sensitive( + eval_func=eval_fn, + sen_file="./deploy/slim/prune/sen.pickle", + skip_vars=[ + "conv2d_57.w_0", "conv2d_transpose_2.w_0", + "conv2d_transpose_3.w_0" + ]) + logger.info( + "The sensitivity analysis results of model parameters saved in sen.pickle" + ) + # calculate pruned params's ratio + params_sensitive = pruner._get_ratios_by_loss( + params_sensitive, loss=0.02) + for key in params_sensitive.keys(): + logger.info("{}, {}".format(key, params_sensitive[key])) + else: + params_sensitive = {} + for param in model.parameters(): + if 'transpose' not in param.name and 'linear' not in param.name: + # set prune ratio as 10%. The larger the value, the more convolution weights will be cropped + params_sensitive[param.name] = 0.1 plan = pruner.prune_vars(params_sensitive, [0]) diff --git a/tools/program.py b/tools/program.py index f484cf4a1f512107bd755a6b446935751563bfcb..2015a0fa74cff8f9b18dc59a594f2051c272432a 100755 --- a/tools/program.py +++ b/tools/program.py @@ -351,7 +351,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class, - model_type, + model_type=None, use_srn=False, use_sar=False): model.eval()