提交 b7f6e079 编写于 作者: A Allen Wang 提交者: A. Unique TensorFlower

MultiHeadRelativeAttention compatibility changes with XLNet

PiperOrigin-RevId: 330751568
上级 da4aca1c
......@@ -20,14 +20,29 @@ import string
import tensorflow as tf
from official.nlp.modeling.layers import masked_softmax
EinsumDense = tf.keras.layers.experimental.EinsumDense
MultiHeadAttention = tf.keras.layers.MultiHeadAttention
_CHR_IDX = string.ascii_lowercase
def _large_compatible_negative(tensor_type):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if tensor_type == tf.float16:
return tf.float16.min
return -1e9
@tf.keras.utils.register_keras_serializable(package="Text")
class CachedAttention(tf.keras.layers.MultiHeadAttention):
"""Attention layer with cache used for auto-agressive decoding.
......@@ -116,14 +131,15 @@ class CachedAttention(tf.keras.layers.MultiHeadAttention):
def _rel_shift(x, klen=-1):
"""Performs relative shift to form the relative attention score."""
x = tf.transpose(x, perm=[1, 2, 0, 3])
x = tf.transpose(x, perm=[2, 3, 0, 1])
x_size = tf.shape(x)
x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1])
x = tf.transpose(x, perm=[2, 0, 1, 3])
x = tf.transpose(x, perm=[2, 3, 0, 1])
return x
......@@ -200,15 +216,17 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
to certain positions.
"""
def __init__(self,
kernel_initializer="variance_scaling",
**kwargs):
super().__init__(kernel_initializer=kernel_initializer,
**kwargs)
def _build_from_signature(self, query, value, key=None):
super(MultiHeadRelativeAttention, self)._build_from_signature(
query=query,
value=value,
key=key)
if hasattr(query, "shape"):
query_shape = tf.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape)
else:
......@@ -230,36 +248,16 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
bias_constraint=self._bias_constraint)
with tf.init_scope():
free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation(
einsum_equation, _, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2)
self._encoding_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
bias_axes=None,
name="encoding",
**common_kwargs)
output_shape = [query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape))
# TODO(allencwang) - replace all einsums with programmatic equations.
einsum_equation = "abcd,ecd->abe"
self._output_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if self._use_bias else None,
name="attention_output",
**common_kwargs)
def _build_attention(self, rank):
self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=[2])
self._dropout_layer = tf.keras.layers.Dropout(
rate=self._dropout)
def compute_attention(self,
query,
key,
......@@ -267,6 +265,9 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
position,
content_attention_bias,
positional_attention_bias,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
attention_mask=None):
"""Computes the attention.
......@@ -282,33 +283,59 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
when calculating the content-based attention score.
positional_attention_bias: Trainable bias parameter added to the query
head when calculating the position-based attention score.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional trainable `Tensor` representing the
segmentation encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
attention_mask: (default None) Optional mask that is added to attention
logits. If state is not None, the mask source sequence dimension should
extend M.
Returns:
attention_output: Multi-headed output of attention computation of shape
`[B, T, N, key_dim]`.
`[B, S, N, key_dim]`.
"""
content_attention = tf.einsum("bind,bjnd->bijn",
query + content_attention_bias,
key)
content_attention = tf.einsum(self._dot_product_equation,
key,
query + content_attention_bias)
positional_attention = tf.einsum(self._dot_product_equation,
position,
query + positional_attention_bias)
positional_attention = _rel_shift(
positional_attention, klen=tf.shape(content_attention)[3])
if segment_matrix is not None:
segment_attention = tf.einsum("bind,snd->bnis",
query + segment_attention_bias,
segment_encoding)
target_shape = tf.shape(positional_attention)
segment_attention = tf.where(
tf.broadcast_to(tf.expand_dims(segment_matrix, 1), target_shape),
tf.broadcast_to(segment_attention[:, :, :, 1:], target_shape),
tf.broadcast_to(segment_attention[:, :, :, :1], target_shape))
attention_sum = (
content_attention + positional_attention + segment_attention)
else:
attention_sum = content_attention + positional_attention
positional_attention = tf.einsum("bind,bjnd->bijn",
query + positional_attention_bias,
position)
attention_scores = tf.multiply(
attention_sum, 1.0 / math.sqrt(float(self._key_dim)))
positional_attention = _rel_shift(
positional_attention, klen=tf.shape(content_attention)[2])
# `attention_scores`: `[B, N, S, S + M]`
if attention_mask is not None:
attention_scores += (_large_compatible_negative(attention_scores.dtype)
* attention_mask)
attention_scores = tf.multiply((content_attention + positional_attention),
1.0 / math.sqrt(float(self._key_dim)))
attention_scores = self._masked_softmax(attention_scores, attention_mask)
attention_scores = tf.nn.softmax(attention_scores, 3)
attention_output = self._dropout_layer(attention_scores)
attention_output = tf.einsum("bijn,bjnd->bind", attention_output, value)
attention_output = tf.einsum(self._combine_equation,
attention_output,
value)
return attention_output
def call(self,
......@@ -318,6 +345,9 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
positional_attention_bias,
key=None,
relative_position_encoding=None,
segment_matrix=None,
segment_encoding=None,
segment_attention_bias=None,
state=None,
attention_mask=None):
"""Compute multi-head relative attention over inputs.
......@@ -342,6 +372,13 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
key: attention input.
relative_position_encoding: relative positional encoding for key and
value.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
state: (default None) optional state. If passed, this is also attended
over as in TransformerXL.
attention_mask: (default None) Optional mask that is added to attention
......@@ -381,7 +418,12 @@ class MultiHeadRelativeAttention(MultiHeadAttention):
position=position,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
segment_matrix=segment_matrix,
segment_encoding=segment_encoding,
segment_attention_bias=segment_attention_bias,
attention_mask=attention_mask)
# `attention_output` = [B, S, N, H]
attention_output = self._output_dense(attention_output)
return attention_output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册