提交 0b930c81 编写于 作者: Bubbliiiing's avatar Bubbliiiing

update val

上级 16c84698
......@@ -77,12 +77,11 @@ def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callbac
for iteration, batch in enumerate(gen_val):
if iteration >= epoch_step_val:
break
images, targets, y_trues = batch[0], batch[1], batch[2]
images, targets = batch[0], batch[1]
with torch.no_grad():
if cuda:
images = images.cuda(local_rank)
targets = [ann.cuda(local_rank) for ann in targets]
y_trues = [ann.cuda(local_rank) for ann in y_trues]
targets = targets.cuda(local_rank)
#----------------------#
# 清零梯度
#----------------------#
......@@ -91,15 +90,7 @@ def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callbac
# 前向传播
#----------------------#
outputs = model_train_eval(images)
loss_value_all = 0
#----------------------#
# 计算损失
#----------------------#
for l in range(len(outputs)):
loss_item = yolo_loss(l, outputs[l], targets, y_trues[l])
loss_value_all += loss_item
loss_value = loss_value_all
loss_value = yolo_loss(outputs, targets, images)
val_loss += loss_value.item()
if local_rank == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册