train_faster_rcnn.py 3.2 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
# -*- coding:utf8 -*-
import argparse
import os
import ast

import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddlehub.reader.cv_reader import ObjectDetectionReader
from paddlehub.dataset.base_cv_dataset import ObjectDetectionDataset

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch",          type=int,               default=50,                               help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu",            type=ast.literal_eval,  default=True,                             help="Whether use GPU for fine-tuning.")
parser.add_argument("--checkpoint_dir",     type=str,               default="faster_rcnn_finetune_ckpt",      help="Path to save log data.")
parser.add_argument("--batch_size",         type=int,               default=1,                                help="Total examples' number in batch for training.")
parser.add_argument("--module",             type=str,               default="faster_rcnn_resnet50_coco2017",  help="Module used as feature extractor.")
parser.add_argument("--dataset",            type=str,               default="coco_10",                        help="Dataset to finetune.")
parser.add_argument("--use_data_parallel",  type=ast.literal_eval,  default=False,                            help="Whether use data parallel.")
# yapf: enable.


def finetune(args):
    module = hub.Module(name=args.module)
    dataset = hub.dataset.Coco10('rcnn')

    print("dataset.num_labels:", dataset.num_labels)

    # define batch reader
    data_reader = ObjectDetectionReader(dataset=dataset, model_type='rcnn')

    input_dict, output_dict, program = module.context(trainable=True)
    pred_input_dict, pred_output_dict, pred_program = module.context(
        trainable=False, phase='predict')

    feed_list = [
        input_dict["image"].name, input_dict["im_info"].name,
        input_dict['gt_bbox'].name, input_dict['gt_class'].name,
        input_dict['is_crowd'].name
    ]

    pred_feed_list = [
        pred_input_dict['image'].name, pred_input_dict['im_info'].name,
        pred_input_dict['im_shape'].name
    ]

    feature = [
W
wuzewu 已提交
49
        output_dict['head_features'], output_dict['rpn_cls_loss'],
W
wuzewu 已提交
50 51 52
        output_dict['rpn_reg_loss'], output_dict['generate_proposal_labels']
    ]

W
wuzewu 已提交
53
    pred_feature = [pred_output_dict['head_features'], pred_output_dict['rois']]
W
wuzewu 已提交
54 55 56 57 58 59 60 61 62 63 64 65

    config = hub.RunConfig(
        log_interval=10,
        eval_interval=100,
        use_data_parallel=args.use_data_parallel,
        use_pyreader=True,
        use_cuda=args.use_gpu,
        num_epoch=args.num_epoch,
        batch_size=args.batch_size,
        enable_memory_optim=False,
        checkpoint_dir=args.checkpoint_dir,
        strategy=hub.finetune.strategy.DefaultFinetuneStrategy(
W
wuzewu 已提交
66
            learning_rate=0.00025, optimizer_name="momentum", momentum=0.9))
W
wuzewu 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81

    task = hub.FasterRCNNTask(
        data_reader=data_reader,
        num_classes=dataset.num_labels,
        feed_list=feed_list,
        feature=feature,
        predict_feed_list=pred_feed_list,
        predict_feature=pred_feature,
        config=config)
    task.finetune_and_eval()


if __name__ == "__main__":
    args = parser.parse_args()
    finetune(args)