From a13180fc409af6cf47bbb87ce07f3b90cabc50c1 Mon Sep 17 00:00:00 2001 From: Wenyu Date: Thu, 8 Jul 2021 12:46:49 +0800 Subject: [PATCH] fix import problem of _convert_attention_mask (#3631) --- ppdet/modeling/layers.py | 23 +++++++++++++++++-- .../modeling/transformers/detr_transformer.py | 3 +-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index fd3181265..713880b9a 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -31,8 +31,6 @@ from . import ops from .initializer import xavier_uniform_, constant_ from paddle.vision.ops import DeformConv2D -from paddle.nn.layer import transformer -_convert_attention_mask = transformer._convert_attention_mask def _to_list(l): @@ -1195,6 +1193,27 @@ class Concat(nn.Layer): return 'dim={}'.format(self.dim) +def _convert_attention_mask(attn_mask, dtype): + """ + Convert the attention mask to the target dtype we expect. + Parameters: + attn_mask (Tensor, optional): A tensor used in multi-head attention + to prevents attention to some unwanted positions, usually the + paddings or the subsequent positions. It is a tensor with shape + broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. + When the data type is bool, the unwanted positions have `False` + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when + nothing wanted or needed to be prevented attention to. Default None. + dtype (VarType): The target type of `attn_mask` we expect. + Returns: + Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`. + """ + return nn.layer.transformer._convert_attention_mask(attn_mask, dtype) + + class MultiHeadAttention(nn.Layer): """ Attention mapps queries and a set of key-value pairs to outputs, and diff --git a/ppdet/modeling/transformers/detr_transformer.py b/ppdet/modeling/transformers/detr_transformer.py index 92d79d53c..9069ee8c4 100644 --- a/ppdet/modeling/transformers/detr_transformer.py +++ b/ppdet/modeling/transformers/detr_transformer.py @@ -18,11 +18,10 @@ from __future__ import print_function import paddle import paddle.nn as nn -from paddle.nn.layer.transformer import _convert_attention_mask import paddle.nn.functional as F from ppdet.core.workspace import register -from ..layers import MultiHeadAttention +from ..layers import MultiHeadAttention, _convert_attention_mask from .position_encoding import PositionEmbedding from .utils import * from ..initializer import * -- GitLab