提交 391f22eb 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 462680315
上级 b0a37140
......@@ -65,12 +65,11 @@ class DbofModel(tf.keras.Model):
norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: keyword arguments to be passed.
"""
del num_frames
self._self_setattr_tracking = False
self._config_dict = {
"input_specs": input_specs,
"num_classes": num_classes,
"num_frames": num_frames,
"params": params
}
self._num_classes = num_classes
......@@ -80,26 +79,23 @@ class DbofModel(tf.keras.Model):
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
if tf.keras.backend.image_data_format() == "channels_last":
bn_axis = -1
else:
bn_axis = 1
bn_axis = -1
# [batch_size x num_frames x num_features]
feature_size = input_specs.shape[-1]
# shape 'excluding' batch_size
model_input = tf.keras.Input(shape=self._input_specs.shape[1:])
reshaped_input = tf.reshape(model_input, [-1, feature_size])
tf.summary.histogram("input_hist", model_input)
input_data = model_input
# configure model
if params.add_batch_norm:
reshaped_input = self._norm(
input_data = self._norm(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
name="input_bn")(
reshaped_input)
input_data)
# activation = reshaped input * cluster weights
if params.cluster_size > 0:
......@@ -108,7 +104,7 @@ class DbofModel(tf.keras.Model):
kernel_regularizer=kernel_regularizer,
kernel_initializer=tf.random_normal_initializer(
stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))(
reshaped_input)
input_data)
if params.add_batch_norm:
activation = self._norm(
......@@ -142,7 +138,7 @@ class DbofModel(tf.keras.Model):
pooling_method=pooling_method,
hidden_layer_size=params.context_gate_cluster_bottleneck_size,
kernel_regularizer=kernel_regularizer)
activation = tf.reshape(activation, [-1, num_frames, params.cluster_size])
activation = utils.frame_pooling(activation, params.pooling_method)
# activation = activation * hidden1_weights
......
......@@ -34,7 +34,7 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
num_frames: number of frames.
feature_dims: indicates total dimension size of the features.
"""
input_specs = tf.keras.layers.InputSpec(shape=[num_frames, feature_dims])
input_specs = tf.keras.layers.InputSpec(shape=[None, None, feature_dims])
num_classes = 3862
model = yt8m_model.DbofModel(
......@@ -44,7 +44,7 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
input_specs=input_specs)
# batch = 2 -> arbitrary value for test
inputs = np.random.rand(2 * num_frames, feature_dims)
inputs = np.random.rand(2, num_frames, feature_dims)
logits = model(inputs)
self.assertAllEqual([2, num_classes], logits.numpy().shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册