提交 5842f323 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 492385091
上级 bdf4a7bb
......@@ -258,13 +258,37 @@ class FNetEncoderConfig(hyperparams.Config):
initializer_range: float = 0.02
embedding_width: Optional[int] = None
output_range: Optional[int] = None
return_all_encoder_outputs: bool = False
# Pre/Post-LN Transformer
norm_first: bool = False
use_fft: bool = False
attention_layers: Sequence[int] = ()
@dataclasses.dataclass
class SparseMixerEncoderConfig(hyperparams.Config):
"""SparseMixer encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 14
moe_layers: Sequence[int] = (5, 6, 7, 8)
attention_layers: Sequence[int] = (10, 11, 12, 13)
num_experts: int = 16
train_capacity_factor: float = 1.
eval_capacity_factor: float = 1.
examples_per_group: float = 1.
use_fft: bool = False
num_attention_heads: int = 8
max_sequence_length: int = 512
type_vocab_size: int = 2
inner_dim: int = 3072
inner_activation: str = "gelu"
output_dropout: float = 0.1
attention_dropout: float = 0.1
initializer_range: float = 0.02
output_range: Optional[int] = None
embedding_width: Optional[int] = None
norm_first: bool = False
@dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig):
"""Encoder configuration."""
......@@ -279,6 +303,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
query_bert: QueryBertConfig = QueryBertConfig()
fnet: FNetEncoderConfig = FNetEncoderConfig()
sparse_mixer: SparseMixerEncoderConfig = SparseMixerEncoderConfig()
# If `any` is used, the encoder building relies on any.BUILDER.
any: hyperparams.Config = hyperparams.Config()
......@@ -607,6 +632,32 @@ def build_encoder(config: EncoderConfig,
use_fft=encoder_cfg.use_fft,
attention_layers=encoder_cfg.attention_layers)
if encoder_type == "sparse_mixer":
return networks.SparseMixer(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
moe_layers=encoder_cfg.moe_layers,
attention_layers=encoder_cfg.attention_layers,
num_experts=encoder_cfg.num_experts,
train_capacity_factor=encoder_cfg.train_capacity_factor,
eval_capacity_factor=encoder_cfg.eval_capacity_factor,
examples_per_group=encoder_cfg.examples_per_group,
use_fft=encoder_cfg.use_fft,
num_attention_heads=encoder_cfg.num_attention_heads,
max_sequence_length=encoder_cfg.max_sequence_length,
type_vocab_size=encoder_cfg.type_vocab_size,
inner_dim=encoder_cfg.inner_dim,
inner_activation=tf_utils.get_activation(encoder_cfg.inner_activation),
output_dropout=encoder_cfg.output_dropout,
attention_dropout=encoder_cfg.attention_dropout,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_width=encoder_cfg.embedding_width,
norm_first=encoder_cfg.norm_first,
embedding_layer=embedding_layer)
bert_encoder_cls = networks.BertEncoder
if encoder_type == "bert_v2":
bert_encoder_cls = networks.BertEncoderV2
......
......@@ -17,8 +17,6 @@
import dataclasses
from typing import Any, Callable, Optional, Tuple
from absl import logging
import numpy as np
import tensorflow as tf
from official.modeling import tf_utils
......@@ -48,11 +46,13 @@ def _router_z_loss(router_logits: tf.Tensor) -> float:
Returns:
Scalar router z-loss <float32>.
"""
num_groups, tokens_per_group, _ = router_logits.shape
num_groups = tf.shape(router_logits)[0]
tokens_per_group = router_logits.shape[1]
log_z = tf.math.reduce_logsumexp(router_logits, axis=-1)
z_loss = log_z**2
return tf.math.reduce_sum(z_loss) / (num_groups * tokens_per_group)
return tf.math.reduce_sum(z_loss) / tf.cast(
num_groups * tokens_per_group, tf.float32)
@dataclasses.dataclass
......@@ -187,7 +187,7 @@ class Router(tf.keras.layers.Layer):
"""
if apply_jitter and self.jitter_noise > 0:
inputs *= tf.random.uniform(
inputs.shape,
tf.shape(inputs),
minval=1.0 - self.jitter_noise,
maxval=1.0 + self.jitter_noise,
dtype=inputs.dtype)
......@@ -259,7 +259,9 @@ class ExpertsChooseMaskedRouter(MaskedRouter):
Returns:
Dispatch and combine arrays for routing with masked matmuls.
"""
num_groups, tokens_per_group, _ = router_probs.shape
num_groups = tf.shape(router_probs)[0]
tokens_per_group = router_probs.shape[1]
router_probs_t = tf.transpose(router_probs, perm=[0, 2, 1])
# router_probs_t: <float32>[num_groups, num_experts, tokens_per_group]
......@@ -296,8 +298,10 @@ class ExpertsChooseMaskedRouter(MaskedRouter):
num_tokens = num_groups * tokens_per_group
num_tokens_dispatched_somewhere = tf.math.reduce_sum(tf.math.reduce_max(
dispatch_mask, axis=(-1, -2)))
fraction_tokens_left_behind = 1.0 - num_tokens_dispatched_somewhere / float(
num_tokens)
fraction_tokens_left_behind = 1.0 - tf.cast(
num_tokens_dispatched_somewhere, tf.float32) / tf.cast(
num_tokens, tf.float32)
# Total number of tokens that were dispatched (one token could be
# dispatched to multiple experts).
num_tokens_dispatched = tf.math.reduce_sum(dispatch_mask)
......@@ -513,9 +517,7 @@ class MoeLayer(tf.keras.layers.Layer):
*,
train_capacity_factor: float = 1.0,
eval_capacity_factor: float = 1.0,
min_expert_capacity: int = 4,
max_group_size: int = 4096,
strict_group_size: bool = False,
examples_per_group: float = 1.0,
name: str = "moe",
**kwargs):
"""Init.
......@@ -537,22 +539,18 @@ class MoeLayer(tf.keras.layers.Layer):
tokens that an expert will process AND will indirectly increase the
number of experts that a given token is routed to.
eval_capacity_factor: As above, but used during evaluation.
min_expert_capacity: Minimum token processing capacity for each expert.
max_group_size: The total number of tokens on each device is subdivided
into groups of this size. Router computations are then performed on a
per-group basis. A larger group size will result in slower but more
accurate top-k and sorting computations, whereas a smaller group size
will result in faster but more approximate (and potentially less stable)
routing choices. Note that actual group size may be smaller than
max_group_size for consistency with the number of experts and tokens;
see also `strict_group_size` attribute. In practice,
we find that imperfect routing choices are tolerable and recommend
choosing a group size on the order of 4096 tokens, although this number
will vary based on model configuration and size.
strict_group_size: If True, fail if unable to set the token group size
equal to max_group_size. If False (default), the actual group size may
be smaller than max_group_size for consistency with the number of
experts and tokens.
examples_per_group: Number of examples to form a group. Router then
performs top_k token selection for each expert on a per group basis.
E.g. when `examples_per_group=4.0`, tokens are assigned to experts in
groups formed from 4 examples. When `examples_per_group=0.5`,
each example is split into 2 groups.
`examples_per_group` must divide the local batch size.
A larger group size will result in slower but more accurate top-k and
sorting computations, whereas a smaller group size will result in faster
but more approximate (and potentially less stable) routing choices.
In practice, we find that imperfect routing choices are tolerable and
recommend choosing a group size on the order of 4096 tokens, although
this number will vary based on model configuration and size.
name: Layer name.
**kwargs: Forwarded to super.
"""
......@@ -565,9 +563,7 @@ class MoeLayer(tf.keras.layers.Layer):
self._train_capacity_factor = train_capacity_factor
self._eval_capacity_factor = eval_capacity_factor
self._max_group_size = max_group_size
self._min_expert_capacity = min_expert_capacity
self._strict_group_size = strict_group_size
self._examples_per_group = examples_per_group
def call(self,
inputs: tf.Tensor,
......@@ -592,10 +588,15 @@ class MoeLayer(tf.keras.layers.Layer):
training = tf.keras.backend.learning_phase()
# inputs shape [batch_size, seq_length, hidden_dim]
per_device_batch_size, seq_length, hidden_dim = inputs.shape
num_tokens = per_device_batch_size * seq_length
num_groups = self._num_groups(num_tokens, self._max_group_size)
tokens_per_group = num_tokens // num_groups
batch_size, seq_length, hidden_dim = inputs.shape
if batch_size is not None:
if self._examples_per_group > batch_size:
raise ValueError(
f"examples_per_group={self._examples_per_group} is larger than the "
"number of examples available in the local (per-device) batch_size="
f"{batch_size}. Either decrease examples_per_group or increase the "
"batch_size.")
tokens_per_group = int(seq_length * self._examples_per_group)
if training:
capacity_factor = self._train_capacity_factor
......@@ -604,71 +605,16 @@ class MoeLayer(tf.keras.layers.Layer):
# Each group will send expert_capacity tokens to each expert.
expert_capacity = int(
round(capacity_factor * tokens_per_group / self.num_experts))
expert_capacity = max(expert_capacity, self._min_expert_capacity)
logging.info(
"Selected expert_capacity=%d for num_experts=%d and training=%r.",
expert_capacity, self.num_experts, training)
# Reshape batch and sequence/token dimensions for expert routing.
x = tf.reshape(inputs, (num_groups, tokens_per_group, hidden_dim))
x = tf.reshape(inputs, (-1, tokens_per_group, hidden_dim))
x = self._mask_and_dispatch_to_experts(x, expert_capacity, training)
# Return to original input shape.
x = tf.reshape(x, (per_device_batch_size, seq_length, hidden_dim))
x = tf.reshape(x, (-1, seq_length, hidden_dim))
return x
def _num_groups(self, num_tokens: int, max_group_size: int) -> int:
"""Returns the number of token routing groups.
Note that the quantities are local to the device.
We select the smallest num_groups such that:
- num_groups >= num_tokens / max_group_size (ensuring the group size is no
larger than max_group_size),
- num_tokens % num_groups = 0 (ensuring that the group size evenly divides
into the num_tokens),
Args:
num_tokens: Number of tokens from input batch.
max_group_size: Maximum size of each token routing group. Actual group
size may end up being smaller unless strict_group_size==True.
Returns:
Number of token routing groups.
Raises:
ValueError if we cannot find a group_size satisfying the above
requirements.
"""
# Increase the number of groups (and decrease the group size) until we have
# a viable number of groups.
min_num_groups = int(np.ceil(num_tokens / max_group_size))
num_groups = min_num_groups
while num_groups < num_tokens and num_tokens % num_groups != 0:
num_groups += 1
group_size = num_tokens // num_groups
logging.info(
"Selected group_size=%d and num_groups=%d for input num_tokens=%d, "
"max_group_size=%d, num_experts=%d.",
group_size, num_groups, num_tokens, max_group_size, self.num_experts)
if group_size < self._min_expert_capacity:
raise ValueError(
f"Local (per-device) group_size {group_size} is smaller than "
f"min_expert_capacity {self._min_expert_capacity}, which is probably "
"not intended. Please increase max_group_size {max_group_size} to"
" seq_length or increase batch_size or decrease min_expert_capacity.")
if self._strict_group_size and group_size != self._max_group_size:
raise ValueError(
f"Selected group_size={group_size} is less than the "
f"max_group_size={max_group_size}. Exiting because strict mode is "
"active (strict_group_size=True)")
return num_groups
def _mask_and_dispatch_to_experts(self, inputs: tf.Tensor,
expert_capacity: int,
training: bool) -> tf.Tensor:
......
......@@ -32,14 +32,13 @@ def small_config():
config['jitter_noise'] = 0.1
config['train_capacity_factor'] = 1.0
config['eval_capacity_factor'] = 1.0
config['min_expert_capacity'] = 1
config['max_group_size'] = 9
config['examples_per_group'] = 2.0
config['backbone_d_ff'] = 13
return config
def make_input_ones(batch_size: int = 2,
def make_input_ones(batch_size: int = 4,
seq_length: int = 10,
hidden_dim: int = 7) -> tf.Tensor:
return tf.ones((batch_size, seq_length, hidden_dim), dtype=tf.float32)
......@@ -64,11 +63,7 @@ class MoeTest(tf.test.TestCase):
y = moe._router_z_loss(x)
expected = (5 + np.log(np.exp(5) + 1))**2
self.assertAllClose(expected, y, atol=1e-7)
x = tf.constant([[[10.0, 5.0]]], dtype=tf.bfloat16)
y = moe._router_z_loss(x)
expected = 100.0
self.assertAllClose(expected, y, atol=1e-7)
self.assertDTypeEqual(y, tf.float32)
def test_router_z_loss_shape(self):
x = make_input_ones(2, 5, 7)
......@@ -199,21 +194,12 @@ class MoeTest(tf.test.TestCase):
router,
train_capacity_factor=config['train_capacity_factor'],
eval_capacity_factor=config['eval_capacity_factor'],
max_group_size=config['max_group_size'],
min_expert_capacity=config['min_expert_capacity'])
examples_per_group=config['examples_per_group'])
inputs = make_input_ones()
with self.assertLogs('absl', level='INFO') as cm:
outputs = moe_layer(inputs, training=True)
outputs = moe_layer(inputs, training=True)
self.assertAllEqual(tf.shape(inputs), tf.shape(outputs))
self.assertEqual(
cm.output,
[('INFO:absl:Selected group_size=5 and num_groups=4 for input '
'num_tokens=20, max_group_size=9, num_experts=2.'),
('INFO:absl:Selected expert_capacity=2 for num_experts=2 and '
'training=True.')])
var_names = sorted([v.name for v in moe_layer.trainable_variables])
self.assertAllEqual([
'moe/experts/intermediate/bias:0', 'moe/experts/intermediate/kernel:0',
......@@ -241,8 +227,7 @@ class MoeTest(tf.test.TestCase):
router,
train_capacity_factor=config['train_capacity_factor'],
eval_capacity_factor=config['eval_capacity_factor'],
max_group_size=config['max_group_size'],
min_expert_capacity=config['min_expert_capacity'])
examples_per_group=config['examples_per_group'])
layer = moe.MoeLayerWithBackbone(moe_layer, config['backbone_d_ff'])
inputs = make_input_ones()
......
......@@ -123,7 +123,7 @@ class SparseMixer(tf.keras.layers.Layer):
num_experts: int = 16,
train_capacity_factor: float = 1.,
eval_capacity_factor: float = 1.,
max_group_size: int = 4096,
examples_per_group: float = 1.,
mixing_mechanism: layers.MixingMechanism = layers.MixingMechanism.LINEAR,
use_fft: bool = False,
num_attention_heads: int = 8,
......@@ -157,7 +157,7 @@ class SparseMixer(tf.keras.layers.Layer):
'num_experts': num_experts,
'train_capacity_factor': train_capacity_factor,
'eval_capacity_factor': eval_capacity_factor,
'max_group_size': max_group_size,
'examples_per_group': examples_per_group,
'mixing_mechanism': mixing_mechanism,
'use_fft': use_fft,
'attention_layers': attention_layers,
......@@ -243,7 +243,7 @@ class SparseMixer(tf.keras.layers.Layer):
name='router'),
train_capacity_factor=train_capacity_factor,
eval_capacity_factor=eval_capacity_factor,
max_group_size=max_group_size,
examples_per_group=examples_per_group,
name='moe')
else:
feedforward_layer = None # Fallback to default (dense) MLP class
......@@ -273,18 +273,15 @@ class SparseMixer(tf.keras.layers.Layer):
if with_dense_inputs:
self.inputs = dict(
input_word_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
input_mask=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
# The total length of token ids and dense inputs still has to be
# max_sequence_length. It is checked in call().
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
dense_inputs=tf.keras.Input(
shape=(max_sequence_length, embedding_width), dtype=tf.float32),
dense_mask=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
dense_type_ids=tf.keras.Input(
shape=(max_sequence_length,), dtype=tf.int32),
shape=(None, embedding_width), dtype=tf.float32),
dense_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
dense_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
)
else:
self.inputs = dict(
......@@ -320,11 +317,9 @@ class SparseMixer(tf.keras.layers.Layer):
type_ids = tf.concat([type_ids, dense_type_ids], axis=1)
mask = tf.concat([mask, dense_mask], axis=1)
seq_length = word_embeddings.shape[1]
if seq_length != self._max_sequence_length:
raise ValueError('Sparse Mixer: Sequence length must be the same as '
'`max_sequence_length` ({}), but it is {}.'.format(
self._max_sequence_length, seq_length))
# SparseMixer: Sequence length must be the same as `max_sequence_length`.
word_embeddings = tf.ensure_shape(word_embeddings,
[None, self._max_sequence_length, None])
# Absolute position embeddings.
position_embeddings = self._position_embedding_layer(word_embeddings)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册