diff --git a/train.py b/train.py index 64ed193844d17ad6204c208790c69e54ae3c4373..242be509cf2a85d115b0f679ce53fbd983ec8010 100644 --- a/train.py +++ b/train.py @@ -63,6 +63,7 @@ if __name__ == "__main__": False, Cuda) epoch_size = num_train // Batch_size + if True: # ------------------------------------# # 冻结一定部分训练 @@ -76,8 +77,11 @@ if __name__ == "__main__": adjust_learning_rate(optimizer,lr,0.9,epoch) loc_loss = 0 conf_loss = 0 - for iteration in range(epoch_size): - images, targets = next(gen) + for iteration, batch in enumerate(gen): + if iteration >= epoch_size: + break + start_time = time.time() + images, targets = batch[0], batch[1] with torch.no_grad(): if Cuda: images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda() @@ -118,8 +122,11 @@ if __name__ == "__main__": adjust_learning_rate(optimizer,freeze_lr,0.9,epoch) loc_loss = 0 conf_loss = 0 - for iteration in range(epoch_size): - images, targets = next(gen) + for iteration, batch in enumerate(gen): + if iteration >= epoch_size: + break + start_time = time.time() + images, targets = batch[0], batch[1] with torch.no_grad(): if Cuda: images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()