train.py 9.9 KB
Newer Older
1
import os
X
Xingyuan Bu 已提交
2
import time
3 4 5
import numpy as np
import argparse
import functools
D
Dang Qingqing 已提交
6
import shutil
7

D
Dang Qingqing 已提交
8 9 10 11 12 13
import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments

14 15
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
16 17
# yapf: disable
add_arg('learning_rate',    float, 0.001,     "Learning rate.")
X
Xingyuan Bu 已提交
18
add_arg('batch_size',       int,   64,        "Minibatch size.")
D
Dang Qingqing 已提交
19
add_arg('num_passes',       int,   120,       "Epoch number.")
20
add_arg('use_gpu',          bool,  True,      "Whether use GPU.")
X
Xingyuan Bu 已提交
21
add_arg('parallel',         bool,  True,      "Parallel.")
22 23 24 25
add_arg('dataset',          str,   'pascalvoc', "coco2014, coco2017, and pascalvoc.")
add_arg('model_save_dir',   str,   'model',     "The path to save model.")
add_arg('pretrained_model', str,   'pretrained/ssd_mobilenet_v1_coco/', "The init model path.")
add_arg('apply_distort',    bool,  True,   "Whether apply distort.")
B
add ce  
baiyfbupt 已提交
26
add_arg('apply_expand',     bool,  True,   "Whether apply expand.")
27
add_arg('nms_threshold',    float, 0.45,   "NMS threshold.")
X
Xingyuan Bu 已提交
28
add_arg('ap_version',       str,   '11point',   "integral, 11point.")
29 30 31 32 33 34
add_arg('resize_h',         int,   300,    "The resized image height.")
add_arg('resize_w',         int,   300,    "The resized image height.")
add_arg('mean_value_B',     float, 127.5,  "Mean value for B channel which will be subtracted.")  #123.68
add_arg('mean_value_G',     float, 127.5,  "Mean value for G channel which will be subtracted.")  #116.78
add_arg('mean_value_R',     float, 127.5,  "Mean value for R channel which will be subtracted.")  #103.94
add_arg('is_toy',           int,   0, "Toy for quick debug, 0 means using all data, while n means using only n sample.")
B
baiyf 已提交
35
add_arg('data_dir',         str,   'data/pascalvoc', "data directory")
B
kpi fix  
baiyfbupt 已提交
36
add_arg('enable_ce',     bool,  False, "Whether use CE to evaluate the model")
37
#yapf: enable
38 39


X
Xingyuan Bu 已提交
40 41 42 43 44 45 46 47 48
def train(args,
          train_file_list,
          val_file_list,
          data_args,
          learning_rate,
          batch_size,
          num_passes,
          model_save_dir,
          pretrained_model=None):
B
kpi fix  
baiyfbupt 已提交
49 50 51
    if args.enable_ce:
        fluid.framework.default_startup_program().random_seed = 111

52
    image_shape = [3, data_args.resize_h, data_args.resize_w]
53 54 55
    if 'coco' in data_args.dataset:
        num_classes = 91
    elif 'pascalvoc' in data_args.dataset:
56 57
        num_classes = 21

D
Dang Qingqing 已提交
58 59 60
    devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
    devices_num = len(devices.split(","))

61 62 63 64 65 66 67 68 69
    image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
    gt_box = fluid.layers.data(
        name='gt_box', shape=[4], dtype='float32', lod_level=1)
    gt_label = fluid.layers.data(
        name='gt_label', shape=[1], dtype='int32', lod_level=1)
    difficult = fluid.layers.data(
        name='gt_difficult', shape=[1], dtype='int32', lod_level=1)
    locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
    nmsed_out = fluid.layers.detection_output(
70 71 72
        locs, confs, box, box_var, nms_threshold=args.nms_threshold)
    loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box,
                                 box_var)
73 74 75 76 77 78 79 80 81 82 83 84
    loss = fluid.layers.reduce_sum(loss)

    test_program = fluid.default_main_program().clone(for_test=True)
    with fluid.program_guard(test_program):
        map_eval = fluid.evaluator.DetectionMAP(
            nmsed_out,
            gt_label,
            gt_box,
            difficult,
            num_classes,
            overlap_threshold=0.5,
            evaluate_difficult=False,
D
Dang Qingqing 已提交
85
            ap_version=args.ap_version)
86

87
    if 'coco' in data_args.dataset:
