predict_yolo.py 2.2 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# -*- coding:utf8 -*-
import argparse
import os
import ast

import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddlehub.reader.cv_reader import ObjectDetectionReader
from paddlehub.dataset.base_cv_dataset import ObjectDetectionDataset

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu",            type=ast.literal_eval,  default=True,                         help="Whether use GPU for fine-tuning.")
W
wuzewu 已提交
15
parser.add_argument("--checkpoint_dir",     type=str,               default="yolo_finetune_ckpt",         help="Path to save log data.")
W
wuzewu 已提交
16
parser.add_argument("--batch_size",         type=int,               default=8,                            help="Total examples' number in batch for training.")
W
wuzewu 已提交
17
parser.add_argument("--module",             type=str,               default="yolov3_darknet53_coco2017",  help="Module used as feature extractor.")
W
wuzewu 已提交
18 19 20 21 22 23
parser.add_argument("--dataset",            type=str,               default="coco_10",                    help="Dataset to finetune.")
# yapf: enable.


def predict(args):
    module = hub.Module(name=args.module)
W
wuzewu 已提交
24
    dataset = hub.dataset.Coco10('yolo')
W
wuzewu 已提交
25 26 27 28

    print("dataset.num_labels:", dataset.num_labels)

    # define batch reader
W
wuzewu 已提交
29
    data_reader = ObjectDetectionReader(dataset=dataset, model_type='yolo')
W
wuzewu 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43

    input_dict, output_dict, program = module.context(trainable=True)
    feed_list = [input_dict["image"].name, input_dict["im_size"].name]
    feature = output_dict['body_features']

    config = hub.RunConfig(
        use_data_parallel=False,
        use_pyreader=True,
        use_cuda=args.use_gpu,
        batch_size=args.batch_size,
        enable_memory_optim=False,
        checkpoint_dir=args.checkpoint_dir,
        strategy=hub.finetune.strategy.DefaultFinetuneStrategy())

W
wuzewu 已提交
44
    task = hub.YOLOTask(
W
wuzewu 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
        data_reader=data_reader,
        num_classes=dataset.num_labels,
        feed_list=feed_list,
        feature=feature,
        config=config)

    data = [
        "./test/test_img_bird.jpg",
        "./test/test_img_cat.jpg",
    ]
    label_map = dataset.label_dict()
    results = task.predict(data=data, return_result=True, accelerate_mode=False)
    print(results)


if __name__ == "__main__":
    args = parser.parse_args()
    predict(args)