提交 a8ae1619 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 366939051
上级 fab47e9e
......@@ -238,7 +238,15 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec,
use_batched_nms=generator_config.use_batched_nms)
model = retinanet_model.RetinaNetModel(
backbone, decoder, head, detection_generator_obj)
backbone,
decoder,
head,
detection_generator_obj,
min_level=model_config.min_level,
max_level=model_config.max_level,
num_scales=model_config.anchor.num_scales,
aspect_ratios=model_config.anchor.aspect_ratios,
anchor_size=model_config.anchor.anchor_size)
return model
......
......@@ -13,10 +13,13 @@
# limitations under the License.
"""RetinaNet."""
from typing import List, Optional
# Import libraries
import tensorflow as tf
from official.vision.beta.ops import anchor
@tf.keras.utils.register_keras_serializable(package='Vision')
class RetinaNetModel(tf.keras.Model):
......@@ -27,6 +30,11 @@ class RetinaNetModel(tf.keras.Model):
decoder,
head,
detection_generator,
min_level: Optional[int] = None,
max_level: Optional[int] = None,
num_scales: Optional[int] = None,
aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None,
**kwargs):
"""Classification initialization function.
......@@ -35,6 +43,17 @@ class RetinaNetModel(tf.keras.Model):
decoder: `tf.keras.Model` a decoder network.
head: `RetinaNetHead`, the RetinaNet head.
detection_generator: the detection generator.
min_level: Minimum level in output feature maps.
max_level: Maximum level in output feature maps.
num_scales: A number representing intermediate scales added
on each level. For instances, num_scales=2 adds one additional
intermediate anchor scales [2^0, 2^0.5] on each level.
aspect_ratios: A list representing the aspect raito
anchors added on each level. The number indicates the ratio of width to
height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
on each scale level.
anchor_size: A number representing the scale of size of the base
anchor to the feature stride 2^level.
**kwargs: keyword arguments to be passed.
"""
super(RetinaNetModel, self).__init__(**kwargs)
......@@ -43,6 +62,11 @@ class RetinaNetModel(tf.keras.Model):
'decoder': decoder,
'head': head,
'detection_generator': detection_generator,
'min_level': min_level,
'max_level': max_level,
'num_scales': num_scales,
'aspect_ratios': aspect_ratios,
'anchor_size': anchor_size,
}
self._backbone = backbone
self._decoder = decoder
......@@ -105,6 +129,21 @@ class RetinaNetModel(tf.keras.Model):
outputs.update({'att_outputs': raw_attributes})
return outputs
else:
# Generate anchor boxes for this batch if not provided.
if anchor_boxes is None:
_, image_height, image_width, _ = images.get_shape().as_list()
anchor_boxes = anchor.Anchor(
min_level=self._config_dict['min_level'],
max_level=self._config_dict['max_level'],
num_scales=self._config_dict['num_scales'],
aspect_ratios=self._config_dict['aspect_ratios'],
anchor_size=self._config_dict['anchor_size'],
image_size=(image_height, image_width)).multilevel_boxes
for l in anchor_boxes:
anchor_boxes[l] = tf.tile(
tf.expand_dims(anchor_boxes[l], axis=0),
[tf.shape(images)[0], 1, 1, 1])
# Post-processing.
final_results = self.detection_generator(
raw_boxes, raw_scores, anchor_boxes, image_shape, raw_attributes)
......
......@@ -33,37 +33,79 @@ from official.vision.beta.ops import anchor
class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(3, 3, 7, 3, [1.0], 50, False, 256, 4, 256, 32244949),
{
'use_separable_conv': True,
'build_anchor_boxes': True,
'is_training': False,
'has_att_heads': False
},
{
'use_separable_conv': False,
'build_anchor_boxes': True,
'is_training': False,
'has_att_heads': False
},
{
'use_separable_conv': False,
'build_anchor_boxes': False,
'is_training': False,
'has_att_heads': False
},
{
'use_separable_conv': False,
'build_anchor_boxes': False,
'is_training': True,
'has_att_heads': False
},
{
'use_separable_conv': False,
'build_anchor_boxes': True,
'is_training': True,
'has_att_heads': True
},
{
'use_separable_conv': False,
'build_anchor_boxes': True,
'is_training': False,
'has_att_heads': True
},
)
def test_num_params(self,
num_classes,
min_level,
max_level,
num_scales,
aspect_ratios,
resnet_model_id,
use_separable_conv,
fpn_num_filters,
head_num_convs,
head_num_filters,
expected_num_params):
def test_build_model(self, use_separable_conv, build_anchor_boxes,
is_training, has_att_heads):
num_classes = 3
min_level = 3
max_level = 7
num_scales = 3
aspect_ratios = [1.0]
anchor_size = 3
fpn_num_filters = 256
head_num_convs = 4
head_num_filters = 256
num_anchors_per_location = num_scales * len(aspect_ratios)
image_size = 384
images = np.random.rand(2, image_size, image_size, 3)
image_shape = np.array([[image_size, image_size], [image_size, image_size]])
anchor_boxes = anchor.Anchor(
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=3,
image_size=(image_size, image_size)).multilevel_boxes
for l in anchor_boxes:
anchor_boxes[l] = tf.tile(
tf.expand_dims(anchor_boxes[l], axis=0), [2, 1, 1, 1])
if build_anchor_boxes:
anchor_boxes = anchor.Anchor(
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size,
image_size=(image_size, image_size)).multilevel_boxes
for l in anchor_boxes:
anchor_boxes[l] = tf.tile(
tf.expand_dims(anchor_boxes[l], axis=0), [2, 1, 1, 1])
else:
anchor_boxes = None
backbone = resnet.ResNet(model_id=resnet_model_id)
if has_att_heads:
attribute_heads = {'depth': ('regression', 1)}
else:
attribute_heads = None
backbone = resnet.ResNet(model_id=50)
decoder = fpn.FPN(
input_specs=backbone.output_specs,
min_level=min_level,
......@@ -74,6 +116,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
min_level=min_level,
max_level=max_level,
num_classes=num_classes,
attribute_heads=attribute_heads,
num_anchors_per_location=num_anchors_per_location,
use_separable_conv=use_separable_conv,
num_convs=head_num_convs,
......@@ -84,10 +127,14 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
backbone=backbone,
decoder=decoder,
head=head,
detection_generator=generator)
detection_generator=generator,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size)
_ = model(images, image_shape, anchor_boxes, training=True)
self.assertEqual(expected_num_params, model.count_params())
_ = model(images, image_shape, anchor_boxes, training=is_training)
@combinations.generate(
combinations.combine(
......@@ -226,7 +273,12 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
backbone=backbone,
decoder=decoder,
head=head,
detection_generator=generator)
detection_generator=generator,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=3)
config = model.get_config()
new_model = retinanet_model.RetinaNetModel.from_config(config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册