From 25d9604e9d9ce5212cd0a205e9a9d6c244fd3514 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Wed, 11 Nov 2020 14:15:10 +0800 Subject: [PATCH] Update train_with_tensorboard.py --- train_with_tensorboard.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_with_tensorboard.py b/train_with_tensorboard.py index eba0c33..d85b4eb 100644 --- a/train_with_tensorboard.py +++ b/train_with_tensorboard.py @@ -79,7 +79,7 @@ def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epo # 将loss写入tensorboard,下面注释的是每个世代保存一次 # writer.add_scalar('Train_loss', total_loss/(iteration+1), epoch) - + net.eval() print('Start Validation') with tqdm(total=epoch_size_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar: for iteration, batch in enumerate(genval): @@ -108,7 +108,7 @@ def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epo pbar.set_postfix(**{'total_loss': val_loss.item() / (iteration + 1)}) pbar.update(1) - + net.train() # 将loss写入tensorboard,每个世代保存一次 writer.add_scalar('Val_loss',val_loss/(epoch_size_val+1), epoch) print('Finish Validation') -- GitLab