predict.py 3.2 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
2 3
import argparse
import os
W
wuzewu 已提交
4
import ast
5 6 7 8 9 10 11

import paddle.fluid as fluid
import paddlehub as hub
import numpy as np

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
K
kinghuin 已提交
12
parser.add_argument("--use_gpu",            type=ast.literal_eval,  default=True,                      help="Whether use GPU for predict.")
W
wuzewu 已提交
13 14 15
parser.add_argument("--checkpoint_dir",     type=str,               default="paddlehub_finetune_ckpt",  help="Path to save log data.")
parser.add_argument("--batch_size",         type=int,               default=16,                         help="Total examples' number in batch for training.")
parser.add_argument("--module",             type=str,               default="resnet50",                 help="Module used as a feature extractor.")
S
Steffy-zxf 已提交
16
parser.add_argument("--dataset",            type=str,               default="flowers",                  help="Dataset to fine-tune.")
17 18 19 20 21 22 23 24 25 26 27 28 29
# yapf: enable.

module_map = {
    "resnet50": "resnet_v2_50_imagenet",
    "resnet101": "resnet_v2_101_imagenet",
    "resnet152": "resnet_v2_152_imagenet",
    "mobilenet": "mobilenet_v2_imagenet",
    "nasnet": "nasnet_imagenet",
    "pnasnet": "pnasnet_imagenet"
}


def predict(args):
K
kinghuin 已提交
30
    # Load Paddlehub  pretrained model
W
wuzewu 已提交
31 32
    module = hub.Module(name=args.module)
    input_dict, output_dict, program = module.context(trainable=True)
33

K
kinghuin 已提交
34
    # Download dataset
35
    if args.dataset.lower() == "flowers":
36
        dataset = hub.dataset.Flowers()
37 38 39 40 41 42 43 44 45 46
    elif args.dataset.lower() == "dogcat":
        dataset = hub.dataset.DogCat()
    elif args.dataset.lower() == "indoor67":
        dataset = hub.dataset.Indoor67()
    elif args.dataset.lower() == "food101":
        dataset = hub.dataset.Food101()
    elif args.dataset.lower() == "stanforddogs":
        dataset = hub.dataset.StanfordDogs()
    else:
        raise ValueError("%s dataset is not defined" % args.dataset)
47

K
kinghuin 已提交
48
    # Use ImageClassificationReader to read dataset
49 50 51 52 53
    data_reader = hub.reader.ImageClassificationReader(
        image_width=module.get_expected_image_width(),
        image_height=module.get_expected_image_height(),
        images_mean=module.get_pretrained_images_mean(),
        images_std=module.get_pretrained_images_std(),
W
wuzewu 已提交
54
        dataset=dataset)
55 56

    feature_map = output_dict["feature_map"]
W
wuzewu 已提交
57

K
kinghuin 已提交
58 59
    # Setup feed list for data feeder
    feed_list = [input_dict["image"].name]
60

S
Steffy-zxf 已提交
61
    # Setup RunConfig for PaddleHub Fine-tune API
W
wuzewu 已提交
62
    config = hub.RunConfig(
W
wuzewu 已提交
63
        use_data_parallel=False,
W
wuzewu 已提交
64 65 66 67
        use_cuda=args.use_gpu,
        batch_size=args.batch_size,
        checkpoint_dir=args.checkpoint_dir,
        strategy=hub.finetune.strategy.DefaultFinetuneStrategy())
68

S
Steffy-zxf 已提交
69
    # Define a image classification task by PaddleHub Fine-tune API
K
kinghuin 已提交
70
    task = hub.ImageClassifierTask(
W
wuzewu 已提交
71 72 73 74 75 76 77
        data_reader=data_reader,
        feed_list=feed_list,
        feature=feature_map,
        num_classes=dataset.num_labels,
        config=config)

    data = ["./test/test_img_daisy.jpg", "./test/test_img_roses.jpg"]
K
kinghuin 已提交
78
    print(task.predict(data=data, return_result=True))
79 80 81 82 83 84 85 86 87 88


if __name__ == "__main__":
    args = parser.parse_args()
    if not args.module in module_map:
        hub.logger.error("module should in %s" % module_map.keys())
        exit(1)
    args.module = module_map[args.module]

    predict(args)