rec_nrtr_head.py 32.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
Topdu 已提交
15 16 17
import math
import paddle
import copy
18
from paddle import nn
T
Topdu 已提交
19 20 21 22 23
import paddle.nn.functional as F
from paddle.nn import LayerList
from paddle.nn.initializer import XavierNormal as xavier_uniform_
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
import numpy as np
T
Topdu 已提交
24
from ppocr.modeling.heads.multiheadAttention import MultiheadAttention
T
Topdu 已提交
25 26 27 28 29 30
from paddle.nn.initializer import Constant as constant_
from paddle.nn.initializer import XavierNormal as xavier_normal_

zeros_ = constant_(value=0.)
ones_ = constant_(value=1.)

31

T
Topdu 已提交
32
class Transformer(nn.Layer):
33
    """A transformer model. User is able to modify the attributes as needed. The architechture
T
Topdu 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
    is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
    Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
    Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
    Processing Systems, pages 6000-6010.

    Args:
        d_model: the number of expected features in the encoder/decoder inputs (default=512).
        nhead: the number of heads in the multiheadattention models (default=8).
        num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
        num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        custom_encoder: custom encoder (default=None).
        custom_decoder: custom decoder (default=None).

    """

51 52 53 54 55 56 57 58 59 60 61 62 63 64
    def __init__(self,
                 d_model=512,
                 nhead=8,
                 num_encoder_layers=6,
                 beam_size=0,
                 num_decoder_layers=6,
                 dim_feedforward=1024,
                 attention_dropout_rate=0.0,
                 residual_dropout_rate=0.1,
                 custom_encoder=None,
                 custom_decoder=None,
                 in_channels=0,
                 out_channels=0,
                 scale_embedding=True):
T
Topdu 已提交
65
        super(Transformer, self).__init__()
T
Topdu 已提交
66
        self.out_channels = out_channels + 1
T
Topdu 已提交
67 68
        self.embedding = Embeddings(
            d_model=d_model,
T
Topdu 已提交
69
            vocab=self.out_channels,
T
Topdu 已提交
70
            padding_idx=0,
71
            scale_embedding=scale_embedding)
T
Topdu 已提交
72 73
        self.positional_encoding = PositionalEncoding(
            dropout=residual_dropout_rate,
74
            dim=d_model, )
T
Topdu 已提交
75 76 77
        if custom_encoder is not None:
            self.encoder = custom_encoder
        else:
78 79 80 81 82 83
            if num_encoder_layers > 0:
                encoder_layer = TransformerEncoderLayer(
                    d_model, nhead, dim_feedforward, attention_dropout_rate,
                    residual_dropout_rate)
                self.encoder = TransformerEncoder(encoder_layer,
                                                  num_encoder_layers)
T
Topdu 已提交
84 85 86 87 88 89
            else:
                self.encoder = None

        if custom_decoder is not None:
            self.decoder = custom_decoder
        else:
90 91 92
            decoder_layer = TransformerDecoderLayer(
                d_model, nhead, dim_feedforward, attention_dropout_rate,
                residual_dropout_rate)
T
Topdu 已提交
93 94 95 96 97 98
            self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)

        self._reset_parameters()
        self.beam_size = beam_size
        self.d_model = d_model
        self.nhead = nhead
T
Topdu 已提交
99 100
        self.tgt_word_prj = nn.Linear(
            d_model, self.out_channels, bias_attr=False)
101
        w0 = np.random.normal(0.0, d_model**-0.5,
T
Topdu 已提交
102
                              (d_model, self.out_channels)).astype(np.float32)
T
Topdu 已提交
103 104 105 106
        self.tgt_word_prj.weight.set_value(w0)
        self.apply(self._init_weights)

    def _init_weights(self, m):
107

T
Topdu 已提交
108 109 110 111 112
        if isinstance(m, nn.Conv2D):
            xavier_normal_(m.weight)
            if m.bias is not None:
                zeros_(m.bias)

113 114
    def forward_train(self, src, tgt):
        tgt = tgt[:, :-1]
T
Topdu 已提交
115

116 117 118 119
        tgt_key_padding_mask = self.generate_padding_mask(tgt)
        tgt = self.embedding(tgt).transpose([1, 0, 2])
        tgt = self.positional_encoding(tgt)
        tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0])
T
Topdu 已提交
120

