train.py 5.9 KB
Newer Older
1
import os
2
import shutil
3
import numpy as np
4
import time
5 6 7
import argparse
import functools

8
import reader
9 10 11 12 13 14 15
import paddle
import paddle.fluid as fluid
from pyramidbox import PyramidBox
from utility import add_arguments, print_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
16

17
# yapf: disable
Q
qingqing01 已提交
18 19 20 21 22 23 24 25 26 27
add_arg('parallel',         bool,  True,            "parallel")
add_arg('learning_rate',    float, 0.001,           "Learning rate.")
add_arg('batch_size',       int,   12,              "Minibatch size.")
add_arg('num_passes',       int,   120,             "Epoch number.")
add_arg('use_gpu',          bool,  True,            "Whether use GPU.")
add_arg('use_pyramidbox',   bool,  True,            "Whether use PyramidBox model.")
add_arg('model_save_dir',   str,   'output',        "The path to save model.")
add_arg('pretrained_model', str,   './pretrained/', "The init model path.")
add_arg('resize_h',         int,   640,             "The resized image height.")
add_arg('resize_w',         int,   640,             "The resized image height.")
28 29 30
#yapf: enable


Q
qingqing01 已提交
31 32 33 34 35 36 37 38 39 40
def train(args, config, train_file_list, optimizer_method):
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    num_passes = args.num_passes
    height = args.resize_h
    width = args.resize_w
    use_gpu = args.use_gpu
    use_pyramidbox = args.use_pyramidbox
    model_save_dir = args.model_save_dir
    pretrained_model = args.pretrained_model
41 42

    num_classes = 2
Q
qingqing01 已提交
43
    image_shape = [3, height, width]
44 45 46 47

    devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
    devices_num = len(devices.split(","))

48

49
    fetches = []
Q
qingqing01 已提交
50
    network = PyramidBox(image_shape, num_classes,
Q
qingqing01 已提交
51 52
                         sub_network=use_pyramidbox)
    if use_pyramidbox:
53
        face_loss, head_loss, loss = network.train()
54
        fetches = [face_loss, head_loss]
55
    else:
Q
qingqing01 已提交
56
        loss = network.vgg_ssd_loss()
57
        fetches = [loss]
58 59

    epocs = 12880 / batch_size
60
    boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100]
61
    values = [
62 63
        learning_rate, learning_rate * 0.5, learning_rate * 0.25,
        learning_rate * 0.1, learning_rate * 0.01
64
    ]
65

Q
qingqing01 已提交
66 67 68 69 70 71 72 73 74 75 76 77
    if optimizer_method == "momentum":
        optimizer = fluid.optimizer.Momentum(
            learning_rate=fluid.layers.piecewise_decay(
                boundaries=boundaries, values=values),
            momentum=0.9,
            regularization=fluid.regularizer.L2Decay(0.0005),
        )
    else:
        optimizer = fluid.optimizer.RMSProp(
            learning_rate=fluid.layers.piecewise_decay(boundaries, values),
            regularization=fluid.regularizer.L2Decay(0.0005),
        )
78 79

    optimizer.minimize(loss)
Q
qingqing01 已提交
80
    #fluid.memory_optimize(fluid.default_main_program())
81

Q
qingqing01 已提交
82
    place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
83 84
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
85

86
    start_pass = 0
87
    if pretrained_model:
88 89
        if pretrained_model.isdigit():
            start_pass = int(pretrained_model) + 1
Q
qingqing01 已提交
90
            pretrained_model = os.path.join(model_save_dir, pretrained_model)
91 92
            print("Resume from %s " %(pretrained_model))

93 94 95
        if not os.path.exists(pretrained_model):
            raise ValueError("The pre-trained model path [%s] does not exist." %
                             (pretrained_model))
96 97 98 99
        def if_exist(var):
            return os.path.exists(os.path.join(pretrained_model, var.name))
        fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)

100 101
    if args.parallel:
        train_exe = fluid.ParallelExecutor(
Q
qingqing01 已提交
102
            use_cuda=use_gpu, loss_name=loss.name)
103 104

    train_reader = paddle.batch(
Q
qingqing01 已提交
105
        reader.train(config, train_file_list), batch_size=batch_size)
106
    feeder = fluid.DataFeeder(place=place, feed_list=network.feeds())
107 108 109 110 111 112 113 114

    def save_model(postfix):
        model_path = os.path.join(model_save_dir, postfix)
        if os.path.isdir(model_path):
            shutil.rmtree(model_path)
        print 'save models to %s' % (model_path)
        fluid.io.save_persistables(exe, model_path)

115
    for pass_id in range(start_pass, num_passes):
116 117 118 119 120 121
        start_time = time.time()
        prev_start_time = start_time
        end_time = 0
        for batch_id, data in enumerate(train_reader()):
            prev_start_time = start_time
            start_time = time.time()
122
            if len(data) < 2 * devices_num: continue
123
            if args.parallel:
124 125
                fetch_vars = train_exe.run(fetch_list=[v.name for v in fetches],
                                           feed=feeder.feed(data))
126
            else:
127 128 129
                fetch_vars = exe.run(fluid.default_main_program(),
                                     feed=feeder.feed(data),
                                     fetch_list=fetches)
130
            end_time = time.time()
131
            fetch_vars = [np.mean(np.array(v)) for v in fetch_vars]
132
            if batch_id % 1 == 0:
133 134 135 136 137 138 139 140 141 142
                if not args.use_pyramidbox:
                    print("Pass {0}, batch {1}, loss {2}, time {3}".format(
                        pass_id, batch_id, fetch_vars[0],
                        start_time - prev_start_time))
                else:
                    print("Pass {0}, batch {1}, face loss {2}, head loss {3}, " \
                          "time {4}".format(pass_id,
                           batch_id, fetch_vars[0], fetch_vars[1],
                           start_time - prev_start_time))

Q
qingqing01 已提交
143
        if pass_id % 1 == 0 or pass_id == num_passes - 1:
144
            save_model(str(pass_id))
145 146 147 148 149 150


if __name__ == '__main__':
    args = parser.parse_args()
    print_arguments(args)

151 152 153
    data_dir = 'data/WIDERFACE/WIDER_train/images/'
    train_file_list = 'label/train_gt_widerface.res'

Q
qingqing01 已提交
154
    config = reader.Settings(
155 156 157
        data_dir=data_dir,
        resize_h=args.resize_h,
        resize_w=args.resize_w,
Q
qingqing01 已提交
158 159
        apply_expand=False,
        mean_value=[104., 117., 123],
160
        ap_version='11point')
Q
qingqing01 已提交
161
    train(args, config, train_file_list, optimizer_method="momentum")