From f84a56f57918fd6fe1ca03b0fdca8ccee76d21d2 Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Fri, 28 Aug 2020 14:32:58 +0800 Subject: [PATCH] add sensitive files for hrnet and fastscnn --- paddlex/cv/models/slim/prune_config.py | 34 ++++++++++++++++++-------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/paddlex/cv/models/slim/prune_config.py b/paddlex/cv/models/slim/prune_config.py index 0ecdb1e..d85867e 100644 --- a/paddlex/cv/models/slim/prune_config.py +++ b/paddlex/cv/models/slim/prune_config.py @@ -91,7 +91,23 @@ sensitivities_data = { 'DeepLabv3p_Xception65_aspp_decoder': 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception65_with_aspp_decoder.sensitivities', 'DeepLabv3p_Xception41_aspp_decoder': - 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities' + 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities', + 'HRNet_W18_Seg': + 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w18.sensitivities', + 'HRNet_W30_Seg': + 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w30.sensitivities', + 'HRNet_W32_Seg': + 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w32.sensitivities', + 'HRNet_W40_Seg': + 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w40.sensitivities', + 'HRNet_W44_Seg': + 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w44.sensitivities', + 'HRNet_W48_Seg': + 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w48.sensitivities', + 'HRNet_W64_Seg': + 'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w64.sensitivities', + 'FastSCNN': + 'https://bj.bcebos.com/paddlex/slim_prune/fast_scnn.sensitivities' } @@ -105,6 +121,8 @@ def get_sensitivities(flag, model, save_dir): elif hasattr(model, 'encoder_with_aspp') or hasattr(model, 'enable_decoder'): model_type = model_type + '_' + 'aspp' + '_' + 'decoder' + if model_type.startswith('HRNet') and model.model_type == 'segmenter': + model_type = '{}_W{}_Seg'.format(model_type, model.width) if osp.isfile(flag): return flag elif flag == 'DEFAULT': @@ -243,19 +261,17 @@ def get_prune_params(model): for i in params_not_prune: if i in prune_names: prune_names.remove(i) - - elif model_type.startswith('HRNet'): + + elif model_type.startswith('HRNet') and model.model_type == 'segmenter': for param in program.global_block().all_parameters(): if 'weight' not in param.name: continue prune_names.append(param.name) - params_not_prune = [ - 'conv-1_weights' - ] + params_not_prune = ['conv-1_weights'] for i in params_not_prune: if i in prune_names: prune_names.remove(i) - + elif model_type.startswith('FastSCNN'): for param in program.global_block().all_parameters(): if 'weight' not in param.name: @@ -263,9 +279,7 @@ def get_prune_params(model): if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name: continue prune_names.append(param.name) - params_not_prune = [ - 'classifier/weights' - ] + params_not_prune = ['classifier/weights'] for i in params_not_prune: if i in prune_names: prune_names.remove(i) -- GitLab