train.py 14.1 KB
Newer Older
B
Bubbliiiing 已提交
1 2 3 4 5
#-------------------------------------#
#       对数据集进行训练
#-------------------------------------#
import os
import time
B
Bubbliiiing 已提交
6 7

import numpy as np
B
Bubbliiiing 已提交
8
import torch
B
Bubbliiiing 已提交
9
import torch.backends.cudnn as cudnn
B
Bubbliiiing 已提交
10 11
import torch.nn as nn
import torch.nn.functional as F
B
Bubbliiiing 已提交
12 13
import torch.optim as optim
from torch.autograd import Variable
B
Bubbliiiing 已提交
14
from torch.utils.data import DataLoader
B
Bubbliiiing 已提交
15
from tqdm import tqdm
B
Bubbliiiing 已提交
16

B
Bubbliiiing 已提交
17 18 19 20 21
from nets.yolo4 import YoloBody
from nets.yolo_training import Generator, YOLOLoss
from utils.dataloader import YoloDataset, yolo_dataset_collate


B
Bubbliiiing 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#---------------------------------------------------#
#   获得类和先验框
#---------------------------------------------------#
def get_classes(classes_path):
    '''loads the classes'''
    with open(classes_path) as f:
        class_names = f.readlines()
    class_names = [c.strip() for c in class_names]
    return class_names

def get_anchors(anchors_path):
    '''loads the anchors from a file'''
    with open(anchors_path) as f:
        anchors = f.readline()
    anchors = [float(x) for x in anchors.split(',')]
    return np.array(anchors).reshape([-1,3,2])[::-1,:,:]

B
Bubbliiiing 已提交
39 40 41 42
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

B
Bubbliiiing 已提交
43
        
B
Bubbliiiing 已提交
44
def fit_one_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epoch,cuda):
B
Bubbliiiing 已提交
45 46
    total_loss = 0
    val_loss = 0
B
Bubbliiiing 已提交
47 48

    net.train()
B
Bubbliiiing 已提交
49 50 51 52 53 54 55 56 57 58 59 60
    with tqdm(total=epoch_size,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen):
            if iteration >= epoch_size:
                break
            images, targets = batch[0], batch[1]
            with torch.no_grad():
                if cuda:
                    images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
                    targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
                else:
                    images = Variable(torch.from_numpy(images).type(torch.FloatTensor))
                    targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
B
Bubbliiiing 已提交
61 62 63 64

            #----------------------#
            #   清零梯度
            #----------------------#
J
JiaQi Xu 已提交
65
            optimizer.zero_grad()
B
Bubbliiiing 已提交
66 67 68
            #----------------------#
            #   前向传播
            #----------------------#
B
Bubbliiiing 已提交
69
            outputs = net(images)
J
JiaQi Xu 已提交
70
            losses = []
B
Bubbliiiing 已提交
71 72 73 74
            num_pos_all = 0
            #----------------------#
            #   计算损失
            #----------------------#
J
JiaQi Xu 已提交
75
            for i in range(3):
B
Bubbliiiing 已提交
76 77 78 79 80 81 82 83
                loss_item, num_pos = yolo_losses[i](outputs[i], targets)
                losses.append(loss_item)
                num_pos_all += num_pos

            loss = sum(losses) / num_pos_all
            #----------------------#
            #   反向传播
            #----------------------#
B
Bubbliiiing 已提交
84 85 86
            loss.backward()
            optimizer.step()

B
Bubbliiiing 已提交
87
            total_loss += loss.item()
B
Bubbliiiing 已提交
88
            
B
Bubbliiiing 已提交
89 90
            pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1), 
                                'lr'        : get_lr(optimizer)})
B
Bubbliiiing 已提交
91 92
            pbar.update(1)

B
Bubbliiiing 已提交
93
    net.eval()
B
Bubbliiiing 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    print('Start Validation')
    with tqdm(total=epoch_size_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
        for iteration, batch in enumerate(genval):
            if iteration >= epoch_size_val:
                break
            images_val, targets_val = batch[0], batch[1]

            with torch.no_grad():
                if cuda:
                    images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor)).cuda()
                    targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets_val]
                else:
                    images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor))
                    targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets_val]
                optimizer.zero_grad()
                outputs = net(images_val)
                losses = []
