rec_nrtr_head.py 26.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
Topdu 已提交
15 16
import math
import paddle
17
from paddle import nn
T
Topdu 已提交
18 19
import paddle.nn.functional as F
from paddle.nn import LayerList
20 21
# from paddle.nn.initializer import XavierNormal as xavier_uniform_
from paddle.nn import Dropout, Linear, LayerNorm
T
Topdu 已提交
22
import numpy as np
23
from ppocr.modeling.backbones.rec_svtrnet import Mlp, zeros_, ones_
T
Topdu 已提交
24 25
from paddle.nn.initializer import XavierNormal as xavier_normal_

26

T
Topdu 已提交
27
class Transformer(nn.Layer):
28
    """A transformer model. User is able to modify the attributes as needed. The architechture
T
Topdu 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
    is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
    Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
    Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
    Processing Systems, pages 6000-6010.

    Args:
        d_model: the number of expected features in the encoder/decoder inputs (default=512).
        nhead: the number of heads in the multiheadattention models (default=8).
        num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
        num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        custom_encoder: custom encoder (default=None).
        custom_decoder: custom decoder (default=None).
    """

45 46 47 48 49 50
    def __init__(self,
                 d_model=512,
                 nhead=8,
                 num_encoder_layers=6,
                 beam_size=0,
                 num_decoder_layers=6,
51
                 max_len=25,
52 53 54 55 56 57
                 dim_feedforward=1024,
                 attention_dropout_rate=0.0,
                 residual_dropout_rate=0.1,
                 in_channels=0,
                 out_channels=0,
                 scale_embedding=True):
T
Topdu 已提交
58
        super(Transformer, self).__init__()
T
Topdu 已提交
59
        self.out_channels = out_channels + 1
60
        self.max_len = max_len
T
Topdu 已提交
61 62
        self.embedding = Embeddings(
            d_model=d_model,
T
Topdu 已提交
63
            vocab=self.out_channels,
T
Topdu 已提交
64
            padding_idx=0,
65
            scale_embedding=scale_embedding)
T
Topdu 已提交
66
        self.positional_encoding = PositionalEncoding(
67 68 69 70 71 72 73 74 75 76 77 78 79
            dropout=residual_dropout_rate, dim=d_model)

        if num_encoder_layers > 0:
            self.encoder = nn.LayerList([
                TransformerBlock(
                    d_model,
                    nhead,
                    dim_feedforward,
                    attention_dropout_rate,
                    residual_dropout_rate,
                    with_self_attn=True,
                    with_cross_attn=False) for i in range(num_encoder_layers)
            ])
T
Topdu 已提交
80
        else:
81 82 83 84 85 86 87 88 89 90 91 92
            self.encoder = None

        self.decoder = nn.LayerList([
            TransformerBlock(
                d_model,
                nhead,
                dim_feedforward,
                attention_dropout_rate,
                residual_dropout_rate,
                with_self_attn=True,
                with_cross_attn=True) for i in range(num_decoder_layers)
        ])
T
Topdu 已提交
93 94 95 96

        self.beam_size = beam_size
        self.d_model = d_model
        self.nhead = nhead
T
Topdu 已提交
97 98
        self.tgt_word_prj = nn.Linear(
            d_model, self.out_channels, bias_attr=False)
99
        w0 = np.random.normal(0.0, d_model**-0.5,
T
Topdu 已提交
100
                              (d_model, self.out_channels)).astype(np.float32)
T
Topdu 已提交
101 102 103 104
        self.tgt_word_prj.weight.set_value(w0)
        self.apply(self._init_weights)

    def _init_weights(self, m):
105

106
        if isinstance(m, nn.Linear):
T
Topdu 已提交
107 108 109 110
            xavier_normal_(m.weight)
            if m.bias is not None:
                zeros_(m.bias)

111 112
    def forward_train(self, src, tgt):
        tgt = tgt[:, :-1]
T
Topdu 已提交
113

114
        tgt = self.embedding(tgt)
115
        tgt = self.positional_encoding(tgt)
116
        tgt_mask = self.generate_square_subsequent_mask(tgt.shape[1])
T
Topdu 已提交
117

118
        if self.encoder is not None:
119 120 121 122
            src = self.positional_encoding(src)
            for encoder_layer in self.encoder:
                src = encoder_layer(src)
            memory = src  # B N C