88 89
        # learning rate decay in 12, 19 pass, respectively
        if '2014' in train_file_list:
M
minqiyang 已提交
90
            epocs = 82783 // batch_size
D
Dang Qingqing 已提交
91
            boundaries = [epocs * 12, epocs * 19]
92
        elif '2017' in train_file_list:
M
minqiyang 已提交
93
            epocs = 118287 // batch_size
X
Xingyuan Bu 已提交
94
            boundaries = [epocs * 12, epocs * 19]
95 96 97
        values = [
            learning_rate, learning_rate * 0.5, learning_rate * 0.25
        ]
X
Xingyuan Bu 已提交
98
    elif 'pascalvoc' in data_args.dataset:
M
minqiyang 已提交
99
        epocs = 19200 // batch_size
D
Dang Qingqing 已提交
100
        boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100]
101 102 103 104
        values = [
            learning_rate, learning_rate * 0.5, learning_rate * 0.25,
            learning_rate * 0.1, learning_rate * 0.01
        ]
105 106 107 108 109
    optimizer = fluid.optimizer.RMSProp(
        learning_rate=fluid.layers.piecewise_decay(boundaries, values),
        regularization=fluid.regularizer.L2Decay(0.00005), )

    optimizer.minimize(loss)
D
dangqingqing 已提交
110

111 112 113 114 115 116 117 118 119
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

    if pretrained_model:
        def if_exist(var):
            return os.path.exists(os.path.join(pretrained_model, var.name))
        fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)

120 121 122
    if args.parallel:
        train_exe = fluid.ParallelExecutor(
            use_cuda=args.use_gpu, loss_name=loss.name)
123

B
kpi fix  
baiyfbupt 已提交
124 125 126 127
    if not args.enable_ce:
        train_reader = paddle.batch(
            reader.train(data_args, train_file_list), batch_size=batch_size)
    else:
128 129 130
        import random
        random.seed(0)
        np.random.seed(0)
B
kpi fix  
baiyfbupt 已提交
131 132
        train_reader = paddle.batch(
            reader.train(data_args, train_file_list, False), batch_size=batch_size)
133 134 135 136 137
    test_reader = paddle.batch(
        reader.test(data_args, val_file_list), batch_size=batch_size)
    feeder = fluid.DataFeeder(
        place=place, feed_list=[image, gt_box, gt_label, difficult])

D
Dang Qingqing 已提交
138 139 140 141
    def save_model(postfix):
        model_path = os.path.join(model_save_dir, postfix)
        if os.path.isdir(model_path):
            shutil.rmtree(model_path)
M
minqiyang 已提交
142
        print('save models to %s' % (model_path))
D
Dang Qingqing 已提交
143 144 145
        fluid.io.save_persistables(exe, model_path)

    best_map = 0.
D
Dang Qingqing 已提交
146

D
Dang Qingqing 已提交
147
    def test(pass_id, best_map):
148 149
        _, accum_map = map_eval.get_map_var()
        map_eval.reset(exe)
B
kpi fix  
baiyfbupt 已提交
150
        every_pass_map=[]
151
        for batch_id, data in enumerate(test_reader()):
X
Xingyuan Bu 已提交
152
            test_map, = exe.run(test_program,
153 154
                               feed=feeder.feed(data),
                               fetch_list=[accum_map])
155
            if batch_id % 20 == 0:
B
kpi fix  
baiyfbupt 已提交
156
                every_pass_map.append(test_map)
X
Xingyuan Bu 已提交
157
                print("Batch {0}, map {1}".format(batch_id, test_map))
B
kpi fix  
baiyfbupt 已提交
158
        mean_map = np.mean(every_pass_map)
D
Dang Qingqing 已提交
159 160 161
        if test_map[0] > best_map:
            best_map = test_map[0]
            save_model('best_model')
X
Xingyuan Bu 已提交
162
        print("Pass {0}, test map {1}".format(pass_id, test_map))
B
kpi fix  
baiyfbupt 已提交
163
        return best_map, mean_map
B
baiyfbupt 已提交
164

B
baiyfbupt 已提交
165
    total_time = 0.0
166
    for pass_id in range(num_passes):
B
baiyfbupt 已提交
167
        epoch_idx = pass_id + 1
168 169
        start_time = time.time()
        prev_start_time = start_time
B
baiyf 已提交
170
        every_pass_loss = []
