提交 ac97d775 编写于 作者: S Scott Zhu 提交者: A. Unique TensorFlower

Prepare for upcoming keras initializer change.

PiperOrigin-RevId: 451481056
上级 b58f478c
......@@ -19,6 +19,7 @@ import collections
from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.projects.roformer import roformer_encoder_block
......@@ -115,7 +116,7 @@ class RoformerEncoder(tf.keras.Model):
embedding_layer_inst = layers.on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
else:
embedding_layer_inst = embedding_layer
......@@ -125,7 +126,7 @@ class RoformerEncoder(tf.keras.Model):
type_embedding_layer = layers.on_device_embedding.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')
type_embeddings = type_embedding_layer(type_ids)
......@@ -146,7 +147,7 @@ class RoformerEncoder(tf.keras.Model):
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')
embeddings = embedding_projection(embeddings)
else:
......@@ -171,7 +172,7 @@ class RoformerEncoder(tf.keras.Model):
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=transformer_output_range,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='roformer/layer_%d' % i)
transformer_layers.append(layer)
data = layer([data, attention_mask])
......@@ -185,7 +186,7 @@ class RoformerEncoder(tf.keras.Model):
pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')
cls_output = pooler_layer(first_token_tensor)
......
......@@ -15,6 +15,7 @@
"""Roformer TransformerEncoder block layer."""
import tensorflow as tf
from official.modeling import tf_utils
from official.projects.roformer import roformer_attention
......@@ -111,7 +112,8 @@ class RoformerEncoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes
def build(self, input_shape):
......@@ -164,7 +166,7 @@ class RoformerEncoderBlock(tf.keras.layers.Layer):
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
......@@ -182,7 +184,7 @@ class RoformerEncoderBlock(tf.keras.layers.Layer):
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册