train.py 11.5 KB
Newer Older
1
import os
X
Xingyuan Bu 已提交
2
import time
3 4 5
import numpy as np
import argparse
import functools
D
Dang Qingqing 已提交
6
import shutil
7

D
Dang Qingqing 已提交
8 9 10 11 12 13
import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments

14 15
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
16 17
# yapf: disable
add_arg('learning_rate',    float, 0.001,     "Learning rate.")
18
add_arg('batch_size',       int,   16,        "Minibatch size.")
D
Dang Qingqing 已提交
19
add_arg('num_passes',       int,   120,       "Epoch number.")
20
add_arg('use_gpu',          bool,  True,      "Whether use GPU.")
X
Xingyuan Bu 已提交
21
add_arg('parallel',         bool,  True,      "Parallel.")
22 23 24 25
add_arg('dataset',          str,   'pascalvoc', "coco2014, coco2017, and pascalvoc.")
add_arg('model_save_dir',   str,   'model',     "The path to save model.")
add_arg('pretrained_model', str,   'pretrained/ssd_mobilenet_v1_coco/', "The init model path.")
add_arg('apply_distort',    bool,  True,   "Whether apply distort.")
B
add ce  
baiyfbupt 已提交
26
add_arg('apply_expand',     bool,  True,   "Whether apply expand.")
27
add_arg('nms_threshold',    float, 0.45,   "NMS threshold.")
X
Xingyuan Bu 已提交
28
add_arg('ap_version',       str,   '11point',   "integral, 11point.")
29 30 31 32 33 34
add_arg('resize_h',         int,   300,    "The resized image height.")
add_arg('resize_w',         int,   300,    "The resized image height.")
add_arg('mean_value_B',     float, 127.5,  "Mean value for B channel which will be subtracted.")  #123.68
add_arg('mean_value_G',     float, 127.5,  "Mean value for G channel which will be subtracted.")  #116.78
add_arg('mean_value_R',     float, 127.5,  "Mean value for R channel which will be subtracted.")  #103.94
add_arg('is_toy',           int,   0, "Toy for quick debug, 0 means using all data, while n means using only n sample.")
B
baiyf 已提交
35
add_arg('data_dir',         str,   'data/pascalvoc', "data directory")
B
kpi fix  
baiyfbupt 已提交
36
add_arg('enable_ce',     bool,  False, "Whether use CE to evaluate the model")
37
#yapf: enable
38

39 40 41 42 43 44 45 46 47
def build_program(is_train, main_prog, startup_prog, args, data_args,
                  boundaries=None, values=None, train_file_list=None):
    image_shape = [3, data_args.resize_h, data_args.resize_w]
    if 'coco' in data_args.dataset:
        num_classes = 91
    elif 'pascalvoc' in data_args.dataset:
        num_classes = 21

    def get_optimizer():
B
baiyfbupt 已提交
48 49 50 51 52 53
        if not args.enable_ce:
            optimizer = fluid.optimizer.RMSProp(
                learning_rate=fluid.layers.piecewise_decay(boundaries, values),
                regularization=fluid.regularizer.L2Decay(0.00005), )
        else:
            optimizer = fluid.optimizer.RMSProp(learning_rate=0.001)
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        return optimizer

    with fluid.program_guard(main_prog, startup_prog):
        py_reader = fluid.layers.py_reader(
            capacity=64,
            shapes=[[-1] + image_shape, [-1, 4], [-1, 1], [-1, 1]],
            lod_levels=[0, 1, 1, 1],
            dtypes=["float32", "float32", "int32", "int32"],
            use_double_buffer=True)
        with fluid.unique_name.guard():
            image, gt_box, gt_label, difficult = fluid.layers.read_file(py_reader)
            locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
            if is_train:
                loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box,
                    box_var)
                loss = fluid.layers.reduce_sum(loss)
                optimizer = get_optimizer()
                optimizer.minimize(loss)
            else:
                nmsed_out = fluid.layers.detection_output(
                   locs, confs, box, box_var, nms_threshold=args.nms_threshold)
                with fluid.program_guard(main_prog):
                    loss = fluid.evaluator.DetectionMAP(
                        nmsed_out,
                        gt_label,
                        gt_box,
                        difficult,
                        num_classes,
                        overlap_threshold=0.5,
                        evaluate_difficult=False,
                        ap_version=args.ap_version)
    if not is_train:
        main_prog = main_prog.clone(for_test=True)
    return py_reader, loss
88

X
Xingyuan Bu 已提交
89 90 91 92 93 94 95 96 97
def train(args,
          train_file_list,
          val_file_list,
          data_args,
          learning_rate,
          batch_size,
          num_passes,
          model_save_dir,
          pretrained_model=None):
