diff --git a/nets/yolo_training.py b/nets/yolo_training.py index 484ea20ef880a52a3f6b3eb7a566a93cb9b6cd0a..85e33f930b337cafdf14036818946a291e09cfe1 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -403,14 +403,14 @@ class YOLOLoss(nn.Module): matching_anchs[i].append(all_anch[layer_idx]) for i in range(num_layer): - matching_bs[i] = torch.cat(matching_bs[i], dim=0) - matching_as[i] = torch.cat(matching_as[i], dim=0) - matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) - matching_gis[i] = torch.cat(matching_gis[i], dim=0) - matching_targets[i] = torch.cat(matching_targets[i], dim=0) - matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) - - return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs + matching_bs[i] = torch.cat(matching_bs[i], dim=0) if len(matching_bs[i]) != 0 else torch.Tensor(matching_bs[i]) + matching_as[i] = torch.cat(matching_as[i], dim=0) if len(matching_as[i]) != 0 else torch.Tensor(matching_as[i]) + matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) if len(matching_gjs[i]) != 0 else torch.Tensor(matching_gjs[i]) + matching_gis[i] = torch.cat(matching_gis[i], dim=0) if len(matching_gis[i]) != 0 else torch.Tensor(matching_gis[i]) + matching_targets[i] = torch.cat(matching_targets[i], dim=0) if len(matching_targets[i]) != 0 else torch.Tensor(matching_targets[i]) + matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) if len(matching_anchs[i]) != 0 else torch.Tensor(matching_anchs[i]) + + return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs def find_3_positive(self, predictions, targets): #------------------------------------#