model.py 25.5 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7
from config import *
Y
ying 已提交
8

9 10

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
11
    """
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
30
                         n_head=1,
G
guosheng 已提交
31 32
                         dropout_rate=0.,
                         pre_softmax_shape=None,
33 34
                         post_softmax_shape=None,
                         cache=None):
35
    """
Y
ying 已提交
36 37 38
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
39 40 41
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
42
            "Inputs: quries, keys and values should all be 3-D tensors.")
43

G
guosheng 已提交
44
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
45
        """
46 47 48
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
49
                      size=d_key * n_head,
50 51 52
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
53
                      size=d_key * n_head,
54 55 56
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
57
                      size=d_value * n_head,
58 59 60 61
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
62
    def __split_heads(x, n_head):
63 64 65
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
66 67
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
68
        """
G
guosheng 已提交
69
        if n_head == 1:
70 71 72
            return x

        hidden_size = x.shape[-1]
73 74
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
75
        reshaped = layers.reshape(
76
            x=x, shape=[0, 0, n_head, hidden_size // n_head])
77 78

        # permuate the dimensions into:
G
guosheng 已提交
79
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
80 81 82 83 84 85 86 87 88 89 90 91
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
92 93
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
94 95
        return layers.reshape(
            x=trans_x,
96
            shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
97

G
guosheng 已提交
98
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
99 100 101
        """
        Scaled Dot-Product Attention
        """
G
guosheng 已提交
102
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
103
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
G
guosheng 已提交
104 105 106 107 108 109 110 111
        weights = layers.reshape(
            x=layers.elementwise_add(
                x=product, y=attn_bias) if attn_bias else product,
            shape=[-1, product.shape[-1]],
            actual_shape=pre_softmax_shape,
            act="softmax")
        weights = layers.reshape(
            x=weights, shape=product.shape, actual_shape=post_softmax_shape)
112 113
        if dropout_rate:
            weights = layers.dropout(
G
guosheng 已提交
114 115 116 117
                weights,
                dropout_prob=dropout_rate,
                seed=ModelHyperParams.dropout_seed,
                is_test=False)
118 119 120
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
121
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
122

123 124 125
    if cache is not None:  # use cache and concat time steps
        k = cache["k"] = layers.concat([cache["k"], k], axis=1)
        v = cache["v"] = layers.concat([cache["v"], v], axis=1)
126

G
guosheng 已提交
127 128 129
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
130

G
guosheng 已提交
131
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
132 133 134
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)
135

136 137 138 139 140 141 142 143 144 145
    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
146 147 148
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
149 150 151 152 153
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
                       act="relu")
G
guosheng 已提交
154
    out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
155 156 157
    return out


158
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
159
    """
Y
ying 已提交
160
    Add residual connection, layer normalization and droput to the out tensor
161 162 163 164 165
    optionally according to the value of process_cmd.
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
166
        if cmd == "a":  # add residual connection
167
            out = out + prev_out if prev_out else out
Y
ying 已提交
168
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
169 170 171 172 173
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
174
        elif cmd == "d":  # add dropout
175 176
            if dropout_rate:
                out = layers.dropout(
G
guosheng 已提交
177 178 179 180
                    out,
                    dropout_prob=dropout_rate,
                    seed=ModelHyperParams.dropout_seed,
                    is_test=False)
181 182 183 184 185 186 187 188 189 190 191 192
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_max_len,
193 194
                    dropout_rate=0.,
                    src_data_shape=None,
G
guosheng 已提交
195
                    word_emb_param_name=None,
196
                    pos_enc_param_name=None):
Y
ying 已提交
197 198
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
199
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
200
    This module is used at the bottom of the encoder stacks.
201 202
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
203 204
        src_word,
        size=[src_vocab_size, src_emb_dim],
G
guosheng 已提交
205 206 207 208
        param_attr=fluid.ParamAttr(
            name=word_emb_param_name,
            initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
    src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
209 210 211 212 213 214
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
215 216
    enc_input = layers.reshape(
        x=enc_input,
217
        shape=[batch_size, seq_len, src_emb_dim],
218
        actual_shape=src_data_shape)
219
    return layers.dropout(
G
guosheng 已提交
220 221 222
        enc_input,
        dropout_prob=dropout_rate,
        seed=ModelHyperParams.dropout_seed,
223
        is_test=False) if dropout_rate else enc_input
224 225 226 227 228 229 230 231


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
232 233 234 235 236 237 238
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
239 240 241
                  dropout_rate=0.,
                  pre_softmax_shape=None,
                  post_softmax_shape=None):
Y
ying 已提交
242 243 244 245 246
    """The encoder layers that can be stacked to form a deep encoder.
    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
247
    """
G
guosheng 已提交
248 249 250
    attn_output = multi_head_attention(
        enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model,
        n_head, dropout_rate, pre_softmax_shape, post_softmax_shape)
Y
ying 已提交
251 252
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
253
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
254 255 256 257 258 259 260 261 262 263 264
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
265 266 267
            dropout_rate=0.,
            pre_softmax_shape=None,
            post_softmax_shape=None):
268
    """
Y
ying 已提交
269 270
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
271 272
    """
    for i in range(n_layer):
273 274 275 276 277 278 279 280
        enc_output = encoder_layer(
            enc_input,
            attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
281 282 283
            dropout_rate,
            pre_softmax_shape,
            post_softmax_shape, )
284 285 286 287
        enc_input = enc_output
    return enc_output


Y
ying 已提交
288 289 290 291 292 293 294 295 296
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
297 298 299 300
                  dropout_rate=0.,
                  slf_attn_pre_softmax_shape=None,
                  slf_attn_post_softmax_shape=None,
                  src_attn_pre_softmax_shape=None,
301 302
                  src_attn_post_softmax_shape=None,
                  cache=None):
Y
ying 已提交
303 304 305
    """ The layer to be stacked in decoder part.
    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
306
    """
Y
ying 已提交
307 308 309 310 311 312 313 314 315
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
316 317
        dropout_rate,
        slf_attn_pre_softmax_shape,
318 319
        slf_attn_post_softmax_shape,
        cache, )
Y
ying 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332 333
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
334 335 336
        dropout_rate,
        src_attn_pre_softmax_shape,
        src_attn_post_softmax_shape, )
Y
ying 已提交
337 338 339 340 341 342 343 344 345 346 347 348 349 350
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
351 352 353
    return dec_output


Y
ying 已提交
354 355 356 357 358 359 360 361 362 363
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
364 365 366 367
            dropout_rate=0.,
            slf_attn_pre_softmax_shape=None,
            slf_attn_post_softmax_shape=None,
            src_attn_pre_softmax_shape=None,
368 369
            src_attn_post_softmax_shape=None,
            caches=None):
370 371 372 373
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
374
        dec_output = decoder_layer(
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate,
            slf_attn_pre_softmax_shape,
            slf_attn_post_softmax_shape,
            src_attn_pre_softmax_shape,
            src_attn_post_softmax_shape,
            None if caches is None else caches[i], )
390 391 392 393
        dec_input = dec_output
    return dec_output


394
def make_all_inputs(input_fields):
395 396 397
    """
    Define the input data layers for the transformer model.
    """
398 399 400 401 402 403
    inputs = []
    for input_field in input_fields:
        input_var = layers.data(
            name=input_field,
            shape=input_descs[input_field][0],
            dtype=input_descs[input_field][1],
404 405
            lod_level=input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0,
406
            append_batch_size=False)
407 408
        inputs.append(input_var)
    return inputs
409 410


Y
ying 已提交
411 412 413 414 415 416 417 418 419 420
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
421
        dropout_rate,
G
guosheng 已提交
422
        weight_sharing,
423
        label_smooth_eps, ):
G
guosheng 已提交
424 425 426 427
    if weight_sharing:
        assert src_vocab_size == src_vocab_size, (
            "Vocabularies in source and target should be same for weight sharing."
        )
428 429
    enc_inputs = make_all_inputs(encoder_data_input_fields +
                                 encoder_util_input_fields)
430

431 432 433 434 435 436 437 438 439 440
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
G
guosheng 已提交
441
        weight_sharing,
442
        enc_inputs, )
443

444 445
    dec_inputs = make_all_inputs(decoder_data_input_fields[:-1] +
                                 decoder_util_input_fields)
446 447 448 449 450 451 452 453 454 455 456

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
G
guosheng 已提交
457
        weight_sharing,
458
        dec_inputs,
459 460 461 462
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
463 464 465 466 467 468 469 470 471 472
    label, weights = make_all_inputs(label_data_input_fields)
    if label_smooth_eps:
        label = layers.label_smooth(
            label=layers.one_hot(
                input=label, depth=trg_vocab_size),
            epsilon=label_smooth_eps)
    cost = layers.softmax_with_cross_entropy(
        logits=predict,
        label=label,
        soft_label=True if label_smooth_eps else False)
473
    weighted_cost = cost * weights
G
guosheng 已提交
474 475 476
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
    avg_cost = sum_cost / token_num
G
guosheng 已提交
477
    return sum_cost, avg_cost, predict, token_num
478 479 480 481 482 483 484 485 486 487 488


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
G
guosheng 已提交
489
                 weight_sharing,
490
                 enc_inputs=None):
491 492 493
    """
    The wrapper assembles together all needed layers for the encoder.
    """
494
    if enc_inputs is None:
495
        # This is used to implement independent encoder program in inference.
496 497
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
498 499
            make_all_inputs(encoder_data_input_fields +
                                 encoder_util_input_fields)
500
    else:
501 502 503
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
            enc_inputs
Y
ying 已提交
504 505 506 507 508 509
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
510
        dropout_rate,
G
guosheng 已提交
511 512
        src_data_shape,
        word_emb_param_name=word_emb_param_names[0])
Y
ying 已提交
513 514 515 516 517 518 519 520 521
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
522 523 524
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape, )
525 526 527 528 529 530 531 532 533 534 535 536
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
G
guosheng 已提交
537
                 weight_sharing,
538
                 dec_inputs=None,
539 540
                 enc_output=None,
                 caches=None):
541 542 543
    """
    The wrapper assembles together all needed layers for the decoder.
    """
544
    if dec_inputs is None:
545
        # This is used to implement independent decoder program in inference.
G
guosheng 已提交
546
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
547
            enc_output, trg_data_shape, slf_attn_pre_softmax_shape, \
548
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
549 550
            src_attn_post_softmax_shape = make_all_inputs(
            decoder_data_input_fields + decoder_util_input_fields)
551
    else:
G
guosheng 已提交
552
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
553 554 555
            trg_data_shape, slf_attn_pre_softmax_shape, \
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
            src_attn_post_softmax_shape = dec_inputs
Y
ying 已提交
556 557 558 559 560 561 562

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
563
        dropout_rate,
G
guosheng 已提交
564 565 566
        trg_data_shape,
        word_emb_param_name=word_emb_param_names[0]
        if weight_sharing else word_emb_param_names[1])
Y
ying 已提交
567 568 569 570 571 572 573 574 575 576 577
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
578 579 580 581
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape,
        src_attn_pre_softmax_shape,
582 583
        src_attn_post_softmax_shape,
        caches, )
584
    # Return logits for training and probs for inference.
G
guosheng 已提交
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
    if weight_sharing:
        predict = layers.reshape(
            x=layers.matmul(
                x=dec_output,
                y=fluid.get_var(word_emb_param_names[0]),
                transpose_y=True),
            shape=[-1, trg_vocab_size],
            act="softmax" if dec_inputs is None else None)
    else:
        predict = layers.reshape(
            x=layers.fc(input=dec_output,
                        size=trg_vocab_size,
                        bias_attr=False,
                        num_flatten_dims=2),
            shape=[-1, trg_vocab_size],
            act="softmax" if dec_inputs is None else None)
601
    return predict
602 603 604 605 606 607 608 609 610 611 612 613 614


def fast_decode(
        src_vocab_size,
        trg_vocab_size,
        max_in_len,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
615
        weight_sharing,
616 617 618
        beam_size,
        max_out_len,
        eos_idx, ):
619 620 621 622
    """
    Use beam search to decode. Caches will be used to store states of history
    steps which can make the decoding faster.
    """
623 624
    enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
                              d_key, d_value, d_model, d_inner_hid,
625
                              dropout_rate, weight_sharing)
626
    start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
627
        slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
628 629 630 631
        src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
        attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
        make_all_inputs(fast_decoder_data_input_fields +
                            fast_decoder_util_input_fields)
632 633 634

    def beam_search():
        max_len = layers.fill_constant(
635 636 637 638 639
            shape=[1], dtype=start_tokens.dtype, value=max_out_len)
        step_idx = layers.fill_constant(
            shape=[1], dtype=start_tokens.dtype, value=0)
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond)
640
        # array states will be stored for each step.
641 642
        ids = layers.array_write(start_tokens, step_idx)
        scores = layers.array_write(init_scores, step_idx)
643 644 645
        # cell states will be overwrited at each step.
        # caches contains states of history steps to reduce redundant
        # computation in decoder.
646 647 648 649
        caches = [{
            "k": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
650
                dtype=enc_output.dtype,
651 652 653 654
                value=0),
            "v": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
655
                dtype=enc_output.dtype,
656 657 658 659 660
                value=0)
        } for i in range(n_layer)]
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
            pre_scores = layers.array_read(array=scores, i=step_idx)
661 662
            # sequence_expand can gather sequences according to lod thus can be
            # used in beam search to sift states corresponding to selected ids.
663
            pre_src_attn_bias = layers.sequence_expand(
664 665
                x=trg_src_attn_bias, y=pre_scores)
            pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
666 667
            pre_caches = [{
                "k": layers.sequence_expand(
668
                    x=cache["k"], y=pre_scores),
669
                "v": layers.sequence_expand(
670
                    x=cache["v"], y=pre_scores),
671
            } for cache in caches]
672 673 674 675 676 677 678 679 680
            pre_pos = layers.elementwise_mul(
                x=layers.fill_constant_batch_size_like(
                    input=pre_enc_output,  # cann't use pre_ids here since it has lod
                    value=1,
                    shape=[-1, 1],
                    dtype=pre_ids.dtype),
                y=layers.increment(
                    x=step_idx, value=1.0, in_place=False),
                axis=0)
681 682 683 684 685 686 687 688 689 690
            logits = wrap_decoder(
                trg_vocab_size,
                max_in_len,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
                dropout_rate,
691
                weight_sharing,
692 693 694 695 696 697
                dec_inputs=(
                    pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
                    slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
                    src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
                enc_output=pre_enc_output,
                caches=pre_caches)
698 699
            topk_scores, topk_indices = layers.topk(
                input=layers.softmax(logits), k=beam_size)
700
            accu_scores = layers.elementwise_add(
701
                x=layers.log(topk_scores),
702 703 704 705 706
                y=layers.reshape(
                    pre_scores, shape=[-1]),
                axis=0)
            # beam_search op uses lod to distinguish branches.
            topk_indices = layers.lod_reset(topk_indices, pre_ids)
707 708
            selected_ids, selected_scores = layers.beam_search(
                pre_ids=pre_ids,
709
                pre_scores=pre_scores,
710 711 712 713 714 715
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
                end_id=eos_idx)
            layers.increment(x=step_idx, value=1.0, in_place=True)
            # update states
716 717
            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
718 719 720 721 722
            layers.assign(pre_src_attn_bias, trg_src_attn_bias)
            layers.assign(pre_enc_output, enc_output)
            for i in range(n_layer):
                layers.assign(pre_caches[i]["k"], caches[i]["k"])
                layers.assign(pre_caches[i]["v"], caches[i]["v"])
723
            layers.assign(
724 725 726
                layers.elementwise_add(
                    x=slf_attn_pre_softmax_shape,
                    y=attn_pre_softmax_shape_delta),
727 728 729 730 731 732
                slf_attn_pre_softmax_shape)
            layers.assign(
                layers.elementwise_add(
                    x=slf_attn_post_softmax_shape,
                    y=attn_post_softmax_shape_delta),
                slf_attn_post_softmax_shape)
733

734 735 736
            length_cond = layers.less_than(x=step_idx, y=max_len)
            finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
            layers.logical_and(x=length_cond, y=finish_cond, out=cond)
737

738 739
        finished_ids, finished_scores = layers.beam_search_decode(
            ids, scores, beam_size=beam_size, end_id=eos_idx)
740 741 742 743
        return finished_ids, finished_scores

    finished_ids, finished_scores = beam_search()
    return finished_ids, finished_scores