attention_model.py 9.0 KB
Newer Older
L
LiuChiaChi 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
L
LiuChiaChi 已提交
16 17
import paddle
from paddle.nn import Layer, Linear, Dropout, Embedding, LayerList, RNN, LSTM, LSTMCell, RNNCellBase
18
import paddle.nn.initializer as I
L
LiuChiaChi 已提交
19 20 21
import paddle.nn.functional as F
SEED = 102
paddle.framework.manual_seed(SEED)
22

23

L
LiuChiaChi 已提交
24
class AttentionModel(Layer):
25 26 27
    def __init__(self,
                 hidden_size,
                 src_vocab_size,
L
LiuChiaChi 已提交
28
                 trg_vocab_size,
29 30
                 num_layers=1,
                 init_scale=0.1,
L
LiuChiaChi 已提交
31
                 padding_idx=0,
32 33 34 35 36
                 dropout=None,
                 beam_size=1,
                 beam_start_token=1,
                 beam_end_token=2,
                 beam_max_step_num=100,
L
LiuChiaChi 已提交
37
                 dtype="float32"):
38 39 40
        super(AttentionModel, self).__init__()
        self.hidden_size = hidden_size
        self.src_vocab_size = src_vocab_size
L
LiuChiaChi 已提交
41
        self.trg_vocab_size = trg_vocab_size
42 43 44 45 46 47 48 49 50
        self.num_layers = num_layers
        self.init_scale = init_scale
        self.dropout = dropout
        self.beam_size = beam_size
        self.beam_start_token = beam_start_token
        self.beam_end_token = beam_end_token
        self.beam_max_step_num = beam_max_step_num
        self.kinf = 1e9

L
LiuChiaChi 已提交
51 52 53 54
        self.encoder = Encoder(src_vocab_size, hidden_size, num_layers,
                               init_scale, padding_idx, dropout, dtype)
        self.decoder = Decoder(trg_vocab_size, hidden_size, num_layers,
                               init_scale, padding_idx, dropout, dtype)
55

L
LiuChiaChi 已提交
56
    def forward(self, inputs):
57

L
LiuChiaChi 已提交
58 59 60 61 62 63 64
        src, trg, label, src_seq_len, trg_seq_len = inputs
        enc_states, enc_outputs, enc_padding_mask = self.encoder(src,
                                                                 src_seq_len)
        enc_states = [(enc_states[0][i], enc_states[1][i])
                      for i in range(self.num_layers)]
        decoder_loss = self.decoder(trg, trg_seq_len, enc_states, enc_outputs,
                                    enc_padding_mask, label)
65

L
LiuChiaChi 已提交
66
        return decoder_loss
67 68


L
LiuChiaChi 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81
class Encoder(Layer):
    def __init__(self,
                 vocab_size,
                 hidden_size,
                 num_layers=1,
                 init_scale=0.1,
                 padding_idx=0,
                 dropout=None,
                 dtype="float32"):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dtype = dtype
L
LiuChiaChi 已提交
82
        self.padding_idx = padding_idx
L
LiuChiaChi 已提交
83 84 85
        self.embedder = Embedding(
            vocab_size,
            hidden_size,
86
            padding_idx=padding_idx,
L
LiuChiaChi 已提交
87 88
            weight_attr=paddle.ParamAttr(
                name='source_embedding',
89
                initializer=I.Uniform(
L
LiuChiaChi 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102
                    low=-init_scale, high=init_scale)))
        self.lstm = LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            direction="forward",
            dropout=dropout if num_layers > 1 else 0., )

    def forward(self, src, src_sequence_length):
        src_emb = self.embedder(src)
        outs, (final_h, final_c) = self.lstm(
            src_emb, sequence_length=src_sequence_length)

L
LiuChiaChi 已提交
103
        enc_len_mask = (src != self.padding_idx).astype(self.dtype)
L
LiuChiaChi 已提交
104 105 106 107 108 109 110 111 112 113
        enc_padding_mask = (enc_len_mask - 1.0) * 1e9
        return [final_h, final_c], outs, enc_padding_mask


