quant_post.py 3.5 KB
Newer Older
S
slf12 已提交
1 2 3 4 5 6 7 8
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
G
Guanghua Yu 已提交
9
import random
S
slf12 已提交
10
import numpy as np
G
Guanghua Yu 已提交
11
import paddle
S
slf12 已提交
12

W
whs 已提交
13 14
sys.path[0] = os.path.join(
    os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
S
slf12 已提交
15
from paddleslim.common import get_logger
X
XGZhang 已提交
16
from paddleslim.quant import quant_post_static
S
slf12 已提交
17
from utility import add_arguments, print_arguments
L
Liufang Sang 已提交
18
import imagenet_reader as reader
S
slf12 已提交
19 20 21 22 23
_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
X
XGZhang 已提交
24
add_arg('batch_size',       int,  32,                 "Minibatch size.")
W
whs 已提交
25
add_arg('batch_num',       int,  10,               "Batch number")
S
slf12 已提交
26
add_arg('use_gpu',          bool, True,                "Whether to use GPU or not.")
27
add_arg('model_path',       str,  "./inference_model/MobileNetV1_infer/",  "model dir")
S
slf12 已提交
28
add_arg('save_path',       str,  "./quant_model/MobileNet/",  "model dir to save quanted model")
29 30
add_arg('model_filename',       str, 'inference.pdmodel',                 "model file name")
add_arg('params_filename',      str, 'inference.pdiparams',                 "params file name")
W
whs 已提交
31
add_arg('algo',         str, 'avg',               "calibration algorithm")
32
add_arg('round_type',         str, 'round',               "The method of converting the quantized weights.")
X
XGZhang 已提交
33
add_arg('hist_percent',         float, 0.9999,             "The percentile of algo:hist")
34
add_arg('is_full_quantize',         bool, False,             "Whether is full quantization or not.")
X
XGZhang 已提交
35
add_arg('bias_correction',         bool, False,             "Whether to use bias correction")
G
Guanghua Yu 已提交
36
add_arg('ce_test',                 bool,   False,                                        "Whether to CE test.")
37 38
add_arg('onnx_format',             bool,   False,                  "Whether to export the quantized model with format of ONNX.")
add_arg('input_name',         str, 'inputs',               "The name of model input.")
X
XGZhang 已提交
39

S
slf12 已提交
40 41 42 43
# yapf: enable


def quantize(args):
G
Guanghua Yu 已提交
44 45 46 47 48 49 50 51
    shuffle = True
    if args.ce_test:
        # set seed
        seed = 111
        np.random.seed(seed)
        paddle.seed(seed)
        random.seed(seed)
        shuffle = False
S
slf12 已提交
52

B
Bai Yifan 已提交
53
    place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
G
Guanghua Yu 已提交
54 55 56
    val_dataset = reader.ImageNetDataset(mode='test')
    image_shape = [3, 224, 224]
    image = paddle.static.data(
57
        name=args.input_name, shape=[None] + image_shape, dtype='float32')
G
Guanghua Yu 已提交
58 59 60 61 62 63 64 65
    data_loader = paddle.io.DataLoader(
        val_dataset,
        places=place,
        feed_list=[image],
        drop_last=False,
        return_list=False,
        batch_size=args.batch_size,
        shuffle=False)
S
slf12 已提交
66 67 68 69

    assert os.path.exists(args.model_path), "args.model_path doesn't exist"
    assert os.path.isdir(args.model_path), "args.model_path must be a dir"

B
Bai Yifan 已提交
70
    exe = paddle.static.Executor(place)
X
XGZhang 已提交
71
    quant_post_static(
S
slf12 已提交
72 73 74
        executor=exe,
        model_dir=args.model_path,
        quantize_model_path=args.save_path,
G
Guanghua Yu 已提交
75
        data_loader=data_loader,
S
slf12 已提交
76 77 78
        model_filename=args.model_filename,
        params_filename=args.params_filename,
        batch_size=args.batch_size,
X
XGZhang 已提交
79 80
        batch_nums=args.batch_num,
        algo=args.algo,
81
        round_type=args.round_type,
X
XGZhang 已提交
82
        hist_percent=args.hist_percent,
83 84 85
        is_full_quantize=args.is_full_quantize,
        bias_correction=args.bias_correction,
        onnx_format=args.onnx_format)
S
slf12 已提交
86 87 88 89 90 91 92 93 94


def main():
    args = parser.parse_args()
    print_arguments(args)
    quantize(args)


if __name__ == '__main__':
95
    paddle.enable_static()
S
slf12 已提交
96
    main()