train.py 18.7 KB
Newer Older
B
Bubbliiiing 已提交
1
import keras.backend as K
J
JiaQi Xu 已提交
2 3
import numpy as np
import tensorflow as tf
B
Bubbliiiing 已提交
4 5 6
from keras.backend.tensorflow_backend import set_session
from keras.callbacks import (EarlyStopping, ModelCheckpoint, ReduceLROnPlateau,
                             TensorBoard)
J
JiaQi Xu 已提交
7 8 9
from keras.layers import Input, Lambda
from keras.models import Model
from keras.optimizers import Adam
B
Bubbliiiing 已提交
10

J
JiaQi Xu 已提交
11
from nets.loss import yolo_loss
B
Bubbliiiing 已提交
12 13 14
from nets.yolo4 import yolo_body
from utils.utils import (WarmUpCosineDecayScheduler, get_random_data,
                         get_random_data_with_Mosaic, rand)
J
JiaQi Xu 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36


#---------------------------------------------------#
#   获得类和先验框
#---------------------------------------------------#
def get_classes(classes_path):
    '''loads the classes'''
    with open(classes_path) as f:
        class_names = f.readlines()
    class_names = [c.strip() for c in class_names]
    return class_names

def get_anchors(anchors_path):
    '''loads the anchors from a file'''
    with open(anchors_path) as f:
        anchors = f.readline()
    anchors = [float(x) for x in anchors.split(',')]
    return np.array(anchors).reshape(-1, 2)

#---------------------------------------------------#
#   训练数据生成器
#---------------------------------------------------#
B
Bubbliiiing 已提交
37
def data_generator(annotation_lines, batch_size, input_shape, anchors, num_classes, mosaic=False, random=True):
J
JiaQi Xu 已提交
38 39 40 41 42 43 44 45 46 47 48
    n = len(annotation_lines)
    i = 0
    flag = True
    while True:
        image_data = []
        box_data = []
        for b in range(batch_size):
            if i==0:
                np.random.shuffle(annotation_lines)
            if mosaic:
                if flag and (i+4) < n:
J
JiaQi Xu 已提交
49
                    image, box = get_random_data_with_Mosaic(annotation_lines[i:i+4], input_shape)
B
Bubbliiiing 已提交
50
                    i = (i+1) % n
J
JiaQi Xu 已提交
51
                else:
B
Bubbliiiing 已提交
52
                    image, box = get_random_data(annotation_lines[i], input_shape, random=random)
J
JiaQi Xu 已提交
53 54 55
                    i = (i+1) % n
                flag = bool(1-flag)
            else:
B
Bubbliiiing 已提交
56
                image, box = get_random_data(annotation_lines[i], input_shape, random=random)
J
JiaQi Xu 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
                i = (i+1) % n
            image_data.append(image)
            box_data.append(box)
        image_data = np.array(image_data)
        box_data = np.array(box_data)
        y_true = preprocess_true_boxes(box_data, input_shape, anchors, num_classes)
        yield [image_data, *y_true], np.zeros(batch_size)

#---------------------------------------------------#
#   读入xml文件,并输出y_true
#---------------------------------------------------#
def preprocess_true_boxes(true_boxes, input_shape, anchors, num_classes):
    assert (true_boxes[..., 4]<num_classes).all(), 'class id must be less than num_classes'
    # 一共有三个特征层数
    num_layers = len(anchors)//3
B
Bubbliiiing 已提交
72 73 74 75 76 77
    #-----------------------------------------------------------#
    #   13x13的特征层对应的anchor是[142, 110], [192, 243], [459, 401]
    #   26x26的特征层对应的anchor是[36, 75], [76, 55], [72, 146]
    #   52x52的特征层对应的anchor是[12, 16], [19, 36], [40, 28]
    #-----------------------------------------------------------#
    anchor_mask = [[6,7,8], [3,4,5], [0,1,2]]
J
JiaQi Xu 已提交
78

B
Bubbliiiing 已提交
79 80 81
    #-----------------------------------------------------------#
    #   获得框的坐标和图片的大小
    #-----------------------------------------------------------#
J
JiaQi Xu 已提交
82
    true_boxes = np.array(true_boxes, dtype='float32')
B
Bubbliiiing 已提交
83 84 85 86 87
    input_shape = np.array(input_shape, dtype='int32')
    #-----------------------------------------------------------#
    #   通过计算获得真实框的中心和宽高
    #   中心点(m,n,2) 宽高(m,n,2)
    #-----------------------------------------------------------#
