提交 69bbdc1c 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 485690158
上级 66f76664
......@@ -30,6 +30,7 @@ import enum
import functools
from typing import Callable, Tuple, Union
import gin
import numpy as np
from scipy import linalg
import tensorflow as tf
......@@ -41,6 +42,7 @@ _Initializer = Union[str, tf.keras.initializers.Initializer]
default_kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev=2e-2)
@gin.constants_from_enum
class MixingMechanism(enum.Enum):
"""Determines the type of mixing layer.
......
......@@ -89,6 +89,8 @@ class FNet(tf.keras.layers.Layer):
layers. If set False, output of attention and intermediate dense layers is
normalized.
with_dense_inputs: Whether to accept dense embeddings as the input.
num_dense_tokens: Length of the token dimension of dense inputs if dense
inputs are used. This counts towards max_sequence_length.
"""
def __init__(
......@@ -113,6 +115,7 @@ class FNet(tf.keras.layers.Layer):
embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
with_dense_inputs: bool = False,
num_dense_tokens: int = 0,
**kwargs):
super().__init__(**kwargs)
......@@ -142,6 +145,7 @@ class FNet(tf.keras.layers.Layer):
'embedding_layer': embedding_layer,
'norm_first': norm_first,
'with_dense_inputs': with_dense_inputs,
'num_dense_tokens': num_dense_tokens,
}
if embedding_layer is None:
......@@ -220,20 +224,26 @@ class FNet(tf.keras.layers.Layer):
name='pooler_transform')
if with_dense_inputs:
if max_sequence_length - num_dense_tokens < 0:
raise ValueError(
'FNet: `max_sequence_length` should include dense tokens, but got '
'`max_sequence_length` - `num_dense_tokens` = {} - {} < 0.'.format(
max_sequence_length, num_dense_tokens))
self.inputs = dict(
input_word_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
shape=(max_sequence_length - num_dense_tokens,), dtype=tf.int32),
input_mask=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
shape=(max_sequence_length - num_dense_tokens,), dtype=tf.int32),
input_type_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
shape=(max_sequence_length - num_dense_tokens,), dtype=tf.int32),
dense_inputs=tf.keras.Input(
shape=(max_sequence_length, embedding_width), dtype=tf.float32),
shape=(num_dense_tokens, embedding_width), dtype=tf.float32),
dense_mask=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
shape=(num_dense_tokens,), dtype=tf.int32),
dense_type_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
shape=(num_dense_tokens,), dtype=tf.int32),
)
else:
self.inputs = dict(
input_word_ids=tf.keras.Input(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册