quant_post_hpo.py 2.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
import os
import sys
import numpy as np
import paddle
import logging
import argparse
import functools

sys.path[0] = os.path.join(
    os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
sys.path[1] = os.path.join(
    os.path.dirname("__file__"), os.path.pardir, os.path.pardir, os.path.pardir)
from paddleslim.common import get_logger
from paddleslim.quant import quant_post_hpo
from utility import add_arguments, print_arguments
import imagenet_reader as reader
_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu',          bool, True,                "Whether to use GPU or not.")
add_arg('model_path',       str,  "./inference_model/MobileNet/",  "model dir")
add_arg('save_path',       str,  "./quant_model/MobileNet/",  "model dir to save quanted model")
add_arg('model_filename',       str, None,                 "model file name")
add_arg('params_filename',      str, None,                 "params file name")
add_arg('max_model_quant_count',    int, 30,                 "max model quant count")

def quantize(args):
    place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()

    assert os.path.exists(args.model_path), "args.model_path doesn't exist"
    assert os.path.isdir(args.model_path), "args.model_path must be a dir"

    def reader_generator(imagenet_reader):
        def gen():
            for i, data in enumerate(imagenet_reader()):
                image, label = data
                image = np.expand_dims(image, axis=0)
                yield image
        return gen

    exe = paddle.static.Executor(place)
    quant_post_hpo(
        exe,
        place,
        args.model_path,
        args.save_path,
        train_sample_generator=reader_generator(reader.train()),
        eval_sample_generator=reader_generator(reader.val()),
        model_filename=args.model_filename,
        params_filename=args.params_filename,
Z
zhouzj 已提交
53 54
        save_model_filename='model.pdmodel',
        save_params_filename='model.pdiparams',
55
        quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
C
ceci3 已提交
56
        weight_quantize_type=['channel_wise_abs_max'],
57 58 59 60 61 62 63 64 65 66 67
        runcount_limit=args.max_model_quant_count)

def main():
    args = parser.parse_args()
    print_arguments(args)
    quantize(args)


if __name__ == '__main__':
    paddle.enable_static()
    main()