train.py 11.4 KB
Newer Older
1
import os
X
Xingyuan Bu 已提交
2
import time
3 4 5
import numpy as np
import argparse
import functools
D
Dang Qingqing 已提交
6
import shutil
7

D
Dang Qingqing 已提交
8 9 10 11 12 13
import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments

14 15
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
16 17
# yapf: disable
add_arg('learning_rate',    float, 0.001,     "Learning rate.")
18
add_arg('batch_size',       int,   16,        "Minibatch size.")
D
Dang Qingqing 已提交
19
add_arg('num_passes',       int,   120,       "Epoch number.")
20
add_arg('use_gpu',          bool,  True,      "Whether use GPU.")
X
Xingyuan Bu 已提交
21
add_arg('parallel',         bool,  True,      "Parallel.")
22 23 24 25
add_arg('dataset',          str,   'pascalvoc', "coco2014, coco2017, and pascalvoc.")
add_arg('model_save_dir',   str,   'model',     "The path to save model.")
add_arg('pretrained_model', str,   'pretrained/ssd_mobilenet_v1_coco/', "The init model path.")
add_arg('apply_distort',    bool,  True,   "Whether apply distort.")
B
add ce  
baiyfbupt 已提交
26
add_arg('apply_expand',     bool,  True,   "Whether apply expand.")
27
add_arg('nms_threshold',    float, 0.45,   "NMS threshold.")
X
Xingyuan Bu 已提交
28
add_arg('ap_version',       str,   '11point',   "integral, 11point.")
29 30 31 32 33 34
add_arg('resize_h',         int,   300,    "The resized image height.")
add_arg('resize_w',         int,   300,    "The resized image height.")
add_arg('mean_value_B',     float, 127.5,  "Mean value for B channel which will be subtracted.")  #123.68
add_arg('mean_value_G',     float, 127.5,  "Mean value for G channel which will be subtracted.")  #116.78
add_arg('mean_value_R',     float, 127.5,  "Mean value for R channel which will be subtracted.")  #103.94
add_arg('is_toy',           int,   0, "Toy for quick debug, 0 means using all data, while n means using only n sample.")
B
baiyf 已提交
35
add_arg('data_dir',         str,   'data/pascalvoc', "data directory")
B
kpi fix  
baiyfbupt 已提交
36
add_arg('enable_ce',     bool,  False, "Whether use CE to evaluate the model")
37
#yapf: enable
38

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
def build_program(is_train, main_prog, startup_prog, args, data_args,
                  boundaries=None, values=None, train_file_list=None):
    image_shape = [3, data_args.resize_h, data_args.resize_w]
    if 'coco' in data_args.dataset:
        num_classes = 91
    elif 'pascalvoc' in data_args.dataset:
        num_classes = 21

    def get_optimizer():
        optimizer = fluid.optimizer.RMSProp(
            learning_rate=fluid.layers.piecewise_decay(boundaries, values),
            regularization=fluid.regularizer.L2Decay(0.00005), )
        return optimizer

    with fluid.program_guard(main_prog, startup_prog):
        py_reader = fluid.layers.py_reader(
            capacity=64,
            shapes=[[-1] + image_shape, [-1, 4], [-1, 1], [-1, 1]],
            lod_levels=[0, 1, 1, 1],
            dtypes=["float32", "float32", "int32", "int32"],
            use_double_buffer=True)
        with fluid.unique_name.guard():
            image, gt_box, gt_label, difficult = fluid.layers.read_file(py_reader)
            locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
            if is_train:
                loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box,
                    box_var)
                loss = fluid.layers.reduce_sum(loss)
                optimizer = get_optimizer()
                optimizer.minimize(loss)
            else:
                nmsed_out = fluid.layers.detection_output(
                   locs, confs, box, box_var, nms_threshold=args.nms_threshold)
                with fluid.program_guard(main_prog):
                    loss = fluid.evaluator.DetectionMAP(
                        nmsed_out,
                        gt_label,
                        gt_box,
                        difficult,
                        num_classes,
                        overlap_threshold=0.5,
                        evaluate_difficult=False,
                        ap_version=args.ap_version)
    if not is_train:
        main_prog = main_prog.clone(for_test=True)
    return py_reader, loss
85

X
Xingyuan Bu 已提交
86 87 88 89 90 91 92 93 94
def train(args,
          train_file_list,
          val_file_list,
          data_args,
          learning_rate,
          batch_size,
          num_passes,
          model_save_dir,
          pretrained_model=None):
B
kpi fix  
baiyfbupt 已提交
95

96 97 98
    startup_prog = fluid.Program()
    train_prog = fluid.Program()
    test_prog = fluid.Program()
