未验证 提交 d4b71967 编写于 作者: J Jason 提交者: GitHub

Merge pull request #293 from FlyingQianMM/develop_test

add sensitive files for hrnet and fastscnn
...@@ -91,7 +91,23 @@ sensitivities_data = { ...@@ -91,7 +91,23 @@ sensitivities_data = {
'DeepLabv3p_Xception65_aspp_decoder': 'DeepLabv3p_Xception65_aspp_decoder':
'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception65_with_aspp_decoder.sensitivities', 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception65_with_aspp_decoder.sensitivities',
'DeepLabv3p_Xception41_aspp_decoder': 'DeepLabv3p_Xception41_aspp_decoder':
'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities' 'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities',
'HRNet_W18_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w18.sensitivities',
'HRNet_W30_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w30.sensitivities',
'HRNet_W32_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w32.sensitivities',
'HRNet_W40_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w40.sensitivities',
'HRNet_W44_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w44.sensitivities',
'HRNet_W48_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w48.sensitivities',
'HRNet_W64_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w64.sensitivities',
'FastSCNN':
'https://bj.bcebos.com/paddlex/slim_prune/fast_scnn.sensitivities'
} }
...@@ -105,6 +121,8 @@ def get_sensitivities(flag, model, save_dir): ...@@ -105,6 +121,8 @@ def get_sensitivities(flag, model, save_dir):
elif hasattr(model, 'encoder_with_aspp') or hasattr(model, elif hasattr(model, 'encoder_with_aspp') or hasattr(model,
'enable_decoder'): 'enable_decoder'):
model_type = model_type + '_' + 'aspp' + '_' + 'decoder' model_type = model_type + '_' + 'aspp' + '_' + 'decoder'
if model_type.startswith('HRNet') and model.model_type == 'segmenter':
model_type = '{}_W{}_Seg'.format(model_type, model.width)
if osp.isfile(flag): if osp.isfile(flag):
return flag return flag
elif flag == 'DEFAULT': elif flag == 'DEFAULT':
...@@ -243,19 +261,17 @@ def get_prune_params(model): ...@@ -243,19 +261,17 @@ def get_prune_params(model):
for i in params_not_prune: for i in params_not_prune:
if i in prune_names: if i in prune_names:
prune_names.remove(i) prune_names.remove(i)
elif model_type.startswith('HRNet'): elif model_type.startswith('HRNet') and model.model_type == 'segmenter':
for param in program.global_block().all_parameters(): for param in program.global_block().all_parameters():
if 'weight' not in param.name: if 'weight' not in param.name:
continue continue
prune_names.append(param.name) prune_names.append(param.name)
params_not_prune = [ params_not_prune = ['conv-1_weights']
'conv-1_weights'
]
for i in params_not_prune: for i in params_not_prune:
if i in prune_names: if i in prune_names:
prune_names.remove(i) prune_names.remove(i)
elif model_type.startswith('FastSCNN'): elif model_type.startswith('FastSCNN'):
for param in program.global_block().all_parameters(): for param in program.global_block().all_parameters():
if 'weight' not in param.name: if 'weight' not in param.name:
...@@ -263,9 +279,7 @@ def get_prune_params(model): ...@@ -263,9 +279,7 @@ def get_prune_params(model):
if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name: if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
continue continue
prune_names.append(param.name) prune_names.append(param.name)
params_not_prune = [ params_not_prune = ['classifier/weights']
'classifier/weights'
]
for i in params_not_prune: for i in params_not_prune:
if i in prune_names: if i in prune_names:
prune_names.remove(i) prune_names.remove(i)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册