From 03212064963d7c0a35debcf92b6da432da15a550 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <3323290568@qq.com> Date: Sat, 23 Jul 2022 12:37:14 +0800 Subject: [PATCH] fix no object bugs --- nets/yolo_training.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/nets/yolo_training.py b/nets/yolo_training.py index 484ea20..85e33f9 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -403,14 +403,14 @@ class YOLOLoss(nn.Module): matching_anchs[i].append(all_anch[layer_idx]) for i in range(num_layer): - matching_bs[i] = torch.cat(matching_bs[i], dim=0) - matching_as[i] = torch.cat(matching_as[i], dim=0) - matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) - matching_gis[i] = torch.cat(matching_gis[i], dim=0) - matching_targets[i] = torch.cat(matching_targets[i], dim=0) - matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) - - return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs + matching_bs[i] = torch.cat(matching_bs[i], dim=0) if len(matching_bs[i]) != 0 else torch.Tensor(matching_bs[i]) + matching_as[i] = torch.cat(matching_as[i], dim=0) if len(matching_as[i]) != 0 else torch.Tensor(matching_as[i]) + matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) if len(matching_gjs[i]) != 0 else torch.Tensor(matching_gjs[i]) + matching_gis[i] = torch.cat(matching_gis[i], dim=0) if len(matching_gis[i]) != 0 else torch.Tensor(matching_gis[i]) + matching_targets[i] = torch.cat(matching_targets[i], dim=0) if len(matching_targets[i]) != 0 else torch.Tensor(matching_targets[i]) + matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) if len(matching_anchs[i]) != 0 else torch.Tensor(matching_anchs[i]) + + return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs def find_3_positive(self, predictions, targets): #------------------------------------# -- GitLab