train.py 5.2 KB
Newer Older
1
import os
2
import shutil
3
import numpy as np
4
import time
5 6 7
import argparse
import functools

8
import reader
9 10 11 12 13 14 15
import paddle
import paddle.fluid as fluid
from pyramidbox import PyramidBox
from utility import add_arguments, print_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
16

17
# yapf: disable
18 19 20 21 22
add_arg('parallel', bool, True, "parallel")
add_arg('learning_rate', float, 0.0001, "Learning rate.")
add_arg('batch_size', int, 16, "Minibatch size.")
add_arg('num_passes', int, 120, "Epoch number.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
23
add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.")
24 25
add_arg('dataset', str, 'WIDERFACE', "coco2014, coco2017, and pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
26
add_arg('pretrained_model', str, './pretrained/', "The init model path.")
27 28
add_arg('resize_h', int, 640, "The resized image height.")
add_arg('resize_w', int, 640, "The resized image height.")
29 30 31
#yapf: enable


32
def train(args, data_args, learning_rate, batch_size, pretrained_model,
Q
qingqing01 已提交
33
          num_passes, optimizer_method):
34 35 36 37 38 39 40

    num_classes = 2

    devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
    devices_num = len(devices.split(","))

    image_shape = [3, data_args.resize_h, data_args.resize_w]
41

42 43 44 45 46 47
    if args.use_pyramidbox:
        network = PyramidBox(image_shape, sub_network=args.use_pyramidbox)
        face_loss, head_loss, loss = network.train()
    else:
        network = PyramidBox(image_shape, sub_network=args.use_pyramidbox)
        loss = network.vgg_ssd(num_classes, image_shape)
48 49 50 51 52 53 54

    epocs = 12880 / batch_size
    boundaries = [epocs * 100, epocs * 125, epocs * 150]
    values = [
        learning_rate, learning_rate * 0.1, learning_rate * 0.01,
        learning_rate * 0.001
    ]
55

Q
qingqing01 已提交
56 57 58 59 60 61 62 63 64 65 66 67
    if optimizer_method == "momentum":
        optimizer = fluid.optimizer.Momentum(
            learning_rate=fluid.layers.piecewise_decay(
                boundaries=boundaries, values=values),
            momentum=0.9,
            regularization=fluid.regularizer.L2Decay(0.0005),
        )
    else:
        optimizer = fluid.optimizer.RMSProp(
            learning_rate=fluid.layers.piecewise_decay(boundaries, values),
            regularization=fluid.regularizer.L2Decay(0.0005),
        )
68 69

    optimizer.minimize(loss)
70

71 72 73
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
74 75 76 77

    if pretrained_model:
        def if_exist(var):
            return os.path.exists(os.path.join(pretrained_model, var.name))
78
        print('Load pre-trained model.')
79 80
        fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)

81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
    if args.parallel:
        train_exe = fluid.ParallelExecutor(
            use_cuda=args.use_gpu, loss_name=loss.name)

    train_reader = paddle.batch(
        reader.train(data_args, train_file_list), batch_size=batch_size)
    feeder = fluid.DataFeeder(
        place=place,
        feed_list=[
            network.image, network.gt_box, network.gt_label, network.difficult
        ])

    def save_model(postfix):
        model_path = os.path.join(model_save_dir, postfix)
        if os.path.isdir(model_path):
            shutil.rmtree(model_path)
        print 'save models to %s' % (model_path)
        fluid.io.save_persistables(exe, model_path)

    best_map = 0.

    for pass_id in range(num_passes):
        start_time = time.time()
        prev_start_time = start_time
        end_time = 0
        for batch_id, data in enumerate(train_reader()):
            prev_start_time = start_time
            start_time = time.time()
            if len(data) < devices_num: continue
            if args.parallel:
                loss_v, = train_exe.run(fetch_list=[loss.name],
                                        feed=feeder.feed(data))
            else:
                loss_v, = exe.run(fluid.default_main_program(),
                                  feed=feeder.feed(data),
                                  fetch_list=[loss])
            end_time = time.time()
            loss_v = np.mean(np.array(loss_v))
            if batch_id % 1 == 0:
                print("Pass {0}, batch {1}, loss {2}, time {3}".format(
                    pass_id, batch_id, loss_v, start_time - prev_start_time))
        if pass_id % 10 == 0 or pass_id == num_passes - 1:
            save_model(str(pass_id))
    print("Best test map {0}".format(best_map))
125 126 127 128 129 130


if __name__ == '__main__':
    args = parser.parse_args()
    print_arguments(args)

131 132 133 134 135 136 137 138 139 140
    data_dir = 'data/WIDERFACE/WIDER_train/images/'
    train_file_list = 'label/train_gt_widerface.res'
    val_file_list = 'label/val_gt_widerface.res'
    model_save_dir = args.model_save_dir

    data_args = reader.Settings(
        dataset=args.dataset,
        data_dir=data_dir,
        resize_h=args.resize_h,
        resize_w=args.resize_w,
Q
qingqing01 已提交
141 142
        apply_expand=False,
        mean_value=[104., 117., 123],
143 144 145 146
        ap_version='11point')
    train(
        args,
        data_args=data_args,
Q
qingqing01 已提交
147
        learning_rate=args.learning_rate,
148 149
        batch_size=args.batch_size,
        pretrained_model=args.pretrained_model,
Q
qingqing01 已提交
150 151
        num_passes=args.num_passes,
        optimizer_method="momentum")