B
Bubbliiiing 已提交
111
                num_pos_all = 0
B
Bubbliiiing 已提交
112
                for i in range(3):
B
Bubbliiiing 已提交
113 114 115 116 117 118
                    loss_item, num_pos = yolo_losses[i](outputs[i], targets_val)
                    losses.append(loss_item)
                    num_pos_all += num_pos
                loss = sum(losses) / num_pos_all
                val_loss += loss.item()
            pbar.set_postfix(**{'total_loss': val_loss / (iteration + 1)})
B
Bubbliiiing 已提交
119
            pbar.update(1)
B
Bubbliiiing 已提交
120
    print('Finish Validation')
B
Bubbliiiing 已提交
121
    print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
B
Bubbliiiing 已提交
122 123 124 125 126
    print('Total Loss: %.4f || Val Loss: %.4f ' % (total_loss/(epoch_size+1),val_loss/(epoch_size_val+1)))

    print('Saving state, iter:', str(epoch+1))
    torch.save(model.state_dict(), 'logs/Epoch%d-Total_Loss%.4f-Val_Loss%.4f.pth'%((epoch+1),total_loss/(epoch_size+1),val_loss/(epoch_size_val+1)))

B
Bubbliiiing 已提交
127 128 129 130
#----------------------------------------------------#
#   检测精度mAP和pr曲线计算参考视频
#   https://www.bilibili.com/video/BV1zE411u7Vw
#----------------------------------------------------#
B
Bubbliiiing 已提交
131 132
if __name__ == "__main__":
    #-------------------------------#
B
Bubbliiiing 已提交
133 134
    #   是否使用Cuda
    #   没有GPU可以设置成False
B
Bubbliiiing 已提交
135 136
    #-------------------------------#
    Cuda = True
B
Bubbliiiing 已提交
137 138 139 140
    #-------------------------------#
    #   Dataloder的使用
    #-------------------------------#
    Use_Data_Loader = True
B
Bubbliiiing 已提交
141
    #------------------------------------------------------#
B
Bubbliiiing 已提交
142 143
    #   是否对损失进行归一化,用于改变loss的大小
    #   用于决定计算最终loss是除上batch_size还是除上正样本数量
B
Bubbliiiing 已提交
144
    #------------------------------------------------------#
B
Bubbliiiing 已提交
145
    normalize = False
B
Bubbliiiing 已提交
146
    #-------------------------------#
B
Bubbliiiing 已提交
147 148 149
    #   输入的shape大小
    #   显存比较小可以使用416x416
    #   显存比较大可以使用608x608
B
Bubbliiiing 已提交
150
    #-------------------------------#
B
Bubbliiiing 已提交
151 152 153 154 155 156
    input_shape = (416,416)

    #----------------------------------------------------#
    #   classes和anchor的路径,非常重要
    #   训练前一定要修改classes_path,使其对应自己的数据集
    #----------------------------------------------------#
B
Bubbliiiing 已提交
157 158
    anchors_path = 'model_data/yolo_anchors.txt'
    classes_path = 'model_data/voc_classes.txt'   
B
Bubbliiiing 已提交
159 160 161
    #----------------------------------------------------#
    #   获取classes和anchor
    #----------------------------------------------------#
B
Bubbliiiing 已提交
162 163 164 165
    class_names = get_classes(classes_path)
    anchors = get_anchors(anchors_path)
    num_classes = len(class_names)
    
B
Bubbliiiing 已提交
166 167
    #------------------------------------------------------#
    #   Yolov4的tricks应用
B
Bubbliiiing 已提交
168 169
    #   mosaic 马赛克数据增强 True or False 
    #   实际测试时mosaic数据增强并不稳定,所以默认为False
B
Bubbliiiing 已提交
170 171 172
    #   Cosine_scheduler 余弦退火学习率 True or False
    #   label_smoothing 标签平滑 0.01以下一般 如0.01、0.005
    #------------------------------------------------------#
B
Bubbliiiing 已提交
173
    mosaic = False
