提交 b66b3a1a 编写于 作者: J jiangjiajun

modify cal sensitivities

上级 f001960b
......@@ -66,16 +66,15 @@ def sensitivity(program,
progress = "%.2f%%" % (progress * 100)
logging.info(
"Total evaluate iters={}, current={}, progress={}, eta={}".
format(
total_evaluate_iters, current_iter, progress,
seconds_to_hms(
int(cost * (total_evaluate_iters - current_iter)))),
format(total_evaluate_iters, current_iter, progress,
seconds_to_hms(
int(cost * (total_evaluate_iters - current_iter)))),
use_color=True)
current_iter += 1
pruner = Pruner()
logging.info("sensitive - param: {}; ratios: {}".format(
name, ratio))
logging.info("sensitive - param: {}; ratios: {}".format(name,
ratio))
pruned_program, param_backup, _ = pruner.prune(
program=graph.program,
scope=scope,
......@@ -87,8 +86,8 @@ def sensitivity(program,
param_backup=True)
pruned_metric = eval_func(pruned_program)
loss = (baseline - pruned_metric) / baseline
logging.info("pruned param: {}; {}; loss={}".format(
name, ratio, loss))
logging.info("pruned param: {}; {}; loss={}".format(name, ratio,
loss))
sensitivities[name][ratio] = loss
......@@ -221,6 +220,9 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
"""
if os.path.exists(save_file):
os.remove(save_file)
prune_names = get_prune_params(model)
def eval_for_prune(program):
......@@ -264,8 +266,8 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
if not osp.exists(sensitivities_file):
raise Exception('The sensitivities file is not exists!')
sensitivitives = paddleslim.prune.load_sensitivities(sensitivities_file)
params_ratios = paddleslim.prune.get_ratios_by_loss(
sensitivitives, eval_metric_loss)
params_ratios = paddleslim.prune.get_ratios_by_loss(sensitivitives,
eval_metric_loss)
return params_ratios
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册