diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index 50a8755ac9f7b0a8e35c60f02a9fb825195ab80f..63069e83952172df3136458ebfee4b446749934d 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -25,12 +25,13 @@ __all__ = [ import copy import collections +from .common import Linear, Dropout +from .norm import LayerNorm +from .. import functional as F +from ... import tensor from ...fluid import layers +from ...fluid.dygraph import Layer, LayerList from ...fluid.param_attr import ParamAttr -from ...fluid.dygraph import Layer, Linear, Dropout, LayerNorm, LayerList -from .. import functional as F -from ...fluid.layers import utils -from ...fluid.layers.utils import map_structure def _convert_param_attr_to_list(param_attr, n): @@ -103,7 +104,7 @@ class MultiHeadAttention(Layer): # self attention mask: [batch_size, num_heads, query_len, query_len] attn_mask = paddle.rand((2, 2, 4, 4)) multi_head_attn = paddle.MultiHeadAttention(128, 2) - output = multi_head_attn(query, attn_mask=attn_mask) # [2, 4, 128] + output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128] """ Cache = collections.namedtuple("Cache", ["k", "v"]) @@ -176,8 +177,8 @@ class MultiHeadAttention(Layer): and their data types are same as inputs. """ q = self.q_proj(query) - q = layers.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) - q = layers.transpose(x=q, perm=[0, 2, 1, 3]) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) if isinstance(cache, self.StaticCache): # for encoder-decoder attention in inference and has cached @@ -187,8 +188,8 @@ class MultiHeadAttention(Layer): if isinstance(cache, self.Cache): # for decoder self-attention in inference - k = layers.concat([cache.k, k], axis=2) - v = layers.concat([cache.v, v], axis=2) + k = tensor.concat([cache.k, k], axis=2) + v = tensor.concat([cache.v, v], axis=2) cache = self.Cache(k, v) return (q, k, v) if cache is None else (q, k, v, cache) @@ -219,10 +220,10 @@ class MultiHeadAttention(Layer): """ k = self.k_proj(key) v = self.v_proj(value) - k = layers.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) - k = layers.transpose(x=k, perm=[0, 2, 1, 3]) - v = layers.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) - v = layers.transpose(x=v, perm=[0, 2, 1, 3]) + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) return k, v def gen_cache(self, key, value=None, type=Cache): @@ -352,24 +353,25 @@ class MultiHeadAttention(Layer): q, k, v, cache = self._prepare_qkv(query, key, value, cache) # scale dot product attention + # TODO(guosheng): use tensor.matmul, however it doesn't support `alpha` product = layers.matmul( x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) if attn_mask is not None: # TODO(guosheng): support bool mask product = product + attn_mask - weights = layers.softmax(product) + weights = F.softmax(product) if self.dropout: - weights = layers.dropout( + weights = F.dropout( weights, - dropout_prob=self.dropout, - dropout_implementation="upscale_in_train", - is_test=False) + self.dropout, + training=self.training, + mode="upscale_in_train") - out = layers.matmul(weights, v) + out = tensor.matmul(weights, v) # combine heads - out = layers.transpose(out, perm=[0, 2, 1, 3]) - out = layers.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + out = tensor.transpose(out, perm=[0, 2, 1, 3]) + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) # project to output out = self.out_proj(out) @@ -429,7 +431,7 @@ class TransformerEncoderLayer(Layer): .. code-block:: python import paddle - from paddle import TransformerEncoderLayer + from paddle.nn import TransformerEncoderLayer # encoder input: [batch_size, src_len, d_model] enc_input = paddle.rand((2, 4, 128)) @@ -470,17 +472,14 @@ class TransformerEncoderLayer(Layer): bias_attr=bias_attrs[0]) self.linear1 = Linear( d_model, dim_feedforward, weight_attrs[1], bias_attr=bias_attrs[1]) - self.dropout = Dropout( - act_dropout, dropout_implementation="upscale_in_train") + self.dropout = Dropout(act_dropout, mode="upscale_in_train") self.linear2 = Linear( dim_feedforward, d_model, weight_attrs[1], bias_attr=bias_attrs[1]) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) - self.dropout1 = Dropout( - dropout, dropout_implementation="upscale_in_train") - self.dropout2 = Dropout( - dropout, dropout_implementation="upscale_in_train") - self.activation = getattr(layers, activation) + self.dropout1 = Dropout(dropout, mode="upscale_in_train") + self.dropout2 = Dropout(dropout, mode="upscale_in_train") + self.activation = getattr(F, activation) def forward(self, src, src_mask=None): """ @@ -539,7 +538,7 @@ class TransformerEncoder(Layer): .. code-block:: python import paddle - from paddle import TransformerEncoderLayer, TransformerEncoder + from paddle.nn import TransformerEncoderLayer, TransformerEncoder # encoder input: [batch_size, src_len, d_model] enc_input = paddle.rand((2, 4, 128)) @@ -643,7 +642,7 @@ class TransformerDecoderLayer(Layer): .. code-block:: python import paddle - from paddle import TransformerDecoderLayer + from paddle.nn import TransformerDecoderLayer # decoder input: [batch_size, tgt_len, d_model] dec_input = paddle.rand((2, 4, 128)) @@ -697,20 +696,16 @@ class TransformerDecoderLayer(Layer): bias_attr=bias_attrs[1]) self.linear1 = Linear( d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2]) - self.dropout = Dropout( - act_dropout, dropout_implementation="upscale_in_train") + self.dropout = Dropout(act_dropout, mode="upscale_in_train") self.linear2 = Linear( dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2]) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.norm3 = LayerNorm(d_model) - self.dropout1 = Dropout( - dropout, dropout_implementation="upscale_in_train") - self.dropout2 = Dropout( - dropout, dropout_implementation="upscale_in_train") - self.dropout3 = Dropout( - dropout, dropout_implementation="upscale_in_train") - self.activation = getattr(layers, activation) + self.dropout1 = Dropout(dropout, mode="upscale_in_train") + self.dropout2 = Dropout(dropout, mode="upscale_in_train") + self.dropout3 = Dropout(dropout, mode="upscale_in_train") + self.activation = getattr(F, activation) def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None): """ @@ -834,7 +829,7 @@ class TransformerDecoder(Layer): .. code-block:: python import paddle - from paddle import TransformerDecoderLayer, TransformerDecoder + from paddle.nn import TransformerDecoderLayer, TransformerDecoder # decoder input: [batch_size, tgt_len, d_model] dec_input = paddle.rand((2, 4, 128)) @@ -1017,7 +1012,7 @@ class Transformer(Layer): .. code-block:: python import paddle - from paddle import Transformer + from paddle.nn import Transformer # src: [batch_size, tgt_len, d_model] enc_input = paddle.rand((2, 4, 128))