提交 35ed75e3 编写于 作者: F Fan Yang 提交者: A. Unique TensorFlower

Add option to disable sharing convs across levels in RetinaNet head.

Also significantly refactor the code to move component builder to individual functions and reduce duplicate code to improve readability.

PiperOrigin-RevId: 533598893
上级 8d92444c
...@@ -125,6 +125,7 @@ class RetinaNetHeadQuantized(tf.keras.layers.Layer): ...@@ -125,6 +125,7 @@ class RetinaNetHeadQuantized(tf.keras.layers.Layer):
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
num_params_per_anchor: int = 4, num_params_per_anchor: int = 4,
share_classification_heads: bool = False, share_classification_heads: bool = False,
share_level_convs: bool = True,
**kwargs): **kwargs):
"""Initializes a RetinaNet quantized head. """Initializes a RetinaNet quantized head.
...@@ -161,9 +162,14 @@ class RetinaNetHeadQuantized(tf.keras.layers.Layer): ...@@ -161,9 +162,14 @@ class RetinaNetHeadQuantized(tf.keras.layers.Layer):
share_classification_heads: A `bool` that indicates whethere sharing share_classification_heads: A `bool` that indicates whethere sharing
weights among the main and attribute classification heads. Not used in weights among the main and attribute classification heads. Not used in
the QAT model. the QAT model.
share_level_convs: An optional bool to enable sharing convs
across levels for classnet, boxnet, classifier and box regressor.
If True, convs will be shared across all levels. Not used in the QAT
model.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
del share_classification_heads del share_classification_heads
del share_level_convs
super().__init__(**kwargs) super().__init__(**kwargs)
self._config_dict = { self._config_dict = {
......
...@@ -126,6 +126,7 @@ class RetinaNetHead(hyperparams.Config): ...@@ -126,6 +126,7 @@ class RetinaNetHead(hyperparams.Config):
use_separable_conv: bool = False use_separable_conv: bool = False
attribute_heads: List[AttributeHead] = dataclasses.field(default_factory=list) attribute_heads: List[AttributeHead] = dataclasses.field(default_factory=list)
share_classification_heads: bool = False share_classification_heads: bool = False
share_level_convs: Optional[bool] = True
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -301,7 +301,9 @@ def build_retinanet( ...@@ -301,7 +301,9 @@ def build_retinanet(
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer,
share_level_convs=head_config.share_level_convs,
)
# Builds decoder and head so that their trainable weights are initialized # Builds decoder and head so that their trainable weights are initialized
if decoder: if decoder:
......
...@@ -46,7 +46,9 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -46,7 +46,9 @@ class RetinaNetHead(tf.keras.layers.Layer):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
num_params_per_anchor: int = 4, num_params_per_anchor: int = 4,
**kwargs): share_level_convs: bool = True,
**kwargs,
):
"""Initializes a RetinaNet head. """Initializes a RetinaNet head.
Args: Args:
...@@ -80,9 +82,12 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -80,9 +82,12 @@ class RetinaNetHead(tf.keras.layers.Layer):
box. For example, `num_params_per_anchor` would be 4 for axis-aligned box. For example, `num_params_per_anchor` would be 4 for axis-aligned
anchor boxes specified by their y-centers, x-centers, heights, and anchor boxes specified by their y-centers, x-centers, heights, and
widths. widths.
share_level_convs: An optional bool to enable sharing convs
across levels for classnet, boxnet, classifier and box regressor.
If True, convs will be shared across all levels.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(RetinaNetHead, self).__init__(**kwargs) super().__init__(**kwargs)
self._config_dict = { self._config_dict = {
'min_level': min_level, 'min_level': min_level,
'max_level': max_level, 'max_level': max_level,
...@@ -100,6 +105,7 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -100,6 +105,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
'kernel_regularizer': kernel_regularizer, 'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer, 'bias_regularizer': bias_regularizer,
'num_params_per_anchor': num_params_per_anchor, 'num_params_per_anchor': num_params_per_anchor,
'share_level_convs': share_level_convs,
} }
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
...@@ -159,9 +165,17 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -159,9 +165,17 @@ class RetinaNetHead(tf.keras.layers.Layer):
'kernel_regularizer': self._config_dict['kernel_regularizer'], 'kernel_regularizer': self._config_dict['kernel_regularizer'],
}) })
if not self._config_dict['attribute_heads']: if self._config_dict['attribute_heads']:
return self._init_attribute_kwargs()
def _conv_kwargs_new_kernel_init(self, conv_kwargs):
if 'kernel_initializer' in conv_kwargs:
conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer(
conv_kwargs['kernel_initializer']
)
return conv_kwargs
def _init_attribute_kwargs(self):
self._attribute_kwargs = [] self._attribute_kwargs = []
for att_config in self._config_dict['attribute_heads']: for att_config in self._config_dict['attribute_heads']:
att_type = att_config['type'] att_type = att_config['type']
...@@ -210,12 +224,146 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -210,12 +224,146 @@ class RetinaNetHead(tf.keras.layers.Layer):
}) })
self._attribute_kwargs.append(att_predictor_kwargs) self._attribute_kwargs.append(att_predictor_kwargs)
def _conv_kwargs_new_kernel_init(self, conv_kwargs): def _apply_prediction_tower(self, features, convs, norms) -> tf.Tensor:
if 'kernel_initializer' in conv_kwargs: x = features
conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer( for conv, norm in zip(convs, norms):
conv_kwargs['kernel_initializer'] x = conv(x)
x = norm(x)
x = self._activation(x)
return x
def _apply_attribute_net(
self, attributes, level, level_idx, this_level_features, classnet_x
):
prediction_tower_output = {}
for att_config in self._config_dict['attribute_heads']:
att_name = att_config['name']
att_type = att_config['type']
if (
self._config_dict['share_classification_heads']
and att_type == 'classification'
):
attributes[att_name][str(level)] = self._att_predictors[att_name](
classnet_x
)
else:
def _apply_attribute_prediction_tower(
atttribute_name, features, feature_level
):
return self._apply_prediction_tower(
features,
self._att_convs[atttribute_name],
self._att_norms[atttribute_name][feature_level],
)
prediction_tower_name = att_config['prediction_tower_name']
if not prediction_tower_name:
attributes[att_name][str(level)] = self._att_predictors[att_name](
_apply_attribute_prediction_tower(
att_name, this_level_features, level_idx
)
)
else:
if prediction_tower_name not in prediction_tower_output:
prediction_tower_output[prediction_tower_name] = (
_apply_attribute_prediction_tower(
att_name, this_level_features, level_idx
)
)
attributes[att_name][str(level)] = self._att_predictors[att_name](
prediction_tower_output[prediction_tower_name]
)
def _build_prediction_tower(
self, net_name, predictor_name, conv_op, bn_op, predictor_kwargs
):
"""Builds the prediction tower. Convs across levels can be shared or not."""
convs = []
norms = []
for level in range(
self._config_dict['min_level'], self._config_dict['max_level'] + 1
):
if not self._config_dict['share_level_convs']:
this_level_convs = []
this_level_norms = []
for i in range(self._config_dict['num_convs']):
conv_kwargs = self._conv_kwargs_new_kernel_init(self._conv_kwargs)
if not self._config_dict['share_level_convs']:
# Do not share convs.
this_level_convs.append(
conv_op(name=f'{net_name}-conv_{level}_{i}', **conv_kwargs)
)
elif level == self._config_dict['min_level']:
convs.append(conv_op(name=f'{net_name}-conv_{i}', **conv_kwargs))
this_level_norms.append(
bn_op(name=f'{net_name}-conv-norm_{level}_{i}', **self._bn_kwargs)
)
norms.append(this_level_norms)
if not self._config_dict['share_level_convs']:
convs.append(this_level_convs)
# Create predictors after additional convs.
if self._config_dict['share_level_convs']:
predictors = conv_op(name=predictor_name, **predictor_kwargs)
else:
predictors = []
for level in range(
self._config_dict['min_level'], self._config_dict['max_level'] + 1
):
predictors.append(
conv_op(name=f'{predictor_name}-{level}', **predictor_kwargs)
)
return convs, norms, predictors
def _build_attribute_net(self, conv_op, bn_op):
self._att_predictors = {}
self._att_convs = {}
self._att_norms = {}
for att_config, att_predictor_kwargs in zip(
self._config_dict['attribute_heads'], self._attribute_kwargs
):
att_name = att_config['name']
att_num_convs = (
att_config.get('num_convs') or self._config_dict['num_convs']
)
att_num_filters = (
att_config.get('num_filters') or self._config_dict['num_filters']
)
if att_num_convs < 0:
raise ValueError(f'Invalid `num_convs` {att_num_convs} for {att_name}.')
if att_num_filters < 0:
raise ValueError(
f'Invalid `num_filters` {att_num_filters} for {att_name}.'
)
att_conv_kwargs = self._conv_kwargs.copy()
att_conv_kwargs['filters'] = att_num_filters
att_convs_i = []
att_norms_i = []
# Build conv and norm layers.
for level in range(
self._config_dict['min_level'], self._config_dict['max_level'] + 1
):
this_level_att_norms = []
for i in range(att_num_convs):
if level == self._config_dict['min_level']:
att_conv_name = '{}-conv_{}'.format(att_name, i)
att_convs_i.append(conv_op(name=att_conv_name, **att_conv_kwargs))
att_norm_name = '{}-conv-norm_{}_{}'.format(att_name, level, i)
this_level_att_norms.append(
bn_op(name=att_norm_name, **self._bn_kwargs)
)
att_norms_i.append(this_level_att_norms)
self._att_convs[att_name] = att_convs_i
self._att_norms[att_name] = att_norms_i
# Build the final prediction layer.
self._att_predictors[att_name] = conv_op(
name='{}_attributes'.format(att_name), **att_predictor_kwargs
) )
return conv_kwargs
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]): def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the head.""" """Creates the variables of the head."""
...@@ -231,94 +379,24 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -231,94 +379,24 @@ class RetinaNetHead(tf.keras.layers.Layer):
) )
# Class net. # Class net.
self._cls_convs = [] self._cls_convs, self._cls_norms, self._classifier = (
self._cls_norms = [] self._build_prediction_tower(
for level in range( 'classnet', 'scores', conv_op, bn_op, self._classifier_kwargs
self._config_dict['min_level'], self._config_dict['max_level'] + 1):
this_level_cls_norms = []
for i in range(self._config_dict['num_convs']):
if level == self._config_dict['min_level']:
cls_conv_name = 'classnet-conv_{}'.format(i)
conv_kwargs = self._conv_kwargs_new_kernel_init(self._conv_kwargs)
self._cls_convs.append(conv_op(name=cls_conv_name, **conv_kwargs))
cls_norm_name = 'classnet-conv-norm_{}_{}'.format(level, i)
this_level_cls_norms.append(
bn_op(name=cls_norm_name, **self._bn_kwargs)
) )
self._cls_norms.append(this_level_cls_norms) )
self._classifier = conv_op(name='scores', **self._classifier_kwargs)
# Box net. # Box net.
self._box_convs = [] self._box_convs, self._box_norms, self._box_regressor = (
self._box_norms = [] self._build_prediction_tower(
for level in range( 'boxnet', 'boxes', conv_op, bn_op, self._box_regressor_kwargs
self._config_dict['min_level'], self._config_dict['max_level'] + 1):
this_level_box_norms = []
for i in range(self._config_dict['num_convs']):
if level == self._config_dict['min_level']:
box_conv_name = 'boxnet-conv_{}'.format(i)
conv_kwargs = self._conv_kwargs_new_kernel_init(self._conv_kwargs)
self._box_convs.append(conv_op(name=box_conv_name, **conv_kwargs))
box_norm_name = 'boxnet-conv-norm_{}_{}'.format(level, i)
this_level_box_norms.append(
bn_op(name=box_norm_name, **self._bn_kwargs)
) )
self._box_norms.append(this_level_box_norms) )
self._box_regressor = conv_op(name='boxes', **self._box_regressor_kwargs)
# Attribute learning nets. # Attribute learning nets.
if self._config_dict['attribute_heads']: if self._config_dict['attribute_heads']:
self._att_predictors = {} self._build_attribute_net(conv_op, bn_op)
self._att_convs = {}
self._att_norms = {}
for att_config, att_predictor_kwargs in zip( super().build(input_shape)
self._config_dict['attribute_heads'], self._attribute_kwargs
):
att_name = att_config['name']
att_num_convs = (
att_config.get('num_convs') or self._config_dict['num_convs']
)
att_num_filters = (
att_config.get('num_filters') or self._config_dict['num_filters']
)
if att_num_convs < 0:
raise ValueError(
f'Invalid `num_convs` {att_num_convs} for {att_name}.'
)
if att_num_filters < 0:
raise ValueError(
f'Invalid `num_filters` {att_num_filters} for {att_name}.'
)
att_conv_kwargs = self._conv_kwargs.copy()
att_conv_kwargs['filters'] = att_num_filters
att_convs_i = []
att_norms_i = []
# Build conv and norm layers.
for level in range(self._config_dict['min_level'],
self._config_dict['max_level'] + 1):
this_level_att_norms = []
for i in range(att_num_convs):
if level == self._config_dict['min_level']:
att_conv_name = '{}-conv_{}'.format(att_name, i)
conv_kwargs = self._conv_kwargs_new_kernel_init(self._conv_kwargs)
att_convs_i.append(conv_op(name=att_conv_name, **att_conv_kwargs))
att_norm_name = '{}-conv-norm_{}_{}'.format(att_name, level, i)
this_level_att_norms.append(
bn_op(name=att_norm_name, **self._bn_kwargs)
)
att_norms_i.append(this_level_att_norms)
self._att_convs[att_name] = att_convs_i
self._att_norms[att_name] = att_norms_i
# Build the final prediction layer.
self._att_predictors[att_name] = conv_op(
name='{}_attributes'.format(att_name), **att_predictor_kwargs)
super(RetinaNetHead, self).build(input_shape)
def call(self, features: Mapping[str, tf.Tensor]): def call(self, features: Mapping[str, tf.Tensor]):
"""Forward pass of the RetinaNet head. """Forward pass of the RetinaNet head.
...@@ -366,57 +444,35 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -366,57 +444,35 @@ class RetinaNetHead(tf.keras.layers.Layer):
self._config_dict['max_level'] + 1)): self._config_dict['max_level'] + 1)):
this_level_features = features[str(level)] this_level_features = features[str(level)]
# class net. if self._config_dict['share_level_convs']:
x = this_level_features cls_convs = self._cls_convs
for conv, norm in zip(self._cls_convs, self._cls_norms[i]): box_convs = self._box_convs
x = conv(x) classifier = self._classifier
x = norm(x) box_regressor = self._box_regressor
x = self._activation(x) else:
cls_convs = self._cls_convs[i]
box_convs = self._box_convs[i]
classifier = self._classifier[i]
box_regressor = self._box_regressor[i]
# Apply class net.
x = self._apply_prediction_tower(
this_level_features, cls_convs, self._cls_norms[i]
)
scores[str(level)] = classifier(x)
classnet_x = x classnet_x = x
scores[str(level)] = self._classifier(classnet_x)
# box net. # Apply box net.
x = this_level_features x = self._apply_prediction_tower(
for conv, norm in zip(self._box_convs, self._box_norms[i]): this_level_features, box_convs, self._box_norms[i]
x = conv(x) )
x = norm(x) boxes[str(level)] = box_regressor(x)
x = self._activation(x)
boxes[str(level)] = self._box_regressor(x)
# attribute nets. # Apply attribute nets.
if self._config_dict['attribute_heads']: if self._config_dict['attribute_heads']:
prediction_tower_output = {} self._apply_attribute_net(
for att_config in self._config_dict['attribute_heads']: attributes, level, i, this_level_features, classnet_x
att_name = att_config['name'] )
att_type = att_config['type']
if self._config_dict[
'share_classification_heads'] and att_type == 'classification':
attributes[att_name][str(level)] = self._att_predictors[att_name](
classnet_x)
else:
def build_prediction_tower(atttribute_name, features,
feature_level):
x = features
for conv, norm in zip(
self._att_convs[atttribute_name],
self._att_norms[atttribute_name][feature_level]):
x = conv(x)
x = norm(x)
x = self._activation(x)
return x
prediction_tower_name = att_config['prediction_tower_name']
if not prediction_tower_name:
attributes[att_name][str(level)] = self._att_predictors[att_name](
build_prediction_tower(att_name, this_level_features, i))
else:
if prediction_tower_name not in prediction_tower_output:
prediction_tower_output[
prediction_tower_name] = build_prediction_tower(
att_name, this_level_features, i)
attributes[att_name][str(level)] = self._att_predictors[att_name](
prediction_tower_output[prediction_tower_name])
return scores, boxes, attributes return scores, boxes, attributes
......
...@@ -22,6 +22,7 @@ from absl.testing import parameterized ...@@ -22,6 +22,7 @@ from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations
from official.vision.modeling.heads import dense_prediction_heads from official.vision.modeling.heads import dense_prediction_heads
...@@ -69,20 +70,54 @@ def get_attribute_heads(att_head_type): ...@@ -69,20 +70,54 @@ def get_attribute_heads(att_head_type):
class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase): class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @combinations.generate(
(False, False, False, None, False), combinations.combine(
(False, True, False, None, False), use_separable_conv=[True, False],
(True, False, True, 'regression_head', False), use_sync_bn=[True, False],
(True, True, True, 'classification_head', True), share_level_convs=[True, False],
(True, True, True, 'shared_prediction_tower_attribute_heads', False), )
) )
def test_forward(self, use_separable_conv, use_sync_bn, has_att_heads, def test_forward_without_attribute_head(
att_head_type, share_classification_heads): self, use_separable_conv, use_sync_bn, share_level_convs
if has_att_heads: ):
attribute_heads = get_attribute_heads(att_head_type) retinanet_head = dense_prediction_heads.RetinaNetHead(
else: min_level=3,
attribute_heads = None max_level=4,
num_classes=3,
num_anchors_per_location=3,
num_convs=2,
num_filters=256,
attribute_heads=None,
use_separable_conv=use_separable_conv,
activation='relu',
use_sync_bn=use_sync_bn,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
share_level_convs=share_level_convs,
)
features = {
'3': np.random.rand(2, 128, 128, 16),
'4': np.random.rand(2, 64, 64, 16),
}
scores, boxes, _ = retinanet_head(features)
self.assertAllEqual(scores['3'].numpy().shape, [2, 128, 128, 9])
self.assertAllEqual(scores['4'].numpy().shape, [2, 64, 64, 9])
self.assertAllEqual(boxes['3'].numpy().shape, [2, 128, 128, 12])
self.assertAllEqual(boxes['4'].numpy().shape, [2, 64, 64, 12])
@parameterized.parameters(
(False, 'regression_head', False),
(True, 'classification_head', True),
(True, 'shared_prediction_tower_attribute_heads', False),
)
def test_forward_with_attribute_head(
self,
use_sync_bn,
att_head_type,
share_classification_heads,
):
retinanet_head = dense_prediction_heads.RetinaNetHead( retinanet_head = dense_prediction_heads.RetinaNetHead(
min_level=3, min_level=3,
max_level=4, max_level=4,
...@@ -90,9 +125,9 @@ class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase): ...@@ -90,9 +125,9 @@ class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
num_anchors_per_location=3, num_anchors_per_location=3,
num_convs=2, num_convs=2,
num_filters=256, num_filters=256,
attribute_heads=attribute_heads, attribute_heads=get_attribute_heads(att_head_type),
share_classification_heads=share_classification_heads, share_classification_heads=share_classification_heads,
use_separable_conv=use_separable_conv, use_separable_conv=True,
activation='relu', activation='relu',
use_sync_bn=use_sync_bn, use_sync_bn=use_sync_bn,
norm_momentum=0.99, norm_momentum=0.99,
...@@ -109,13 +144,12 @@ class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase): ...@@ -109,13 +144,12 @@ class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(scores['4'].numpy().shape, [2, 64, 64, 9]) self.assertAllEqual(scores['4'].numpy().shape, [2, 64, 64, 9])
self.assertAllEqual(boxes['3'].numpy().shape, [2, 128, 128, 12]) self.assertAllEqual(boxes['3'].numpy().shape, [2, 128, 128, 12])
self.assertAllEqual(boxes['4'].numpy().shape, [2, 64, 64, 12]) self.assertAllEqual(boxes['4'].numpy().shape, [2, 64, 64, 12])
if has_att_heads: for att in attributes.values():
for att in attributes.values(): self.assertAllEqual(att['3'].numpy().shape, [2, 128, 128, 3])
self.assertAllEqual(att['3'].numpy().shape, [2, 128, 128, 3]) self.assertAllEqual(att['4'].numpy().shape, [2, 64, 64, 3])
self.assertAllEqual(att['4'].numpy().shape, [2, 64, 64, 3]) if att_head_type == 'regression_head':
if att_head_type == 'regression_head': self.assertLen(retinanet_head._att_convs['depth'], 1)
self.assertLen(retinanet_head._att_convs['depth'], 1) self.assertEqual(retinanet_head._att_convs['depth'][0].filters, 128)
self.assertEqual(retinanet_head._att_convs['depth'][0].filters, 128)
@unittest.expectedFailure @unittest.expectedFailure
def test_forward_shared_prediction_tower_with_share_classification_heads( def test_forward_shared_prediction_tower_with_share_classification_heads(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册