提交 d00eb53a 编写于 作者: Y ying

add linear projection to q, k and v.

上级 0d96899f
......@@ -108,16 +108,17 @@ def fc(input,
into a 2-dimensional matrix. The parameter
`num_flatten_dims` determines how the input tensor
is flattened: the first `num_flatten_dims`
dimensions will be flatten to form the first
dimension of the final matrix (height of the
matrix), and the rest `rank(X) - num_flatten_dims`
dimensions are flattened to form the second
dimension of the final matrix (width of the matrix).
For example, suppose `X` is a 6-dimensional tensor
with a shape [2, 3, 4, 5, 6], and
`num_flatten_dims` = 3. Then, the flattened matrix
will have a shape [2 x 3 x 4, 5 x 6] = [24, 30].
By default, `num_flatten_dims` is set to 1.
(inclusive, index starts from 1) dimensions will
be flatten to form the first dimension of the
final matrix (height of the matrix), and the rest
`rank(X) - num_flatten_dims` dimensions are
flattened to form the second dimension of the
final matrix (width of the matrix). For example,
suppose `X` is a 6-dimensional tensor with a shape
[2, 3, 4, 5, 6], and `num_flatten_dims` = 3. Then,
the flattened matrix will have a shape
[2 x 3 x 4, 5 x 6] = [24, 30]. By default,
`num_flatten_dims` is set to 1.
param_attr(ParamAttr|list): The parameter attribute for learnable
parameters/weights of the fully connected
layer.
......@@ -158,6 +159,7 @@ def fc(input,
param_shape = [
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
] + [size]
w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False)
tmp = helper.create_tmp_variable(dtype)
......@@ -747,7 +749,7 @@ def square_error_cost(input, label, **kwargs):
This layer accepts input predictions and target label and returns the
squared error cost.
For predictions, :math:`X`, and target labels, :math:`Y`, the equation is:
.. math::
......
......@@ -197,15 +197,27 @@ def scaled_dot_product_attention(queries,
Variable: A 3-D Tensor computed by multi-head scaled dot product
attention.
Raises:
ValueError: If input queries, keys, values are not 3-D Tensors.
NOTE:
1. When num_heads > 1, three linear projections are learned respectively
to map input queries, keys and values into queries', keys' and values'.
queries', keys' and values' have the same shapes with queries, keys
and values.
1. When num_heads == 1, scaled_dot_product_attention has no learnable
parameters.
Examples:
.. code-block:: python
# Suppose q, k, v are Tensors with the following shape:
# q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10]
contexts = fluid.nets.dot_product_attention(q, k, v)
out.shape # [3, 5, 10]
attn_scores.shape # [3, 5, 6]
contexts = fluid.nets.scaled_dot_product_attention(q, k, v)
contexts.shape # [3, 5, 10]
"""
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError(
......@@ -228,6 +240,22 @@ def scaled_dot_product_attention(queries,
(values.shape[-1], num_heads))
def __compute_qkv(queries, keys, values, num_heads):
"""
Add linear projection to queries, keys, and values.
Args:
queries(Tensor): a 3-D input Tensor.
keys(Tensor): a 3-D input Tensor.
values(Tensor): a 3-D input Tensor.
num_heads(int): The number of heads. Linearly project the inputs
ONLY when num_heads > 1.
Returns:
Tensor: linearly projected output Tensors: queries', keys' and
values'. They have the same shapes with queries, keys and
values.
"""
if num_heads == 1:
return queries, keys, values
......
文件模式从 100755 更改为 100644
......@@ -65,6 +65,7 @@ class TestMultiheadAttention(unittest.TestCase):
self.set_inputs(place)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
output = exe.run(fluid.default_main_program(),
feed=self.inputs,
fetch_list=self.fetch_list,
......@@ -90,6 +91,8 @@ class TestMultiheadAttention(unittest.TestCase):
self.set_program()
self.run_program()
#fixme(caoying) add more meaningfull unittest.
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册