未验证 提交 a13180fc 编写于 作者: W Wenyu 提交者: GitHub

fix import problem of _convert_attention_mask (#3631)

上级 c8f8a3b0
......@@ -31,8 +31,6 @@ from . import ops
from .initializer import xavier_uniform_, constant_
from paddle.vision.ops import DeformConv2D
from paddle.nn.layer import transformer
_convert_attention_mask = transformer._convert_attention_mask
def _to_list(l):
......@@ -1195,6 +1193,27 @@ class Concat(nn.Layer):
return 'dim={}'.format(self.dim)
def _convert_attention_mask(attn_mask, dtype):
"""
Convert the attention mask to the target dtype we expect.
Parameters:
attn_mask (Tensor, optional): A tensor used in multi-head attention
to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
dtype (VarType): The target type of `attn_mask` we expect.
Returns:
Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`.
"""
return nn.layer.transformer._convert_attention_mask(attn_mask, dtype)
class MultiHeadAttention(nn.Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
......
......@@ -18,11 +18,10 @@ from __future__ import print_function
import paddle
import paddle.nn as nn
from paddle.nn.layer.transformer import _convert_attention_mask
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..layers import MultiHeadAttention
from ..layers import MultiHeadAttention, _convert_attention_mask
from .position_encoding import PositionEmbedding
from .utils import *
from ..initializer import *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册