提交 704872d9 编写于 作者: C Chaochao Yan 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 505739149
上级 03c9bfb7
......@@ -68,6 +68,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
"name": "feedforward" }.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
......@@ -88,6 +92,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
dropout_rate=0.0,
attention_dropout_rate=0.0,
norm_first=False,
norm_epsilon=1e-12,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
......@@ -106,6 +111,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._feedforward_cls = feedforward_cls
self._feedforward_cfg = feedforward_cfg
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._num_heads = num_attention_heads
self._inner_dim = inner_dim
self._inner_activation = inner_activation
......@@ -201,7 +207,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=1e-12,
epsilon=self._norm_epsilon,
dtype=tf.float32))
if self._feedforward_block is None:
......@@ -235,7 +241,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
super().build(input_shape)
logging.info("%s configs: %s", self.__class__.__name__, self.get_config())
......@@ -258,6 +267,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._attention_dropout_rate,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册