slim_fpgm.py 665 字节
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
import paddleslim
import paddle
import numpy as np

from paddleslim.dygraph import FPGMFilterPruner


def prune_model(model, input_shape, prune_ratio=0.1):

    flops = paddle.flops(model, input_shape)
    pruner = FPGMFilterPruner(model, input_shape)

    params_sensitive = {}
    for param in model.parameters():
        if 'transpose' not in param.name and 'linear' not in param.name:
            # set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
            params_sensitive[param.name] = prune_ratio

    plan = pruner.prune_vars(params_sensitive, [0])

    flops = paddle.flops(model, input_shape)
    return model