train.py 6.6 KB
Newer Older
1
import os
2
import shutil
3
import numpy as np
4
import time
5 6 7 8 9
import argparse
import functools

import paddle.fluid as fluid
from pyramidbox import PyramidBox
Q
qingqing01 已提交
10
import reader
11 12 13 14
from utility import add_arguments, print_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
15

16
# yapf: disable
Q
qingqing01 已提交
17 18 19
add_arg('parallel',         bool,  True,            "Whether use multi-GPU/threads or not.")
add_arg('learning_rate',    float, 0.001,           "The start learning rate.")
add_arg('batch_size',       int,   16,              "Minibatch size.")
20
add_arg('num_passes',       int,   160,             "Epoch number.")
Q
qingqing01 已提交
21 22 23 24
add_arg('use_gpu',          bool,  True,            "Whether use GPU.")
add_arg('use_pyramidbox',   bool,  True,            "Whether use PyramidBox model.")
add_arg('model_save_dir',   str,   'output',        "The path to save model.")
add_arg('resize_h',         int,   640,             "The resized image height.")
Q
qingqing01 已提交
25 26 27
add_arg('resize_w',         int,   640,             "The resized image width.")
add_arg('with_mem_opt',     bool,  True,            "Whether to use memory optimization or not.")
add_arg('pretrained_model', str,   './vgg_ilsvrc_16_fc_reduced/', "The init model path.")
28 29 30
#yapf: enable


Q
qingqing01 已提交
31 32 33 34 35 36 37 38 39 40
def train(args, config, train_file_list, optimizer_method):
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    num_passes = args.num_passes
    height = args.resize_h
    width = args.resize_w
    use_gpu = args.use_gpu
    use_pyramidbox = args.use_pyramidbox
    model_save_dir = args.model_save_dir
    pretrained_model = args.pretrained_model
41
    with_memory_optimization = args.with_mem_opt
42 43

    num_classes = 2
Q
qingqing01 已提交
44
    image_shape = [3, height, width]
45 46 47 48

    devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
    devices_num = len(devices.split(","))

49

50
    fetches = []
Q
qingqing01 已提交
51
    network = PyramidBox(image_shape, num_classes,
Q
qingqing01 已提交
52 53
                         sub_network=use_pyramidbox)
    if use_pyramidbox:
54
        face_loss, head_loss, loss = network.train()
55
        fetches = [face_loss, head_loss]
56
    else:
Q
qingqing01 已提交
57
        loss = network.vgg_ssd_loss()
58
        fetches = [loss]
59

60 61 62
    steps_per_pass = 12880 / batch_size
    boundaries = [steps_per_pass * 50, steps_per_pass * 80,
                  steps_per_pass * 120, steps_per_pass * 140]
63
    values = [
64 65
        learning_rate, learning_rate * 0.5, learning_rate * 0.25,
        learning_rate * 0.1, learning_rate * 0.01
66
    ]
67

Q
qingqing01 已提交
68 69 70 71 72 73 74 75 76 77 78 79
    if optimizer_method == "momentum":
        optimizer = fluid.optimizer.Momentum(
            learning_rate=fluid.layers.piecewise_decay(
                boundaries=boundaries, values=values),
            momentum=0.9,
            regularization=fluid.regularizer.L2Decay(0.0005),
        )
    else:
        optimizer = fluid.optimizer.RMSProp(
            learning_rate=fluid.layers.piecewise_decay(boundaries, values),
            regularization=fluid.regularizer.L2Decay(0.0005),
        )
80 81

    optimizer.minimize(loss)
82 83
    if with_memory_optimization:
        fluid.memory_optimize(fluid.default_main_program())
84

Q
qingqing01 已提交
85
    place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
86 87
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
88

89
    start_pass = 0
90
    if pretrained_model:
91 92
        if pretrained_model.isdigit():
            start_pass = int(pretrained_model) + 1
Q
qingqing01 已提交
93
            pretrained_model = os.path.join(model_save_dir, pretrained_model)
94 95
            print("Resume from %s " %(pretrained_model))

96 97 98
        if not os.path.exists(pretrained_model):
            raise ValueError("The pre-trained model path [%s] does not exist." %
                             (pretrained_model))
99 100 101 102
        def if_exist(var):
            return os.path.exists(os.path.join(pretrained_model, var.name))
        fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)

103 104
    if args.parallel:
        train_exe = fluid.ParallelExecutor(
Q
qingqing01 已提交
105
            use_cuda=use_gpu, loss_name=loss.name)
106

107
    train_reader = reader.train_batch_reader(config, train_file_list, batch_size=batch_size)
108 109 110 111 112 113 114 115

    def save_model(postfix):
        model_path = os.path.join(model_save_dir, postfix)
        if os.path.isdir(model_path):
            shutil.rmtree(model_path)
        print 'save models to %s' % (model_path)
        fluid.io.save_persistables(exe, model_path)

116 117 118 119 120 121 122
    def tensor(data, place, lod=None):
        t = fluid.core.LoDTensor()
        t.set(data, place)
        if lod:
            t.set_lod(lod)
        return t

123
    for pass_id in range(start_pass, num_passes):
124 125 126
        start_time = time.time()
        prev_start_time = start_time
        end_time = 0
127 128 129 130 131 132 133 134 135
        for batch_id in range(steps_per_pass):
            im, face_box, head_box, labels, lod = next(train_reader)
            im_t = tensor(im, place)
            box1 = tensor(face_box, place, [lod])
            box2 = tensor(head_box, place, [lod])
            lbl_t = tensor(labels, place, [lod])
            feeding = {'image': im_t, 'face_box': box1,
                       'head_box': box2, 'gt_label': lbl_t}

136 137 138
            prev_start_time = start_time
            start_time = time.time()
            if args.parallel:
139
                fetch_vars = train_exe.run(fetch_list=[v.name for v in fetches],
140
                                           feed=feeding)
141
            else:
142
                fetch_vars = exe.run(fluid.default_main_program(),
143
                                     feed=feeding,
144
                                     fetch_list=fetches)
145
            end_time = time.time()
146
            fetch_vars = [np.mean(np.array(v)) for v in fetch_vars]
Q
qingqing01 已提交
147
            if batch_id % 10 == 0:
148 149 150 151 152 153 154 155 156 157
                if not args.use_pyramidbox:
                    print("Pass {0}, batch {1}, loss {2}, time {3}".format(
                        pass_id, batch_id, fetch_vars[0],
                        start_time - prev_start_time))
                else:
                    print("Pass {0}, batch {1}, face loss {2}, head loss {3}, " \
                          "time {4}".format(pass_id,
                           batch_id, fetch_vars[0], fetch_vars[1],
                           start_time - prev_start_time))

Q
qingqing01 已提交
158
        if pass_id % 1 == 0 or pass_id == num_passes - 1:
159
            save_model(str(pass_id))
160 161 162 163 164 165


if __name__ == '__main__':
    args = parser.parse_args()
    print_arguments(args)

Q
qingqing01 已提交
166 167
    data_dir = 'data/WIDER_train/images/'
    train_file_list = 'data/wider_face_split/wider_face_train_bbx_gt.txt'
168

Q
qingqing01 已提交
169
    config = reader.Settings(
170 171 172
        data_dir=data_dir,
        resize_h=args.resize_h,
        resize_w=args.resize_w,
173
        apply_distort=True,
Q
qingqing01 已提交
174
        apply_expand=False,
175
        mean_value=[104., 117., 123.],
176
        ap_version='11point')
Q
qingqing01 已提交
177
    train(args, config, train_file_list, optimizer_method="momentum")