121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
        if self.encoder is not None:
            src = self.positional_encoding(src.transpose([1, 0, 2]))
            memory = self.encoder(src)
        else:
            memory = src.squeeze(2).transpose([2, 0, 1])
        output = self.decoder(
            tgt,
            memory,
            tgt_mask=tgt_mask,
            memory_mask=None,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=None)
        output = output.transpose([1, 0, 2])
        logit = self.tgt_word_prj(output)
        return logit

    def forward(self, src, targets=None):
        """Take in and process masked source/target sequences.
T
Topdu 已提交
139 140 141 142 143 144 145
        Args:
            src: the sequence to the encoder (required).
            tgt: the sequence to the decoder (required).
        Shape:
            - src: :math:`(S, N, E)`.
            - tgt: :math:`(T, N, E)`.
        Examples:
146
            >>> output = transformer_model(src, tgt)
T
Topdu 已提交
147
        """
148 149 150 151

        if self.training:
            max_len = targets[1].max()
            tgt = targets[0][:, :2 + max_len]
T
Topdu 已提交
152 153
            return self.forward_train(src, tgt)
        else:
154
            if self.beam_size > 0:
T
Topdu 已提交
155 156 157 158 159
                return self.forward_beam(src)
            else:
                return self.forward_test(src)

    def forward_test(self, src):
T
Topdu 已提交
160
        bs = paddle.shape(src)[0]
161
        if self.encoder is not None:
T
Topdu 已提交
162
            src = self.positional_encoding(paddle.transpose(src, [1, 0, 2]))
T
Topdu 已提交
163 164
            memory = self.encoder(src)
        else:
T
Topdu 已提交
165
            memory = paddle.transpose(paddle.squeeze(src, 2), [2, 0, 1])
166
        dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
T
Topdu 已提交
167
        dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32)
T
Topdu 已提交
168
        for len_dec_seq in range(1, 25):
T
Topdu 已提交
169
            dec_seq_embed = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
T
Topdu 已提交
170
            dec_seq_embed = self.positional_encoding(dec_seq_embed)
T
Topdu 已提交
171 172
            tgt_mask = self.generate_square_subsequent_mask(
                paddle.shape(dec_seq_embed)[0])
173 174
            output = self.decoder(
                dec_seq_embed,
T
Topdu 已提交
175
                memory,
176 177
                tgt_mask=tgt_mask,
                memory_mask=None,
T
Topdu 已提交
178
                tgt_key_padding_mask=None,
179
                memory_key_padding_mask=None)
T
Topdu 已提交
180 181 182 183
            dec_output = paddle.transpose(output, [1, 0, 2])
            dec_output = dec_output[:, -1, :]
            word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
            preds_idx = paddle.argmax(word_prob, axis=1)
184
            if paddle.equal_all(
T
Topdu 已提交
185
                    preds_idx,
186
                    paddle.full(
T
Topdu 已提交
187
                        paddle.shape(preds_idx), 3, dtype='int64')):
T
Topdu 已提交
188
                break
T
Topdu 已提交
189
            preds_prob = paddle.max(word_prob, axis=1)
190
            dec_seq = paddle.concat(
T
Topdu 已提交
191 192 193 194
                [dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1)
            dec_prob = paddle.concat(
                [dec_prob, paddle.reshape(preds_prob, [-1, 1])], axis=1)
        return [dec_seq, dec_prob]
T
Topdu 已提交
195

196
    def forward_beam(self, images):
T
Topdu 已提交
197 198 199 200
        ''' Translation work in one batch '''

        def get_inst_idx_to_tensor_position_map(inst_idx_list):
            ''' Indicate the position of an instance in a tensor. '''
201 202 203 204
            return {
                inst_idx: tensor_position
                for tensor_position, inst_idx in enumerate(inst_idx_list)
            }
T
Topdu 已提交
205

206 207
        def collect_active_part(beamed_tensor, curr_active_inst_idx,
                                n_prev_active_inst, n_bm):
T
Topdu 已提交
208 209
            ''' Collect tensor parts associated to active instances. '''

T
Topdu 已提交
210
            beamed_tensor_shape = paddle.shape(beamed_tensor)
T
Topdu 已提交
211
            n_curr_active_inst = len(curr_active_inst_idx)
T
Topdu 已提交
212 213
            new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1],
                         beamed_tensor_shape[2])
