diff --git a/paddlex/cv/models/slim/prune_config.py b/paddlex/cv/models/slim/prune_config.py index 65baefa96a42c86c4a1b6ad5797e58b2c2ea5420..49430e9bfb1dcc47fb93aa9fc7d05ceb21e2b9e8 100644 --- a/paddlex/cv/models/slim/prune_config.py +++ b/paddlex/cv/models/slim/prune_config.py @@ -142,13 +142,16 @@ def get_prune_params(model): program = model.test_prog if model_type.startswith('ResNet') or \ model_type.startswith('DenseNet') or \ - model_type.startswith('DarkNet'): + model_type.startswith('DarkNet') or \ + model_type.startswith('AlexNet'): for block in program.blocks: for param in block.all_parameters(): pd_var = fluid.global_scope().find_var(param.name) pd_param = pd_var.get_tensor() if len(np.array(pd_param).shape) == 4: prune_names.append(param.name) + if model_type == 'AlexNet': + prune_names.remove('conv5_weights') elif model_type == "MobileNetV1": prune_names.append("conv1_weights") for param in program.global_block().all_parameters():