未验证 提交 756c4c2e 编写于 作者: B Bubbliiiing 提交者: GitHub

Update train.py

上级 275a7713
......@@ -83,6 +83,10 @@ def fit_one_epoch(net,focal_loss,epoch,epoch_size,epoch_size_val,gen,genval,Epoc
torch.save(model.state_dict(), 'logs/Epoch%d-Total_Loss%.4f-Val_Loss%.4f.pth'%((epoch+1),total_loss/(epoch_size+1),val_loss/(epoch_size_val+1)))
return val_loss/(epoch_size_val+1)
#----------------------------------------------------#
# 检测精度mAP和pr曲线计算参考视频
# https://www.bilibili.com/video/BV1zE411u7Vw
#----------------------------------------------------#
if __name__ == "__main__":
#-------------------------------------------#
# 训练前,请指定好phi和model_path
......@@ -106,6 +110,9 @@ if __name__ == "__main__":
# 创建模型
model = EfficientDetBackbone(num_classes,phi)
#------------------------------------------------------#
# 权值文件请看README,百度网盘下载
#------------------------------------------------------#
model_path = "model_data/efficientdet-d0.pth"
# 加快模型训练的效率
print('Loading weights into state dict...')
......@@ -135,6 +142,14 @@ if __name__ == "__main__":
num_val = int(len(lines)*val_split)
num_train = len(lines) - num_val
#------------------------------------------------------#
# 主干特征提取网络特征通用,冻结训练可以加快训练速度
# 也可以在训练初期防止权值被破坏。
# Init_Epoch为起始世代
# Freeze_Epoch为冻结训练的世代
# Epoch总训练世代
# 提示OOM或者显存不足请调小Batch_size
#------------------------------------------------------#
if True:
#--------------------------------------------#
# BATCH_SIZE不要太小,不然训练效果很差
......@@ -208,4 +223,4 @@ if __name__ == "__main__":
for epoch in range(Freeze_Epoch,Unfreeze_Epoch):
val_loss = fit_one_epoch(net,efficient_loss,epoch,epoch_size,epoch_size_val,gen,gen_val,Unfreeze_Epoch,Cuda)
lr_scheduler.step(val_loss)
\ No newline at end of file
lr_scheduler.step(val_loss)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册