99

D
Dang Qingqing 已提交
100 101
    devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
    devices_num = len(devices.split(","))
102
    if 'coco' in data_args.dataset:
103 104
        # learning rate decay in 12, 19 pass, respectively
        if '2014' in train_file_list:
105 106
            epocs = 82783 // batch_size // devices_num
            test_epocs = 40504 // batch_size
D
Dang Qingqing 已提交
107
            boundaries = [epocs * 12, epocs * 19]
108
        elif '2017' in train_file_list:
109 110
            epocs = 118287 // batch_size // devices_num
            test_epocs = 5000 // batch_size
X
Xingyuan Bu 已提交
111
            boundaries = [epocs * 12, epocs * 19]
112 113
        values = [learning_rate, learning_rate * 0.5,
            learning_rate * 0.25]
X
Xingyuan Bu 已提交
114
    elif 'pascalvoc' in data_args.dataset:
115 116
        epocs = 19200 // batch_size // devices_num
        test_epocs = 4952 // batch_size
D
Dang Qingqing 已提交
117
        boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100]
118 119
        values = [
            learning_rate, learning_rate * 0.5, learning_rate * 0.25,
120
            learning_rate * 0.1, learning_rate * 0.01]
121

122 123 124 125
    if args.enable_ce:
        startup_prog.random_seed = 111
        train_prog.random_seed = 111
        test_prog.random_seed = 111
D
dangqingqing 已提交
126

127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
    train_py_reader, loss = build_program(
        is_train=True,
        main_prog=train_prog,
        startup_prog=startup_prog,
        args=args,
        data_args=data_args,
        values = values,
        boundaries = boundaries,
        train_file_list=train_file_list)
    test_py_reader, map_eval = build_program(
        is_train=False,
        main_prog=test_prog,
        startup_prog=startup_prog,
        args=args,
        data_args=data_args)
142 143
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
144
    exe.run(startup_prog)
145 146 147 148

    if pretrained_model:
        def if_exist(var):
            return os.path.exists(os.path.join(pretrained_model, var.name))
149
        fluid.io.load_vars(exe, pretrained_model, main_program=train_prog, predicate=if_exist)
150

151
    if args.parallel:
152
        train_exe = fluid.ParallelExecutor(main_program=train_prog,
153
            use_cuda=args.use_gpu, loss_name=loss.name)
154 155
        test_exe = fluid.ParallelExecutor(main_program=test_prog,
            use_cuda=args.use_gpu, share_vars_from=train_exe)
B
kpi fix  
baiyfbupt 已提交
156
    if not args.enable_ce:
157
        train_reader = reader.batch_reader(data_args, train_file_list, batch_size, "train")
B
kpi fix  
baiyfbupt 已提交
158
    else:
159 160 161
        import random
        random.seed(0)
        np.random.seed(0)
162 163 164 165 166 167
        train_reader = reader.batch_reader(data_args, train_file_list, batch_size, "train", shuffle=False)
    test_reader = reader.batch_reader(data_args, val_file_list, batch_size, "test")
    train_py_reader.decorate_paddle_reader(train_reader)
    test_py_reader.decorate_paddle_reader(test_reader)

    def save_model(postfix, main_prog):
D
Dang Qingqing 已提交
168 169 170
        model_path = os.path.join(model_save_dir, postfix)
        if os.path.isdir(model_path):
            shutil.rmtree(model_path)
M
minqiyang 已提交
171
        print('save models to %s' % (model_path))
172
        fluid.io.save_persistables(exe, model_path, main_program=main_prog)
D
Dang Qingqing 已提交
173 174 175

    best_map = 0.
    def test(pass_id, best_map):
176 177
        _, accum_map = map_eval.get_map_var()
        map_eval.reset(exe)
B
kpi fix  
baiyfbupt 已提交
178
        every_pass_map=[]
179 180 181 182 183 184 185 186 187 188 189 190 191 192
        test_py_reader.start()
        batch_id = 0
        try:
            while True:
                test_map, = exe.run(test_prog,
                                   fetch_list=[accum_map])
                if batch_id % 20 == 0:
                    every_pass_map.append(test_map)
                    print("Batch {0}, map {1}".format(batch_id, test_map))
                batch_id += 1
                if batch_id > test_epocs:
                    break
        except fluid.core.EOFException:
            test_py_reader.reset()
B
kpi fix  
baiyfbupt 已提交
193
        mean_map = np.mean(every_pass_map)
D
Dang Qingqing 已提交
194 195
        if test_map[0] > best_map:
            best_map = test_map[0]
196
            save_model('best_model', test_prog)
