提交 570fdf6c 编写于 作者: S sunyanfang01

fix the prune

上级 640b80f6
......@@ -116,6 +116,21 @@ def channel_prune(program, prune_names, prune_ratios, place, only_graph=False):
Returns:
paddle.fluid.Program: 裁剪后的Program。
"""
prog_var_shape_dict = {}
for var in program.list_vars():
try:
prog_var_shape_dict[var.name] = var.shape
except Exception:
pass
index = 0
for param, ratio in zip(prune_names, prune_ratios):
origin_num = prog_var_shape_dict[param][0]
pruned_num = int(round(origin_num * ratio))
while origin_num == pruned_num:
ratio -= 0.1
pruned_num = int(round(origin_num * (ratio)))
prune_ratios[index] = ratio
index += 1
scope = fluid.global_scope()
pruner = Pruner()
program, _, _ = pruner.prune(
......@@ -266,6 +281,7 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
sensitivitives = paddleslim.prune.load_sensitivities(sensitivities_file)
params_ratios = paddleslim.prune.get_ratios_by_loss(
sensitivitives, eval_metric_loss)
return params_ratios
......@@ -284,6 +300,19 @@ def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05):
"""
prune_params_ratios = get_params_ratios(sensitivities_file,
eval_metric_loss)
prog_var_shape_dict = {}
for var in program.list_vars():
try:
prog_var_shape_dict[var.name] = var.shape
except Exception:
pass
for param, ratio in prune_params_ratios.items():
origin_num = prog_var_shape_dict[param][0]
pruned_num = int(round(origin_num * ratio))
while origin_num == pruned_num:
ratio -= 0.1
pruned_num = int(round(origin_num * (ratio)))
prune_params_ratios[param] = ratio
prune_program = channel_prune(
program,
list(prune_params_ratios.keys()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册