train.py 8.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
import os
import sys
import math
import time
import argparse
import functools
import numpy as np
import paddle
import paddle.fluid as fluid
import models
from losses import tripletloss
from losses import quadrupletloss
from losses import emlloss
from losses.metrics import recall_topk
from utility import add_arguments, print_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
K
kbChen 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
add_arg('train_batch_size', int, 80, "Minibatch size.")
add_arg('test_batch_size', int, 10, "Minibatch size.")
add_arg('num_epochs', int, 120, "number of epochs.")
add_arg('image_shape', str, "3,224,224", "input image size")
add_arg('model_save_dir', str, "output", "model save directory")
add_arg('with_mem_opt', bool, True,
        "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('lr', float, 0.1, "set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay",
        "Set the learning rate decay strategy.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
add_arg('loss_name', str, "tripletloss", "Set the loss type to use.")
add_arg('samples_each_class', int, 2, "Samples each class.")
add_arg('margin', float, 0.1, "margin.")
add_arg('alpha', float, 0.0, "alpha.")
37 38 39 40 41 42
# yapf: enable

model_list = [m for m in dir(models) if "__" not in m]

def optimizer_setting(params):
    ls = params["learning_strategy"]
K
kbChen 已提交
43 44 45
    assert ls["name"] == "piecewise_decay", \
           "learning rate strategy must be {}, \
           but got {}".format("piecewise_decay", lr["name"])
46 47 48 49 50

    step = 10000
    bd = [step * e for e in ls["epochs"]]
    base_lr = params["lr"]
    lr = []
K
kbChen 已提交
51
    lr = [base_lr * (0.1 ** i) for i in range(len(bd) + 1)]
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
    optimizer = fluid.optimizer.Momentum(
        learning_rate=fluid.layers.piecewise_decay(
            boundaries=bd, values=lr),
        momentum=0.9,
        regularization=fluid.regularizer.L2Decay(1e-4))

    return optimizer

def train(args):
    # parameters from arguments
    model_name = args.model
    checkpoint = args.checkpoint
    pretrained_model = args.pretrained_model
    with_memory_optimization = args.with_mem_opt
    model_save_dir = args.model_save_dir
    loss_name = args.loss_name

    image_shape = [int(m) for m in args.image_shape.split(",")]

K
kbChen 已提交
71
    assert model_name in model_list, "{} is not in lists: {}".format(args.model, model_list)
72 73 74

    image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int64')
K
kbChen 已提交
75
    
76 77 78 79 80 81
    # model definition
    model = models.__dict__[model_name]()
    out = model.net(input=image, class_dim=200)

    if loss_name == "tripletloss":
        metricloss = tripletloss(
K
kbChen 已提交
82 83
                train_batch_size = args.train_batch_size, 
                margin=args.margin)
84 85 86 87
        cost_metric = metricloss.loss(out[0])
        avg_cost_metric = fluid.layers.mean(x=cost_metric)
    elif loss_name == "quadrupletloss":
        metricloss = quadrupletloss(
K
kbChen 已提交
88 89 90
                train_batch_size = args.train_batch_size,
                samples_each_class = args.samples_each_class,
                margin=args.margin)
91 92 93 94
        cost_metric = metricloss.loss(out[0])
        avg_cost_metric = fluid.layers.mean(x=cost_metric)
    elif loss_name == "emlloss":
        metricloss = emlloss(
K
kbChen 已提交
95
                train_batch_size = args.train_batch_size, 
K
kebinC 已提交
96
                samples_each_class = args.samples_each_class
K
kbChen 已提交
97
        )
98 99 100 101 102 103 104
        cost_metric = metricloss.loss(out[0])
        avg_cost_metric = fluid.layers.mean(x=cost_metric)

    cost_cls = fluid.layers.cross_entropy(input=out[1], label=label)
    avg_cost_cls = fluid.layers.mean(x=cost_cls)
    acc_top1 = fluid.layers.accuracy(input=out[1], label=label, k=1)
    acc_top5 = fluid.layers.accuracy(input=out[1], label=label, k=5)
K
kbChen 已提交
105 106
    avg_cost = avg_cost_metric + args.alpha*avg_cost_cls
    
107 108 109 110

    test_program = fluid.default_main_program().clone(for_test=True)

    # parameters from model and arguments
K
kbChen 已提交
111
    params = model.params    
112 113 114 115 116 117 118 119 120 121 122
    params["lr"] = args.lr
    params["num_epochs"] = args.num_epochs
    params["learning_strategy"]["batch_size"] = args.train_batch_size
    params["learning_strategy"]["name"] = args.lr_strategy

    # initialize optimizer
    optimizer = optimizer_setting(params)
    opts = optimizer.minimize(avg_cost)

    global_lr = optimizer._global_learning_rate()

K
kbChen 已提交
123
    place = fluid.CUDAPlace(0)
124 125 126 127 128 129 130
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

    if checkpoint is not None:
        fluid.io.load_persistables(exe, checkpoint)

    if pretrained_model:
K
kbChen 已提交
131
        assert(checkpoint is None)
132
        def if_exist(var):
K
kbChen 已提交
133 134 135 136
            has_var = os.path.exists(os.path.join(pretrained_model, var.name))
            if has_var:
                print('var: %s found' % (var.name))
            return has_var
137 138
        fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)

