From 8e108329eb491a6ca5d1a000b79f68e2c8276b49 Mon Sep 17 00:00:00 2001 From: James Lee-Thorp Date: Wed, 19 Oct 2022 16:19:40 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 482331634 --- official/nlp/modeling/layers/__init__.py | 4 + official/nlp/modeling/networks/README.md | 87 ++-- official/nlp/modeling/networks/__init__.py | 1 + official/nlp/modeling/networks/fnet.py | 2 +- .../nlp/modeling/networks/sparse_mixer.py | 407 ++++++++++++++++++ .../modeling/networks/sparse_mixer_test.py | 143 ++++++ 6 files changed, 603 insertions(+), 41 deletions(-) create mode 100644 official/nlp/modeling/networks/sparse_mixer.py create mode 100644 official/nlp/modeling/networks/sparse_mixer_test.py diff --git a/official/nlp/modeling/layers/__init__.py b/official/nlp/modeling/layers/__init__.py index 27a161b69..b94404c53 100644 --- a/official/nlp/modeling/layers/__init__.py +++ b/official/nlp/modeling/layers/__init__.py @@ -37,6 +37,10 @@ from official.nlp.modeling.layers.mixing import MixingMechanism from official.nlp.modeling.layers.mobile_bert_layers import MobileBertEmbedding from official.nlp.modeling.layers.mobile_bert_layers import MobileBertMaskedLM from official.nlp.modeling.layers.mobile_bert_layers import MobileBertTransformer +from official.nlp.modeling.layers.moe import ExpertsChooseMaskedRouter +from official.nlp.modeling.layers.moe import FeedForwardExperts +from official.nlp.modeling.layers.moe import MoeLayer +from official.nlp.modeling.layers.moe import MoeLayerWithBackbone from official.nlp.modeling.layers.multi_channel_attention import * from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding from official.nlp.modeling.layers.pack_optimization import PackBertEmbeddings diff --git a/official/nlp/modeling/networks/README.md b/official/nlp/modeling/networks/README.md index b32a30775..87cc571e8 100644 --- a/official/nlp/modeling/networks/README.md +++ b/official/nlp/modeling/networks/README.md @@ -2,43 +2,50 @@ Networks are combinations of `tf.keras` layers (and possibly other networks). They are `tf.keras` models that would not be trained alone. It encapsulates -common network structures like a transformer encoder into an easily -handled object with a standardized configuration. - -* [`BertEncoder`](bert_encoder.py) implements a bi-directional -Transformer-based encoder as described in ["BERT: Pre-training of Deep -Bidirectional Transformers for Language Understanding"](https://arxiv.org/abs/1810.04805). -It includes the embedding lookups, transformer layers and pooling layer. - -* [`AlbertEncoder`](albert_encoder.py) implements a -Transformer-encoder described in the paper ["ALBERT: A Lite BERT for -Self-supervised Learning of Language Representations"] -(https://arxiv.org/abs/1909.11942). Compared with [BERT](https://arxiv.org/abs/1810.04805), -ALBERT refactorizes embedding parameters into two smaller matrices and shares -parameters across layers. - -* [`MobileBERTEncoder`](mobile_bert_encoder.py) implements the -MobileBERT network described in the paper ["MobileBERT: a Compact Task-Agnostic -BERT for Resource-Limited Devices"](https://arxiv.org/abs/2004.02984). - -* [`Classification`](classification.py) contains a single hidden layer, and is -intended for use as a classification or regression (if number of classes is set -to 1) head. - -* [`PackedSequenceEmbedding`](packed_sequence_embedding.py) implements an -embedding network that supports packed sequences and position ids. - -* [`SpanLabeling`](span_labeling.py) implements a single-span labeler -(that is, a prediction head that can predict one start and end index per batch -item) based on a single dense hidden layer. It can be used in the SQuAD task. - -* [`XLNetBase`](xlnet_base.py) implements the base network used in "XLNet: -Generalized Autoregressive Pretraining for Language Understanding" -(https://arxiv.org/abs/1906.08237). It includes embedding lookups, -relative position encodings, mask computations, segment matrix computations and -Transformer XL layers using one or two stream relative self-attention. - -* [`FNet`](fnet.py) implements the encoder model from ["FNet: Mixing Tokens with -Fourier Transforms"](https://aclanthology.org/2022.naacl-main.319/). FNet has -the same structure as a Transformer encoder, except that all or most of the -self-attention sublayers are replaced with Fourier sublayers. +common network structures like a transformer encoder into an easily handled +object with a standardized configuration. + +* [`BertEncoder`](bert_encoder.py) implements a bi-directional + Transformer-based encoder as described in ["BERT: Pre-training of Deep + Bidirectional Transformers for Language + Understanding"](https://arxiv.org/abs/1810.04805). It includes the embedding + lookups, transformer layers and pooling layer. + +* [`AlbertEncoder`](albert_encoder.py) implements a Transformer-encoder + described in the paper ["ALBERT: A Lite BERT for Self-supervised Learning of + Language Representations"](https://arxiv.org/abs/1909.11942). Compared with + [BERT](https://arxiv.org/abs/1810.04805), ALBERT refactorizes embedding + parameters into two smaller matrices and shares parameters across layers. + +* [`MobileBERTEncoder`](mobile_bert_encoder.py) implements the MobileBERT + network described in the paper + ["MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices"](https://arxiv.org/abs/2004.02984). + +* [`Classification`](classification.py) contains a single hidden layer, and is + intended for use as a classification or regression (if number of classes is + set to 1) head. + +* [`PackedSequenceEmbedding`](packed_sequence_embedding.py) implements an + embedding network that supports packed sequences and position ids. + +* [`SpanLabeling`](span_labeling.py) implements a single-span labeler (that + is, a prediction head that can predict one start and end index per batch + item) based on a single dense hidden layer. It can be used in the SQuAD + task. + +* [`XLNetBase`](xlnet_base.py) implements the base network used in "XLNet: + Generalized Autoregressive Pretraining for Language Understanding" + (https://arxiv.org/abs/1906.08237). It includes embedding lookups, relative + position encodings, mask computations, segment matrix computations and + Transformer XL layers using one or two stream relative self-attention. + +* [`FNet`](fnet.py) implements the encoder model from + ["FNet: Mixing Tokens with Fourier Transforms"](https://aclanthology.org/2022.naacl-main.319/). + FNet has the same structure as a Transformer encoder, except that all or + most of the self-attention sublayers are replaced with Fourier sublayers. + +* [`Sparse Mixer`](sparse_mixer.py) implements the encoder model from + ["Sparse Mixers: Combining MoE and Mixing to build a more efficient BERT "](https://arxiv.org/abs/2205.12399/). + Sparse Mixer consists of layers of heterogeneous encoder blocks. Each + encoder block contains a linear mixing or an attention sublayer together + with a (dense) MLP or sparsely activated Mixture-of-Experts sublayer. diff --git a/official/nlp/modeling/networks/__init__.py b/official/nlp/modeling/networks/__init__.py index 0128481d9..bda1a339f 100644 --- a/official/nlp/modeling/networks/__init__.py +++ b/official/nlp/modeling/networks/__init__.py @@ -29,4 +29,5 @@ from official.nlp.modeling.networks.mobile_bert_encoder import MobileBERTEncoder from official.nlp.modeling.networks.packed_sequence_embedding import PackedSequenceEmbedding from official.nlp.modeling.networks.span_labeling import SpanLabeling from official.nlp.modeling.networks.span_labeling import XLNetSpanLabeling +from official.nlp.modeling.networks.sparse_mixer import SparseMixer from official.nlp.modeling.networks.xlnet_base import XLNetBase diff --git a/official/nlp/modeling/networks/fnet.py b/official/nlp/modeling/networks/fnet.py index ac9676699..7d5d09424 100644 --- a/official/nlp/modeling/networks/fnet.py +++ b/official/nlp/modeling/networks/fnet.py @@ -51,7 +51,7 @@ class FNet(tf.keras.layers.Layer): num_layers: The number of transformer layers. mixing_mechanism: Type of mixing mechanism used in place of self-attention layers. Defaults to FNet ('Fourier') mixing. - use_fft: Only used for spectral mixing mechanims. Determines whether to use + use_fft: Only used for spectral mixing mechanisms. Determines whether to use Fast Fourier Transform (True) or the Discrete Fourier Transform (DFT) matrix (False; default) to compute the Fourier Transform. See layers.FourierTransformLayer or layers.HartleyTransformLayer for advice. diff --git a/official/nlp/modeling/networks/sparse_mixer.py b/official/nlp/modeling/networks/sparse_mixer.py new file mode 100644 index 000000000..c69e7940e --- /dev/null +++ b/official/nlp/modeling/networks/sparse_mixer.py @@ -0,0 +1,407 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse Mixer encoder network. + +Based on ["Sparse Mixers: Combining MoE and Mixing to build a more efficient +BERT"](https://arxiv.org/abs/2205.12399). +""" +# pylint: disable=g-classes-have-attributes + +from typing import Any, Callable, Optional, Sequence, Union +from absl import logging +import tensorflow as tf + +from official.modeling import tf_utils +from official.nlp.modeling import layers + +_Activation = Union[str, Callable[..., Any]] +_Initializer = Union[str, tf.keras.initializers.Initializer] + +_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True) + + +class SparseMixer(tf.keras.layers.Layer): + """Sparse Mixer encoder network. + + Based on ["Sparse Mixers: Combining MoE and Mixing to build a more efficient + BERT"](https://arxiv.org/abs/2205.12399). Sparse Mixer is an efficient + encoder network that replaces typical Transformer encoder blocks with a + combination of linear mixing and sparsely activated Mixture-of-Experts (MoE) + sublayers. + + This implementation defaults to the canonical Sparse Mixer Base model. To use + the "Fast Sparse Mixer" configuration, set `*_capacity_factor`=0.5. This + yields a sparser and faster variant of the canonical Sparse Mixer model, in + which each expert processes roughly 50% less tokens. + + Notes: + - The underlying MoeLayer uses the Keras add_loss() and add_metric() APIs to + propagate auxiliary MoE losses and metrics. Any model using this network, + should collect these losses/metrics. + - The input length is fixed to 'max_sequence_length' to accomodate the mixing + mechanisms. + + Args: + vocab_size: The size of the token vocabulary. + hidden_size: The size of the transformer hidden layers. + num_layers: The number of transformer layers. + moe_layers: Specifies which layers, if any, should be sparsely activated + Mixture-of-Experts (MoE) layers. The remaining [0, num_layers) setminus + moe_layers will use the vanilla MLP sublayers. Defaults to placing MoE + layers in the middle of the model. + attention_layers: Specifies which layers, if any, should be attention layers + in the encoder. The remaining [0, num_layers) setminus attention_layers + will use the specified `mixing_mechanism`. If using attention layers, a + good rule of thumb is to place them in the final few layers. + num_experts: Number of experts. Experts are themselves MLP modules, with the + same `inner_dim` and `inner_activation` as the vanilla MLP sublayers. + train_capacity_factor: Scaling factor to increase the expert token capacity + during training. See layers.MoeLayer for further details. The "Fast Sparse + Mixer" increases model sparsity (and speed) by using a capacity factor of + 0.5. + eval_capacity_factor: As above, but used during evaluation. + max_group_size: The total number of tokens on each device is subdivided into + groups of this size. Router computations are then performed on a per-group + basis. See layers.MoeLayer for further details. + mixing_mechanism: Type of mixing mechanism used in place of self-attention + layers. Defaults to 'Linear' mixing. + use_fft: Only used for spectral mixing mechanisms. Determines whether to use + Fast Fourier Transform (True) or the Discrete Fourier Transform (DFT) + matrix (False; default) to compute the Fourier Transform. See + layers.FourierTransformLayer or layers.HartleyTransformLayer for advice. + num_attention_heads: The number of attention heads for each transformer. The + hidden size must be divisible by the number of attention heads. + max_sequence_length: The only sequence length that this encoder can consume. + This determines the variable shape for positional embeddings and the size + of the mixing matrices. + type_vocab_size: The number of types that the 'type_ids' input can take. + inner_dim: The output dimension of the first Dense layer in a two-layer + feedforward network for each transformer. + inner_activation: The activation for the first Dense layer in a two-layer + feedforward network for each transformer. + output_dropout: Dropout probability for the post-attention and output + dropout. + attention_dropout: The dropout rate to use for the attention layers within + the transformer layers. + initializer: The initializer to use for all weights in this encoder. + output_range: The sequence output range, [0, output_range), by slicing the + target sequence of the last transformer layer. `None` means the entire + target sequence will attend to the source sequence, which yields the full + output. + embedding_width: The width of the word embeddings. If the embedding width is + not equal to hidden size, embedding parameters will be factorized into two + matrices in the shape of ['vocab_size', 'embedding_width'] and + ['embedding_width', 'hidden_size'] ('embedding_width' is usually much + smaller than 'hidden_size'). + embedding_layer: An optional Layer instance which will be called to generate + embeddings for the input word IDs. + norm_first: Whether to normalize inputs to attention and intermediate dense + layers. If set False, output of attention and intermediate dense layers is + normalized. + with_dense_inputs: Whether to accept dense embeddings as the input. + """ + + def __init__( + self, + vocab_size: int, + hidden_size: int = 512, + num_layers: int = 14, + moe_layers: Sequence[int] = (5, 6, 7, 8), + attention_layers: Sequence[int] = (10, 11, 12, 13), + num_experts: int = 16, + train_capacity_factor: float = 1., + eval_capacity_factor: float = 1., + max_group_size: int = 4096, + mixing_mechanism: layers.MixingMechanism = layers.MixingMechanism.LINEAR, + use_fft: bool = False, + num_attention_heads: int = 8, + max_sequence_length: int = 512, + type_vocab_size: int = 16, + inner_dim: int = 2056, + inner_activation: _Activation = _approx_gelu, + output_dropout: float = 0.1, + attention_dropout: float = 0.1, + initializer: _Initializer = tf.keras.initializers.TruncatedNormal( + stddev=0.02), + output_range: Optional[int] = None, + embedding_width: Optional[int] = None, + embedding_layer: Optional[tf.keras.layers.Layer] = None, + norm_first: bool = False, + with_dense_inputs: bool = False, + **kwargs): + super().__init__(**kwargs) + + activation = tf.keras.activations.get(inner_activation) + initializer = tf.keras.initializers.get(initializer) + + if embedding_width is None: + embedding_width = hidden_size + + self._config = { + 'vocab_size': vocab_size, + 'hidden_size': hidden_size, + 'num_layers': num_layers, + 'moe_layers': moe_layers, + 'num_experts': num_experts, + 'train_capacity_factor': train_capacity_factor, + 'eval_capacity_factor': eval_capacity_factor, + 'max_group_size': max_group_size, + 'mixing_mechanism': mixing_mechanism, + 'use_fft': use_fft, + 'attention_layers': attention_layers, + 'num_attention_heads': num_attention_heads, + 'max_sequence_length': max_sequence_length, + 'type_vocab_size': type_vocab_size, + 'inner_dim': inner_dim, + 'inner_activation': tf.keras.activations.serialize(activation), + 'output_dropout': output_dropout, + 'attention_dropout': attention_dropout, + 'initializer': tf.keras.initializers.serialize(initializer), + 'output_range': output_range, + 'embedding_width': embedding_width, + 'embedding_layer': embedding_layer, + 'norm_first': norm_first, + 'with_dense_inputs': with_dense_inputs, + } + + if embedding_layer is None: + self._embedding_layer = layers.OnDeviceEmbedding( + vocab_size=vocab_size, + embedding_width=embedding_width, + initializer=tf_utils.clone_initializer(initializer), + name='word_embeddings') + else: + self._embedding_layer = embedding_layer + + self._position_embedding_layer = layers.PositionEmbedding( + initializer=tf_utils.clone_initializer(initializer), + max_length=max_sequence_length, + name='position_embedding') + + self._type_embedding_layer = layers.OnDeviceEmbedding( + vocab_size=type_vocab_size, + embedding_width=embedding_width, + initializer=tf_utils.clone_initializer(initializer), + use_one_hot=True, + name='type_embeddings') + + self._embedding_norm_layer = tf.keras.layers.LayerNormalization( + name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32) + + self._embedding_dropout = tf.keras.layers.Dropout( + rate=output_dropout, name='embedding_dropout') + + # We project the 'embedding' output to 'hidden_size' if it is not already + # 'hidden_size'. + self._embedding_projection = None + if embedding_width != hidden_size: + self._embedding_projection = tf.keras.layers.EinsumDense( + '...x,xy->...y', + output_shape=hidden_size, + bias_axes='y', + kernel_initializer=tf_utils.clone_initializer(initializer), + name='embedding_projection') + + self._transformer_layers = [] + for layer in range(num_layers): + if layer in attention_layers: + mixing_layer = layers.MultiHeadAttention( + num_heads=num_attention_heads, + key_dim=int(hidden_size // num_attention_heads), + dropout=attention_dropout, + use_bias=True, + kernel_initializer=tf_utils.clone_initializer(initializer), + name='self_attention', + ) + else: + mixing_layer = self._init_mixing_sublayer(layer) + + if layer in moe_layers: + feedforward_layer = layers.MoeLayer( + experts=layers.FeedForwardExperts( + num_experts=num_experts, + d_ff=hidden_size, + dropout_rate=output_dropout, + activation=inner_activation, + kernel_initializer=tf_utils.clone_initializer(initializer), + name='experts'), + router=layers.ExpertsChooseMaskedRouter( + num_experts=num_experts, + kernel_initializer=tf_utils.clone_initializer(initializer), + name='router'), + train_capacity_factor=train_capacity_factor, + eval_capacity_factor=eval_capacity_factor, + max_group_size=max_group_size, + name='moe') + else: + feedforward_layer = None # Fallback to default (dense) MLP class + + block = layers.TransformerScaffold( + num_attention_heads=num_attention_heads, + inner_dim=inner_dim, + inner_activation=inner_activation, + attention_cls=mixing_layer, + feedforward_cls=feedforward_layer, + output_dropout=output_dropout, + attention_dropout=attention_dropout, + norm_first=norm_first, + output_range=output_range if layer == num_layers - 1 else None, + kernel_initializer=tf_utils.clone_initializer(initializer), + name='transformer/layer_%d' % layer) + self._transformer_layers.append(block) + + self._attention_mask_layer = layers.SelfAttentionMask( + name='self_attention_mask') + + self._pooler_layer = tf.keras.layers.Dense( + units=hidden_size, + activation='tanh', + kernel_initializer=tf_utils.clone_initializer(initializer), + name='pooler_transform') + + if with_dense_inputs: + self.inputs = dict( + input_word_ids=tf.keras.Input( + shape=(max_sequence_length,), dtype=tf.int32), + input_mask=tf.keras.Input( + shape=(max_sequence_length,), dtype=tf.int32), + input_type_ids=tf.keras.Input( + shape=(max_sequence_length,), dtype=tf.int32), + dense_inputs=tf.keras.Input( + shape=(max_sequence_length, embedding_width), dtype=tf.float32), + dense_mask=tf.keras.Input( + shape=(max_sequence_length,), dtype=tf.int32), + dense_type_ids=tf.keras.Input( + shape=(max_sequence_length,), dtype=tf.int32), + ) + else: + self.inputs = dict( + input_word_ids=tf.keras.Input( + shape=(max_sequence_length,), dtype=tf.int32), + input_mask=tf.keras.Input( + shape=(max_sequence_length,), dtype=tf.int32), + input_type_ids=tf.keras.Input( + shape=(max_sequence_length,), dtype=tf.int32)) + self._max_sequence_length = max_sequence_length + + def call(self, inputs): + word_embeddings = None + if isinstance(inputs, dict): + word_ids = inputs.get('input_word_ids') + mask = inputs.get('input_mask') + type_ids = inputs.get('input_type_ids') + word_embeddings = inputs.get('input_word_embeddings', None) + + dense_inputs = inputs.get('dense_inputs', None) + dense_mask = inputs.get('dense_mask', None) + dense_type_ids = inputs.get('dense_type_ids', None) + else: + raise ValueError('Unexpected inputs type (%s) to %s.' % + (type(inputs), self.__class__)) + + if word_embeddings is None: + word_embeddings = self._embedding_layer(word_ids) + + if dense_inputs is not None: + # Concat the dense embeddings at sequence end. + word_embeddings = tf.concat([word_embeddings, dense_inputs], axis=1) + type_ids = tf.concat([type_ids, dense_type_ids], axis=1) + mask = tf.concat([mask, dense_mask], axis=1) + + seq_length = word_embeddings.shape[1] + if seq_length != self._max_sequence_length: + raise ValueError('Sparse Mixer: Sequence length must be the same as ' + '`max_sequence_length` ({}), but it is {}.'.format( + self._max_sequence_length, seq_length)) + + # Absolute position embeddings. + position_embeddings = self._position_embedding_layer(word_embeddings) + type_embeddings = self._type_embedding_layer(type_ids) + + embeddings = word_embeddings + position_embeddings + type_embeddings + embeddings = self._embedding_norm_layer(embeddings) + embeddings = self._embedding_dropout(embeddings) + + if self._embedding_projection is not None: + embeddings = self._embedding_projection(embeddings) + + attention_mask = self._attention_mask_layer(embeddings, mask) + + encoder_outputs = [] + x = embeddings + for layer in self._transformer_layers: + x = layer([x, attention_mask]) + encoder_outputs.append(x) + + last_encoder_output = encoder_outputs[-1] + first_token_tensor = last_encoder_output[:, 0, :] + pooled_output = self._pooler_layer(first_token_tensor) + + output = dict( + sequence_output=encoder_outputs[-1], + pooled_output=pooled_output, + encoder_outputs=encoder_outputs) + return output + + def get_embedding_table(self): + return self._embedding_layer.embeddings + + def get_embedding_layer(self): + return self._embedding_layer + + def get_config(self): + return dict(self._config) + + @property + def transformer_layers(self): + """List of Transformer layers in the encoder.""" + return self._transformer_layers + + @property + def pooler_layer(self): + """The pooler dense layer after the transformer layers.""" + return self._pooler_layer + + @classmethod + def from_config(cls, config, custom_objects=None): + if 'embedding_layer' in config and config['embedding_layer'] is not None: + warn_string = ( + 'You are reloading a model that was saved with a ' + 'potentially-shared embedding layer object. If you contine to ' + 'train this model, the embedding layer will no longer be shared. ' + 'To work around this, load the model outside of the Keras API.') + print('WARNING: ' + warn_string) + logging.warn(warn_string) + + return cls(**config) + + def _init_mixing_sublayer(self, layer: int): + """Initializes config-dependent mixing sublayer.""" + if self._config['mixing_mechanism'] == layers.MixingMechanism.FOURIER: + mixing_sublayer = layers.FourierTransformLayer( + use_fft=self._config['use_fft'], name='fourier_transform') + elif self._config['mixing_mechanism'] == layers.MixingMechanism.HARTLEY: + mixing_sublayer = layers.HartleyTransformLayer( + use_fft=self._config['use_fft'], name='hartley_transform') + elif self._config['mixing_mechanism'] == layers.MixingMechanism.LINEAR: + mixing_sublayer = layers.LinearTransformLayer( + kernel_initializer=tf_utils.clone_initializer( + self._config['initializer']), + name='linear_transform') + else: + raise ValueError('Unsupported mixing mechanism: %s' % + self._config['mixing_mechanism']) + + return mixing_sublayer diff --git a/official/nlp/modeling/networks/sparse_mixer_test.py b/official/nlp/modeling/networks/sparse_mixer_test.py new file mode 100644 index 000000000..3a7db920b --- /dev/null +++ b/official/nlp/modeling/networks/sparse_mixer_test.py @@ -0,0 +1,143 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Sparse Mixer encoder network.""" + +from typing import Sequence + +from absl.testing import parameterized +import tensorflow as tf + +from official.nlp.modeling import layers +from official.nlp.modeling.networks import sparse_mixer + + +class SparseMixerTest(parameterized.TestCase, tf.test.TestCase): + + def tearDown(self): + super().tearDown() + tf.keras.mixed_precision.set_global_policy("float32") + + @parameterized.named_parameters( + dict( + testcase_name="sparse_mixer", + mixing_mechanism=layers.MixingMechanism.LINEAR, + moe_layers=(1,), + attention_layers=(2,)), + dict( + testcase_name="fnet", + mixing_mechanism=layers.MixingMechanism.FOURIER, + moe_layers=(), + attention_layers=()), + dict( + testcase_name="sparse_hnet", + mixing_mechanism=layers.MixingMechanism.HARTLEY, + moe_layers=(0, 1, 2), + attention_layers=(1, 2)), + dict( + testcase_name="sparse_bert", + mixing_mechanism=layers.MixingMechanism.LINEAR, + moe_layers=(0, 1, 2), # All layers use MoE + attention_layers=(0, 1, 2)), # All layers use attention + ) + def test_network(self, mixing_mechanism: layers.MixingMechanism, + attention_layers: Sequence[int], moe_layers: Sequence[int]): + num_layers = 3 + hidden_size = 16 + sequence_length = 32 + test_network = sparse_mixer.SparseMixer( + vocab_size=100, + hidden_size=hidden_size, + num_attention_heads=2, + max_sequence_length=sequence_length, + num_layers=num_layers, + moe_layers=moe_layers, + num_experts=8, + mixing_mechanism=mixing_mechanism, + attention_layers=attention_layers) + + batch_size = 4 + word_ids = tf.keras.Input( + shape=(sequence_length,), batch_size=batch_size, dtype=tf.int32) + mask = tf.keras.Input( + shape=(sequence_length,), batch_size=batch_size, dtype=tf.int32) + type_ids = tf.keras.Input( + shape=(sequence_length,), batch_size=batch_size, dtype=tf.int32) + + dict_outputs = test_network( + dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids)) + data = dict_outputs["sequence_output"] + pooled = dict_outputs["pooled_output"] + + self.assertIsInstance(test_network.transformer_layers, list) + self.assertLen(test_network.transformer_layers, 3) + self.assertIsInstance(test_network.pooler_layer, tf.keras.layers.Dense) + + expected_data_shape = [batch_size, sequence_length, hidden_size] + expected_pooled_shape = [batch_size, hidden_size] + self.assertAllEqual(expected_data_shape, data.shape.as_list()) + self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list()) + + # The default output dtype is float32. + self.assertAllEqual(tf.float32, data.dtype) + self.assertAllEqual(tf.float32, pooled.dtype) + + def test_embeddings_as_inputs(self): + hidden_size = 32 + sequence_length = 8 + test_network = sparse_mixer.SparseMixer( + vocab_size=100, + hidden_size=hidden_size, + num_attention_heads=2, + max_sequence_length=sequence_length, + num_layers=3, + moe_layers=(1,), + num_experts=4, + attention_layers=(2,)) + + batch_size = 2 + word_ids = tf.keras.Input( + shape=(sequence_length), batch_size=batch_size, dtype=tf.int32) + mask = tf.keras.Input( + shape=(sequence_length,), batch_size=batch_size, dtype=tf.int32) + type_ids = tf.keras.Input( + shape=(sequence_length,), batch_size=batch_size, dtype=tf.int32) + + test_network.build( + dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids)) + embeddings = test_network.get_embedding_layer()(word_ids) + + # Calls with the embeddings. + dict_outputs = test_network( + dict( + input_word_embeddings=embeddings, + input_mask=mask, + input_type_ids=type_ids)) + all_encoder_outputs = dict_outputs["encoder_outputs"] + pooled = dict_outputs["pooled_output"] + + expected_data_shape = [batch_size, sequence_length, hidden_size] + expected_pooled_shape = [batch_size, hidden_size] + self.assertLen(all_encoder_outputs, 3) + for data in all_encoder_outputs: + self.assertAllEqual(expected_data_shape, data.shape.as_list()) + self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list()) + + # The default output dtype is float32. + self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype) + self.assertAllEqual(tf.float32, pooled.dtype) + + +if __name__ == "__main__": + tf.test.main() -- GitLab