123
        else:
124 125 126 127
            memory = src  # B N C
        for decoder_layer in self.decoder:
            tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
        output = tgt
128 129 130 131 132
        logit = self.tgt_word_prj(output)
        return logit

    def forward(self, src, targets=None):
        """Take in and process masked source/target sequences.
T
Topdu 已提交
133 134 135 136
        Args:
            src: the sequence to the encoder (required).
            tgt: the sequence to the decoder (required).
        Shape:
137 138
            - src: :math:`(B, sN, C)`.
            - tgt: :math:`(B, tN, C)`.
T
Topdu 已提交
139
        Examples:
140
            >>> output = transformer_model(src, tgt)
T
Topdu 已提交
141
        """
142 143 144 145

        if self.training:
            max_len = targets[1].max()
            tgt = targets[0][:, :2 + max_len]
T
Topdu 已提交
146 147
            return self.forward_train(src, tgt)
        else:
148
            if self.beam_size > 0:
T
Topdu 已提交
149 150 151 152 153
                return self.forward_beam(src)
            else:
                return self.forward_test(src)

    def forward_test(self, src):
154

T
Topdu 已提交
155
        bs = paddle.shape(src)[0]
156
        if self.encoder is not None:
157 158 159 160
            src = self.positional_encoding(src)
            for encoder_layer in self.encoder:
                src = encoder_layer(src)
            memory = src  # B N C
T
Topdu 已提交
161
        else:
162
            memory = src
163
        dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
T
Topdu 已提交
164
        dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32)
165 166
        for len_dec_seq in range(1, self.max_len):
            dec_seq_embed = self.embedding(dec_seq)
T
Topdu 已提交
167
            dec_seq_embed = self.positional_encoding(dec_seq_embed)
T
Topdu 已提交
168
            tgt_mask = self.generate_square_subsequent_mask(
169 170 171 172 173
                paddle.shape(dec_seq_embed)[1])
            tgt = dec_seq_embed
            for decoder_layer in self.decoder:
                tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
            dec_output = tgt
T
Topdu 已提交
174
            dec_output = dec_output[:, -1, :]
175 176
            word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=-1)
            preds_idx = paddle.argmax(word_prob, axis=-1)
177
            if paddle.equal_all(
T
Topdu 已提交
178
                    preds_idx,
179
                    paddle.full(
T
Topdu 已提交
180
                        paddle.shape(preds_idx), 3, dtype='int64')):
T
Topdu 已提交
181
                break
182
            preds_prob = paddle.max(word_prob, axis=-1)
183
            dec_seq = paddle.concat(
T
Topdu 已提交
184 185 186 187
                [dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1)
            dec_prob = paddle.concat(
                [dec_prob, paddle.reshape(preds_prob, [-1, 1])], axis=1)
        return [dec_seq, dec_prob]
T
Topdu 已提交
188

189
    def forward_beam(self, images):
190
        """ Translation work in one batch """
T
Topdu 已提交
191 192

        def get_inst_idx_to_tensor_position_map(inst_idx_list):
193
            """ Indicate the position of an instance in a tensor. """
194 195 196 197
            return {
                inst_idx: tensor_position
                for tensor_position, inst_idx in enumerate(inst_idx_list)
            }
T
Topdu 已提交
198

199 200
        def collect_active_part(beamed_tensor, curr_active_inst_idx,
                                n_prev_active_inst, n_bm):
201
            """ Collect tensor parts associated to active instances. """
T
Topdu 已提交
202

T
Topdu 已提交
203
            beamed_tensor_shape = paddle.shape(beamed_tensor)
T
Topdu 已提交
204
            n_curr_active_inst = len(curr_active_inst_idx)
T
Topdu 已提交
205 206
            new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1],
                         beamed_tensor_shape[2])
T
Topdu 已提交
207

T
Topdu 已提交
208
            beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
209
            beamed_tensor = beamed_tensor.index_select(
T
Topdu 已提交
210 211
                curr_active_inst_idx, axis=0)
            beamed_tensor = beamed_tensor.reshape(new_shape)
T
Topdu 已提交
212 213 214

            return beamed_tensor

215 216
        def collate_active_info(src_enc, inst_idx_to_position_map,
                                active_inst_idx_list):
T
Topdu 已提交
217 218
            # Sentences which are still active are collected,
            # so the decoder will not run on completed sentences.
