From 317f7ce2ef72b27d513d7d92ffdc953aad23153b Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Mon, 24 Aug 2020 13:50:45 +0800 Subject: [PATCH] [API 2.0] Add transformer apis (#26418) * Add MultiHeadAttention api. test=develop * Add MultiHeadAttention cache type and gen_cache. test=develop * Add TransformerEncoderLayer and TransformerEncoder. test=develop * Add Transformer decoder apis. test=develop * Add Transformer api. test=develop * add unittests for transformer api * add unittests for transformer api * Fix some bugs in Transformer apis. test=develop * add unittests for encoder, decoder and transformer * clean conflicts infor in code * clean Chinese comments * Add TransformerDecoderCell and TransformerBeamSearchDecoder. test=develop * Remove TransformerDecoderCell and TransformerBeamSearchDecoder temporarily. test=develop * Add import for Transformer apis. test=develop * Update usage of weight_attr and Tensor in Transformer api docs. test=develop * Update Transformer apis by renaming MultiheadAttention and cal_kv according to comments. test=develop * Fix MultiHeadAttention in test_transformer_api.py. test=develop Co-authored-by: LiuChiaChi <709153940@qq.com> --- .../tests/unittests/test_transformer_api.py | 477 +++++++ python/paddle/nn/__init__.py | 6 + python/paddle/nn/layer/__init__.py | 2 + python/paddle/nn/layer/transformer.py | 1105 ++++++++++++++++- 4 files changed, 1589 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_transformer_api.py diff --git a/python/paddle/fluid/tests/unittests/test_transformer_api.py b/python/paddle/fluid/tests/unittests/test_transformer_api.py new file mode 100644 index 0000000000..c8d1e77134 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_transformer_api.py @@ -0,0 +1,477 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.nn.layer.transformer import MultiHeadAttention, TransformerEncoderLayer, TransformerDecoderLayer, TransformerEncoder, TransformerDecoder, Transformer + +import unittest + + +def generate_basic_params(mode="attn", self_attention=True): + batch_size, query_length = [np.random.randint(2, 10) for _ in range(2)] + d_head, num_heads = [np.random.randint(3, 10) for _ in range(2)] + attn_dropout = 0.0 + embed_dim = d_head * num_heads + if mode == "attn": + if self_attention: + kdim, vdim = embed_dim, embed_dim + key_length, value_length = query_length, query_length + else: + kdim, vdim = [np.random.randint(5, 20) for _ in range(2)] + key_length = np.random.randint(2, 10) + value_length = key_length + return batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout + + else: + dropout, act_dropout = 0.0, 0.0 + dim_feedforward = np.random.randint(128, 1024) + sequence_length = np.random.randint(2, 10) + if mode == "encoder_layer": + return batch_size, embed_dim, num_heads, dim_feedforward, dropout, attn_dropout, act_dropout, sequence_length + elif mode == "decoder_layer": + target_length = np.random.randint(2, 10) + return batch_size, embed_dim, num_heads, dim_feedforward, dropout, attn_dropout, act_dropout, sequence_length, target_length + + +def generate_query_key_value_cache(self_attention, + batch_size, + num_heads, + query_length, + embed_dim, + key_length=None, + value_length=None, + kdim=None, + vdim=None, + cache=None): + query = np.random.rand(batch_size, query_length, + embed_dim).astype("float32") + attn_mask = np.zeros((batch_size, num_heads, query_length, key_length)) + attn_mask[0][0][0][0] = -1e9 + + head_dim = embed_dim // num_heads + if self_attention: + key, value = query, query + else: + key = np.random.rand(batch_size, key_length, kdim).astype("float32") + value = np.random.rand(batch_size, value_length, vdim).astype("float32") + cache_dict = {} + if cache: + if not self_attention: + cache_dict["static_k"] = np.random.rand( + batch_size, num_heads, key_length, head_dim).astype("float32") + cache_dict["static_v"] = np.random.rand( + batch_size, num_heads, value_length, head_dim).astype("float32") + else: + cache_dict["k"] = np.random.rand(batch_size, num_heads, key_length, + head_dim).astype("float32") + cache_dict["v"] = np.random.rand( + batch_size, num_heads, value_length, head_dim).astype("float32") + else: + cache_dict = None + return query, key, value, attn_mask, cache_dict + + +def fc(x, weight): + return np.matmul(x, weight) + + +def softmax(x): + np.seterr(invalid='ignore') + output = np.zeros(x.shape, dtype=np.float64) + for i in range(x.shape[0]): + for j in range(x.shape[1]): + for k in range(x.shape[2]): + x_curr = x[i, j, k, :] + e_x = np.exp(x_curr - np.amax(x_curr)) + output[i, j, k, :] = e_x / np.sum(e_x) + return output + + +def batch_matmul(x, y): + assert x.shape[0] == y.shape[0] + assert x.shape[1] == y.shape[1] + retval = np.zeros( + (x.shape[0], x.shape[1], x.shape[2], y.shape[3]), dtype=np.float64) + for i in range(x.shape[0]): + for j in range(x.shape[1]): + retval[i, j, :, :] = np.matmul(x[i, j, :, :], y[i, j, :, :]) + return retval + + +def scaled_dot_product_attention(q, k, v, d_key, attn_mask, multi_head_attn): + k = k.transpose([0, 1, 3, 2]) + qkt = batch_matmul(q, k / np.sqrt(d_key, dtype=np.float64)) + if attn_mask is not None: + qkt += attn_mask + weight = softmax(qkt) + attn_heads = batch_matmul(weight, v) + attn_heads = attn_heads.transpose((0, 2, 1, 3)) + attn_heads = attn_heads.reshape((attn_heads.shape[0], attn_heads.shape[1], + attn_heads.shape[2] * attn_heads.shape[3])) + return attn_heads + + +def cal_qkv(key, value, num_heads, embed_dim, multi_head_attn): + with fluid.dygraph.guard(): + head_dim = embed_dim // num_heads + k_weight = multi_head_attn.k_proj.weight.numpy() + v_weight = multi_head_attn.v_proj.weight.numpy() + k = fc(key, k_weight) + v = fc(value, v_weight) + k = k.reshape((k.shape[0], k.shape[1], num_heads, head_dim)) + k = k.transpose((0, 2, 1, 3)) + v = v.reshape((v.shape[0], v.shape[1], num_heads, head_dim)) + v = v.transpose((0, 2, 1, 3)) + return k, v + + +def prepare_qkv(query, key, value, num_heads, embed_dim, self_attention, + multi_head_attn, cache_dict): + q_weight = multi_head_attn.q_proj.weight.numpy() + q = fc(query, q_weight) + q = q.reshape((q.shape[0], q.shape[1], num_heads, embed_dim // num_heads)) + q = q.transpose((0, 2, 1, 3)) + + if not self_attention and cache_dict: + k, v = cache_dict["static_k"], cache_dict["static_v"] + else: + k, v = cal_qkv(key, value, num_heads, embed_dim, multi_head_attn) + if cache_dict is not None: + k = np.concatenate((cache_dict["k"], k), axis=2) + v = np.concatenate((cache_dict["v"], v), axis=2) + return (q, k, v, cache_dict) + + +def add(x, y=None): + fluid.enable_dygraph() + with fluid.dygraph.guard(): + x = x.numpy() if not isinstance(x, np.ndarray) else x + if y is not None: + x += y + return x + return x + + +def relu(x): + compare = x > 0 + return x * compare + + +def layer_norm(x, normalized_shape, norm, epsilon=1e-05, act=None): + fluid.enable_dygraph() + with fluid.dygraph.guard(): + # scale: + weight = norm.weight.numpy() + # shift: + bias = norm.bias.numpy() + + batch_size, src_len, d_model = x.shape + x = x.reshape((batch_size * src_len, d_model)) + mu = np.mean(x, axis=1, keepdims=True) + sigma_squar = np.sum(np.square(x - mu), axis=1) / d_model + x1_up = (x - mu) + x1_down_1 = sigma_squar + epsilon + x1_down = np.sqrt(x1_down_1) + x1_down = x1_down.reshape((x1_down.shape[0], 1)) + x1 = x1_up / x1_down + x_scaled = weight * x1 + x_scaled_bias = x_scaled + bias + x_scaled_bias = x_scaled_bias.reshape((batch_size, src_len, d_model)) + return x_scaled_bias + + +def ffn(src, encoder_layer, ffn_fc1_act="relu"): + assert ffn_fc1_act == "relu", "only relu is supported" + fluid.enable_dygraph() + with fluid.dygraph.guard(): + src = src.numpy() if not isinstance(src, np.ndarray) else src + w1 = encoder_layer.linear1.weight.numpy() + w2 = encoder_layer.linear2.weight.numpy() + # fc1 + x1 = fc(src, w1) + x1 = relu(x1) + # fc2 + x2 = fc(x1, w2) + return x2 + + +class TestTransformer(unittest.TestCase): + def test_multi_head_attention(self): + def multihead_attention_test_helper(self_attention, cache): + paddle.framework.manual_seed(2020) + # self_attention|cross_attention, cache|No cache + with fluid.dygraph.guard(fluid.CPUPlace()): + + # generate params for multi_head_attention + batch_size, query_length, key_length, value_length, embed_dim, kdim, vdim, num_heads, attn_dropout = generate_basic_params( + "attn", self_attention) + query, key, value, attn_mask, cache_dict = generate_query_key_value_cache( + self_attention, batch_size, num_heads, query_length, + embed_dim, key_length, value_length, kdim, vdim, cache) + if cache and self_attention: + attn_mask = np.concatenate((attn_mask, attn_mask), axis=3) + need_weight, param_attr, bias_attr = False, None, None + # call paddle's function + multi_head_attn = MultiHeadAttention( + embed_dim, num_heads, attn_dropout, kdim, vdim, need_weight, + param_attr, bias_attr) + # construct cache object + cache_obj = None + if cache_dict: + if 'k' and 'v' in cache_dict: + cache_obj = multi_head_attn.Cache( + paddle.to_variable(cache_dict['k']), + paddle.to_variable(cache_dict['v'])) + elif 'static_k' and 'static_v' in cache_dict: + cache_obj = multi_head_attn.StaticCache( + paddle.to_variable(cache_dict['static_k']), + paddle.to_variable(cache_dict['static_v'])) + if attn_mask is not None: + attn_output = multi_head_attn( + paddle.to_variable(query), + paddle.to_variable(key), + paddle.to_variable(value), + paddle.to_variable(attn_mask), cache_obj) + else: + attn_output = multi_head_attn( + paddle.to_variable(query), + paddle.to_variable(key), + paddle.to_variable(value), attn_mask, cache_obj) + attn_output = attn_output[0] if cache_dict else attn_output + + # implementation by numpy + # compute q, k, v + q, k, v, _ = prepare_qkv(query, key, value, num_heads, + embed_dim, self_attention, + multi_head_attn, cache_dict) + # scale dot product attention + attn_heads = scaled_dot_product_attention( + q, k, v, embed_dim // num_heads, attn_mask, multi_head_attn) + out_proj_weight = multi_head_attn.out_proj.weight.numpy() + reference = fc(attn_heads, out_proj_weight) + + np.testing.assert_allclose( + attn_output.numpy(), reference, atol=1e-6) + + multihead_attention_test_helper(True, True) + multihead_attention_test_helper(True, False) + multihead_attention_test_helper(False, True) + multihead_attention_test_helper(False, False) + + def test_transformer_encoder_layer(self): + + with fluid.dygraph.guard(fluid.CPUPlace()): + paddle.framework.manual_seed(2020) + + ffn_fc1_act = "relu" + # 1.generate basic params + batch_size, d_model, n_head, dim_feedforward, dropout, attn_dropout, act_dropout, sequence_length = generate_basic_params( + mode="encoder_layer") + # 2.generate input for encoder + src = np.random.rand(batch_size, sequence_length, + d_model).astype("float32") + residual = src + src_mask = np.zeros((batch_size, n_head, sequence_length, + sequence_length)).astype("float32") + src_mask[0][0][0][0] = -np.inf + + # paddle + encoder_layer = TransformerEncoderLayer( + d_model, n_head, dim_feedforward, dropout, ffn_fc1_act, + attn_dropout, act_dropout) + + encoder_output = encoder_layer( + paddle.to_variable(src), + paddle.to_variable(src_mask)) # paddle.to_variable(src_mask)) + # 4.numpy: + # paddle self attention + self_attn = MultiHeadAttention( + d_model, n_head, dropout=attn_dropout) + attn_output = self_attn( + paddle.to_variable(src), + paddle.to_variable(src), + paddle.to_variable(src), paddle.to_variable(src_mask)).numpy() + + src = attn_output + residual + src_norm = layer_norm(src, d_model, encoder_layer.norm1) + residual = src_norm + + ffn_output = ffn(src_norm, encoder_layer, ffn_fc1_act) + src = residual + ffn_output + src = layer_norm(src, d_model, encoder_layer.norm2) + + np.testing.assert_allclose( + encoder_output.numpy(), src, rtol=1e-5, atol=1e-6) + + def test_transformer_decoder_layer(self): + with fluid.dygraph.guard(fluid.CPUPlace()): + paddle.framework.manual_seed(2020) + activation = "relu" + normalize_before = False + batch_size, d_model, n_head, dim_feedforward, dropout, attn_dropout, act_dropout, source_length, target_length = generate_basic_params( + mode="decoder_layer") + tgt = np.random.rand(batch_size, target_length, + d_model).astype("float32") + memory = np.random.rand(batch_size, source_length, + d_model).astype("float32") + tgt_mask = np.zeros((batch_size, n_head, target_length, + target_length)).astype("float32") + tgt_mask[0][0][0][0] = -1e9 + memory_mask = np.zeros((batch_size, n_head, target_length, + source_length)).astype("float32") + memory_mask[0][0][0][0] = -1e9 + for cache in [True, False]: + self_attn = MultiHeadAttention( + d_model, n_head, dropout=attn_dropout) + cross_attn = MultiHeadAttention( + d_model, n_head, dropout=attn_dropout) + + # paddle decoderlayer: + decoder_layer = TransformerDecoderLayer( + d_model, n_head, dim_feedforward, dropout, activation, + attn_dropout, act_dropout, normalize_before) + cache_objs = None + if cache: + cache_objs = decoder_layer.gen_cache( + paddle.to_variable(memory)) + + decoder_output = decoder_layer( + paddle.to_variable(tgt), + paddle.to_variable(memory), + paddle.to_variable(tgt_mask), + paddle.to_variable(memory_mask), cache_objs) + + decoder_output = decoder_output[0].numpy( + ) if cache else decoder_output.numpy() + + # numpy: + residual = tgt + # self-attn + self_attn_cache = cache_objs[ + 0] if cache_objs is not None else None + tgt = self_attn( + paddle.to_variable(tgt), + paddle.to_variable(tgt), + paddle.to_variable(tgt), + paddle.to_variable(tgt_mask), self_attn_cache) + + tgt = tgt[0].numpy() if cache else tgt.numpy() + + tgt = residual + tgt + # postprocess + tgt_norm = layer_norm(tgt, d_model, decoder_layer.norm1) + residual = tgt_norm + # cross-attn + cross_attn_cache = cache_objs[ + 1] if cache_objs is not None else None + tgt = cross_attn( + paddle.to_variable(tgt_norm), + paddle.to_variable(memory), + paddle.to_variable(memory), + paddle.to_variable(memory_mask), cross_attn_cache) + tgt = tgt[0].numpy() if cache else tgt.numpy() + + # postprocess + tgt = tgt + residual + tgt_norm = layer_norm(tgt, d_model, decoder_layer.norm2) + residual = tgt_norm + # FFN + ffn_output = ffn(tgt_norm, decoder_layer, activation) + # post process + tgt = residual + ffn_output + tgt_norm = layer_norm(tgt, d_model, decoder_layer.norm3) + + np.testing.assert_allclose( + decoder_output, tgt_norm, rtol=1e-5, atol=1e-6) + + def test_encoder(self): + batch_size, d_model, n_head, dim_feedforward, dropout, attn_dropout, act_dropout, sequence_length = generate_basic_params( + mode="encoder_layer") + + src = np.random.rand(batch_size, sequence_length, + d_model).astype("float32") + + src_mask = np.zeros((batch_size, n_head, sequence_length, + sequence_length)).astype("float32") + src_mask[0][0][0][0] = -np.inf + with fluid.dygraph.guard(fluid.CPUPlace()): + encoder_layer = TransformerEncoderLayer(d_model, n_head, + dim_feedforward, dropout) + num_layers = 6 + encoder = TransformerEncoder(encoder_layer, num_layers) + # src, src_mask + enc_output = encoder( + paddle.to_variable(src), paddle.to_variable(src_mask)) + + def test_decoder(self): + batch_size, d_model, n_head, dim_feedforward, dropout, _, _, source_length, target_length = generate_basic_params( + mode="decoder_layer") + tgt = np.random.rand(batch_size, target_length, + d_model).astype("float32") + memory = np.random.rand(batch_size, source_length, + d_model).astype("float32") + tgt_mask = np.zeros((batch_size, n_head, target_length, + target_length)).astype("float32") + tgt_mask[0][0][0][0] = -1e9 + memory_mask = np.zeros((batch_size, n_head, target_length, + source_length)).astype("float32") + memory_mask[0][0][0][0] = -1e9 + with fluid.dygraph.guard(fluid.CPUPlace()): + decoder_layer = TransformerDecoderLayer(d_model, n_head, + dim_feedforward, dropout) + num_layers = 6 + decoder = TransformerDecoder(decoder_layer, num_layers) + + output = decoder( + paddle.to_variable(tgt), + paddle.to_variable(memory), + paddle.to_variable(tgt_mask), paddle.to_variable(memory_mask)) + + def test_transformer(self): + batch_size, d_model, n_head, dim_feedforward, dropout, _, _, source_length, target_length = generate_basic_params( + mode="decoder_layer") + + # batch_size, source_length, target_length, d_model, n_head = 4, 8, 8, 64, 8 + with fluid.dygraph.guard(fluid.CPUPlace()): + transformer = Transformer( + d_model, + n_head, + dim_feedforward=dim_feedforward, + dropout=dropout) + src = paddle.to_variable( + np.random.rand(batch_size, source_length, d_model).astype( + "float32")) + tgt = paddle.to_variable( + np.random.rand(batch_size, target_length, d_model).astype( + "float32")) + src_mask = np.zeros((batch_size, n_head, source_length, + source_length)).astype("float32") + src_mask[0][0][0][0] = -np.inf + src_mask = paddle.to_variable(src_mask) + tgt_mask = np.zeros((batch_size, n_head, target_length, + target_length)).astype("float32") + tgt_mask[0][0][0][0] = -1e9 + memory_mask = np.zeros((batch_size, n_head, target_length, + source_length)).astype("float32") + memory_mask[0][0][0][0] = -1e9 + tgt_mask, memory_mask = paddle.to_variable( + tgt_mask), paddle.to_variable(memory_mask) + trans_output = transformer(src, tgt, src_mask, tgt_mask, + memory_mask) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index b2fe248c26..290622450a 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -130,6 +130,12 @@ from .layer.norm import InstanceNorm #DEFINE_ALIAS # from .layer.rnn import RNNCell #DEFINE_ALIAS # from .layer.rnn import GRUCell #DEFINE_ALIAS # from .layer.rnn import LSTMCell #DEFINE_ALIAS +from .layer.transformer import MultiHeadAttention +from .layer.transformer import TransformerEncoderLayer +from .layer.transformer import TransformerEncoder +from .layer.transformer import TransformerDecoderLayer +from .layer.transformer import TransformerDecoder +from .layer.transformer import Transformer from .layer.distance import PairwiseDistance #DEFINE_ALIAS from .layer import loss #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 2fa248450b..de52744e65 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -21,6 +21,7 @@ from . import extension from . import activation from . import norm from . import distance +from . import transformer from .activation import * from .loss import * @@ -28,6 +29,7 @@ from .conv import * from .extension import * from .activation import * from .norm import * +from .transformer import * # from .activation import PReLU #DEFINE_ALIAS from .activation import ReLU #DEFINE_ALIAS from .activation import LeakyReLU #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index 2b926b5ab3..50a8755ac9 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -13,4 +13,1107 @@ # limitations under the License. # TODO: define the classes of Transformer neural network -# __all__ = [ ] +__all__ = [ + 'MultiHeadAttention', + 'TransformerEncoderLayer', + 'TransformerEncoder', + 'TransformerDecoderLayer', + 'TransformerDecoder', + 'Transformer', +] + +import copy +import collections + +from ...fluid import layers +from ...fluid.param_attr import ParamAttr +from ...fluid.dygraph import Layer, Linear, Dropout, LayerNorm, LayerList +from .. import functional as F +from ...fluid.layers import utils +from ...fluid.layers.utils import map_structure + + +def _convert_param_attr_to_list(param_attr, n): + """ + If `param_attr` is a list or tuple, convert every element in it to a + ParamAttr instance. Otherwise, repeat `param_attr` `n` times to + construct a list, and rename every one by appending a increasing index + suffix to avoid having same names when `param_attr` contains a name. + + Parameters: + param_attr (list|tuple|ParamAttr): A list, tuple or something can be + converted to a ParamAttr instance by `ParamAttr._to_attr`. + n (int): The times to repeat to construct a list when `param_attr` + is not a list or tuple. + + Returns: + list: A list composed of each including cell's `param_attr`. + """ + if isinstance(param_attr, (list, tuple)): + assert len(param_attr) == n, ( + "length of param_attr should be %d when it is a list/tuple" % n) + param_attrs = [ParamAttr._to_attr(attr) for attr in param_attr] + else: + param_attrs = [] + attr = ParamAttr._to_attr(param_attr) + for i in range(n): + attr_i = copy.deepcopy(attr) + if attr.name: + attr_i.name = attr_i.name + "_" + str(i) + param_attrs.append(attr_i) + return param_attrs + + +class MultiHeadAttention(Layer): + """ + Attention mapps queries and a set of key-value pairs to outputs, and + Multi-Head Attention performs multiple parallel attention to jointly attending + to information from different representation subspaces. + + Please refer to `Attention Is All You Need `_ + for more details. + + Parameters: + embed_dim (int): The expected feature size in the input and output. + num_heads (int): The number of heads in multi-head attention. + dropout (float, optional): The dropout probability used on attention + weights to drop some attention targets. 0 for no dropout. Default 0 + kdim (int, optional): The feature size in key. If None, assumed equal to + `embed_dim`. Default None. + vdim (int, optional): The feature size in value. If None, assumed equal to + `embed_dim`. Default None. + need_weights (bool, optional): Indicate whether to return the attention + weights. Default False. + weight_attr(ParamAttr, optional): To specify the weight parameter property. + Default: None, which means the default weight parameter property is used. + See usage for details in :code:`ParamAttr` . + bias_attr (ParamAttr, optional): To specify the bias parameter property. + Default: None, which means the default bias parameter property is used. + If it is set to False, this layer will not have trainable bias parameter. + See usage for details in :code:`ParamAttr` . + + Examples: + + .. code-block:: python + + import paddle + + # encoder input: [batch_size, sequence_length, d_model] + query = paddle.rand((2, 4, 128)) + # self attention mask: [batch_size, num_heads, query_len, query_len] + attn_mask = paddle.rand((2, 2, 4, 4)) + multi_head_attn = paddle.MultiHeadAttention(128, 2) + output = multi_head_attn(query, attn_mask=attn_mask) # [2, 4, 128] + """ + + Cache = collections.namedtuple("Cache", ["k", "v"]) + StaticCache = collections.namedtuple("StaticCache", ["k", "v"]) + + def __init__(self, + embed_dim, + num_heads, + dropout=0., + kdim=None, + vdim=None, + need_weights=False, + weight_attr=None, + bias_attr=None): + super(MultiHeadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.need_weights = need_weights + + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + self.q_proj = Linear( + embed_dim, embed_dim, weight_attr, bias_attr=bias_attr) + self.k_proj = Linear( + self.kdim, embed_dim, weight_attr, bias_attr=bias_attr) + self.v_proj = Linear( + self.vdim, embed_dim, weight_attr, bias_attr=bias_attr) + self.out_proj = Linear( + embed_dim, embed_dim, weight_attr, bias_attr=bias_attr) + + def _prepare_qkv(self, query, key, value, cache=None): + """ + Prapares linear projected queries, keys and values for usage of subsequnt + multiple parallel attention. If `cache` is not None, using cached results + to reduce redundant calculations. + + Parameters: + query (Tensor): The queries for multi-head attention. It is a + tensor with shape `[batch_size, query_length, embed_dim]`. The + data type should be float32 or float64. + key (Tensor): The keys for multi-head attention. It is + a tensor with shape `[batch_size, key_length, kdim]`. The + data type should be float32 or float64. If None, use `query` as + `key`. + value (Tensor): The values for multi-head attention. It + is a tensor with shape `[batch_size, value_length, vdim]`. + The data type should be float32 or float64. If None, use `query` as + `value`. + cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional): + It is a namedtuple with `k` and `v` as fields, and stores tensors + shaped `[batch_size, num_heads, length, embed_dim]` which are results + of linear projection, reshape and transpose calculations in + MultiHeadAttention. If is an instance of `Cache`, `k` and `v` + fields reserve intermediate results of previous positions, which + mostly used for decoder self attention. If it is an instance of + `StaticCache`, `key` and `value` args would be ignored, `k` and + `v` fields would be used as calculated results on `key` and + `value`, which mostly used for decoder-encoder cross attention. + It is only used for inference and should be None for training. + Default None. + + Returns: + tuple: A tuple including linear projected keys and values. These two \ + tensors have shapes `[batch_size, n_head, sequence_length, d_key]` \ + and `[batch_size, n_head, sequence_length, d_value]` separately, \ + and their data types are same as inputs. + """ + q = self.q_proj(query) + q = layers.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q = layers.transpose(x=q, perm=[0, 2, 1, 3]) + + if isinstance(cache, self.StaticCache): + # for encoder-decoder attention in inference and has cached + k, v = cache.k, cache.v + else: + k, v = self.compute_kv(key, value) + + if isinstance(cache, self.Cache): + # for decoder self-attention in inference + k = layers.concat([cache.k, k], axis=2) + v = layers.concat([cache.v, v], axis=2) + cache = self.Cache(k, v) + + return (q, k, v) if cache is None else (q, k, v, cache) + + def compute_kv(self, key, value): + """ + Applies linear projection on input keys and values, then splits heads + (reshape and transpose) to get keys and values from different representation + subspaces. The results are used as key-values pairs for subsequent multiple + parallel attention. + + It is part of calculations in multi-head attention, and is provided as + a method to pre-compute and prefetch these results, thus we can use them + to construct cache for inference. + + Parameters: + key (Tensor): The keys for multi-head attention. It is a tensor + with shape `[batch_size, sequence_length, kdim]`. The data type + should be float32 or float64. + value (Tensor): The values for multi-head attention. It is a tensor + with shape `[batch_size, sequence_length, vdim]`. The data type + should be float32 or float64. + + Returns: + tuple: A tuple including transformed keys and values. Their shapes \ + both are `[batch_size, num_heads, sequence_length, embed_dim // num_heads]`, \ + and their data types are same as inputs. + """ + k = self.k_proj(key) + v = self.v_proj(value) + k = layers.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k = layers.transpose(x=k, perm=[0, 2, 1, 3]) + v = layers.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v = layers.transpose(x=v, perm=[0, 2, 1, 3]) + return k, v + + def gen_cache(self, key, value=None, type=Cache): + """ + Generates cache for `forward` usage in inference accroding to arguments. + The generated cache is an instance of `MultiHeadAttention.Cache` or an + instance of `MultiHeadAttention.StaticCache`. + + `Cache` or `StaticCache` is namedtuple with `k` and `v` as fields, + and it stores tensors shaped `[batch_size, num_heads, length, embed_dim]` + which are results of linear projection, reshape and transpose calculations + in MultiHeadAttention. + + If the generated cache is an instance of `Cache`, `k` and `v` fields + reserve intermediate result tensors of previous positions, and the tensors + are incremental among decoding steps, which mostly are used for decoder + decoder self attention. + + If the generated cache is an instance of `StaticCache`, `k` and `v` fields + would be used as calculated result tensors on keys an values in `forward`, + and the tensors keep unchanged among decoding steps, which are mostly used + for decoder-encoder cross attention. + + The cache is generated as follows: + + 1. If `type` is `StaticCache`, apply `compute_kv(key, value)` and use the + results to create an instance of `StaticCache`. + + 2. If `type` is `Cache` and `value` is None, generate empty tensors shaped + `[batch_size, num_heads, 0, embed_dim // num_heads]` and use the results + to create an instance of `Cache`, where `batch_size` is from the first + dimension of `key`. + + 3. If `type` is `Cache` and `value` is not None, use `key`, `value` to create + an instance of `Cache`. + + Parameters: + key (Tensor): The keys for multi-head attention. It is + a tensor with shape `[batch_size, key_length, kdim]`. The + data type should be float32 or float64. If `value` is None, + it is only for batch size and data type reference. + value (Tensor, optional): The values for multi-head attention. It + is a tensor with shape `[batch_size, value_length, vdim]`. + The data type should be float32 or float64. If None, `key` is only + for batch size reference. Default None. + type (type): It should be `MultiHeadAttention.StaticCache` or + `MultiHeadAttention.Cache` to indicate the cache type to generate. + + Returns: + namedtuple: an instance of `Cache` or `StaticCache` accordingly. + """ + if type == MultiHeadAttention.StaticCache: # static_kv + k, v = self.compute_kv(key, value) + return self.StaticCache(k, v) + elif value is None: # incremental_state + k = layers.fill_constant_batch_size_like( + input=key, + shape=[-1, self.num_heads, 0, self.head_dim], + dtype=key.dtype, + value=0) + v = layers.fill_constant_batch_size_like( + input=key, + shape=[-1, self.num_heads, 0, self.head_dim], + dtype=key.dtype, + value=0) + return self.Cache(k, v) + else: + # incremental_state with initial value, mainly for usage like UniLM + return self.Cache(key, value) + + def forward(self, query, key, value, attn_mask=None, cache=None): + """ + Applies multi-head attention to map queries and a set of key-value pairs + to outputs. + + Parameters: + query (Tensor): The queries for multi-head attention. It is a + tensor with shape `[batch_size, query_length, embed_dim]`. The + data type should be float32 or float64. + key (Tensor, optional): The keys for multi-head attention. It is + a tensor with shape `[batch_size, key_length, kdim]`. The + data type should be float32 or float64. If None, use `query` as + `key`. Default None. + value (Tensor, optional): The values for multi-head attention. It + is a tensor with shape `[batch_size, value_length, vdim]`. + The data type should be float32 or float64. If None, use `query` as + `value`. Default None. + attn_mask (Tensor, optional): A tensor used in multi-head attention + to prevents attention to some unwanted positions, usually the + paddings or the subsequent positions. It is a tensor with shape + broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`, + where the unwanted positions have `-INF` values and the others + have 0 values. The data type should be float32 or float64. It can + be None when nothing wanted or needed to be prevented attention to. + Default None + cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional): + It is a namedtuple with `k` and `v` as fields, and stores tensors + shaped `[batch_size, num_heads, length, embed_dim]` which are results + of linear projection, reshape and transpose calculations in + MultiHeadAttention. If it is an instance of `Cache`, `k` and `v` + fields reserve intermediate results of previous positions, which + mostly used for decoder self attention. If it is an instance of + `StaticCache`, `key` and `value` args would be ignored, `k` and + `v` fields would be used as calculated results on `key` and + `value`, which mostly used for decoder-encoder cross attention. + It is only used for inference and should be None for training. + Default None. + + Returns: + Tensor|tuple: It is a tensor that has the same shape and data type \ + as `query`, representing attention output. Or a tuple if \ + `need_weights` is True or `cache` is not None. If `need_weights` \ + is True, except for attention output, the tuple also includes \ + the attention weights tensor shaped `[batch_size, num_heads, query_length, key_length]`. \ + If `cache` is not None, the tuple then includes the new cache \ + having the same type as `cache`, and if it is `StaticCache`, it \ + is same as the input `cache`, if it is `Cache`, the new cache \ + reserves tensors concatanating raw tensors with intermediate \ + results of current query. + """ + key = query if key is None else key + value = query if value is None else value + # compute q ,k ,v + if cache is None: + q, k, v = self._prepare_qkv(query, key, value, cache) + else: + q, k, v, cache = self._prepare_qkv(query, key, value, cache) + + # scale dot product attention + product = layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) + if attn_mask is not None: + # TODO(guosheng): support bool mask + product = product + attn_mask + weights = layers.softmax(product) + if self.dropout: + weights = layers.dropout( + weights, + dropout_prob=self.dropout, + dropout_implementation="upscale_in_train", + is_test=False) + + out = layers.matmul(weights, v) + + # combine heads + out = layers.transpose(out, perm=[0, 2, 1, 3]) + out = layers.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + + outs = [out] + if self.need_weights: + outs.append(weights) + if cache is not None: + outs.append(cache) + return out if len(outs) == 1 else tuple(outs) + + +class TransformerEncoderLayer(Layer): + """ + TransformerEncoderLayer is composed of two sub-layers which are self (multi-head) + attention and feedforward network. Before and after each sub-layer, pre-process + and post-precess would be applied on the input and output accordingly. If + `normalize_before` is True, pre-process is layer normalization and post-precess + includes dropout, residual connection. Otherwise, no pre-process and post-precess + includes dropout, residual connection, layer normalization. + + Parameters: + d_model (int): The expected feature size in the input and output. + nhead (int): The number of heads in multi-head attention(MHA). + dim_feedforward (int): The hidden layer size in the feedforward network(FFN). + dropout (float, optional): The dropout probability used in pre-process + and post-precess of MHA and FFN sub-layer. Default 0.1 + activation (str, optional): The activation function in the feedforward + network. Default relu. + attn_dropout (float, optional): The dropout probability used + in MHA to drop some attention target. If None, use the value of + `dropout`. Default None + act_dropout (float, optional): The dropout probability used after FFN + activition. If None, use the value of `dropout`. Default None + normalize_before (bool, optional): Indicate whether to put layer normalization + into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer + normalization and post-precess includes dropout, residual connection. + Otherwise, no pre-process and post-precess includes dropout, residual + connection, layer normalization. Default False + weight_attr(ParamAttr|tuple, optional): To specify the weight parameter property. + If it is a tuple, `weight_attr[0]` would be used as `weight_attr` for + MHA, and `weight_attr[1]` would be used as `weight_attr` for linear in FFN. + Otherwise, MHA and FFN both use it as `weight_attr` to create parameters. + Default: None, which means the default weight parameter property is used. + See usage for details in :code:`ParamAttr` . + bias_attr (ParamAttr|tuple, optional): To specify the bias parameter property. + If it is a tuple, `bias_attr[0]` would be used as `bias_attr` for + MHA, and `bias_attr[1]` would be used as `bias_attr` for linear in FFN. + Otherwise, MHA and FFN both use it as `bias_attr` to create parameters. + The `False` value means the corresponding layer would not have trainable + bias parameter. See usage for details in :code:`ParamAttr` . Default: None, + which means the default bias parameter property is used. + + + Examples: + + .. code-block:: python + + import paddle + from paddle import TransformerEncoderLayer + + # encoder input: [batch_size, src_len, d_model] + enc_input = paddle.rand((2, 4, 128)) + # self attention mask: [batch_size, n_head, src_len, src_len] + attn_mask = paddle.rand((2, 2, 4, 4)) + encoder_layer = TransformerEncoderLayer(128, 2, 512) + enc_output = encoder_layer(enc_input, attn_mask) # [2, 4, 128] + """ + + def __init__(self, + d_model, + nhead, + dim_feedforward, + dropout=0.1, + activation="relu", + attn_dropout=None, + act_dropout=None, + normalize_before=False, + weight_attr=None, + bias_attr=None): + self._config = locals() + self._config.pop("self") + self._config.pop("__class__", None) # py3 + + super(TransformerEncoderLayer, self).__init__() + attn_dropout = dropout if attn_dropout is None else attn_dropout + act_dropout = dropout if act_dropout is None else act_dropout + self.normalize_before = normalize_before + + weight_attrs = _convert_param_attr_to_list(weight_attr, 2) + bias_attrs = _convert_param_attr_to_list(bias_attr, 2) + + self.self_attn = MultiHeadAttention( + d_model, + nhead, + dropout=attn_dropout, + weight_attr=weight_attrs[0], + bias_attr=bias_attrs[0]) + self.linear1 = Linear( + d_model, dim_feedforward, weight_attrs[1], bias_attr=bias_attrs[1]) + self.dropout = Dropout( + act_dropout, dropout_implementation="upscale_in_train") + self.linear2 = Linear( + dim_feedforward, d_model, weight_attrs[1], bias_attr=bias_attrs[1]) + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.dropout1 = Dropout( + dropout, dropout_implementation="upscale_in_train") + self.dropout2 = Dropout( + dropout, dropout_implementation="upscale_in_train") + self.activation = getattr(layers, activation) + + def forward(self, src, src_mask=None): + """ + Applies a Transformer encoder layer on the input. + + Parameters: + src (Tensor): The input of Transformer encoder layer. It is + a tensor with shape `[batch_size, sequence_length, d_model]`. + The data type should be float32 or float64. + src_mask (Tensor, optional): A tensor used in multi-head attention + to prevents attention to some unwanted positions, usually the + paddings or the subsequent positions. It is a tensor with shape + broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`, + where the unwanted positions have `-INF` values and the others + have 0 values. The data type should be float32 or float64. It can + be None when nothing wanted or needed to be prevented attention to. + Default None + + Returns: + Tensor: The output of Transformer encoder layer. It is a tensor that \ + has the same shape and data type as `enc_input`. + """ + residual = src + if self.normalize_before: + src = self.norm1(src) + # TODO(guosheng): Add cache for encoder for the usage like UniLM + src = self.self_attn(src, src, src, src_mask) + src = residual + self.dropout1(src) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src) + if not self.normalize_before: + src = self.norm2(src) + return src + + +class TransformerEncoder(Layer): + """ + TransformerEncoder is a stack of N encoder layers. + + Parameters: + encoder_layer (Layer): an instance of the `TransformerEncoderLayer`. It + would be used as the first layer, and the other layers would be created + according to the configurations of it. + num_layers (int): The number of encoder layers to be stacked. + norm (LayerNorm, optional): the layer normalization component. If provided, + apply layer normalization on the output of last encoder layer. + + Examples: + + .. code-block:: python + + import paddle + from paddle import TransformerEncoderLayer, TransformerEncoder + + # encoder input: [batch_size, src_len, d_model] + enc_input = paddle.rand((2, 4, 128)) + # self attention mask: [batch_size, n_head, src_len, src_len] + attn_mask = paddle.rand((2, 2, 4, 4)) + encoder_layer = TransformerEncoderLayer(128, 2, 512) + encoder = TransformerEncoder(encoder_layer, 2) + enc_output = encoder(enc_input, attn_mask) # [2, 4, 128] + """ + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = LayerList([(encoder_layer if i == 0 else + type(encoder_layer)(**encoder_layer._config)) + for i in range(num_layers)]) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, src_mask=None): + """ + Applies a stack of N Transformer encoder layers on inputs. If `norm` is + provided, also applies layer normalization on the output of last encoder + layer. + + Parameters: + src (Tensor): The input of Transformer encoder. It is a tensor + with shape `[batch_size, sequence_length, d_model]`. The data + type should be float32 or float64. + src_mask (Tensor, optional): A tensor used in multi-head attention + to prevents attention to some unwanted positions, usually the + paddings or the subsequent positions. It is a tensor with shape + broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`, + where the unwanted positions have `-INF` values and the others + have 0 values. The data type should be float32 or float64. It can + be None when nothing wanted or needed to be prevented attention to. + Default None + + Returns: + Tensor: The output of Transformer encoder. It is a tensor that \ + has the same shape and data type as `src`. + """ + output = src + + for mod in self.layers: + output = mod(output, src_mask=src_mask) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoderLayer(Layer): + """ + TransformerDecoderLayer is composed of three sub-layers which are decoder + self (multi-head) attention, decoder-encoder cross attention and feedforward + network. Before and after each sub-layer, pre-process and post-precess would + be applied on the input and output accordingly. If `normalize_before` is True, + pre-process is layer normalization and post-precess includes dropout, residual + connection. Otherwise, no pre-process and post-precess includes dropout, residual + connection, layer normalization. + + Parameters: + d_model (int): The expected feature size in the input and output. + nhead (int): The number of heads in multi-head attention(MHA). + dim_feedforward (int): The hidden layer size in the feedforward network(FFN). + dropout (float, optional): The dropout probability used in pre-process + and post-precess of MHA and FFN sub-layer. Default 0.1 + activation (str, optional): The activation function in the feedforward + network. Default relu. + attn_dropout (float, optional): The dropout probability used + in MHA to drop some attention target. If None, use the value of + `dropout`. Default None + act_dropout (float, optional): The dropout probability used after FFN + activition. If None, use the value of `dropout`. Default None + normalize_before (bool, optional): Indicate whether to put layer normalization + into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer + normalization and post-precess includes dropout, residual connection. + Otherwise, no pre-process and post-precess includes dropout, residual + connection, layer normalization. Default False + weight_attr(ParamAttr|tuple, optional): To specify the weight parameter property. + If it is a tuple, `weight_attr[0]` would be used as `weight_attr` for + self attention, `weight_attr[1]` would be used as `weight_attr` for + cross attention, and `weight_attr[2]` would be used as `weight_attr` + for linear in FFN. Otherwise, the three sub-layers all uses it as + `weight_attr` to create parameters. Default: None, which means the + default weight parameter property is used. See usage for details + in :ref:`api_fluid_ParamAttr` . + bias_attr (ParamAttr|tuple, optional): To specify the bias parameter property. + If it is a tuple, `bias_attr[0]` would be used as `bias_attr` for + self attention, `bias_attr[1]` would be used as `bias_attr` for + cross attention, and `bias_attr[2]` would be used as `bias_attr` + for linear in FFN. Otherwise, the three sub-layers all uses it as + `bias_attr` to create parameters. The `False` value means the + corresponding layer would not have trainable bias parameter. See + usage for details in :code:`ParamAttr` . Default: None,which means + the default bias parameter property is used. + + Examples: + + .. code-block:: python + + import paddle + from paddle import TransformerDecoderLayer + + # decoder input: [batch_size, tgt_len, d_model] + dec_input = paddle.rand((2, 4, 128)) + # encoder output: [batch_size, src_len, d_model] + enc_output = paddle.rand((2, 6, 128)) + # self attention mask: [batch_size, n_head, tgt_len, tgt_len] + self_attn_mask = paddle.rand((2, 2, 4, 4)) + # cross attention mask: [batch_size, n_head, tgt_len, src_len] + cross_attn_mask = paddle.rand((2, 2, 4, 6)) + decoder_layer = TransformerDecoderLayer(128, 2, 512) + output = decoder_layer(dec_input, + enc_output, + self_attn_mask, + cross_attn_mask) # [2, 4, 128] + """ + + def __init__(self, + d_model, + nhead, + dim_feedforward, + dropout=0.1, + activation="relu", + attn_dropout=None, + act_dropout=None, + normalize_before=False, + weight_attr=None, + bias_attr=None): + self._config = locals() + self._config.pop("self") + self._config.pop("__class__", None) # py3 + + super(TransformerDecoderLayer, self).__init__() + attn_dropout = dropout if attn_dropout is None else attn_dropout + act_dropout = dropout if act_dropout is None else act_dropout + self.normalize_before = normalize_before + + weight_attrs = _convert_param_attr_to_list(weight_attr, 3) + bias_attrs = _convert_param_attr_to_list(bias_attr, 3) + + self.self_attn = MultiHeadAttention( + d_model, + nhead, + dropout=attn_dropout, + weight_attr=weight_attrs[0], + bias_attr=bias_attrs[0]) + self.cross_attn = MultiHeadAttention( + d_model, + nhead, + dropout=attn_dropout, + weight_attr=weight_attrs[1], + bias_attr=bias_attrs[1]) + self.linear1 = Linear( + d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2]) + self.dropout = Dropout( + act_dropout, dropout_implementation="upscale_in_train") + self.linear2 = Linear( + dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2]) + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.norm3 = LayerNorm(d_model) + self.dropout1 = Dropout( + dropout, dropout_implementation="upscale_in_train") + self.dropout2 = Dropout( + dropout, dropout_implementation="upscale_in_train") + self.dropout3 = Dropout( + dropout, dropout_implementation="upscale_in_train") + self.activation = getattr(layers, activation) + + def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None): + """ + Applies a Transformer decoder layer on the input. + + Parameters: + tgt (Tensor): The input of Transformer decoder layer. It is a tensor + with shape `[batch_size, target_length, d_model]`. The data type + should be float32 or float64. + memory (Tensor): The output of Transformer encoder. It is a tensor + with shape `[batch_size, source_length, d_model]`. The data type + should be float32 or float64. + tgt_mask (Tensor, optional): A tensor used in self attention + to prevents attention to some unwanted positions, usually the + the subsequent positions. It is a tensor with shape broadcasted + to `[batch_size, n_head, target_length, target_length]`, + where the unwanted positions have `-INF` values and the others + have 0 values. The data type should be float32 or float64. It can + be None when nothing wanted or needed to be prevented attention to. + Default None + memory_mask (Tensor, optional): A tensor used in decoder-encoder + cross attention to prevents attention to some unwanted positions, + usually the paddings. It is a tensor with shape broadcasted to + `[batch_size, n_head, target_length, source_length]`, where the + unwanted positions have `-INF` values and the others have 0 values. + The data type should be float32 or float64. It can be None when + nothing wanted or needed to be prevented attention to. Default None + cache (tuple, optional): It is a tuple( :code:`(incremental_cache, static_cache)` ), + `incremental_cache` is an instance of `MultiHeadAttention.Cache`, + `static_cache` is an instance of `MultiHeadAttention.StaticCache. + See `TransformerDecoderLayer.gen_cache` for more details. It is + only used for inference and should be None for training. Default + None. + + Returns: + Tensor|tuple: It is a tensor that has the same shape and data type \ + as `tgt`, representing the output of Transformer decoder layer. \ + Or a tuple if `cache` is not None, except for decoder layer output, \ + the tuple includes the new cache which is same as input `cache` \ + argument but `incremental_cache` in it has an incremental length. \ + See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \ + for more details. + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + if cache is None: + tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, None) + else: + tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask, + cache[0]) + tgt = residual + self.dropout1(tgt) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + if cache is None: + tgt = self.cross_attn(tgt, memory, memory, memory_mask, None) + else: + tgt, static_cache = self.cross_attn(tgt, memory, memory, + memory_mask, cache[1]) + tgt = residual + self.dropout2(tgt) + if not self.normalize_before: + tgt = self.norm2(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm3(tgt) + tgt = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = residual + self.dropout3(tgt) + if not self.normalize_before: + tgt = self.norm3(tgt) + return tgt if cache is None else (tgt, (incremental_cache, + static_cache)) + + def gen_cache(self, memory): + """ + Generates cache for `forward` usage. The generated cache is a tuple + composed of an instance of `MultiHeadAttention.Cache` and an instance + of `MultiHeadAttention.StaticCache`. + + Parameters: + memory (Tensor): The output of Transformer encoder. It is a tensor + with shape `[batch_size, source_length, d_model]`. The data type + should be float32 or float64. + + Returns: + tuple: It is a tuple( :code:`(incremental_cache, static_cache)` ). \ + `incremental_cache` is an instance of `MultiHeadAttention.Cache` \ + produced by `self_attn.gen_cache(memory, MultiHeadAttention.Cache)`, \ + it reserves two tensors shaped `[batch_size, nhead, 0, d_model // nhead]`. \ + `static_cache` is an instance of `MultiHeadAttention.StaticCache` \ + produced by `cross_attn.gen_cache(memory, MultiHeadAttention.StaticCache)`, \ + it reserves two tensors shaped `[batch_size, nhead, source_length, d_model // nhead]`. + See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \ + for more details. + """ + incremental_cache = self.self_attn.gen_cache( + memory, type=self.self_attn.Cache) + static_cache = self.cross_attn.gen_cache( + memory, memory, type=self.cross_attn.StaticCache) + return incremental_cache, static_cache + + +class TransformerDecoder(Layer): + """ + TransformerDecoder is a stack of N decoder layers. + + Parameters: + decoder_layer (Layer): an instance of the `TransformerDecoderLayer`. It + would be used as the first layer, and the other layers would be created + according to the configurations of it. + num_layers (int): The number of decoder layers to be stacked. + norm (LayerNorm, optional): the layer normalization component. If provided, + apply layer normalization on the output of last encoder layer. + + Examples: + + .. code-block:: python + + import paddle + from paddle import TransformerDecoderLayer, TransformerDecoder + + # decoder input: [batch_size, tgt_len, d_model] + dec_input = paddle.rand((2, 4, 128)) + # encoder output: [batch_size, src_len, d_model] + enc_output = paddle.rand((2, 6, 128)) + # self attention mask: [batch_size, n_head, tgt_len, tgt_len] + self_attn_mask = paddle.rand((2, 2, 4, 4)) + # cross attention mask: [batch_size, n_head, tgt_len, src_len] + cross_attn_mask = paddle.rand((2, 2, 4, 6)) + decoder_layer = TransformerDecoderLayer(128, 2, 512) + decoder = TransformerDecoder(decoder_layer, 2) + output = decoder(dec_input, + enc_output, + self_attn_mask, + cross_attn_mask) # [2, 4, 128] + """ + + def __init__(self, decoder_layer, num_layers, norm=None): + super(TransformerDecoder, self).__init__() + self.layers = LayerList([(decoder_layer if i == 0 else + type(decoder_layer)(**decoder_layer._config)) + for i in range(num_layers)]) + self.num_layers = num_layers + self.norm = norm + + def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None): + """ + Applies a stack of N Transformer decoder layers on inputs. If `norm` is + provided, also applies layer normalization on the output of last decoder + layer. + + Parameters: + tgt (Tensor): The input of Transformer decoder. It is a tensor + with shape `[batch_size, target_length, d_model]`. The data type + should be float32 or float64. + memory (Tensor): The output of Transformer encoder. It is a tensor + with shape `[batch_size, source_length, d_model]`. The data type + should be float32 or float64. + tgt_mask (Tensor, optional): A tensor used in self attention + to prevents attention to some unwanted positions, usually the + the subsequent positions. It is a tensor with shape broadcasted + to `[batch_size, n_head, target_length, target_length]`, + where the unwanted positions have `-INF` values and the others + have 0 values. The data type should be float32 or float64. It can + be None when nothing wanted or needed to be prevented attention to. + Default None + memory_mask (Tensor, optional): A tensor used in decoder-encoder + cross attention to prevents attention to some unwanted positions, + usually the paddings. It is a tensor with shape broadcasted to + `[batch_size, n_head, target_length, source_length]`, where the + unwanted positions have `-INF` values and the others have 0 values. + The data type should be float32 or float64. It can be None when + nothing wanted or needed to be prevented attention to. Default None + cache (list, optional): It is a list, and each element in the list + is a tuple( :code:`(incremental_cache, static_cache)` ). See + `TransformerDecoder.gen_cache` for more details. It is only + used for inference and should be None for training. Default None. + + Returns: + Tensor|tuple: It is a tensor that has the same shape and data type \ + as `tgt`, representing the output of Transformer decoder. \ + Or a tuple if `cache` is not None, except for decoder output, \ + the tuple includes the new cache which is same as input `cache` \ + argument but `incremental_cache` in it has an incremental length. \ + See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \ + for more details. + """ + output = tgt + new_caches = [] + for i, mod in enumerate(self.layers): + if cache is None: + output = mod(output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + cache=None) + else: + output, new_cache = mod(output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + cache=cache[i]) + new_caches.append(new_cache) + + if self.norm is not None: + output = self.norm(output) + + return output if cache is None else (output, new_caches) + + def gen_cache(self, memory, do_zip=False): + """ + Generates cache for `forward` usage. The generated cache is a list, and + each element in it is a tuple( :code:`(incremental_cache, static_cache)` ) + produced by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache` + for more details. If `do_zip` is True, apply `zip` on these tuples to get + a list with two elements. + + + Parameters: + memory (Tensor): The output of Transformer encoder. It is a tensor + with shape `[batch_size, source_length, d_model]`. The data type + should be float32 or float64. + do_zip (bool, optional): Indicate whether to apply `zip` on the tuples. + If True, return a list with two elements. Default False + + Returns: + list: It is a list, and each element in the list is a tuple produced \ + by `TransformerDecoderLayer.gen_cache(memory)`. See `TransformerDecoderLayer.gen_cache` \ + for more details. If `do_zip` is True, apply `zip` on these tuples \ + and return a list with two elements. + """ + cache = [layer.gen_cache(memory) for layer in self.layers] + if do_zip: + cache = list(zip(*cache)) + return cache + + +class Transformer(Layer): + """ + A Transformer model composed of an instance of `TransformerEncoder` and an + instance of `TransformerDecoder`. While the embedding layer and output layer + are not included. + + Please refer to `Attention is all you need `_ , + and see `TransformerEncoder` and `TransformerDecoder` for more details. + + Users can configurate the model architecture with corresponding parameters. + Note the usage of `normalize_before` representing where to apply layer + normalization (in pre-process or post-precess of multi-head attention or FFN), + and some transformer like models are different on this, such as + `BERT `_ and `GPT2 `_ . + The default architecture here places layer normalization in post-process and + applies another layer normalization on the output of last encoder/decoder layer. + + Parameters: + d_model (int): The expected feature size in the encoder/decoder input + and output. + nhead (int): The number of heads in multi-head attention(MHA). + num_encoder_layers (int): The number of layers in encoder. + num_encoder_layers (int): The number of layers in decoder. + dim_feedforward (int): The hidden layer size in the feedforward network(FFN). + dropout (float, optional): The dropout probability used in pre-process + and post-precess of MHA and FFN sub-layer. Default 0.1 + activation (str, optional): The activation function in the feedforward + network. Default relu. + attn_dropout (float, optional): The dropout probability used + in MHA to drop some attention target. If None, use the value of + `dropout`. Default None + act_dropout (float, optional): The dropout probability used after FFN + activition. If None, use the value of `dropout`. Default None + normalize_before (bool, optional): Indicate whether to put layer normalization + into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer + normalization and post-precess includes dropout, residual connection. + Otherwise, no pre-process and post-precess includes dropout, residual + connection, layer normalization. Default False + weight_attr(ParamAttr|tuple, optional): To specify the weight parameter property. + If it is a tuple, `weight_attr[0]` would be used as `weight_attr` for + self attention, `weight_attr[1]` would be used as `weight_attr` for + cross attention, and `weight_attr[2]` would be used as `weight_attr` + for linear in FFN. Otherwise, the three sub-layers all uses it as + `weight_attr` to create parameters. Default: None, which means the + default weight parameter property is used. See usage for details + in :code:`ParamAttr` . + bias_attr (ParamAttr|tuple, optional): To specify the bias parameter property. + If it is a tuple, `bias_attr[0]` would be used as `bias_attr` for + self attention, `bias_attr[1]` would be used as `bias_attr` for + cross attention, and `bias_attr[2]` would be used as `bias_attr` + for linear in FFN. Otherwise, the three sub-layers all uses it as + `bias_attr` to create parameters. The `False` value means the + corresponding layer would not have trainable bias parameter. See + usage for details in :code:`ParamAttr` . Default: None,which means + the default bias parameter property is used. + custom_encoder (Layer): If custom encoder is provided, use it as the encoder. + Default None + custom_decoder (Layer): If custom decoder is provided, use it as the decoder. + Default None + + Examples: + + .. code-block:: python + + import paddle + from paddle import Transformer + + # src: [batch_size, tgt_len, d_model] + enc_input = paddle.rand((2, 4, 128)) + # tgt: [batch_size, src_len, d_model] + dec_input = paddle.rand((2, 6, 128)) + # src_mask: [batch_size, n_head, src_len, src_len] + enc_self_attn_mask = paddle.rand((2, 2, 4, 4)) + # tgt_mask: [batch_size, n_head, tgt_len, tgt_len] + dec_self_attn_mask = paddle.rand((2, 2, 6, 6)) + # memory_mask: [batch_size, n_head, tgt_len, src_len] + cross_attn_mask = paddle.rand((2, 2, 6, 4)) + transformer = Transformer(128, 2, 4, 4, 512) + output = transformer(enc_input, + dec_input, + enc_self_attn_mask, + dec_self_attn_mask, + cross_attn_mask) # [2, 6, 128] + """ + + def __init__(self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + attn_dropout=None, + act_dropout=None, + normalize_before=False, + weight_attr=None, + bias_attr=None, + custom_encoder=None, + custom_decoder=None): + super(Transformer, self).__init__() + + if custom_encoder is not None: + self.encoder = custom_encoder + else: + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, + attn_dropout, act_dropout, normalize_before, weight_attr, + bias_attr) + encoder_norm = LayerNorm(d_model) + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, + encoder_norm) + + if custom_decoder is not None: + self.decoder = custom_decoder + else: + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, + attn_dropout, act_dropout, normalize_before, weight_attr, + bias_attr) + decoder_norm = LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, + decoder_norm) + + self.d_model = d_model + self.nhead = nhead + + def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None): + """ + Applies a Transformer model on the inputs. + + Parameters: + src (Tensor): The input of Transformer encoder. It is a tensor + with shape `[batch_size, source_length, d_model]`. The data type + should be float32 or float64. + tgt (Tensor): The input of Transformer decoder. It is a tensor + with shape `[batch_size, target_length, d_model]`. The data type + should be float32 or float64. + memory (Tensor): The output of Transformer encoder. It is a tensor + with shape `[batch_size, source_length, d_model]`. The data type + should be float32 or float64. + tgt_mask (Tensor, optional): A tensor used in self attention + to prevents attention to some unwanted positions, usually the + the subsequent positions. It is a tensor with shape broadcasted + to `[batch_size, n_head, target_length, target_length]`, + where the unwanted positions have `-INF` values and the others + have 0 values. The data type should be float32 or float64. It can + be None when nothing wanted or needed to be prevented attention to. + Default None + memory_mask (Tensor, optional): A tensor used in decoder-encoder + cross attention to prevents attention to some unwanted positions, + usually the paddings. It is a tensor with shape broadcasted to + `[batch_size, n_head, target_length, source_length]`, where the + unwanted positions have `-INF` values and the others have 0 values. + The data type should be float32 or float64. It can be None when + nothing wanted or needed to be prevented attention to. Default None + + Returns: + Tensor: It is a tensor that has the same shape and data type \ + as `tgt`, representing the output of Transformer decoder. + """ + memory = self.encoder(src, src_mask=src_mask) + output = self.decoder( + tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask) + return output -- GitLab