export_model.py 2.8 KB
Newer Older
S
slf12 已提交
1 2 3 4 5 6 7 8 9
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
W
whs 已提交
10 11
sys.path[0] = os.path.join(
    os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
S
slf12 已提交
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
from paddleslim.common import get_logger
import models
from utility import add_arguments, print_arguments

_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu',          bool, True,                "Whether to use GPU or not.")
add_arg('model',            str,  "MobileNet",                "The target model.")
add_arg('pretrained_model', str,  "../pretrained_model/MobileNetV1_pretained",                "Whether to use pretrained model.")
add_arg('data',             str, "mnist",                 "Which data to use. 'mnist' or 'imagenet'")
add_arg('test_period',      int, 10,                 "Test period in epoches.")
# yapf: enable

model_list = [m for m in dir(models) if "__" not in m]


def export_model(args):
    if args.data == "mnist":
        import paddle.dataset.mnist as reader
        train_reader = reader.train()
        val_reader = reader.test()
        class_dim = 10
        image_shape = "1,28,28"
    elif args.data == "imagenet":
        import imagenet_reader as reader
        train_reader = reader.train()
        val_reader = reader.val()
        class_dim = 1000
        image_shape = "3,224,224"
    else:
        raise ValueError("{} is not supported.".format(args.data))

    image_shape = [int(m) for m in image_shape.split(",")]
B
Bai Yifan 已提交
48
    image = paddle.static.data(
S
slf12 已提交
49
        name='image', shape=[None] + image_shape, dtype='float32')
50 51
    assert args.model in model_list, "{} is not in lists: {}".format(args.model,
                                                                     model_list)
S
slf12 已提交
52 53 54
    # model definition
    model = models.__dict__[args.model]()
    out = model.net(input=image, class_dim=class_dim)
B
Bai Yifan 已提交
55 56 57 58
    val_program = paddle.static.default_main_program().clone(for_test=True)
    place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
    exe = paddle.static.Executor(place)
    exe.run(paddle.static.default_startup_program())
S
slf12 已提交
59 60

    if args.pretrained_model:
B
Bai Yifan 已提交
61
        paddle.static.load(val_program, args.pretrained_model, exe)
S
slf12 已提交
62 63 64
    else:
        assert False, "args.pretrained_model must set"

B
Bai Yifan 已提交
65
    paddle.static.save_inference_model(
S
slf12 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        './inference_model/' + args.model,
        feeded_var_names=[image.name],
        target_vars=[out],
        executor=exe,
        main_program=val_program,
        model_filename='model',
        params_filename='weights')


def main():
    args = parser.parse_args()
    print_arguments(args)
    export_model(args)


if __name__ == '__main__':
82
    paddle.enable_static()
S
slf12 已提交
83
    main()