提交 8e108329 编写于 作者: J James Lee-Thorp 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 482331634
上级 35785639
......@@ -37,6 +37,10 @@ from official.nlp.modeling.layers.mixing import MixingMechanism
from official.nlp.modeling.layers.mobile_bert_layers import MobileBertEmbedding
from official.nlp.modeling.layers.mobile_bert_layers import MobileBertMaskedLM
from official.nlp.modeling.layers.mobile_bert_layers import MobileBertTransformer
from official.nlp.modeling.layers.moe import ExpertsChooseMaskedRouter
from official.nlp.modeling.layers.moe import FeedForwardExperts
from official.nlp.modeling.layers.moe import MoeLayer
from official.nlp.modeling.layers.moe import MoeLayerWithBackbone
from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.pack_optimization import PackBertEmbeddings
......
......@@ -2,43 +2,50 @@
Networks are combinations of `tf.keras` layers (and possibly other networks).
They are `tf.keras` models that would not be trained alone. It encapsulates
common network structures like a transformer encoder into an easily
handled object with a standardized configuration.
* [`BertEncoder`](bert_encoder.py) implements a bi-directional
Transformer-based encoder as described in ["BERT: Pre-training of Deep
Bidirectional Transformers for Language Understanding"](https://arxiv.org/abs/1810.04805).
It includes the embedding lookups, transformer layers and pooling layer.
* [`AlbertEncoder`](albert_encoder.py) implements a
Transformer-encoder described in the paper ["ALBERT: A Lite BERT for
Self-supervised Learning of Language Representations"]
(https://arxiv.org/abs/1909.11942). Compared with [BERT](https://arxiv.org/abs/1810.04805),
ALBERT refactorizes embedding parameters into two smaller matrices and shares
parameters across layers.
* [`MobileBERTEncoder`](mobile_bert_encoder.py) implements the
MobileBERT network described in the paper ["MobileBERT: a Compact Task-Agnostic
BERT for Resource-Limited Devices"](https://arxiv.org/abs/2004.02984).
* [`Classification`](classification.py) contains a single hidden layer, and is
intended for use as a classification or regression (if number of classes is set
to 1) head.
* [`PackedSequenceEmbedding`](packed_sequence_embedding.py) implements an
embedding network that supports packed sequences and position ids.
* [`SpanLabeling`](span_labeling.py) implements a single-span labeler
(that is, a prediction head that can predict one start and end index per batch
item) based on a single dense hidden layer. It can be used in the SQuAD task.
* [`XLNetBase`](xlnet_base.py) implements the base network used in "XLNet:
Generalized Autoregressive Pretraining for Language Understanding"
(https://arxiv.org/abs/1906.08237). It includes embedding lookups,
relative position encodings, mask computations, segment matrix computations and
Transformer XL layers using one or two stream relative self-attention.
* [`FNet`](fnet.py) implements the encoder model from ["FNet: Mixing Tokens with
Fourier Transforms"](https://aclanthology.org/2022.naacl-main.319/). FNet has
the same structure as a Transformer encoder, except that all or most of the
self-attention sublayers are replaced with Fourier sublayers.
common network structures like a transformer encoder into an easily handled
object with a standardized configuration.
* [`BertEncoder`](bert_encoder.py) implements a bi-directional
Transformer-based encoder as described in ["BERT: Pre-training of Deep
Bidirectional Transformers for Language
Understanding"](https://arxiv.org/abs/1810.04805). It includes the embedding
lookups, transformer layers and pooling layer.
* [`AlbertEncoder`](albert_encoder.py) implements a Transformer-encoder
described in the paper ["ALBERT: A Lite BERT for Self-supervised Learning of
Language Representations"](https://arxiv.org/abs/1909.11942). Compared with
[BERT](https://arxiv.org/abs/1810.04805), ALBERT refactorizes embedding
parameters into two smaller matrices and shares parameters across layers.
* [`MobileBERTEncoder`](mobile_bert_encoder.py) implements the MobileBERT
network described in the paper
["MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices"](https://arxiv.org/abs/2004.02984).
* [`Classification`](classification.py) contains a single hidden layer, and is
intended for use as a classification or regression (if number of classes is
set to 1) head.
* [`PackedSequenceEmbedding`](packed_sequence_embedding.py) implements an
embedding network that supports packed sequences and position ids.
* [`SpanLabeling`](span_labeling.py) implements a single-span labeler (that
is, a prediction head that can predict one start and end index per batch
item) based on a single dense hidden layer. It can be used in the SQuAD
task.
* [`XLNetBase`](xlnet_base.py) implements the base network used in "XLNet:
Generalized Autoregressive Pretraining for Language Understanding"
(https://arxiv.org/abs/1906.08237). It includes embedding lookups, relative
position encodings, mask computations, segment matrix computations and
Transformer XL layers using one or two stream relative self-attention.
* [`FNet`](fnet.py) implements the encoder model from
["FNet: Mixing Tokens with Fourier Transforms"](https://aclanthology.org/2022.naacl-main.319/).
FNet has the same structure as a Transformer encoder, except that all or
most of the self-attention sublayers are replaced with Fourier sublayers.
* [`Sparse Mixer`](sparse_mixer.py) implements the encoder model from
["Sparse Mixers: Combining MoE and Mixing to build a more efficient BERT "](https://arxiv.org/abs/2205.12399/).
Sparse Mixer consists of layers of heterogeneous encoder blocks. Each
encoder block contains a linear mixing or an attention sublayer together
with a (dense) MLP or sparsely activated Mixture-of-Experts sublayer.
......@@ -29,4 +29,5 @@ from official.nlp.modeling.networks.mobile_bert_encoder import MobileBERTEncoder
from official.nlp.modeling.networks.packed_sequence_embedding import PackedSequenceEmbedding
from official.nlp.modeling.networks.span_labeling import SpanLabeling
from official.nlp.modeling.networks.span_labeling import XLNetSpanLabeling
from official.nlp.modeling.networks.sparse_mixer import SparseMixer
from official.nlp.modeling.networks.xlnet_base import XLNetBase
......@@ -51,7 +51,7 @@ class FNet(tf.keras.layers.Layer):
num_layers: The number of transformer layers.
mixing_mechanism: Type of mixing mechanism used in place of self-attention
layers. Defaults to FNet ('Fourier') mixing.
use_fft: Only used for spectral mixing mechanims. Determines whether to use
use_fft: Only used for spectral mixing mechanisms. Determines whether to use
Fast Fourier Transform (True) or the Discrete Fourier Transform (DFT)
matrix (False; default) to compute the Fourier Transform. See
layers.FourierTransformLayer or layers.HartleyTransformLayer for advice.
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sparse Mixer encoder network.
Based on ["Sparse Mixers: Combining MoE and Mixing to build a more efficient
BERT"](https://arxiv.org/abs/2205.12399).
"""
# pylint: disable=g-classes-have-attributes
from typing import Any, Callable, Optional, Sequence, Union
from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
_Activation = Union[str, Callable[..., Any]]
_Initializer = Union[str, tf.keras.initializers.Initializer]
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
class SparseMixer(tf.keras.layers.Layer):
"""Sparse Mixer encoder network.
Based on ["Sparse Mixers: Combining MoE and Mixing to build a more efficient
BERT"](https://arxiv.org/abs/2205.12399). Sparse Mixer is an efficient
encoder network that replaces typical Transformer encoder blocks with a
combination of linear mixing and sparsely activated Mixture-of-Experts (MoE)
sublayers.
This implementation defaults to the canonical Sparse Mixer Base model. To use
the "Fast Sparse Mixer" configuration, set `*_capacity_factor`=0.5. This
yields a sparser and faster variant of the canonical Sparse Mixer model, in
which each expert processes roughly 50% less tokens.
Notes:
- The underlying MoeLayer uses the Keras add_loss() and add_metric() APIs to
propagate auxiliary MoE losses and metrics. Any model using this network,
should collect these losses/metrics.
- The input length is fixed to 'max_sequence_length' to accomodate the mixing
mechanisms.
Args:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
moe_layers: Specifies which layers, if any, should be sparsely activated
Mixture-of-Experts (MoE) layers. The remaining [0, num_layers) setminus
moe_layers will use the vanilla MLP sublayers. Defaults to placing MoE
layers in the middle of the model.
attention_layers: Specifies which layers, if any, should be attention layers
in the encoder. The remaining [0, num_layers) setminus attention_layers
will use the specified `mixing_mechanism`. If using attention layers, a
good rule of thumb is to place them in the final few layers.
num_experts: Number of experts. Experts are themselves MLP modules, with the
same `inner_dim` and `inner_activation` as the vanilla MLP sublayers.
train_capacity_factor: Scaling factor to increase the expert token capacity
during training. See layers.MoeLayer for further details. The "Fast Sparse
Mixer" increases model sparsity (and speed) by using a capacity factor of
0.5.
eval_capacity_factor: As above, but used during evaluation.
max_group_size: The total number of tokens on each device is subdivided into
groups of this size. Router computations are then performed on a per-group
basis. See layers.MoeLayer for further details.
mixing_mechanism: Type of mixing mechanism used in place of self-attention
layers. Defaults to 'Linear' mixing.
use_fft: Only used for spectral mixing mechanisms. Determines whether to use
Fast Fourier Transform (True) or the Discrete Fourier Transform (DFT)
matrix (False; default) to compute the Fourier Transform. See
layers.FourierTransformLayer or layers.HartleyTransformLayer for advice.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The only sequence length that this encoder can consume.
This determines the variable shape for positional embeddings and the size
of the mixing matrices.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
initializer: The initializer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to generate
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
with_dense_inputs: Whether to accept dense embeddings as the input.
"""
def __init__(
self,
vocab_size: int,
hidden_size: int = 512,
num_layers: int = 14,
moe_layers: Sequence[int] = (5, 6, 7, 8),
attention_layers: Sequence[int] = (10, 11, 12, 13),
num_experts: int = 16,
train_capacity_factor: float = 1.,
eval_capacity_factor: float = 1.,
max_group_size: int = 4096,
mixing_mechanism: layers.MixingMechanism = layers.MixingMechanism.LINEAR,
use_fft: bool = False,
num_attention_heads: int = 8,
max_sequence_length: int = 512,
type_vocab_size: int = 16,
inner_dim: int = 2056,
inner_activation: _Activation = _approx_gelu,
output_dropout: float = 0.1,
attention_dropout: float = 0.1,
initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
stddev=0.02),
output_range: Optional[int] = None,
embedding_width: Optional[int] = None,
embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
with_dense_inputs: bool = False,
**kwargs):
super().__init__(**kwargs)
activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer)
if embedding_width is None:
embedding_width = hidden_size
self._config = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'moe_layers': moe_layers,
'num_experts': num_experts,
'train_capacity_factor': train_capacity_factor,
'eval_capacity_factor': eval_capacity_factor,
'max_group_size': max_group_size,
'mixing_mechanism': mixing_mechanism,
'use_fft': use_fft,
'attention_layers': attention_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
'with_dense_inputs': with_dense_inputs,
}
if embedding_layer is None:
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
self._embedding_dropout = tf.keras.layers.Dropout(
rate=output_dropout, name='embedding_dropout')
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')
self._transformer_layers = []
for layer in range(num_layers):
if layer in attention_layers:
mixing_layer = layers.MultiHeadAttention(
num_heads=num_attention_heads,
key_dim=int(hidden_size // num_attention_heads),
dropout=attention_dropout,
use_bias=True,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='self_attention',
)
else:
mixing_layer = self._init_mixing_sublayer(layer)
if layer in moe_layers:
feedforward_layer = layers.MoeLayer(
experts=layers.FeedForwardExperts(
num_experts=num_experts,
d_ff=hidden_size,
dropout_rate=output_dropout,
activation=inner_activation,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='experts'),
router=layers.ExpertsChooseMaskedRouter(
num_experts=num_experts,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='router'),
train_capacity_factor=train_capacity_factor,
eval_capacity_factor=eval_capacity_factor,
max_group_size=max_group_size,
name='moe')
else:
feedforward_layer = None # Fallback to default (dense) MLP class
block = layers.TransformerScaffold(
num_attention_heads=num_attention_heads,
inner_dim=inner_dim,
inner_activation=inner_activation,
attention_cls=mixing_layer,
feedforward_cls=feedforward_layer,
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if layer == num_layers - 1 else None,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % layer)
self._transformer_layers.append(block)
self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask')
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')
if with_dense_inputs:
self.inputs = dict(
input_word_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
input_mask=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
dense_inputs=tf.keras.Input(
shape=(max_sequence_length, embedding_width), dtype=tf.float32),
dense_mask=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
dense_type_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
)
else:
self.inputs = dict(
input_word_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
input_mask=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32))
self._max_sequence_length = max_sequence_length
def call(self, inputs):
word_embeddings = None
if isinstance(inputs, dict):
word_ids = inputs.get('input_word_ids')
mask = inputs.get('input_mask')
type_ids = inputs.get('input_type_ids')
word_embeddings = inputs.get('input_word_embeddings', None)
dense_inputs = inputs.get('dense_inputs', None)
dense_mask = inputs.get('dense_mask', None)
dense_type_ids = inputs.get('dense_type_ids', None)
else:
raise ValueError('Unexpected inputs type (%s) to %s.' %
(type(inputs), self.__class__))
if word_embeddings is None:
word_embeddings = self._embedding_layer(word_ids)
if dense_inputs is not None:
# Concat the dense embeddings at sequence end.
word_embeddings = tf.concat([word_embeddings, dense_inputs], axis=1)
type_ids = tf.concat([type_ids, dense_type_ids], axis=1)
mask = tf.concat([mask, dense_mask], axis=1)
seq_length = word_embeddings.shape[1]
if seq_length != self._max_sequence_length:
raise ValueError('Sparse Mixer: Sequence length must be the same as '
'`max_sequence_length` ({}), but it is {}.'.format(
self._max_sequence_length, seq_length))
# Absolute position embeddings.
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids)
embeddings = word_embeddings + position_embeddings + type_embeddings
embeddings = self._embedding_norm_layer(embeddings)
embeddings = self._embedding_dropout(embeddings)
if self._embedding_projection is not None:
embeddings = self._embedding_projection(embeddings)
attention_mask = self._attention_mask_layer(embeddings, mask)
encoder_outputs = []
x = embeddings
for layer in self._transformer_layers:
x = layer([x, attention_mask])
encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1]
first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor)
output = dict(
sequence_output=encoder_outputs[-1],
pooled_output=pooled_output,
encoder_outputs=encoder_outputs)
return output
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_embedding_layer(self):
return self._embedding_layer
def get_config(self):
return dict(self._config)
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod
def from_config(cls, config, custom_objects=None):
if 'embedding_layer' in config and config['embedding_layer'] is not None:
warn_string = (
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.')
print('WARNING: ' + warn_string)
logging.warn(warn_string)
return cls(**config)
def _init_mixing_sublayer(self, layer: int):
"""Initializes config-dependent mixing sublayer."""
if self._config['mixing_mechanism'] == layers.MixingMechanism.FOURIER:
mixing_sublayer = layers.FourierTransformLayer(
use_fft=self._config['use_fft'], name='fourier_transform')
elif self._config['mixing_mechanism'] == layers.MixingMechanism.HARTLEY:
mixing_sublayer = layers.HartleyTransformLayer(
use_fft=self._config['use_fft'], name='hartley_transform')
elif self._config['mixing_mechanism'] == layers.MixingMechanism.LINEAR:
mixing_sublayer = layers.LinearTransformLayer(
kernel_initializer=tf_utils.clone_initializer(
self._config['initializer']),
name='linear_transform')
else:
raise ValueError('Unsupported mixing mechanism: %s' %
self._config['mixing_mechanism'])
return mixing_sublayer
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Sparse Mixer encoder network."""
from typing import Sequence
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling.networks import sparse_mixer
class SparseMixerTest(parameterized.TestCase, tf.test.TestCase):
def tearDown(self):
super().tearDown()
tf.keras.mixed_precision.set_global_policy("float32")
@parameterized.named_parameters(
dict(
testcase_name="sparse_mixer",
mixing_mechanism=layers.MixingMechanism.LINEAR,
moe_layers=(1,),
attention_layers=(2,)),
dict(
testcase_name="fnet",
mixing_mechanism=layers.MixingMechanism.FOURIER,
moe_layers=(),
attention_layers=()),
dict(
testcase_name="sparse_hnet",
mixing_mechanism=layers.MixingMechanism.HARTLEY,
moe_layers=(0, 1, 2),
attention_layers=(1, 2)),
dict(
testcase_name="sparse_bert",
mixing_mechanism=layers.MixingMechanism.LINEAR,
moe_layers=(0, 1, 2), # All layers use MoE
attention_layers=(0, 1, 2)), # All layers use attention
)
def test_network(self, mixing_mechanism: layers.MixingMechanism,
attention_layers: Sequence[int], moe_layers: Sequence[int]):
num_layers = 3
hidden_size = 16
sequence_length = 32
test_network = sparse_mixer.SparseMixer(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
max_sequence_length=sequence_length,
num_layers=num_layers,
moe_layers=moe_layers,
num_experts=8,
mixing_mechanism=mixing_mechanism,
attention_layers=attention_layers)
batch_size = 4
word_ids = tf.keras.Input(
shape=(sequence_length,), batch_size=batch_size, dtype=tf.int32)
mask = tf.keras.Input(
shape=(sequence_length,), batch_size=batch_size, dtype=tf.int32)
type_ids = tf.keras.Input(
shape=(sequence_length,), batch_size=batch_size, dtype=tf.int32)
dict_outputs = test_network(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
data = dict_outputs["sequence_output"]
pooled = dict_outputs["pooled_output"]
self.assertIsInstance(test_network.transformer_layers, list)
self.assertLen(test_network.transformer_layers, 3)
self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense)
expected_data_shape = [batch_size, sequence_length, hidden_size]
expected_pooled_shape = [batch_size, hidden_size]
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
def test_embeddings_as_inputs(self):
hidden_size = 32
sequence_length = 8
test_network = sparse_mixer.SparseMixer(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
max_sequence_length=sequence_length,
num_layers=3,
moe_layers=(1,),
num_experts=4,
attention_layers=(2,))
batch_size = 2
word_ids = tf.keras.Input(
shape=(sequence_length), batch_size=batch_size, dtype=tf.int32)
mask = tf.keras.Input(
shape=(sequence_length,), batch_size=batch_size, dtype=tf.int32)
type_ids = tf.keras.Input(
shape=(sequence_length,), batch_size=batch_size, dtype=tf.int32)
test_network.build(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
embeddings = test_network.get_embedding_layer()(word_ids)
# Calls with the embeddings.
dict_outputs = test_network(
dict(
input_word_embeddings=embeddings,
input_mask=mask,
input_type_ids=type_ids))
all_encoder_outputs = dict_outputs["encoder_outputs"]
pooled = dict_outputs["pooled_output"]
expected_data_shape = [batch_size, sequence_length, hidden_size]
expected_pooled_shape = [batch_size, hidden_size]
self.assertLen(all_encoder_outputs, 3)
for data in all_encoder_outputs:
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册