class AttentionLayer(Layer):
    def __init__(self, hidden_size, bias=False, init_scale=0.1):
        super(AttentionLayer, self).__init__()
        self.input_proj = Linear(
            hidden_size,
            hidden_size,
114
            weight_attr=paddle.ParamAttr(initializer=I.Uniform(
L
LiuChiaChi 已提交
115 116 117 118 119
                low=-init_scale, high=init_scale)),
            bias_attr=bias)
        self.output_proj = Linear(
            hidden_size + hidden_size,
            hidden_size,
120
            weight_attr=paddle.ParamAttr(initializer=I.Uniform(
L
LiuChiaChi 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
                low=-init_scale, high=init_scale)),
            bias_attr=bias)

    def forward(self, hidden, encoder_output, encoder_padding_mask):
        encoder_output = self.input_proj(encoder_output)

        attn_scores = paddle.matmul(
            paddle.unsqueeze(hidden, [1]), encoder_output, transpose_y=True)
        encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1])
        if encoder_padding_mask is not None:
            attn_scores = paddle.add(attn_scores, encoder_padding_mask)
        attn_scores = F.softmax(attn_scores)

        attn_out = paddle.matmul(attn_scores, encoder_output)
        attn_out = paddle.squeeze(attn_out, [1])
        attn_out = paddle.concat([attn_out, hidden], 1)
        attn_out = self.output_proj(attn_out)
        return attn_out


class DecoderCell(RNNCellBase):
    def __init__(self,
                 input_size,
                 hidden_size,
                 num_layers,
                 init_scale=0.1,
                 dropout=0.):
        super(DecoderCell, self).__init__()
        if dropout > 0.0:
            self.dropout = Dropout(dropout)
        else:
            self.dropout = None
        self.lstm_cells = []
        for i in range(num_layers):
            self.lstm_cells.append(
                self.add_sublayer(
                    "lstm_%d" % i,
                    LSTMCell(
                        input_size=input_size + hidden_size
                        if i == 0 else hidden_size,
                        hidden_size=hidden_size)))
        self.attention_layer = AttentionLayer(hidden_size)

    def forward(self,
                step_input,
                states,
                encoder_output,
                encoder_padding_mask=None):
        lstm_states, input_feed = states
        new_lstm_states = []

        step_input = paddle.concat([step_input, input_feed], 1)
        for i, lstm_cell in enumerate(self.lstm_cells):

            new_hidden, (new_hidden, new_cell) = lstm_cell(step_input,
                                                           lstm_states[i])
            if self.dropout:
                new_hidden = self.dropout(new_hidden)

            new_lstm_state = [new_hidden, new_cell]
            new_lstm_states.append(new_lstm_state)
            step_input = new_hidden
        out = self.attention_layer(step_input, encoder_output,
                                   encoder_padding_mask)
        return out, [new_lstm_states, out]


class Decoder(Layer):
    def __init__(self,
                 vocab_size,
                 hidden_size,
                 num_layers=1,
                 init_scale=0.1,
                 padding_idx=0,
                 dropout=None,
                 dtype="float32"):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.init_scale = init_scale
        self.dtype = dtype
L
LiuChiaChi 已提交
202
        self.padding_idx = padding_idx
L
LiuChiaChi 已提交
203 204 205
        self.embedder = Embedding(
            vocab_size,
            hidden_size,
206
            padding_idx=padding_idx,
L
LiuChiaChi 已提交
207 208
            weight_attr=paddle.ParamAttr(
                name='target_embedding',
209
                initializer=I.Uniform(
L
LiuChiaChi 已提交
210 211 212 213 214 215 216 217 218
                    low=-init_scale, high=init_scale)))
        self.dropout = dropout
        self.lstm_attention = RNN(DecoderCell(hidden_size, hidden_size,
                                              num_layers, init_scale, dropout),
                                  is_reverse=False,
                                  time_major=False)
        self.fc = Linear(
            hidden_size,
            vocab_size,
219
            weight_attr=paddle.ParamAttr(initializer=I.Uniform(
L
LiuChiaChi 已提交
220
                low=-init_scale, high=init_scale)),
221
            bias_attr=False)
222

L
LiuChiaChi 已提交
223 224 225 226 227
    def forward(self, trg, trg_sequence_length, enc_states, enc_outputs,
                enc_padding_mask, label):
        trg_emb = self.embedder(trg)
        bsz = paddle.shape(trg)[0]
        input_feed = paddle.to_tensor(
228
            np.zeros(
L
LiuChiaChi 已提交
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
                (bsz, self.hidden_size), dtype=self.dtype))
        states = [enc_states, input_feed]
        dec_output, _ = self.lstm_attention(
            trg_emb,
            initial_states=states,
            sequence_length=trg_sequence_length,
            encoder_output=enc_outputs,
            encoder_padding_mask=enc_padding_mask)

        dec_output = self.fc(dec_output)

        loss = F.softmax_with_cross_entropy(
            logits=dec_output, label=label, soft_label=False)
        loss = paddle.squeeze(loss, axis=[2])

L
LiuChiaChi 已提交
244
        trg_mask = (trg != self.padding_idx).astype(self.dtype)
L
LiuChiaChi 已提交
245 246 247 248 249

        loss = loss * trg_mask
        loss = paddle.reduce_mean(loss, dim=[0])
        loss = paddle.reduce_sum(loss)
        return loss