diff --git a/paddlex/cv/models/slim/prune_config.py b/paddlex/cv/models/slim/prune_config.py index d5e6325e805f6dda7987c1e0e909950e43aa5218..0ecdb1e8252233ef4b75b845deee27cb5be6c5af 100644 --- a/paddlex/cv/models/slim/prune_config.py +++ b/paddlex/cv/models/slim/prune_config.py @@ -243,6 +243,32 @@ def get_prune_params(model): for i in params_not_prune: if i in prune_names: prune_names.remove(i) + + elif model_type.startswith('HRNet'): + for param in program.global_block().all_parameters(): + if 'weight' not in param.name: + continue + prune_names.append(param.name) + params_not_prune = [ + 'conv-1_weights' + ] + for i in params_not_prune: + if i in prune_names: + prune_names.remove(i) + + elif model_type.startswith('FastSCNN'): + for param in program.global_block().all_parameters(): + if 'weight' not in param.name: + continue + if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name: + continue + prune_names.append(param.name) + params_not_prune = [ + 'classifier/weights' + ] + for i in params_not_prune: + if i in prune_names: + prune_names.remove(i) elif model_type.startswith('DeepLabv3p'): for param in program.global_block().all_parameters():