提交 6e897a35 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 341693911
上级 041f6976
......@@ -26,6 +26,8 @@ from official.modeling import hyperparams
class ResNet(hyperparams.Config):
"""ResNet config."""
model_id: int = 50
stem_type: str = 'v0'
se_ratio: float = 0.0
@dataclasses.dataclass
......
......@@ -35,11 +35,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
"""Test creation of ResNet models."""
network = backbones.ResNet(
model_id=model_id, norm_momentum=0.99, norm_epsilon=1e-5)
model_id=model_id, se_ratio=0.0, norm_momentum=0.99, norm_epsilon=1e-5)
backbone_config = backbones_cfg.Backbone(
type='resnet',
resnet=backbones_cfg.ResNet(model_id=model_id))
resnet=backbones_cfg.ResNet(model_id=model_id, se_ratio=0.0))
norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
......
......@@ -78,6 +78,8 @@ class ResNet(tf.keras.Model):
def __init__(self,
model_id,
input_specs=layers.InputSpec(shape=[None, None, None, 3]),
stem_type='v0',
se_ratio=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
......@@ -91,6 +93,9 @@ class ResNet(tf.keras.Model):
Args:
model_id: `int` depth of ResNet backbone model.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
stem_type: `str` stem type of ResNet. Default to `v0`. If set to `v1`,
use ResNet-C type stem (https://arxiv.org/abs/1812.01187).
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
......@@ -105,6 +110,8 @@ class ResNet(tf.keras.Model):
"""
self._model_id = model_id
self._input_specs = input_specs
self._stem_type = stem_type
self._se_ratio = se_ratio
self._use_sync_bn = use_sync_bn
self._activation = activation
self._norm_momentum = norm_momentum
......@@ -125,8 +132,13 @@ class ResNet(tf.keras.Model):
# Build ResNet.
inputs = tf.keras.Input(shape=input_specs.shape[1:])
if stem_type == 'v0':
x = layers.Conv2D(
filters=64, kernel_size=7, strides=2, use_bias=False, padding='same',
filters=64,
kernel_size=7,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
......@@ -135,6 +147,52 @@ class ResNet(tf.keras.Model):
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
elif stem_type == 'v1':
x = layers.Conv2D(
filters=32,
kernel_size=3,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
x = layers.Conv2D(
filters=32,
kernel_size=3,
strides=1,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
x = layers.Conv2D(
filters=64,
kernel_size=3,
strides=1,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
else:
raise ValueError('Stem type {} not supported.'.format(stem_type))
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
# TODO(xianzhi): keep a list of blocks to make blocks accessible.
......@@ -184,6 +242,7 @@ class ResNet(tf.keras.Model):
filters=filters,
strides=strides,
use_projection=True,
se_ratio=self._se_ratio,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
......@@ -198,6 +257,7 @@ class ResNet(tf.keras.Model):
filters=filters,
strides=1,
use_projection=False,
se_ratio=self._se_ratio,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
......@@ -212,7 +272,9 @@ class ResNet(tf.keras.Model):
def get_config(self):
config_dict = {
'model_id': self._model_id,
'stem_type': self._stem_type,
'activation': self._activation,
'se_ratio': self._se_ratio,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon,
......@@ -247,6 +309,8 @@ def build_resnet(
return ResNet(
model_id=backbone_cfg.model_id,
input_specs=input_specs,
stem_type=backbone_cfg.stem_type,
se_ratio=backbone_cfg.se_ratio,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
......
......@@ -83,6 +83,21 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
network = resnet.ResNet(model_id=50, use_sync_bn=use_sync_bn)
_ = network(inputs)
@parameterized.parameters(
(128, 34, 1, 'v0', None),
(128, 34, 1, 'v1', 0.25),
(128, 50, 4, 'v0', None),
(128, 50, 4, 'v1', 0.25),
)
def test_resnet_addons(self, input_size, model_id, endpoint_filter_scale,
stem_type, se_ratio):
"""Test creation of ResNet family models."""
tf.keras.backend.set_image_data_format('channels_last')
network = resnet.ResNet(
model_id=model_id, stem_type=stem_type, se_ratio=se_ratio)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
_ = network(inputs)
@parameterized.parameters(1, 3, 4)
def test_input_specs(self, input_dim):
"""Test different input feature dimensions."""
......@@ -98,6 +113,8 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
# Create a network object that sets all of its config options.
kwargs = dict(
model_id=50,
stem_type='v0',
se_ratio=None,
use_sync_bn=False,
activation='relu',
norm_momentum=0.99,
......
......@@ -62,6 +62,7 @@ class ResidualBlock(tf.keras.layers.Layer):
filters,
strides,
use_projection=False,
se_ratio=None,
stochastic_depth_drop_rate=None,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
......@@ -82,6 +83,7 @@ class ResidualBlock(tf.keras.layers.Layer):
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers.
......@@ -101,6 +103,7 @@ class ResidualBlock(tf.keras.layers.Layer):
self._filters = filters
self._strides = strides
self._use_projection = use_projection
self._se_ratio = se_ratio
self._use_sync_bn = use_sync_bn
self._activation = activation
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
......@@ -163,6 +166,17 @@ class ResidualBlock(tf.keras.layers.Layer):
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation(
in_filters=self._filters,
out_filters=self._filters,
se_ratio=self._se_ratio,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
else:
self._squeeze_excitation = None
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
......@@ -176,6 +190,7 @@ class ResidualBlock(tf.keras.layers.Layer):
'filters': self._filters,
'strides': self._strides,
'use_projection': self._use_projection,
'se_ratio': self._se_ratio,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
......@@ -201,6 +216,9 @@ class ResidualBlock(tf.keras.layers.Layer):
x = self._conv2(x)
x = self._norm2(x)
if self._squeeze_excitation:
x = self._squeeze_excitation(x)
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
......@@ -216,6 +234,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
strides,
dilation_rate=1,
use_projection=False,
se_ratio=None,
stochastic_depth_drop_rate=None,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
......@@ -237,6 +256,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers.
......@@ -257,6 +277,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._strides = strides
self._dilation_rate = dilation_rate
self._use_projection = use_projection
self._se_ratio = se_ratio
self._use_sync_bn = use_sync_bn
self._activation = activation
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
......@@ -331,6 +352,17 @@ class BottleneckBlock(tf.keras.layers.Layer):
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation(
in_filters=self._filters * 4,
out_filters=self._filters * 4,
se_ratio=self._se_ratio,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
else:
self._squeeze_excitation = None
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
......@@ -345,6 +377,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
'strides': self._strides,
'dilation_rate': self._dilation_rate,
'use_projection': self._use_projection,
'se_ratio': self._se_ratio,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
......@@ -374,6 +407,9 @@ class BottleneckBlock(tf.keras.layers.Layer):
x = self._conv3(x)
x = self._norm3(x)
if self._squeeze_excitation:
x = self._squeeze_excitation(x)
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
......
......@@ -39,11 +39,11 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
class NNBlocksTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(nn_blocks.ResidualBlock, 1, False, 0.0),
(nn_blocks.ResidualBlock, 2, True, 0.2),
(nn_blocks.ResidualBlock, 1, False, 0.0, None),
(nn_blocks.ResidualBlock, 2, True, 0.2, 0.25),
)
def test_residual_block_creation(
self, block_fn, strides, use_projection, stochastic_depth_drop_rate):
def test_residual_block_creation(self, block_fn, strides, use_projection,
stochastic_depth_drop_rate, se_ratio):
input_size = 128
filter_size = 256
inputs = tf.keras.Input(
......@@ -52,6 +52,7 @@ class NNBlocksTest(parameterized.TestCase, tf.test.TestCase):
filter_size,
strides,
use_projection=use_projection,
se_ratio=se_ratio,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
)
......@@ -62,11 +63,11 @@ class NNBlocksTest(parameterized.TestCase, tf.test.TestCase):
features.shape.as_list())
@parameterized.parameters(
(nn_blocks.BottleneckBlock, 1, False, 0.0),
(nn_blocks.BottleneckBlock, 2, True, 0.2),
(nn_blocks.BottleneckBlock, 1, False, 0.0, None),
(nn_blocks.BottleneckBlock, 2, True, 0.2, 0.25),
)
def test_bottleneck_block_creation(
self, block_fn, strides, use_projection, stochastic_depth_drop_rate):
def test_bottleneck_block_creation(self, block_fn, strides, use_projection,
stochastic_depth_drop_rate, se_ratio):
input_size = 128
filter_size = 256
inputs = tf.keras.Input(
......@@ -75,8 +76,8 @@ class NNBlocksTest(parameterized.TestCase, tf.test.TestCase):
filter_size,
strides,
use_projection=use_projection,
stochastic_depth_drop_rate=stochastic_depth_drop_rate
)
se_ratio=se_ratio,
stochastic_depth_drop_rate=stochastic_depth_drop_rate)
features = block(inputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册