From 391f22eb7a9c2dde6315c1018981fce9748cf70e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 22 Jul 2022 12:34:09 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 462680315 --- official/projects/yt8m/modeling/yt8m_model.py | 18 +++++++----------- .../projects/yt8m/modeling/yt8m_model_test.py | 4 ++-- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/official/projects/yt8m/modeling/yt8m_model.py b/official/projects/yt8m/modeling/yt8m_model.py index 24b67d8c5..339ba9fa7 100644 --- a/official/projects/yt8m/modeling/yt8m_model.py +++ b/official/projects/yt8m/modeling/yt8m_model.py @@ -65,12 +65,11 @@ class DbofModel(tf.keras.Model): norm_epsilon: A `float` added to variance to avoid dividing by zero. **kwargs: keyword arguments to be passed. """ - + del num_frames self._self_setattr_tracking = False self._config_dict = { "input_specs": input_specs, "num_classes": num_classes, - "num_frames": num_frames, "params": params } self._num_classes = num_classes @@ -80,26 +79,23 @@ class DbofModel(tf.keras.Model): self._norm = layers.experimental.SyncBatchNormalization else: self._norm = layers.BatchNormalization - if tf.keras.backend.image_data_format() == "channels_last": - bn_axis = -1 - else: - bn_axis = 1 + bn_axis = -1 # [batch_size x num_frames x num_features] feature_size = input_specs.shape[-1] # shape 'excluding' batch_size model_input = tf.keras.Input(shape=self._input_specs.shape[1:]) - reshaped_input = tf.reshape(model_input, [-1, feature_size]) tf.summary.histogram("input_hist", model_input) + input_data = model_input # configure model if params.add_batch_norm: - reshaped_input = self._norm( + input_data = self._norm( axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon, name="input_bn")( - reshaped_input) + input_data) # activation = reshaped input * cluster weights if params.cluster_size > 0: @@ -108,7 +104,7 @@ class DbofModel(tf.keras.Model): kernel_regularizer=kernel_regularizer, kernel_initializer=tf.random_normal_initializer( stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))( - reshaped_input) + input_data) if params.add_batch_norm: activation = self._norm( @@ -142,7 +138,7 @@ class DbofModel(tf.keras.Model): pooling_method=pooling_method, hidden_layer_size=params.context_gate_cluster_bottleneck_size, kernel_regularizer=kernel_regularizer) - activation = tf.reshape(activation, [-1, num_frames, params.cluster_size]) + activation = utils.frame_pooling(activation, params.pooling_method) # activation = activation * hidden1_weights diff --git a/official/projects/yt8m/modeling/yt8m_model_test.py b/official/projects/yt8m/modeling/yt8m_model_test.py index 50883fa5a..b204ec6cf 100644 --- a/official/projects/yt8m/modeling/yt8m_model_test.py +++ b/official/projects/yt8m/modeling/yt8m_model_test.py @@ -34,7 +34,7 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase): num_frames: number of frames. feature_dims: indicates total dimension size of the features. """ - input_specs = tf.keras.layers.InputSpec(shape=[num_frames, feature_dims]) + input_specs = tf.keras.layers.InputSpec(shape=[None, None, feature_dims]) num_classes = 3862 model = yt8m_model.DbofModel( @@ -44,7 +44,7 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase): input_specs=input_specs) # batch = 2 -> arbitrary value for test - inputs = np.random.rand(2 * num_frames, feature_dims) + inputs = np.random.rand(2, num_frames, feature_dims) logits = model(inputs) self.assertAllEqual([2, num_classes], logits.numpy().shape) -- GitLab