未验证 提交 cb538495 编写于 作者: B Bubbliiiing 提交者: GitHub

Add files via upload

上级 a78c8232
......@@ -10,6 +10,8 @@ import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from utils.dataloader import yolo_dataset_collate, YoloDataset
from nets.yolo_training import YOLOLoss,Generator
from nets.yolo4 import YoloBody
......@@ -34,10 +36,10 @@ def get_anchors(anchors_path):
def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epoch,cuda):
total_loss = 0
val_loss = 0
start_time = time.time()
for iteration, batch in enumerate(gen):
if iteration >= epoch_size:
break
start_time = time.time()
images, targets = batch[0], batch[1]
with torch.no_grad():
if cuda:
......@@ -60,6 +62,7 @@ def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epo
waste_time = time.time() - start_time
print('\nEpoch:'+ str(epoch+1) + '/' + str(Epoch))
print('iter:' + str(iteration) + '/' + str(epoch_size) + ' || Total Loss: %.4f || %.4fs/step' % (total_loss/(iteration+1),waste_time))
start_time = time.time()
print('Start Validation')
for iteration, batch in enumerate(genval):
......@@ -106,6 +109,10 @@ if __name__ == "__main__":
# 用于设定是否使用cuda
Cuda = True
smoooth_label = 0
#-------------------------------#
# Dataloder的使用
#-------------------------------#
Use_Data_Loader = True
annotation_path = '2007_train.txt'
#-------------------------------#
......@@ -165,11 +172,19 @@ if __name__ == "__main__":
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=1,gamma=0.9)
gen = Generator(Batch_size, lines[:num_train],
(input_shape[0], input_shape[1])).generate(mosaic = mosaic)
gen_val = Generator(Batch_size, lines[num_train:],
(input_shape[0], input_shape[1])).generate(mosaic = False)
if Use_Data_Loader:
train_dataset = YoloDataset(lines[:num_train], (input_shape[0], input_shape[1]), mosaic=mosaic)
val_dataset = YoloDataset(lines[num_train:], (input_shape[0], input_shape[1]), mosaic=False)
gen = DataLoader(train_dataset, batch_size=Batch_size, num_workers=8, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
gen_val = DataLoader(val_dataset, batch_size=Batch_size, num_workers=8,pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
else:
gen = Generator(Batch_size, lines[:num_train],
(input_shape[0], input_shape[1])).generate(mosaic = mosaic)
gen_val = Generator(Batch_size, lines[num_train:],
(input_shape[0], input_shape[1])).generate(mosaic = False)
epoch_size = max(1, num_train//Batch_size)
epoch_size_val = num_val//Batch_size
#------------------------------------#
......@@ -194,11 +209,19 @@ if __name__ == "__main__":
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=1,gamma=0.9)
gen = Generator(Batch_size, lines[:num_train],
(input_shape[0], input_shape[1])).generate(mosaic = mosaic)
gen_val = Generator(Batch_size, lines[num_train:],
(input_shape[0], input_shape[1])).generate(mosaic = False)
if Use_Data_Loader:
train_dataset = YoloDataset(lines[:num_train], (input_shape[0], input_shape[1]), mosaic=mosaic)
val_dataset = YoloDataset(lines[num_train:], (input_shape[0], input_shape[1]), mosaic=False)
gen = DataLoader(train_dataset, batch_size=Batch_size, num_workers=8, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
gen_val = DataLoader(val_dataset, batch_size=Batch_size, num_workers=8,pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
else:
gen = Generator(Batch_size, lines[:num_train],
(input_shape[0], input_shape[1])).generate(mosaic = mosaic)
gen_val = Generator(Batch_size, lines[num_train:],
(input_shape[0], input_shape[1])).generate(mosaic = False)
epoch_size = max(1, num_train//Batch_size)
epoch_size_val = num_val//Batch_size
#------------------------------------#
......@@ -209,4 +232,4 @@ if __name__ == "__main__":
for epoch in range(Freeze_Epoch,Unfreeze_Epoch):
fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,gen_val,Unfreeze_Epoch,Cuda)
lr_scheduler.step()
lr_scheduler.step()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册