未验证 提交 9dcbe669 编写于 作者: J Jason 提交者: GitHub

Merge pull request #154 from SunAhong1993/syf_prune

add prune configs and prompt
......@@ -108,6 +108,7 @@ def load_model(model_dir, fixed_input_shape=None):
logging.info("Model[{}] loaded.".format(info['Model']))
model.trainable = False
model.status = status
return model
......
......@@ -158,6 +158,7 @@ def prune_program(model, prune_params_ratios=None):
prune_params_ratios (dict): 由裁剪参数名和裁剪率组成的字典,当为None时
使用默认裁剪参数名和裁剪率。默认为None。
"""
assert model.status == 'Normal', 'Only the models saved while training are supported!'
place = model.places[0]
train_prog = model.train_prog
eval_prog = model.test_prog
......@@ -235,6 +236,7 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
"""
assert model.status == 'Normal', 'Only the models saved while training are supported!'
if os.path.exists(save_file):
os.remove(save_file)
......
......@@ -19,6 +19,8 @@ import paddle.fluid as fluid
import paddlex
sensitivities_data = {
'AlexNet':
'https://bj.bcebos.com/paddlex/slim_prune/alexnet_sensitivities.data',
'ResNet18':
'https://bj.bcebos.com/paddlex/slim_prune/resnet18.sensitivities',
'ResNet34':
......@@ -41,6 +43,10 @@ sensitivities_data = {
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large.sensitivities',
'MobileNetV3_small':
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small.sensitivities',
'MobileNetV3_large_ssld':
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large_ssld_sensitivities.data',
'MobileNetV3_small_ssld':
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small_ssld_sensitivities.data',
'DenseNet121':
'https://bj.bcebos.com/paddlex/slim_prune/densenet121.sensitivities',
'DenseNet161':
......@@ -51,6 +57,8 @@ sensitivities_data = {
'https://bj.bcebos.com/paddlex/slim_prune/xception41.sensitivities',
'Xception65':
'https://bj.bcebos.com/paddlex/slim_prune/xception65.sensitivities',
'ShuffleNetV2':
'https://bj.bcebos.com/paddlex/slim_prune/shufflenetv2_sensitivities.data',
'YOLOv3_MobileNetV1':
'https://bj.bcebos.com/paddlex/slim_prune/yolov3_mobilenetv1.sensitivities',
'YOLOv3_MobileNetV3_large':
......@@ -143,7 +151,8 @@ def get_prune_params(model):
if model_type.startswith('ResNet') or \
model_type.startswith('DenseNet') or \
model_type.startswith('DarkNet') or \
model_type.startswith('AlexNet'):
model_type.startswith('AlexNet') or \
model_type.startswith('ShuffleNetV2'):
for block in program.blocks:
for param in block.all_parameters():
pd_var = fluid.global_scope().find_var(param.name)
......@@ -152,6 +161,28 @@ def get_prune_params(model):
prune_names.append(param.name)
if model_type == 'AlexNet':
prune_names.remove('conv5_weights')
if model_type == 'ShuffleNetV2':
not_prune_names = ['stage_2_1_conv5_weights',
'stage_2_1_conv3_weights',
'stage_2_2_conv3_weights',
'stage_2_3_conv3_weights',
'stage_2_4_conv3_weights',
'stage_3_1_conv5_weights',
'stage_3_1_conv3_weights',
'stage_3_2_conv3_weights',
'stage_3_3_conv3_weights',
'stage_3_4_conv3_weights',
'stage_3_5_conv3_weights',
'stage_3_6_conv3_weights',
'stage_3_7_conv3_weights',
'stage_3_8_conv3_weights',
'stage_4_1_conv5_weights',
'stage_4_1_conv3_weights',
'stage_4_2_conv3_weights',
'stage_4_3_conv3_weights',
'stage_4_4_conv3_weights',]
for name in not_prune_names:
prune_names.remove(name)
elif model_type == "MobileNetV1":
prune_names.append("conv1_weights")
for param in program.global_block().all_parameters():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册