diff --git a/paddlex/cv/models/slim/prune_config.py b/paddlex/cv/models/slim/prune_config.py index d85867e9cb3b921715c2a22aa7900f9c23a6491a..e9019f21e874e3960fd19b133c89dc607ac1c147 100644 --- a/paddlex/cv/models/slim/prune_config.py +++ b/paddlex/cv/models/slim/prune_config.py @@ -285,11 +285,35 @@ def get_prune_params(model): prune_names.remove(i) elif model_type.startswith('DeepLabv3p'): + if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld": + params_not_prune = [ + 'last_1x1_conv_weights', 'conv14_se_2_weights', + 'conv16_depthwise_weights', 'conv13_depthwise_weights', + 'conv15_se_2_weights', 'conv2_depthwise_weights', + 'conv6_depthwise_weights', 'conv8_depthwise_weights', + 'fc_weights', 'conv3_depthwise_weights', 'conv7_se_2_weights', + 'conv16_expand_weights', 'conv16_se_2_weights', + 'conv10_depthwise_weights', 'conv11_depthwise_weights', + 'conv15_expand_weights', 'conv5_expand_weights', + 'conv15_depthwise_weights', 'conv14_depthwise_weights', + 'conv12_se_2_weights', 'conv1_weights', + 'conv13_expand_weights', 'conv_last_weights', + 'conv12_depthwise_weights', 'conv13_se_2_weights', + 'conv12_expand_weights', 'conv5_depthwise_weights', + 'conv6_se_2_weights', 'conv10_expand_weights', + 'conv9_depthwise_weights', 'conv6_expand_weights', + 'conv5_se_2_weights', 'conv14_expand_weights', + 'conv4_depthwise_weights', 'conv7_expand_weights', + 'conv7_depthwise_weights' + ] for param in program.global_block().all_parameters(): if 'weight' not in param.name: continue if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name: continue + if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld": + if param.name in params_not_prune: + continue prune_names.append(param.name) params_not_prune = [ 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.