T
Topdu 已提交
214

T
Topdu 已提交
215
            beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
216
            beamed_tensor = beamed_tensor.index_select(
T
Topdu 已提交
217 218
                curr_active_inst_idx, axis=0)
            beamed_tensor = beamed_tensor.reshape(new_shape)
T
Topdu 已提交
219 220 221

            return beamed_tensor

222 223
        def collate_active_info(src_enc, inst_idx_to_position_map,
                                active_inst_idx_list):
T
Topdu 已提交
224 225
            # Sentences which are still active are collected,
            # so the decoder will not run on completed sentences.
226

T
Topdu 已提交
227
            n_prev_active_inst = len(inst_idx_to_position_map)
228 229 230
            active_inst_idx = [
                inst_idx_to_position_map[k] for k in active_inst_idx_list
            ]
T
Topdu 已提交
231
            active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64')
232 233 234 235 236
            active_src_enc = collect_active_part(
                src_enc.transpose([1, 0, 2]), active_inst_idx,
                n_prev_active_inst, n_bm).transpose([1, 0, 2])
            active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)
T
Topdu 已提交
237 238
            return active_src_enc, active_inst_idx_to_position_map

239 240 241
        def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
                             inst_idx_to_position_map, n_bm,
                             memory_key_padding_mask):
T
Topdu 已提交
242 243 244
            ''' Decode and update beam status, and then return active beam idx '''

            def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
245 246 247
                dec_partial_seq = [
                    b.get_current_state() for b in inst_dec_beams if not b.done
                ]
T
Topdu 已提交
248 249 250 251
                dec_partial_seq = paddle.stack(dec_partial_seq)
                dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
                return dec_partial_seq

252 253
            def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
                             memory_key_padding_mask):
T
Topdu 已提交
254
                dec_seq = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
T
Topdu 已提交
255
                dec_seq = self.positional_encoding(dec_seq)
T
Topdu 已提交
256 257
                tgt_mask = self.generate_square_subsequent_mask(
                    paddle.shape(dec_seq)[0])
T
Topdu 已提交
258
                dec_output = self.decoder(
259 260
                    dec_seq,
                    enc_output,
T
Topdu 已提交
261
                    tgt_mask=tgt_mask,
T
Topdu 已提交
262 263 264
                    tgt_key_padding_mask=None,
                    memory_key_padding_mask=memory_key_padding_mask, )
                dec_output = paddle.transpose(dec_output, [1, 0, 2])
265 266
                dec_output = dec_output[:,
                                        -1, :]  # Pick the last step: (bh * bm) * d_h
T
Topdu 已提交
267 268
                word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
                word_prob = paddle.reshape(word_prob, [n_active_inst, n_bm, -1])
T
Topdu 已提交
269 270
                return word_prob

271 272
            def collect_active_inst_idx_list(inst_beams, word_prob,
                                             inst_idx_to_position_map):
T
Topdu 已提交
273 274
                active_inst_idx_list = []
                for inst_idx, inst_position in inst_idx_to_position_map.items():
275 276
                    is_inst_complete = inst_beams[inst_idx].advance(word_prob[
                        inst_position])
T
Topdu 已提交
277 278 279 280 281 282 283
                    if not is_inst_complete:
                        active_inst_idx_list += [inst_idx]

                return active_inst_idx_list

            n_active_inst = len(inst_idx_to_position_map)
            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
284
            word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
T
Topdu 已提交
285
                                     None)
T
Topdu 已提交
286 287 288 289 290 291 292 293 294 295
            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)
            return active_inst_idx_list

        def collect_hypothesis_and_scores(inst_dec_beams, n_best):
            all_hyp, all_scores = [], []
            for inst_idx in range(len(inst_dec_beams)):
                scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
                all_scores += [scores[:n_best]]
296 297 298 299
                hyps = [
                    inst_dec_beams[inst_idx].get_hypothesis(i)
                    for i in tail_idxs[:n_best]
                ]
T
Topdu 已提交
300 301 302 303 304
                all_hyp += [hyps]
            return all_hyp, all_scores

        with paddle.no_grad():
            #-- Encode
305
            if self.encoder is not None:
T
Topdu 已提交
306
                src = self.positional_encoding(images.transpose([1, 0, 2]))
T
Topdu 已提交
307
                src_enc = self.encoder(src)
