提交 113cd6b3 编写于 作者: Y ying

add multi-head scaled_dot_product attention.

上级 abf9395d
......@@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pdb
import layers
__all__ = [
"simple_img_conv_pool",
"sequence_conv_pool",
"glu",
"dot_product_attention",
"scaled_dot_product_attention",
]
......@@ -179,7 +179,7 @@ def scaled_dot_product_attention(queries,
.. math::
Attention(Q, K, V)= softmax(QK^\mathrm{T})V
Attention(Q, K, V)= softmax(QK^\mathrm{T})V
Refer to `Attention Is All You Need
<https://arxiv.org/pdf/1706.03762.pdf>`_.
......@@ -195,8 +195,8 @@ def scaled_dot_product_attention(queries,
LoDTensor.
Returns:
tuple: The Tensor variables representing the output and attention
scores.
Variable: The context Tensor computed by multi-head scaled dot product
attention.
Examples:
.. code-block:: python
......@@ -239,26 +239,42 @@ def scaled_dot_product_attention(queries,
Returns:
a Tensor with shape [..., n, m/n]
"""
if num_heads == 1: return x
hidden_size = x.shape[-1]
#
# reshape the 3-D input: [batch_size, max_sequence_length, hidden_dim]
# into a 4-D output:
# [batch_size, max_sequence_length, num_heads, hidden_size_per_head].
reshaped = layers.reshape(
x=x, shape=x.shape[:-1] + [num_heads, hidden_size // num_heads])
pass
def __combine_heads():
pass
q = __split_heads(quries, num_heads)
x=x,
shape=list(x.shape[:-1]) + [num_heads, hidden_size // num_heads])
# permuate the original dimensions into:
# [batch_size, num_heads, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
def __combine_heads(x):
if len(x.shape) == 3: return
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(
x, perm=[x.shape[0], x.shape[2], x.shape[1], x.shape[3]])
return layers.reshape(x=layers.reshape(
x=trans_x,
shape=[trans_x.shape[0], trans_x[1], trans_x[2] * trans_x[3]]))
q = __split_heads(queries, num_heads)
k = __split_heads(keys, num_heads)
v = __split_heads(values, num_heads)
key_dim_per_head = keys.shape[-1] // num_heads
scale = key_dim_per_head**-0.5
scaled_q = layers.scale(x=q, scale=key_dim_per_head**-0.5)
product = layers.matmul(x=k, y=scaled_q, transpose_y=True)
product = layers.matmul(x=k, y=q, transpose_y=True)
attn_scores = layers.reshape(
x=layers.reshape(
x=product, shape=[-1, product.shape[-1]], act="softmax"),
shape=product.shape)
context = layers.matmul(attn_scores, values)
return context, attn_scores
ctx_multiheads = layers.matmul(attn_scores, values)
context = __combine_heads(ctx_multiheads)
return context
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册