提交 b58cfe8b 编写于 作者: S sunyanfang01

add alexnet prune

上级 f13d9367
......@@ -142,13 +142,16 @@ def get_prune_params(model):
program = model.test_prog
if model_type.startswith('ResNet') or \
model_type.startswith('DenseNet') or \
model_type.startswith('DarkNet'):
model_type.startswith('DarkNet') or \
model_type.startswith('AlexNet'):
for block in program.blocks:
for param in block.all_parameters():
pd_var = fluid.global_scope().find_var(param.name)
pd_param = pd_var.get_tensor()
if len(np.array(pd_param).shape) == 4:
prune_names.append(param.name)
if model_type == 'AlexNet':
prune_names.remove('conv5_weights')
elif model_type == "MobileNetV1":
prune_names.append("conv1_weights")
for param in program.global_block().all_parameters():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册