X
Xingyuan Bu 已提交
197
        print("Pass {0}, test map {1}".format(pass_id, test_map))
B
kpi fix  
baiyfbupt 已提交
198
        return best_map, mean_map
B
baiyfbupt 已提交
199

B
baiyfbupt 已提交
200
    total_time = 0.0
201
    for pass_id in range(num_passes):
B
baiyfbupt 已提交
202
        epoch_idx = pass_id + 1
203
        start_time = time.time()
204
        train_py_reader.start()
205
        prev_start_time = start_time
B
baiyf 已提交
206
        every_pass_loss = []
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
        batch_id = 0
        try:
            while True:
                prev_start_time = start_time
                start_time = time.time()

                if args.parallel:
                    loss_v, = train_exe.run(fetch_list=[loss.name])
                else:
                    loss_v, = exe.run(train_prog, fetch_list=[loss])
                loss_v = np.mean(np.array(loss_v))
                every_pass_loss.append(loss_v)
                if batch_id % 20 == 0:
                    print("Pass {0}, batch {1}, loss {2}, time {3}".format(
                        pass_id, batch_id, loss_v, start_time - prev_start_time))
                batch_id += 1
                if batch_id > epocs:
                    break
        except fluid.core.EOFException:
            train_py_reader.reset()
B
baiyf 已提交
227

B
add ce  
baiyfbupt 已提交
228
        end_time = time.time()
B
kpi fix  
baiyfbupt 已提交
229
        best_map, mean_map = test(pass_id, best_map)
230
        if args.enable_ce and pass_id == num_passes - 1:
B
baiyfbupt 已提交
231
            total_time += end_time - start_time
B
baiyfbupt 已提交
232
            train_avg_loss = np.mean(every_pass_loss)
B
kpi fix  
baiyfbupt 已提交
233
            if devices_num == 1:
M
minqiyang 已提交
234 235 236
                print("kpis	train_cost	%s" % train_avg_loss)
                print("kpis	test_acc	%s" % mean_map)
                print("kpis	train_speed	%s" % (total_time / epoch_idx))
B
baiyfbupt 已提交
237
            else:
M
minqiyang 已提交
238
                print("kpis	train_cost_card%s	%s" %
B
Bai Yifan 已提交
239
                       (devices_num, train_avg_loss))
M
minqiyang 已提交
240
                print("kpis	test_acc_card%s	%s" %
B
Bai Yifan 已提交
241
                       (devices_num, mean_map))
M
minqiyang 已提交
242
                print("kpis	train_speed_card%s	%f" %
B
Bai Yifan 已提交
243
                       (devices_num, total_time / epoch_idx))
B
baiyfbupt 已提交
244

245
        if pass_id % 10 == 0 or pass_id == num_passes - 1:
246
            save_model(str(pass_id), train_prog)
D
Dang Qingqing 已提交
247
    print("Best test map {0}".format(best_map))
D
dangqingqing 已提交
248 249

if __name__ == '__main__':
250 251
    args = parser.parse_args()
    print_arguments(args)
252

B
baiyf 已提交
253
    data_dir = args.data_dir
254 255
    label_file = 'label_list'
    model_save_dir = args.model_save_dir
B
baiyf 已提交
256 257
    train_file_list = 'trainval.txt'
    val_file_list = 'test.txt'
258
    if 'coco' in args.dataset:
X
Xingyuan Bu 已提交
259
        data_dir = 'data/coco'
260 261
        if '2014' in args.dataset:
            train_file_list = 'annotations/instances_train2014.json'
X
Xingyuan Bu 已提交
262
            val_file_list = 'annotations/instances_val2014.json'
263 264 265
        elif '2017' in args.dataset:
            train_file_list = 'annotations/instances_train2017.json'
            val_file_list = 'annotations/instances_val2017.json'
266

D
dangqingqing 已提交
267
    data_args = reader.Settings(
268 269 270
        dataset=args.dataset,
        data_dir=data_dir,
        label_file=label_file,
X
Xingyuan Bu 已提交
271 272
        resize_h=args.resize_h,
        resize_w=args.resize_w,
D
Dang Qingqing 已提交
273
        mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
274 275 276
        apply_distort=args.apply_distort,
        apply_expand=args.apply_expand,
        ap_version = args.ap_version,
D
Dang Qingqing 已提交
277
        toy=args.is_toy)
X
Xingyuan Bu 已提交
278
    train(
D
Dang Qingqing 已提交
279 280 281 282 283 284 285 286 287
        args,
        train_file_list=train_file_list,
        val_file_list=val_file_list,
        data_args=data_args,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        num_passes=args.num_passes,
        model_save_dir=model_save_dir,
        pretrained_model=args.pretrained_model)