提交 b66b3a1a 编写于 作者: J jiangjiajun

modify cal sensitivities

上级 f001960b
...@@ -66,16 +66,15 @@ def sensitivity(program, ...@@ -66,16 +66,15 @@ def sensitivity(program,
progress = "%.2f%%" % (progress * 100) progress = "%.2f%%" % (progress * 100)
logging.info( logging.info(
"Total evaluate iters={}, current={}, progress={}, eta={}". "Total evaluate iters={}, current={}, progress={}, eta={}".
format( format(total_evaluate_iters, current_iter, progress,
total_evaluate_iters, current_iter, progress, seconds_to_hms(
seconds_to_hms( int(cost * (total_evaluate_iters - current_iter)))),
int(cost * (total_evaluate_iters - current_iter)))),
use_color=True) use_color=True)
current_iter += 1 current_iter += 1
pruner = Pruner() pruner = Pruner()
logging.info("sensitive - param: {}; ratios: {}".format( logging.info("sensitive - param: {}; ratios: {}".format(name,
name, ratio)) ratio))
pruned_program, param_backup, _ = pruner.prune( pruned_program, param_backup, _ = pruner.prune(
program=graph.program, program=graph.program,
scope=scope, scope=scope,
...@@ -87,8 +86,8 @@ def sensitivity(program, ...@@ -87,8 +86,8 @@ def sensitivity(program,
param_backup=True) param_backup=True)
pruned_metric = eval_func(pruned_program) pruned_metric = eval_func(pruned_program)
loss = (baseline - pruned_metric) / baseline loss = (baseline - pruned_metric) / baseline
logging.info("pruned param: {}; {}; loss={}".format( logging.info("pruned param: {}; {}; loss={}".format(name, ratio,
name, ratio, loss)) loss))
sensitivities[name][ratio] = loss sensitivities[name][ratio] = loss
...@@ -221,6 +220,9 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8): ...@@ -221,6 +220,9 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。 其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
""" """
if os.path.exists(save_file):
os.remove(save_file)
prune_names = get_prune_params(model) prune_names = get_prune_params(model)
def eval_for_prune(program): def eval_for_prune(program):
...@@ -264,8 +266,8 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05): ...@@ -264,8 +266,8 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
if not osp.exists(sensitivities_file): if not osp.exists(sensitivities_file):
raise Exception('The sensitivities file is not exists!') raise Exception('The sensitivities file is not exists!')
sensitivitives = paddleslim.prune.load_sensitivities(sensitivities_file) sensitivitives = paddleslim.prune.load_sensitivities(sensitivities_file)
params_ratios = paddleslim.prune.get_ratios_by_loss( params_ratios = paddleslim.prune.get_ratios_by_loss(sensitivitives,
sensitivitives, eval_metric_loss) eval_metric_loss)
return params_ratios return params_ratios
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册