未验证 提交 763c56a5 编写于 作者: B Bubbliiiing 提交者: GitHub

Update train.py

上级 24def704
...@@ -63,6 +63,7 @@ if __name__ == "__main__": ...@@ -63,6 +63,7 @@ if __name__ == "__main__":
False, Cuda) False, Cuda)
epoch_size = num_train // Batch_size epoch_size = num_train // Batch_size
if True: if True:
# ------------------------------------# # ------------------------------------#
# 冻结一定部分训练 # 冻结一定部分训练
...@@ -76,8 +77,11 @@ if __name__ == "__main__": ...@@ -76,8 +77,11 @@ if __name__ == "__main__":
adjust_learning_rate(optimizer,lr,0.9,epoch) adjust_learning_rate(optimizer,lr,0.9,epoch)
loc_loss = 0 loc_loss = 0
conf_loss = 0 conf_loss = 0
for iteration in range(epoch_size): for iteration, batch in enumerate(gen):
images, targets = next(gen) if iteration >= epoch_size:
break
start_time = time.time()
images, targets = batch[0], batch[1]
with torch.no_grad(): with torch.no_grad():
if Cuda: if Cuda:
images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda() images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
...@@ -118,8 +122,11 @@ if __name__ == "__main__": ...@@ -118,8 +122,11 @@ if __name__ == "__main__":
adjust_learning_rate(optimizer,freeze_lr,0.9,epoch) adjust_learning_rate(optimizer,freeze_lr,0.9,epoch)
loc_loss = 0 loc_loss = 0
conf_loss = 0 conf_loss = 0
for iteration in range(epoch_size): for iteration, batch in enumerate(gen):
images, targets = next(gen) if iteration >= epoch_size:
break
start_time = time.time()
images, targets = batch[0], batch[1]
with torch.no_grad(): with torch.no_grad():
if Cuda: if Cuda:
images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda() images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册