predict.py 5.5 KB
Newer Older
1 2 3 4 5 6 7 8
import os
import paddle
from paddleslim.analysis import TableLatencyPredictor
from .prune_model import get_sparse_model, get_prune_model
from .fake_ptq import post_quant_fake
import shutil


C
ceci3 已提交
9 10 11 12
def predict_compressed_model(model_dir,
                             model_filename,
                             params_filename,
                             hardware='SD710'):
13 14 15
    """
    Evaluating the latency of the model under various compression strategies.
    Args:
C
ceci3 已提交
16 17 18 19 20 21 22 23 24
        model_dir(str): The path of inference model that will be compressed, and
            the model and params that saved by ``paddle.static.io.save_inference_model``
            are under the path.
        model_filename(str, optional):  The name of model file. If parameters
            are saved in separate files, set it as 'None'. Default: 'None'.
        params_filename(str, optional): The name of params file.
            When all parameters are saved in a single file, set it
            as filename. If parameters are saved in separate files,
            set it as 'None'. Default : 'None'.
25 26 27 28
        hardware(str): Target device.
    Returns:
        latency_dict(dict): The latency latency of the model under various compression strategies.
    """
29 30 31 32 33
    local_rank = paddle.distributed.get_rank()
    quant_model_path = f'quant_model/rank_{local_rank}'
    prune_model_path = f'prune_model/rank_{local_rank}'
    sparse_model_path = f'sparse_model/rank_{local_rank}'

34 35
    latency_dict = {}

C
ceci3 已提交
36 37
    model_file = os.path.join(model_dir, model_filename)
    param_file = os.path.join(model_dir, params_filename)
38 39 40 41 42 43 44 45 46 47

    predictor = TableLatencyPredictor(hardware)
    latency = predictor.predict(
        model_file=model_file, param_file=param_file, data_type='fp32')
    latency_dict.update({'origin_fp32': latency})
    paddle.enable_static()
    place = paddle.CPUPlace()
    exe = paddle.static.Executor(place)
    post_quant_fake(
        exe,
C
ceci3 已提交
48
        model_dir=model_dir,
49
        model_filename=model_filename,
C
ceci3 已提交
50
        params_filename=params_filename,
51
        save_model_path=quant_model_path,
52 53 54 55
        quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
        is_full_quantize=False,
        activation_bits=8,
        weight_bits=8)
56 57
    quant_model_file = os.path.join(quant_model_path, model_filename)
    quant_param_file = os.path.join(quant_model_path, params_filename)
58 59 60 61 62

    latency = predictor.predict(
        model_file=quant_model_file,
        param_file=quant_param_file,
        data_type='int8')
C
ceci3 已提交
63
    latency_dict.update({'origin_int8': latency})
64 65 66 67 68 69

    for prune_ratio in [0.3, 0.4, 0.5, 0.6]:
        get_prune_model(
            model_file=model_file,
            param_file=param_file,
            ratio=prune_ratio,
70 71 72
            save_path=prune_model_path)
        prune_model_file = os.path.join(prune_model_path, model_filename)
        prune_param_file = os.path.join(prune_model_path, params_filename)
73 74 75 76 77 78 79 80 81

        latency = predictor.predict(
            model_file=prune_model_file,
            param_file=prune_param_file,
            data_type='fp32')
        latency_dict.update({f'prune_{prune_ratio}_fp32': latency})

        post_quant_fake(
            exe,
82
            model_dir=prune_model_path,
83
            model_filename=model_filename,
C
ceci3 已提交
84
            params_filename=params_filename,
85
            save_model_path=quant_model_path,
86 87 88 89
            quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
            is_full_quantize=False,
            activation_bits=8,
            weight_bits=8)
90 91
        quant_model_file = os.path.join(quant_model_path, model_filename)
        quant_param_file = os.path.join(quant_model_path, params_filename)
92 93 94 95 96 97 98 99 100 101 102 103

        latency = predictor.predict(
            model_file=quant_model_file,
            param_file=quant_param_file,
            data_type='int8')
        latency_dict.update({f'prune_{prune_ratio}_int8': latency})

    for sparse_ratio in [0.70, 0.75, 0.80, 0.85, 0.90, 0.95]:
        get_sparse_model(
            model_file=model_file,
            param_file=param_file,
            ratio=sparse_ratio,
104 105 106
            save_path=sparse_model_path)
        sparse_model_file = os.path.join(sparse_model_path, model_filename)
        sparse_param_file = os.path.join(sparse_model_path, params_filename)
107 108 109 110 111 112 113 114 115

        latency = predictor.predict(
            model_file=sparse_model_file,
            param_file=sparse_param_file,
            data_type='fp32')
        latency_dict.update({f'sparse_{sparse_ratio}_fp32': latency})

        post_quant_fake(
            exe,
116
            model_dir=sparse_model_path,
117
            model_filename=model_filename,
C
ceci3 已提交
118
            params_filename=params_filename,
119 120 121 122 123
            save_model_path='quant_model',
            quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
            is_full_quantize=False,
            activation_bits=8,
            weight_bits=8)
124 125
        quant_model_file = os.path.join(quant_model_path, model_filename)
        quant_param_file = os.path.join(quant_model_path, params_filename)
126 127 128 129 130

        latency = predictor.predict(
            model_file=quant_model_file,
            param_file=quant_param_file,
            data_type='int8')
C
ceci3 已提交
131
        latency_dict.update({f'sparse_{sparse_ratio}_int8': latency})
132

133 134 135 136 137 138 139
    # NOTE: Delete temporary model files
    if os.path.exists('quant_model'):
        shutil.rmtree('quant_model', ignore_errors=True)
    if os.path.exists('prune_model'):
        shutil.rmtree('prune_model', ignore_errors=True)
    if os.path.exists('sparse_model'):
        shutil.rmtree('sparse_model', ignore_errors=True)
140
    return latency_dict