提交 1a273519 编写于 作者: Z Zhenyu Tan 提交者: A. Unique TensorFlower

aliasing OnDeviceEmbedding inside tensorflow_models.

PiperOrigin-RevId: 331173006
上级 3bac1426
......@@ -26,7 +26,7 @@ import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp import keras_nlp
from official.nlp.modeling import networks
......@@ -137,8 +137,9 @@ ENCODER_CLS = {
@gin.configurable
def build_encoder(config: EncoderConfig,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None,
def build_encoder(
config: EncoderConfig,
embedding_layer: Optional[keras_nlp.layers.OnDeviceEmbedding] = None,
encoder_cls=None,
bypass_config: bool = False):
"""Instantiate a Transformer encoder network from EncoderConfig.
......
......@@ -34,9 +34,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that
is, not to scale). Setting this option to True will let values in output
embeddings multiplied by self._embedding_width ** 0.5.
scale_factor: Whether to scale the output embeddings. Defaults to None (that
is, not to scale). Setting this option to a float will let values in
output embeddings multiplied by scale_factor.
"""
def __init__(self,
......@@ -44,7 +44,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embedding_width,
initializer="glorot_uniform",
use_one_hot=False,
use_scale=False,
scale_factor=None,
**kwargs):
super(OnDeviceEmbedding, self).__init__(**kwargs)
......@@ -52,7 +52,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
self._embedding_width = embedding_width
self._initializer = initializer
self._use_one_hot = use_one_hot
self._use_scale = use_scale
self._scale_factor = scale_factor
def get_config(self):
config = {
......@@ -60,7 +60,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"embedding_width": self._embedding_width,
"initializer": self._initializer,
"use_one_hot": self._use_one_hot,
"use_scale": self._use_scale,
"scale_factor": self._scale_factor,
}
base_config = super(OnDeviceEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......@@ -87,6 +87,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
# Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
if self._use_scale:
embeddings *= self._embedding_width**0.5
if self._scale_factor:
embeddings *= self._scale_factor
return embeddings
......@@ -192,7 +192,8 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, use_scale=True)
vocab_size=vocab_size, embedding_width=embedding_width,
scale_factor=embedding_width**0.5)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
......
......@@ -142,12 +142,12 @@ class Seq2SeqTransformer(tf.keras.Model):
self._beam_size = beam_size
self._alpha = alpha
self._dtype = dtype
self.embedding_lookup = layers.OnDeviceEmbedding(
self.embedding_lookup = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=self._vocab_size,
embedding_width=self._embedding_width,
initializer=tf.random_normal_initializer(
mean=0., stddev=self._embedding_width**-0.5),
use_scale=True)
scale_factor=self._embedding_width**0.5)
self.encoder_layer = encoder_layer
self.decoder_layer = decoder_layer
self.position_embedding = layers.RelativePositionEmbedding(
......@@ -472,7 +472,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
self.encoder_layers = []
for i in range(self.num_layers):
self.encoder_layers.append(
keras_nlp.TransformerEncoderBlock(
keras_nlp.layers.TransformerEncoderBlock(
num_attention_heads=self.num_attention_heads,
inner_dim=self._intermediate_size,
inner_activation=self._activation,
......
......@@ -141,7 +141,7 @@ class EncoderScaffold(tf.keras.Model):
shape=(seq_length,), dtype=tf.int32, name='input_type_ids')
inputs = [word_ids, mask, type_ids]
self._embedding_layer = layers.OnDeviceEmbedding(
self._embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
......@@ -150,13 +150,13 @@ class EncoderScaffold(tf.keras.Model):
word_embeddings = self._embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.PositionEmbedding(
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=embedding_cfg['initializer'],
max_length=embedding_cfg['max_seq_length'],
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
self._type_embedding_layer = layers.OnDeviceEmbedding(
self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['type_vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
......
......@@ -101,18 +101,18 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self.max_sequence_length = max_sequence_length
self.dropout_rate = dropout_rate
self.word_embedding = layers.OnDeviceEmbedding(
self.word_embedding = keras_nlp.layers.OnDeviceEmbedding(
self.word_vocab_size,
self.word_embed_size,
initializer=initializer,
name='word_embedding')
self.type_embedding = layers.OnDeviceEmbedding(
self.type_embedding = keras_nlp.layers.OnDeviceEmbedding(
self.type_vocab_size,
self.output_embed_size,
use_one_hot=True,
initializer=initializer,
name='type_embedding')
self.pos_embedding = keras_nlp.PositionEmbedding(
self.pos_embedding = keras_nlp.layers.PositionEmbedding(
max_length=max_sequence_length,
initializer=initializer,
name='position_embedding')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册