未验证 提交 b0bb64ab 编写于 作者: Z zhouzj 提交者: GitHub

Fix the bug when pruning conv2d with small filter number. (#1424)

上级 dd544b87
......@@ -42,6 +42,7 @@ class FPGMFilterPruner(FilterPruner):
sorted_idx = scores.argsort()
pruned_num = int(round(len(sorted_idx) * pruned_ratio))
pruned_num = min(len(sorted_idx) - 1, pruned_num)
pruned_idx = sorted_idx[:pruned_num]
mask_shape = [value.shape[pruned_axis]]
mask = np.ones(mask_shape, dtype="int32")
......
......@@ -38,6 +38,7 @@ class L1NormFilterPruner(FilterPruner):
sorted_idx = l1norm.argsort()
pruned_num = int(round(len(sorted_idx) * pruned_ratio))
pruned_num = min(len(sorted_idx) - 1, pruned_num)
pruned_idx = sorted_idx[:pruned_num]
mask_shape = [value.shape[pruned_axis]]
......
......@@ -38,6 +38,7 @@ class L2NormFilterPruner(FilterPruner):
sorted_idx = scores.argsort()
pruned_num = int(round(len(sorted_idx) * pruned_ratio))
pruned_num = min(len(sorted_idx) - 1, pruned_num)
pruned_idx = sorted_idx[:pruned_num]
mask_shape = [value.shape[pruned_axis]]
......
......@@ -63,6 +63,7 @@ def default_idx_selector(group, scores, ratios):
sorted_idx = score.argsort()
ratio = ratios[name]
pruned_num = int(round(len(sorted_idx) * ratio))
pruned_num = min(len(sorted_idx) - 1, pruned_num)
pruned_idx = sorted_idx[:pruned_num]
# convert indices of channel groups to indices of channels.
if max_groups > 1:
......
......@@ -52,7 +52,7 @@ class TestSensitivity(unittest.TestCase):
for _ratio, _loss in _value.items():
if not np.allclose(_losses[_ratio], _loss, atol=1e-2):
print(
f'static loss: {static_sen[_name][_ratio]}; dygraph loss: {_loss}'
f'ratio: {_ratio}; static loss: {_losses[_ratio]}; dygraph loss: {_loss}'
)
all_right = False
self.assertTrue(all_right)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册