提交 35ed75e3 编写于 作者: F Fan Yang 提交者: A. Unique TensorFlower

Add option to disable sharing convs across levels in RetinaNet head.

Also significantly refactor the code to move component builder to individual functions and reduce duplicate code to improve readability.

PiperOrigin-RevId: 533598893
上级 8d92444c
......@@ -125,6 +125,7 @@ class RetinaNetHeadQuantized(tf.keras.layers.Layer):
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
num_params_per_anchor: int = 4,
share_classification_heads: bool = False,
share_level_convs: bool = True,
**kwargs):
"""Initializes a RetinaNet quantized head.
......@@ -161,9 +162,14 @@ class RetinaNetHeadQuantized(tf.keras.layers.Layer):
share_classification_heads: A `bool` that indicates whethere sharing
weights among the main and attribute classification heads. Not used in
the QAT model.
share_level_convs: An optional bool to enable sharing convs
across levels for classnet, boxnet, classifier and box regressor.
If True, convs will be shared across all levels. Not used in the QAT
model.
**kwargs: Additional keyword arguments to be passed.
"""
del share_classification_heads
del share_level_convs
super().__init__(**kwargs)
self._config_dict = {
......
......@@ -126,6 +126,7 @@ class RetinaNetHead(hyperparams.Config):
use_separable_conv: bool = False
attribute_heads: List[AttributeHead] = dataclasses.field(default_factory=list)
share_classification_heads: bool = False
share_level_convs: Optional[bool] = True
@dataclasses.dataclass
......
......@@ -301,7 +301,9 @@ def build_retinanet(
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
kernel_regularizer=l2_regularizer,
share_level_convs=head_config.share_level_convs,
)
# Builds decoder and head so that their trainable weights are initialized
if decoder:
......
......@@ -46,7 +46,9 @@ class RetinaNetHead(tf.keras.layers.Layer):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
num_params_per_anchor: int = 4,
**kwargs):
share_level_convs: bool = True,
**kwargs,
):
"""Initializes a RetinaNet head.
Args:
......@@ -80,9 +82,12 @@ class RetinaNetHead(tf.keras.layers.Layer):
box. For example, `num_params_per_anchor` would be 4 for axis-aligned
anchor boxes specified by their y-centers, x-centers, heights, and
widths.
share_level_convs: An optional bool to enable sharing convs
across levels for classnet, boxnet, classifier and box regressor.
If True, convs will be shared across all levels.
**kwargs: Additional keyword arguments to be passed.
"""
super(RetinaNetHead, self).__init__(**kwargs)
super().__init__(**kwargs)
self._config_dict = {
'min_level': min_level,
'max_level': max_level,
......@@ -100,6 +105,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'num_params_per_anchor': num_params_per_anchor,
'share_level_convs': share_level_convs,
}
if tf.keras.backend.image_data_format() == 'channels_last':
......@@ -159,9 +165,17 @@ class RetinaNetHead(tf.keras.layers.Layer):
'kernel_regularizer': self._config_dict['kernel_regularizer'],
})
if not self._config_dict['attribute_heads']:
return
if self._config_dict['attribute_heads']:
self._init_attribute_kwargs()
def _conv_kwargs_new_kernel_init(self, conv_kwargs):
if 'kernel_initializer' in conv_kwargs:
conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer(
conv_kwargs['kernel_initializer']
)
return conv_kwargs
def _init_attribute_kwargs(self):
self._attribute_kwargs = []
for att_config in self._config_dict['attribute_heads']:
att_type = att_config['type']
......@@ -210,12 +224,146 @@ class RetinaNetHead(tf.keras.layers.Layer):
})
self._attribute_kwargs.append(att_predictor_kwargs)
def _conv_kwargs_new_kernel_init(self, conv_kwargs):
if 'kernel_initializer' in conv_kwargs:
conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer(
conv_kwargs['kernel_initializer']
def _apply_prediction_tower(self, features, convs, norms) -> tf.Tensor:
x = features
for conv, norm in zip(convs, norms):
x = conv(x)
x = norm(x)
x = self._activation(x)
return x
def _apply_attribute_net(
self, attributes, level, level_idx, this_level_features, classnet_x
):
prediction_tower_output = {}
for att_config in self._config_dict['attribute_heads']:
att_name = att_config['name']
att_type = att_config['type']
if (
self._config_dict['share_classification_heads']
and att_type == 'classification'
):
attributes[att_name][str(level)] = self._att_predictors[att_name](
classnet_x
)
else:
def _apply_attribute_prediction_tower(
atttribute_name, features, feature_level
):
return self._apply_prediction_tower(
features,
self._att_convs[atttribute_name],
self._att_norms[atttribute_name][feature_level],
)
prediction_tower_name = att_config['prediction_tower_name']
if not prediction_tower_name:
attributes[att_name][str(level)] = self._att_predictors[att_name](
_apply_attribute_prediction_tower(
att_name, this_level_features, level_idx
)
)
else:
if prediction_tower_name not in prediction_tower_output:
prediction_tower_output[prediction_tower_name] = (
_apply_attribute_prediction_tower(
att_name, this_level_features, level_idx
)
)
attributes[att_name][str(level)] = self._att_predictors[att_name](
prediction_tower_output[prediction_tower_name]
)
def _build_prediction_tower(
self, net_name, predictor_name, conv_op, bn_op, predictor_kwargs
):
"""Builds the prediction tower. Convs across levels can be shared or not."""
convs = []
norms = []
for level in range(
self._config_dict['min_level'], self._config_dict['max_level'] + 1
):
if not self._config_dict['share_level_convs']:
this_level_convs = []
this_level_norms = []
for i in range(self._config_dict['num_convs']):
conv_kwargs = self._conv_kwargs_new_kernel_init(self._conv_kwargs)
if not self._config_dict['share_level_convs']:
# Do not share convs.
this_level_convs.append(
conv_op(name=f'{net_name}-conv_{level}_{i}', **conv_kwargs)
)
elif level == self._config_dict['min_level']:
convs.append(conv_op(name=f'{net_name}-conv_{i}', **conv_kwargs))
this_level_norms.append(
bn_op(name=f'{net_name}-conv-norm_{level}_{i}', **self._bn_kwargs)
)
norms.append(this_level_norms)
if not self._config_dict['share_level_convs']:
convs.append(this_level_convs)
# Create predictors after additional convs.
if self._config_dict['share_level_convs']:
predictors = conv_op(name=predictor_name, **predictor_kwargs)
else:
predictors = []
for level in range(
self._config_dict['min_level'], self._config_dict['max_level'] + 1
):
predictors.append(
conv_op(name=f'{predictor_name}-{level}', **predictor_kwargs)
)
return convs, norms, predictors
def _build_attribute_net(self, conv_op, bn_op):
self._att_predictors = {}
self._att_convs = {}
self._att_norms = {}
for att_config, att_predictor_kwargs in zip(
self._config_dict['attribute_heads'], self._attribute_kwargs
):
att_name = att_config['name']
att_num_convs = (
att_config.get('num_convs') or self._config_dict['num_convs']
)
att_num_filters = (
att_config.get('num_filters') or self._config_dict['num_filters']
)
if att_num_convs < 0:
raise ValueError(f'Invalid `num_convs` {att_num_convs} for {att_name}.')
if att_num_filters < 0:
raise ValueError(
f'Invalid `num_filters` {att_num_filters} for {att_name}.'
)
att_conv_kwargs = self._conv_kwargs.copy()
att_conv_kwargs['filters'] = att_num_filters
att_convs_i = []
att_norms_i = []
# Build conv and norm layers.
for level in range(
self._config_dict['min_level'], self._config_dict['max_level'] + 1
):
this_level_att_norms = []
for i in range(att_num_convs):
if level == self._config_dict['min_level']:
att_conv_name = '{}-conv_{}'.format(att_name, i)
att_convs_i.append(conv_op(name=att_conv_name, **att_conv_kwargs))
att_norm_name = '{}-conv-norm_{}_{}'.format(att_name, level, i)
this_level_att_norms.append(
bn_op(name=att_norm_name, **self._bn_kwargs)
)
att_norms_i.append(this_level_att_norms)
self._att_convs[att_name] = att_convs_i
self._att_norms[att_name] = att_norms_i
# Build the final prediction layer.
self._att_predictors[att_name] = conv_op(
name='{}_attributes'.format(att_name), **att_predictor_kwargs
)
return conv_kwargs
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the head."""
......@@ -231,94 +379,24 @@ class RetinaNetHead(tf.keras.layers.Layer):
)
# Class net.
self._cls_convs = []
self._cls_norms = []
for level in range(
self._config_dict['min_level'], self._config_dict['max_level'] + 1):
this_level_cls_norms = []
for i in range(self._config_dict['num_convs']):
if level == self._config_dict['min_level']:
cls_conv_name = 'classnet-conv_{}'.format(i)
conv_kwargs = self._conv_kwargs_new_kernel_init(self._conv_kwargs)
self._cls_convs.append(conv_op(name=cls_conv_name, **conv_kwargs))
cls_norm_name = 'classnet-conv-norm_{}_{}'.format(level, i)
this_level_cls_norms.append(
bn_op(name=cls_norm_name, **self._bn_kwargs)
self._cls_convs, self._cls_norms, self._classifier = (
self._build_prediction_tower(
'classnet', 'scores', conv_op, bn_op, self._classifier_kwargs
)
self._cls_norms.append(this_level_cls_norms)
self._classifier = conv_op(name='scores', **self._classifier_kwargs)
)
# Box net.
self._box_convs = []
self._box_norms = []
for level in range(
self._config_dict['min_level'], self._config_dict['max_level'] + 1):
this_level_box_norms = []
for i in range(self._config_dict['num_convs']):
if level == self._config_dict['min_level']:
box_conv_name = 'boxnet-conv_{}'.format(i)
conv_kwargs = self._conv_kwargs_new_kernel_init(self._conv_kwargs)
self._box_convs.append(conv_op(name=box_conv_name, **conv_kwargs))
box_norm_name = 'boxnet-conv-norm_{}_{}'.format(level, i)
this_level_box_norms.append(
bn_op(name=box_norm_name, **self._bn_kwargs)
self._box_convs, self._box_norms, self._box_regressor = (
self._build_prediction_tower(
'boxnet', 'boxes', conv_op, bn_op, self._box_regressor_kwargs
)
self._box_norms.append(this_level_box_norms)
self._box_regressor = conv_op(name='boxes', **self._box_regressor_kwargs)
)
# Attribute learning nets.
if self._config_dict['attribute_heads']:
self._att_predictors = {}
self._att_convs = {}
self._att_norms = {}
self._build_attribute_net(conv_op, bn_op)
for att_config, att_predictor_kwargs in zip(
self._config_dict['attribute_heads'], self._attribute_kwargs
):
att_name = att_config['name']
att_num_convs = (
att_config.get('num_convs') or self._config_dict['num_convs']
)
att_num_filters = (
att_config.get('num_filters') or self._config_dict['num_filters']
)
if att_num_convs < 0:
raise ValueError(
f'Invalid `num_convs` {att_num_convs} for {att_name}.'
)
if att_num_filters < 0:
raise ValueError(
f'Invalid `num_filters` {att_num_filters} for {att_name}.'
)
att_conv_kwargs = self._conv_kwargs.copy()
att_conv_kwargs['filters'] = att_num_filters
att_convs_i = []
att_norms_i = []
# Build conv and norm layers.
for level in range(self._config_dict['min_level'],
self._config_dict['max_level'] + 1):
this_level_att_norms = []
for i in range(att_num_convs):
if level == self._config_dict['min_level']:
att_conv_name = '{}-conv_{}'.format(att_name, i)
conv_kwargs = self._conv_kwargs_new_kernel_init(self._conv_kwargs)
att_convs_i.append(conv_op(name=att_conv_name, **att_conv_kwargs))
att_norm_name = '{}-conv-norm_{}_{}'.format(att_name, level, i)
this_level_att_norms.append(
bn_op(name=att_norm_name, **self._bn_kwargs)
)
att_norms_i.append(this_level_att_norms)
self._att_convs[att_name] = att_convs_i
self._att_norms[att_name] = att_norms_i
# Build the final prediction layer.
self._att_predictors[att_name] = conv_op(
name='{}_attributes'.format(att_name), **att_predictor_kwargs)
super(RetinaNetHead, self).build(input_shape)
super().build(input_shape)
def call(self, features: Mapping[str, tf.Tensor]):
"""Forward pass of the RetinaNet head.
......@@ -366,57 +444,35 @@ class RetinaNetHead(tf.keras.layers.Layer):
self._config_dict['max_level'] + 1)):
this_level_features = features[str(level)]
# class net.
x = this_level_features
for conv, norm in zip(self._cls_convs, self._cls_norms[i]):
x = conv(x)
x = norm(x)
x = self._activation(x)
if self._config_dict['share_level_convs']:
cls_convs = self._cls_convs
box_convs = self._box_convs
classifier = self._classifier
box_regressor = self._box_regressor
else:
cls_convs = self._cls_convs[i]
box_convs = self._box_convs[i]
classifier = self._classifier[i]
box_regressor = self._box_regressor[i]
# Apply class net.
x = self._apply_prediction_tower(
this_level_features, cls_convs, self._cls_norms[i]
)
scores[str(level)] = classifier(x)
classnet_x = x
scores[str(level)] = self._classifier(classnet_x)
# box net.
x = this_level_features
for conv, norm in zip(self._box_convs, self._box_norms[i]):
x = conv(x)
x = norm(x)
x = self._activation(x)
boxes[str(level)] = self._box_regressor(x)
# Apply box net.
x = self._apply_prediction_tower(
this_level_features, box_convs, self._box_norms[i]
)
boxes[str(level)] = box_regressor(x)
# attribute nets.
# Apply attribute nets.
if self._config_dict['attribute_heads']:
prediction_tower_output = {}
for att_config in self._config_dict['attribute_heads']:
att_name = att_config['name']
att_type = att_config['type']
if self._config_dict[
'share_classification_heads'] and att_type == 'classification':
attributes[att_name][str(level)] = self._att_predictors[att_name](
classnet_x)
else:
def build_prediction_tower(atttribute_name, features,
feature_level):
x = features
for conv, norm in zip(
self._att_convs[atttribute_name],
self._att_norms[atttribute_name][feature_level]):
x = conv(x)
x = norm(x)
x = self._activation(x)
return x
prediction_tower_name = att_config['prediction_tower_name']
if not prediction_tower_name:
attributes[att_name][str(level)] = self._att_predictors[att_name](
build_prediction_tower(att_name, this_level_features, i))
else:
if prediction_tower_name not in prediction_tower_output:
prediction_tower_output[
prediction_tower_name] = build_prediction_tower(
att_name, this_level_features, i)
attributes[att_name][str(level)] = self._att_predictors[att_name](
prediction_tower_output[prediction_tower_name])
self._apply_attribute_net(
attributes, level, i, this_level_features, classnet_x
)
return scores, boxes, attributes
......
......@@ -22,6 +22,7 @@ from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from official.vision.modeling.heads import dense_prediction_heads
......@@ -69,20 +70,54 @@ def get_attribute_heads(att_head_type):
class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(False, False, False, None, False),
(False, True, False, None, False),
(True, False, True, 'regression_head', False),
(True, True, True, 'classification_head', True),
(True, True, True, 'shared_prediction_tower_attribute_heads', False),
@combinations.generate(
combinations.combine(
use_separable_conv=[True, False],
use_sync_bn=[True, False],
share_level_convs=[True, False],
)
)
def test_forward(self, use_separable_conv, use_sync_bn, has_att_heads,
att_head_type, share_classification_heads):
if has_att_heads:
attribute_heads = get_attribute_heads(att_head_type)
else:
attribute_heads = None
def test_forward_without_attribute_head(
self, use_separable_conv, use_sync_bn, share_level_convs
):
retinanet_head = dense_prediction_heads.RetinaNetHead(
min_level=3,
max_level=4,
num_classes=3,
num_anchors_per_location=3,
num_convs=2,
num_filters=256,
attribute_heads=None,
use_separable_conv=use_separable_conv,
activation='relu',
use_sync_bn=use_sync_bn,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
share_level_convs=share_level_convs,
)
features = {
'3': np.random.rand(2, 128, 128, 16),
'4': np.random.rand(2, 64, 64, 16),
}
scores, boxes, _ = retinanet_head(features)
self.assertAllEqual(scores['3'].numpy().shape, [2, 128, 128, 9])
self.assertAllEqual(scores['4'].numpy().shape, [2, 64, 64, 9])
self.assertAllEqual(boxes['3'].numpy().shape, [2, 128, 128, 12])
self.assertAllEqual(boxes['4'].numpy().shape, [2, 64, 64, 12])
@parameterized.parameters(
(False, 'regression_head', False),
(True, 'classification_head', True),
(True, 'shared_prediction_tower_attribute_heads', False),
)
def test_forward_with_attribute_head(
self,
use_sync_bn,
att_head_type,
share_classification_heads,
):
retinanet_head = dense_prediction_heads.RetinaNetHead(
min_level=3,
max_level=4,
......@@ -90,9 +125,9 @@ class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
num_anchors_per_location=3,
num_convs=2,
num_filters=256,
attribute_heads=attribute_heads,
attribute_heads=get_attribute_heads(att_head_type),
share_classification_heads=share_classification_heads,
use_separable_conv=use_separable_conv,
use_separable_conv=True,
activation='relu',
use_sync_bn=use_sync_bn,
norm_momentum=0.99,
......@@ -109,13 +144,12 @@ class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(scores['4'].numpy().shape, [2, 64, 64, 9])
self.assertAllEqual(boxes['3'].numpy().shape, [2, 128, 128, 12])
self.assertAllEqual(boxes['4'].numpy().shape, [2, 64, 64, 12])
if has_att_heads:
for att in attributes.values():
self.assertAllEqual(att['3'].numpy().shape, [2, 128, 128, 3])
self.assertAllEqual(att['4'].numpy().shape, [2, 64, 64, 3])
if att_head_type == 'regression_head':
self.assertLen(retinanet_head._att_convs['depth'], 1)
self.assertEqual(retinanet_head._att_convs['depth'][0].filters, 128)
for att in attributes.values():
self.assertAllEqual(att['3'].numpy().shape, [2, 128, 128, 3])
self.assertAllEqual(att['4'].numpy().shape, [2, 64, 64, 3])
if att_head_type == 'regression_head':
self.assertLen(retinanet_head._att_convs['depth'], 1)
self.assertEqual(retinanet_head._att_convs['depth'][0].filters, 128)
@unittest.expectedFailure
def test_forward_shared_prediction_tower_with_share_classification_heads(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册