From 756c4c2e184116b96391554a8f911eaa36f76550 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Sat, 11 Jul 2020 17:13:05 +0800 Subject: [PATCH] Update train.py --- train.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index cd5fdef..da053df 100644 --- a/train.py +++ b/train.py @@ -83,6 +83,10 @@ def fit_one_epoch(net,focal_loss,epoch,epoch_size,epoch_size_val,gen,genval,Epoc torch.save(model.state_dict(), 'logs/Epoch%d-Total_Loss%.4f-Val_Loss%.4f.pth'%((epoch+1),total_loss/(epoch_size+1),val_loss/(epoch_size_val+1))) return val_loss/(epoch_size_val+1) +#----------------------------------------------------# +# 检测精度mAP和pr曲线计算参考视频 +# https://www.bilibili.com/video/BV1zE411u7Vw +#----------------------------------------------------# if __name__ == "__main__": #-------------------------------------------# # 训练前,请指定好phi和model_path @@ -106,6 +110,9 @@ if __name__ == "__main__": # 创建模型 model = EfficientDetBackbone(num_classes,phi) + #------------------------------------------------------# + # 权值文件请看README,百度网盘下载 + #------------------------------------------------------# model_path = "model_data/efficientdet-d0.pth" # 加快模型训练的效率 print('Loading weights into state dict...') @@ -135,6 +142,14 @@ if __name__ == "__main__": num_val = int(len(lines)*val_split) num_train = len(lines) - num_val + #------------------------------------------------------# + # 主干特征提取网络特征通用,冻结训练可以加快训练速度 + # 也可以在训练初期防止权值被破坏。 + # Init_Epoch为起始世代 + # Freeze_Epoch为冻结训练的世代 + # Epoch总训练世代 + # 提示OOM或者显存不足请调小Batch_size + #------------------------------------------------------# if True: #--------------------------------------------# # BATCH_SIZE不要太小,不然训练效果很差 @@ -208,4 +223,4 @@ if __name__ == "__main__": for epoch in range(Freeze_Epoch,Unfreeze_Epoch): val_loss = fit_one_epoch(net,efficient_loss,epoch,epoch_size,epoch_size_val,gen,gen_val,Unfreeze_Epoch,Cuda) - lr_scheduler.step(val_loss) \ No newline at end of file + lr_scheduler.step(val_loss) -- GitLab