diff --git a/official/projects/yt8m/modeling/yt8m_model.py b/official/projects/yt8m/modeling/yt8m_model.py index 24b67d8c558dd8198886a029f1ff5e501fd19e66..339ba9fa72ec8ab286bcc609088e43d8d52703d1 100644 --- a/official/projects/yt8m/modeling/yt8m_model.py +++ b/official/projects/yt8m/modeling/yt8m_model.py @@ -65,12 +65,11 @@ class DbofModel(tf.keras.Model): norm_epsilon: A `float` added to variance to avoid dividing by zero. **kwargs: keyword arguments to be passed. """ - + del num_frames self._self_setattr_tracking = False self._config_dict = { "input_specs": input_specs, "num_classes": num_classes, - "num_frames": num_frames, "params": params } self._num_classes = num_classes @@ -80,26 +79,23 @@ class DbofModel(tf.keras.Model): self._norm = layers.experimental.SyncBatchNormalization else: self._norm = layers.BatchNormalization - if tf.keras.backend.image_data_format() == "channels_last": - bn_axis = -1 - else: - bn_axis = 1 + bn_axis = -1 # [batch_size x num_frames x num_features] feature_size = input_specs.shape[-1] # shape 'excluding' batch_size model_input = tf.keras.Input(shape=self._input_specs.shape[1:]) - reshaped_input = tf.reshape(model_input, [-1, feature_size]) tf.summary.histogram("input_hist", model_input) + input_data = model_input # configure model if params.add_batch_norm: - reshaped_input = self._norm( + input_data = self._norm( axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon, name="input_bn")( - reshaped_input) + input_data) # activation = reshaped input * cluster weights if params.cluster_size > 0: @@ -108,7 +104,7 @@ class DbofModel(tf.keras.Model): kernel_regularizer=kernel_regularizer, kernel_initializer=tf.random_normal_initializer( stddev=1 / tf.sqrt(tf.cast(feature_size, tf.float32))))( - reshaped_input) + input_data) if params.add_batch_norm: activation = self._norm( @@ -142,7 +138,7 @@ class DbofModel(tf.keras.Model): pooling_method=pooling_method, hidden_layer_size=params.context_gate_cluster_bottleneck_size, kernel_regularizer=kernel_regularizer) - activation = tf.reshape(activation, [-1, num_frames, params.cluster_size]) + activation = utils.frame_pooling(activation, params.pooling_method) # activation = activation * hidden1_weights diff --git a/official/projects/yt8m/modeling/yt8m_model_test.py b/official/projects/yt8m/modeling/yt8m_model_test.py index 50883fa5ab16b4b0c793af4b3a6debcc26a2c4b1..b204ec6cf91f0849229bd43bd9f0d7e8afedd305 100644 --- a/official/projects/yt8m/modeling/yt8m_model_test.py +++ b/official/projects/yt8m/modeling/yt8m_model_test.py @@ -34,7 +34,7 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase): num_frames: number of frames. feature_dims: indicates total dimension size of the features. """ - input_specs = tf.keras.layers.InputSpec(shape=[num_frames, feature_dims]) + input_specs = tf.keras.layers.InputSpec(shape=[None, None, feature_dims]) num_classes = 3862 model = yt8m_model.DbofModel( @@ -44,7 +44,7 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase): input_specs=input_specs) # batch = 2 -> arbitrary value for test - inputs = np.random.rand(2 * num_frames, feature_dims) + inputs = np.random.rand(2, num_frames, feature_dims) logits = model(inputs) self.assertAllEqual([2, num_classes], logits.numpy().shape)