img_classifier.py 2.5 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9
import argparse
import os

import paddle.fluid as fluid
import paddlehub as hub
import numpy as np

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
10 11 12
parser.add_argument("--num_epoch",      type=int,   default=1,                          help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu",        type=bool,  default=False,                      help="Whether use GPU for fine-tuning.")
parser.add_argument("--checkpoint_dir", type=str,   default="paddlehub_finetune_ckpt",  help="Path to save log data.")
W
wuzewu 已提交
13
parser.add_argument("--batch_size",     type=int,   default=16,                         help="Total examples' number in batch for training.")
14 15
parser.add_argument("--module",         type=str,   default="resnet50",                 help="Module used as feature extractor.")
parser.add_argument("--dataset",        type=str,   default="flowers",                  help="Dataset to finetune.")
W
wuzewu 已提交
16 17 18 19 20 21 22 23 24 25 26 27
# yapf: enable.

module_map = {
    "resnet50": "resnet_v2_50_imagenet",
    "resnet101": "resnet_v2_101_imagenet",
    "resnet152": "resnet_v2_152_imagenet",
    "mobilenet": "mobilenet_v2_imagenet",
    "nasnet": "nasnet_imagenet",
    "pnasnet": "pnasnet_imagenet"
}


28 29 30 31 32 33
def finetune(args):
    module = hub.Module(name=args.module)
    input_dict, output_dict, program = module.context(trainable=True)

    dataset = hub.dataset.Flowers()
    data_reader = hub.reader.ImageClassificationReader(
W
wuzewu 已提交
34 35 36 37 38 39
        image_width=module.get_expected_image_width(),
        image_height=module.get_expected_image_height(),
        images_mean=module.get_pretrained_images_mean(),
        images_std=module.get_pretrained_images_std(),
        dataset=dataset)

40 41 42
    feature_map = output_dict["feature_map"]
    task = hub.create_img_cls_task(
        feature=feature_map, num_classes=dataset.num_labels)
W
wuzewu 已提交
43 44 45

    img = input_dict["image"]
    feed_list = [img.name, task.variable('label').name]
46

W
wuzewu 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    config = hub.RunConfig(
        use_cuda=args.use_gpu,
        num_epoch=args.num_epoch,
        batch_size=args.batch_size,
        enable_memory_optim=False,
        checkpoint_dir=args.checkpoint_dir,
        strategy=hub.finetune.strategy.DefaultFinetuneStrategy())

    hub.finetune_and_eval(
        task, feed_list=feed_list, data_reader=data_reader, config=config)


if __name__ == "__main__":
    args = parser.parse_args()
    if not args.module in module_map:
        hub.logger.error("module should in %s" % module_map.keys())
        exit(1)
    args.module = module_map[args.module]

66
    finetune(args)