171 172 173
        for batch_id, data in enumerate(train_reader()):
            prev_start_time = start_time
            start_time = time.time()
X
Xingyuan Bu 已提交
174 175 176
            if len(data) < (devices_num * 2):
                print("There are too few data to train on all devices.")
                continue
177 178
            if args.parallel:
                loss_v, = train_exe.run(fetch_list=[loss.name],
179
                                        feed=feeder.feed(data))
180 181 182 183
            else:
                loss_v, = exe.run(fluid.default_main_program(),
                                  feed=feeder.feed(data),
                                  fetch_list=[loss])
184
            loss_v = np.mean(np.array(loss_v))
B
baiyfbupt 已提交
185
            every_pass_loss.append(loss_v)
186 187 188
            if batch_id % 20 == 0:
                print("Pass {0}, batch {1}, loss {2}, time {3}".format(
                    pass_id, batch_id, loss_v, start_time - prev_start_time))
B
baiyf 已提交
189

B
add ce  
baiyfbupt 已提交
190
        end_time = time.time()
B
kpi fix  
baiyfbupt 已提交
191 192
        best_map, mean_map = test(pass_id, best_map)
        if args.enable_ce and pass_id == 1:
B
baiyfbupt 已提交
193
            total_time += end_time - start_time
B
baiyfbupt 已提交
194
            train_avg_loss = np.mean(every_pass_loss)
B
kpi fix  
baiyfbupt 已提交
195
            if devices_num == 1:
M
minqiyang 已提交
196 197 198
                print("kpis	train_cost	%s" % train_avg_loss)
                print("kpis	test_acc	%s" % mean_map)
                print("kpis	train_speed	%s" % (total_time / epoch_idx))
B
baiyfbupt 已提交
199
            else:
M
minqiyang 已提交
200
                print("kpis	train_cost_card%s	%s" %
B
Bai Yifan 已提交
201
                       (devices_num, train_avg_loss))
M
minqiyang 已提交
202
                print("kpis	test_acc_card%s	%s" %
B
Bai Yifan 已提交
203
                       (devices_num, mean_map))
M
minqiyang 已提交
204
                print("kpis	train_speed_card%s	%f" %
B
Bai Yifan 已提交
205
                       (devices_num, total_time / epoch_idx))
B
baiyfbupt 已提交
206

B
add ce  
baiyfbupt 已提交
207

208
        if pass_id % 10 == 0 or pass_id == num_passes - 1:
D
Dang Qingqing 已提交
209 210
            save_model(str(pass_id))
    print("Best test map {0}".format(best_map))
D
dangqingqing 已提交
211 212

if __name__ == '__main__':
213 214
    args = parser.parse_args()
    print_arguments(args)
215

B
baiyf 已提交
216
    data_dir = args.data_dir
217 218
    label_file = 'label_list'
    model_save_dir = args.model_save_dir
B
baiyf 已提交
219 220
    train_file_list = 'trainval.txt'
    val_file_list = 'test.txt'
221
    if 'coco' in args.dataset:
X
Xingyuan Bu 已提交
222
        data_dir = 'data/coco'
223 224
        if '2014' in args.dataset:
            train_file_list = 'annotations/instances_train2014.json'
X
Xingyuan Bu 已提交
225
            val_file_list = 'annotations/instances_val2014.json'
226 227 228
        elif '2017' in args.dataset:
            train_file_list = 'annotations/instances_train2017.json'
            val_file_list = 'annotations/instances_val2017.json'
229

D
dangqingqing 已提交
230
    data_args = reader.Settings(
231 232 233
        dataset=args.dataset,
        data_dir=data_dir,
        label_file=label_file,
X
Xingyuan Bu 已提交
234 235
        resize_h=args.resize_h,
        resize_w=args.resize_w,
D
Dang Qingqing 已提交
236
        mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
237 238 239
        apply_distort=args.apply_distort,
        apply_expand=args.apply_expand,
        ap_version = args.ap_version,
D
Dang Qingqing 已提交
240
        toy=args.is_toy)
X
Xingyuan Bu 已提交
241
    train(
D
Dang Qingqing 已提交
242 243 244 245 246 247 248 249 250
        args,
        train_file_list=train_file_list,
        val_file_list=val_file_list,
        data_args=data_args,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        num_passes=args.num_passes,
        model_save_dir=model_save_dir,
        pretrained_model=args.pretrained_model)