J
JiaQi Xu 已提交
88 89
    boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2
    boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]
B
Bubbliiiing 已提交
90 91 92
    #-----------------------------------------------------------#
    #   将真实框归一化到小数形式
    #-----------------------------------------------------------#
B
Bubbliiiing 已提交
93 94
    true_boxes[..., 0:2] = boxes_xy/input_shape[::-1]
    true_boxes[..., 2:4] = boxes_wh/input_shape[::-1]
J
JiaQi Xu 已提交
95

B
Bubbliiiing 已提交
96
    # m为图片数量,grid_shapes为网格的shape
J
JiaQi Xu 已提交
97 98
    m = true_boxes.shape[0]
    grid_shapes = [input_shape//{0:32, 1:16, 2:8}[l] for l in range(num_layers)]
B
Bubbliiiing 已提交
99 100 101
    #-----------------------------------------------------------#
    #   y_true的格式为(m,13,13,3,85)(m,26,26,3,85)(m,52,52,3,85)
    #-----------------------------------------------------------#
J
JiaQi Xu 已提交
102 103
    y_true = [np.zeros((m,grid_shapes[l][0],grid_shapes[l][1],len(anchor_mask[l]),5+num_classes),
        dtype='float32') for l in range(num_layers)]
B
Bubbliiiing 已提交
104 105 106 107

    #-----------------------------------------------------------#
    #   [9,2] -> [1,9,2]
    #-----------------------------------------------------------#
J
JiaQi Xu 已提交
108 109 110
    anchors = np.expand_dims(anchors, 0)
    anchor_maxes = anchors / 2.
    anchor_mins = -anchor_maxes
B
Bubbliiiing 已提交
111 112 113 114

    #-----------------------------------------------------------#
    #   长宽要大于0才有效
    #-----------------------------------------------------------#
J
JiaQi Xu 已提交
115 116 117 118 119 120
    valid_mask = boxes_wh[..., 0]>0

    for b in range(m):
        # 对每一张图进行处理
        wh = boxes_wh[b, valid_mask[b]]
        if len(wh)==0: continue
B
Bubbliiiing 已提交
121 122 123
        #-----------------------------------------------------------#
        #   [n,2] -> [n,1,2]
        #-----------------------------------------------------------#
J
JiaQi Xu 已提交
124 125 126 127
        wh = np.expand_dims(wh, -2)
        box_maxes = wh / 2.
        box_mins = -box_maxes

B
Bubbliiiing 已提交
128 129 130 131 132 133 134
        #-----------------------------------------------------------#
        #   计算所有真实框和先验框的交并比
        #   intersect_area  [n,9]
        #   box_area        [n,1]
        #   anchor_area     [1,9]
        #   iou             [n,9]
        #-----------------------------------------------------------#
J
JiaQi Xu 已提交
135 136 137 138
        intersect_mins = np.maximum(box_mins, anchor_mins)
        intersect_maxes = np.minimum(box_maxes, anchor_maxes)
        intersect_wh = np.maximum(intersect_maxes - intersect_mins, 0.)
        intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
B
Bubbliiiing 已提交
139

J
JiaQi Xu 已提交
140 141
        box_area = wh[..., 0] * wh[..., 1]
        anchor_area = anchors[..., 0] * anchors[..., 1]
B
Bubbliiiing 已提交
142

J
JiaQi Xu 已提交
143
        iou = intersect_area / (box_area + anchor_area - intersect_area)
B
Bubbliiiing 已提交
144 145 146
        #-----------------------------------------------------------#
        #   维度是[n,] 感谢 消尽不死鸟 的提醒
        #-----------------------------------------------------------#
J
JiaQi Xu 已提交
147 148 149
        best_anchor = np.argmax(iou, axis=-1)

        for t, n in enumerate(best_anchor):
B
Bubbliiiing 已提交
150 151 152
            #-----------------------------------------------------------#
            #   找到每个真实框所属的特征层
            #-----------------------------------------------------------#
J
JiaQi Xu 已提交
153 154
            for l in range(num_layers):
                if n in anchor_mask[l]:
B
Bubbliiiing 已提交
155 156 157 158 159 160 161 162
                    #-----------------------------------------------------------#
                    #   floor用于向下取整,找到真实框所属的特征层对应的x、y轴坐标
                    #-----------------------------------------------------------#
                    i = np.floor(true_boxes[b,t,0] * grid_shapes[l][1]).astype('int32')
                    j = np.floor(true_boxes[b,t,1] * grid_shapes[l][0]).astype('int32')
                    #-----------------------------------------------------------#
                    #   k指的的当前这个特征点的第k个先验框
                    #-----------------------------------------------------------#
J
JiaQi Xu 已提交
163
                    k = anchor_mask[l].index(n)
B
Bubbliiiing 已提交
164 165 166 167 168 169 170 171 172 173
                    #-----------------------------------------------------------#
                    #   c指的是当前这个真实框的种类
                    #-----------------------------------------------------------#
                    c = true_boxes[b, t, 4].astype('int32')
                    #-----------------------------------------------------------#
                    #   y_true的shape为(m,13,13,3,85)(m,26,26,3,85)(m,52,52,3,85)
                    #   最后的85可以拆分成4+1+80,4代表的是框的中心与宽高、
                    #   1代表的是置信度、80代表的是种类
                    #-----------------------------------------------------------#
                    y_true[l][b, j, i, k, 0:4] = true_boxes[b, t, 0:4]
J
JiaQi Xu 已提交
174 175 176 177 178
                    y_true[l][b, j, i, k, 4] = 1
                    y_true[l][b, j, i, k, 5+c] = 1

    return y_true

B
Bubbliiiing 已提交
179 180 181 182
#----------------------------------------------------#
#   检测精度mAP和pr曲线计算参考视频
#   https://www.bilibili.com/video/BV1zE411u7Vw
#----------------------------------------------------#
J
JiaQi Xu 已提交
183
if __name__ == "__main__":
B
Bubbliiiing 已提交
184 185 186
    #----------------------------------------------------#
    #   获得图片路径和标签
    #----------------------------------------------------#
J
JiaQi Xu 已提交
187
    annotation_path = '2007_train.txt'
B
Bubbliiiing 已提交
188 189 190 191 192 193 194 195
    #------------------------------------------------------#
    #   训练后的模型保存的位置,保存在logs文件夹里面
    #------------------------------------------------------#
    log_dir = 'logs/'
    #----------------------------------------------------#
    #   classes和anchor的路径,非常重要
    #   训练前一定要修改classes_path,使其对应自己的数据集
    #----------------------------------------------------#
J
JiaQi Xu 已提交
196 197
    classes_path = 'model_data/voc_classes.txt'    
    anchors_path = 'model_data/yolo_anchors.txt'
B
Bubbliiiing 已提交
198 199 200 201 202
    #------------------------------------------------------#
    #   权值文件请看README,百度网盘下载
    #   训练自己的数据集时提示维度不匹配正常
    #   预测的东西都不一样了自然维度不匹配
    #------------------------------------------------------#
J
JiaQi Xu 已提交
203
    weights_path = 'model_data/yolo4_weight.h5'
B
Bubbliiiing 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216
    #------------------------------------------------------#
    #   训练用图片大小
    #   一般在416x416和608x608选择
    #------------------------------------------------------#
    input_shape = (416,416)
    #------------------------------------------------------#
    #   是否对损失进行归一化
    #------------------------------------------------------#
    normalize = True

    #----------------------------------------------------#
    #   获取classes和anchor
    #----------------------------------------------------#
J
JiaQi Xu 已提交
217 218
    class_names = get_classes(classes_path)
    anchors = get_anchors(anchors_path)
B
Bubbliiiing 已提交
219 220 221
    #------------------------------------------------------#
    #   一共有多少类和多少先验框
    #------------------------------------------------------#
J
JiaQi Xu 已提交
222 223
    num_classes = len(class_names)
    num_anchors = len(anchors)
B
Bubbliiiing 已提交
224 225 226 227 228 229
    #------------------------------------------------------#
    #   Yolov4的tricks应用
    #   mosaic 马赛克数据增强 True or False
    #   Cosine_scheduler 余弦退火学习率 True or False
    #   label_smoothing 标签平滑 0.01以下一般 如0.01、0.005
    #------------------------------------------------------#
J
JiaQi Xu 已提交
230 231 232 233 234
    mosaic = True
    Cosine_scheduler = False
    label_smoothing = 0

    K.clear_session()
B
Bubbliiiing 已提交
235 236 237
    #------------------------------------------------------#
    #   创建yolo模型
    #------------------------------------------------------#
J
JiaQi Xu 已提交
238 239
    image_input = Input(shape=(None, None, 3))
    h, w = input_shape
B
Bubbliiiing 已提交
240
    print('Create YOLOv4 model with {} anchors and {} classes.'.format(num_anchors, num_classes))
J
JiaQi Xu 已提交
241 242
    model_body = yolo_body(image_input, num_anchors//3, num_classes)
    
B
Bubbliiiing 已提交
243 244 245
    #------------------------------------------------------#
    #   载入预训练权重
    #------------------------------------------------------#
J
JiaQi Xu 已提交
246 247 248
    print('Load weights {}.'.format(weights_path))
    model_body.load_weights(weights_path, by_name=True, skip_mismatch=True)
    
B
Bubbliiiing 已提交
249 250 251 252
    #------------------------------------------------------#
    #   在这个地方设置损失,将网络的输出结果传入loss函数
    #   把整个模型的输出作为loss
    #------------------------------------------------------#
J
JiaQi Xu 已提交
253 254 255 256
    y_true = [Input(shape=(h//{0:32, 1:16, 2:8}[l], w//{0:32, 1:16, 2:8}[l], \
        num_anchors//3, num_classes+5)) for l in range(3)]
    loss_input = [*model_body.output, *y_true]
    model_loss = Lambda(yolo_loss, output_shape=(1,), name='yolo_loss',
B
Bubbliiiing 已提交
257 258
        arguments={'anchors': anchors, 'num_classes': num_classes, 'ignore_thresh': 0.5, 
            'label_smoothing': label_smoothing, 'normalize': normalize})(loss_input)
J
JiaQi Xu 已提交
259 260 261

    model = Model([model_body.input, *y_true], model_loss)

B
Bubbliiiing 已提交
262 263 264 265 266 267 268
    #-------------------------------------------------------------------------------#
    #   训练参数的设置
    #   logging表示tensorboard的保存地址
    #   checkpoint用于设置权值保存的细节,period用于修改多少epoch保存一次
    #   reduce_lr用于设置学习率下降的方式
    #   early_stopping用于设定早停,val_loss多次不下降自动结束训练,表示模型基本收敛
    #-------------------------------------------------------------------------------#
J
JiaQi Xu 已提交
269 270 271
    logging = TensorBoard(log_dir=log_dir)
    checkpoint = ModelCheckpoint(log_dir + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',
        monitor='val_loss', save_weights_only=True, save_best_only=False, period=1)
B
Bubbliiiing 已提交
272
    early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1)
J
JiaQi Xu 已提交
273

B
Bubbliiiing 已提交
274 275 276 277 278
    #----------------------------------------------------------------------#
    #   验证集的划分在train.py代码里面进行
    #   2007_test.txt和2007_val.txt里面没有内容是正常的。训练不会使用到。
    #   当前划分方式下,验证集和训练集的比例为1:9
    #----------------------------------------------------------------------#
J
JiaQi Xu 已提交
279 280 281 282 283 284 285 286 287
    val_split = 0.1
    with open(annotation_path) as f:
        lines = f.readlines()
    np.random.seed(10101)
    np.random.shuffle(lines)
    np.random.seed(None)
    num_val = int(len(lines)*val_split)
    num_train = len(lines) - num_val
    
B
Bubbliiiing 已提交
288 289 290 291
    freeze_layers = 249
    for i in range(freeze_layers): model_body.layers[i].trainable = False
    print('Freeze the first {} layers of total {} layers.'.format(freeze_layers, len(model_body.layers)))

B
Bubbliiiing 已提交
292 293 294 295 296 297 298 299
    #------------------------------------------------------#
    #   主干特征提取网络特征通用,冻结训练可以加快训练速度
    #   也可以在训练初期防止权值被破坏。
    #   Init_Epoch为起始世代
    #   Freeze_Epoch为冻结训练的世代
    #   Epoch总训练世代
    #   提示OOM或者显存不足请调小Batch_size
    #------------------------------------------------------#
J
JiaQi Xu 已提交
300 301
    if True:
        Init_epoch = 0
B
Bubbliiiing 已提交
302
        Freeze_epoch = 50
B
Bubbliiiing 已提交
303
        batch_size = 8
J
JiaQi Xu 已提交
304
        learning_rate_base = 1e-3
B
Bubbliiiing 已提交
305

J
JiaQi Xu 已提交
306 307 308 309 310 311 312 313 314
        if Cosine_scheduler:
            # 预热期
            warmup_epoch = int((Freeze_epoch-Init_epoch)*0.2)
            # 总共的步长
            total_steps = int((Freeze_epoch-Init_epoch) * num_train / batch_size)
            # 预热步长
            warmup_steps = int(warmup_epoch * num_train / batch_size)
            # 学习率
            reduce_lr = WarmUpCosineDecayScheduler(learning_rate_base=learning_rate_base,
J
JiaQi Xu 已提交
315
                                                        total_steps=total_steps,
J
JiaQi Xu 已提交
316 317 318 319 320 321 322
                                                        warmup_learning_rate=1e-4,
                                                        warmup_steps=warmup_steps,
                                                        hold_base_rate_steps=num_train,
                                                        min_learn_rate=1e-6
                                                        )
            model.compile(optimizer=Adam(), loss={'yolo_loss': lambda y_true, y_pred: y_pred})
        else:
B
Bubbliiiing 已提交
323
            reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1)
J
JiaQi Xu 已提交
324 325 326
            model.compile(optimizer=Adam(learning_rate_base), loss={'yolo_loss': lambda y_true, y_pred: y_pred})

        print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
B
Bubbliiiing 已提交
327
        model.fit_generator(data_generator(lines[:num_train], batch_size, input_shape, anchors, num_classes, mosaic=mosaic, random=True),
B
Bubbliiiing 已提交
328
                steps_per_epoch=max(1, num_train//batch_size),
B
Bubbliiiing 已提交
329
                validation_data=data_generator(lines[num_train:], batch_size, input_shape, anchors, num_classes, mosaic=False, random=False),
J
JiaQi Xu 已提交
330
                validation_steps=max(1, num_val//batch_size),
J
JiaQi Xu 已提交
331 332
                epochs=Freeze_epoch,
                initial_epoch=Init_epoch,
J
JiaQi Xu 已提交
333 334 335 336 337 338
                callbacks=[logging, checkpoint, reduce_lr, early_stopping])
        model.save_weights(log_dir + 'trained_weights_stage_1.h5')

    for i in range(freeze_layers): model_body.layers[i].trainable = True

    if True:
B
Bubbliiiing 已提交
339 340
        Freeze_epoch = 50
        Epoch = 100
J
JiaQi Xu 已提交
341 342
        batch_size = 2
        learning_rate_base = 1e-4
B
Bubbliiiing 已提交
343

J
JiaQi Xu 已提交
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
        if Cosine_scheduler:
            # 预热期
            warmup_epoch = int((Epoch-Freeze_epoch)*0.2)
            # 总共的步长
            total_steps = int((Epoch-Freeze_epoch) * num_train / batch_size)
            # 预热步长
            warmup_steps = int(warmup_epoch * num_train / batch_size)
            # 学习率
            reduce_lr = WarmUpCosineDecayScheduler(learning_rate_base=learning_rate_base,
                                                        total_steps=total_steps,
                                                        warmup_learning_rate=1e-5,
                                                        warmup_steps=warmup_steps,
                                                        hold_base_rate_steps=num_train//2,
                                                        min_learn_rate=1e-6
                                                        )
            model.compile(optimizer=Adam(), loss={'yolo_loss': lambda y_true, y_pred: y_pred})
        else:
B
Bubbliiiing 已提交
361
            reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1)
J
JiaQi Xu 已提交
362 363 364
            model.compile(optimizer=Adam(learning_rate_base), loss={'yolo_loss': lambda y_true, y_pred: y_pred})

        print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
B
Bubbliiiing 已提交
365
        model.fit_generator(data_generator(lines[:num_train], batch_size, input_shape, anchors, num_classes, mosaic=mosaic, random=True),
B
Bubbliiiing 已提交
366
                steps_per_epoch=max(1, num_train//batch_size),
B
Bubbliiiing 已提交
367
                validation_data=data_generator(lines[num_train:], batch_size, input_shape, anchors, num_classes, mosaic=False, random=False),
J
JiaQi Xu 已提交
368
                validation_steps=max(1, num_val//batch_size),
J
JiaQi Xu 已提交
369 370 371 372
                epochs=Epoch,
                initial_epoch=Freeze_epoch,
                callbacks=[logging, checkpoint, reduce_lr, early_stopping])
        model.save_weights(log_dir + 'last1.h5')