T
Topdu 已提交
308 309 310 311
            else:
                src_enc = images.squeeze(2).transpose([0, 2, 1])

            n_bm = self.beam_size
T
Topdu 已提交
312 313 314 315 316
            src_shape = paddle.shape(src_enc)
            inst_dec_beams = [Beam(n_bm) for _ in range(1)]
            active_inst_idx_list = list(range(1))
            # Repeat data for beam search
            src_enc = paddle.tile(src_enc, [1, n_bm, 1])
317 318
            inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)
T
Topdu 已提交
319
            # Decode
T
Topdu 已提交
320 321 322
            for len_dec_seq in range(1, 25):
                src_enc_copy = src_enc.clone()
                active_inst_idx_list = beam_decode_step(
323 324
                    inst_dec_beams, len_dec_seq, src_enc_copy,
                    inst_idx_to_position_map, n_bm, None)
T
Topdu 已提交
325 326 327
                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>
                src_enc, inst_idx_to_position_map = collate_active_info(
328 329 330 331
                    src_enc_copy, inst_idx_to_position_map,
                    active_inst_idx_list)
        batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
                                                                1)
T
Topdu 已提交
332
        result_hyp = []
T
Topdu 已提交
333 334 335 336
        hyp_scores = []
        for bs_hyp, score in zip(batch_hyp, batch_scores):
            l = len(bs_hyp[0])
            bs_hyp_pad = bs_hyp[0] + [3] * (25 - l)
T
Topdu 已提交
337
            result_hyp.append(bs_hyp_pad)
T
Topdu 已提交
338 339 340 341 342 343 344 345
            score = float(score) / l
            hyp_score = [score for _ in range(25)]
            hyp_scores.append(hyp_score)
        return [
            paddle.to_tensor(
                np.array(result_hyp), dtype=paddle.int64),
            paddle.to_tensor(hyp_scores)
        ]
T
Topdu 已提交
346 347

    def generate_square_subsequent_mask(self, sz):
348
        """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
T
Topdu 已提交
349 350
            Unmasked positions are filled with float(0.0).
        """
351 352 353 354 355 356
        mask = paddle.zeros([sz, sz], dtype='float32')
        mask_inf = paddle.triu(
            paddle.full(
                shape=[sz, sz], dtype='float32', fill_value='-inf'),
            diagonal=1)
        mask = mask + mask_inf
T
Topdu 已提交
357 358 359
        return mask

    def generate_padding_mask(self, x):
T
Topdu 已提交
360
        padding_mask = paddle.equal(x, paddle.to_tensor(0, dtype=x.dtype))
T
Topdu 已提交
361 362 363
        return padding_mask

    def _reset_parameters(self):
364
        """Initiate parameters in the transformer model."""
T
Topdu 已提交
365 366 367 368 369 370 371

        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)


class TransformerEncoder(nn.Layer):
372
    """TransformerEncoder is a stack of N encoder layers
T
Topdu 已提交
373 374 375 376 377 378 379 380 381 382 383 384
    Args:
        encoder_layer: an instance of the TransformerEncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).
    """

    def __init__(self, encoder_layer, num_layers):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers

    def forward(self, src):
385
        """Pass the input through the endocder layers in turn.
T
Topdu 已提交
386 387 388 389 390 391 392 393
        Args:
            src: the sequnce to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        """
        output = src

        for i in range(self.num_layers):
394 395
            output = self.layers[i](output,
                                    src_mask=None,
T
Topdu 已提交
396 397 398 399 400 401
                                    src_key_padding_mask=None)

        return output


