diff --git a/nets/yolo_training.py b/nets/yolo_training.py index f59278c08c9a42b15f197d00d5d3721e672e62f1..4a4bd00e300b84b1ceb41e7fca0395f22599cbf2 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -384,7 +384,7 @@ class YOLOLoss(nn.Module): #-------------------------------------------# # 取出符合条件的框 #-------------------------------------------# - from_which_layer = from_which_layer[fg_mask_inboxes] + from_which_layer = from_which_layer.to(fg_mask_inboxes.device)[fg_mask_inboxes] all_b = all_b[fg_mask_inboxes] all_a = all_a[fg_mask_inboxes] all_gj = all_gj[fg_mask_inboxes]