diff --git a/nets/yolo_training.py b/nets/yolo_training.py index d965e42fe250dc82579137e50b34aed96b7683f2..f59278c08c9a42b15f197d00d5d3721e672e62f1 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -448,6 +448,7 @@ class YOLOLoss(nn.Module): # anchors_i [num_anchor, 2] #----------------------------------------------------# anchors_i = torch.from_numpy(self.anchors[i] / self.stride[i]).type_as(predictions[i]) + anchors_i, shape = torch.from_numpy(self.anchors[i] / self.stride[i]).type_as(predictions[i]), predictions[i].shape #-------------------------------------------# # 计算获得对应特征层的高宽 #-------------------------------------------# @@ -509,7 +510,7 @@ class YOLOLoss(nn.Module): # a代表属于该特征点的第几个先验框 #-------------------------------------------# a = t[:, 6].long() # anchor indices - indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices + indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid indices anchors.append(anchors_i[a]) # anchors return indices, anchors