219

T
Topdu 已提交
220
            n_prev_active_inst = len(inst_idx_to_position_map)
221 222 223
            active_inst_idx = [
                inst_idx_to_position_map[k] for k in active_inst_idx_list
            ]
T
Topdu 已提交
224
            active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64')
225 226 227 228 229
            active_src_enc = collect_active_part(
                src_enc.transpose([1, 0, 2]), active_inst_idx,
                n_prev_active_inst, n_bm).transpose([1, 0, 2])
            active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)
T
Topdu 已提交
230 231
            return active_src_enc, active_inst_idx_to_position_map

232
        def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
233 234
                             inst_idx_to_position_map, n_bm):
            """ Decode and update beam status, and then return active beam idx """
T
Topdu 已提交
235 236

            def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
237 238 239
                dec_partial_seq = [
                    b.get_current_state() for b in inst_dec_beams if not b.done
                ]
T
Topdu 已提交
240 241 242 243
                dec_partial_seq = paddle.stack(dec_partial_seq)
                dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
                return dec_partial_seq

244 245
            def predict_word(dec_seq, enc_output, n_active_inst, n_bm):
                dec_seq = self.embedding(dec_seq)
T
Topdu 已提交
246
                dec_seq = self.positional_encoding(dec_seq)
T
Topdu 已提交
247
                tgt_mask = self.generate_square_subsequent_mask(
248 249 250 251 252
                    paddle.shape(dec_seq)[1])
                tgt = dec_seq
                for decoder_layer in self.decoder:
                    tgt = decoder_layer(tgt, enc_output, self_mask=tgt_mask)
                dec_output = tgt
253 254
                dec_output = dec_output[:,
                                        -1, :]  # Pick the last step: (bh * bm) * d_h
T
Topdu 已提交
255 256
                word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
                word_prob = paddle.reshape(word_prob, [n_active_inst, n_bm, -1])
T
Topdu 已提交
257 258
                return word_prob

259 260
            def collect_active_inst_idx_list(inst_beams, word_prob,
                                             inst_idx_to_position_map):
T
Topdu 已提交
261 262
                active_inst_idx_list = []
                for inst_idx, inst_position in inst_idx_to_position_map.items():
263 264
                    is_inst_complete = inst_beams[inst_idx].advance(word_prob[
                        inst_position])
T
Topdu 已提交
265 266 267 268 269 270 271
                    if not is_inst_complete:
                        active_inst_idx_list += [inst_idx]

                return active_inst_idx_list

            n_active_inst = len(inst_idx_to_position_map)
            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
272
            word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm)
T
Topdu 已提交
273 274 275 276 277 278 279 280 281 282
            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)
            return active_inst_idx_list

        def collect_hypothesis_and_scores(inst_dec_beams, n_best):
            all_hyp, all_scores = [], []
            for inst_idx in range(len(inst_dec_beams)):
                scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
                all_scores += [scores[:n_best]]
283 284 285 286
                hyps = [
                    inst_dec_beams[inst_idx].get_hypothesis(i)
                    for i in tail_idxs[:n_best]
                ]
T
Topdu 已提交
287 288 289 290 291
                all_hyp += [hyps]
            return all_hyp, all_scores

        with paddle.no_grad():
            #-- Encode
292
            if self.encoder is not None:
293
                src = self.positional_encoding(images)
T
Topdu 已提交
294
                src_enc = self.encoder(src)
T
Topdu 已提交
295
            else:
296
                src_enc = images
T
Topdu 已提交
297 298

            n_bm = self.beam_size
T
Topdu 已提交
299 300 301 302 303
            src_shape = paddle.shape(src_enc)
            inst_dec_beams = [Beam(n_bm) for _ in range(1)]
            active_inst_idx_list = list(range(1))
            # Repeat data for beam search
            src_enc = paddle.tile(src_enc, [1, n_bm, 1])
304 305
            inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)
T
Topdu 已提交
306
            # Decode
307
            for len_dec_seq in range(1, self.max_len):
T
Topdu 已提交
308 309
                src_enc_copy = src_enc.clone()
                active_inst_idx_list = beam_decode_step(
310
                    inst_dec_beams, len_dec_seq, src_enc_copy,
311
                    inst_idx_to_position_map, n_bm)
