未验证 提交 edacb629 编写于 作者: X xiemoyuan 提交者: GitHub

Optimization of Transformer API (#30957)

* Support 'bool' and 'int' for attention mask.

* Update docs.

* Add unittest for Transformer.

* fix bugs.
上级 ee1801c1
......@@ -51,6 +51,7 @@ def generate_query_key_value_cache(self_attention,
num_heads,
query_length,
embed_dim,
attn_mask_type,
key_length=None,
value_length=None,
kdim=None,
......@@ -58,8 +59,14 @@ def generate_query_key_value_cache(self_attention,
cache=None):
query = np.random.rand(batch_size, query_length,
embed_dim).astype("float32")
attn_mask = np.zeros((batch_size, num_heads, query_length, key_length))
attn_mask[0][0][0][0] = -1e9
attn_mask = np.ones(
(batch_size, num_heads, query_length, key_length), dtype=attn_mask_type)
if attn_mask_type == 'int64':
attn_mask = np.tril(attn_mask)
elif attn_mask_type == 'float64':
attn_mask = (np.tril(attn_mask) - 1.0) * 1e9
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
head_dim = embed_dim // num_heads
if self_attention:
......@@ -115,6 +122,10 @@ def scaled_dot_product_attention(q, k, v, d_key, attn_mask, multi_head_attn):
k = k.transpose([0, 1, 3, 2])
qkt = batch_matmul(q, k / np.sqrt(d_key, dtype=np.float64))
if attn_mask is not None:
if attn_mask.dtype.name == 'int64':
attn_mask = (attn_mask.astype(qkt.dtype) - 1.0) * 1e9
else:
attn_mask = attn_mask.astype(qkt.dtype)
qkt += attn_mask
weight = softmax(qkt)
attn_heads = batch_matmul(weight, v)
......@@ -219,16 +230,19 @@ class TestTransformer(unittest.TestCase):
# generate params for multi_head_attention
batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout = generate_basic_params(
"attn", self_attention)
for attn_mask_type in ['int64', 'float64']:
query, key, value, attn_mask, cache_dict = generate_query_key_value_cache(
self_attention, batch_size, num_heads, query_length,
embed_dim, key_length, value_length, kdim, vdim, cache)
embed_dim, attn_mask_type, key_length, value_length,
kdim, vdim, cache)
if cache and self_attention:
attn_mask = np.concatenate((attn_mask, attn_mask), axis=3)
attn_mask = np.concatenate(
(attn_mask, attn_mask), axis=3)
need_weight, param_attr, bias_attr = False, None, None
# call paddle's function
multi_head_attn = MultiHeadAttention(
embed_dim, num_heads, attn_dropout, kdim, vdim, need_weight,
param_attr, bias_attr)
embed_dim, num_heads, attn_dropout, kdim, vdim,
need_weight, param_attr, bias_attr)
# construct cache object
cache_obj = None
if cache_dict:
......@@ -260,7 +274,8 @@ class TestTransformer(unittest.TestCase):
multi_head_attn, cache_dict)
# scale dot product attention
attn_heads = scaled_dot_product_attention(
q, k, v, embed_dim // num_heads, attn_mask, multi_head_attn)
q, k, v, embed_dim // num_heads, attn_mask,
multi_head_attn)
out_proj_weight = multi_head_attn.out_proj.weight.numpy()
reference = fc(attn_heads, out_proj_weight)
......
......@@ -34,6 +34,7 @@ from ... import tensor
from ...fluid import layers
from ...fluid.dygraph import Layer, LayerList
from ...fluid.param_attr import ParamAttr
from ...fluid.data_feeder import convert_dtype
def _convert_param_attr_to_list(param_attr, n):
......@@ -82,6 +83,35 @@ def _convert_param_attr_to_list(param_attr, n):
return param_attrs
def _convert_attention_mask(attn_mask, dtype):
"""
Convert the attention mask to the target dtype we expect.
Parameters:
attn_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
dtype (VarType): The target type of `attn_mask` we expect.
Returns:
Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`.
"""
if attn_mask is not None and attn_mask.dtype != dtype:
attn_mask_dtype = convert_dtype(attn_mask.dtype)
if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype:
attn_mask = (paddle.cast(attn_mask, dtype) - 1.0) * 1e9
else:
attn_mask = paddle.cast(attn_mask, dtype)
return attn_mask
class MultiHeadAttention(Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
......@@ -105,7 +135,7 @@ class MultiHeadAttention(Layer):
weight_attr(ParamAttr, optional): To specify the weight parameter property.
Default: None, which means the default weight parameter property is used.
See usage for details in :code:`ParamAttr` .
bias_attr (ParamAttr, optional): To specify the bias parameter property.
bias_attr (ParamAttr|bool, optional): To specify the bias parameter property.
Default: None, which means the default bias parameter property is used.
If it is set to False, this layer will not have trainable bias parameter.
See usage for details in :code:`ParamAttr` .
......@@ -331,11 +361,13 @@ class MultiHeadAttention(Layer):
attn_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`,
where the unwanted positions have `-INF` values and the others
have 0 values. The data type should be float32 or float64. It can
be None when nothing wanted or needed to be prevented attention to.
Default None
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional):
It is a namedtuple with `k` and `v` as fields, and stores tensors
shaped `[batch_size, num_heads, length, embed_dim]` which are results
......@@ -374,7 +406,8 @@ class MultiHeadAttention(Layer):
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if attn_mask is not None:
# TODO(guosheng): support bool mask
# Support bool or int mask
attn_mask = _convert_attention_mask(attn_mask, product.dtype)
product = product + attn_mask
weights = F.softmax(product)
if self.dropout:
......@@ -509,11 +542,13 @@ class TransformerEncoderLayer(Layer):
src_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`,
where the unwanted positions have `-INF` values and the others
have 0 values. The data type should be float32 or float64. It can
be None when nothing wanted or needed to be prevented attention to.
Default None
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
cache (Tensor, optional): It is an instance of `MultiHeadAttention.Cache`.
See `TransformerEncoderLayer.gen_cache` for more details. It is
only used for inference and should be None for training. Default
......@@ -528,10 +563,12 @@ class TransformerEncoderLayer(Layer):
incremental length. See `MultiHeadAttention.gen_cache` and \
`MultiHeadAttention.forward` for more details.
"""
src_mask = _convert_attention_mask(src_mask, src.dtype)
residual = src
if self.normalize_before:
src = self.norm1(src)
# TODO(guosheng): Add cache for encoder for the usage like UniLM
# Add cache for encoder for the usage like UniLM
if cache is None:
src = self.self_attn(src, src, src, src_mask)
else:
......@@ -622,11 +659,13 @@ class TransformerEncoder(Layer):
src_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`,
where the unwanted positions have `-INF` values and the others
have 0 values. The data type should be float32 or float64. It can
be None when nothing wanted or needed to be prevented attention to.
Default None
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
cache (list, optional): It is a list, and each element in the list
is `incremental_cache` produced by `TransformerEncoderLayer.gen_cache`.
See `TransformerEncoder.gen_cache` for more details. It is only
......@@ -641,6 +680,8 @@ class TransformerEncoder(Layer):
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
src_mask = _convert_attention_mask(src_mask, src.dtype)
output = src
new_caches = []
for i, mod in enumerate(self.layers):
......@@ -808,18 +849,23 @@ class TransformerDecoderLayer(Layer):
tgt_mask (Tensor, optional): A tensor used in self attention
to prevents attention to some unwanted positions, usually the
the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`,
where the unwanted positions have `-INF` values and the others
have 0 values. The data type should be float32 or float64. It can
be None when nothing wanted or needed to be prevented attention to.
Default None
to `[batch_size, n_head, target_length, target_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
memory_mask (Tensor, optional): A tensor used in decoder-encoder
cross attention to prevents attention to some unwanted positions,
usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`, where the
unwanted positions have `-INF` values and the others have 0 values.
The data type should be float32 or float64. It can be None when
nothing wanted or needed to be prevented attention to. Default None
`[batch_size, n_head, target_length, source_length]`. When the
data type is bool, the unwanted positions have `False` values
and the others have `True` values. When the data type is int,
the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
cache (tuple, optional): It is a tuple( :code:`(incremental_cache, static_cache)` ),
`incremental_cache` is an instance of `MultiHeadAttention.Cache`,
`static_cache` is an instance of `MultiHeadAttention.StaticCache.
......@@ -836,6 +882,9 @@ class TransformerDecoderLayer(Layer):
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
......@@ -958,18 +1007,23 @@ class TransformerDecoder(Layer):
tgt_mask (Tensor, optional): A tensor used in self attention
to prevents attention to some unwanted positions, usually the
the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`,
where the unwanted positions have `-INF` values and the others
have 0 values. The data type should be float32 or float64. It can
be None when nothing wanted or needed to be prevented attention to.
Default None
to `[batch_size, n_head, target_length, target_length]`. When
the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
memory_mask (Tensor, optional): A tensor used in decoder-encoder
cross attention to prevents attention to some unwanted positions,
usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`, where the
unwanted positions have `-INF` values and the others have 0 values.
The data type should be float32 or float64. It can be None when
nothing wanted or needed to be prevented attention to. Default None
`[batch_size, n_head, target_length, source_length]`. When the
data type is bool, the unwanted positions have `False` values
and the others have `True` values. When the data type is int,
the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
cache (list, optional): It is a list, and each element in the list
is a tuple( :code:`(incremental_cache, static_cache)` ). See
`TransformerDecoder.gen_cache` for more details. It is only
......@@ -984,6 +1038,9 @@ class TransformerDecoder(Layer):
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)
output = tgt
new_caches = []
for i, mod in enumerate(self.layers):
......@@ -1222,27 +1279,46 @@ class Transformer(Layer):
memory (Tensor): The output of Transformer encoder. It is a tensor
with shape `[batch_size, source_length, d_model]`. The data type
should be float32 or float64.
src_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
tgt_mask (Tensor, optional): A tensor used in self attention
to prevents attention to some unwanted positions, usually the
the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`,
where the unwanted positions have `-INF` values and the others
have 0 values. The data type should be float32 or float64. It can
be None when nothing wanted or needed to be prevented attention to.
Default None
to `[batch_size, n_head, target_length, target_length]`. When
the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
memory_mask (Tensor, optional): A tensor used in decoder-encoder
cross attention to prevents attention to some unwanted positions,
usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`, where the
unwanted positions have `-INF` values and the others have 0 values.
The data type should be float32 or float64. It can be None when
nothing wanted or needed to be prevented attention to. Default None
`[batch_size, n_head, target_length, source_length]`. When the
data type is bool, the unwanted positions have `False` values
and the others have `True` values. When the data type is int,
the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Returns:
Tensor: It is a tensor that has the same shape and data type \
as `tgt`, representing the output of Transformer decoder.
"""
src_mask = _convert_attention_mask(src_mask, src.dtype)
memory = self.encoder(src, src_mask=src_mask)
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)
output = self.decoder(
tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
return output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册