未验证 提交 763c56a5 编写于 作者: B Bubbliiiing 提交者: GitHub

Update train.py

上级 24def704
......@@ -63,6 +63,7 @@ if __name__ == "__main__":
False, Cuda)
epoch_size = num_train // Batch_size
if True:
# ------------------------------------#
# 冻结一定部分训练
......@@ -76,8 +77,11 @@ if __name__ == "__main__":
adjust_learning_rate(optimizer,lr,0.9,epoch)
loc_loss = 0
conf_loss = 0
for iteration in range(epoch_size):
images, targets = next(gen)
for iteration, batch in enumerate(gen):
if iteration >= epoch_size:
break
start_time = time.time()
images, targets = batch[0], batch[1]
with torch.no_grad():
if Cuda:
images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
......@@ -118,8 +122,11 @@ if __name__ == "__main__":
adjust_learning_rate(optimizer,freeze_lr,0.9,epoch)
loc_loss = 0
conf_loss = 0
for iteration in range(epoch_size):
images, targets = next(gen)
for iteration, batch in enumerate(gen):
if iteration >= epoch_size:
break
start_time = time.time()
images, targets = batch[0], batch[1]
with torch.no_grad():
if Cuda:
images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册