class TransformerDecoder(nn.Layer):
402
    """TransformerDecoder is a stack of N decoder layers
T
Topdu 已提交
403 404 405 406 407 408 409 410 411 412 413 414 415

    Args:
        decoder_layer: an instance of the TransformerDecoderLayer() class (required).
        num_layers: the number of sub-decoder-layers in the decoder (required).
        norm: the layer normalization component (optional).

    """

    def __init__(self, decoder_layer, num_layers):
        super(TransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers

416 417 418 419 420 421
    def forward(self,
                tgt,
                memory,
                tgt_mask=None,
                memory_mask=None,
                tgt_key_padding_mask=None,
T
Topdu 已提交
422
                memory_key_padding_mask=None):
423
        """Pass the inputs (and mask) through the decoder layer in turn.
T
Topdu 已提交
424 425 426 427 428 429 430 431 432 433 434

        Args:
            tgt: the sequence to the decoder (required).
            memory: the sequnce from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).
        """
        output = tgt
        for i in range(self.num_layers):
435 436 437 438 439 440 441
            output = self.layers[i](
                output,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask)
T
Topdu 已提交
442 443 444

        return output

445

T
Topdu 已提交
446
class TransformerEncoderLayer(nn.Layer):
447
    """TransformerEncoderLayer is made up of self-attn and feedforward network.
T
Topdu 已提交
448 449 450 451 452 453 454 455 456 457 458 459 460 461
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).

    """

462 463 464 465 466 467
    def __init__(self,
                 d_model,
                 nhead,
                 dim_feedforward=2048,
                 attention_dropout_rate=0.0,
                 residual_dropout_rate=0.1):
T
Topdu 已提交
468
        super(TransformerEncoderLayer, self).__init__()
T
Topdu 已提交
469
        self.self_attn = MultiheadAttention(
470 471 472 473 474 475 476 477 478 479
            d_model, nhead, dropout=attention_dropout_rate)

        self.conv1 = Conv2D(
            in_channels=d_model,
            out_channels=dim_feedforward,
            kernel_size=(1, 1))
        self.conv2 = Conv2D(
            in_channels=dim_feedforward,
            out_channels=d_model,
            kernel_size=(1, 1))
T
Topdu 已提交
480 481 482 483 484 485 486

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(residual_dropout_rate)
        self.dropout2 = Dropout(residual_dropout_rate)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
487
        """Pass the input through the endocder layer.
T
Topdu 已提交
488 489 490 491 492
        Args:
            src: the sequnce to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        """
493 494 495 496 497
        src2 = self.self_attn(
            src,
            src,
            src,
            attn_mask=src_mask,
T
Topdu 已提交
498
            key_padding_mask=src_key_padding_mask)
T
Topdu 已提交
499 500 501
        src = src + self.dropout1(src2)
        src = self.norm1(src)

T
Topdu 已提交
502
        src = paddle.transpose(src, [1, 2, 0])
T
Topdu 已提交
503 504 505
        src = paddle.unsqueeze(src, 2)
        src2 = self.conv2(F.relu(self.conv1(src)))
        src2 = paddle.squeeze(src2, 2)
T
Topdu 已提交
506
        src2 = paddle.transpose(src2, [2, 0, 1])
T
Topdu 已提交
507
        src = paddle.squeeze(src, 2)
T
Topdu 已提交
508
        src = paddle.transpose(src, [2, 0, 1])
T
Topdu 已提交
509 510 511 512 513

        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

514

T
Topdu 已提交
515
class TransformerDecoderLayer(nn.Layer):
516
    """TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
T
Topdu 已提交
517 518 519 520 521 522 523 524 525 526 527 528 529 530
    This standard decoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).

    """

531 532 533 534 535 536
    def __init__(self,
                 d_model,
                 nhead,
                 dim_feedforward=2048,
                 attention_dropout_rate=0.0,
                 residual_dropout_rate=0.1):
T
Topdu 已提交
537
        super(TransformerDecoderLayer, self).__init__()
T
Topdu 已提交
538
        self.self_attn = MultiheadAttention(
539
            d_model, nhead, dropout=attention_dropout_rate)
T
Topdu 已提交
540
        self.multihead_attn = MultiheadAttention(
541 542 543 544 545 546 547 548 549 550
            d_model, nhead, dropout=attention_dropout_rate)

        self.conv1 = Conv2D(
            in_channels=d_model,
            out_channels=dim_feedforward,
            kernel_size=(1, 1))
        self.conv2 = Conv2D(
            in_channels=dim_feedforward,
            out_channels=d_model,
            kernel_size=(1, 1))
T
Topdu 已提交
551 552 553 554 555 556 557 558

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout1 = Dropout(residual_dropout_rate)
        self.dropout2 = Dropout(residual_dropout_rate)
        self.dropout3 = Dropout(residual_dropout_rate)

559 560 561 562 563 564 565 566
    def forward(self,
                tgt,
                memory,
                tgt_mask=None,
                memory_mask=None,
                tgt_key_padding_mask=None,
                memory_key_padding_mask=None):
        """Pass the inputs (and mask) through the decoder layer.
T
Topdu 已提交
567 568 569 570 571 572 573 574 575 576

        Args:
            tgt: the sequence to the decoder layer (required).
            memory: the sequnce from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).

        """
577 578 579 580 581
        tgt2 = self.self_attn(
            tgt,
            tgt,
            tgt,
            attn_mask=tgt_mask,
T
Topdu 已提交
582
            key_padding_mask=tgt_key_padding_mask)
T
Topdu 已提交
583 584
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
585 586 587 588 589
        tgt2 = self.multihead_attn(
            tgt,
            memory,
            memory,
            attn_mask=memory_mask,
T
Topdu 已提交
590
            key_padding_mask=memory_key_padding_mask)
T
Topdu 已提交
591 592 593 594
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)

        # default
