提交 cdf815f8 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 366052172
上级 eb6687ac
......@@ -16,7 +16,7 @@
"""RetinaNet configuration definition."""
import os
from typing import List, Optional
from typing import Dict, List, Optional, Tuple
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
......@@ -93,6 +93,7 @@ class RetinaNetHead(hyperparams.Config):
num_convs: int = 4
num_filters: int = 256
use_separable_conv: bool = False
attribute_heads: Optional[Dict[str, Tuple[str, int]]] = None
@dataclasses.dataclass
......
......@@ -221,6 +221,7 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec,
num_anchors_per_location=num_anchors_per_location,
num_convs=head_config.num_convs,
num_filters=head_config.num_filters,
attribute_heads=head_config.attribute_heads,
use_separable_conv=head_config.use_separable_conv,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
......
......@@ -32,6 +32,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
num_anchors_per_location,
num_convs=4,
num_filters=256,
attribute_heads=None,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
......@@ -52,6 +53,9 @@ class RetinaNetHead(tf.keras.layers.Layer):
conv layers before the prediction.
num_filters: An `int` number that represents the number of filters of the
intermediate conv layers.
attribute_heads: If not None, a dict that contains
(attribute_name, attribute_config) for additional attribute heads.
`attribute_config` is a tuple of (attribute_type, attribute_size).
use_separable_conv: A `bool` that indicates whether the separable
convolution layers is used.
activation: A `str` that indicates which activation is used, e.g. 'relu',
......@@ -73,6 +77,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
'num_anchors_per_location': num_anchors_per_location,
'num_convs': num_convs,
'num_filters': num_filters,
'attribute_heads': attribute_heads,
'use_separable_conv': use_separable_conv,
'activation': activation,
'use_sync_bn': use_sync_bn,
......@@ -174,6 +179,64 @@ class RetinaNetHead(tf.keras.layers.Layer):
})
self._box_regressor = conv_op(name='boxes', **box_regressor_kwargs)
# Attribute learning nets.
if self._config_dict['attribute_heads']:
self._att_predictors = {}
self._att_convs = {}
self._att_norms = {}
for att_name, att_head in self._config_dict['attribute_heads'].items():
att_convs_i = []
att_norms_i = []
att_type = att_head[0]
att_size = att_head[1]
# Build conv and norm layers.
for level in range(self._config_dict['min_level'],
self._config_dict['max_level'] + 1):
this_level_att_norms = []
for i in range(self._config_dict['num_convs']):
if level == self._config_dict['min_level']:
att_conv_name = '{}-conv_{}'.format(att_name, i)
att_convs_i.append(conv_op(name=att_conv_name, **conv_kwargs))
att_norm_name = '{}-conv-norm_{}_{}'.format(att_name, level, i)
this_level_att_norms.append(bn_op(name=att_norm_name, **bn_kwargs))
att_norms_i.append(this_level_att_norms)
self._att_convs[att_name] = att_convs_i
self._att_norms[att_name] = att_norms_i
# Build the final prediction layer.
att_predictor_kwargs = {
'filters':
(att_size * self._config_dict['num_anchors_per_location']),
'kernel_size': 3,
'padding': 'same',
'bias_initializer': tf.zeros_initializer(),
'bias_regularizer': self._config_dict['bias_regularizer'],
}
if att_type == 'regression':
att_predictor_kwargs.update(
{'bias_initializer': tf.zeros_initializer()})
elif att_type == 'classification':
att_predictor_kwargs.update({
'bias_initializer':
tf.constant_initializer(-np.log((1 - 0.01) / 0.01))
})
else:
raise ValueError(
'Attribute head type {} not supported.'.format(att_type))
if not self._config_dict['use_separable_conv']:
att_predictor_kwargs.update({
'kernel_initializer':
tf.keras.initializers.RandomNormal(stddev=1e-5),
'kernel_regularizer':
self._config_dict['kernel_regularizer'],
})
self._att_predictors[att_name] = conv_op(
name='{}_attributes'.format(att_name), **att_predictor_kwargs)
super(RetinaNetHead, self).build(input_shape)
def call(self, features):
......@@ -197,9 +260,25 @@ class RetinaNetHead(tf.keras.layers.Layer):
- values: A `tf.Tensor` of the box scores predicted from a particular
feature level, whose shape is
[batch, height_l, width_l, 4 * num_anchors_per_location].
attributes: a dict of (attribute_name, attribute_prediction). Each
`attribute_prediction` is a dict of:
- key: `str`, the level of the multilevel predictions.
- values: `Tensor`, the box scores predicted from a particular feature
level, whose shape is
[batch, height_l, width_l,
attribute_size * num_anchors_per_location].
Can be an empty dictionary if no attribute learning is required.
"""
scores = {}
boxes = {}
if self._config_dict['attribute_heads']:
attributes = {
att_name: {}
for att_name in self._config_dict['attribute_heads'].keys()
}
else:
attributes = {}
for i, level in enumerate(
range(self._config_dict['min_level'],
self._config_dict['max_level'] + 1)):
......@@ -220,7 +299,19 @@ class RetinaNetHead(tf.keras.layers.Layer):
x = norm(x)
x = self._activation(x)
boxes[str(level)] = self._box_regressor(x)
return scores, boxes
# attribute nets.
if self._config_dict['attribute_heads']:
for att_name in self._config_dict['attribute_heads'].keys():
x = this_level_features
for conv, norm in zip(self._att_convs[att_name],
self._att_norms[att_name][i]):
x = conv(x)
x = norm(x)
x = self._activation(x)
attributes[att_name][str(level)] = self._att_predictors[att_name](x)
return scores, boxes, attributes
def get_config(self):
return self._config_dict
......
......@@ -26,12 +26,17 @@ from official.vision.beta.modeling.heads import dense_prediction_heads
class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(False, False),
(False, True),
(True, False),
(True, True),
(False, False, False),
(False, True, False),
(True, False, True),
(True, True, True),
)
def test_forward(self, use_separable_conv, use_sync_bn):
def test_forward(self, use_separable_conv, use_sync_bn, has_att_heads):
if has_att_heads:
attribute_heads = {'depth': ('regression', 1)}
else:
attribute_heads = None
retinanet_head = dense_prediction_heads.RetinaNetHead(
min_level=3,
max_level=4,
......@@ -39,6 +44,7 @@ class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
num_anchors_per_location=3,
num_convs=2,
num_filters=256,
attribute_heads=attribute_heads,
use_separable_conv=use_separable_conv,
activation='relu',
use_sync_bn=use_sync_bn,
......@@ -51,11 +57,15 @@ class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
'3': np.random.rand(2, 128, 128, 16),
'4': np.random.rand(2, 64, 64, 16),
}
scores, boxes = retinanet_head(features)
scores, boxes, attributes = retinanet_head(features)
self.assertAllEqual(scores['3'].numpy().shape, [2, 128, 128, 9])
self.assertAllEqual(scores['4'].numpy().shape, [2, 64, 64, 9])
self.assertAllEqual(boxes['3'].numpy().shape, [2, 128, 128, 12])
self.assertAllEqual(boxes['4'].numpy().shape, [2, 64, 64, 12])
if has_att_heads:
for att in attributes.values():
self.assertAllEqual(att['3'].numpy().shape, [2, 128, 128, 3])
self.assertAllEqual(att['4'].numpy().shape, [2, 64, 64, 3])
def test_serialize_deserialize(self):
retinanet_head = dense_prediction_heads.RetinaNetHead(
......@@ -65,6 +75,7 @@ class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
num_anchors_per_location=9,
num_convs=2,
num_filters=16,
attribute_heads=None,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
......
......@@ -15,7 +15,6 @@
"""Contains definitions of generators to generate the final detections."""
# Import libraries
import tensorflow as tf
from official.vision.beta.ops import box_ops
......@@ -24,6 +23,7 @@ from official.vision.beta.ops import nms
def _generate_detections_v1(boxes,
scores,
attributes=None,
pre_nms_top_k=5000,
pre_nms_score_threshold=0.05,
nms_iou_threshold=0.5,
......@@ -36,12 +36,18 @@ def _generate_detections_v1(boxes,
Args:
boxes: A `tf.Tensor` with shape `[batch_size, N, num_classes, 4]` or
`[batch_size, N, 1, 4]`, which box predictions on all feature levels. The
`[batch_size, N, 1, 4]` for box predictions on all feature levels. The
N is the number of total anchors on all levels.
scores: A `tf.Tensor` with shape `[batch_size, N, num_classes]`, which
stacks class probability on all feature levels. The N is the number of
total anchors on all levels. The num_classes is the number of classes
predicted by the model. Note that the class_outputs here is the raw score.
attributes: None or a dict of (attribute_name, attributes) pairs. Each
attributes is a `tf.Tensor` with shape
`[batch_size, N, num_classes, attribute_size]` or
`[batch_size, N, 1, attribute_size]` for attribute predictions on all
feature levels. The N is the number of total anchors on all levels. Can
be None if no attribute learning is required.
pre_nms_top_k: An `int` number of top candidate detections per class before
NMS.
pre_nms_score_threshold: A `float` representing the threshold for deciding
......@@ -63,6 +69,11 @@ def _generate_detections_v1(boxes,
boxes.
valid_detections: An `int` type `tf.Tensor` of shape `[batch_size]` only the
top `valid_detections` boxes are valid detections.
nms_attributes: None or a dict of (attribute_name, attributes). Each
attribute is a `float` type `tf.Tensor` of shape
`[batch_size, max_num_detections, attribute_size]` representing attribute
predictions for detected boxes. Can be an empty dict if no attribute
learning is required.
"""
with tf.name_scope('generate_detections'):
batch_size = scores.get_shape().as_list()[0]
......@@ -70,28 +81,45 @@ def _generate_detections_v1(boxes,
nmsed_classes = []
nmsed_scores = []
valid_detections = []
if attributes:
nmsed_attributes = {att_name: [] for att_name in attributes.keys()}
else:
nmsed_attributes = {}
for i in range(batch_size):
(nmsed_boxes_i, nmsed_scores_i, nmsed_classes_i,
valid_detections_i) = _generate_detections_per_image(
(nmsed_boxes_i, nmsed_scores_i, nmsed_classes_i, valid_detections_i,
nmsed_att_i) = _generate_detections_per_image(
boxes[i],
scores[i],
max_num_detections,
nms_iou_threshold,
pre_nms_score_threshold,
pre_nms_top_k)
attributes={
att_name: att[i] for att_name, att in attributes.items()
} if attributes else {},
pre_nms_top_k=pre_nms_top_k,
pre_nms_score_threshold=pre_nms_score_threshold,
nms_iou_threshold=nms_iou_threshold,
max_num_detections=max_num_detections)
nmsed_boxes.append(nmsed_boxes_i)
nmsed_scores.append(nmsed_scores_i)
nmsed_classes.append(nmsed_classes_i)
valid_detections.append(valid_detections_i)
if attributes:
for att_name in attributes.keys():
nmsed_attributes[att_name].append(nmsed_att_i[att_name])
nmsed_boxes = tf.stack(nmsed_boxes, axis=0)
nmsed_scores = tf.stack(nmsed_scores, axis=0)
nmsed_classes = tf.stack(nmsed_classes, axis=0)
valid_detections = tf.stack(valid_detections, axis=0)
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
if attributes:
for att_name in attributes.keys():
nmsed_attributes[att_name] = tf.stack(nmsed_attributes[att_name], axis=0)
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, nmsed_attributes
def _generate_detections_per_image(boxes,
scores,
attributes=None,
pre_nms_top_k=5000,
pre_nms_score_threshold=0.05,
nms_iou_threshold=0.5,
......@@ -106,6 +134,10 @@ def _generate_detections_per_image(boxes,
probability on all feature levels. The N is the number of total anchors on
all levels. The num_classes is the number of classes predicted by the
model. Note that the class_outputs here is the raw score.
attributes: If not None, a dict of `tf.Tensor`. Each value is in shape
`[N, num_classes, attribute_size]` or `[N, 1, attribute_size]` of
attribute predictions on all feature levels. The N is the number of total
anchors on all levels.
pre_nms_top_k: An `int` number of top candidate detections per class before
NMS.
pre_nms_score_threshold: A `float` representing the threshold for deciding
......@@ -125,16 +157,23 @@ def _generate_detections_per_image(boxes,
classes for detected boxes.
valid_detections: An `int` tf.Tensor of shape [1] only the top
`valid_detections` boxes are valid detections.
nms_attributes: None or a dict. Each value is a `float` tf.Tensor of shape
`[max_num_detections, attribute_size]` representing attribute predictions
for detected boxes. Can be an empty dict if `attributes` is None.
"""
nmsed_boxes = []
nmsed_scores = []
nmsed_classes = []
num_classes_for_box = boxes.get_shape().as_list()[1]
num_classes = scores.get_shape().as_list()[1]
if attributes:
nmsed_attributes = {att_name: [] for att_name in attributes.keys()}
else:
nmsed_attributes = {}
for i in range(num_classes):
boxes_i = boxes[:, min(num_classes_for_box - 1, i)]
scores_i = scores[:, i]
# Obtains pre_nms_top_k before running NMS.
scores_i, indices = tf.nn.top_k(
scores_i, k=tf.minimum(tf.shape(scores_i)[-1], pre_nms_top_k))
......@@ -159,6 +198,13 @@ def _generate_detections_per_image(boxes,
nmsed_boxes.append(nmsed_boxes_i)
nmsed_scores.append(nmsed_scores_i)
nmsed_classes.append(nmsed_classes_i)
if attributes:
for att_name, att in attributes.items():
num_classes_for_attr = att.get_shape().as_list()[1]
att_i = att[:, min(num_classes_for_attr - 1, i)]
att_i = tf.gather(att_i, indices)
nmsed_att_i = tf.gather(att_i, nmsed_indices_i)
nmsed_attributes[att_name].append(nmsed_att_i)
# Concats results from all classes and sort them.
nmsed_boxes = tf.concat(nmsed_boxes, axis=0)
......@@ -170,7 +216,13 @@ def _generate_detections_per_image(boxes,
nmsed_classes = tf.gather(nmsed_classes, indices)
valid_detections = tf.reduce_sum(
tf.cast(tf.greater(nmsed_scores, -1), tf.int32))
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
if attributes:
for att_name in attributes.keys():
nmsed_attributes[att_name] = tf.concat(nmsed_attributes[att_name], axis=0)
nmsed_attributes[att_name] = tf.gather(nmsed_attributes[att_name],
indices)
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, nmsed_attributes
def _select_top_k_scores(scores_in, pre_nms_num_detections):
......@@ -532,7 +584,8 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
raw_boxes,
raw_scores,
anchor_boxes,
image_shape):
image_shape,
raw_attributes=None):
"""Generates final detections.
Args:
......@@ -547,6 +600,11 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
image_shape: A `tf.Tensor` of shape of [batch_size, 2] storing the image
height and width w.r.t. the scaled image, i.e. the same image space as
`box_outputs` and `anchor_boxes`.
raw_attributes: If not None, a `dict` of
(attribute_name, attribute_prediction) pairs. `attribute_prediction`
is a dict that contains keys representing FPN levels and values
representing tenors of shape `[batch, feature_h, feature_w,
num_anchors * attribute_size]`.
Returns:
If `apply_nms` = True, the return is a dictionary with keys:
......@@ -560,15 +618,26 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
[batch, max_num_detections] representing classes for detected boxes.
`num_detections`: An `int` tf.Tensor of shape [batch] only the first
`num_detections` boxes are valid detections
`detection_attributes`: A dict. Values of the dict is a `float`
tf.Tensor of shape [batch, max_num_detections, attribute_size]
representing attribute predictions for detected boxes.
If `apply_nms` = False, the return is a dictionary with keys:
`decoded_boxes`: A `float` tf.Tensor of shape [batch, num_raw_boxes, 4]
representing all the decoded boxes.
`decoded_box_scores`: A `float` tf.Tensor of shape
[batch, num_raw_boxes] representing socres of all the decoded boxes.
`decoded_box_attributes`: A dict. Values in the dict is a
`float` tf.Tensor of shape [batch, num_raw_boxes, attribute_size]
representing attribute predictions of all the decoded boxes.
"""
# Collects outputs from all levels into a list.
boxes = []
scores = []
if raw_attributes:
attributes = {att_name: [] for att_name in raw_attributes.keys()}
else:
attributes = {}
levels = list(raw_boxes.keys())
min_level = int(min(levels))
max_level = int(max(levels))
......@@ -597,17 +666,34 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
boxes.append(boxes_i)
scores.append(scores_i)
if raw_attributes:
for att_name, raw_att in raw_attributes.items():
attribute_size = tf.shape(
raw_att[str(i)])[-1] // num_anchors_per_locations
att_i = tf.reshape(raw_att[str(i)], [batch_size, -1, attribute_size])
attributes[att_name].append(att_i)
boxes = tf.concat(boxes, axis=1)
boxes = tf.expand_dims(boxes, axis=2)
scores = tf.concat(scores, axis=1)
if raw_attributes:
for att_name in raw_attributes.keys():
attributes[att_name] = tf.concat(attributes[att_name], axis=1)
attributes[att_name] = tf.expand_dims(attributes[att_name], axis=2)
if not self._config_dict['apply_nms']:
return {
'decoded_boxes': boxes,
'decoded_box_scores': scores,
'decoded_box_attributes': attributes,
}
if self._config_dict['use_batched_nms']:
if raw_attributes:
raise ValueError('Attribute learning is not supported for batched NMS.')
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
_generate_detections_batched(
boxes,
......@@ -615,16 +701,28 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
# Set `nmsed_attributes` to None for batched NMS.
nmsed_attributes = {}
else:
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
_generate_detections_v2(
boxes,
scores,
self._config_dict['pre_nms_top_k'],
self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
if raw_attributes:
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, nmsed_attributes = (
_generate_detections_v1(
boxes,
scores,
attributes=attributes if raw_attributes else None,
pre_nms_top_k=self._config_dict['pre_nms_top_k'],
pre_nms_score_threshold=self
._config_dict['pre_nms_score_threshold'],
nms_iou_threshold=self._config_dict['nms_iou_threshold'],
max_num_detections=self._config_dict['max_num_detections']))
else:
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
_generate_detections_v2(
boxes, scores, self._config_dict['pre_nms_top_k'],
self._config_dict['pre_nms_score_threshold'],
self._config_dict['nms_iou_threshold'],
self._config_dict['max_num_detections']))
nmsed_attributes = {}
# Adds 1 to offset the background class which has index 0.
nmsed_classes += 1
......@@ -633,6 +731,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
'detection_boxes': nmsed_boxes,
'detection_classes': nmsed_classes,
'detection_scores': nmsed_scores,
'detection_attributes': nmsed_attributes,
}
def get_config(self):
......
......@@ -116,10 +116,11 @@ class MultilevelDetectionGeneratorTest(
parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(True),
(False),
(True, False),
(False, False),
(False, True),
)
def testDetectionsOutputShape(self, use_batched_nms):
def testDetectionsOutputShape(self, use_batched_nms, has_att_heads):
min_level = 4
max_level = 6
num_scales = 2
......@@ -169,11 +170,34 @@ class MultilevelDetectionGeneratorTest(
'6': tf.reshape(tf.convert_to_tensor(
box_outputs_all[80:84], dtype=tf.float32), [1, 2, 2, 4]),
}
if has_att_heads:
att_outputs_all = np.random.rand(84, 1) # random attributes.
att_outputs = {
'depth': {
'4':
tf.reshape(
tf.convert_to_tensor(
att_outputs_all[0:64], dtype=tf.float32),
[1, 8, 8, 1]),
'5':
tf.reshape(
tf.convert_to_tensor(
att_outputs_all[64:80], dtype=tf.float32),
[1, 4, 4, 1]),
'6':
tf.reshape(
tf.convert_to_tensor(
att_outputs_all[80:84], dtype=tf.float32),
[1, 2, 2, 1]),
}
}
else:
att_outputs = None
image_info = tf.constant([[[1000, 1000], [100, 100], [0.1, 0.1], [0, 0]]],
dtype=tf.float32)
generator = detection_generator.MultilevelDetectionGenerator(**kwargs)
results = generator(box_outputs, class_outputs, anchor_boxes,
image_info[:, 1, :])
image_info[:, 1, :], att_outputs)
boxes = results['detection_boxes']
classes = results['detection_classes']
scores = results['detection_scores']
......@@ -183,6 +207,9 @@ class MultilevelDetectionGeneratorTest(
self.assertEqual(scores.numpy().shape, (batch_size, max_num_detections,))
self.assertEqual(classes.numpy().shape, (batch_size, max_num_detections,))
self.assertEqual(valid_detections.numpy().shape, (batch_size,))
if has_att_heads:
for att in results['detection_attributes'].values():
self.assertEqual(att.numpy().shape, (batch_size, max_num_detections, 1))
def test_serialize_deserialize(self):
kwargs = {
......
......@@ -81,31 +81,41 @@ class RetinaNetModel(tf.keras.Model):
- values: `Tensor`, the box coordinates predicted from a particular
feature level, whose shape is
[batch, height_l, width_l, 4 * num_anchors_per_location].
attributes: a dict of (attribute_name, attribute_predictions). Each
attribute prediction is a dict that includes:
- key: `str`, the level of the multilevel predictions.
- values: `Tensor`, the attribute predictions from a particular
feature level, whose shape is
[batch, height_l, width_l, att_size * num_anchors_per_location].
"""
# Feature extraction.
features = self.backbone(images)
if self.decoder:
features = self.decoder(features)
# Dense prediction.
raw_scores, raw_boxes = self.head(features)
# Dense prediction. `raw_attributes` can be empty.
raw_scores, raw_boxes, raw_attributes = self.head(features)
if training:
return {
'cls_outputs': raw_scores,
'box_outputs': raw_boxes,
'att_outputs': raw_attributes,
}
else:
# Post-processing.
final_results = self.detection_generator(
raw_boxes, raw_scores, anchor_boxes, image_shape)
final_results = self.detection_generator(raw_boxes, raw_scores,
anchor_boxes, image_shape,
raw_attributes)
return {
'detection_boxes': final_results['detection_boxes'],
'detection_scores': final_results['detection_scores'],
'detection_classes': final_results['detection_classes'],
'detection_attributes': final_results['detection_attributes'],
'num_detections': final_results['num_detections'],
'cls_outputs': raw_scores,
'box_outputs': raw_boxes
'box_outputs': raw_boxes,
'att_outputs': raw_attributes,
}
@property
......
......@@ -95,11 +95,13 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
image_size=[(128, 128),],
image_size=[
(128, 128),
],
training=[True, False],
)
)
def test_forward(self, strategy, image_size, training):
has_att_heads=[True, False],
))
def test_forward(self, strategy, image_size, training, has_att_heads):
"""Test for creation of a R50-FPN RetinaNet."""
tf.keras.backend.set_image_data_format('channels_last')
num_classes = 3
......@@ -130,10 +132,16 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
input_specs=backbone.output_specs,
min_level=min_level,
max_level=max_level)
if has_att_heads:
attribute_heads = {'depth': ('regression', 1)}
else:
attribute_heads = None
head = dense_prediction_heads.RetinaNetHead(
min_level=min_level,
max_level=max_level,
num_classes=num_classes,
attribute_heads=attribute_heads,
num_anchors_per_location=num_anchors_per_location)
generator = detection_generator.MultilevelDetectionGenerator(
max_num_detections=10)
......@@ -152,6 +160,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
if training:
cls_outputs = model_outputs['cls_outputs']
box_outputs = model_outputs['box_outputs']
att_outputs = model_outputs['att_outputs']
for level in range(min_level, max_level + 1):
self.assertIn(str(level), cls_outputs)
self.assertIn(str(level), box_outputs)
......@@ -167,10 +176,17 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
image_size[1] // 2**level,
4 * num_anchors_per_location
], box_outputs[str(level)].numpy().shape)
if has_att_heads:
for att in att_outputs.values():
self.assertAllEqual([
2, image_size[0] // 2**level, image_size[1] // 2**level,
1 * num_anchors_per_location
], att[str(level)].numpy().shape)
else:
self.assertIn('detection_boxes', model_outputs)
self.assertIn('detection_scores', model_outputs)
self.assertIn('detection_classes', model_outputs)
self.assertIn('detection_attributes', model_outputs)
self.assertIn('num_detections', model_outputs)
self.assertAllEqual(
[2, 10, 4], model_outputs['detection_boxes'].numpy().shape)
......@@ -180,6 +196,10 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
[2, 10], model_outputs['detection_classes'].numpy().shape)
self.assertAllEqual(
[2,], model_outputs['num_detections'].numpy().shape)
if has_att_heads:
self.assertAllEqual(
[2, 10, 1],
model_outputs['detection_attributes']['depth'].numpy().shape)
def test_serialize_deserialize(self):
"""Validate the network can be serialized and deserialized."""
......@@ -220,4 +240,3 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册