quant_post.py 2.1 KB
Newer Older
S
slf12 已提交
1 2 3 4 5 6 7 8 9 10
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np

W
whs 已提交
11 12
sys.path[0] = os.path.join(
    os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
S
slf12 已提交
13 14 15
from paddleslim.common import get_logger
from paddleslim.quant import quant_post
from utility import add_arguments, print_arguments
L
Liufang Sang 已提交
16
import imagenet_reader as reader
S
slf12 已提交
17 18 19 20 21
_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
X
XGZhang 已提交
22 23
add_arg('batch_size',       int,  32,                 "Minibatch size.")
add_arg('batch_num',       int,  1,               "Batch number")
S
slf12 已提交
24 25 26 27 28
add_arg('use_gpu',          bool, True,                "Whether to use GPU or not.")
add_arg('model_path',       str,  "./inference_model/MobileNet/",  "model dir")
add_arg('save_path',       str,  "./quant_model/MobileNet/",  "model dir to save quanted model")
add_arg('model_filename',       str, None,                 "model file name")
add_arg('params_filename',      str, None,                 "params file name")
X
XGZhang 已提交
29 30
add_arg('algo',         str, 'hist',               "calibration algorithm")
add_arg('hist_percent',         float, 0.9999,             "The percentile of algo:hist")
S
slf12 已提交
31 32 33 34 35 36
# yapf: enable


def quantize(args):
    val_reader = reader.train()

B
Bai Yifan 已提交
37
    place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
S
slf12 已提交
38 39 40 41

    assert os.path.exists(args.model_path), "args.model_path doesn't exist"
    assert os.path.isdir(args.model_path), "args.model_path must be a dir"

B
Bai Yifan 已提交
42
    exe = paddle.static.Executor(place)
S
slf12 已提交
43 44 45 46 47 48 49 50
    quant_post(
        executor=exe,
        model_dir=args.model_path,
        quantize_model_path=args.save_path,
        sample_generator=val_reader,
        model_filename=args.model_filename,
        params_filename=args.params_filename,
        batch_size=args.batch_size,
X
XGZhang 已提交
51 52 53
        batch_nums=args.batch_num,
        algo=args.algo,
        hist_percent=args.hist_percent)
S
slf12 已提交
54 55 56 57 58 59 60 61 62


def main():
    args = parser.parse_args()
    print_arguments(args)
    quantize(args)


if __name__ == '__main__':
63
    paddle.enable_static()
S
slf12 已提交
64
    main()