B
Bubbliiiing 已提交
174 175 176 177 178 179 180 181 182 183 184 185
    Cosine_lr = False
    smoooth_label = 0

    #------------------------------------------------------#
    #   创建yolo模型
    #   训练前一定要修改classes_path和对应的txt文件
    #------------------------------------------------------#
    model = YoloBody(len(anchors[0]), num_classes)

    #------------------------------------------------------#
    #   权值文件请看README,百度网盘下载
    #------------------------------------------------------#
B
Bubbliiiing 已提交
186 187
    model_path = "model_data/yolo4_weights.pth"
    print('Loading weights into state dict...')
B
Bubbliiiing 已提交
188
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
B
Bubbliiiing 已提交
189
    model_dict = model.state_dict()
B
Bubbliiiing 已提交
190
    pretrained_dict = torch.load(model_path, map_location=device)
B
Bubbliiiing 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    print('Finished!')

    net = model.train()

    if Cuda:
        net = torch.nn.DataParallel(model)
        cudnn.benchmark = True
        net = net.cuda()

    # 建立loss函数
    yolo_losses = []
    for i in range(3):
        yolo_losses.append(YOLOLoss(np.reshape(anchors,[-1,2]),num_classes, \
B
Bubbliiiing 已提交
207
                                (input_shape[1], input_shape[0]), smoooth_label, Cuda, normalize))
B
Bubbliiiing 已提交
208

B
Bubbliiiing 已提交
209 210 211 212 213 214 215 216 217
    #----------------------------------------------------#
    #   获得图片路径和标签
    #----------------------------------------------------#
    annotation_path = '2007_train.txt'
    #----------------------------------------------------------------------#
    #   验证集的划分在train.py代码里面进行
    #   2007_test.txt和2007_val.txt里面没有内容是正常的。训练不会使用到。
    #   当前划分方式下,验证集和训练集的比例为1:9
    #----------------------------------------------------------------------#
B
Bubbliiiing 已提交
218 219 220 221 222 223 224 225 226
    val_split = 0.1
    with open(annotation_path) as f:
        lines = f.readlines()
    np.random.seed(10101)
    np.random.shuffle(lines)
    np.random.seed(None)
    num_val = int(len(lines)*val_split)
    num_train = len(lines) - num_val
    
B
Bubbliiiing 已提交
227 228 229 230 231 232 233 234
    #------------------------------------------------------#
    #   主干特征提取网络特征通用,冻结训练可以加快训练速度
    #   也可以在训练初期防止权值被破坏。
    #   Init_Epoch为起始世代
    #   Freeze_Epoch为冻结训练的世代
    #   Epoch总训练世代
    #   提示OOM或者显存不足请调小Batch_size
    #------------------------------------------------------#
B
Bubbliiiing 已提交
235 236 237 238
    if True:
        lr = 1e-3
        Batch_size = 4
        Init_Epoch = 0
B
Bubbliiiing 已提交
239
        Freeze_Epoch = 50
B
Bubbliiiing 已提交
240
        
B
Bubbliiiing 已提交
241 242 243 244 245
        #----------------------------------------------------------------------------#
        #   我在实际测试时,发现optimizer的weight_decay起到了反作用,
        #   所以去除掉了weight_decay,大家也可以开起来试试,一般是weight_decay=5e-4
        #----------------------------------------------------------------------------#
        optimizer = optim.Adam(net.parameters(),lr)
B
Bubbliiiing 已提交
246 247 248
        if Cosine_lr:
            lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-5)
        else:
B
Bubbliiiing 已提交
249
            lr_scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=1,gamma=0.92)
B
Bubbliiiing 已提交
250

B
Bubbliiiing 已提交
251
        if Use_Data_Loader:
B
Bubbliiiing 已提交
252 253
            train_dataset = YoloDataset(lines[:num_train], (input_shape[0], input_shape[1]), mosaic=mosaic, is_train=True)
            val_dataset = YoloDataset(lines[num_train:], (input_shape[0], input_shape[1]), mosaic=False, is_train=False)
B
Bubbliiiing 已提交
254
            gen = DataLoader(train_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
B
Bubbliiiing 已提交
255
                                    drop_last=True, collate_fn=yolo_dataset_collate)
