train.py 7.4 KB
Newer Older
X
Xingyuan Bu 已提交
1
import paddle
D
dangqingqing 已提交
2 3 4 5
import paddle.fluid as fluid
import reader
import load_model as load_model
from mobilenet_ssd import mobile_net
6 7
from utility import add_arguments, print_arguments
import os
X
Xingyuan Bu 已提交
8
import time
9 10 11 12 13 14
import numpy as np
import argparse
import functools

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
X
Xingyuan Bu 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('num_passes', int, 25, "Epoch number.")
add_arg('parallel', bool, True, "Whether use parallel training.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('data_dir', str, './data/COCO17', "Root path of data")
add_arg('train_file_list', str, 'annotations/instances_train2017.json',
        "train file list")
add_arg('val_file_list', str, 'annotations/instances_val2017.json',
        "vaild file list")
add_arg('model_save_dir', str, 'model_COCO17', "where to save model")

add_arg('dataset', str, 'coco', "coco or pascalvoc")
add_arg(
    'is_toy', int, 0,
    "Is Toy for quick debug, 0 means using all data, while n means using only n sample"
)
add_arg('label_file', str, 'label_list',
        "Lable file which lists all label name")
add_arg('apply_distort', bool, True, "Whether apply distort")
add_arg('apply_expand', bool, False, "Whether appley expand")
add_arg('resize_h', int, 300, "resize image size")
add_arg('resize_w', int, 300, "resize image size")
add_arg('mean_value_B', float, 127.5,
        "mean value which will be subtracted")  #123.68
add_arg('mean_value_G', float, 127.5,
        "mean value which will be subtracted")  #116.78
add_arg('mean_value_R', float, 127.5,
        "mean value which will be subtracted")  #103.94
D
dangqingqing 已提交
44 45


46 47
def train(args,
          train_file_list,
D
dangqingqing 已提交
48 49 50 51 52 53 54 55
          val_file_list,
          data_args,
          learning_rate,
          batch_size,
          num_passes,
          model_save_dir='model',
          init_model_path=None):
    image_shape = [3, data_args.resize_h, data_args.resize_w]
X
Xingyuan Bu 已提交
56 57 58 59
    if data_args.dataset == 'coco':
        num_classes = 81
    elif data_args.dataset == 'pascalvoc':
        num_classes = 21
D
dangqingqing 已提交
60 61 62 63 64 65 66 67 68

    image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
    gt_box = fluid.layers.data(
        name='gt_box', shape=[4], dtype='float32', lod_level=1)
    gt_label = fluid.layers.data(
        name='gt_label', shape=[1], dtype='int32', lod_level=1)
    difficult = fluid.layers.data(
        name='gt_difficult', shape=[1], dtype='int32', lod_level=1)

69 70 71 72 73 74 75 76
    if args.parallel:
        places = fluid.layers.get_places()
        pd = fluid.layers.ParallelDo(places)
        with pd.do():
            image_ = pd.read_input(image)
            gt_box_ = pd.read_input(gt_box)
            gt_label_ = pd.read_input(gt_label)
            difficult_ = pd.read_input(difficult)
X
Xingyuan Bu 已提交
77 78 79 80
            locs, confs, box, box_var = mobile_net(num_classes, image_,
                                                   image_shape)
            loss = fluid.layers.ssd_loss(locs, confs, gt_box_, gt_label_, box,
                                         box_var)
81 82
            nmsed_out = fluid.layers.detection_output(
                locs, confs, box, box_var, nms_threshold=0.45)
83
            loss = fluid.layers.reduce_sum(loss)
84
            pd.write_output(loss)
85
            pd.write_output(nmsed_out)
86

87
        loss, nmsed_out = pd()
88
        loss = fluid.layers.mean(loss)
89
    else:
X
Xingyuan Bu 已提交
90
        locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
91
        nmsed_out = fluid.layers.detection_output(
92
            locs, confs, box, box_var, nms_threshold=0.45)
X
Xingyuan Bu 已提交
93 94
        loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box,
                                     box_var)
95
        loss = fluid.layers.reduce_sum(loss)
D
dangqingqing 已提交
96 97 98

    test_program = fluid.default_main_program().clone(for_test=True)
    with fluid.program_guard(test_program):
