diff --git a/python/paddle/v2/fluid/nets.py b/python/paddle/v2/fluid/nets.py index 3390fa5946168dbeb3dc4e216b5d45fc870c934f..dfae9c9391a930d9833496669a383e45a3090399 100644 --- a/python/paddle/v2/fluid/nets.py +++ b/python/paddle/v2/fluid/nets.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import pdb import layers __all__ = [ "simple_img_conv_pool", "sequence_conv_pool", "glu", - "dot_product_attention", + "scaled_dot_product_attention", ] @@ -179,7 +179,7 @@ def scaled_dot_product_attention(queries, .. math:: - Attention(Q, K, V)= softmax(QK^\mathrm{T})V + Attention(Q, K, V)= softmax(QK^\mathrm{T})V Refer to `Attention Is All You Need `_. @@ -195,8 +195,8 @@ def scaled_dot_product_attention(queries, LoDTensor. Returns: - tuple: The Tensor variables representing the output and attention - scores. + Variable: The context Tensor computed by multi-head scaled dot product + attention. Examples: .. code-block:: python @@ -239,26 +239,42 @@ def scaled_dot_product_attention(queries, Returns: a Tensor with shape [..., n, m/n] """ + if num_heads == 1: return x + hidden_size = x.shape[-1] - # + # reshape the 3-D input: [batch_size, max_sequence_length, hidden_dim] + # into a 4-D output: + # [batch_size, max_sequence_length, num_heads, hidden_size_per_head]. reshaped = layers.reshape( - x=x, shape=x.shape[:-1] + [num_heads, hidden_size // num_heads]) - pass - - def __combine_heads(): - pass - - q = __split_heads(quries, num_heads) + x=x, + shape=list(x.shape[:-1]) + [num_heads, hidden_size // num_heads]) + # permuate the original dimensions into: + # [batch_size, num_heads, max_sequence_len, hidden_size_per_head] + return layers.transpose(x=reshaped, perm=[0, 2, 1, 3]) + + def __combine_heads(x): + if len(x.shape) == 3: return + if len(x.shape) != 4: + raise ValueError("Input(x) should be a 4-D Tensor.") + + trans_x = layers.transpose( + x, perm=[x.shape[0], x.shape[2], x.shape[1], x.shape[3]]) + return layers.reshape(x=layers.reshape( + x=trans_x, + shape=[trans_x.shape[0], trans_x[1], trans_x[2] * trans_x[3]])) + + q = __split_heads(queries, num_heads) k = __split_heads(keys, num_heads) v = __split_heads(values, num_heads) key_dim_per_head = keys.shape[-1] // num_heads - scale = key_dim_per_head**-0.5 + scaled_q = layers.scale(x=q, scale=key_dim_per_head**-0.5) + product = layers.matmul(x=k, y=scaled_q, transpose_y=True) - product = layers.matmul(x=k, y=q, transpose_y=True) attn_scores = layers.reshape( x=layers.reshape( x=product, shape=[-1, product.shape[-1]], act="softmax"), shape=product.shape) - context = layers.matmul(attn_scores, values) - return context, attn_scores + ctx_multiheads = layers.matmul(attn_scores, values) + context = __combine_heads(ctx_multiheads) + return context