model.py 22.3 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7
from config import *
Y
ying 已提交
8

9 10

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
11
    """
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
30
                         n_head=1,
G
guosheng 已提交
31
                         dropout_rate=0.,
32
                         cache=None):
33
    """
Y
ying 已提交
34 35 36
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
37 38 39
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
40
            "Inputs: quries, keys and values should all be 3-D tensors.")
41

G
guosheng 已提交
42
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
43
        """
44 45 46
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
47
                      size=d_key * n_head,
48 49 50
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
51
                      size=d_key * n_head,
52 53 54
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
55
                      size=d_value * n_head,
56 57 58 59
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
60
    def __split_heads(x, n_head):
61 62 63
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
64 65
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
66
        """
G
guosheng 已提交
67
        if n_head == 1:
68 69 70
            return x

        hidden_size = x.shape[-1]
71 72
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
73
        reshaped = layers.reshape(
74
            x=x, shape=[0, 0, n_head, hidden_size // n_head])
75 76

        # permuate the dimensions into:
G
guosheng 已提交
77
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
78 79 80 81 82 83 84 85 86 87 88 89
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
90 91
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
92 93
        return layers.reshape(
            x=trans_x,
94
            shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
95

G
guosheng 已提交
96
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
97 98 99
        """
        Scaled Dot-Product Attention
        """
G
guosheng 已提交
100
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
101
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
102 103 104
        if attn_bias:
            product += attn_bias
        weights = layers.softmax(product)
105 106
        if dropout_rate:
            weights = layers.dropout(
G
guosheng 已提交
107 108 109 110
                weights,
                dropout_prob=dropout_rate,
                seed=ModelHyperParams.dropout_seed,
                is_test=False)
111 112 113
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
114
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
115

116 117 118
    if cache is not None:  # use cache and concat time steps
        k = cache["k"] = layers.concat([cache["k"], k], axis=1)
        v = cache["v"] = layers.concat([cache["v"], v], axis=1)
119

G
guosheng 已提交
120 121 122
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
123

G
guosheng 已提交
124
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
125 126 127
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)
128

129 130 131 132 133 134 135 136 137 138
    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
139 140 141
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
142 143 144 145 146
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
                       act="relu")
G
guosheng 已提交
147
    out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
148 149 150
    return out


151
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
152
    """
Y
ying 已提交
153
    Add residual connection, layer normalization and droput to the out tensor
154 155 156 157 158
    optionally according to the value of process_cmd.
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
159
        if cmd == "a":  # add residual connection
160
            out = out + prev_out if prev_out else out
Y
ying 已提交
161
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
162 163 164 165 166
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
167
        elif cmd == "d":  # add dropout
168 169
            if dropout_rate:
                out = layers.dropout(
G
guosheng 已提交
170 171 172 173
                    out,
                    dropout_prob=dropout_rate,
                    seed=ModelHyperParams.dropout_seed,
                    is_test=False)
174 175 176 177 178 179 180 181 182 183 184 185
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_max_len,
186
                    dropout_rate=0.,
G
guosheng 已提交
187
                    word_emb_param_name=None,
188
                    pos_enc_param_name=None):
Y
ying 已提交
189 190
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
191
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
192
    This module is used at the bottom of the encoder stacks.
193 194
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
195 196
        src_word,
        size=[src_vocab_size, src_emb_dim],
G
guosheng 已提交
197 198 199
        param_attr=fluid.ParamAttr(
            name=word_emb_param_name,
            initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
Y
Yu Yang 已提交
200

G
guosheng 已提交
201
    src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
202 203 204 205 206 207 208
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
    return layers.dropout(
G
guosheng 已提交
209 210 211
        enc_input,
        dropout_prob=dropout_rate,
        seed=ModelHyperParams.dropout_seed,
212
        is_test=False) if dropout_rate else enc_input
213 214 215 216 217 218 219 220


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
221 222 223 224 225 226 227
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
228
                  dropout_rate=0.):
Y
ying 已提交
229 230 231 232 233
    """The encoder layers that can be stacked to form a deep encoder.
    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
234
    """
235 236 237
    attn_output = multi_head_attention(enc_input, enc_input, enc_input,
                                       attn_bias, d_key, d_value, d_model,
                                       n_head, dropout_rate)
Y
ying 已提交
238 239
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
240
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
241 242 243 244 245 246 247 248 249 250 251
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
252
            dropout_rate=0.):
253
    """
Y
ying 已提交
254 255
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
256 257
    """
    for i in range(n_layer):
258 259
        enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
                                   d_model, d_inner_hid, dropout_rate)
260 261 262 263
        enc_input = enc_output
    return enc_output


Y
ying 已提交
264 265 266 267 268 269 270 271 272
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
273
                  dropout_rate=0.,
274
                  cache=None):
Y
ying 已提交
275 276 277
    """ The layer to be stacked in decoder part.
    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
278
    """
Y
ying 已提交
279 280 281 282 283 284 285 286 287
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
288
        dropout_rate,
289
        cache, )
Y
ying 已提交
290 291 292 293 294 295 296 297 298 299 300 301 302 303
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
304
        dropout_rate, )
Y
ying 已提交
305 306 307 308 309 310 311 312 313 314 315 316 317 318
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
319 320 321
    return dec_output


Y
ying 已提交
322 323 324 325 326 327 328 329 330 331
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
332
            dropout_rate=0.,
333
            caches=None):
334 335 336 337
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
Yu Yang 已提交
338 339 340 341
        cache = None
        if caches is not None:
            cache = caches[i]

Y
ying 已提交
342
        dec_output = decoder_layer(
343 344 345 346 347 348 349 350 351
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
Y
Yu Yang 已提交
352 353
            dropout_rate,
            cache=cache)
354 355 356 357
        dec_input = dec_output
    return dec_output


358
def make_all_inputs(input_fields):
359 360 361
    """
    Define the input data layers for the transformer model.
    """
362 363 364 365 366 367
    inputs = []
    for input_field in input_fields:
        input_var = layers.data(
            name=input_field,
            shape=input_descs[input_field][0],
            dtype=input_descs[input_field][1],
368 369
            lod_level=input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0,
370
            append_batch_size=False)
371 372
        inputs.append(input_var)
    return inputs
373 374


Y
ying 已提交
375 376 377 378 379 380 381 382 383 384
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
385
        dropout_rate,
G
guosheng 已提交
386
        weight_sharing,
387
        label_smooth_eps, ):
G
guosheng 已提交
388 389 390 391
    if weight_sharing:
        assert src_vocab_size == src_vocab_size, (
            "Vocabularies in source and target should be same for weight sharing."
        )
392
    enc_inputs = make_all_inputs(encoder_data_input_fields)
393

394 395 396 397 398 399 400 401 402 403
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
G
guosheng 已提交
404
        weight_sharing,
405
        enc_inputs, )
406

407
    dec_inputs = make_all_inputs(decoder_data_input_fields[:-1])
408 409 410 411 412 413 414 415 416 417 418

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
G
guosheng 已提交
419
        weight_sharing,
420
        dec_inputs,
421 422 423 424
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
425 426 427 428 429 430
    label, weights = make_all_inputs(label_data_input_fields)
    if label_smooth_eps:
        label = layers.label_smooth(
            label=layers.one_hot(
                input=label, depth=trg_vocab_size),
            epsilon=label_smooth_eps)
431

432
    cost = layers.softmax_with_cross_entropy(
433 434
        logits=layers.reshape(
            predict, shape=[-1, trg_vocab_size]),
435 436
        label=label,
        soft_label=True if label_smooth_eps else False)
437
    weighted_cost = cost * weights
G
guosheng 已提交
438 439 440
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
    avg_cost = sum_cost / token_num
Y
Yu Yang 已提交
441
    avg_cost.stop_gradient = True
G
guosheng 已提交
442
    return sum_cost, avg_cost, predict, token_num
443 444 445 446 447 448 449 450 451 452 453


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
G
guosheng 已提交
454
                 weight_sharing,
455
                 enc_inputs=None):
456 457 458
    """
    The wrapper assembles together all needed layers for the encoder.
    """
459
    if enc_inputs is None:
460
        # This is used to implement independent encoder program in inference.
461
        src_word, src_pos, src_slf_attn_bias = \
Y
Yu Yang 已提交
462
            make_all_inputs(encoder_data_input_fields)
463
    else:
464
        src_word, src_pos, src_slf_attn_bias = \
465
            enc_inputs
Y
ying 已提交
466 467 468 469 470 471
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
472
        dropout_rate,
G
guosheng 已提交
473
        word_emb_param_name=word_emb_param_names[0])
474 475
    enc_output = encoder(enc_input, src_slf_attn_bias, n_layer, n_head, d_key,
                         d_value, d_model, d_inner_hid, dropout_rate)
476 477 478 479 480 481 482 483 484 485 486 487
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
G
guosheng 已提交
488
                 weight_sharing,
489
                 dec_inputs=None,
490 491
                 enc_output=None,
                 caches=None):
492 493 494
    """
    The wrapper assembles together all needed layers for the decoder.
    """
495
    if dec_inputs is None:
496
        # This is used to implement independent decoder program in inference.
G
guosheng 已提交
497
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
498
        enc_output = make_all_inputs(
499
            decoder_data_input_fields + decoder_util_input_fields)
500
    else:
501
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
Y
ying 已提交
502 503 504 505 506 507 508

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
509
        dropout_rate,
G
guosheng 已提交
510 511
        word_emb_param_name=word_emb_param_names[0]
        if weight_sharing else word_emb_param_names[1])
Y
ying 已提交
512 513 514 515 516 517 518 519 520 521 522
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
Y
Yu Yang 已提交
523 524
        dropout_rate,
        caches=caches)
525
    # Return logits for training and probs for inference.
G
guosheng 已提交
526
    if weight_sharing:
527 528 529 530
        predict = layers.matmul(
            x=dec_output,
            y=fluid.get_var(word_emb_param_names[0]),
            transpose_y=True)
G
guosheng 已提交
531
    else:
532 533 534
        predict = layers.fc(input=dec_output,
                            size=trg_vocab_size,
                            bias_attr=False,
Y
Fix bug  
Yu Yang 已提交
535 536 537
                            num_flatten_dims=2)
    if dec_inputs is None:
        predict = layers.softmax(predict)
538
    return predict
539 540 541 542 543 544 545 546 547 548 549 550 551


def fast_decode(
        src_vocab_size,
        trg_vocab_size,
        max_in_len,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
552
        weight_sharing,
553 554 555
        beam_size,
        max_out_len,
        eos_idx, ):
556 557 558 559
    """
    Use beam search to decode. Caches will be used to store states of history
    steps which can make the decoding faster.
    """
560 561
    enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
                              d_key, d_value, d_model, d_inner_hid,
562
                              dropout_rate, weight_sharing)
Y
Yu Yang 已提交
563 564
    start_tokens, init_scores, trg_src_attn_bias = \
        make_all_inputs(fast_decoder_data_input_fields )
565 566 567

    def beam_search():
        max_len = layers.fill_constant(
568 569 570 571 572
            shape=[1], dtype=start_tokens.dtype, value=max_out_len)
        step_idx = layers.fill_constant(
            shape=[1], dtype=start_tokens.dtype, value=0)
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond)
573
        # array states will be stored for each step.
Y
Yu Yang 已提交
574
        ids = layers.array_write(
Y
Yu Yang 已提交
575
            layers.reshape(start_tokens, (-1, 1)), step_idx)
576
        scores = layers.array_write(init_scores, step_idx)
577 578 579
        # cell states will be overwrited at each step.
        # caches contains states of history steps to reduce redundant
        # computation in decoder.
580 581 582 583
        caches = [{
            "k": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
584
                dtype=enc_output.dtype,
585 586 587 588
                value=0),
            "v": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
589
                dtype=enc_output.dtype,
590 591 592 593
                value=0)
        } for i in range(n_layer)]
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
Y
Yu Yang 已提交
594
            pre_ids = layers.reshape(pre_ids, (-1, 1, 1))
595
            pre_scores = layers.array_read(array=scores, i=step_idx)
596 597
            # sequence_expand can gather sequences according to lod thus can be
            # used in beam search to sift states corresponding to selected ids.
598
            pre_src_attn_bias = layers.sequence_expand(
599 600
                x=trg_src_attn_bias, y=pre_scores)
            pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
601 602
            pre_caches = [{
                "k": layers.sequence_expand(
603
                    x=cache["k"], y=pre_scores),
604
                "v": layers.sequence_expand(
605
                    x=cache["v"], y=pre_scores),
606
            } for cache in caches]
607 608 609 610
            pre_pos = layers.elementwise_mul(
                x=layers.fill_constant_batch_size_like(
                    input=pre_enc_output,  # cann't use pre_ids here since it has lod
                    value=1,
Y
Yu Yang 已提交
611
                    shape=[-1, 1, 1],
612 613 614 615
                    dtype=pre_ids.dtype),
                y=layers.increment(
                    x=step_idx, value=1.0, in_place=False),
                axis=0)
616 617 618 619 620 621 622 623 624 625
            logits = wrap_decoder(
                trg_vocab_size,
                max_in_len,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
                dropout_rate,
626
                weight_sharing,
Y
Yu Yang 已提交
627
                dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
628 629
                enc_output=pre_enc_output,
                caches=pre_caches)
Y
Yu Yang 已提交
630 631
            logits = layers.reshape(logits, (-1, trg_vocab_size))

632 633
            topk_scores, topk_indices = layers.topk(
                input=layers.softmax(logits), k=beam_size)
634
            accu_scores = layers.elementwise_add(
635
                x=layers.log(topk_scores),
636 637 638 639 640
                y=layers.reshape(
                    pre_scores, shape=[-1]),
                axis=0)
            # beam_search op uses lod to distinguish branches.
            topk_indices = layers.lod_reset(topk_indices, pre_ids)
641 642
            selected_ids, selected_scores = layers.beam_search(
                pre_ids=pre_ids,
643
                pre_scores=pre_scores,
644 645 646 647
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
                end_id=eos_idx)
Y
Yu Yang 已提交
648

649 650
            layers.increment(x=step_idx, value=1.0, in_place=True)
            # update states
651 652
            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
653 654 655 656 657
            layers.assign(pre_src_attn_bias, trg_src_attn_bias)
            layers.assign(pre_enc_output, enc_output)
            for i in range(n_layer):
                layers.assign(pre_caches[i]["k"], caches[i]["k"])
                layers.assign(pre_caches[i]["v"], caches[i]["v"])
658 659 660
            length_cond = layers.less_than(x=step_idx, y=max_len)
            finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
            layers.logical_and(x=length_cond, y=finish_cond, out=cond)
661

662
        finished_ids, finished_scores = layers.beam_search_decode(
Y
Yu Yang 已提交
663
            ids, scores, beam_size=beam_size, end_id=eos_idx)
664 665 666 667
        return finished_ids, finished_scores

    finished_ids, finished_scores = beam_search()
    return finished_ids, finished_scores