提交 0d96899f 编写于 作者: Y ying

fix the documentation.

上级 d163592a
......@@ -1968,7 +1968,7 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
data = fluid.layers.data(name="data",
shape=(3, 17, 13),
dtype="float32")
fc = fluid.layers.l2_normalize(x=data, axis=1)
normed = fluid.layers.l2_normalize(x=data, axis=1)
"""
if len(x.shape) == 1: axis = 0
......
......@@ -182,28 +182,28 @@ def scaled_dot_product_attention(queries,
Refer to `Attention Is All You Need
<https://arxiv.org/pdf/1706.03762.pdf>`_.
Note that batch data containing sequences with different lengths is not
supported by this because of the (batch) matrix multipication.
Args:
queries (Variable): The input variable which is a Tensor or
LoDTensor.
keys (Variable): The input variable which is a Tensor or LoDTensor.
values (Variable): The input variable which is a Tensor or
LoDTensor.
num_heads (int): Head number to compute the dot product attention.
dropout_rate (float): The dropout rate for attention weight.
queries (Variable): The input variable which should be a 3-D Tensor.
keys (Variable): The input variable which should be a 3-D Tensor.
values (Variable): The input variable which should be a 3-D Tensor.
num_heads (int): Head number to compute the scaled dot product
attention. Default value is 1.
dropout_rate (float): The dropout rate to drop the attention weight.
Default value is 0.
Returns:
Variable: The context Tensor computed by multi-head scaled dot product
Variable: A 3-D Tensor computed by multi-head scaled dot product
attention.
Examples:
.. code-block:: python
# Suppose q, k, v are tensor variables with the following
# shape: q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10]
out, attn_scores = fluid.nets.dot_product_attention(q, k, v)
# Suppose q, k, v are Tensors with the following shape:
# q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10]
contexts = fluid.nets.dot_product_attention(q, k, v)
out.shape # [3, 5, 10]
attn_scores.shape # [3, 5, 6]
"""
......@@ -227,19 +227,30 @@ def scaled_dot_product_attention(queries,
"by the number of attention heads (%d)." %
(values.shape[-1], num_heads))
def __compute_qkv(queries, keys, values, num_heads):
if num_heads == 1:
return queries, keys, values
q = layers.fc(input=queries, size=queries.shape[-1], num_flatten_dims=2)
k = layers.fc(input=keys, size=keys.shape[-1], num_flatten_dims=2)
v = layers.fc(input=values, size=values.shape[-1], num_flatten_dims=2)
return q, k, v
def __split_heads(x, num_heads):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions.
Args:
x(Tensor): a 3-D input Tensor.
num_heads(int): The number of heads.
x(Tensor): a 3-D input Tensor.
num_heads(int): The number of heads.
Returns:
a Tensor with shape [..., n, m/n]
Tensor: a Tensor with shape [..., n, m/num_heads], where m is size
of the last dimension of x.
"""
if num_heads == 1: return x
if num_heads == 1:
return x
hidden_size = x.shape[-1]
# reshape the 3-D input: [batch_size, max_sequence_length, hidden_dim]
......@@ -254,6 +265,19 @@ def scaled_dot_product_attention(queries,
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
def __combine_heads(x):
"""
Reshape the last two dimensions of inpunt tensor x so that it becomes
one dimension.
Args:
x(Tensor): a 4-D input Tensor with shape
[bs, num_heads, max_sequence_length, hidden_dim].
Returns:
Tensor: a Tensor with shape
[bs, max_sequence_length, num_heads * hidden_dim].
"""
if len(x.shape) == 3: return
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
......@@ -266,9 +290,11 @@ def scaled_dot_product_attention(queries,
trans_x.shape[2] * trans_x.shape[3]
]))
q = __split_heads(queries, num_heads)
k = __split_heads(keys, num_heads)
v = __split_heads(values, num_heads)
q, k, v = __compute_qkv(queries, keys, values, num_heads)
q = __split_heads(q, num_heads)
k = __split_heads(k, num_heads)
v = __split_heads(v, num_heads)
key_dim_per_head = keys.shape[-1] // num_heads
scaled_q = layers.scale(x=q, scale=key_dim_per_head**-0.5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册