99 100 101 102 103
        map_eval = fluid.evaluator.DetectionMAP(
            nmsed_out,
            gt_label,
            gt_box,
            difficult,
X
Xingyuan Bu 已提交
104
            num_classes,
105 106
            overlap_threshold=0.5,
            evaluate_difficult=False,
X
Xingyuan Bu 已提交
107 108 109 110 111 112 113 114 115 116 117
            ap_version='integral')

    if data_args.dataset == 'coco':
        # learning rate decay in 12, 19 pass, respectively
        if '2014' in train_file_list:
            boundaries = [82783 / batch_size * 12, 82783 / batch_size * 19]
        elif '2017' in train_file_list:
            boundaries = [118287 / batch_size * 12, 118287 / batch_size * 19]
    elif data_args.dataset == 'pascalvoc':
        boundaries = [40000, 60000]
    values = [learning_rate, learning_rate * 0.5, learning_rate * 0.25]
G
gaoyuan 已提交
118 119
    optimizer = fluid.optimizer.RMSProp(
        learning_rate=fluid.layers.piecewise_decay(boundaries, values),
G
gaoyuan 已提交
120
        regularization=fluid.regularizer.L2Decay(0.00005), )
D
dangqingqing 已提交
121

122
    optimizer.minimize(loss)
D
dangqingqing 已提交
123

124
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
D
dangqingqing 已提交
125 126 127
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

X
Xingyuan Bu 已提交
128 129
    #load_model.load_and_set_vars(place)
    load_model.load_paddlev1_vars(place)
D
dangqingqing 已提交
130 131 132 133 134 135 136 137
    train_reader = paddle.batch(
        reader.train(data_args, train_file_list), batch_size=batch_size)
    test_reader = paddle.batch(
        reader.test(data_args, val_file_list), batch_size=batch_size)
    feeder = fluid.DataFeeder(
        place=place, feed_list=[image, gt_box, gt_label, difficult])

    def test(pass_id):
138
        _, accum_map = map_eval.get_map_var()
D
dangqingqing 已提交
139 140 141 142 143 144 145 146 147
        map_eval.reset(exe)
        test_map = None
        for _, data in enumerate(test_reader()):
            test_map = exe.run(test_program,
                               feed=feeder.feed(data),
                               fetch_list=[accum_map])
        print("Test {0}, map {1}".format(pass_id, test_map[0]))

    for pass_id in range(num_passes):
X
Xingyuan Bu 已提交
148 149 150
        start_time = time.time()
        prev_start_time = start_time
        end_time = 0
D
dangqingqing 已提交
151
        for batch_id, data in enumerate(train_reader()):
X
Xingyuan Bu 已提交
152 153 154
            prev_start_time = start_time
            start_time = time.time()
            #print("Batch {} start at {:.2f}".format(batch_id, start_time))
155 156 157
            loss_v = exe.run(fluid.default_main_program(),
                             feed=feeder.feed(data),
                             fetch_list=[loss])
X
Xingyuan Bu 已提交
158
            end_time = time.time()
G
gaoyuan 已提交
159
            if batch_id % 20 == 0:
X
Xingyuan Bu 已提交
160 161
                print("Pass {0}, batch {1}, loss {2}, time {3}".format(
                    pass_id, batch_id, loss_v[0], start_time - prev_start_time))
D
dangqingqing 已提交
162 163
        test(pass_id)

X
Xingyuan Bu 已提交
164
        if pass_id % 10 == 0 or pass_id == num_passes - 1:
D
dangqingqing 已提交
165 166 167 168 169 170 171
            model_path = os.path.join(model_save_dir, str(pass_id))
            print 'save models to %s' % (model_path)
            fluid.io.save_inference_model(model_path, ['image'], [nmsed_out],
                                          exe)


if __name__ == '__main__':
172 173
    args = parser.parse_args()
    print_arguments(args)
D
dangqingqing 已提交
174
    data_args = reader.Settings(
X
Xingyuan Bu 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
        dataset=args.dataset,  # coco or pascalvoc
        toy=args.is_toy,
        data_dir=args.data_dir,
        label_file=args.label_file,
        apply_distort=args.apply_distort,
        apply_expand=args.apply_expand,
        resize_h=args.resize_h,
        resize_w=args.resize_w,
        mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R])
    train(
        args,
        train_file_list=args.train_file_list,
        val_file_list=args.val_file_list,
        data_args=data_args,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        num_passes=args.num_passes,
        model_save_dir=args.model_save_dir)