提交 ce8f9010 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 535815173
上级 799abc1c
...@@ -298,6 +298,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -298,6 +298,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
str, tf.keras.layers.Layer str, tf.keras.layers.Layer
] = layers.TransformerEncoderBlock, ] = layers.TransformerEncoderBlock,
share_rezero: bool = False, share_rezero: bool = False,
append_dense_inputs: bool = False,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -420,6 +421,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -420,6 +421,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._pool_strides = pool_strides # This is a list here. self._pool_strides = pool_strides # This is a list here.
self._unpool_length = unpool_length self._unpool_length = unpool_length
self._pool_type = pool_type self._pool_type = pool_type
self._append_dense_inputs = append_dense_inputs
self._config = { self._config = {
'vocab_size': vocab_size, 'vocab_size': vocab_size,
...@@ -485,11 +487,22 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -485,11 +487,22 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
word_embeddings = self._embedding_layer(word_ids) word_embeddings = self._embedding_layer(word_ids)
if dense_inputs is not None: if dense_inputs is not None:
# Concat the dense embeddings at sequence begin so unpool_len can control # Allow concatenation of the dense embeddings at sequence end if requested
# embedding not being pooled. # and `unpool_length`` is set as zero
word_embeddings = tf.concat([dense_inputs, word_embeddings], axis=1) if self._append_dense_inputs:
type_ids = tf.concat([dense_type_ids, type_ids], axis=1) if self._unpool_length != 0:
mask = tf.concat([dense_mask, mask], axis=1) raise ValueError(
'unpool_length is not supported by append_dense_inputs now.'
)
word_embeddings = tf.concat([word_embeddings, dense_inputs], axis=1)
type_ids = tf.concat([type_ids, dense_type_ids], axis=1)
mask = tf.concat([mask, dense_mask], axis=1)
else:
# Concat the dense embeddings at sequence begin so unpool_len can
# control embedding not being pooled.
word_embeddings = tf.concat([dense_inputs, word_embeddings], axis=1)
type_ids = tf.concat([dense_type_ids, type_ids], axis=1)
mask = tf.concat([dense_mask, mask], axis=1)
# absolute position embeddings # absolute position embeddings
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids) type_embeddings = self._type_embedding_layer(type_ids)
......
...@@ -101,7 +101,11 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -101,7 +101,11 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(tf.float32, data.dtype) self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(pooled_dtype, pooled.dtype) self.assertAllEqual(pooled_dtype, pooled.dtype)
def test_network_creation_dense(self): @parameterized.named_parameters(
("append_dense_inputs", True),
("dense_inputs_at_sequence_begin", False),
)
def test_network_creation_dense(self, append_dense_inputs):
tf.keras.mixed_precision.set_global_policy("mixed_float16") tf.keras.mixed_precision.set_global_policy("mixed_float16")
pool_type = "avg" pool_type = "avg"
...@@ -120,7 +124,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -120,7 +124,8 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
pool_type=pool_type, pool_type=pool_type,
max_sequence_length=sequence_length + dense_sequence_length, max_sequence_length=sequence_length + dense_sequence_length,
unpool_length=0, unpool_length=0,
transformer_cls="TransformerEncoderBlock") transformer_cls="TransformerEncoderBlock",
append_dense_inputs=append_dense_inputs)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册