train.py 9.9 KB
Newer Older
1
import os
X
Xingyuan Bu 已提交
2
import time
3 4 5
import numpy as np
import argparse
import functools
D
Dang Qingqing 已提交
6
import shutil
7

D
Dang Qingqing 已提交
8 9 10 11 12 13
import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments

14 15
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
16 17
# yapf: disable
add_arg('learning_rate',    float, 0.001,     "Learning rate.")
B
Bai Yifan 已提交
18
add_arg('batch_size',       int,   64,        "Minibatch size.")
D
Dang Qingqing 已提交
19
add_arg('num_passes',       int,   120,       "Epoch number.")
20
add_arg('use_gpu',          bool,  True,      "Whether use GPU.")
X
Xingyuan Bu 已提交
21
add_arg('parallel',         bool,  True,      "Parallel.")
22 23 24 25
add_arg('dataset',          str,   'pascalvoc', "coco2014, coco2017, and pascalvoc.")
add_arg('model_save_dir',   str,   'model',     "The path to save model.")
add_arg('pretrained_model', str,   'pretrained/ssd_mobilenet_v1_coco/', "The init model path.")
add_arg('apply_distort',    bool,  True,   "Whether apply distort.")
B
add ce  
baiyfbupt 已提交
26
add_arg('apply_expand',     bool,  True,   "Whether apply expand.")
27
add_arg('nms_threshold',    float, 0.45,   "NMS threshold.")
X
Xingyuan Bu 已提交
28
add_arg('ap_version',       str,   '11point',   "integral, 11point.")
29 30 31 32 33 34
add_arg('resize_h',         int,   300,    "The resized image height.")
add_arg('resize_w',         int,   300,    "The resized image height.")
add_arg('mean_value_B',     float, 127.5,  "Mean value for B channel which will be subtracted.")  #123.68
add_arg('mean_value_G',     float, 127.5,  "Mean value for G channel which will be subtracted.")  #116.78
add_arg('mean_value_R',     float, 127.5,  "Mean value for R channel which will be subtracted.")  #103.94
add_arg('is_toy',           int,   0, "Toy for quick debug, 0 means using all data, while n means using only n sample.")
B
baiyf 已提交
35
add_arg('data_dir',         str,   'data/pascalvoc', "data directory")
B
kpi fix  
baiyfbupt 已提交
36
add_arg('enable_ce',     bool,  False, "Whether use CE to evaluate the model")
37
#yapf: enable
38

X
Xingyuan Bu 已提交
39 40 41 42 43 44 45 46 47
def train(args,
          train_file_list,
          val_file_list,
          data_args,
          learning_rate,
          batch_size,
          num_passes,
          model_save_dir,
          pretrained_model=None):
B
Bai Yifan 已提交
48 49
    if args.enable_ce:
        fluid.framework.default_startup_program().random_seed = 111
B
kpi fix  
baiyfbupt 已提交
50

B
Bai Yifan 已提交
51 52 53 54 55
    image_shape = [3, data_args.resize_h, data_args.resize_w]
    if 'coco' in data_args.dataset:
        num_classes = 91
    elif 'pascalvoc' in data_args.dataset:
        num_classes = 21
56

D
Dang Qingqing 已提交
57 58
    devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
    devices_num = len(devices.split(","))
B
Bai Yifan 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

    image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
    gt_box = fluid.layers.data(
        name='gt_box', shape=[4], dtype='float32', lod_level=1)
    gt_label = fluid.layers.data(
        name='gt_label', shape=[1], dtype='int32', lod_level=1)
    difficult = fluid.layers.data(
        name='gt_difficult', shape=[1], dtype='int32', lod_level=1)
    locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
    nmsed_out = fluid.layers.detection_output(
        locs, confs, box, box_var, nms_threshold=args.nms_threshold)
    loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box,
                                 box_var)
    loss = fluid.layers.reduce_sum(loss)

    test_program = fluid.default_main_program().clone(for_test=True)
    with fluid.program_guard(test_program):
        map_eval = fluid.evaluator.DetectionMAP(
            nmsed_out,
            gt_label,
            gt_box,
            difficult,
            num_classes,
            overlap_threshold=0.5,
            evaluate_difficult=False,
            ap_version=args.ap_version)

86
    if 'coco' in data_args.dataset:
87 88
        # learning rate decay in 12, 19 pass, respectively
        if '2014' in train_file_list:
B
Bai Yifan 已提交
89
            epocs = 82783 // batch_size
D
Dang Qingqing 已提交
90
            boundaries = [epocs * 12, epocs * 19]
91
        elif '2017' in train_file_list:
B
Bai Yifan 已提交
92
            epocs = 118287 // batch_size
X
Xingyuan Bu 已提交
93
            boundaries = [epocs * 12, epocs * 19]
B
Bai Yifan 已提交
94 95 96
        values = [
            learning_rate, learning_rate * 0.5, learning_rate * 0.25
        ]
X
Xingyuan Bu 已提交
97
    elif 'pascalvoc' in data_args.dataset:
B
Bai Yifan 已提交
98
        epocs = 19200 // batch_size
D
Dang Qingqing 已提交
99
        boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100]
100 101
        values = [
            learning_rate, learning_rate * 0.5, learning_rate * 0.25,
B
Bai Yifan 已提交
102 103 104 105 106
            learning_rate * 0.1, learning_rate * 0.01
        ]
    optimizer = fluid.optimizer.RMSProp(
        learning_rate=fluid.layers.piecewise_decay(boundaries, values),
        regularization=fluid.regularizer.L2Decay(0.00005), )
107

B
Bai Yifan 已提交
108
    optimizer.minimize(loss)
D
dangqingqing 已提交
109

110 111
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
B
Bai Yifan 已提交
112
    exe.run(fluid.default_startup_program())
113 114 115 116

    if pretrained_model:
        def if_exist(var):
            return os.path.exists(os.path.join(pretrained_model, var.name))
B
Bai Yifan 已提交
117
        fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
118

119
    if args.parallel:
B
Bai Yifan 已提交
120
        train_exe = fluid.ParallelExecutor(
121
            use_cuda=args.use_gpu, loss_name=loss.name)
B
Bai Yifan 已提交
122

B
kpi fix  
baiyfbupt 已提交
123
    if not args.enable_ce:
B
Bai Yifan 已提交
124 125
        train_reader = paddle.batch(
            reader.train(data_args, train_file_list), batch_size=batch_size)
B
kpi fix  
baiyfbupt 已提交
126
    else:
127 128 129
        import random
        random.seed(0)
        np.random.seed(0)
B
Bai Yifan 已提交
130 131 132 133 134 135 136 137
        train_reader = paddle.batch(
            reader.train(data_args, train_file_list, False), batch_size=batch_size)
    test_reader = paddle.batch(
        reader.test(data_args, val_file_list), batch_size=batch_size)
    feeder = fluid.DataFeeder(
        place=place, feed_list=[image, gt_box, gt_label, difficult])

    def save_model(postfix):
D
Dang Qingqing 已提交
138 139 140
        model_path = os.path.join(model_save_dir, postfix)
        if os.path.isdir(model_path):
            shutil.rmtree(model_path)
M
minqiyang 已提交
141
        print('save models to %s' % (model_path))
B
Bai Yifan 已提交
142
        fluid.io.save_persistables(exe, model_path)
D
Dang Qingqing 已提交
143 144

    best_map = 0.
B
Bai Yifan 已提交
145

D
Dang Qingqing 已提交
146
    def test(pass_id, best_map):
147 148
        _, accum_map = map_eval.get_map_var()
        map_eval.reset(exe)
B
kpi fix  
baiyfbupt 已提交
149
        every_pass_map=[]
B
Bai Yifan 已提交
150 151 152 153 154 155 156
        for batch_id, data in enumerate(test_reader()):
            test_map, = exe.run(test_program,
                               feed=feeder.feed(data),
                               fetch_list=[accum_map])
            if batch_id % 20 == 0:
                every_pass_map.append(test_map)
                print("Batch {0}, map {1}".format(batch_id, test_map))
B
kpi fix  
baiyfbupt 已提交
157
        mean_map = np.mean(every_pass_map)
D
Dang Qingqing 已提交
158 159
        if test_map[0] > best_map:
            best_map = test_map[0]
B
Bai Yifan 已提交
160
            save_model('best_model')
X
Xingyuan Bu 已提交
161
        print("Pass {0}, test map {1}".format(pass_id, test_map))
B
kpi fix  
baiyfbupt 已提交
162
        return best_map, mean_map
B
baiyfbupt 已提交
163

164
    for pass_id in range(num_passes):
B
baiyfbupt 已提交
165
        batch_begin = time.time()