B
kpi fix  
baiyfbupt 已提交
98

99 100 101
    startup_prog = fluid.Program()
    train_prog = fluid.Program()
    test_prog = fluid.Program()
102

D
Dang Qingqing 已提交
103 104
    devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
    devices_num = len(devices.split(","))
105
    if 'coco' in data_args.dataset:
106 107
        # learning rate decay in 12, 19 pass, respectively
        if '2014' in train_file_list:
108 109
            epocs = 82783 // batch_size // devices_num
            test_epocs = 40504 // batch_size
D
Dang Qingqing 已提交
110
            boundaries = [epocs * 12, epocs * 19]
111
        elif '2017' in train_file_list:
112 113
            epocs = 118287 // batch_size // devices_num
            test_epocs = 5000 // batch_size
X
Xingyuan Bu 已提交
114
            boundaries = [epocs * 12, epocs * 19]
115 116
        values = [learning_rate, learning_rate * 0.5,
            learning_rate * 0.25]
X
Xingyuan Bu 已提交
117
    elif 'pascalvoc' in data_args.dataset:
118 119
        epocs = 19200 // batch_size // devices_num
        test_epocs = 4952 // batch_size
D
Dang Qingqing 已提交
120
        boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100]
121 122
        values = [
            learning_rate, learning_rate * 0.5, learning_rate * 0.25,
123
            learning_rate * 0.1, learning_rate * 0.01]
124

125 126 127 128
    if args.enable_ce:
        startup_prog.random_seed = 111
        train_prog.random_seed = 111
        test_prog.random_seed = 111
D
dangqingqing 已提交
129

130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
    train_py_reader, loss = build_program(
        is_train=True,
        main_prog=train_prog,
        startup_prog=startup_prog,
        args=args,
        data_args=data_args,
        values = values,
        boundaries = boundaries,
        train_file_list=train_file_list)
    test_py_reader, map_eval = build_program(
        is_train=False,
        main_prog=test_prog,
        startup_prog=startup_prog,
        args=args,
        data_args=data_args)
145 146
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
147
    exe.run(startup_prog)
148 149 150 151

    if pretrained_model:
        def if_exist(var):
            return os.path.exists(os.path.join(pretrained_model, var.name))
152
        fluid.io.load_vars(exe, pretrained_model, main_program=train_prog, predicate=if_exist)
153

154
    if args.parallel:
155
        train_exe = fluid.ParallelExecutor(main_program=train_prog,
156
            use_cuda=args.use_gpu, loss_name=loss.name)
157 158
        test_exe = fluid.ParallelExecutor(main_program=test_prog,
            use_cuda=args.use_gpu, share_vars_from=train_exe)
B
kpi fix  
baiyfbupt 已提交
159
    if not args.enable_ce:
160
        train_reader = reader.batch_reader(data_args, train_file_list, batch_size, "train")
B
kpi fix  
baiyfbupt 已提交
161
    else:
162 163 164
        import random
        random.seed(0)
        np.random.seed(0)
165 166 167 168 169 170
        train_reader = reader.batch_reader(data_args, train_file_list, batch_size, "train", shuffle=False)
    test_reader = reader.batch_reader(data_args, val_file_list, batch_size, "test")
    train_py_reader.decorate_paddle_reader(train_reader)
    test_py_reader.decorate_paddle_reader(test_reader)

    def save_model(postfix, main_prog):
D
Dang Qingqing 已提交
171 172 173
        model_path = os.path.join(model_save_dir, postfix)
        if os.path.isdir(model_path):
            shutil.rmtree(model_path)
M
minqiyang 已提交
174
        print('save models to %s' % (model_path))
175
        fluid.io.save_persistables(exe, model_path, main_program=main_prog)
D
Dang Qingqing 已提交
176 177 178

    best_map = 0.
    def test(pass_id, best_map):
179 180
        _, accum_map = map_eval.get_map_var()
        map_eval.reset(exe)
B
kpi fix  
baiyfbupt 已提交
181
        every_pass_map=[]
182 183 184 185 186 187 188 189 190 191 192 193 194 195
        test_py_reader.start()
        batch_id = 0
        try:
            while True:
                test_map, = exe.run(test_prog,
                                   fetch_list=[accum_map])
                if batch_id % 20 == 0:
                    every_pass_map.append(test_map)
                    print("Batch {0}, map {1}".format(batch_id, test_map))
                batch_id += 1
                if batch_id > test_epocs:
                    break
        except fluid.core.EOFException:
            test_py_reader.reset()
B
kpi fix  
baiyfbupt 已提交
196
        mean_map = np.mean(every_pass_map)