T
Topdu 已提交
312 313 314
                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>
                src_enc, inst_idx_to_position_map = collate_active_info(
315 316 317 318
                    src_enc_copy, inst_idx_to_position_map,
                    active_inst_idx_list)
        batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
                                                                1)
T
Topdu 已提交
319
        result_hyp = []
T
Topdu 已提交
320 321 322 323
        hyp_scores = []
        for bs_hyp, score in zip(batch_hyp, batch_scores):
            l = len(bs_hyp[0])
            bs_hyp_pad = bs_hyp[0] + [3] * (25 - l)
T
Topdu 已提交
324
            result_hyp.append(bs_hyp_pad)
T
Topdu 已提交
325 326 327 328 329 330 331 332
            score = float(score) / l
            hyp_score = [score for _ in range(25)]
            hyp_scores.append(hyp_score)
        return [
            paddle.to_tensor(
                np.array(result_hyp), dtype=paddle.int64),
            paddle.to_tensor(hyp_scores)
        ]
T
Topdu 已提交
333 334

    def generate_square_subsequent_mask(self, sz):
335
        """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
T
Topdu 已提交
336 337
            Unmasked positions are filled with float(0.0).
        """
338 339 340 341 342 343
        mask = paddle.zeros([sz, sz], dtype='float32')
        mask_inf = paddle.triu(
            paddle.full(
                shape=[sz, sz], dtype='float32', fill_value='-inf'),
            diagonal=1)
        mask = mask + mask_inf
344
        return mask.unsqueeze([0, 1])
T
Topdu 已提交
345 346


347 348 349 350
class MultiheadAttention(nn.Layer):
    """Allows the model to jointly attend to information
    from different representation subspaces.
    See reference: Attention Is All You Need
T
Topdu 已提交
351

352 353 354
    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
        \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
T
Topdu 已提交
355 356

    Args:
357 358
        embed_dim: total dimension of the model
        num_heads: parallel attention layers, or heads
T
Topdu 已提交
359

360
    """
T
Topdu 已提交
361

362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
    def __init__(self, embed_dim, num_heads, dropout=0., self_attn=False):
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        # self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.scale = self.head_dim**-0.5
        self.self_attn = self_attn
        if self_attn:
            self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        else:
            self.q = nn.Linear(embed_dim, embed_dim)
            self.kv = nn.Linear(embed_dim, embed_dim * 2)
        self.attn_drop = nn.Dropout(dropout)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
T
Topdu 已提交
378

379
    def forward(self, query, key=None, attn_mask=None):
T
Topdu 已提交
380

381
        qN = query.shape[1]
T
Topdu 已提交
382

383 384 385 386 387 388 389 390 391 392 393 394 395
        if self.self_attn:
            qkv = self.qkv(query).reshape(
                (0, qN, 3, self.num_heads, self.head_dim)).transpose(
                    (2, 0, 3, 1, 4))
            q, k, v = qkv[0], qkv[1], qkv[2]
        else:
            kN = key.shape[1]
            q = self.q(query).reshape(
                [0, qN, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3])
            kv = self.kv(key).reshape(
                (0, kN, 2, self.num_heads, self.head_dim)).transpose(
                    (2, 0, 3, 1, 4))
            k, v = kv[0], kv[1]
T
Topdu 已提交
396

397
        attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
T
Topdu 已提交
398

399 400
        if attn_mask is not None:
            attn += attn_mask
T
Topdu 已提交
401

402 403
        attn = F.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)
T
Topdu 已提交
404

405 406 407
        x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape(
            (0, qN, self.embed_dim))
        x = self.out_proj(x)
T
Topdu 已提交
408

409
        return x
T
Topdu 已提交
410 411


