diff --git a/model.py b/model.py index 8652bd71f07c87da9ed60eb033bf05f56415de6a..b17ded1212a36dad1875570ccbf9470fcd268f4e 100644 --- a/model.py +++ b/model.py @@ -263,21 +263,17 @@ class ProposalLayer(KE.Layer): Inputs: rpn_probs: [batch, anchors, (bg prob, fg prob)] rpn_bbox: [batch, anchors, (dy, dx, log(dh), log(dw))] + anchors: [batch, (y1, x1, y2, x2)] anchors in normalized coordinates Returns: Proposals in normalized coordinates [batch, rois, (y1, x1, y2, x2)] """ - def __init__(self, proposal_count, nms_threshold, anchors, - config=None, **kwargs): - """ - anchors: [N, (y1, x1, y2, x2)] anchors in normalized coordinates - """ + def __init__(self, proposal_count, nms_threshold, config=None, **kwargs): super(ProposalLayer, self).__init__(**kwargs) self.config = config self.proposal_count = proposal_count self.nms_threshold = nms_threshold - self.anchors = anchors def call(self, inputs): # Box Scores. Use the foreground class confidence. [Batch, num_rois, 1] @@ -285,25 +281,25 @@ class ProposalLayer(KE.Layer): # Box deltas [batch, num_rois, 4] deltas = inputs[1] deltas = deltas * np.reshape(self.config.RPN_BBOX_STD_DEV, [1, 1, 4]) - # Base anchors - anchors = self.anchors + # Anchors + anchors = inputs[2] # Improve performance by trimming to top anchors by score # and doing the rest on the smaller subset. - pre_nms_limit = min(6000, self.anchors.shape[0]) + pre_nms_limit = tf.minimum(6000, tf.shape(anchors)[1]) ix = tf.nn.top_k(scores, pre_nms_limit, sorted=True, name="top_anchors").indices scores = utils.batch_slice([scores, ix], lambda x, y: tf.gather(x, y), self.config.IMAGES_PER_GPU) deltas = utils.batch_slice([deltas, ix], lambda x, y: tf.gather(x, y), self.config.IMAGES_PER_GPU) - anchors = utils.batch_slice(ix, lambda x: tf.gather(anchors, x), + pre_nms_anchors = utils.batch_slice([anchors, ix], lambda a, x: tf.gather(a, x), self.config.IMAGES_PER_GPU, names=["pre_nms_anchors"]) # Apply deltas to anchors to get refined anchors. # [batch, N, (y1, x1, y2, x2)] - boxes = utils.batch_slice([anchors, deltas], + boxes = utils.batch_slice([pre_nms_anchors, deltas], lambda x, y: apply_box_deltas_graph(x, y), self.config.IMAGES_PER_GPU, names=["refined_anchors"]) @@ -1847,7 +1843,7 @@ class MaskRCNN(): # Inputs input_image = KL.Input( - shape=config.IMAGE_SHAPE.tolist(), name="input_image") + shape=[None, None, 3], name="input_image") input_image_meta = KL.Input(shape=[config.IMAGE_META_SIZE], name="input_image_meta") if mode == "training": @@ -1879,6 +1875,9 @@ class MaskRCNN(): input_gt_masks = KL.Input( shape=[config.IMAGE_SHAPE[0], config.IMAGE_SHAPE[1], None], name="input_gt_masks", dtype=bool) + elif mode == "inference": + # Anchors in normalized coordinates + input_anchors = KL.Input(shape=[None, 4], name="input_anchors") # Build the shared convolutional layers. # Bottom-up Layers @@ -1911,13 +1910,16 @@ class MaskRCNN(): rpn_feature_maps = [P2, P3, P4, P5, P6] mrcnn_feature_maps = [P2, P3, P4, P5] - # Generate Anchors - backbone_shapes = compute_backbone_shapes(config, config.IMAGE_SHAPE) - self.anchors = utils.generate_pyramid_anchors(config.RPN_ANCHOR_SCALES, - config.RPN_ANCHOR_RATIOS, - backbone_shapes, - config.BACKBONE_STRIDES, - config.RPN_ANCHOR_STRIDE) + # Anchors + if mode == "training": + anchors = self.get_anchors(config.IMAGE_SHAPE) + # Duplicate across the batch dimension because Keras requires it + # TODO: can this be optimized to avoid duplicating the anchors? + anchors = np.broadcast_to(anchors, (config.BATCH_SIZE,) + anchors.shape) + # A hack to get around Keras's bad support for constants + anchors = KL.Lambda(lambda x: tf.constant(anchors), name="anchors")(input_image) + else: + anchors = input_anchors # RPN Model rpn = build_rpn_model(config.RPN_ANCHOR_STRIDE, @@ -1937,19 +1939,16 @@ class MaskRCNN(): rpn_class_logits, rpn_class, rpn_bbox = outputs - # Normalize anchors coordinates - normalized_anchors = utils.norm_boxes(self.anchors, self.config.IMAGE_SHAPE[:2]) - # Generate proposals # Proposals are [batch, N, (y1, x1, y2, x2)] in normalized coordinates # and zero padded. proposal_count = config.POST_NMS_ROIS_TRAINING if mode == "training"\ else config.POST_NMS_ROIS_INFERENCE - rpn_rois = ProposalLayer(proposal_count=proposal_count, - nms_threshold=config.RPN_NMS_THRESHOLD, - name="ROI", - anchors=normalized_anchors, - config=config)([rpn_class, rpn_bbox]) + rpn_rois = ProposalLayer( + proposal_count=proposal_count, + nms_threshold=config.RPN_NMS_THRESHOLD, + name="ROI", + config=config)([rpn_class, rpn_bbox, anchors]) if mode == "training": # Class ID mask to mark class IDs supported by the dataset the image @@ -2036,7 +2035,7 @@ class MaskRCNN(): config.NUM_CLASSES, train_bn=config.TRAIN_BN) - model = KM.Model([input_image, input_image_meta], + model = KM.Model([input_image, input_image_meta, input_anchors], [detections, mrcnn_class, mrcnn_bbox, mrcnn_mask, rpn_rois, rpn_class, rpn_bbox], name='mask_rcnn') @@ -2446,15 +2445,32 @@ class MaskRCNN(): log("Processing {} images".format(len(images))) for image in images: log("image", image) + + # Validate image sizes + if self.config.IMAGE_RESIZE_MODE == "square": + image_shape = self.config.IMAGE_SHAPE + else: + # All images MUST be of the same size + image_shape = images[0].shape + for g in images[1:]: + assert g.shape == image_shape,\ + "Images must have the same size unless IMAGE_RESIZE_MODE is 'square'" + + # Anchors + anchors = self.get_anchors(image_shape) + # Duplicate across the batch dimension because Keras requires it + # TODO: can this be optimized to avoid duplicating the anchors? + anchors = np.broadcast_to(anchors, (self.config.BATCH_SIZE,) + anchors.shape) + # Mold inputs to format expected by the neural network molded_images, image_metas, windows = self.mold_inputs(images) if verbose: log("molded_images", molded_images) log("image_metas", image_metas) + log("anchors", anchors) # Run object detection - detections, mrcnn_class, mrcnn_bbox, mrcnn_mask, \ - rois, rpn_class, rpn_bbox =\ - self.keras_model.predict([molded_images, image_metas], verbose=0) + detections, _, _, mrcnn_mask, _, _, _ =\ + self.keras_model.predict([molded_images, image_metas, anchors], verbose=0) # Process detections results = [] for i, image in enumerate(images): @@ -2470,6 +2486,28 @@ class MaskRCNN(): }) return results + def get_anchors(self, image_shape): + """Returns anchor pyramid for the given image size.""" + backbone_shapes = compute_backbone_shapes(self.config, image_shape) + # Cache anchors and reuse if image shape is the same + if not hasattr(self, "_anchor_cache"): + self._anchor_cache = {} + if not tuple(image_shape) in self._anchor_cache: + # Generate Anchors + a = utils.generate_pyramid_anchors( + self.config.RPN_ANCHOR_SCALES, + self.config.RPN_ANCHOR_RATIOS, + backbone_shapes, + self.config.BACKBONE_STRIDES, + self.config.RPN_ANCHOR_STRIDE) + # Keep a copy of the latest anchors in pixel coordinates because + # it's used in inspect_model notebooks. + # TODO: Remove this after the notebook are refactored to not use it + self.anchors = a + # Normalize coordinates + self._anchor_cache[tuple(image_shape)] = utils.norm_boxes(a, image_shape[:2]) + return self._anchor_cache[tuple(image_shape)] + def ancestor(self, tensor, name, checked=None): """Finds the ancestor of a TF tensor in the computation graph. tensor: TensorFlow symbolic tensor.