diff --git a/official/nlp/modeling/networks/funnel_transformer.py b/official/nlp/modeling/networks/funnel_transformer.py index 2e8d3938d4f9e265834af42b1fa9259aca78faa9..0be87d450e32cfe94128f2a13b6554852cb881bb 100644 --- a/official/nlp/modeling/networks/funnel_transformer.py +++ b/official/nlp/modeling/networks/funnel_transformer.py @@ -298,6 +298,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): str, tf.keras.layers.Layer ] = layers.TransformerEncoderBlock, share_rezero: bool = False, + append_dense_inputs: bool = False, **kwargs ): super().__init__(**kwargs) @@ -420,6 +421,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): self._pool_strides = pool_strides # This is a list here. self._unpool_length = unpool_length self._pool_type = pool_type + self._append_dense_inputs = append_dense_inputs self._config = { 'vocab_size': vocab_size, @@ -485,11 +487,22 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): word_embeddings = self._embedding_layer(word_ids) if dense_inputs is not None: - # Concat the dense embeddings at sequence begin so unpool_len can control - # embedding not being pooled. - word_embeddings = tf.concat([dense_inputs, word_embeddings], axis=1) - type_ids = tf.concat([dense_type_ids, type_ids], axis=1) - mask = tf.concat([dense_mask, mask], axis=1) + # Allow concatenation of the dense embeddings at sequence end if requested + # and `unpool_length`` is set as zero + if self._append_dense_inputs: + if self._unpool_length != 0: + raise ValueError( + 'unpool_length is not supported by append_dense_inputs now.' + ) + word_embeddings = tf.concat([word_embeddings, dense_inputs], axis=1) + type_ids = tf.concat([type_ids, dense_type_ids], axis=1) + mask = tf.concat([mask, dense_mask], axis=1) + else: + # Concat the dense embeddings at sequence begin so unpool_len can + # control embedding not being pooled. + word_embeddings = tf.concat([dense_inputs, word_embeddings], axis=1) + type_ids = tf.concat([dense_type_ids, type_ids], axis=1) + mask = tf.concat([dense_mask, mask], axis=1) # absolute position embeddings position_embeddings = self._position_embedding_layer(word_embeddings) type_embeddings = self._type_embedding_layer(type_ids) diff --git a/official/nlp/modeling/networks/funnel_transformer_test.py b/official/nlp/modeling/networks/funnel_transformer_test.py index 9a4e49733d0d471d6991e0f182799e6a781ecb71..fdeb9bc1380f4234dd3a3864e8831473f89baaa1 100644 --- a/official/nlp/modeling/networks/funnel_transformer_test.py +++ b/official/nlp/modeling/networks/funnel_transformer_test.py @@ -101,7 +101,11 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): self.assertAllEqual(tf.float32, data.dtype) self.assertAllEqual(pooled_dtype, pooled.dtype) - def test_network_creation_dense(self): + @parameterized.named_parameters( + ("append_dense_inputs", True), + ("dense_inputs_at_sequence_begin", False), + ) + def test_network_creation_dense(self, append_dense_inputs): tf.keras.mixed_precision.set_global_policy("mixed_float16") pool_type = "avg" @@ -120,7 +124,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): pool_type=pool_type, max_sequence_length=sequence_length + dense_sequence_length, unpool_length=0, - transformer_cls="TransformerEncoderBlock") + transformer_cls="TransformerEncoderBlock", + append_dense_inputs=append_dense_inputs) # Create the inputs (note that the first dimension is implicit). word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)