T
Topdu 已提交
595
        tgt = paddle.transpose(tgt, [1, 2, 0])
T
Topdu 已提交
596 597 598
        tgt = paddle.unsqueeze(tgt, 2)
        tgt2 = self.conv2(F.relu(self.conv1(tgt)))
        tgt2 = paddle.squeeze(tgt2, 2)
T
Topdu 已提交
599
        tgt2 = paddle.transpose(tgt2, [2, 0, 1])
T
Topdu 已提交
600
        tgt = paddle.squeeze(tgt, 2)
T
Topdu 已提交
601
        tgt = paddle.transpose(tgt, [2, 0, 1])
T
Topdu 已提交
602 603 604 605 606 607 608 609 610 611 612

        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt


def _get_clones(module, N):
    return LayerList([copy.deepcopy(module) for i in range(N)])


class PositionalEncoding(nn.Layer):
613
    """Inject some information about the relative or absolute position of the tokens
T
Topdu 已提交
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, dropout, dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = paddle.zeros([max_len, dim])
        position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
635 636 637
        div_term = paddle.exp(
            paddle.arange(0, dim, 2).astype('float32') *
            (-math.log(10000.0) / dim))
T
Topdu 已提交
638 639
        pe[:, 0::2] = paddle.sin(position * div_term)
        pe[:, 1::2] = paddle.cos(position * div_term)
T
Topdu 已提交
640 641
        pe = paddle.unsqueeze(pe, 0)
        pe = paddle.transpose(pe, [1, 0, 2])
T
Topdu 已提交
642 643 644
        self.register_buffer('pe', pe)

    def forward(self, x):
645
        """Inputs of forward function
T
Topdu 已提交
646 647 648 649 650 651 652 653
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """
T
Topdu 已提交
654
        x = x + self.pe[:paddle.shape(x)[0], :]
T
Topdu 已提交
655 656 657 658
        return self.dropout(x)


class PositionalEncoding_2d(nn.Layer):
659
    """Inject some information about the relative or absolute position of the tokens
T
Topdu 已提交
660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, dropout, dim, max_len=5000):
        super(PositionalEncoding_2d, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = paddle.zeros([max_len, dim])
        position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
681 682 683
        div_term = paddle.exp(
            paddle.arange(0, dim, 2).astype('float32') *
            (-math.log(10000.0) / dim))
T
Topdu 已提交
684 685
        pe[:, 0::2] = paddle.sin(position * div_term)
        pe[:, 1::2] = paddle.cos(position * div_term)
T
Topdu 已提交
686
        pe = paddle.transpose(paddle.unsqueeze(pe, 0), [1, 0, 2])
T
Topdu 已提交
687 688 689 690 691 692 693 694 695 696
        self.register_buffer('pe', pe)

        self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1))
        self.linear1 = nn.Linear(dim, dim)
        self.linear1.weight.data.fill_(1.)
        self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1))
        self.linear2 = nn.Linear(dim, dim)
        self.linear2.weight.data.fill_(1.)

    def forward(self, x):
697
        """Inputs of forward function
T
Topdu 已提交
698 699 700 701 702 703 704 705
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """
T
Topdu 已提交
706
        w_pe = self.pe[:paddle.shape(x)[-1], :]
T
Topdu 已提交
707 708
        w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
        w_pe = w_pe * w1
T
Topdu 已提交
709 710
        w_pe = paddle.transpose(w_pe, [1, 2, 0])
        w_pe = paddle.unsqueeze(w_pe, 2)
T
Topdu 已提交
711

