未验证 提交 d4b71967 编写于 作者: J Jason 提交者: GitHub

Merge pull request #293 from FlyingQianMM/develop_test

add sensitive files for hrnet and fastscnn
......@@ -91,7 +91,23 @@ sensitivities_data = {
'DeepLabv3p_Xception65_aspp_decoder':
'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception65_with_aspp_decoder.sensitivities',
'DeepLabv3p_Xception41_aspp_decoder':
'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities'
'https://bj.bcebos.com/paddlex/slim_prune/deeplab_xception41_with_aspp_decoder.sensitivities',
'HRNet_W18_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w18.sensitivities',
'HRNet_W30_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w30.sensitivities',
'HRNet_W32_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w32.sensitivities',
'HRNet_W40_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w40.sensitivities',
'HRNet_W44_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w44.sensitivities',
'HRNet_W48_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w48.sensitivities',
'HRNet_W64_Seg':
'https://bj.bcebos.com/paddlex/slim_prune/hrnet_w64.sensitivities',
'FastSCNN':
'https://bj.bcebos.com/paddlex/slim_prune/fast_scnn.sensitivities'
}
......@@ -105,6 +121,8 @@ def get_sensitivities(flag, model, save_dir):
elif hasattr(model, 'encoder_with_aspp') or hasattr(model,
'enable_decoder'):
model_type = model_type + '_' + 'aspp' + '_' + 'decoder'
if model_type.startswith('HRNet') and model.model_type == 'segmenter':
model_type = '{}_W{}_Seg'.format(model_type, model.width)
if osp.isfile(flag):
return flag
elif flag == 'DEFAULT':
......@@ -243,19 +261,17 @@ def get_prune_params(model):
for i in params_not_prune:
if i in prune_names:
prune_names.remove(i)
elif model_type.startswith('HRNet'):
elif model_type.startswith('HRNet') and model.model_type == 'segmenter':
for param in program.global_block().all_parameters():
if 'weight' not in param.name:
continue
prune_names.append(param.name)
params_not_prune = [
'conv-1_weights'
]
params_not_prune = ['conv-1_weights']
for i in params_not_prune:
if i in prune_names:
prune_names.remove(i)
elif model_type.startswith('FastSCNN'):
for param in program.global_block().all_parameters():
if 'weight' not in param.name:
......@@ -263,9 +279,7 @@ def get_prune_params(model):
if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
continue
prune_names.append(param.name)
params_not_prune = [
'classifier/weights'
]
params_not_prune = ['classifier/weights']
for i in params_not_prune:
if i in prune_names:
prune_names.remove(i)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册