提交 d00eb53a 编写于 作者: Y ying

add linear projection to q, k and v.

上级 0d96899f
...@@ -108,16 +108,17 @@ def fc(input, ...@@ -108,16 +108,17 @@ def fc(input,
into a 2-dimensional matrix. The parameter into a 2-dimensional matrix. The parameter
`num_flatten_dims` determines how the input tensor `num_flatten_dims` determines how the input tensor
is flattened: the first `num_flatten_dims` is flattened: the first `num_flatten_dims`
dimensions will be flatten to form the first (inclusive, index starts from 1) dimensions will
dimension of the final matrix (height of the be flatten to form the first dimension of the
matrix), and the rest `rank(X) - num_flatten_dims` final matrix (height of the matrix), and the rest
dimensions are flattened to form the second `rank(X) - num_flatten_dims` dimensions are
dimension of the final matrix (width of the matrix). flattened to form the second dimension of the
For example, suppose `X` is a 6-dimensional tensor final matrix (width of the matrix). For example,
with a shape [2, 3, 4, 5, 6], and suppose `X` is a 6-dimensional tensor with a shape
`num_flatten_dims` = 3. Then, the flattened matrix [2, 3, 4, 5, 6], and `num_flatten_dims` = 3. Then,
will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. the flattened matrix will have a shape
By default, `num_flatten_dims` is set to 1. [2 x 3 x 4, 5 x 6] = [24, 30]. By default,
`num_flatten_dims` is set to 1.
param_attr(ParamAttr|list): The parameter attribute for learnable param_attr(ParamAttr|list): The parameter attribute for learnable
parameters/weights of the fully connected parameters/weights of the fully connected
layer. layer.
...@@ -158,6 +159,7 @@ def fc(input, ...@@ -158,6 +159,7 @@ def fc(input,
param_shape = [ param_shape = [
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1) reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
] + [size] ] + [size]
w = helper.create_parameter( w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False)
tmp = helper.create_tmp_variable(dtype) tmp = helper.create_tmp_variable(dtype)
......
...@@ -197,15 +197,27 @@ def scaled_dot_product_attention(queries, ...@@ -197,15 +197,27 @@ def scaled_dot_product_attention(queries,
Variable: A 3-D Tensor computed by multi-head scaled dot product Variable: A 3-D Tensor computed by multi-head scaled dot product
attention. attention.
Raises:
ValueError: If input queries, keys, values are not 3-D Tensors.
NOTE:
1. When num_heads > 1, three linear projections are learned respectively
to map input queries, keys and values into queries', keys' and values'.
queries', keys' and values' have the same shapes with queries, keys
and values.
1. When num_heads == 1, scaled_dot_product_attention has no learnable
parameters.
Examples: Examples:
.. code-block:: python .. code-block:: python
# Suppose q, k, v are Tensors with the following shape: # Suppose q, k, v are Tensors with the following shape:
# q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10] # q: [3, 5, 9], k: [3, 6, 9], v: [3, 6, 10]
contexts = fluid.nets.dot_product_attention(q, k, v) contexts = fluid.nets.scaled_dot_product_attention(q, k, v)
out.shape # [3, 5, 10] contexts.shape # [3, 5, 10]
attn_scores.shape # [3, 5, 6]
""" """
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3): if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError( raise ValueError(
...@@ -228,6 +240,22 @@ def scaled_dot_product_attention(queries, ...@@ -228,6 +240,22 @@ def scaled_dot_product_attention(queries,
(values.shape[-1], num_heads)) (values.shape[-1], num_heads))
def __compute_qkv(queries, keys, values, num_heads): def __compute_qkv(queries, keys, values, num_heads):
"""
Add linear projection to queries, keys, and values.
Args:
queries(Tensor): a 3-D input Tensor.
keys(Tensor): a 3-D input Tensor.
values(Tensor): a 3-D input Tensor.
num_heads(int): The number of heads. Linearly project the inputs
ONLY when num_heads > 1.
Returns:
Tensor: linearly projected output Tensors: queries', keys' and
values'. They have the same shapes with queries, keys and
values.
"""
if num_heads == 1: if num_heads == 1:
return queries, keys, values return queries, keys, values
......
文件模式从 100755 更改为 100644
...@@ -65,6 +65,7 @@ class TestMultiheadAttention(unittest.TestCase): ...@@ -65,6 +65,7 @@ class TestMultiheadAttention(unittest.TestCase):
self.set_inputs(place) self.set_inputs(place)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
output = exe.run(fluid.default_main_program(), output = exe.run(fluid.default_main_program(),
feed=self.inputs, feed=self.inputs,
fetch_list=self.fetch_list, fetch_list=self.fetch_list,
...@@ -90,6 +91,8 @@ class TestMultiheadAttention(unittest.TestCase): ...@@ -90,6 +91,8 @@ class TestMultiheadAttention(unittest.TestCase):
self.set_program() self.set_program()
self.run_program() self.run_program()
#fixme(caoying) add more meaningfull unittest.
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册