未验证 提交 0623a5df 编写于 作者: L lijianshe02 提交者: GitHub

fix optimal threshold pruning bugs in only_graph mode test=develop (#192)

上级 30636a29
...@@ -8,6 +8,7 @@ import math ...@@ -8,6 +8,7 @@ import math
import time import time
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
sys.path.append("../../")
from paddleslim.prune import load_model from paddleslim.prune import load_model
from paddleslim.common import get_logger from paddleslim.common import get_logger
from paddleslim.analysis import flops from paddleslim.analysis import flops
......
...@@ -79,7 +79,11 @@ class Pruner(): ...@@ -79,7 +79,11 @@ class Pruner():
if only_graph: if only_graph:
param_v = graph.var(param) param_v = graph.var(param)
pruned_num = int(round(param_v.shape()[0] * ratio)) pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num if self.criterion == "optimal_threshold":
pruned_idx = self._cal_pruned_idx(
graph, scope, param, ratio, axis=0)
else:
pruned_idx = [0] * pruned_num
else: else:
pruned_idx = self._cal_pruned_idx( pruned_idx = self._cal_pruned_idx(
graph, scope, param, ratio, axis=0) graph, scope, param, ratio, axis=0)
...@@ -202,7 +206,7 @@ class Pruner(): ...@@ -202,7 +206,7 @@ class Pruner():
) / 2 if i > 0 else 0 ) / 2 if i > 0 else 0
return th return th
optimal_th = get_optimal_threshold(bn_scale_np, 0.12) optimal_th = get_optimal_threshold(bn_scale_np, ratio)
pruned_idx = np.squeeze( pruned_idx = np.squeeze(
np.argwhere(bn_scale_np < optimal_th)) np.argwhere(bn_scale_np < optimal_th))
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册