412
class TransformerBlock(nn.Layer):
413 414 415 416 417
    def __init__(self,
                 d_model,
                 nhead,
                 dim_feedforward=2048,
                 attention_dropout_rate=0.0,
418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
                 residual_dropout_rate=0.1,
                 with_self_attn=True,
                 with_cross_attn=False,
                 epsilon=1e-5):
        super(TransformerBlock, self).__init__()
        self.with_self_attn = with_self_attn
        if with_self_attn:
            self.self_attn = MultiheadAttention(
                d_model,
                nhead,
                dropout=attention_dropout_rate,
                self_attn=with_self_attn)
            self.norm1 = LayerNorm(d_model, epsilon=epsilon)
            self.dropout1 = Dropout(residual_dropout_rate)
        self.with_cross_attn = with_cross_attn
        if with_cross_attn:
            self.cross_attn = MultiheadAttention(  #for self_attn of encoder or cross_attn of decoder
                d_model,
                nhead,
                dropout=attention_dropout_rate)
            self.norm2 = LayerNorm(d_model, epsilon=epsilon)
            self.dropout2 = Dropout(residual_dropout_rate)

        self.mlp = Mlp(in_features=d_model,
                       hidden_features=dim_feedforward,
                       act_layer=nn.ReLU,
                       drop=residual_dropout_rate)

        self.norm3 = LayerNorm(d_model, epsilon=epsilon)
T
Topdu 已提交
447 448 449

        self.dropout3 = Dropout(residual_dropout_rate)

450 451 452 453
    def forward(self, tgt, memory=None, self_mask=None, cross_mask=None):
        if self.with_self_attn:
            tgt1 = self.self_attn(tgt, attn_mask=self_mask)
            tgt = self.norm1(tgt + self.dropout1(tgt1))
T
Topdu 已提交
454

455 456 457 458
        if self.with_cross_attn:
            tgt2 = self.cross_attn(tgt, key=memory, attn_mask=cross_mask)
            tgt = self.norm2(tgt + self.dropout2(tgt2))
        tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
T
Topdu 已提交
459 460 461 462
        return tgt


class PositionalEncoding(nn.Layer):
463
    """Inject some information about the relative or absolute position of the tokens
T
Topdu 已提交
464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, dropout, dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = paddle.zeros([max_len, dim])
        position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
485 486 487
        div_term = paddle.exp(
            paddle.arange(0, dim, 2).astype('float32') *
            (-math.log(10000.0) / dim))
T
Topdu 已提交
488 489
        pe[:, 0::2] = paddle.sin(position * div_term)
        pe[:, 1::2] = paddle.cos(position * div_term)
T
Topdu 已提交
490 491
        pe = paddle.unsqueeze(pe, 0)
        pe = paddle.transpose(pe, [1, 0, 2])
T
Topdu 已提交
492 493 494
        self.register_buffer('pe', pe)

    def forward(self, x):
495
        """Inputs of forward function
T
Topdu 已提交
496 497 498 499 500 501 502 503
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """
504
        x = x.transpose([1, 0, 2])
T
Topdu 已提交
505
        x = x + self.pe[:paddle.shape(x)[0], :]
506
        return self.dropout(x).transpose([1, 0, 2])
T
Topdu 已提交
507 508 509


class PositionalEncoding_2d(nn.Layer):
510
    """Inject some information about the relative or absolute position of the tokens
T
Topdu 已提交
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, dropout, dim, max_len=5000):
        super(PositionalEncoding_2d, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = paddle.zeros([max_len, dim])
        position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
532 533 534
        div_term = paddle.exp(
            paddle.arange(0, dim, 2).astype('float32') *
            (-math.log(10000.0) / dim))
T
Topdu 已提交
535 536
        pe[:, 0::2] = paddle.sin(position * div_term)
        pe[:, 1::2] = paddle.cos(position * div_term)
T
Topdu 已提交
537
        pe = paddle.transpose(paddle.unsqueeze(pe, 0), [1, 0, 2])
T
Topdu 已提交
538 539 540 541 542 543 544 545 546 547
        self.register_buffer('pe', pe)

        self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1))
        self.linear1 = nn.Linear(dim, dim)
        self.linear1.weight.data.fill_(1.)
        self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1))
        self.linear2 = nn.Linear(dim, dim)
        self.linear2.weight.data.fill_(1.)

    def forward(self, x):
548
        """Inputs of forward function
T
Topdu 已提交
549 550 551 552 553 554 555 556
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """
T
Topdu 已提交
557
        w_pe = self.pe[:paddle.shape(x)[-1], :]
T
Topdu 已提交
558 559
        w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
        w_pe = w_pe * w1
T
Topdu 已提交
560 561
        w_pe = paddle.transpose(w_pe, [1, 2, 0])
        w_pe = paddle.unsqueeze(w_pe, 2)
T
Topdu 已提交
562

T
Topdu 已提交
563
        h_pe = self.pe[:paddle.shape(x).shape[-2], :]
