提交 9dba4a12 编写于 作者: L LDOUBLEV

optimize the prune

上级 e2ed89fa
...@@ -110,25 +110,42 @@ def main(config, device, logger, vdl_writer): ...@@ -110,25 +110,42 @@ def main(config, device, logger, vdl_writer):
logger.info("metric['hmean']: {}".format(metric['hmean'])) logger.info("metric['hmean']: {}".format(metric['hmean']))
return metric['hmean'] return metric['hmean']
params_sensitive = pruner.sensitive( run_sensitive_analysis = False
eval_func=eval_fn, """
sen_file="./sen.pickle", run_sensitive_analysis=True:
skip_vars=[ Automatically compute the sensitivities of convolutions in a model.
"conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0" The sensitivity of a convolution is the losses of accuracy on test dataset in
]) differenct pruned ratios. The sensitivities can be used to get a group of best
ratios with some condition.
logger.info(
"The sensitivity analysis results of model parameters saved in sen.pickle" run_sensitive_analysis=False:
) Set prune trim ratio to a fixed value, such as 10%. The larger the value,
# calculate pruned params's ratio the more convolution weights will be cropped.
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
for key in params_sensitive.keys(): """
logger.info("{}, {}".format(key, params_sensitive[key]))
if run_sensitive_analysis:
#params_sensitive = {} params_sensitive = pruner.sensitive(
#for param in model.parameters(): eval_func=eval_fn,
# if 'transpose' not in param.name and 'linear' not in param.name: sen_file="./deploy/slim/prune/sen.pickle",
# params_sensitive[param.name] = 0.1 skip_vars=[
"conv2d_57.w_0", "conv2d_transpose_2.w_0",
"conv2d_transpose_3.w_0"
])
logger.info(
"The sensitivity analysis results of model parameters saved in sen.pickle"
)
# calculate pruned params's ratio
params_sensitive = pruner._get_ratios_by_loss(
params_sensitive, loss=0.02)
for key in params_sensitive.keys():
logger.info("{}, {}".format(key, params_sensitive[key]))
else:
params_sensitive = {}
for param in model.parameters():
if 'transpose' not in param.name and 'linear' not in param.name:
# set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
params_sensitive[param.name] = 0.1
plan = pruner.prune_vars(params_sensitive, [0]) plan = pruner.prune_vars(params_sensitive, [0])
......
...@@ -351,7 +351,7 @@ def eval(model, ...@@ -351,7 +351,7 @@ def eval(model,
valid_dataloader, valid_dataloader,
post_process_class, post_process_class,
eval_class, eval_class,
model_type, model_type=None,
use_srn=False, use_srn=False,
use_sar=False): use_sar=False):
model.eval() model.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册