model.py 25.3 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7
from config import *
Y
ying 已提交
8

9 10

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
11
    """
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
30
                         n_head=1,
G
guosheng 已提交
31 32
                         dropout_rate=0.,
                         pre_softmax_shape=None,
33 34
                         post_softmax_shape=None,
                         cache=None):
35
    """
Y
ying 已提交
36 37 38
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
39 40 41
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
42
            "Inputs: quries, keys and values should all be 3-D tensors.")
43

G
guosheng 已提交
44
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
45
        """
46 47 48
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
49
                      size=d_key * n_head,
50 51 52
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
53
                      size=d_key * n_head,
54 55 56
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
57
                      size=d_value * n_head,
58 59 60 61
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
62
    def __split_heads(x, n_head):
63 64 65
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
66 67
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
68
        """
G
guosheng 已提交
69
        if n_head == 1:
70 71 72
            return x

        hidden_size = x.shape[-1]
73 74
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
75
        reshaped = layers.reshape(
76
            x=x, shape=[0, 0, n_head, hidden_size // n_head])
77 78

        # permuate the dimensions into:
G
guosheng 已提交
79
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
80 81 82 83 84 85 86 87 88 89 90 91
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
92 93
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
94 95
        return layers.reshape(
            x=trans_x,
96
            shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
97

G
guosheng 已提交
98
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
99 100 101
        """
        Scaled Dot-Product Attention
        """
G
guosheng 已提交
102
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
103
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
G
guosheng 已提交
104 105 106 107 108 109 110 111
        weights = layers.reshape(
            x=layers.elementwise_add(
                x=product, y=attn_bias) if attn_bias else product,
            shape=[-1, product.shape[-1]],
            actual_shape=pre_softmax_shape,
            act="softmax")
        weights = layers.reshape(
            x=weights, shape=product.shape, actual_shape=post_softmax_shape)
112 113 114 115 116 117
        if dropout_rate:
            weights = layers.dropout(
                weights, dropout_prob=dropout_rate, is_test=False)
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
118
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
119

120 121 122
    if cache is not None:  # use cache and concat time steps
        k = cache["k"] = layers.concat([cache["k"], k], axis=1)
        v = cache["v"] = layers.concat([cache["v"], v], axis=1)
123

G
guosheng 已提交
124 125 126
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
127

G
guosheng 已提交
128
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
129 130 131
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)
132

133 134 135 136 137 138 139 140 141 142
    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
143 144 145
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
146 147 148 149 150
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
                       act="relu")
G
guosheng 已提交
151
    out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
152 153 154
    return out


155
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
156
    """
Y
ying 已提交
157
    Add residual connection, layer normalization and droput to the out tensor
158 159 160 161 162
    optionally according to the value of process_cmd.
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
163
        if cmd == "a":  # add residual connection
164
            out = out + prev_out if prev_out else out
Y
ying 已提交
165
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
166 167 168 169 170
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
171
        elif cmd == "d":  # add dropout
172 173 174
            if dropout_rate:
                out = layers.dropout(
                    out, dropout_prob=dropout_rate, is_test=False)
175 176 177 178 179 180 181 182 183 184 185 186
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_max_len,
187 188
                    dropout_rate=0.,
                    src_data_shape=None,
G
guosheng 已提交
189
                    word_emb_param_name=None,
190
                    pos_enc_param_name=None):
Y
ying 已提交
191 192
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
193
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
194
    This module is used at the bottom of the encoder stacks.
195 196
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
197 198
        src_word,
        size=[src_vocab_size, src_emb_dim],
G
guosheng 已提交
199 200 201 202
        param_attr=fluid.ParamAttr(
            name=word_emb_param_name,
            initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
    src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
203 204 205 206 207 208
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
209 210
    enc_input = layers.reshape(
        x=enc_input,
211
        shape=[batch_size, seq_len, src_emb_dim],
212
        actual_shape=src_data_shape)
213
    return layers.dropout(
214 215
        enc_input, dropout_prob=dropout_rate,
        is_test=False) if dropout_rate else enc_input
216 217 218 219 220 221 222 223


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
224 225 226 227 228 229 230
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
231 232 233
                  dropout_rate=0.,
                  pre_softmax_shape=None,
                  post_softmax_shape=None):
Y
ying 已提交
234 235 236 237 238
    """The encoder layers that can be stacked to form a deep encoder.
    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
239
    """
G
guosheng 已提交
240 241 242
    attn_output = multi_head_attention(
        enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model,
        n_head, dropout_rate, pre_softmax_shape, post_softmax_shape)
Y
ying 已提交
243 244
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
245
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
246 247 248 249 250 251 252 253 254 255 256
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
257 258 259
            dropout_rate=0.,
            pre_softmax_shape=None,
            post_softmax_shape=None):
260
    """
Y
ying 已提交
261 262
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
263 264
    """
    for i in range(n_layer):
265 266 267 268 269 270 271 272
        enc_output = encoder_layer(
            enc_input,
            attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
273 274 275
            dropout_rate,
            pre_softmax_shape,
            post_softmax_shape, )
276 277 278 279
        enc_input = enc_output
    return enc_output


Y
ying 已提交
280 281 282 283 284 285 286 287 288
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
289 290 291 292
                  dropout_rate=0.,
                  slf_attn_pre_softmax_shape=None,
                  slf_attn_post_softmax_shape=None,
                  src_attn_pre_softmax_shape=None,
293 294
                  src_attn_post_softmax_shape=None,
                  cache=None):
Y
ying 已提交
295 296 297
    """ The layer to be stacked in decoder part.
    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
298
    """
Y
ying 已提交
299 300 301 302 303 304 305 306 307
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
308 309
        dropout_rate,
        slf_attn_pre_softmax_shape,
310 311
        slf_attn_post_softmax_shape,
        cache, )
Y
ying 已提交
312 313 314 315 316 317 318 319 320 321 322 323 324 325
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
326 327 328
        dropout_rate,
        src_attn_pre_softmax_shape,
        src_attn_post_softmax_shape, )
Y
ying 已提交
329 330 331 332 333 334 335 336 337 338 339 340 341 342
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
343 344 345
    return dec_output


Y
ying 已提交
346 347 348 349 350 351 352 353 354 355
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
356 357 358 359
            dropout_rate=0.,
            slf_attn_pre_softmax_shape=None,
            slf_attn_post_softmax_shape=None,
            src_attn_pre_softmax_shape=None,
360 361
            src_attn_post_softmax_shape=None,
            caches=None):
362 363 364 365
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
366
        dec_output = decoder_layer(
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate,
            slf_attn_pre_softmax_shape,
            slf_attn_post_softmax_shape,
            src_attn_pre_softmax_shape,
            src_attn_post_softmax_shape,
            None if caches is None else caches[i], )
382 383 384 385
        dec_input = dec_output
    return dec_output


386
def make_all_inputs(input_fields):
387 388 389
    """
    Define the input data layers for the transformer model.
    """
390 391 392 393 394 395
    inputs = []
    for input_field in input_fields:
        input_var = layers.data(
            name=input_field,
            shape=input_descs[input_field][0],
            dtype=input_descs[input_field][1],
396 397
            lod_level=input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0,
398
            append_batch_size=False)
399 400
        inputs.append(input_var)
    return inputs
401 402


Y
ying 已提交
403 404 405 406 407 408 409 410 411 412
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
413
        dropout_rate,
G
guosheng 已提交
414
        weight_sharing,
415
        label_smooth_eps, ):
G
guosheng 已提交
416 417 418 419
    if weight_sharing:
        assert src_vocab_size == src_vocab_size, (
            "Vocabularies in source and target should be same for weight sharing."
        )
420 421
    enc_inputs = make_all_inputs(encoder_data_input_fields +
                                 encoder_util_input_fields)
422

423 424 425 426 427 428 429 430 431 432
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
G
guosheng 已提交
433
        weight_sharing,
434
        enc_inputs, )
435

436 437
    dec_inputs = make_all_inputs(decoder_data_input_fields[:-1] +
                                 decoder_util_input_fields)
438 439 440 441 442 443 444 445 446 447 448

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
G
guosheng 已提交
449
        weight_sharing,
450
        dec_inputs,
451 452 453 454
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
455 456 457 458 459 460 461 462 463 464
    label, weights = make_all_inputs(label_data_input_fields)
    if label_smooth_eps:
        label = layers.label_smooth(
            label=layers.one_hot(
                input=label, depth=trg_vocab_size),
            epsilon=label_smooth_eps)
    cost = layers.softmax_with_cross_entropy(
        logits=predict,
        label=label,
        soft_label=True if label_smooth_eps else False)
465
    weighted_cost = cost * weights
G
guosheng 已提交
466 467 468
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
    avg_cost = sum_cost / token_num
G
guosheng 已提交
469
    return sum_cost, avg_cost, predict, token_num
470 471 472 473 474 475 476 477 478 479 480


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
G
guosheng 已提交
481
                 weight_sharing,
482
                 enc_inputs=None):
483 484 485
    """
    The wrapper assembles together all needed layers for the encoder.
    """
486
    if enc_inputs is None:
487
        # This is used to implement independent encoder program in inference.
488 489
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
490 491
            make_all_inputs(encoder_data_input_fields +
                                 encoder_util_input_fields)
492
    else:
493 494 495
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
            enc_inputs
Y
ying 已提交
496 497 498 499 500 501
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
502
        dropout_rate,
G
guosheng 已提交
503 504
        src_data_shape,
        word_emb_param_name=word_emb_param_names[0])
Y
ying 已提交
505 506 507 508 509 510 511 512 513
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
514 515 516
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape, )
517 518 519 520 521 522 523 524 525 526 527 528
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
G
guosheng 已提交
529
                 weight_sharing,
530
                 dec_inputs=None,
531 532
                 enc_output=None,
                 caches=None):
533 534 535
    """
    The wrapper assembles together all needed layers for the decoder.
    """
536
    if dec_inputs is None:
537
        # This is used to implement independent decoder program in inference.
G
guosheng 已提交
538
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
539
            enc_output, trg_data_shape, slf_attn_pre_softmax_shape, \
540
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
541 542
            src_attn_post_softmax_shape = make_all_inputs(
            decoder_data_input_fields + decoder_util_input_fields)
543
    else:
G
guosheng 已提交
544
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
545 546 547
            trg_data_shape, slf_attn_pre_softmax_shape, \
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
            src_attn_post_softmax_shape = dec_inputs
Y
ying 已提交
548 549 550 551 552 553 554

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
555
        dropout_rate,
G
guosheng 已提交
556 557 558
        trg_data_shape,
        word_emb_param_name=word_emb_param_names[0]
        if weight_sharing else word_emb_param_names[1])
Y
ying 已提交
559 560 561 562 563 564 565 566 567 568 569
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
570 571 572 573
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape,
        src_attn_pre_softmax_shape,
574 575
        src_attn_post_softmax_shape,
        caches, )
576
    # Return logits for training and probs for inference.
G
guosheng 已提交
577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592
    if weight_sharing:
        predict = layers.reshape(
            x=layers.matmul(
                x=dec_output,
                y=fluid.get_var(word_emb_param_names[0]),
                transpose_y=True),
            shape=[-1, trg_vocab_size],
            act="softmax" if dec_inputs is None else None)
    else:
        predict = layers.reshape(
            x=layers.fc(input=dec_output,
                        size=trg_vocab_size,
                        bias_attr=False,
                        num_flatten_dims=2),
            shape=[-1, trg_vocab_size],
            act="softmax" if dec_inputs is None else None)
593
    return predict
594 595 596 597 598 599 600 601 602 603 604 605 606


def fast_decode(
        src_vocab_size,
        trg_vocab_size,
        max_in_len,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
607
        weight_sharing,
608 609 610
        beam_size,
        max_out_len,
        eos_idx, ):
611 612 613 614
    """
    Use beam search to decode. Caches will be used to store states of history
    steps which can make the decoding faster.
    """
615 616
    enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
                              d_key, d_value, d_model, d_inner_hid,
617
                              dropout_rate, weight_sharing)
618
    start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
619
        slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
620 621 622 623
        src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
        attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
        make_all_inputs(fast_decoder_data_input_fields +
                            fast_decoder_util_input_fields)
624 625 626

    def beam_search():
        max_len = layers.fill_constant(
627 628 629 630 631
            shape=[1], dtype=start_tokens.dtype, value=max_out_len)
        step_idx = layers.fill_constant(
            shape=[1], dtype=start_tokens.dtype, value=0)
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond)
632
        # array states will be stored for each step.
633 634
        ids = layers.array_write(start_tokens, step_idx)
        scores = layers.array_write(init_scores, step_idx)
635 636 637
        # cell states will be overwrited at each step.
        # caches contains states of history steps to reduce redundant
        # computation in decoder.
638 639 640 641
        caches = [{
            "k": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
642
                dtype=enc_output.dtype,
643 644 645 646
                value=0),
            "v": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
647
                dtype=enc_output.dtype,
648 649 650 651 652
                value=0)
        } for i in range(n_layer)]
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
            pre_scores = layers.array_read(array=scores, i=step_idx)
653 654
            # sequence_expand can gather sequences according to lod thus can be
            # used in beam search to sift states corresponding to selected ids.
655
            pre_src_attn_bias = layers.sequence_expand(
656 657
                x=trg_src_attn_bias, y=pre_scores)
            pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
658 659
            pre_caches = [{
                "k": layers.sequence_expand(
660
                    x=cache["k"], y=pre_scores),
661
                "v": layers.sequence_expand(
662
                    x=cache["v"], y=pre_scores),
663
            } for cache in caches]
664 665 666 667 668 669 670 671 672
            pre_pos = layers.elementwise_mul(
                x=layers.fill_constant_batch_size_like(
                    input=pre_enc_output,  # cann't use pre_ids here since it has lod
                    value=1,
                    shape=[-1, 1],
                    dtype=pre_ids.dtype),
                y=layers.increment(
                    x=step_idx, value=1.0, in_place=False),
                axis=0)
673 674 675 676 677 678 679 680 681 682
            logits = wrap_decoder(
                trg_vocab_size,
                max_in_len,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
                dropout_rate,
683
                weight_sharing,
684 685 686 687 688 689
                dec_inputs=(
                    pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
                    slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
                    src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
                enc_output=pre_enc_output,
                caches=pre_caches)
690 691
            topk_scores, topk_indices = layers.topk(
                input=layers.softmax(logits), k=beam_size)
692
            accu_scores = layers.elementwise_add(
693
                x=layers.log(topk_scores),
694 695 696 697 698
                y=layers.reshape(
                    pre_scores, shape=[-1]),
                axis=0)
            # beam_search op uses lod to distinguish branches.
            topk_indices = layers.lod_reset(topk_indices, pre_ids)
699 700
            selected_ids, selected_scores = layers.beam_search(
                pre_ids=pre_ids,
701
                pre_scores=pre_scores,
702 703 704 705 706 707
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
                end_id=eos_idx)
            layers.increment(x=step_idx, value=1.0, in_place=True)
            # update states
708 709
            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
710 711 712 713 714
            layers.assign(pre_src_attn_bias, trg_src_attn_bias)
            layers.assign(pre_enc_output, enc_output)
            for i in range(n_layer):
                layers.assign(pre_caches[i]["k"], caches[i]["k"])
                layers.assign(pre_caches[i]["v"], caches[i]["v"])
715
            layers.assign(
716 717 718
                layers.elementwise_add(
                    x=slf_attn_pre_softmax_shape,
                    y=attn_pre_softmax_shape_delta),
719 720 721 722 723 724
                slf_attn_pre_softmax_shape)
            layers.assign(
                layers.elementwise_add(
                    x=slf_attn_post_softmax_shape,
                    y=attn_post_softmax_shape_delta),
                slf_attn_post_softmax_shape)
725

726 727 728
            length_cond = layers.less_than(x=step_idx, y=max_len)
            finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
            layers.logical_and(x=length_cond, y=finish_cond, out=cond)
729

730 731
        finished_ids, finished_scores = layers.beam_search_decode(
            ids, scores, beam_size=beam_size, end_id=eos_idx)
732 733 734 735
        return finished_ids, finished_scores

    finished_ids, finished_scores = beam_search()
    return finished_ids, finished_scores