T
Topdu 已提交
564 565
        w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
        h_pe = h_pe * w2
T
Topdu 已提交
566 567
        h_pe = paddle.transpose(h_pe, [1, 2, 0])
        h_pe = paddle.unsqueeze(h_pe, 3)
T
Topdu 已提交
568 569

        x = x + w_pe + h_pe
T
Topdu 已提交
570 571 572 573
        x = paddle.transpose(
            paddle.reshape(x,
                           [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
            [2, 0, 1])
T
Topdu 已提交
574 575 576 577 578

        return self.dropout(x)


class Embeddings(nn.Layer):
579
    def __init__(self, d_model, vocab, padding_idx=None, scale_embedding=True):
T
Topdu 已提交
580 581
        super(Embeddings, self).__init__()
        self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
582 583 584
        w0 = np.random.normal(0.0, d_model**-0.5,
                              (vocab, d_model)).astype(np.float32)
        self.embedding.weight.set_value(w0)
T
Topdu 已提交
585 586 587 588 589 590 591 592 593 594 595
        self.d_model = d_model
        self.scale_embedding = scale_embedding

    def forward(self, x):
        if self.scale_embedding:
            x = self.embedding(x)
            return x * math.sqrt(self.d_model)
        return self.embedding(x)


class Beam():
596
    """ Beam search """
T
Topdu 已提交
597 598 599 600 601 602

    def __init__(self, size, device=False):

        self.size = size
        self._done = False
        # The score for each translation on the beam.
603
        self.scores = paddle.zeros((size, ), dtype=paddle.float32)
T
Topdu 已提交
604 605 606 607
        self.all_scores = []
        # The backpointers at each time-step.
        self.prev_ks = []
        # The outputs at each time-step.
608
        self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)]
T
Topdu 已提交
609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633
        self.next_ys[0][0] = 2

    def get_current_state(self):
        "Get the outputs for the current timestep."
        return self.get_tentative_hypothesis()

    def get_current_origin(self):
        "Get the backpointers for the current timestep."
        return self.prev_ks[-1]

    @property
    def done(self):
        return self._done

    def advance(self, word_prob):
        "Update beam status and check if finished or not."
        num_words = word_prob.shape[1]

        # Sum the previous scores.
        if len(self.prev_ks) > 0:
            beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
        else:
            beam_lk = word_prob[0]

        flat_beam_lk = beam_lk.reshape([-1])
634 635
        best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True,
                                                        True)  # 1st sort
T
Topdu 已提交
636 637 638 639 640 641
        self.all_scores.append(self.scores)
        self.scores = best_scores
        # bestScoresId is flattened as a (beam x word) array,
        # so we need to calculate which word and beam each score came from
        prev_k = best_scores_id // num_words
        self.prev_ks.append(prev_k)
642
        self.next_ys.append(best_scores_id - prev_k * num_words)
T
Topdu 已提交
643
        # End condition is when top-of-beam is EOS.
644
        if self.next_ys[-1][0] == 3:
T
Topdu 已提交
645 646 647 648 649 650 651
            self._done = True
            self.all_scores.append(self.scores)

        return self._done

    def sort_scores(self):
        "Sort the scores."
652
        return self.scores, paddle.to_tensor(
T
Topdu 已提交
653
            [i for i in range(int(self.scores.shape[0]))], dtype='int32')
T
Topdu 已提交
654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674

    def get_the_best_score_and_idx(self):
        "Get the score of the best in the beam."
        scores, ids = self.sort_scores()
        return scores[1], ids[1]

    def get_tentative_hypothesis(self):
        "Get the decoded sequence for the current timestep."
        if len(self.next_ys) == 1:
            dec_seq = self.next_ys[0].unsqueeze(1)
        else:
            _, keys = self.sort_scores()
            hyps = [self.get_hypothesis(k) for k in keys]
            hyps = [[2] + h for h in hyps]
            dec_seq = paddle.to_tensor(hyps, dtype='int64')
        return dec_seq

    def get_hypothesis(self, k):
        """ Walk back to construct the full hypothesis. """
        hyp = []
        for j in range(len(self.prev_ks) - 1, -1, -1):
675
            hyp.append(self.next_ys[j + 1][k])
T
Topdu 已提交
676 677
            k = self.prev_ks[j][k]
        return list(map(lambda x: x.item(), hyp[::-1]))