B
Bubbliiiing 已提交
256
            gen_val = DataLoader(val_dataset, shuffle=True, batch_size=Batch_size, num_workers=4,pin_memory=True, 
B
Bubbliiiing 已提交
257 258 259
                                    drop_last=True, collate_fn=yolo_dataset_collate)
        else:
            gen = Generator(Batch_size, lines[:num_train],
B
Bubbliiiing 已提交
260
                            (input_shape[0], input_shape[1])).generate(train=True, mosaic = mosaic)
B
Bubbliiiing 已提交
261
            gen_val = Generator(Batch_size, lines[num_train:],
B
Bubbliiiing 已提交
262
                            (input_shape[0], input_shape[1])).generate(train=False, mosaic = mosaic)
B
Bubbliiiing 已提交
263

B
Bubbliiiing 已提交
264 265 266 267 268 269 270 271 272
        epoch_size = max(1, num_train//Batch_size)
        epoch_size_val = num_val//Batch_size
        #------------------------------------#
        #   冻结一定部分训练
        #------------------------------------#
        for param in model.backbone.parameters():
            param.requires_grad = False

        for epoch in range(Init_Epoch,Freeze_Epoch):
B
Bubbliiiing 已提交
273
            fit_one_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,gen_val,Freeze_Epoch,Cuda)
B
Bubbliiiing 已提交
274 275 276 277 278
            lr_scheduler.step()

    if True:
        lr = 1e-4
        Batch_size = 2
B
Bubbliiiing 已提交
279 280
        Freeze_Epoch = 50
        Unfreeze_Epoch = 100
B
Bubbliiiing 已提交
281

B
Bubbliiiing 已提交
282 283 284 285 286
        #----------------------------------------------------------------------------#
        #   我在实际测试时,发现optimizer的weight_decay起到了反作用,
        #   所以去除掉了weight_decay,大家也可以开起来试试,一般是weight_decay=5e-4
        #----------------------------------------------------------------------------#
        optimizer = optim.Adam(net.parameters(),lr)
B
Bubbliiiing 已提交
287 288 289
        if Cosine_lr:
            lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-5)
        else:
B
Bubbliiiing 已提交
290
            lr_scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=1,gamma=0.92)
B
Bubbliiiing 已提交
291

B
Bubbliiiing 已提交
292
        if Use_Data_Loader:
B
Bubbliiiing 已提交
293 294
            train_dataset = YoloDataset(lines[:num_train], (input_shape[0], input_shape[1]), mosaic=mosaic, is_train=True)
            val_dataset = YoloDataset(lines[num_train:], (input_shape[0], input_shape[1]), mosaic=False, is_train=False)
B
Bubbliiiing 已提交
295
            gen = DataLoader(train_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
B
Bubbliiiing 已提交
296
                                    drop_last=True, collate_fn=yolo_dataset_collate)
B
Bubbliiiing 已提交
297
            gen_val = DataLoader(val_dataset, shuffle=True, batch_size=Batch_size, num_workers=4,pin_memory=True, 
B
Bubbliiiing 已提交
298 299 300
                                    drop_last=True, collate_fn=yolo_dataset_collate)
        else:
            gen = Generator(Batch_size, lines[:num_train],
B
Bubbliiiing 已提交
301
                            (input_shape[0], input_shape[1])).generate(train=True, mosaic = mosaic)
B
Bubbliiiing 已提交
302
            gen_val = Generator(Batch_size, lines[num_train:],
B
Bubbliiiing 已提交
303
                            (input_shape[0], input_shape[1])).generate(train=False, mosaic = mosaic)
B
Bubbliiiing 已提交
304

B
Bubbliiiing 已提交
305 306 307 308 309 310 311 312 313
        epoch_size = max(1, num_train//Batch_size)
        epoch_size_val = num_val//Batch_size
        #------------------------------------#
        #   解冻后训练
        #------------------------------------#
        for param in model.backbone.parameters():
            param.requires_grad = True

        for epoch in range(Freeze_Epoch,Unfreeze_Epoch):
B
Bubbliiiing 已提交
314
            fit_one_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,gen_val,Unfreeze_Epoch,Cuda)
B
Bubbliiiing 已提交
315
            lr_scheduler.step()