K
kbChen 已提交
139 140
    train_reader = paddle.batch(metricloss.train_reader, batch_size=args.train_batch_size)
    test_reader = paddle.batch(metricloss.test_reader, batch_size=args.test_batch_size)
141 142
    feeder = fluid.DataFeeder(place=place, feed_list=[image, label])

K
kbChen 已提交
143
    train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
144

K
kbChen 已提交
145
    fetch_list_train = [avg_cost_metric.name, avg_cost_cls.name, acc_top1.name, acc_top5.name, global_lr.name]
146 147 148
    fetch_list_test = [out[0].name]

    if with_memory_optimization:
K
kbChen 已提交
149
        fluid.memory_optimize(fluid.default_main_program(), skip_opt_set=set(fetch_list_train))
150 151 152 153 154

    for pass_id in range(params["num_epochs"]):
        train_info = [[], [], [], []]
        for batch_id, data in enumerate(train_reader()):
            t1 = time.time()
K
kbChen 已提交
155
            loss_metric, loss_cls, acc1, acc5, lr = train_exe.run(fetch_list_train, feed=feeder.feed(data))
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
            t2 = time.time()
            period = t2 - t1
            loss_metric = np.mean(np.array(loss_metric))
            loss_cls = np.mean(np.array(loss_cls))
            acc1 = np.mean(np.array(acc1))
            acc5 = np.mean(np.array(acc5))
            lr = np.mean(np.array(lr))
            train_info[0].append(loss_metric)
            train_info[1].append(loss_cls)
            train_info[2].append(acc1)
            train_info[3].append(acc5)
            if batch_id % 10 == 0:
                print("Pass {0}, trainbatch {1}, lr {2}, loss_metric {3}, loss_cls {4}, acc1 {5}, acc5 {6}, time {7}".format(pass_id,  \
                      batch_id, lr, loss_metric, loss_cls, acc1, acc5, "%2.2f sec" % period))

        train_loss_metric = np.array(train_info[0]).mean()
        train_loss_cls = np.array(train_info[1]).mean()
        train_acc1 = np.array(train_info[2]).mean()
        train_acc5 = np.array(train_info[3]).mean()
        f = []
        l = []
K
kbChen 已提交
177 178 179
        for batch_id, data in enumerate(test_reader()):
            if len(data) < args.test_batch_size:
                continue
180
            t1 = time.time()
K
kbChen 已提交
181 182
            [feas] = exe.run(test_program, fetch_list = fetch_list_test, feed=feeder.feed(data))
            label = np.asarray([x[1] for x in data])
183 184 185 186 187
            f.append(feas)
            l.append(label)

            t2 = time.time()
            period = t2 - t1
K
kbChen 已提交
188 189 190
            if batch_id % 20 == 0:
                print("Pass {0}, testbatch {1}, time {2}".format(pass_id,  \
                      batch_id, "%2.2f sec" % period))
191 192

        f = np.vstack(f)
K
kbChen 已提交
193 194
        l = np.hstack(l)
        recall = recall_topk(f, l, k = 1)
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
        print("End pass {0}, train_loss_metric {1}, train_loss_cls {2}, train_acc1 {3}, train_acc5 {4}, test_recall {5}".format(pass_id,  \
              train_loss_metric, train_loss_cls, train_acc1, train_acc5, recall))
        sys.stdout.flush()

        model_path = os.path.join(model_save_dir + '/' + model_name,
                                  str(pass_id))
        if not os.path.isdir(model_path):
            os.makedirs(model_path)
        fluid.io.save_persistables(exe, model_path)

def main():
    args = parser.parse_args()
    print_arguments(args)
    train(args)

if __name__ == '__main__':
    main()