infer.py 2.7 KB
Newer Older
G
guosheng 已提交
1
import os
W
wwhu 已提交
2
import gzip
G
guosheng 已提交
3 4 5 6
import argparse
import numpy as np
from PIL import Image

W
wwhu 已提交
7 8 9 10 11 12
import paddle.v2 as paddle
import reader
import vgg
import resnet
import alexnet
import googlenet
13
import inception_v4
G
guosheng 已提交
14
import inception_resnet_v2
W
wangmeng28 已提交
15
import xception
W
wwhu 已提交
16

G
guosheng 已提交
17
DATA_DIM = 3 * 224 * 224  # Use 3 * 331 * 331 or 3 * 299 * 299 for Inception-ResNet-v2.
W
wwhu 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30
CLASS_DIM = 102


def main():
    # parse the argument
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'data_list',
        help='The path of data list file, which consists of one image path per line'
    )
    parser.add_argument(
        'model',
        help='The model for image classification',
31 32
        choices=[
            'alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet',
33
            'inception-resnet-v2', 'inception_v4', 'xception'
34
        ])
W
wwhu 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
    parser.add_argument(
        'params_path', help='The file which stores the parameters')
    args = parser.parse_args()

    # PaddlePaddle init
    paddle.init(use_gpu=True, trainer_count=1)

    image = paddle.layer.data(
        name="image", type=paddle.data_type.dense_vector(DATA_DIM))

    if args.model == 'alexnet':
        out = alexnet.alexnet(image, class_dim=CLASS_DIM)
    elif args.model == 'vgg13':
        out = vgg.vgg13(image, class_dim=CLASS_DIM)
    elif args.model == 'vgg16':
        out = vgg.vgg16(image, class_dim=CLASS_DIM)
    elif args.model == 'vgg19':
        out = vgg.vgg19(image, class_dim=CLASS_DIM)
    elif args.model == 'resnet':
        out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM)
    elif args.model == 'googlenet':
        out, _, _ = googlenet.googlenet(image, class_dim=CLASS_DIM)
G
guosheng 已提交
57 58 59 60
    elif args.model == 'inception-resnet-v2':
        assert DATA_DIM == 3 * 331 * 331 or DATA_DIM == 3 * 299 * 299
        out = inception_resnet_v2.inception_resnet_v2(
            image, class_dim=CLASS_DIM, dropout_rate=0.5, data_dim=DATA_DIM)
61 62
    elif args.model == 'inception_v4':
        out = inception_v4.inception_v4(image, class_dim=CLASS_DIM)
63 64
    elif args.model == 'xception':
        out = xception.xception(image, class_dim=CLASS_DIM)
W
wwhu 已提交
65 66 67 68 69 70

    # load parameters
    with gzip.open(args.params_path, 'r') as f:
        parameters = paddle.parameters.Parameters.from_tar(f)

    file_list = [line.strip() for line in open(args.data_list)]
71 72
    test_data = [(paddle.image.load_and_transform(image_file, 256, 224, False)
                  .flatten().astype('float32'), ) for image_file in file_list]
W
wwhu 已提交
73 74 75 76 77 78 79 80 81
    probs = paddle.infer(
        output_layer=out, parameters=parameters, input=test_data)
    lab = np.argsort(-probs)
    for file_name, result in zip(file_list, lab):
        print "Label of %s is: %d" % (file_name, result[0])


if __name__ == '__main__':
    main()