model.py 23.0 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7
from config import *
Y
ying 已提交
8

9 10

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
11
    """
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
30
                         n_head=1,
G
guosheng 已提交
31
                         dropout_rate=0.,
32
                         cache=None):
33
    """
Y
ying 已提交
34 35 36
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
37 38 39
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
40
            "Inputs: quries, keys and values should all be 3-D tensors.")
41

G
guosheng 已提交
42
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
43
        """
44 45 46
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
47
                      size=d_key * n_head,
48 49 50
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
51
                      size=d_key * n_head,
52 53 54
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
55
                      size=d_value * n_head,
56 57 58 59
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
60
    def __split_heads(x, n_head):
61 62 63
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
64 65
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
66
        """
G
guosheng 已提交
67
        if n_head == 1:
68 69 70
            return x

        hidden_size = x.shape[-1]
71 72
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
73
        reshaped = layers.reshape(
74
            x=x, shape=[0, 0, n_head, hidden_size // n_head])
75 76

        # permuate the dimensions into:
G
guosheng 已提交
77
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
78 79 80 81 82 83 84 85 86 87 88 89
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
90 91
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
92 93
        return layers.reshape(
            x=trans_x,
94
            shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
95

G
guosheng 已提交
96
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
97 98 99
        """
        Scaled Dot-Product Attention
        """
G
guosheng 已提交
100
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
101
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
102 103 104
        if attn_bias:
            product += attn_bias
        weights = layers.softmax(product)
105 106
        if dropout_rate:
            weights = layers.dropout(
G
guosheng 已提交
107 108 109 110
                weights,
                dropout_prob=dropout_rate,
                seed=ModelHyperParams.dropout_seed,
                is_test=False)
111 112 113
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
114
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
115

116 117 118
    if cache is not None:  # use cache and concat time steps
        k = cache["k"] = layers.concat([cache["k"], k], axis=1)
        v = cache["v"] = layers.concat([cache["v"], v], axis=1)
119

G
guosheng 已提交
120 121 122
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
123

G
guosheng 已提交
124
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
125 126 127
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)
128

129 130 131 132 133 134 135 136 137 138
    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
139 140 141
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
142 143 144 145 146
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
                       act="relu")
G
guosheng 已提交
147
    out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
148 149 150
    return out


151
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
152
    """
Y
ying 已提交
153
    Add residual connection, layer normalization and droput to the out tensor
154 155 156 157 158
    optionally according to the value of process_cmd.
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
159
        if cmd == "a":  # add residual connection
160
            out = out + prev_out if prev_out else out
Y
ying 已提交
161
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
162 163 164 165 166
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
167
        elif cmd == "d":  # add dropout
168 169
            if dropout_rate:
                out = layers.dropout(
G
guosheng 已提交
170 171 172 173
                    out,
                    dropout_prob=dropout_rate,
                    seed=ModelHyperParams.dropout_seed,
                    is_test=False)
174 175 176 177 178 179 180 181 182 183 184 185
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_max_len,
186
                    dropout_rate=0.,
G
guosheng 已提交
187
                    word_emb_param_name=None,
188
                    pos_enc_param_name=None):
Y
ying 已提交
189 190
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
191
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
192
    This module is used at the bottom of the encoder stacks.
193 194
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
195 196
        src_word,
        size=[src_vocab_size, src_emb_dim],
G
guosheng 已提交
197 198 199 200
        param_attr=fluid.ParamAttr(
            name=word_emb_param_name,
            initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
    src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
201 202 203 204 205 206 207
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
    return layers.dropout(
G
guosheng 已提交
208 209 210
        enc_input,
        dropout_prob=dropout_rate,
        seed=ModelHyperParams.dropout_seed,
211
        is_test=False) if dropout_rate else enc_input
212 213 214 215 216 217 218 219


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
220 221 222 223 224 225 226
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
227
                  dropout_rate=0.):
Y
ying 已提交
228 229 230 231 232
    """The encoder layers that can be stacked to form a deep encoder.
    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
233
    """
234 235 236
    attn_output = multi_head_attention(enc_input, enc_input, enc_input,
                                       attn_bias, d_key, d_value, d_model,
                                       n_head, dropout_rate)
Y
ying 已提交
237 238
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
239
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
240 241 242 243 244 245 246 247 248 249 250
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
251
            dropout_rate=0.):
252
    """
Y
ying 已提交
253 254
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
255 256
    """
    for i in range(n_layer):
257 258
        enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
                                   d_model, d_inner_hid, dropout_rate)
259 260 261 262
        enc_input = enc_output
    return enc_output


Y
ying 已提交
263 264 265 266 267 268 269 270 271
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
272
                  dropout_rate=0.,
273
                  cache=None):
Y
ying 已提交
274 275 276
    """ The layer to be stacked in decoder part.
    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
277
    """
Y
ying 已提交
278 279 280 281 282 283 284 285 286
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
287
        dropout_rate,
288
        cache, )
Y
ying 已提交
289 290 291 292 293 294 295 296 297 298 299 300 301 302
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
303
        dropout_rate, )
Y
ying 已提交
304 305 306 307 308 309 310 311 312 313 314 315 316 317
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
318 319 320
    return dec_output


Y
ying 已提交
321 322 323 324 325 326 327 328 329 330
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
331
            dropout_rate=0.,
332
            caches=None):
333 334 335 336
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
337
        dec_output = decoder_layer(
338 339 340 341 342 343 344 345 346
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
347
            dropout_rate, )
348 349 350 351
        dec_input = dec_output
    return dec_output


352
def make_all_inputs(input_fields):
353 354 355
    """
    Define the input data layers for the transformer model.
    """
356 357 358 359 360 361
    inputs = []
    for input_field in input_fields:
        input_var = layers.data(
            name=input_field,
            shape=input_descs[input_field][0],
            dtype=input_descs[input_field][1],
362 363
            lod_level=input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0,
364
            append_batch_size=False)
365 366
        inputs.append(input_var)
    return inputs
367 368


Y
ying 已提交
369 370 371 372 373 374 375 376 377 378
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
379
        dropout_rate,
G
guosheng 已提交
380
        weight_sharing,
381
        label_smooth_eps, ):
G
guosheng 已提交
382 383 384 385
    if weight_sharing:
        assert src_vocab_size == src_vocab_size, (
            "Vocabularies in source and target should be same for weight sharing."
        )
386
    enc_inputs = make_all_inputs(encoder_data_input_fields)
387

388 389 390 391 392 393 394 395 396 397
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
G
guosheng 已提交
398
        weight_sharing,
399
        enc_inputs, )
400

401
    dec_inputs = make_all_inputs(decoder_data_input_fields[:-1])
402 403 404 405 406 407 408 409 410 411 412

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
G
guosheng 已提交
413
        weight_sharing,
414
        dec_inputs,
415 416 417 418
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
419 420 421 422 423 424
    label, weights = make_all_inputs(label_data_input_fields)
    if label_smooth_eps:
        label = layers.label_smooth(
            label=layers.one_hot(
                input=label, depth=trg_vocab_size),
            epsilon=label_smooth_eps)
425

426
    cost = layers.softmax_with_cross_entropy(
427 428
        logits=layers.reshape(
            predict, shape=[-1, trg_vocab_size]),
429 430
        label=label,
        soft_label=True if label_smooth_eps else False)
431
    weighted_cost = cost * weights
G
guosheng 已提交
432 433 434
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
    avg_cost = sum_cost / token_num
Y
Yu Yang 已提交
435
    avg_cost.stop_gradient = True
G
guosheng 已提交
436
    return sum_cost, avg_cost, predict, token_num
437 438 439 440 441 442 443 444 445 446 447


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
G
guosheng 已提交
448
                 weight_sharing,
449
                 enc_inputs=None):
450 451 452
    """
    The wrapper assembles together all needed layers for the encoder.
    """
453
    if enc_inputs is None:
454
        # This is used to implement independent encoder program in inference.
455
        src_word, src_pos, src_slf_attn_bias = \
456
            make_all_inputs(encoder_data_input_fields +
457
                            encoder_util_input_fields)
458
    else:
459
        src_word, src_pos, src_slf_attn_bias = \
460
            enc_inputs
Y
ying 已提交
461 462 463 464 465 466
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
467
        dropout_rate,
G
guosheng 已提交
468
        word_emb_param_name=word_emb_param_names[0])
469 470
    enc_output = encoder(enc_input, src_slf_attn_bias, n_layer, n_head, d_key,
                         d_value, d_model, d_inner_hid, dropout_rate)
471 472 473 474 475 476 477 478 479 480 481 482
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
G
guosheng 已提交
483
                 weight_sharing,
484
                 dec_inputs=None,
485 486
                 enc_output=None,
                 caches=None):
487 488 489
    """
    The wrapper assembles together all needed layers for the decoder.
    """
490
    if dec_inputs is None:
491
        # This is used to implement independent decoder program in inference.
G
guosheng 已提交
492
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
493
        enc_output = make_all_inputs(
494
            decoder_data_input_fields + decoder_util_input_fields)
495
    else:
496
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
Y
ying 已提交
497 498 499 500 501 502 503

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
504
        dropout_rate,
G
guosheng 已提交
505 506
        word_emb_param_name=word_emb_param_names[0]
        if weight_sharing else word_emb_param_names[1])
Y
ying 已提交
507 508 509 510 511 512 513 514 515 516 517
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
518
        dropout_rate, )
519
    # Return logits for training and probs for inference.
G
guosheng 已提交
520
    if weight_sharing:
521 522 523 524 525
        predict = layers.matmul(
            x=dec_output,
            y=fluid.get_var(word_emb_param_names[0]),
            transpose_y=True)
        predict = layers.softmax(predict)
G
guosheng 已提交
526
    else:
527 528 529 530 531
        predict = layers.fc(input=dec_output,
                            size=trg_vocab_size,
                            bias_attr=False,
                            num_flatten_dims=2,
                            act='softmax')
532
    return predict
533 534 535 536 537 538 539 540 541 542 543 544 545


def fast_decode(
        src_vocab_size,
        trg_vocab_size,
        max_in_len,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
546
        weight_sharing,
547 548 549
        beam_size,
        max_out_len,
        eos_idx, ):
550 551 552 553
    """
    Use beam search to decode. Caches will be used to store states of history
    steps which can make the decoding faster.
    """
554 555
    enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
                              d_key, d_value, d_model, d_inner_hid,
556
                              dropout_rate, weight_sharing)
557
    start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
558 559 560
    slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
    src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
    attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
561
        make_all_inputs(fast_decoder_data_input_fields +
562
                        fast_decoder_util_input_fields)
563 564 565

    def beam_search():
        max_len = layers.fill_constant(
566 567 568 569 570
            shape=[1], dtype=start_tokens.dtype, value=max_out_len)
        step_idx = layers.fill_constant(
            shape=[1], dtype=start_tokens.dtype, value=0)
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond)
571
        # array states will be stored for each step.
572 573
        ids = layers.array_write(start_tokens, step_idx)
        scores = layers.array_write(init_scores, step_idx)
574 575 576
        # cell states will be overwrited at each step.
        # caches contains states of history steps to reduce redundant
        # computation in decoder.
577 578 579 580
        caches = [{
            "k": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
581
                dtype=enc_output.dtype,
582 583 584 585
                value=0),
            "v": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
586
                dtype=enc_output.dtype,
587 588 589 590 591
                value=0)
        } for i in range(n_layer)]
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
            pre_scores = layers.array_read(array=scores, i=step_idx)
592 593
            # sequence_expand can gather sequences according to lod thus can be
            # used in beam search to sift states corresponding to selected ids.
594
            pre_src_attn_bias = layers.sequence_expand(
595 596
                x=trg_src_attn_bias, y=pre_scores)
            pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
597 598
            pre_caches = [{
                "k": layers.sequence_expand(
599
                    x=cache["k"], y=pre_scores),
600
                "v": layers.sequence_expand(
601
                    x=cache["v"], y=pre_scores),
602
            } for cache in caches]
603 604 605 606 607 608 609 610 611
            pre_pos = layers.elementwise_mul(
                x=layers.fill_constant_batch_size_like(
                    input=pre_enc_output,  # cann't use pre_ids here since it has lod
                    value=1,
                    shape=[-1, 1],
                    dtype=pre_ids.dtype),
                y=layers.increment(
                    x=step_idx, value=1.0, in_place=False),
                axis=0)
612 613 614 615 616 617 618 619 620 621
            logits = wrap_decoder(
                trg_vocab_size,
                max_in_len,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
                dropout_rate,
622
                weight_sharing,
623 624 625 626 627 628
                dec_inputs=(
                    pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
                    slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
                    src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
                enc_output=pre_enc_output,
                caches=pre_caches)
629 630
            topk_scores, topk_indices = layers.topk(
                input=layers.softmax(logits), k=beam_size)
631
            accu_scores = layers.elementwise_add(
632
                x=layers.log(topk_scores),
633 634 635 636 637
                y=layers.reshape(
                    pre_scores, shape=[-1]),
                axis=0)
            # beam_search op uses lod to distinguish branches.
            topk_indices = layers.lod_reset(topk_indices, pre_ids)
638 639
            selected_ids, selected_scores = layers.beam_search(
                pre_ids=pre_ids,
640
                pre_scores=pre_scores,
641 642 643 644 645 646
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
                end_id=eos_idx)
            layers.increment(x=step_idx, value=1.0, in_place=True)
            # update states
647 648
            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
649 650 651 652 653
            layers.assign(pre_src_attn_bias, trg_src_attn_bias)
            layers.assign(pre_enc_output, enc_output)
            for i in range(n_layer):
                layers.assign(pre_caches[i]["k"], caches[i]["k"])
                layers.assign(pre_caches[i]["v"], caches[i]["v"])
654
            layers.assign(
655 656 657
                layers.elementwise_add(
                    x=slf_attn_pre_softmax_shape,
                    y=attn_pre_softmax_shape_delta),
658 659 660 661 662 663
                slf_attn_pre_softmax_shape)
            layers.assign(
                layers.elementwise_add(
                    x=slf_attn_post_softmax_shape,
                    y=attn_post_softmax_shape_delta),
                slf_attn_post_softmax_shape)
664

665 666 667
            length_cond = layers.less_than(x=step_idx, y=max_len)
            finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
            layers.logical_and(x=length_cond, y=finish_cond, out=cond)
668

669 670
        finished_ids, finished_scores = layers.beam_search_decode(
            ids, scores, beam_size=beam_size, end_id=eos_idx)
671 672 673 674
        return finished_ids, finished_scores

    finished_ids, finished_scores = beam_search()
    return finished_ids, finished_scores