未验证 提交 0623a5df 编写于 作者: L lijianshe02 提交者: GitHub

fix optimal threshold pruning bugs in only_graph mode test=develop (#192)

上级 30636a29
......@@ -8,6 +8,7 @@ import math
import time
import numpy as np
import paddle.fluid as fluid
sys.path.append("../../")
from paddleslim.prune import load_model
from paddleslim.common import get_logger
from paddleslim.analysis import flops
......
......@@ -79,7 +79,11 @@ class Pruner():
if only_graph:
param_v = graph.var(param)
pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num
if self.criterion == "optimal_threshold":
pruned_idx = self._cal_pruned_idx(
graph, scope, param, ratio, axis=0)
else:
pruned_idx = [0] * pruned_num
else:
pruned_idx = self._cal_pruned_idx(
graph, scope, param, ratio, axis=0)
......@@ -202,7 +206,7 @@ class Pruner():
) / 2 if i > 0 else 0
return th
optimal_th = get_optimal_threshold(bn_scale_np, 0.12)
optimal_th = get_optimal_threshold(bn_scale_np, ratio)
pruned_idx = np.squeeze(
np.argwhere(bn_scale_np < optimal_th))
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册