train.py 5.6 KB
Newer Older
1
import os
2
import shutil
3
import numpy as np
4
import time
5 6 7
import argparse
import functools

8
import reader
9 10 11 12 13 14 15
import paddle
import paddle.fluid as fluid
from pyramidbox import PyramidBox
from utility import add_arguments, print_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
16

17
# yapf: disable
18 19 20 21 22
add_arg('parallel', bool, True, "parallel")
add_arg('learning_rate', float, 0.0001, "Learning rate.")
add_arg('batch_size', int, 16, "Minibatch size.")
add_arg('num_passes', int, 120, "Epoch number.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
23
add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.")
24 25
add_arg('dataset', str, 'WIDERFACE', "coco2014, coco2017, and pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
26
add_arg('pretrained_model', str, './pretrained/', "The init model path.")
27 28
add_arg('resize_h', int, 640, "The resized image height.")
add_arg('resize_w', int, 640, "The resized image height.")
29 30 31
#yapf: enable


32
def train(args, data_args, learning_rate, batch_size, pretrained_model,
Q
qingqing01 已提交
33
          num_passes, optimizer_method):
34 35 36 37 38 39 40

    num_classes = 2

    devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
    devices_num = len(devices.split(","))

    image_shape = [3, data_args.resize_h, data_args.resize_w]
41

42
    fetches = []
43 44 45
    if args.use_pyramidbox:
        network = PyramidBox(image_shape, sub_network=args.use_pyramidbox)
        face_loss, head_loss, loss = network.train()
46
        fetches = [face_loss, head_loss]
47 48 49
    else:
        network = PyramidBox(image_shape, sub_network=args.use_pyramidbox)
        loss = network.vgg_ssd(num_classes, image_shape)
50
        fetches = [loss]
51 52 53 54 55 56 57

    epocs = 12880 / batch_size
    boundaries = [epocs * 100, epocs * 125, epocs * 150]
    values = [
        learning_rate, learning_rate * 0.1, learning_rate * 0.01,
        learning_rate * 0.001
    ]
58

Q
qingqing01 已提交
59 60 61 62 63 64 65 66 67 68 69 70
    if optimizer_method == "momentum":
        optimizer = fluid.optimizer.Momentum(
            learning_rate=fluid.layers.piecewise_decay(
                boundaries=boundaries, values=values),
            momentum=0.9,
            regularization=fluid.regularizer.L2Decay(0.0005),
        )
    else:
        optimizer = fluid.optimizer.RMSProp(
            learning_rate=fluid.layers.piecewise_decay(boundaries, values),
            regularization=fluid.regularizer.L2Decay(0.0005),
        )
71 72

    optimizer.minimize(loss)
73

74 75 76
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
77 78

    if pretrained_model:
79 80 81
        if not os.path.exists(pretrained_model):
            raise ValueError("The pre-trained model path [%s] does not exist." %
                             (pretrained_model))
82 83 84 85
        def if_exist(var):
            return os.path.exists(os.path.join(pretrained_model, var.name))
        fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)

86 87 88 89 90 91
    if args.parallel:
        train_exe = fluid.ParallelExecutor(
            use_cuda=args.use_gpu, loss_name=loss.name)

    train_reader = paddle.batch(
        reader.train(data_args, train_file_list), batch_size=batch_size)
92
    feeder = fluid.DataFeeder(place=place, feed_list=network.feeds())
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109

    def save_model(postfix):
        model_path = os.path.join(model_save_dir, postfix)
        if os.path.isdir(model_path):
            shutil.rmtree(model_path)
        print 'save models to %s' % (model_path)
        fluid.io.save_persistables(exe, model_path)

    for pass_id in range(num_passes):
        start_time = time.time()
        prev_start_time = start_time
        end_time = 0
        for batch_id, data in enumerate(train_reader()):
            prev_start_time = start_time
            start_time = time.time()
            if len(data) < devices_num: continue
            if args.parallel:
110 111
                fetch_vars = train_exe.run(fetch_list=[v.name for v in fetches],
                                           feed=feeder.feed(data))
112
            else:
113 114 115
                fetch_vars = exe.run(fluid.default_main_program(),
                                     feed=feeder.feed(data),
                                     fetch_list=fetches)
116
            end_time = time.time()
117
            fetch_vars = [np.mean(np.array(v)) for v in fetch_vars]
118
            if batch_id % 1 == 0:
119 120 121 122 123 124 125 126 127 128
                if not args.use_pyramidbox:
                    print("Pass {0}, batch {1}, loss {2}, time {3}".format(
                        pass_id, batch_id, fetch_vars[0],
                        start_time - prev_start_time))
                else:
                    print("Pass {0}, batch {1}, face loss {2}, head loss {3}, " \
                          "time {4}".format(pass_id,
                           batch_id, fetch_vars[0], fetch_vars[1],
                           start_time - prev_start_time))

129 130
        if pass_id % 10 == 0 or pass_id == num_passes - 1:
            save_model(str(pass_id))
131 132 133 134 135 136


if __name__ == '__main__':
    args = parser.parse_args()
    print_arguments(args)

137 138 139 140 141 142 143 144 145 146
    data_dir = 'data/WIDERFACE/WIDER_train/images/'
    train_file_list = 'label/train_gt_widerface.res'
    val_file_list = 'label/val_gt_widerface.res'
    model_save_dir = args.model_save_dir

    data_args = reader.Settings(
        dataset=args.dataset,
        data_dir=data_dir,
        resize_h=args.resize_h,
        resize_w=args.resize_w,
Q
qingqing01 已提交
147 148
        apply_expand=False,
        mean_value=[104., 117., 123],
149 150 151 152
        ap_version='11point')
    train(
        args,
        data_args=data_args,
Q
qingqing01 已提交
153
        learning_rate=args.learning_rate,
154 155
        batch_size=args.batch_size,
        pretrained_model=args.pretrained_model,
Q
qingqing01 已提交
156 157
        num_passes=args.num_passes,
        optimizer_method="momentum")