提交 441e5ae0 编写于 作者: L Liangzhe Yuan 提交者: A. Unique TensorFlower

#movinet Support 'none' squeeze and excitation layers in Movinet.

PiperOrigin-RevId: 424743840
上级 a033df77
......@@ -338,7 +338,7 @@ class Movinet(tf.keras.Model):
3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed
by 5x1x1 conv).
se_type: '3d', '2d', or '2plus3d'. '3d' uses the default 3D
se_type: '3d', '2d', '2plus3d' or 'none'. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
......@@ -369,7 +369,7 @@ class Movinet(tf.keras.Model):
if conv_type not in ('3d', '2plus1d', '3d_2plus1d'):
raise ValueError('Unknown conv type: {}'.format(conv_type))
if se_type not in ('3d', '2d', '2plus3d'):
if se_type not in ('3d', '2d', '2plus3d', 'none'):
raise ValueError('Unknown squeeze excitation type: {}'.format(se_type))
self._model_id = model_id
......@@ -602,10 +602,11 @@ class Movinet(tf.keras.Model):
expand_filters,
)
states[f'{prefix}_pool_buffer'] = (
input_shape[0], 1, 1, 1, expand_filters,
)
states[f'{prefix}_pool_frame_count'] = (1,)
if '3d' in self._se_type:
states[f'{prefix}_pool_buffer'] = (
input_shape[0], 1, 1, 1, expand_filters,
)
states[f'{prefix}_pool_frame_count'] = (1,)
if use_positional_encoding:
name = f'{prefix}_pos_enc_frame_count'
......
......@@ -885,7 +885,8 @@ class MobileBottleneck(tf.keras.layers.Layer):
x = self._expansion_layer(inputs)
x, states = self._feature_layer(x, states=states)
x, states = self._attention_layer(x, states=states)
if self._attention_layer is not None:
x, states = self._attention_layer(x, states=states)
x = self._projection_layer(x)
# Add identity so that the ops are ordered as written. This is useful for,
......@@ -1136,18 +1137,20 @@ class MovinetBlock(tf.keras.layers.Layer):
batch_norm_momentum=self._batch_norm_momentum,
batch_norm_epsilon=self._batch_norm_epsilon,
name='projection')
self._attention = StreamSqueezeExcitation(
se_hidden_filters,
se_type=se_type,
activation=activation,
gating_activation=gating_activation,
causal=self._causal,
conv_type=conv_type,
use_positional_encoding=use_positional_encoding,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
state_prefix=state_prefix,
name='se')
self._attention = None
if se_type != 'none':
self._attention = StreamSqueezeExcitation(
se_hidden_filters,
se_type=se_type,
activation=activation,
gating_activation=gating_activation,
causal=self._causal,
conv_type=conv_type,
use_positional_encoding=use_positional_encoding,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
state_prefix=state_prefix,
name='se')
def get_config(self):
"""Returns a dictionary containing the config used for initialization."""
......
......@@ -378,6 +378,35 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_stream_movinet_block_none_se(self):
block = movinet_layers.MovinetBlock(
out_filters=3,
expand_filters=6,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
causal=True,
se_type='none',
state_prefix='test',
)
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
expected, expected_states = block(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = block(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllEqual(list(expected_states.keys()), ['test_stream_buffer'])
def test_stream_classifier_head(self):
head = movinet_layers.Head(project_filters=5)
classifier_head = movinet_layers.ClassifierHead(
......
......@@ -99,6 +99,49 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_movinet_stream_nse(self):
"""Test if the backbone can be run in streaming mode w/o SE layer."""
tf.keras.backend.set_image_data_format('channels_last')
backbone = movinet.Movinet(
model_id='a0',
causal=True,
use_external_states=True,
se_type='none',
)
inputs = tf.ones([1, 5, 128, 128, 3])
init_states = backbone.init_states(tf.shape(inputs))
expected_endpoints, _ = backbone({**init_states, 'image': inputs})
frames = tf.split(inputs, inputs.shape[1], axis=1)
states = init_states
for frame in frames:
output, states = backbone({**states, 'image': frame})
predicted_endpoints = output
predicted = predicted_endpoints['head']
# The expected final output is simply the mean across frames
expected = expected_endpoints['head']
expected = tf.reduce_mean(expected, 1, keepdims=True)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
# Check contents in the states dictionary.
state_keys = list(init_states.keys())
self.assertIn('state_head_pool_buffer', state_keys)
self.assertIn('state_head_pool_frame_count', state_keys)
state_keys.remove('state_head_pool_buffer')
state_keys.remove('state_head_pool_frame_count')
# From now on, there are only 'stream_buffer' for the convolutions.
for state_key in state_keys:
self.assertIn(
'stream_buffer', state_key,
msg=f'Expecting stream_buffer only, found {state_key}')
def test_movinet_2plus1d_stream(self):
tf.keras.backend.set_image_data_format('channels_last')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册