D
Dang Qingqing 已提交
197 198
        if test_map[0] > best_map:
            best_map = test_map[0]
199
            save_model('best_model', test_prog)
X
Xingyuan Bu 已提交
200
        print("Pass {0}, test map {1}".format(pass_id, test_map))
B
kpi fix  
baiyfbupt 已提交
201
        return best_map, mean_map
B
baiyfbupt 已提交
202

203
    for pass_id in range(num_passes):
B
baiyfbupt 已提交
204
        batch_begin = time.time()
205
        start_time = time.time()
206
        train_py_reader.start()
B
baiyf 已提交
207
        every_pass_loss = []
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
        batch_id = 0
        try:
            while True:
                prev_start_time = start_time
                start_time = time.time()

                if args.parallel:
                    loss_v, = train_exe.run(fetch_list=[loss.name])
                else:
                    loss_v, = exe.run(train_prog, fetch_list=[loss])
                loss_v = np.mean(np.array(loss_v))
                every_pass_loss.append(loss_v)
                if batch_id % 20 == 0:
                    print("Pass {0}, batch {1}, loss {2}, time {3}".format(
                        pass_id, batch_id, loss_v, start_time - prev_start_time))
                batch_id += 1
                if batch_id > epocs:
                    break
        except fluid.core.EOFException:
            train_py_reader.reset()
B
baiyfbupt 已提交
228
        batch_end = time.time()
B
kpi fix  
baiyfbupt 已提交
229
        best_map, mean_map = test(pass_id, best_map)
230
        if args.enable_ce and pass_id == num_passes - 1:
B
baiyfbupt 已提交
231
            total_time = batch_end - batch_begin
B
baiyfbupt 已提交
232
            train_avg_loss = np.mean(every_pass_loss)
B
kpi fix  
baiyfbupt 已提交
233
            if devices_num == 1:
M
minqiyang 已提交
234 235
                print("kpis	train_cost	%s" % train_avg_loss)
                print("kpis	test_acc	%s" % mean_map)
B
baiyfbupt 已提交
236
                print("kpis	train_speed	%s" % (total_time / epocs))
B
baiyfbupt 已提交
237
            else:
M
minqiyang 已提交
238
                print("kpis	train_cost_card%s	%s" %
B
Bai Yifan 已提交
239
                       (devices_num, train_avg_loss))
M
minqiyang 已提交
240
                print("kpis	test_acc_card%s	%s" %
B
Bai Yifan 已提交
241
                       (devices_num, mean_map))
M
minqiyang 已提交
242
                print("kpis	train_speed_card%s	%f" %
B
baiyfbupt 已提交
243
                       (devices_num, total_time / test_epocs))
B
baiyfbupt 已提交
244

245
        if pass_id % 10 == 0 or pass_id == num_passes - 1:
246
            save_model(str(pass_id), train_prog)
D
Dang Qingqing 已提交
247
    print("Best test map {0}".format(best_map))
D
dangqingqing 已提交
248 249

if __name__ == '__main__':
250 251
    args = parser.parse_args()
    print_arguments(args)
252

B
baiyf 已提交
253
    data_dir = args.data_dir
254 255
    label_file = 'label_list'
    model_save_dir = args.model_save_dir
B
baiyf 已提交
256 257
    train_file_list = 'trainval.txt'
    val_file_list = 'test.txt'
258
    if 'coco' in args.dataset:
X
Xingyuan Bu 已提交
259
        data_dir = 'data/coco'
260 261
        if '2014' in args.dataset:
            train_file_list = 'annotations/instances_train2014.json'
X
Xingyuan Bu 已提交
262
            val_file_list = 'annotations/instances_val2014.json'
263 264 265
        elif '2017' in args.dataset:
            train_file_list = 'annotations/instances_train2017.json'
            val_file_list = 'annotations/instances_val2017.json'
266

D
dangqingqing 已提交
267
    data_args = reader.Settings(
268 269 270
        dataset=args.dataset,
        data_dir=data_dir,
        label_file=label_file,
X
Xingyuan Bu 已提交
271 272
        resize_h=args.resize_h,
        resize_w=args.resize_w,
D
Dang Qingqing 已提交
273
        mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
274 275 276
        apply_distort=args.apply_distort,
        apply_expand=args.apply_expand,
        ap_version = args.ap_version,
D
Dang Qingqing 已提交
277
        toy=args.is_toy)
X
Xingyuan Bu 已提交
278
    train(
D
Dang Qingqing 已提交
279 280 281 282 283 284 285 286 287
        args,
        train_file_list=train_file_list,
        val_file_list=val_file_list,
        data_args=data_args,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        num_passes=args.num_passes,
        model_save_dir=model_save_dir,
        pretrained_model=args.pretrained_model)