diff --git a/nets/yolo_training.py b/nets/yolo_training.py index eac5964c8e6c254d772ff6cddc68d9c496f988d3..484ea20ef880a52a3f6b3eb7a566a93cb9b6cd0a 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -177,7 +177,7 @@ class YOLOLoss(nn.Module): bs = tobj.shape[0] loss = box_loss + obj_loss + cls_loss - return loss * bs + return loss def xywh2xyxy(self, x): # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2]