predict.py 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
import os
import paddle
from paddleslim.analysis import TableLatencyPredictor
from .prune_model import get_sparse_model, get_prune_model
from .fake_ptq import post_quant_fake
import shutil


def predict_compressed_model(model_file, param_file, hardware='SD710'):
    """
    Evaluating the latency of the model under various compression strategies.
    Args:
        model_file(str), param_file(str): The inference model to be compressed.
        hardware(str): Target device.
    Returns:
        latency_dict(dict): The latency latency of the model under various compression strategies.
    """
    latency_dict = {}

    model_filename = model_file.split('/')[-1]
    param_filename = param_file.split('/')[-1]

    predictor = TableLatencyPredictor(hardware)
    latency = predictor.predict(
        model_file=model_file, param_file=param_file, data_type='fp32')
    latency_dict.update({'origin_fp32': latency})
    paddle.enable_static()
    place = paddle.CPUPlace()
    exe = paddle.static.Executor(place)
    post_quant_fake(
        exe,
        model_dir=os.path.dirname(model_file),
        model_filename=model_filename,
        params_filename=param_filename,
        save_model_path='quant_model',
        quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
        is_full_quantize=False,
        activation_bits=8,
        weight_bits=8)
    quant_model_file = os.path.join('quant_model', model_filename)
    quant_param_file = os.path.join('quant_model', param_filename)

    latency = predictor.predict(
        model_file=quant_model_file,
        param_file=quant_param_file,
        data_type='int8')
    latency_dict.update({f'origin_int8': latency})

    for prune_ratio in [0.3, 0.4, 0.5, 0.6]:
        get_prune_model(
            model_file=model_file,
            param_file=param_file,
            ratio=prune_ratio,
            save_path='prune_model')
        prune_model_file = os.path.join('prune_model', model_filename)
        prune_param_file = os.path.join('prune_model', param_filename)

        latency = predictor.predict(
            model_file=prune_model_file,
            param_file=prune_param_file,
            data_type='fp32')
        latency_dict.update({f'prune_{prune_ratio}_fp32': latency})

        post_quant_fake(
            exe,
            model_dir='prune_model',
            model_filename=model_filename,
            params_filename=param_filename,
            save_model_path='quant_model',
            quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
            is_full_quantize=False,
            activation_bits=8,
            weight_bits=8)
        quant_model_file = os.path.join('quant_model', model_filename)
        quant_param_file = os.path.join('quant_model', param_filename)

        latency = predictor.predict(
            model_file=quant_model_file,
            param_file=quant_param_file,
            data_type='int8')
        latency_dict.update({f'prune_{prune_ratio}_int8': latency})

    for sparse_ratio in [0.70, 0.75, 0.80, 0.85, 0.90, 0.95]:
        get_sparse_model(
            model_file=model_file,
            param_file=param_file,
            ratio=sparse_ratio,
            save_path='sparse_model')
        sparse_model_file = os.path.join('sparse_model', model_filename)
        sparse_param_file = os.path.join('sparse_model', param_filename)

        latency = predictor.predict(
            model_file=sparse_model_file,
            param_file=sparse_param_file,
            data_type='fp32')
        latency_dict.update({f'sparse_{sparse_ratio}_fp32': latency})

        post_quant_fake(
            exe,
            model_dir='sparse_model',
            model_filename=model_filename,
            params_filename=param_filename,
            save_model_path='quant_model',
            quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
            is_full_quantize=False,
            activation_bits=8,
            weight_bits=8)
        quant_model_file = os.path.join('quant_model', model_filename)
        quant_param_file = os.path.join('quant_model', param_filename)

        latency = predictor.predict(
            model_file=quant_model_file,
            param_file=quant_param_file,
            data_type='int8')
        latency_dict.update({f'sparse_{prune_ratio}_int8': latency})

    # Delete temporary model files
    shutil.rmtree('./quant_model')
    shutil.rmtree('./prune_model')
    shutil.rmtree('./sparse_model')
    return latency_dict