提交 f6a29a8b 编写于 作者: F Frederick Liu 提交者: A. Unique TensorFlower

[kernel] Add streaming support.

PiperOrigin-RevId: 477214841
上级 910cdfa1
......@@ -160,7 +160,8 @@ def causal_windowed_performer_attention(query_matrix,
chunk_length,
window_length,
window_decay=None,
padding=None):
padding=None,
cache=None):
"""Applies windowed causal kernel attention with query, key, value tensors.
We partition the T-length input sequence into N chunks, each of
......@@ -202,55 +203,70 @@ def causal_windowed_performer_attention(query_matrix,
padding if padding is set to None. In the latter case, the axis dimension
of the query, value and key input tensors must be divisible by the
chunk_length.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
Returns:
Window causal performer attention of shape `[B, T, H, out_dim]`.
"""
old_shape = tf.shape(value_matrix)
query_matrix = pad_to_chunk_length(query_matrix, -3, chunk_length, padding)
key_matrix = pad_to_chunk_length(key_matrix, -3, chunk_length, padding)
value_matrix = pad_to_chunk_length(value_matrix, -3, chunk_length, padding)
new_shape = tf.shape(value_matrix)
chunked_query_matrix = split_tensor_into_chunks(
query_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, dim]
chunked_key_matrix = split_tensor_into_chunks(
key_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, dim]
chunked_value_matrix = split_tensor_into_chunks(
value_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, out_dim]
kp_v = tf.einsum("BTCHD,BTCHO->BTHDO", chunked_key_matrix,
chunked_value_matrix)
k_sum = tf.math.reduce_sum(chunked_key_matrix, axis=-3, keepdims=True)
if window_decay is None:
kp_v_winsum = rectangular_window_sum(kp_v, window_length)
k_winsum = rectangular_window_sum(k_sum, window_length)
else:
# Compute exponentially decaying weights.
decaying_weights = tf.math.pow(
tf.convert_to_tensor(window_decay, dtype=value_matrix.dtype),
tf.range(window_length - 1, -1, delta=-1, dtype=value_matrix.dtype))
kp_v_winsum = weighted_window_sum(kp_v, window_length, decaying_weights)
k_winsum = weighted_window_sum(k_sum, window_length, decaying_weights)
numerator = tf.einsum("BTCHD,BTHDO->BTCHO", chunked_query_matrix, kp_v_winsum)
k_winsum = tf.squeeze(k_winsum, -3)
denominator = tf.einsum("BTCHD,BTHD->BTCH", chunked_query_matrix, k_winsum)
denominator = tf.expand_dims(denominator, -1) + _NUMERIC_STABLER
attention = numerator / denominator
attention = tf.reshape(attention, new_shape)
start = tf.zeros([len(old_shape)], dtype=old_shape.dtype)
attention = tf.slice(attention, start, old_shape)
if cache is None: # Training
old_shape = tf.shape(value_matrix)
query_matrix = pad_to_chunk_length(query_matrix, -3, chunk_length, padding)
key_matrix = pad_to_chunk_length(key_matrix, -3, chunk_length, padding)
value_matrix = pad_to_chunk_length(value_matrix, -3, chunk_length, padding)
new_shape = tf.shape(value_matrix)
chunked_query_matrix = split_tensor_into_chunks(
query_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, dim]
chunked_key_matrix = split_tensor_into_chunks(
key_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, dim]
chunked_value_matrix = split_tensor_into_chunks(
value_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, out_dim]
kp_v = tf.einsum("BTCHD,BTCHO->BTHDO", chunked_key_matrix,
chunked_value_matrix)
k_sum = tf.math.reduce_sum(chunked_key_matrix, axis=-3, keepdims=True)
if window_decay is None:
kp_v_winsum = rectangular_window_sum(kp_v, window_length)
k_winsum = rectangular_window_sum(k_sum, window_length)
else:
# Compute exponentially decaying weights.
decaying_weights = tf.math.pow(
tf.convert_to_tensor(window_decay, dtype=value_matrix.dtype),
tf.range(window_length - 1, -1, delta=-1, dtype=value_matrix.dtype))
kp_v_winsum = weighted_window_sum(kp_v, window_length, decaying_weights)
k_winsum = weighted_window_sum(k_sum, window_length, decaying_weights)
numerator = tf.einsum(
"BTCHD,BTHDO->BTCHO", chunked_query_matrix, kp_v_winsum)
k_winsum = tf.squeeze(k_winsum, -3)
denominator = tf.einsum("BTCHD,BTHD->BTCH", chunked_query_matrix, k_winsum)
denominator = tf.expand_dims(denominator, -1) + _NUMERIC_STABLER
attention = numerator / denominator
attention = tf.reshape(attention, new_shape)
start = tf.zeros([len(old_shape)], dtype=old_shape.dtype)
attention = tf.slice(attention, start, old_shape)
# Queued window cache (drop instead of decay) not yet supported.
else: # Streaming
if window_decay is None or window_decay > 1.0 or window_decay < 0.0:
raise ValueError("window_decay should be in (0.0, 1.0) and not None.")
kv = cache["kv"] + tf.einsum("BTHD,BTHO->BHOD", key_matrix, value_matrix)
cache["kv"] = kv * window_decay
k_sum = cache["k_sum"] + tf.reduce_sum(key_matrix, axis=1)
cache["k_sum"] = k_sum * window_decay
denominator = tf.einsum("BTHD,BHD->BTH", query_matrix, k_sum)
attention = tf.einsum("BTHD,BHOD,BTH->BTHO", query_matrix, kv,
1.0 / (denominator + _NUMERIC_STABLER))
return attention
......@@ -443,7 +459,7 @@ def expplus(data_orig,
# pylint: disable=g-long-lambda
_TRANSFORM_MAP = {
_CAUSAL_SUPPORT_TRANSFORM_MAP = {
"elu":
functools.partial(
_generalized_kernel,
......@@ -476,11 +492,19 @@ _TRANSFORM_MAP = {
h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt(
tf.cast(tf.shape(x)[-1], tf.float32))),
),
"expplus":
expplus,
"identity":
functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1)
}
_NON_CAUSAL_SUPPORT_TRANSFORM_MAP = {
"expplus": expplus,
}
_TRANSFORM_MAP = {
**_CAUSAL_SUPPORT_TRANSFORM_MAP,
**_NON_CAUSAL_SUPPORT_TRANSFORM_MAP
}
# pylint: enable=g-long-lambda
......@@ -609,6 +633,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
feature_transform,
is_short_seq,
attention_mask=None,
cache=None,
training=False,
numeric_stabler=_NUMERIC_STABLER):
"""Applies kernel attention with query, key, value tensors.
......@@ -628,6 +653,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
numeric_stabler: A scalar value added to avoid divide by 0.
......@@ -682,7 +709,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
chunk_length=self.causal_chunk_length,
window_length=self.causal_window_length,
window_decay=self.causal_window_decay,
padding=self.causal_padding)
padding=self.causal_padding,
cache=cache)
else:
kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value)
denominator = 1.0 / (
......@@ -709,7 +737,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
name="attention_output_softmax")
self._dropout_softmax = tf.keras.layers.Dropout(rate=self._dropout)
def call(self, query, value, key=None, attention_mask=None, training=False):
def call(self, query, value, key=None, attention_mask=None, cache=None,
training=False):
"""Compute attention with kernel mechanism.
Args:
......@@ -720,12 +749,29 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_mask: a boolean mask of shape `[B, S]`, that prevents attenting
to masked positions. Note that the mask is only appied to the keys. User
may want to mask the output if query contains pads.
cache: Cache to accumulate history in memory. Used at inferecne time
(streaming, decoding) for causal attention.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
Multi-headed outputs of attention computation.
"""
if cache is not None:
if training:
raise ValueError(
"Cache is not supported when training is True.")
if not self.use_causal_windowed:
raise ValueError(
"Cache is not supported for non use_causal_windowed case.")
if self._begin_kernel:
raise ValueError(
"Cache is not supported when begin_kernel is set since the bahvior "
"is too complicated.")
if self._feature_transform in _NON_CAUSAL_SUPPORT_TRANSFORM_MAP:
raise ValueError("Cache is not supported for feature_transform %s" %
(self._feature_transform))
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
......@@ -761,7 +807,9 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_output = self._compute_attention(query, key, value,
self._feature_transform,
self._is_short_seq,
attention_mask, training)
attention_mask,
cache,
training)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_output = self._dropout_layer(attention_output)
......
......@@ -30,6 +30,64 @@ _BEGIN_KERNEL = [0, 512]
class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
# expplus is only designed for bi-directional use case.
# exp can be numeric unstable.
@parameterized.parameters(itertools.product(
["relu", "elu"], [1, 4], [0.9]))
def test_causal_windowed_attention_projection_streaming(
self, feature_transform, causal_chunk_length, causal_weight_decay):
num_heads = 12
key_dim = 64
seq_length = 16
num_chunks = seq_length // causal_chunk_length
causal_window_length = num_chunks
batch_size = 2
training = False
num_random_features = 0
test_layer = attention.KernelAttention(
num_heads=num_heads,
key_dim=key_dim,
feature_transform=feature_transform,
num_random_features=num_random_features,
redraw=False,
is_short_seq=False,
begin_kernel=False,
use_causal_windowed=True,
causal_chunk_length=causal_chunk_length,
causal_window_length=causal_window_length,
causal_window_decay=causal_weight_decay,
causal_padding=None,
)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim), seed=2)
value = query
encoder_inputs_mask = tf.ones((batch_size, seq_length), dtype=tf.int32)
masks = tf.cast(encoder_inputs_mask, dtype=tf.float32)
output = test_layer(
query=query,
value=value,
attention_mask=masks,
training=training)
kv_cache = tf.zeros(
(batch_size, num_heads, key_dim,
num_random_features if num_random_features > 0 else key_dim))
k_sum_cache = tf.zeros((batch_size, 1, key_dim))
stream_output = []
cache = {"kv": kv_cache, "k_sum": k_sum_cache}
for i in range(num_chunks):
stream_output.append(
test_layer(
query=query[:, i * causal_chunk_length:(i + 1) *
causal_chunk_length, :],
value=value[:, i * causal_chunk_length:(i + 1) *
causal_chunk_length, :],
attention_mask=masks[:, i * causal_chunk_length:(i + 1) *
causal_chunk_length],
cache=cache,
training=training))
stream_output = tf.concat(stream_output, axis=1)
self.assertAllClose(output, stream_output)
@parameterized.parameters(
itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
_IS_SHORT_SEQ, _BEGIN_KERNEL))
......@@ -196,6 +254,5 @@ class KernelAttentionTest(tf.test.TestCase, parameterized.TestCase):
[2, 1, 2, 2, 2]),
winsum)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册