diff --git a/mrcnn/model.py b/mrcnn/model.py index 0b485dfea5d4e6a591df07dfc8f38584721ee376..62cb2b0951a200a40c56c2db4d650b8566f191d9 100644 --- a/mrcnn/model.py +++ b/mrcnn/model.py @@ -1495,8 +1495,8 @@ def build_rpn_targets(image_shape, anchors, gt_class_ids, gt_boxes, config): anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax] rpn_match[(anchor_iou_max < 0.3) & (no_crowd_bool)] = -1 # 2. Set an anchor for each GT box (regardless of IoU value). - # TODO: If multiple anchors have the same IoU match all of them - gt_iou_argmax = np.argmax(overlaps, axis=0) + # If multiple anchors have the same IoU match all of them + gt_iou_argmax = np.argwhere(overlaps == np.max(overlaps, axis=0))[:,0] rpn_match[gt_iou_argmax] = 1 # 3. Set anchors with high overlap as positive. rpn_match[anchor_iou_max >= 0.7] = 1