From 73fb5929043a1b4abdd632575ddc387a18381259 Mon Sep 17 00:00:00 2001 From: lishuang Date: Mon, 26 Apr 2021 09:02:41 +0800 Subject: [PATCH] fix the get_input_mask function judge the layer name rather than idx --- utils/prune_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/prune_utils.py b/utils/prune_utils.py index 653a85c..07c51b8 100644 --- a/utils/prune_utils.py +++ b/utils/prune_utils.py @@ -182,7 +182,7 @@ def get_input_mask(module_defs, idx, CBLidx2mask): if idx == 0: return np.ones(3) - if idx == 1: + if module_defs[idx - 1]['type'] == 'focus': return np.ones(12) if module_defs[idx - 1]['type'] == 'convolutional': return CBLidx2mask[idx - 1] -- GitLab