提交 113cd6b3 编写于 作者: Y ying

add multi-head scaled_dot_product attention.

上级 abf9395d
...@@ -11,14 +11,14 @@ ...@@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pdb
import layers import layers
__all__ = [ __all__ = [
"simple_img_conv_pool", "simple_img_conv_pool",
"sequence_conv_pool", "sequence_conv_pool",
"glu", "glu",
"dot_product_attention", "scaled_dot_product_attention",
] ]
...@@ -179,7 +179,7 @@ def scaled_dot_product_attention(queries, ...@@ -179,7 +179,7 @@ def scaled_dot_product_attention(queries,
.. math:: .. math::
Attention(Q, K, V)= softmax(QK^\mathrm{T})V Attention(Q, K, V)= softmax(QK^\mathrm{T})V
Refer to `Attention Is All You Need Refer to `Attention Is All You Need
<https://arxiv.org/pdf/1706.03762.pdf>`_. <https://arxiv.org/pdf/1706.03762.pdf>`_.
...@@ -195,8 +195,8 @@ def scaled_dot_product_attention(queries, ...@@ -195,8 +195,8 @@ def scaled_dot_product_attention(queries,
LoDTensor. LoDTensor.
Returns: Returns:
tuple: The Tensor variables representing the output and attention Variable: The context Tensor computed by multi-head scaled dot product
scores. attention.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -239,26 +239,42 @@ def scaled_dot_product_attention(queries, ...@@ -239,26 +239,42 @@ def scaled_dot_product_attention(queries,
Returns: Returns:
a Tensor with shape [..., n, m/n] a Tensor with shape [..., n, m/n]
""" """
if num_heads == 1: return x
hidden_size = x.shape[-1] hidden_size = x.shape[-1]
# # reshape the 3-D input: [batch_size, max_sequence_length, hidden_dim]
# into a 4-D output:
# [batch_size, max_sequence_length, num_heads, hidden_size_per_head].
reshaped = layers.reshape( reshaped = layers.reshape(
x=x, shape=x.shape[:-1] + [num_heads, hidden_size // num_heads]) x=x,
pass shape=list(x.shape[:-1]) + [num_heads, hidden_size // num_heads])
# permuate the original dimensions into:
def __combine_heads(): # [batch_size, num_heads, max_sequence_len, hidden_size_per_head]
pass return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
q = __split_heads(quries, num_heads) def __combine_heads(x):
if len(x.shape) == 3: return
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(
x, perm=[x.shape[0], x.shape[2], x.shape[1], x.shape[3]])
return layers.reshape(x=layers.reshape(
x=trans_x,
shape=[trans_x.shape[0], trans_x[1], trans_x[2] * trans_x[3]]))
q = __split_heads(queries, num_heads)
k = __split_heads(keys, num_heads) k = __split_heads(keys, num_heads)
v = __split_heads(values, num_heads) v = __split_heads(values, num_heads)
key_dim_per_head = keys.shape[-1] // num_heads key_dim_per_head = keys.shape[-1] // num_heads
scale = key_dim_per_head**-0.5 scaled_q = layers.scale(x=q, scale=key_dim_per_head**-0.5)
product = layers.matmul(x=k, y=scaled_q, transpose_y=True)
product = layers.matmul(x=k, y=q, transpose_y=True)
attn_scores = layers.reshape( attn_scores = layers.reshape(
x=layers.reshape( x=layers.reshape(
x=product, shape=[-1, product.shape[-1]], act="softmax"), x=product, shape=[-1, product.shape[-1]], act="softmax"),
shape=product.shape) shape=product.shape)
context = layers.matmul(attn_scores, values) ctx_multiheads = layers.matmul(attn_scores, values)
return context, attn_scores context = __combine_heads(ctx_multiheads)
return context
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册