diff --git a/utils/prune_utils.py b/utils/prune_utils.py index 653a85c3557706bb69d2ba7698d5e82fb84b8939..07c51b88b4592e1133dac75fb908d74d886bc7e0 100644 --- a/utils/prune_utils.py +++ b/utils/prune_utils.py @@ -182,7 +182,7 @@ def get_input_mask(module_defs, idx, CBLidx2mask): if idx == 0: return np.ones(3) - if idx == 1: + if module_defs[idx - 1]['type'] == 'focus': return np.ones(12) if module_defs[idx - 1]['type'] == 'convolutional': return CBLidx2mask[idx - 1]