From 0f579fe05a5c4b606b5a03627bdcef527bfcaeb3 Mon Sep 17 00:00:00 2001 From: Waleed Abdulla Date: Sun, 12 Nov 2017 23:18:26 -0800 Subject: [PATCH] Cleanup: split GT boxes and class IDs tensors. --- inspect_data.ipynb | 11 +-- inspect_model.ipynb | 21 +++-- model.py | 182 ++++++++++++++++++++++++-------------------- train_shapes.ipynb | 9 ++- 4 files changed, 119 insertions(+), 104 deletions(-) diff --git a/inspect_data.ipynb b/inspect_data.ipynb index 3de3cfd..2cfd219 100644 --- a/inspect_data.ipynb +++ b/inspect_data.ipynb @@ -438,11 +438,12 @@ ], "source": [ "image_id = np.random.choice(dataset.image_ids, 1)[0]\n", - "image, image_meta, bbox, mask = modellib.load_image_gt(\n", + "image, image_meta, class_ids, bbox, mask = modellib.load_image_gt(\n", " dataset, config, image_id, use_mini_mask=False)\n", "\n", "log(\"image\", image)\n", "log(\"image_meta\", image_meta)\n", + "log(\"class_ids\", class_ids)\n", "log(\"bbox\", bbox)\n", "log(\"mask\", mask)\n", "\n", @@ -466,7 +467,7 @@ } ], "source": [ - "visualize.display_instances(image, bbox[:,:4], mask, bbox[:,4], dataset.class_names)" + "visualize.display_instances(image, bbox, mask, class_ids, dataset.class_names)" ] }, { @@ -496,7 +497,7 @@ ], "source": [ "# Add augmentation and mask resizing.\n", - "image, image_meta, bbox, mask = modellib.load_image_gt(\n", + "image, image_meta, class_ids, bbox, mask = modellib.load_image_gt(\n", " dataset, config, image_id, augment=True, use_mini_mask=True)\n", "log(\"mask\", mask)\n", "display_images([image]+[mask[:,:,i] for i in range(min(mask.shape[-1], 7))])" @@ -520,7 +521,7 @@ ], "source": [ "mask = utils.expand_mask(bbox, mask, image.shape)\n", - "visualize.display_instances(image, bbox[:,:4], mask, bbox[:,4], dataset.class_names)" + "visualize.display_instances(image, bbox, mask, class_ids, dataset.class_names)" ] }, { @@ -634,7 +635,7 @@ "\n", "# Load and draw random image\n", "image_id = np.random.choice(dataset.image_ids, 1)[0]\n", - "image, image_meta, _, _ = modellib.load_image_gt(dataset, config, image_id)\n", + "image, image_meta, _, _, _ = modellib.load_image_gt(dataset, config, image_id)\n", "fig, ax = plt.subplots(1, figsize=(10, 10))\n", "ax.imshow(image)\n", "levels = len(config.BACKBONE_SHAPES)\n", diff --git a/inspect_model.ipynb b/inspect_model.ipynb index 93b3a87..7da2f7e 100644 --- a/inspect_model.ipynb +++ b/inspect_model.ipynb @@ -313,21 +313,17 @@ ], "source": [ "image_id = random.choice(dataset.image_ids)\n", - "image, image_meta, gt_bbox, gt_mask =\\\n", + "image, image_meta, gt_class_id, gt_bbox, gt_mask =\\\n", " modellib.load_image_gt(dataset, config, image_id, use_mini_mask=False)\n", "info = dataset.image_info[image_id]\n", "print(\"image ID: {}.{} ({}) {}\".format(info[\"source\"], info[\"id\"], image_id, \n", " dataset.image_reference(image_id)))\n", - "gt_class_id = gt_bbox[:, 4]\n", - "\n", "# Run object detection\n", "results = model.detect([image], verbose=1)\n", "\n", "# Display results\n", "ax = get_ax(1)\n", "r = results[0]\n", - "# visualize.display_instances(image, gt_bbox[:,:4], gt_mask, gt_bbox[:,4], \n", - "# dataset.class_names, ax=ax[0], title=\"Ground Truth\")\n", "visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], \n", " dataset.class_names, r['scores'], ax=ax,\n", " title=\"Predictions\")\n", @@ -361,7 +357,7 @@ ], "source": [ "# Draw precision-recall curve\n", - "AP, precisions, recalls, overlaps = utils.compute_ap(gt_bbox[:,:4], gt_bbox[:,4], \n", + "AP, precisions, recalls, overlaps = utils.compute_ap(gt_bbox, gt_class_id, \n", " r['rois'], r['class_ids'], r['scores'])\n", "visualize.plot_precision_recall(AP, precisions, recalls)" ] @@ -384,7 +380,7 @@ ], "source": [ "# Grid of ground truth objects and their predictions\n", - "visualize.plot_overlaps(gt_bbox[:, 4], r['class_ids'], r['scores'],\n", + "visualize.plot_overlaps(gt_class_id, r['class_ids'], r['scores'],\n", " overlaps, dataset.class_names)" ] }, @@ -422,7 +418,7 @@ " APs = []\n", " for image_id in image_ids:\n", " # Load image\n", - " image, image_meta, gt_bbox, gt_mask =\\\n", + " image, image_meta, gt_class_id, gt_bbox, gt_mask =\\\n", " modellib.load_image_gt(dataset, config,\n", " image_id, use_mini_mask=False)\n", " # Run object detection\n", @@ -430,7 +426,7 @@ " # Compute AP\n", " r = results[0]\n", " AP, precisions, recalls, overlaps =\\\n", - " utils.compute_ap(gt_bbox[:,:4], gt_bbox[:,4],\n", + " utils.compute_ap(gt_bbox, gt_class_id,\n", " r['rois'], r['class_ids'], r['scores'])\n", " APs.append(AP)\n", " return APs\n", @@ -493,7 +489,7 @@ "# target_rpn_match is 1 for positive anchors, -1 for negative anchors\n", "# and 0 for neutral anchors.\n", "target_rpn_match, target_rpn_bbox = modellib.build_rpn_targets(\n", - " image.shape, model.anchors, gt_bbox, model.config)\n", + " image.shape, model.anchors, gt_class_id, gt_bbox, model.config)\n", "log(\"target_rpn_match\", target_rpn_match)\n", "log(\"target_rpn_bbox\", target_rpn_bbox)\n", "\n", @@ -568,8 +564,9 @@ "pillar = model.keras_model.get_layer(\"ROI\").output # node to start searching from\n", "\n", "# TF 1.4 introduces a new version of NMS. Search for both names to support TF 1.3 and 1.4\n", - "nms_node = (model.ancestor(pillar, \"ROI/rpn_non_max_suppression:0\")\n", - " or model.ancestor(pillar, \"ROI/rpn_non_max_suppression/NonMaxSuppressionV2:0\"))\n", + "nms_node = model.ancestor(pillar, \"ROI/rpn_non_max_suppression:0\")\n", + "if nms_node is None:\n", + " nms_node = model.ancestor(pillar, \"ROI/rpn_non_max_suppression/NonMaxSuppressionV2:0\")\n", "\n", "rpn = model.run_graph([image], [\n", " (\"rpn_class\", model.keras_model.get_layer(\"rpn_class\").output),\n", diff --git a/model.py b/model.py index f21f4c3..825a789 100644 --- a/model.py +++ b/model.py @@ -419,15 +419,15 @@ class PyramidROIAlign(KE.Layer): # Detection Target Layer ############################################################ -def detection_targets_graph(proposals, gt_boxes, gt_masks, config): +def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config): """Generates detection targets for one image. Subsamples proposals and generates target class IDs, bounding box deltas, and masks for each. Inputs: proposals: [N, (y1, x1, y2, x2)] in normalized coordinates. Might be zero padded if there are not enough proposals. - gt_boxes: [MAX_GT_INSTANCES, (y1, x1, y2, x2, class_id)] in - normalized coordinates. + gt_class_ids: [MAX_GT_INSTANCES] int class IDs + gt_boxes: [MAX_GT_INSTANCES, (y1, x1, y2, x2)] in normalized coordinates. gt_masks: [height, width, MAX_GT_INSTANCES] of boolean type. Returns: Target ROIs and corresponding class IDs, bounding box shifts, @@ -452,6 +452,8 @@ def detection_targets_graph(proposals, gt_boxes, gt_masks, config): # Remove zero padding proposals, _ = trim_zeros_graph(proposals, name="trim_proposals") gt_boxes, non_zeros = trim_zeros_graph(gt_boxes, name="trim_gt_boxes") + gt_class_ids = tf.boolean_mask(gt_class_ids, non_zeros, + name="trim_gt_class_ids") gt_masks = tf.gather(gt_masks, tf.where(non_zeros)[:, 0], axis=2, name="trim_gt_masks") @@ -461,12 +463,12 @@ def detection_targets_graph(proposals, gt_boxes, gt_masks, config): # allows us to compare every ROI against every GT box without loops. # TF doesn't have an equivalent to np.repeate() so simulate it # using tf.tile() and tf.reshape. - rois = tf.reshape(tf.tile(tf.expand_dims(proposals, 1), + rois = tf.reshape(tf.tile(tf.expand_dims(proposals, 1), [1, 1, tf.shape(gt_boxes)[0]]), [-1, 4]) boxes = tf.tile(gt_boxes, [tf.shape(proposals)[0], 1]) # 2. Compute intersections roi_y1, roi_x1, roi_y2, roi_x2 = tf.split(rois, 4, axis=1) - box_y1, box_x1, box_y2, box_x2, class_ids = tf.split(boxes, 5, axis=1) + box_y1, box_x1, box_y2, box_x2 = tf.split(boxes, 4, axis=1) y1 = tf.maximum(roi_y1, box_y1) x1 = tf.maximum(roi_x1, box_x1) y2 = tf.minimum(roi_y2, box_y2) @@ -503,9 +505,10 @@ def detection_targets_graph(proposals, gt_boxes, gt_masks, config): positive_overlaps = tf.gather(overlaps, positive_indices) roi_gt_box_assignment = tf.argmax(positive_overlaps, axis=1) roi_gt_boxes = tf.gather(gt_boxes, roi_gt_box_assignment) + roi_gt_class_ids = tf.gather(gt_class_ids, roi_gt_box_assignment) # Compute bbox refinement for positive ROIs - deltas = utils.box_refinement_graph(positive_rois, roi_gt_boxes[:,:4]) + deltas = utils.box_refinement_graph(positive_rois, roi_gt_boxes) deltas /= config.BBOX_STD_DEV # Assign positive ROIs to GT masks @@ -520,7 +523,7 @@ def detection_targets_graph(proposals, gt_boxes, gt_masks, config): # Transform ROI corrdinates from normalized image space # to normalized mini-mask space. y1, x1, y2, x2 = tf.split(positive_rois, 4, axis=1) - gt_y1, gt_x1, gt_y2, gt_x2, _ = tf.split(roi_gt_boxes, 5, axis=1) + gt_y1, gt_x1, gt_y2, gt_x2 = tf.split(roi_gt_boxes, 4, axis=1) gt_h = gt_y2 - gt_y1 gt_w = gt_x2 - gt_x1 y1 = (y1 - gt_y1) / gt_h @@ -546,10 +549,11 @@ def detection_targets_graph(proposals, gt_boxes, gt_masks, config): P = tf.maximum(config.TRAIN_ROIS_PER_IMAGE - tf.shape(rois)[0], 0) rois = tf.pad(rois, [(0, P), (0, 0)]) roi_gt_boxes = tf.pad(roi_gt_boxes, [(0, N+P), (0, 0)]) + roi_gt_class_ids = tf.pad(roi_gt_class_ids, [(0, N+P)]) deltas = tf.pad(deltas, [(0, N+P), (0, 0)]) masks = tf.pad(masks, [[0, N+P], (0, 0), (0, 0)]) - return rois, roi_gt_boxes[:, 4], deltas, masks + return rois, roi_gt_class_ids, deltas, masks class DetectionTargetLayer(KE.Layer): @@ -559,8 +563,9 @@ class DetectionTargetLayer(KE.Layer): Inputs: proposals: [batch, N, (y1, x1, y2, x2)] in normalized coordinates. Might be zero padded if there are not enough proposals. - gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2, class_id)] in - normalized coordinates. + gt_class_ids: [batch, MAX_GT_INSTANCES] Integer class IDs. + gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)] in normalized + coordinates. gt_masks: [batch, height, width, MAX_GT_INSTANCES] of boolean type Returns: Target ROIs and corresponding class IDs, bounding box shifts, @@ -583,16 +588,16 @@ class DetectionTargetLayer(KE.Layer): def call(self, inputs): proposals = inputs[0] - gt_boxes = inputs[1] - gt_masks = inputs[2] + gt_class_ids = inputs[1] + gt_boxes = inputs[2] + gt_masks = inputs[3] # Slice the batch and run a graph for each slice - # TODO: Optimize by supporting batch > 1 # TODO: Rename target_bbox to target_deltas for clarity names = ["rois", "target_class_ids", "target_bbox", "target_mask"] outputs = utils.batch_slice( - [proposals, gt_boxes, gt_masks], - lambda x, y, z: detection_targets_graph(x, y, z, self.config), + [proposals, gt_class_ids, gt_boxes, gt_masks], + lambda w, x, y, z: detection_targets_graph(w, x, y, z, self.config), self.config.IMAGES_PER_GPU, names=names) return outputs @@ -1107,7 +1112,8 @@ def load_image_gt(dataset, config, image_id, augment=False, Returns: image: [height, width, 3] shape: the original shape of the image before resizing and cropping. - bbox: [instance_count, (y1, x1, y2, x2, class_id)] + class_ids: [instance_count] Integer class IDs + bbox: [instance_count, (y1, x1, y2, x2)] mask: [height, width, instance_count]. The height and width are those of the image unless use_mini_mask is True, in which case they are defined in MINI_MASK_SHAPE. @@ -1134,15 +1140,12 @@ def load_image_gt(dataset, config, image_id, augment=False, # bbox: [num_instances, (y1, x1, y2, x2)] bbox = utils.extract_bboxes(mask) - # Add class_id as the last value in bbox - bbox = np.hstack([bbox, class_ids[:, np.newaxis]]) - # Active classes # Different datasets have different classes, so track the # classes supported in the dataset of this image. active_class_ids = np.zeros([dataset.num_classes], dtype=np.int32) - class_ids = dataset.source_class_ids[dataset.image_info[image_id]["source"]] - active_class_ids[class_ids] = 1 + source_class_ids = dataset.source_class_ids[dataset.image_info[image_id]["source"]] + active_class_ids[source_class_ids] = 1 # Resize masks to smaller size to reduce memory usage if use_mini_mask: @@ -1151,27 +1154,31 @@ def load_image_gt(dataset, config, image_id, augment=False, # Image meta data image_meta = compose_image_meta(image_id, shape, window, active_class_ids) - return image, image_meta, bbox, mask + return image, image_meta, class_ids, bbox, mask -def build_detection_targets(rpn_rois, gt_boxes, gt_masks, config): +def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config): """Generate targets for training Stage 2 classifier and mask heads. + This is not used in normal training. It's useful for debugging or to train + the Mask RCNN heads without using the RPN head. Inputs: rpn_rois: [N, (y1, x1, y2, x2)] proposal boxes. - gt_boxes: [instance count, (y1, x1, y2, x2, class_id)] + gt_class_ids: [instance count] Integer class IDs + gt_boxes: [instance count, (y1, x1, y2, x2)] gt_masks: [height, width, instance count] Grund truth masks. Can be full size or mini-masks. Returns: rois: [TRAIN_ROIS_PER_IMAGE, (y1, x1, y2, x2)] - class_ids: [TRAIN_ROIS_PER_IMAGE]. Int class IDs. - bboxes: [TRAIN_ROIS_PER_IMAGE, NUM_CLASSES, 5]. Rows are class-specific - bbox refinments [y, x, log(h), log(w), weight]. + class_ids: [TRAIN_ROIS_PER_IMAGE]. Integer class IDs. + bboxes: [TRAIN_ROIS_PER_IMAGE, NUM_CLASSES, (y, x, log(h), log(w))]. Class-specific + bbox refinments. masks: [TRAIN_ROIS_PER_IMAGE, height, width, NUM_CLASSES). Class specific masks cropped to bbox boundaries and resized to neural network output size. """ assert rpn_rois.shape[0] > 0 + assert gt_class_ids.dtype == np.int32, "Expected int but got {}".format(gt_class_ids.dtype) assert gt_boxes.dtype == np.int32, "Expected int but got {}".format(gt_boxes.dtype) assert gt_masks.dtype == np.bool_, "Expected bool but got {}".format(gt_masks.dtype) @@ -1179,8 +1186,9 @@ def build_detection_targets(rpn_rois, gt_boxes, gt_masks, config): # according to XinLei Chen's paper, it doesn't help. # Trim empty padding in gt_boxes and gt_masks parts - instance_ids = np.where(gt_boxes[:, 4] > 0)[0] + instance_ids = np.where(gt_class_ids > 0)[0] assert instance_ids.shape[0] > 0, "Image must contain instances." + gt_class_ids = gt_class_ids[instance_ids] gt_boxes = gt_boxes[instance_ids] gt_masks = gt_masks[:, :, instance_ids] @@ -1191,15 +1199,16 @@ def build_detection_targets(rpn_rois, gt_boxes, gt_masks, config): # Compute overlaps [rpn_rois, gt_boxes] overlaps = np.zeros((rpn_rois.shape[0], gt_boxes.shape[0])) for i in range(overlaps.shape[1]): - gt = gt_boxes[i][:4] - overlaps[:,i] = utils.compute_iou(gt, rpn_rois, gt_box_area[i], rpn_roi_area) + gt = gt_boxes[i] + overlaps[:, i] = utils.compute_iou(gt, rpn_rois, gt_box_area[i], rpn_roi_area) # Assign ROIs to GT boxes rpn_roi_iou_argmax = np.argmax(overlaps, axis=1) rpn_roi_iou_max = overlaps[np.arange(overlaps.shape[0]), rpn_roi_iou_argmax] rpn_roi_gt_boxes = gt_boxes[rpn_roi_iou_argmax] # GT box assigned to each ROI + rpn_roi_gt_class_ids = gt_class_ids[rpn_roi_iou_argmax] - # Positive ROIs are those with >= 0.5 IoU with a GT box. + # Positive ROIs are those with >= 0.5 IoU with a GT box. fg_ids = np.where(rpn_roi_iou_max > 0.5)[0] # Negative ROIs are those with max IoU 0.1-0.5 (hard example mining) @@ -1247,36 +1256,35 @@ def build_detection_targets(rpn_rois, gt_boxes, gt_masks, config): # Reset the gt boxes assigned to BG ROIs. rpn_roi_gt_boxes[keep_bg_ids, :] = 0 + rpn_roi_gt_class_ids[keep_bg_ids] = 0 # For each kept ROI, assign a class_id, and for FG ROIs also add bbox refinement. - rois = rpn_rois[keep, :4] + rois = rpn_rois[keep] roi_gt_boxes = rpn_roi_gt_boxes[keep] - class_ids = roi_gt_boxes[:,4].astype(np.int32) + roi_gt_class_ids = rpn_roi_gt_class_ids[keep] roi_gt_assignment = rpn_roi_iou_argmax[keep] - # Class-aware bbox shifts. [y, x, log(h), log(w), weight]. Weight is 0 or 1 to - # determine if a bbox is included in the loss. - bboxes = np.zeros((config.TRAIN_ROIS_PER_IMAGE, config.NUM_CLASSES, 5), dtype=np.float32) - pos_ids = np.where(class_ids > 0)[0] - bboxes[pos_ids, class_ids[pos_ids], :4] = utils.box_refinement(rois[pos_ids], roi_gt_boxes[pos_ids, :4]) - bboxes[pos_ids, class_ids[pos_ids], 4] = 1 # weight = 1 to influence the loss + # Class-aware bbox deltas. [y, x, log(h), log(w)] + bboxes = np.zeros((config.TRAIN_ROIS_PER_IMAGE, config.NUM_CLASSES, 4), dtype=np.float32) + pos_ids = np.where(roi_gt_class_ids > 0)[0] + bboxes[pos_ids, roi_gt_class_ids[pos_ids]] = utils.box_refinement(rois[pos_ids], roi_gt_boxes[pos_ids, :4]) # Normalize bbox refinments - bboxes[:, :, :4] /= config.BBOX_STD_DEV + bboxes /= config.BBOX_STD_DEV # Generate class-specific target masks. masks = np.zeros((config.TRAIN_ROIS_PER_IMAGE, config.MASK_SHAPE[0], config.MASK_SHAPE[1], config.NUM_CLASSES), dtype=np.float32) for i in pos_ids: - class_id = class_ids[i] + class_id = roi_gt_class_ids[i] assert class_id > 0, "class id must be greater than 0" gt_id = roi_gt_assignment[i] class_mask = gt_masks[:, :, gt_id] - + if config.USE_MINI_MASK: # Create a mask placeholder, the size of the image placeholder = np.zeros(config.IMAGE_SHAPE[:2], dtype=bool) # GT box - gt_y1, gt_x1, gt_y2, gt_x2 = gt_boxes[gt_id][:4] + gt_y1, gt_x1, gt_y2, gt_x2 = gt_boxes[gt_id] gt_w = gt_x2 - gt_x1 gt_h = gt_y2 - gt_y1 # Resize mini mask to size of GT box @@ -1285,22 +1293,23 @@ def build_detection_targets(rpn_rois, gt_boxes, gt_masks, config): interp='nearest') / 255.0).astype(bool) # Place the mini batch in the placeholder class_mask = placeholder - + # Pick part of the mask and resize it - y1, x1, y2, x2 = rois[i][:4].astype(np.int32) + y1, x1, y2, x2 = rois[i].astype(np.int32) m = class_mask[y1:y2, x1:x2] mask = scipy.misc.imresize(m.astype(float), config.MASK_SHAPE, interp='nearest') / 255.0 - masks[i,:,:,class_id] = mask - - return rois, class_ids, bboxes, masks + masks[i, :, :, class_id] = mask + + return rois, roi_gt_class_ids, bboxes, masks -def build_rpn_targets(image_shape, anchors, gt_boxes, config): +def build_rpn_targets(image_shape, anchors, gt_class_ids, gt_boxes, config): """Given the anchors and GT boxes, compute overlaps and identify positive anchors and deltas to refine them to match their corresponding GT boxes. anchors: [num_anchors, (y1, x1, y2, x2)] - gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, class_id)] + gt_class_ids: [num_gt_boxes] Integer class IDs. + gt_boxes: [num_gt_boxes, (y1, x1, y2, x2)] Returns: rpn_match: [N] (int32) matches between anchors and GT boxes. @@ -1320,7 +1329,7 @@ def build_rpn_targets(image_shape, anchors, gt_boxes, config): # Each cell contains the IoU of an anchor and GT box. overlaps = np.zeros((anchors.shape[0], gt_boxes.shape[0])) for i in range(overlaps.shape[1]): - gt = gt_boxes[i][:4] + gt = gt_boxes[i] overlaps[:, i] = utils.compute_iou(gt, anchors, gt_box_area[i], anchor_area) # Match anchors to GT Boxes @@ -1365,7 +1374,7 @@ def build_rpn_targets(image_shape, anchors, gt_boxes, config): # TODO: use box_refinment() rather than duplicating the code here for i, a in zip(ids, anchors[ids]): # Closest gt box (it might have IoU < 0.7) - gt = gt_boxes[anchor_iou_argmax[i], :4] + gt = gt_boxes[anchor_iou_argmax[i]] # Convert coordinates to center plus width/height. # GT Box @@ -1393,23 +1402,24 @@ def build_rpn_targets(image_shape, anchors, gt_boxes, config): return rpn_match, rpn_bbox -def generate_random_rois(image_shape, count, gt_boxes): +def generate_random_rois(image_shape, count, gt_class_ids, gt_boxes): """Generates ROI proposals similar to what a region proposal network would generate. image_shape: [Height, Width, Depth] count: Number of ROIs to generate - gt_boxes: [N, (y1, x1, y2, x2, class_id)] Ground trugh boxes in pixels. + gt_class_ids: [N] Integer ground truth class IDs + gt_boxes: [N, (y1, x1, y2, x2)] Ground truth boxes in pixels. Returns: [count, (y1, x1, y2, x2)] ROI boxes in pixels. """ # placeholder rois = np.zeros((count, 4), dtype=np.int32) - + # Generate random ROIs around GT boxes (90% of count) rois_per_box = int(0.9 * count / gt_boxes.shape[0]) for i in range(gt_boxes.shape[0]): - gt_y1, gt_x1, gt_y2, gt_x2 = gt_boxes[i,:4] + gt_y1, gt_x1, gt_y2, gt_x2 = gt_boxes[i] h = gt_y2 - gt_y1 w = gt_x2 - gt_x1 # random boundaries @@ -1417,7 +1427,7 @@ def generate_random_rois(image_shape, count, gt_boxes): r_y2 = min(gt_y2+h, image_shape[0]) r_x1 = max(gt_x1-w, 0) r_x2 = min(gt_x2+w, image_shape[1]) - + # To avoid generating boxes with zero area, we generate double what # we need and filter out the extra. If we get fewer valid boxes # than we need, we loop and try again. @@ -1430,14 +1440,14 @@ def generate_random_rois(image_shape, count, gt_boxes): x1x2 = x1x2[np.abs(x1x2[:,0] - x1x2[:,1]) >= threshold][:rois_per_box] if y1y2.shape[0] == rois_per_box and x1x2.shape[0] == rois_per_box: break - + # Sort on axis 1 to ensure x1 <= x2 and y1 <= y2 and then reshape # into x1, y1, x2, y2 order x1, x2 = np.split(np.sort(x1x2, axis=1), 2, axis=1) y1, y2 = np.split(np.sort(y1y2, axis=1), 2, axis=1) box_rois = np.hstack([y1, x1, y2, x2]) rois[rois_per_box*i:rois_per_box*(i+1)] = box_rois - + # Generate random ROIs anywhere in the image (10% of count) remaining_count = count - (rois_per_box * gt_boxes.shape[0]) # To avoid generating boxes with zero area, we generate double what @@ -1452,7 +1462,7 @@ def generate_random_rois(image_shape, count, gt_boxes): x1x2 = x1x2[np.abs(x1x2[:,0] - x1x2[:,1]) >= threshold][:remaining_count] if y1y2.shape[0] == remaining_count and x1x2.shape[0] == remaining_count: break - + # Sort on axis 1 to ensure x1 <= x2 and y1 <= y2 and then reshape # into x1, y1, x2, y2 order x1, x2 = np.split(np.sort(x1x2, axis=1), 2, axis=1) @@ -1488,7 +1498,8 @@ def data_generator(dataset, config, shuffle=True, augment=True, random_rois=0, - image_meta: [batch, size of image meta] - rpn_match: [batch, N] Integer (1=positive anchor, -1=negative, 0=neutral) - rpn_bbox: [batch, N, (dy, dx, log(dh), log(dw))] Anchor bbox deltas. - - gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2, class_id)] + - gt_class_ids: [batch, MAX_GT_INSTANCES] Integer class IDs + - gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)] - gt_masks: [batch, height, width, MAX_GT_INSTANCES]. The height and width are those of the image unless use_mini_mask is True, in which case they are defined in MINI_MASK_SHAPE. @@ -1520,26 +1531,25 @@ def data_generator(dataset, config, shuffle=True, augment=True, random_rois=0, # Get GT bounding boxes and masks for image. image_id = image_ids[image_index] - image, image_meta, gt_boxes, gt_masks = \ + image, image_meta, gt_class_ids, gt_boxes, gt_masks = \ load_image_gt(dataset, config, image_id, augment=augment, use_mini_mask=config.USE_MINI_MASK) - # Skip images that have no instances. This can happen in cases + # Skip images that have no instances. This can happen in cases # where we train on a subset of classes and the image doesn't # have any of the classes we care about. - if not np.any(gt_boxes): + if not np.any(gt_class_ids): continue # RPN Targets - rpn_match, rpn_bbox = build_rpn_targets(image.shape, anchors, gt_boxes, config) + rpn_match, rpn_bbox = build_rpn_targets(image.shape, anchors, + gt_class_ids, gt_boxes, config) # Mask R-CNN Targets if random_rois: - rpn_rois = generate_random_rois(image.shape, random_rois, gt_boxes) + rpn_rois = generate_random_rois(image.shape, random_rois, gt_class_ids, gt_boxes) if detection_targets: - # Append two columns of zeros. TODO: needed? - rpn_rois = np.hstack([rpn_rois, np.zeros([rpn_rois.shape[0], 2], dtype=np.int32)]) rois, mrcnn_class_ids, mrcnn_bbox, mrcnn_mask =\ - build_detection_targets(rpn_rois, gt_boxes, gt_masks, config) + build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config) # Init batch arrays if b == 0: @@ -1547,7 +1557,8 @@ def data_generator(dataset, config, shuffle=True, augment=True, random_rois=0, batch_rpn_match = np.zeros([batch_size, anchors.shape[0], 1], dtype=rpn_match.dtype) batch_rpn_bbox = np.zeros([batch_size, config.RPN_TRAIN_ANCHORS_PER_IMAGE, 4], dtype=rpn_bbox.dtype) batch_images = np.zeros((batch_size,)+image.shape, dtype=np.float32) - batch_gt_boxes = np.zeros((batch_size, config.MAX_GT_INSTANCES, 5), dtype=np.int32) + batch_gt_class_ids = np.zeros((batch_size, config.MAX_GT_INSTANCES), dtype=np.int32) + batch_gt_boxes = np.zeros((batch_size, config.MAX_GT_INSTANCES, 4), dtype=np.int32) if config.USE_MINI_MASK: batch_gt_masks = np.zeros((batch_size, config.MINI_MASK_SHAPE[0], config.MINI_MASK_SHAPE[1], config.MAX_GT_INSTANCES)) @@ -1564,6 +1575,7 @@ def data_generator(dataset, config, shuffle=True, augment=True, random_rois=0, # If more instances than fits in the array, sub-sample from them. if gt_boxes.shape[0] > config.MAX_GT_INSTANCES: ids = np.random.choice(np.arange(gt_boxes.shape[0]), config.MAX_GT_INSTANCES, replace=False) + gt_class_ids = gt_class_ids[ids] gt_boxes = gt_boxes[ids] gt_masks = gt_masks[:,:,ids] @@ -1572,10 +1584,11 @@ def data_generator(dataset, config, shuffle=True, augment=True, random_rois=0, batch_rpn_match[b] = rpn_match[:, np.newaxis] batch_rpn_bbox[b] = rpn_bbox batch_images[b] = mold_image(image.astype(np.float32), config) + batch_gt_class_ids[b,:gt_class_ids.shape[0]] = gt_class_ids batch_gt_boxes[b,:gt_boxes.shape[0]] = gt_boxes batch_gt_masks[b,:,:,:gt_masks.shape[-1]] = gt_masks if random_rois: - batch_rpn_rois[b] = rpn_rois[:,:4] + batch_rpn_rois[b] = rpn_rois if detection_targets: batch_rois[b] = rois batch_mrcnn_class_ids[b] = mrcnn_class_ids @@ -1586,7 +1599,7 @@ def data_generator(dataset, config, shuffle=True, augment=True, random_rois=0, # Batch full? if b >= batch_size: inputs = [batch_images, batch_image_meta, batch_rpn_match, batch_rpn_bbox, - batch_gt_boxes, batch_gt_masks] + batch_gt_class_ids, batch_gt_boxes, batch_gt_masks] outputs = [] if random_rois: @@ -1621,7 +1634,6 @@ class MaskRCNN(): The actual Keras model is in the keras_model property. """ - def __init__(self, mode, config, model_dir): """ mode: Either "training" or "inference" @@ -1657,14 +1669,18 @@ class MaskRCNN(): # RPN GT input_rpn_match = KL.Input(shape=[None, 1], name="input_rpn_match", dtype=tf.int32) input_rpn_bbox = KL.Input(shape=[None, 4], name="input_rpn_bbox", dtype=tf.float32) - # GT Boxes (zero padded) - # [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2, class_id)] in image coordinates - input_gt_boxes = KL.Input(shape=[None, 5], name="input_gt_boxes", dtype=tf.int32) + + # Detection GT (class IDs, bounding boxes, and masks) + # 1. GT Class IDs (zero padded) + input_gt_class_ids = KL.Input(shape=[None], name="input_gt_class_ids", dtype=tf.int32) + # 2. GT Boxes in pixels (zero padded) + # [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)] in image coordinates + input_gt_boxes = KL.Input(shape=[None, 4], name="input_gt_boxes", dtype=tf.float32) # Normalize coordinates h, w = K.shape(input_image)[1], K.shape(input_image)[2] - image_scale = K.cast(K.stack([h, w, h, w, 1], axis=0), tf.float32) - gt_boxes = KL.Lambda(lambda x: K.cast(x, tf.float32) / image_scale)(input_gt_boxes) - # GT Masks (zero padded) + image_scale = K.cast(K.stack([h, w, h, w], axis=0), tf.float32) + gt_boxes = KL.Lambda(lambda x: x / image_scale)(input_gt_boxes) + # 3. GT Masks (zero padded) # [batch, height, width, MAX_GT_INSTANCES] if config.USE_MINI_MASK: input_gt_masks = KL.Input( @@ -1757,11 +1773,11 @@ class MaskRCNN(): # Generate detection targets # Subsamples proposals and generates target outputs for training - # Note that proposals, gt_boxes, and gt_masks might be zero padded - # Equally, returned rois and targets might be zero padded as well + # Note that proposal class IDs, gt_boxes, and gt_masks are zero + # padded. Equally, returned rois and targets are zero padded. rois, target_class_ids, target_bbox, target_mask =\ DetectionTargetLayer(config, name="proposal_targets")([ - target_rois, gt_boxes, input_gt_masks]) + target_rois, input_gt_class_ids, gt_boxes, input_gt_masks]) # Network Heads # TODO: verify that this handles zero padded ROIs @@ -1791,7 +1807,7 @@ class MaskRCNN(): # Model inputs = [input_image, input_image_meta, - input_rpn_match, input_rpn_bbox, input_gt_boxes, input_gt_masks] + input_rpn_match, input_rpn_bbox, input_gt_class_ids, input_gt_boxes, input_gt_masks] if not config.USE_RPN_ROIS: inputs.append(input_rois) outputs = [rpn_class_logits, rpn_class, rpn_bbox, diff --git a/train_shapes.ipynb b/train_shapes.ipynb index e8e5365..82b643c 100644 --- a/train_shapes.ipynb +++ b/train_shapes.ipynb @@ -907,16 +907,17 @@ "source": [ "# Test on a random image\n", "image_id = random.choice(dataset_val.image_ids)\n", - "original_image, image_meta, gt_bbox, gt_mask =\\\n", + "original_image, image_meta, gt_class_id, gt_bbox, gt_mask =\\\n", " modellib.load_image_gt(dataset_val, inference_config, \n", " image_id, use_mini_mask=False)\n", "\n", "log(\"original_image\", original_image)\n", "log(\"image_meta\", image_meta)\n", + "log(\"gt_class_id\", gt_bbox)\n", "log(\"gt_bbox\", gt_bbox)\n", "log(\"gt_mask\", gt_mask)\n", "\n", - "visualize.display_instances(original_image, gt_bbox[:,:4], gt_mask, gt_bbox[:,4], \n", + "visualize.display_instances(original_image, gt_bbox, gt_mask, gt_class_id, \n", " dataset_train.class_names, figsize=(8, 8))" ] }, @@ -981,7 +982,7 @@ "APs = []\n", "for image_id in image_ids:\n", " # Load image and ground truth data\n", - " image, image_meta, gt_bbox, gt_mask =\\\n", + " image, image_meta, gt_class_id, gt_bbox, gt_mask =\\\n", " modellib.load_image_gt(dataset_val, inference_config,\n", " image_id, use_mini_mask=False)\n", " molded_images = np.expand_dims(modellib.mold_image(image, inference_config), 0)\n", @@ -990,7 +991,7 @@ " r = results[0]\n", " # Compute AP\n", " AP, precisions, recalls, overlaps =\\\n", - " utils.compute_ap(gt_bbox[:,:4], gt_bbox[:,4],\n", + " utils.compute_ap(gt_bbox, gt_class_id,\n", " r[\"rois\"], r[\"class_ids\"], r[\"scores\"])\n", " APs.append(AP)\n", " \n", -- GitLab