166
        start_time = time.time()
B
Bai Yifan 已提交
167
        prev_start_time = start_time
B
baiyf 已提交
168
        every_pass_loss = []
B
Bai Yifan 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
        for batch_id, data in enumerate(train_reader()):
            prev_start_time = start_time
            start_time = time.time()
            if len(data) < (devices_num * 2):
                print("There are too few data to train on all devices.")
                continue
            if args.parallel:
                loss_v, = train_exe.run(fetch_list=[loss.name],
                                        feed=feeder.feed(data))
            else:
                loss_v, = exe.run(fluid.default_main_program(),
                                  feed=feeder.feed(data),
                                  fetch_list=[loss])
            loss_v = np.mean(np.array(loss_v))
            every_pass_loss.append(loss_v)
            if batch_id % 20 == 0:
                print("Pass {0}, batch {1}, loss {2}, time {3}".format(
                    pass_id, batch_id, loss_v, start_time - prev_start_time))

        end_time = time.time()
B
kpi fix  
baiyfbupt 已提交
189
        best_map, mean_map = test(pass_id, best_map)
B
Bai Yifan 已提交
190 191
        if args.enable_ce and pass_id == args.num_passes - 1:
            total_time = end_time - start_time
B
baiyfbupt 已提交
192
            train_avg_loss = np.mean(every_pass_loss)
B
kpi fix  
baiyfbupt 已提交
193
            if devices_num == 1:
M
minqiyang 已提交
194 195
                print("kpis	train_cost	%s" % train_avg_loss)
                print("kpis	test_acc	%s" % mean_map)
B
fix ce  
baiyfbupt 已提交
196
                print("kpis	train_speed	%s" % (epocs / total_time))
B
baiyfbupt 已提交
197
            else:
M
minqiyang 已提交
198
                print("kpis	train_cost_card%s	%s" %
B
Bai Yifan 已提交
199
                       (devices_num, train_avg_loss))
M
minqiyang 已提交
200
                print("kpis	test_acc_card%s	%s" %
B
Bai Yifan 已提交
201
                       (devices_num, mean_map))
M
minqiyang 已提交
202
                print("kpis	train_speed_card%s	%f" %
B
Bai Yifan 已提交
203 204
                       (devices_num, epocs / total_time))

B
baiyfbupt 已提交
205

206
        if pass_id % 10 == 0 or pass_id == num_passes - 1:
B
Bai Yifan 已提交
207
            save_model(str(pass_id))
D
Dang Qingqing 已提交
208
    print("Best test map {0}".format(best_map))
D
dangqingqing 已提交
209 210

if __name__ == '__main__':
211 212
    args = parser.parse_args()
    print_arguments(args)
213

B
baiyf 已提交
214
    data_dir = args.data_dir
215 216
    label_file = 'label_list'
    model_save_dir = args.model_save_dir
B
baiyf 已提交
217 218
    train_file_list = 'trainval.txt'
    val_file_list = 'test.txt'
219
    if 'coco' in args.dataset:
X
Xingyuan Bu 已提交
220
        data_dir = 'data/coco'
221 222
        if '2014' in args.dataset:
            train_file_list = 'annotations/instances_train2014.json'
X
Xingyuan Bu 已提交
223
            val_file_list = 'annotations/instances_val2014.json'
224 225 226
        elif '2017' in args.dataset:
            train_file_list = 'annotations/instances_train2017.json'
            val_file_list = 'annotations/instances_val2017.json'
227

D
dangqingqing 已提交
228
    data_args = reader.Settings(
229 230 231
        dataset=args.dataset,
        data_dir=data_dir,
        label_file=label_file,
X
Xingyuan Bu 已提交
232 233
        resize_h=args.resize_h,
        resize_w=args.resize_w,
D
Dang Qingqing 已提交
234
        mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
235 236 237
        apply_distort=args.apply_distort,
        apply_expand=args.apply_expand,
        ap_version = args.ap_version,
D
Dang Qingqing 已提交
238
        toy=args.is_toy)
X
Xingyuan Bu 已提交
239
    train(
D
Dang Qingqing 已提交
240 241 242 243 244 245 246 247 248
        args,
        train_file_list=train_file_list,
        val_file_list=val_file_list,
        data_args=data_args,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        num_passes=args.num_passes,
        model_save_dir=model_save_dir,
        pretrained_model=args.pretrained_model)