提交 caaa39c2 编写于 作者: K Krzysztof Choromanski 提交者: A. Unique TensorFlower

Adding FAVOR++ mechanism from the paper: "Chefs' Random Tables:...

Adding FAVOR++ mechanism from the paper: "Chefs' Random Tables: Non-Trigonometric Random Features" to TF MODEL GARDEN.

PiperOrigin-RevId: 460625574
上级 20685639
......@@ -18,6 +18,8 @@ import functools
import math
import tensorflow as tf
from official.modeling import tf_utils
......@@ -56,8 +58,8 @@ def create_projection_matrix(m, d, seed=None):
The matrix of random projections of the shape [m, d].
nb_full_blocks = math.ceil(m / d)
block_list = tf.TensorArray(tf.float32,
size=tf.cast(nb_full_blocks, dtype=tf.int32))
block_list = tf.TensorArray(
tf.float32, size=tf.cast(nb_full_blocks, dtype=tf.int32))
stateful = False
if seed is None:
stateful = True
......@@ -108,6 +110,122 @@ def _generalized_kernel(x, projection_matrix, f, h):
tf.cast(tf.shape(projection_matrix)[0], tf.float32))
def expplus(data_orig,
"""FAVOR++ mechanism from the CRT paper: https://arxiv.org/abs/2205.15317 .
data_orig: data tensor of shape [B,T,H,D] for which random features aree to
be computed
other_data: additional tensor of the shape [B,F,H,D] used to collect stats
to determine the exact instantiation of the random feature mechanism
is_query: boolean indicating whether <data_orig> tensor is a query tensor
projection_matrix: tensor of the shape [M,D] encoding random projections for
random features (M stands for the number of random features)
numerical_stabilizer: numerical stabilizer for the kernel features
normalize_data: whether to sqrt-d-normalize queries/keys as in the regular
numerical_renormalizer: whether to apply additional renormalization for
numerical stability
extra_renormalize_exp_fun: extra renormalizer for the exponential mapping
applied to construct random features
Random feature map tensor for the unbiased softmax-kernel estimation.
data = data_orig
if projection_matrix is None:
return data_orig
projection_matrix = tf.cast(projection_matrix, data.dtype)
if normalize_data:
data_normalizer = 1.0 / tf.math.sqrt(
(tf.math.sqrt(tf.dtypes.cast(data.shape[-1], data.dtype))))
data_normalizer = 1.0
lengths = tf.math.square(data)
lengths = tf.reduce_sum(lengths, axis=tf.keras.backend.ndim(data) - 1)
lengths = tf.expand_dims(lengths, axis=tf.keras.backend.ndim(data) - 1)
lengths = tf.math.sqrt(lengths)
data /= lengths
ratio = 1.0 / tf.math.sqrt(
tf.dtypes.cast(projection_matrix.shape[0], data.dtype))
data_dash = tf.einsum("blhd,md->blhm", data_normalizer * data,
diag_data = tf.math.square(data)
diag_data = tf.math.reduce_sum(
diag_data, axis=tf.keras.backend.ndim(data) - 1)
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1)
# Calculating coefficients A, B of the FAVOR++ mechanism:
_, l, _, _ = tf_utils.get_shape_list(data_orig)
l = tf.cast(l, dtype=tf.float32)
first_sum_of_squares = tf.math.square(data)
first_sum_of_squares = tf.math.reduce_sum(
first_sum_of_squares, axis=(1, -1), keepdims=True)
first_sum_of_squares *= (data_normalizer * data_normalizer)
first_sum_of_squares /= l # data.shape[1]
second_sum_of_squares = tf.math.square(other_data)
second_sum_of_squares = tf.math.reduce_sum(
second_sum_of_squares, axis=(1, -1), keepdims=True)
second_sum_of_squares *= (data_normalizer * data_normalizer)
second_sum_of_squares /= l # other_data.shape[1]
data_sum = tf.math.reduce_sum(data, axis=(1,), keepdims=True)
other_data_sum = tf.math.reduce_sum(other_data, axis=(1,), keepdims=True)
d_prod = tf.einsum("blhd,blhd->blh", data_sum, other_data_sum)
d_prod = tf.expand_dims(d_prod, axis=-1)
d_prod *= (data_normalizer * data_normalizer)
d_prod *= (2.0 / (l * l))
ave = first_sum_of_squares + second_sum_of_squares + d_prod
dim = projection_matrix.shape[-1]
A = (1.0 / (4.0 * ave)) * (
tf.math.sqrt((2.0 * ave + dim) *
(2.0 * ave + dim) + 8.0 * dim * ave) - 2.0 * ave - dim)
A = (1.0 - 1.0 / A) / 8.0
B = tf.math.sqrt(1.0 - 4.0 * A)
D = tf.math.pow(1.0 - 4.0 * A, dim / 4.0)
A = tf.stop_gradient(A)
B = tf.stop_gradient(B)
D = tf.stop_gradient(D)
# Calculating diag_omega for the FAVOR++ mechanism:
diag_omega = tf.math.square(projection_matrix)
diag_omega = tf.math.reduce_sum(
diag_omega, axis=tf.keras.backend.ndim(projection_matrix) - 1)
diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = A * diag_omega
if numerical_renormalizer:
if is_query:
last_dims_t = (len(data_dash.shape) - 1,)
stab = B * tf.math.reduce_max(data_dash, axis=last_dims_t, keepdims=True)
stab = B * tf.math.reduce_max(data_dash, keepdims=True)
if extra_renormalize_exp_fun:
extra_stab = tf.reduce_max(diag_data, axis=1, keepdims=True)
stab = tf.math.maximum(stab, extra_stab)
data_dash = ratio * D * (
tf.math.exp(B * data_dash - stab - diag_data + diag_omega) +
data_dash = ratio * D * (
tf.math.exp(B * data_dash - diag_data + diag_omega) +
return data_dash
# pylint: disable=g-long-lambda
......@@ -120,19 +238,19 @@ _TRANSFORM_MAP = {
# Improve numerical stability and avoid NaNs in some cases by adding
# a tiny epsilon.
f=lambda x: tf.keras.activations.relu(x) + 1e-3, h=lambda x: 1),
f=lambda x: tf.keras.activations.relu(x) + 1e-3,
h=lambda x: 1),
_generalized_kernel, f=tf.math.square, h=lambda x: 1),
functools.partial(_generalized_kernel, f=tf.math.square, h=lambda x: 1),
# Avoid exp explosion by shifting.
f=lambda x: tf.math.exp(
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp(
-0.5 * tf.math.reduce_sum(
tf.math.square(x), axis=-1, keepdims=True)),),
f=lambda x: tf.math.exp(x - tf.math.reduce_max(
x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp(-0.5 * tf.math.reduce_sum(
tf.math.square(x), axis=-1, keepdims=True)),
......@@ -157,6 +275,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
- exp (Lemma 1, positive), relu
- random/deterministic projection
Chefs' Random Tables: Non-Trigonometric Random Features
- expplus (OPRF mechanism)
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
......@@ -187,7 +308,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
feature_transform: A non-linear transform of the keys and quries. Possible
transforms are "elu", "relu", "square", "exp", "expmod", "identity".
transforms are "elu", "relu", "square", "exp", "expplus", "expmod",
num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the
......@@ -209,7 +331,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
https://kexue.fm/archives/8823 for details.
**kwargs: The same arguments `MultiHeadAttention` layer.
if feature_transform not in _TRANSFORM_MAP:
if feature_transform not in _TRANSFORM_MAP and feature_transform != "expplus":
raise ValueError("Unsupported feature_transform. The supported "
"feature_transform are %s. "
"Got '%s'." % (_TRANSFORM_MAP.keys(), feature_transform))
......@@ -296,23 +418,27 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
key *= tf.math.sqrt(scale)
query *= tf.math.sqrt(scale)
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
if feature_transform != "expplus":
key_prime = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
query_prime = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
key_prime = expplus(key, query, False, projection_matrix)
query_prime = expplus(query, key, True, projection_matrix)
if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
key_prime = tf.einsum("BSNH,BS->BSNH", key_prime, attention_mask)
if is_short_seq:
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query_prime, key_prime)
attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
kv = tf.einsum("BSNH,BSND->BNDH", key, value)
kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value)
denominator = 1.0 / (
tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) +
attention_output = tf.einsum(
"BTNH,BNDH,BTN->BTND", query, kv, denominator)
tf.einsum("BTNH,BNH->BTN", query_prime,
tf.reduce_sum(key_prime, axis=1)) + _NUMERIC_STABLER)
attention_output = tf.einsum("BTNH,BNDH,BTN->BTND", query_prime, kv,
return attention_output
def _build_from_signature(self, query, value, key=None):
......@@ -327,16 +453,12 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
self._output_dense_softmax = self._make_output_dense(
self._query_shape.rank - 1, common_kwargs,
self._query_shape.rank - 1,
self._dropout_softmax = tf.keras.layers.Dropout(rate=self._dropout)
def call(self,
def call(self, query, value, key=None, attention_mask=None, training=False):
"""Compute attention with kernel mechanism.
......@@ -371,19 +493,17 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
if self._begin_kernel > 0:
attention_output_softmax = self._compute_attention(
query[:, :self._begin_kernel],
key, value, "identity", True, attention_mask, training)
query[:, :self._begin_kernel], key, value, "identity", True,
attention_mask, training)
attention_output_softmax = self._dropout_softmax(attention_output_softmax)
attention_output_softmax = self._output_dense_softmax(
attention_output_kernel = self._compute_attention(
query[:, self._begin_kernel:],
key, value, self._feature_transform, self._is_short_seq,
attention_mask, training)
query[:, self._begin_kernel:], key, value, self._feature_transform,
self._is_short_seq, attention_mask, training)
attention_output_kernel = self._dropout_layer(attention_output_kernel)
attention_output_kernel = self._output_dense(
attention_output_kernel = self._output_dense(attention_output_kernel)
attention_output = tf.concat(
[attention_output_softmax, attention_output_kernel], axis=1)
......@@ -21,7 +21,7 @@ import tensorflow as tf
from official.nlp.modeling.layers import kernel_attention as attention
_FEATURE_TRANSFORM = ['relu', 'elu', 'exp']
_FEATURE_TRANSFORM = ['relu', 'elu', 'exp', 'expplus']
_REDRAW = [True, False]
_TRAINING = [True, False]
_IS_SHORT_SEQ = [True, False]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
想要评论请 注册