未验证 提交 edacb629 编写于 作者: X xiemoyuan 提交者: GitHub

Optimization of Transformer API (#30957)

* Support 'bool' and 'int' for attention mask.

* Update docs.

* Add unittest for Transformer.

* fix bugs.
上级 ee1801c1
...@@ -51,6 +51,7 @@ def generate_query_key_value_cache(self_attention, ...@@ -51,6 +51,7 @@ def generate_query_key_value_cache(self_attention,
num_heads, num_heads,
query_length, query_length,
embed_dim, embed_dim,
attn_mask_type,
key_length=None, key_length=None,
value_length=None, value_length=None,
kdim=None, kdim=None,
...@@ -58,8 +59,14 @@ def generate_query_key_value_cache(self_attention, ...@@ -58,8 +59,14 @@ def generate_query_key_value_cache(self_attention,
cache=None): cache=None):
query = np.random.rand(batch_size, query_length, query = np.random.rand(batch_size, query_length,
embed_dim).astype("float32") embed_dim).astype("float32")
attn_mask = np.zeros((batch_size, num_heads, query_length, key_length)) attn_mask = np.ones(
attn_mask[0][0][0][0] = -1e9 (batch_size, num_heads, query_length, key_length), dtype=attn_mask_type)
if attn_mask_type == 'int64':
attn_mask = np.tril(attn_mask)
elif attn_mask_type == 'float64':
attn_mask = (np.tril(attn_mask) - 1.0) * 1e9
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
if self_attention: if self_attention:
...@@ -115,6 +122,10 @@ def scaled_dot_product_attention(q, k, v, d_key, attn_mask, multi_head_attn): ...@@ -115,6 +122,10 @@ def scaled_dot_product_attention(q, k, v, d_key, attn_mask, multi_head_attn):
k = k.transpose([0, 1, 3, 2]) k = k.transpose([0, 1, 3, 2])
qkt = batch_matmul(q, k / np.sqrt(d_key, dtype=np.float64)) qkt = batch_matmul(q, k / np.sqrt(d_key, dtype=np.float64))
if attn_mask is not None: if attn_mask is not None:
if attn_mask.dtype.name == 'int64':
attn_mask = (attn_mask.astype(qkt.dtype) - 1.0) * 1e9
else:
attn_mask = attn_mask.astype(qkt.dtype)
qkt += attn_mask qkt += attn_mask
weight = softmax(qkt) weight = softmax(qkt)
attn_heads = batch_matmul(weight, v) attn_heads = batch_matmul(weight, v)
...@@ -219,53 +230,57 @@ class TestTransformer(unittest.TestCase): ...@@ -219,53 +230,57 @@ class TestTransformer(unittest.TestCase):
# generate params for multi_head_attention # generate params for multi_head_attention
batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout = generate_basic_params( batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout = generate_basic_params(
"attn", self_attention) "attn", self_attention)
query, key, value, attn_mask, cache_dict = generate_query_key_value_cache( for attn_mask_type in ['int64', 'float64']:
self_attention, batch_size, num_heads, query_length, query, key, value, attn_mask, cache_dict = generate_query_key_value_cache(
embed_dim, key_length, value_length, kdim, vdim, cache) self_attention, batch_size, num_heads, query_length,
if cache and self_attention: embed_dim, attn_mask_type, key_length, value_length,
attn_mask = np.concatenate((attn_mask, attn_mask), axis=3) kdim, vdim, cache)
need_weight, param_attr, bias_attr = False, None, None if cache and self_attention:
# call paddle's function attn_mask = np.concatenate(
multi_head_attn = MultiHeadAttention( (attn_mask, attn_mask), axis=3)
embed_dim, num_heads, attn_dropout, kdim, vdim, need_weight, need_weight, param_attr, bias_attr = False, None, None
param_attr, bias_attr) # call paddle's function
# construct cache object multi_head_attn = MultiHeadAttention(
cache_obj = None embed_dim, num_heads, attn_dropout, kdim, vdim,
if cache_dict: need_weight, param_attr, bias_attr)
if 'k' and 'v' in cache_dict: # construct cache object
cache_obj = multi_head_attn.Cache( cache_obj = None
paddle.to_tensor(cache_dict['k']), if cache_dict:
paddle.to_tensor(cache_dict['v'])) if 'k' and 'v' in cache_dict:
elif 'static_k' and 'static_v' in cache_dict: cache_obj = multi_head_attn.Cache(
cache_obj = multi_head_attn.StaticCache( paddle.to_tensor(cache_dict['k']),
paddle.to_tensor(cache_dict['static_k']), paddle.to_tensor(cache_dict['v']))
paddle.to_tensor(cache_dict['static_v'])) elif 'static_k' and 'static_v' in cache_dict:
if attn_mask is not None: cache_obj = multi_head_attn.StaticCache(
attn_output = multi_head_attn( paddle.to_tensor(cache_dict['static_k']),
paddle.to_tensor(query), paddle.to_tensor(cache_dict['static_v']))
paddle.to_tensor(key), if attn_mask is not None:
paddle.to_tensor(value), attn_output = multi_head_attn(
paddle.to_tensor(attn_mask), cache_obj) paddle.to_tensor(query),
else: paddle.to_tensor(key),
attn_output = multi_head_attn( paddle.to_tensor(value),
paddle.to_tensor(query), paddle.to_tensor(attn_mask), cache_obj)
paddle.to_tensor(key), else:
paddle.to_tensor(value), attn_mask, cache_obj) attn_output = multi_head_attn(
attn_output = attn_output[0] if cache_dict else attn_output paddle.to_tensor(query),
paddle.to_tensor(key),
# implementation by numpy paddle.to_tensor(value), attn_mask, cache_obj)
# compute q, k, v attn_output = attn_output[0] if cache_dict else attn_output
q, k, v, _ = prepare_qkv(query, key, value, num_heads,
embed_dim, self_attention, # implementation by numpy
multi_head_attn, cache_dict) # compute q, k, v
# scale dot product attention q, k, v, _ = prepare_qkv(query, key, value, num_heads,
attn_heads = scaled_dot_product_attention( embed_dim, self_attention,
q, k, v, embed_dim // num_heads, attn_mask, multi_head_attn) multi_head_attn, cache_dict)
out_proj_weight = multi_head_attn.out_proj.weight.numpy() # scale dot product attention
reference = fc(attn_heads, out_proj_weight) attn_heads = scaled_dot_product_attention(
q, k, v, embed_dim // num_heads, attn_mask,
np.testing.assert_allclose( multi_head_attn)
attn_output.numpy(), reference, atol=1e-6) out_proj_weight = multi_head_attn.out_proj.weight.numpy()
reference = fc(attn_heads, out_proj_weight)
np.testing.assert_allclose(
attn_output.numpy(), reference, atol=1e-6)
multihead_attention_test_helper(True, True) multihead_attention_test_helper(True, True)
multihead_attention_test_helper(True, False) multihead_attention_test_helper(True, False)
......
...@@ -34,6 +34,7 @@ from ... import tensor ...@@ -34,6 +34,7 @@ from ... import tensor
from ...fluid import layers from ...fluid import layers
from ...fluid.dygraph import Layer, LayerList from ...fluid.dygraph import Layer, LayerList
from ...fluid.param_attr import ParamAttr from ...fluid.param_attr import ParamAttr
from ...fluid.data_feeder import convert_dtype
def _convert_param_attr_to_list(param_attr, n): def _convert_param_attr_to_list(param_attr, n):
...@@ -82,6 +83,35 @@ def _convert_param_attr_to_list(param_attr, n): ...@@ -82,6 +83,35 @@ def _convert_param_attr_to_list(param_attr, n):
return param_attrs return param_attrs
def _convert_attention_mask(attn_mask, dtype):
"""
Convert the attention mask to the target dtype we expect.
Parameters:
attn_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
dtype (VarType): The target type of `attn_mask` we expect.
Returns:
Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`.
"""
if attn_mask is not None and attn_mask.dtype != dtype:
attn_mask_dtype = convert_dtype(attn_mask.dtype)
if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype:
attn_mask = (paddle.cast(attn_mask, dtype) - 1.0) * 1e9
else:
attn_mask = paddle.cast(attn_mask, dtype)
return attn_mask
class MultiHeadAttention(Layer): class MultiHeadAttention(Layer):
""" """
Attention mapps queries and a set of key-value pairs to outputs, and Attention mapps queries and a set of key-value pairs to outputs, and
...@@ -105,7 +135,7 @@ class MultiHeadAttention(Layer): ...@@ -105,7 +135,7 @@ class MultiHeadAttention(Layer):
weight_attr(ParamAttr, optional): To specify the weight parameter property. weight_attr(ParamAttr, optional): To specify the weight parameter property.
Default: None, which means the default weight parameter property is used. Default: None, which means the default weight parameter property is used.
See usage for details in :code:`ParamAttr` . See usage for details in :code:`ParamAttr` .
bias_attr (ParamAttr, optional): To specify the bias parameter property. bias_attr (ParamAttr|bool, optional): To specify the bias parameter property.
Default: None, which means the default bias parameter property is used. Default: None, which means the default bias parameter property is used.
If it is set to False, this layer will not have trainable bias parameter. If it is set to False, this layer will not have trainable bias parameter.
See usage for details in :code:`ParamAttr` . See usage for details in :code:`ParamAttr` .
...@@ -331,11 +361,13 @@ class MultiHeadAttention(Layer): ...@@ -331,11 +361,13 @@ class MultiHeadAttention(Layer):
attn_mask (Tensor, optional): A tensor used in multi-head attention attn_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`, broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
where the unwanted positions have `-INF` values and the others When the data type is bool, the unwanted positions have `False`
have 0 values. The data type should be float32 or float64. It can values and the others have `True` values. When the data type is
be None when nothing wanted or needed to be prevented attention to. int, the unwanted positions have 0 values and the others have 1
Default None values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional): cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional):
It is a namedtuple with `k` and `v` as fields, and stores tensors It is a namedtuple with `k` and `v` as fields, and stores tensors
shaped `[batch_size, num_heads, length, embed_dim]` which are results shaped `[batch_size, num_heads, length, embed_dim]` which are results
...@@ -374,7 +406,8 @@ class MultiHeadAttention(Layer): ...@@ -374,7 +406,8 @@ class MultiHeadAttention(Layer):
product = layers.matmul( product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if attn_mask is not None: if attn_mask is not None:
# TODO(guosheng): support bool mask # Support bool or int mask
attn_mask = _convert_attention_mask(attn_mask, product.dtype)
product = product + attn_mask product = product + attn_mask
weights = F.softmax(product) weights = F.softmax(product)
if self.dropout: if self.dropout:
...@@ -509,11 +542,13 @@ class TransformerEncoderLayer(Layer): ...@@ -509,11 +542,13 @@ class TransformerEncoderLayer(Layer):
src_mask (Tensor, optional): A tensor used in multi-head attention src_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`, broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
where the unwanted positions have `-INF` values and the others When the data type is bool, the unwanted positions have `False`
have 0 values. The data type should be float32 or float64. It can values and the others have `True` values. When the data type is
be None when nothing wanted or needed to be prevented attention to. int, the unwanted positions have 0 values and the others have 1
Default None values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
cache (Tensor, optional): It is an instance of `MultiHeadAttention.Cache`. cache (Tensor, optional): It is an instance of `MultiHeadAttention.Cache`.
See `TransformerEncoderLayer.gen_cache` for more details. It is See `TransformerEncoderLayer.gen_cache` for more details. It is
only used for inference and should be None for training. Default only used for inference and should be None for training. Default
...@@ -528,10 +563,12 @@ class TransformerEncoderLayer(Layer): ...@@ -528,10 +563,12 @@ class TransformerEncoderLayer(Layer):
incremental length. See `MultiHeadAttention.gen_cache` and \ incremental length. See `MultiHeadAttention.gen_cache` and \
`MultiHeadAttention.forward` for more details. `MultiHeadAttention.forward` for more details.
""" """
src_mask = _convert_attention_mask(src_mask, src.dtype)
residual = src residual = src
if self.normalize_before: if self.normalize_before:
src = self.norm1(src) src = self.norm1(src)
# TODO(guosheng): Add cache for encoder for the usage like UniLM # Add cache for encoder for the usage like UniLM
if cache is None: if cache is None:
src = self.self_attn(src, src, src, src_mask) src = self.self_attn(src, src, src, src_mask)
else: else:
...@@ -622,11 +659,13 @@ class TransformerEncoder(Layer): ...@@ -622,11 +659,13 @@ class TransformerEncoder(Layer):
src_mask (Tensor, optional): A tensor used in multi-head attention src_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`, broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
where the unwanted positions have `-INF` values and the others When the data type is bool, the unwanted positions have `False`
have 0 values. The data type should be float32 or float64. It can values and the others have `True` values. When the data type is
be None when nothing wanted or needed to be prevented attention to. int, the unwanted positions have 0 values and the others have 1
Default None values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
cache (list, optional): It is a list, and each element in the list cache (list, optional): It is a list, and each element in the list
is `incremental_cache` produced by `TransformerEncoderLayer.gen_cache`. is `incremental_cache` produced by `TransformerEncoderLayer.gen_cache`.
See `TransformerEncoder.gen_cache` for more details. It is only See `TransformerEncoder.gen_cache` for more details. It is only
...@@ -641,6 +680,8 @@ class TransformerEncoder(Layer): ...@@ -641,6 +680,8 @@ class TransformerEncoder(Layer):
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \ See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details. for more details.
""" """
src_mask = _convert_attention_mask(src_mask, src.dtype)
output = src output = src
new_caches = [] new_caches = []
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
...@@ -808,18 +849,23 @@ class TransformerDecoderLayer(Layer): ...@@ -808,18 +849,23 @@ class TransformerDecoderLayer(Layer):
tgt_mask (Tensor, optional): A tensor used in self attention tgt_mask (Tensor, optional): A tensor used in self attention
to prevents attention to some unwanted positions, usually the to prevents attention to some unwanted positions, usually the
the subsequent positions. It is a tensor with shape broadcasted the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`, to `[batch_size, n_head, target_length, target_length]`.
where the unwanted positions have `-INF` values and the others When the data type is bool, the unwanted positions have `False`
have 0 values. The data type should be float32 or float64. It can values and the others have `True` values. When the data type is
be None when nothing wanted or needed to be prevented attention to. int, the unwanted positions have 0 values and the others have 1
Default None values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
memory_mask (Tensor, optional): A tensor used in decoder-encoder memory_mask (Tensor, optional): A tensor used in decoder-encoder
cross attention to prevents attention to some unwanted positions, cross attention to prevents attention to some unwanted positions,
usually the paddings. It is a tensor with shape broadcasted to usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`, where the `[batch_size, n_head, target_length, source_length]`. When the
unwanted positions have `-INF` values and the others have 0 values. data type is bool, the unwanted positions have `False` values
The data type should be float32 or float64. It can be None when and the others have `True` values. When the data type is int,
nothing wanted or needed to be prevented attention to. Default None the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
cache (tuple, optional): It is a tuple( :code:`(incremental_cache, static_cache)` ), cache (tuple, optional): It is a tuple( :code:`(incremental_cache, static_cache)` ),
`incremental_cache` is an instance of `MultiHeadAttention.Cache`, `incremental_cache` is an instance of `MultiHeadAttention.Cache`,
`static_cache` is an instance of `MultiHeadAttention.StaticCache. `static_cache` is an instance of `MultiHeadAttention.StaticCache.
...@@ -836,6 +882,9 @@ class TransformerDecoderLayer(Layer): ...@@ -836,6 +882,9 @@ class TransformerDecoderLayer(Layer):
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \ See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details. for more details.
""" """
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)
residual = tgt residual = tgt
if self.normalize_before: if self.normalize_before:
tgt = self.norm1(tgt) tgt = self.norm1(tgt)
...@@ -958,18 +1007,23 @@ class TransformerDecoder(Layer): ...@@ -958,18 +1007,23 @@ class TransformerDecoder(Layer):
tgt_mask (Tensor, optional): A tensor used in self attention tgt_mask (Tensor, optional): A tensor used in self attention
to prevents attention to some unwanted positions, usually the to prevents attention to some unwanted positions, usually the
the subsequent positions. It is a tensor with shape broadcasted the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`, to `[batch_size, n_head, target_length, target_length]`. When
where the unwanted positions have `-INF` values and the others the data type is bool, the unwanted positions have `False`
have 0 values. The data type should be float32 or float64. It can values and the others have `True` values. When the data type is
be None when nothing wanted or needed to be prevented attention to. int, the unwanted positions have 0 values and the others have 1
Default None values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
memory_mask (Tensor, optional): A tensor used in decoder-encoder memory_mask (Tensor, optional): A tensor used in decoder-encoder
cross attention to prevents attention to some unwanted positions, cross attention to prevents attention to some unwanted positions,
usually the paddings. It is a tensor with shape broadcasted to usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`, where the `[batch_size, n_head, target_length, source_length]`. When the
unwanted positions have `-INF` values and the others have 0 values. data type is bool, the unwanted positions have `False` values
The data type should be float32 or float64. It can be None when and the others have `True` values. When the data type is int,
nothing wanted or needed to be prevented attention to. Default None the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
cache (list, optional): It is a list, and each element in the list cache (list, optional): It is a list, and each element in the list
is a tuple( :code:`(incremental_cache, static_cache)` ). See is a tuple( :code:`(incremental_cache, static_cache)` ). See
`TransformerDecoder.gen_cache` for more details. It is only `TransformerDecoder.gen_cache` for more details. It is only
...@@ -984,6 +1038,9 @@ class TransformerDecoder(Layer): ...@@ -984,6 +1038,9 @@ class TransformerDecoder(Layer):
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \ See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details. for more details.
""" """
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)
output = tgt output = tgt
new_caches = [] new_caches = []
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
...@@ -1222,27 +1279,46 @@ class Transformer(Layer): ...@@ -1222,27 +1279,46 @@ class Transformer(Layer):
memory (Tensor): The output of Transformer encoder. It is a tensor memory (Tensor): The output of Transformer encoder. It is a tensor
with shape `[batch_size, source_length, d_model]`. The data type with shape `[batch_size, source_length, d_model]`. The data type
should be float32 or float64. should be float32 or float64.
src_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
tgt_mask (Tensor, optional): A tensor used in self attention tgt_mask (Tensor, optional): A tensor used in self attention
to prevents attention to some unwanted positions, usually the to prevents attention to some unwanted positions, usually the
the subsequent positions. It is a tensor with shape broadcasted the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`, to `[batch_size, n_head, target_length, target_length]`. When
where the unwanted positions have `-INF` values and the others the data type is bool, the unwanted positions have `False`
have 0 values. The data type should be float32 or float64. It can values and the others have `True` values. When the data type is
be None when nothing wanted or needed to be prevented attention to. int, the unwanted positions have 0 values and the others have 1
Default None values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
memory_mask (Tensor, optional): A tensor used in decoder-encoder memory_mask (Tensor, optional): A tensor used in decoder-encoder
cross attention to prevents attention to some unwanted positions, cross attention to prevents attention to some unwanted positions,
usually the paddings. It is a tensor with shape broadcasted to usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`, where the `[batch_size, n_head, target_length, source_length]`. When the
unwanted positions have `-INF` values and the others have 0 values. data type is bool, the unwanted positions have `False` values
The data type should be float32 or float64. It can be None when and the others have `True` values. When the data type is int,
nothing wanted or needed to be prevented attention to. Default None the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
Returns: Returns:
Tensor: It is a tensor that has the same shape and data type \ Tensor: It is a tensor that has the same shape and data type \
as `tgt`, representing the output of Transformer decoder. as `tgt`, representing the output of Transformer decoder.
""" """
src_mask = _convert_attention_mask(src_mask, src.dtype)
memory = self.encoder(src, src_mask=src_mask) memory = self.encoder(src, src_mask=src_mask)
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)
output = self.decoder( output = self.decoder(
tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask) tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
return output return output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册