diff --git a/official/vision/beta/modeling/retinanet_model.py b/official/vision/beta/modeling/retinanet_model.py index a5f61cb182e431adc24dba965154cac28f6c7067..bfb6e73e38298a04a9bd9e9a0df13b1c924692d6 100644 --- a/official/vision/beta/modeling/retinanet_model.py +++ b/official/vision/beta/modeling/retinanet_model.py @@ -77,6 +77,7 @@ class RetinaNetModel(tf.keras.Model): images: tf.Tensor, image_shape: Optional[tf.Tensor] = None, anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None, + output_intermediate_features: bool = False, training: bool = None) -> Mapping[str, tf.Tensor]: """Forward pass of the RetinaNet model. @@ -92,6 +93,8 @@ class RetinaNetModel(tf.keras.Model): - key: `str`, the level of the multilevel predictions. - values: `Tensor`, the anchor coordinates of a particular feature level, whose shape is [height_l, width_l, num_anchors_per_location]. + output_intermediate_features: `bool` indicating whether to return the + intermediate feature maps generated by backbone and decoder. training: `bool`, indicating whether it is in training mode. Returns: @@ -112,19 +115,26 @@ class RetinaNetModel(tf.keras.Model): feature level, whose shape is [batch, height_l, width_l, att_size * num_anchors_per_location]. """ + outputs = {} # Feature extraction. features = self.backbone(images) + if output_intermediate_features: + outputs.update( + {'backbone_{}'.format(k): v for k, v in features.items()}) if self.decoder: features = self.decoder(features) + if output_intermediate_features: + outputs.update( + {'decoder_{}'.format(k): v for k, v in features.items()}) # Dense prediction. `raw_attributes` can be empty. raw_scores, raw_boxes, raw_attributes = self.head(features) if training: - outputs = { + outputs.update({ 'cls_outputs': raw_scores, 'box_outputs': raw_boxes, - } + }) if raw_attributes: outputs.update({'attribute_outputs': raw_attributes}) return outputs @@ -145,12 +155,13 @@ class RetinaNetModel(tf.keras.Model): [tf.shape(images)[0], 1, 1, 1]) # Post-processing. - final_results = self.detection_generator( - raw_boxes, raw_scores, anchor_boxes, image_shape, raw_attributes) - outputs = { + final_results = self.detection_generator(raw_boxes, raw_scores, + anchor_boxes, image_shape, + raw_attributes) + outputs.update({ 'cls_outputs': raw_scores, 'box_outputs': raw_boxes, - } + }) if self.detection_generator.get_config()['apply_nms']: outputs.update({ 'detection_boxes': final_results['detection_boxes'], diff --git a/official/vision/beta/modeling/retinanet_model_test.py b/official/vision/beta/modeling/retinanet_model_test.py index 2f5f0119cef1e608f69af904b0875a9141ea689c..003fdb5a92adf9457947047fd17b8cbc19a44f68 100644 --- a/official/vision/beta/modeling/retinanet_model_test.py +++ b/official/vision/beta/modeling/retinanet_model_test.py @@ -147,8 +147,10 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): ], training=[True, False], has_att_heads=[True, False], + output_intermediate_features=[True, False], )) - def test_forward(self, strategy, image_size, training, has_att_heads): + def test_forward(self, strategy, image_size, training, has_att_heads, + output_intermediate_features): """Test for creation of a R50-FPN RetinaNet.""" tf.keras.backend.set_image_data_format('channels_last') num_classes = 3 @@ -202,6 +204,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): images, image_shape, anchor_boxes, + output_intermediate_features=output_intermediate_features, training=training) if training: @@ -247,6 +250,19 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): self.assertAllEqual( [2, 10, 1], model_outputs['detection_attributes']['depth'].numpy().shape) + if output_intermediate_features: + for l in range(2, 6): + self.assertIn('backbone_{}'.format(l), model_outputs) + self.assertAllEqual([ + 2, image_size[0] // 2**l, image_size[1] // 2**l, + backbone.output_specs[str(l)].as_list()[-1] + ], model_outputs['backbone_{}'.format(l)].numpy().shape) + for l in range(min_level, max_level + 1): + self.assertIn('decoder_{}'.format(l), model_outputs) + self.assertAllEqual([ + 2, image_size[0] // 2**l, image_size[1] // 2**l, + decoder.output_specs[str(l)].as_list()[-1] + ], model_outputs['decoder_{}'.format(l)].numpy().shape) def test_serialize_deserialize(self): """Validate the network can be serialized and deserialized."""