You need to sign in or sign up before continuing.
未验证 提交 9dcbe669 编写于 作者: J Jason 提交者: GitHub

Merge pull request #154 from SunAhong1993/syf_prune

add prune configs and prompt
...@@ -108,6 +108,7 @@ def load_model(model_dir, fixed_input_shape=None): ...@@ -108,6 +108,7 @@ def load_model(model_dir, fixed_input_shape=None):
logging.info("Model[{}] loaded.".format(info['Model'])) logging.info("Model[{}] loaded.".format(info['Model']))
model.trainable = False model.trainable = False
model.status = status
return model return model
......
...@@ -158,6 +158,7 @@ def prune_program(model, prune_params_ratios=None): ...@@ -158,6 +158,7 @@ def prune_program(model, prune_params_ratios=None):
prune_params_ratios (dict): 由裁剪参数名和裁剪率组成的字典,当为None时 prune_params_ratios (dict): 由裁剪参数名和裁剪率组成的字典,当为None时
使用默认裁剪参数名和裁剪率。默认为None。 使用默认裁剪参数名和裁剪率。默认为None。
""" """
assert model.status == 'Normal', 'Only the models saved while training are supported!'
place = model.places[0] place = model.places[0]
train_prog = model.train_prog train_prog = model.train_prog
eval_prog = model.test_prog eval_prog = model.test_prog
...@@ -235,6 +236,7 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8): ...@@ -235,6 +236,7 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。 其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
""" """
assert model.status == 'Normal', 'Only the models saved while training are supported!'
if os.path.exists(save_file): if os.path.exists(save_file):
os.remove(save_file) os.remove(save_file)
......
...@@ -19,6 +19,8 @@ import paddle.fluid as fluid ...@@ -19,6 +19,8 @@ import paddle.fluid as fluid
import paddlex import paddlex
sensitivities_data = { sensitivities_data = {
'AlexNet':
'https://bj.bcebos.com/paddlex/slim_prune/alexnet_sensitivities.data',
'ResNet18': 'ResNet18':
'https://bj.bcebos.com/paddlex/slim_prune/resnet18.sensitivities', 'https://bj.bcebos.com/paddlex/slim_prune/resnet18.sensitivities',
'ResNet34': 'ResNet34':
...@@ -41,6 +43,10 @@ sensitivities_data = { ...@@ -41,6 +43,10 @@ sensitivities_data = {
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large.sensitivities', 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large.sensitivities',
'MobileNetV3_small': 'MobileNetV3_small':
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small.sensitivities', 'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small.sensitivities',
'MobileNetV3_large_ssld':
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large_ssld_sensitivities.data',
'MobileNetV3_small_ssld':
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small_ssld_sensitivities.data',
'DenseNet121': 'DenseNet121':
'https://bj.bcebos.com/paddlex/slim_prune/densenet121.sensitivities', 'https://bj.bcebos.com/paddlex/slim_prune/densenet121.sensitivities',
'DenseNet161': 'DenseNet161':
...@@ -51,6 +57,8 @@ sensitivities_data = { ...@@ -51,6 +57,8 @@ sensitivities_data = {
'https://bj.bcebos.com/paddlex/slim_prune/xception41.sensitivities', 'https://bj.bcebos.com/paddlex/slim_prune/xception41.sensitivities',
'Xception65': 'Xception65':
'https://bj.bcebos.com/paddlex/slim_prune/xception65.sensitivities', 'https://bj.bcebos.com/paddlex/slim_prune/xception65.sensitivities',
'ShuffleNetV2':
'https://bj.bcebos.com/paddlex/slim_prune/shufflenetv2_sensitivities.data',
'YOLOv3_MobileNetV1': 'YOLOv3_MobileNetV1':
'https://bj.bcebos.com/paddlex/slim_prune/yolov3_mobilenetv1.sensitivities', 'https://bj.bcebos.com/paddlex/slim_prune/yolov3_mobilenetv1.sensitivities',
'YOLOv3_MobileNetV3_large': 'YOLOv3_MobileNetV3_large':
...@@ -143,7 +151,8 @@ def get_prune_params(model): ...@@ -143,7 +151,8 @@ def get_prune_params(model):
if model_type.startswith('ResNet') or \ if model_type.startswith('ResNet') or \
model_type.startswith('DenseNet') or \ model_type.startswith('DenseNet') or \
model_type.startswith('DarkNet') or \ model_type.startswith('DarkNet') or \
model_type.startswith('AlexNet'): model_type.startswith('AlexNet') or \
model_type.startswith('ShuffleNetV2'):
for block in program.blocks: for block in program.blocks:
for param in block.all_parameters(): for param in block.all_parameters():
pd_var = fluid.global_scope().find_var(param.name) pd_var = fluid.global_scope().find_var(param.name)
...@@ -152,6 +161,28 @@ def get_prune_params(model): ...@@ -152,6 +161,28 @@ def get_prune_params(model):
prune_names.append(param.name) prune_names.append(param.name)
if model_type == 'AlexNet': if model_type == 'AlexNet':
prune_names.remove('conv5_weights') prune_names.remove('conv5_weights')
if model_type == 'ShuffleNetV2':
not_prune_names = ['stage_2_1_conv5_weights',
'stage_2_1_conv3_weights',
'stage_2_2_conv3_weights',
'stage_2_3_conv3_weights',
'stage_2_4_conv3_weights',
'stage_3_1_conv5_weights',
'stage_3_1_conv3_weights',
'stage_3_2_conv3_weights',
'stage_3_3_conv3_weights',
'stage_3_4_conv3_weights',
'stage_3_5_conv3_weights',
'stage_3_6_conv3_weights',
'stage_3_7_conv3_weights',
'stage_3_8_conv3_weights',
'stage_4_1_conv5_weights',
'stage_4_1_conv3_weights',
'stage_4_2_conv3_weights',
'stage_4_3_conv3_weights',
'stage_4_4_conv3_weights',]
for name in not_prune_names:
prune_names.remove(name)
elif model_type == "MobileNetV1": elif model_type == "MobileNetV1":
prune_names.append("conv1_weights") prune_names.append("conv1_weights")
for param in program.global_block().all_parameters(): for param in program.global_block().all_parameters():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册