T
Topdu 已提交
712
        h_pe = self.pe[:paddle.shape(x).shape[-2], :]
T
Topdu 已提交
713 714
        w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
        h_pe = h_pe * w2
T
Topdu 已提交
715 716
        h_pe = paddle.transpose(h_pe, [1, 2, 0])
        h_pe = paddle.unsqueeze(h_pe, 3)
T
Topdu 已提交
717 718

        x = x + w_pe + h_pe
T
Topdu 已提交
719 720 721 722
        x = paddle.transpose(
            paddle.reshape(x,
                           [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
            [2, 0, 1])
T
Topdu 已提交
723 724 725 726 727 728 729 730

        return self.dropout(x)


class Embeddings(nn.Layer):
    def __init__(self, d_model, vocab, padding_idx, scale_embedding):
        super(Embeddings, self).__init__()
        self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
731 732 733
        w0 = np.random.normal(0.0, d_model**-0.5,
                              (vocab, d_model)).astype(np.float32)
        self.embedding.weight.set_value(w0)
T
Topdu 已提交
734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751
        self.d_model = d_model
        self.scale_embedding = scale_embedding

    def forward(self, x):
        if self.scale_embedding:
            x = self.embedding(x)
            return x * math.sqrt(self.d_model)
        return self.embedding(x)


class Beam():
    ''' Beam search '''

    def __init__(self, size, device=False):

        self.size = size
        self._done = False
        # The score for each translation on the beam.
752
        self.scores = paddle.zeros((size, ), dtype=paddle.float32)
T
Topdu 已提交
753 754 755 756
        self.all_scores = []
        # The backpointers at each time-step.
        self.prev_ks = []
        # The outputs at each time-step.
757
        self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)]
T
Topdu 已提交
758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782
        self.next_ys[0][0] = 2

    def get_current_state(self):
        "Get the outputs for the current timestep."
        return self.get_tentative_hypothesis()

    def get_current_origin(self):
        "Get the backpointers for the current timestep."
        return self.prev_ks[-1]

    @property
    def done(self):
        return self._done

    def advance(self, word_prob):
        "Update beam status and check if finished or not."
        num_words = word_prob.shape[1]

        # Sum the previous scores.
        if len(self.prev_ks) > 0:
            beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
        else:
            beam_lk = word_prob[0]

        flat_beam_lk = beam_lk.reshape([-1])
783 784
        best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True,
                                                        True)  # 1st sort
T
Topdu 已提交
785 786 787 788 789 790
        self.all_scores.append(self.scores)
        self.scores = best_scores
        # bestScoresId is flattened as a (beam x word) array,
        # so we need to calculate which word and beam each score came from
        prev_k = best_scores_id // num_words
        self.prev_ks.append(prev_k)
791
        self.next_ys.append(best_scores_id - prev_k * num_words)
T
Topdu 已提交
792
        # End condition is when top-of-beam is EOS.
793
        if self.next_ys[-1][0] == 3:
T
Topdu 已提交
794 795 796 797 798 799 800
            self._done = True
            self.all_scores.append(self.scores)

        return self._done

    def sort_scores(self):
        "Sort the scores."
801
        return self.scores, paddle.to_tensor(
T
Topdu 已提交
802
            [i for i in range(int(self.scores.shape[0]))], dtype='int32')
T
Topdu 已提交
803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823

    def get_the_best_score_and_idx(self):
        "Get the score of the best in the beam."
        scores, ids = self.sort_scores()
        return scores[1], ids[1]

    def get_tentative_hypothesis(self):
        "Get the decoded sequence for the current timestep."
        if len(self.next_ys) == 1:
            dec_seq = self.next_ys[0].unsqueeze(1)
        else:
            _, keys = self.sort_scores()
            hyps = [self.get_hypothesis(k) for k in keys]
            hyps = [[2] + h for h in hyps]
            dec_seq = paddle.to_tensor(hyps, dtype='int64')
        return dec_seq

    def get_hypothesis(self, k):
        """ Walk back to construct the full hypothesis. """
        hyp = []
        for j in range(len(self.prev_ks) - 1, -1, -1):
824
            hyp.append(self.next_ys[j + 1][k])
T
Topdu 已提交
825 826
            k = self.prev_ks[j][k]
        return list(map(lambda x: x.item(), hyp[::-1]))