提交 b58cfe8b 编写于 作者: S sunyanfang01

add alexnet prune

上级 f13d9367
...@@ -142,13 +142,16 @@ def get_prune_params(model): ...@@ -142,13 +142,16 @@ def get_prune_params(model):
program = model.test_prog program = model.test_prog
if model_type.startswith('ResNet') or \ if model_type.startswith('ResNet') or \
model_type.startswith('DenseNet') or \ model_type.startswith('DenseNet') or \
model_type.startswith('DarkNet'): model_type.startswith('DarkNet') or \
model_type.startswith('AlexNet'):
for block in program.blocks: for block in program.blocks:
for param in block.all_parameters(): for param in block.all_parameters():
pd_var = fluid.global_scope().find_var(param.name) pd_var = fluid.global_scope().find_var(param.name)
pd_param = pd_var.get_tensor() pd_param = pd_var.get_tensor()
if len(np.array(pd_param).shape) == 4: if len(np.array(pd_param).shape) == 4:
prune_names.append(param.name) prune_names.append(param.name)
if model_type == 'AlexNet':
prune_names.remove('conv5_weights')
elif model_type == "MobileNetV1": elif model_type == "MobileNetV1":
prune_names.append("conv1_weights") prune_names.append("conv1_weights")
for param in program.global_block().all_parameters(): for param in program.global_block().all_parameters():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册