未验证 提交 559a296c 编写于 作者: B Bubbliiiing 提交者: GitHub

Update train.py

上级 f469fe52
......@@ -75,6 +75,10 @@ def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epo
print('Saving state, iter:', str(epoch+1))
torch.save(model.state_dict(), 'logs/Epoch%d-Total_Loss%.4f-Val_Loss%.4f.pth'%((epoch+1),total_loss/(epoch_size+1),val_loss/(epoch_size_val+1)))
#----------------------------------------------------#
# 检测精度mAP和pr曲线计算参考视频
# https://www.bilibili.com/video/BV1zE411u7Vw
#----------------------------------------------------#
if __name__ == "__main__":
# 参数初始化
annotation_path = '2007_train.txt'
......@@ -121,6 +125,14 @@ if __name__ == "__main__":
num_train = len(lines) - num_val
#------------------------------------------------------#
# 主干特征提取网络特征通用,冻结训练可以加快训练速度
# 也可以在训练初期防止权值被破坏。
# Init_Epoch为起始世代
# Freeze_Epoch为冻结训练的世代
# Epoch总训练世代
# 提示OOM或者显存不足请调小Batch_size
#------------------------------------------------------#
if True:
# 最开始使用1e-3的学习率可以收敛的更快
lr = 1e-3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册