model.py 22.4 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7
from config import *
Y
ying 已提交
8

9 10

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
11
    """
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
30
                         n_head=1,
G
guosheng 已提交
31
                         dropout_rate=0.,
32
                         cache=None):
33
    """
Y
ying 已提交
34 35 36
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
37 38 39
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
40
            "Inputs: quries, keys and values should all be 3-D tensors.")
41

G
guosheng 已提交
42
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
43
        """
44 45 46
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
47
                      size=d_key * n_head,
48 49 50
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
51
                      size=d_key * n_head,
52 53 54
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
55
                      size=d_value * n_head,
56 57 58 59
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
60
    def __split_heads(x, n_head):
61 62 63
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
64 65
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
66
        """
G
guosheng 已提交
67
        if n_head == 1:
68 69 70
            return x

        hidden_size = x.shape[-1]
71 72
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
73
        reshaped = layers.reshape(
74
            x=x, shape=[0, 0, n_head, hidden_size // n_head])
75 76

        # permuate the dimensions into:
G
guosheng 已提交
77
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
78 79 80 81 82 83 84 85 86 87 88 89
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
90 91
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
92 93
        return layers.reshape(
            x=trans_x,
94
            shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
95

G
guosheng 已提交
96
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
97 98 99
        """
        Scaled Dot-Product Attention
        """
G
guosheng 已提交
100
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
101
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
102 103 104
        if attn_bias:
            product += attn_bias
        weights = layers.softmax(product)
105 106
        if dropout_rate:
            weights = layers.dropout(
G
guosheng 已提交
107 108 109 110
                weights,
                dropout_prob=dropout_rate,
                seed=ModelHyperParams.dropout_seed,
                is_test=False)
111 112 113
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
114
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
115

116 117 118
    if cache is not None:  # use cache and concat time steps
        k = cache["k"] = layers.concat([cache["k"], k], axis=1)
        v = cache["v"] = layers.concat([cache["v"], v], axis=1)
119

G
guosheng 已提交
120 121 122
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
123

G
guosheng 已提交
124
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
125 126 127
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)
128

129 130 131 132 133 134 135 136 137 138
    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
139 140 141
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
142 143 144 145 146
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
                       act="relu")
G
guosheng 已提交
147
    out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
148 149 150
    return out


151
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
152
    """
Y
ying 已提交
153
    Add residual connection, layer normalization and droput to the out tensor
154 155 156 157 158
    optionally according to the value of process_cmd.
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
159
        if cmd == "a":  # add residual connection
160
            out = out + prev_out if prev_out else out
Y
ying 已提交
161
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
162 163 164 165 166
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
167
        elif cmd == "d":  # add dropout
168 169
            if dropout_rate:
                out = layers.dropout(
G
guosheng 已提交
170 171 172 173
                    out,
                    dropout_prob=dropout_rate,
                    seed=ModelHyperParams.dropout_seed,
                    is_test=False)
174 175 176 177 178 179 180 181 182 183 184 185
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_max_len,
186
                    dropout_rate=0.,
G
guosheng 已提交
187
                    word_emb_param_name=None,
188
                    pos_enc_param_name=None):
Y
ying 已提交
189 190
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
191
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
192
    This module is used at the bottom of the encoder stacks.
193 194
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
195 196
        src_word,
        size=[src_vocab_size, src_emb_dim],
G
guosheng 已提交
197 198 199
        param_attr=fluid.ParamAttr(
            name=word_emb_param_name,
            initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
Y
Yu Yang 已提交
200

G
guosheng 已提交
201
    src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
202 203 204 205 206 207 208
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
    return layers.dropout(
G
guosheng 已提交
209 210 211
        enc_input,
        dropout_prob=dropout_rate,
        seed=ModelHyperParams.dropout_seed,
212
        is_test=False) if dropout_rate else enc_input
213 214 215 216 217 218 219 220


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
221 222 223 224 225 226 227
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
228
                  dropout_rate=0.):
Y
ying 已提交
229 230 231 232 233
    """The encoder layers that can be stacked to form a deep encoder.
    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
234
    """
235 236 237
    attn_output = multi_head_attention(enc_input, enc_input, enc_input,
                                       attn_bias, d_key, d_value, d_model,
                                       n_head, dropout_rate)
Y
ying 已提交
238 239
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
240
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
241 242 243 244 245 246 247 248 249 250 251
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
252
            dropout_rate=0.):
253
    """
Y
ying 已提交
254 255
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
256 257
    """
    for i in range(n_layer):
258 259
        enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
                                   d_model, d_inner_hid, dropout_rate)
260 261 262 263
        enc_input = enc_output
    return enc_output


Y
ying 已提交
264 265 266 267 268 269 270 271 272
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
273
                  dropout_rate=0.,
274
                  cache=None):
Y
ying 已提交
275 276 277
    """ The layer to be stacked in decoder part.
    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
278
    """
Y
ying 已提交
279 280 281 282 283 284 285 286 287
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
288
        dropout_rate,
289
        cache, )
Y
ying 已提交
290 291 292 293 294 295 296 297 298 299 300 301 302 303
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
304
        dropout_rate, )
Y
ying 已提交
305 306 307 308 309 310 311 312 313 314 315 316 317 318
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
319 320 321
    return dec_output


Y
ying 已提交
322 323 324 325 326 327 328 329 330 331
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
332
            dropout_rate=0.,
333
            caches=None):
334 335 336 337
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
338
        dec_output = decoder_layer(
339 340 341 342 343 344 345 346 347
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
348
            dropout_rate, )
349 350 351 352
        dec_input = dec_output
    return dec_output


353
def make_all_inputs(input_fields):
354 355 356
    """
    Define the input data layers for the transformer model.
    """
357 358 359 360 361 362
    inputs = []
    for input_field in input_fields:
        input_var = layers.data(
            name=input_field,
            shape=input_descs[input_field][0],
            dtype=input_descs[input_field][1],
363 364
            lod_level=input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0,
365
            append_batch_size=False)
366 367
        inputs.append(input_var)
    return inputs
368 369


Y
ying 已提交
370 371 372 373 374 375 376 377 378 379
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
380
        dropout_rate,
G
guosheng 已提交
381
        weight_sharing,
382
        label_smooth_eps, ):
G
guosheng 已提交
383 384 385 386
    if weight_sharing:
        assert src_vocab_size == src_vocab_size, (
            "Vocabularies in source and target should be same for weight sharing."
        )
387
    enc_inputs = make_all_inputs(encoder_data_input_fields)
388

389 390 391 392 393 394 395 396 397 398
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
G
guosheng 已提交
399
        weight_sharing,
400
        enc_inputs, )
401

402
    dec_inputs = make_all_inputs(decoder_data_input_fields[:-1])
403 404 405 406 407 408 409 410 411 412 413

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
G
guosheng 已提交
414
        weight_sharing,
415
        dec_inputs,
416 417 418 419
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
420 421 422 423 424 425
    label, weights = make_all_inputs(label_data_input_fields)
    if label_smooth_eps:
        label = layers.label_smooth(
            label=layers.one_hot(
                input=label, depth=trg_vocab_size),
            epsilon=label_smooth_eps)
426

427
    cost = layers.softmax_with_cross_entropy(
428 429
        logits=layers.reshape(
            predict, shape=[-1, trg_vocab_size]),
430 431
        label=label,
        soft_label=True if label_smooth_eps else False)
432
    weighted_cost = cost * weights
G
guosheng 已提交
433 434 435
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
    avg_cost = sum_cost / token_num
Y
Yu Yang 已提交
436
    avg_cost.stop_gradient = True
G
guosheng 已提交
437
    return sum_cost, avg_cost, predict, token_num
438 439 440 441 442 443 444 445 446 447 448


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
G
guosheng 已提交
449
                 weight_sharing,
450
                 enc_inputs=None):
451 452 453
    """
    The wrapper assembles together all needed layers for the encoder.
    """
454
    if enc_inputs is None:
455
        # This is used to implement independent encoder program in inference.
456
        src_word, src_pos, src_slf_attn_bias = \
Y
Yu Yang 已提交
457
            make_all_inputs(encoder_data_input_fields)
458
    else:
459
        src_word, src_pos, src_slf_attn_bias = \
460
            enc_inputs
Y
ying 已提交
461 462 463 464 465 466
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
467
        dropout_rate,
G
guosheng 已提交
468
        word_emb_param_name=word_emb_param_names[0])
469 470
    enc_output = encoder(enc_input, src_slf_attn_bias, n_layer, n_head, d_key,
                         d_value, d_model, d_inner_hid, dropout_rate)
471 472 473 474 475 476 477 478 479 480 481 482
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
G
guosheng 已提交
483
                 weight_sharing,
484
                 dec_inputs=None,
485 486
                 enc_output=None,
                 caches=None):
487 488 489
    """
    The wrapper assembles together all needed layers for the decoder.
    """
490
    if dec_inputs is None:
491
        # This is used to implement independent decoder program in inference.
G
guosheng 已提交
492
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
493
        enc_output = make_all_inputs(
494
            decoder_data_input_fields + decoder_util_input_fields)
495
    else:
496
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
Y
ying 已提交
497 498 499 500 501 502 503

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
504
        dropout_rate,
G
guosheng 已提交
505 506
        word_emb_param_name=word_emb_param_names[0]
        if weight_sharing else word_emb_param_names[1])
Y
ying 已提交
507 508 509 510 511 512 513 514 515 516 517
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
518
        dropout_rate, )
519
    # Return logits for training and probs for inference.
G
guosheng 已提交
520
    if weight_sharing:
521 522 523 524
        predict = layers.matmul(
            x=dec_output,
            y=fluid.get_var(word_emb_param_names[0]),
            transpose_y=True)
G
guosheng 已提交
525
    else:
526 527 528
        predict = layers.fc(input=dec_output,
                            size=trg_vocab_size,
                            bias_attr=False,
Y
Fix bug  
Yu Yang 已提交
529 530 531
                            num_flatten_dims=2)
    if dec_inputs is None:
        predict = layers.softmax(predict)
532
    return predict
533 534 535 536 537 538 539 540 541 542 543 544 545


def fast_decode(
        src_vocab_size,
        trg_vocab_size,
        max_in_len,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
546
        weight_sharing,
547 548 549
        beam_size,
        max_out_len,
        eos_idx, ):
550 551 552 553
    """
    Use beam search to decode. Caches will be used to store states of history
    steps which can make the decoding faster.
    """
554 555
    enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
                              d_key, d_value, d_model, d_inner_hid,
556
                              dropout_rate, weight_sharing)
Y
Yu Yang 已提交
557 558
    start_tokens, init_scores, trg_src_attn_bias = \
        make_all_inputs(fast_decoder_data_input_fields )
559 560 561

    def beam_search():
        max_len = layers.fill_constant(
562 563 564 565 566
            shape=[1], dtype=start_tokens.dtype, value=max_out_len)
        step_idx = layers.fill_constant(
            shape=[1], dtype=start_tokens.dtype, value=0)
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond)
567
        # array states will be stored for each step.
568
        ids = layers.array_write(start_tokens, step_idx)
Y
Yu Yang 已提交
569 570
        ids_flatten = layers.array_write(
            layers.reshape(start_tokens, (-1, 1)), step_idx)
571
        scores = layers.array_write(init_scores, step_idx)
572 573 574
        # cell states will be overwrited at each step.
        # caches contains states of history steps to reduce redundant
        # computation in decoder.
575 576 577 578
        caches = [{
            "k": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
579
                dtype=enc_output.dtype,
580 581 582 583
                value=0),
            "v": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
584
                dtype=enc_output.dtype,
585 586 587 588 589
                value=0)
        } for i in range(n_layer)]
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
            pre_scores = layers.array_read(array=scores, i=step_idx)
590 591
            # sequence_expand can gather sequences according to lod thus can be
            # used in beam search to sift states corresponding to selected ids.
592
            pre_src_attn_bias = layers.sequence_expand(
593 594
                x=trg_src_attn_bias, y=pre_scores)
            pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
595 596
            pre_caches = [{
                "k": layers.sequence_expand(
597
                    x=cache["k"], y=pre_scores),
598
                "v": layers.sequence_expand(
599
                    x=cache["v"], y=pre_scores),
600
            } for cache in caches]
601 602 603 604
            pre_pos = layers.elementwise_mul(
                x=layers.fill_constant_batch_size_like(
                    input=pre_enc_output,  # cann't use pre_ids here since it has lod
                    value=1,
Y
Yu Yang 已提交
605
                    shape=[-1, 1, 1],
606 607 608 609
                    dtype=pre_ids.dtype),
                y=layers.increment(
                    x=step_idx, value=1.0, in_place=False),
                axis=0)
610 611 612 613 614 615 616 617 618 619
            logits = wrap_decoder(
                trg_vocab_size,
                max_in_len,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
                dropout_rate,
620
                weight_sharing,
Y
Yu Yang 已提交
621
                dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
622 623
                enc_output=pre_enc_output,
                caches=pre_caches)
Y
Yu Yang 已提交
624 625
            logits = layers.reshape(logits, (-1, trg_vocab_size))

626 627
            topk_scores, topk_indices = layers.topk(
                input=layers.softmax(logits), k=beam_size)
628
            accu_scores = layers.elementwise_add(
629
                x=layers.log(topk_scores),
630 631 632 633 634
                y=layers.reshape(
                    pre_scores, shape=[-1]),
                axis=0)
            # beam_search op uses lod to distinguish branches.
            topk_indices = layers.lod_reset(topk_indices, pre_ids)
635 636
            selected_ids, selected_scores = layers.beam_search(
                pre_ids=pre_ids,
637
                pre_scores=pre_scores,
638 639 640 641
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
                end_id=eos_idx)
Y
Yu Yang 已提交
642

643 644
            layers.increment(x=step_idx, value=1.0, in_place=True)
            # update states
Y
Yu Yang 已提交
645 646
            layers.array_write(selected_ids, i=step_idx, array=ids_flatten)
            selected_ids = layers.reshape(selected_ids, shape=(-1, 1, 1))
647 648
            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
649 650 651 652 653
            layers.assign(pre_src_attn_bias, trg_src_attn_bias)
            layers.assign(pre_enc_output, enc_output)
            for i in range(n_layer):
                layers.assign(pre_caches[i]["k"], caches[i]["k"])
                layers.assign(pre_caches[i]["v"], caches[i]["v"])
654 655 656
            length_cond = layers.less_than(x=step_idx, y=max_len)
            finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
            layers.logical_and(x=length_cond, y=finish_cond, out=cond)
657

658
        finished_ids, finished_scores = layers.beam_search_decode(
Y
Yu Yang 已提交
659
            ids_flatten, scores, beam_size=beam_size, end_id=eos_idx)
660 661 662 663
        return finished_ids, finished_scores

    finished_ids, finished_scores = beam_search()
    return finished_ids, finished_scores