quant_offline.py 2.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
import argparse
from datasets.dataset import Dataset
import transforms
import models


def parse_args():
    parser = argparse.ArgumentParser(description='HumanSeg training')
    parser.add_argument(
        '--model_dir',
        dest='model_dir',
        help='Model path for quant',
        type=str,
        default='output/best_model')
    parser.add_argument(
        '--batch_size',
        dest='batch_size',
        help='Mini batch size',
        type=int,
        default=1)
    parser.add_argument(
        '--batch_nums',
        dest='batch_nums',
        help='Batch number for quant',
        type=int,
        default=10)
    parser.add_argument(
        '--data_dir',
        dest='data_dir',
        help='the root directory of dataset',
        type=str)
    parser.add_argument(
        '--quant_list',
        dest='quant_list',
        help=
        'Image file list for model quantization, it can be vat.txt or train.txt',
        type=str,
        default=None)
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='The directory for saving the quant model',
        type=str,
        default='./output/quant_offline')
C
chenguowei01 已提交
45 46 47 48 49 50 51
    parser.add_argument(
        "--image_shape",
        dest="image_shape",
        help="The image shape for net inputs.",
        nargs=2,
        default=[192, 192],
        type=int)
52 53 54 55 56
    return parser.parse_args()


def evaluate(args):
    eval_transforms = transforms.Compose(
C
chenguowei01 已提交
57
        [transforms.Resize(args.image_shape),
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
         transforms.Normalize()])

    eval_dataset = Dataset(
        data_dir=args.data_dir,
        file_list=args.quant_list,
        transforms=eval_transforms,
        num_workers='auto',
        buffer_size=100,
        parallel_method='thread',
        shuffle=False)

    model = models.load_model(args.model_dir)
    model.export_quant_model(
        dataset=eval_dataset,
        save_dir=args.save_dir,
        batch_size=args.batch_size,
        batch_nums=args.batch_nums)


if __name__ == '__main__':
    args = parse_args()

    evaluate(args)