model.py 22.3 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7
from config import *
Y
ying 已提交
8

9 10

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
11
    """
12 13 14
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
G
guosheng 已提交
15
        pos / np.power(10000, 2. * (j // 2) / d_pos_vec)
16 17 18 19 20 21 22 23 24 25 26 27 28 29
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
30
                         n_head=1,
G
guosheng 已提交
31
                         dropout_rate=0.,
32
                         cache=None):
33
    """
Y
ying 已提交
34 35 36
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
37 38 39
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
40
            "Inputs: quries, keys and values should all be 3-D tensors.")
41

G
guosheng 已提交
42
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
43
        """
44 45 46
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
47
                      size=d_key * n_head,
48 49 50
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
51
                      size=d_key * n_head,
52 53 54
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
55
                      size=d_value * n_head,
56 57 58 59
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
60
    def __split_heads(x, n_head):
61 62 63
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
64 65
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
66
        """
G
guosheng 已提交
67
        if n_head == 1:
68 69 70
            return x

        hidden_size = x.shape[-1]
71 72
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
73
        reshaped = layers.reshape(
74
            x=x, shape=[0, 0, n_head, hidden_size // n_head])
75 76

        # permuate the dimensions into:
G
guosheng 已提交
77
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
78 79 80 81 82 83 84 85 86 87 88 89
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
90 91
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
92
        return layers.reshape(
G
guosheng 已提交
93
            x=trans_x, shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]])
94

G
guosheng 已提交
95
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
96 97 98
        """
        Scaled Dot-Product Attention
        """
G
guosheng 已提交
99
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
100
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
101 102 103
        if attn_bias:
            product += attn_bias
        weights = layers.softmax(product)
104 105
        if dropout_rate:
            weights = layers.dropout(
G
guosheng 已提交
106 107 108 109
                weights,
                dropout_prob=dropout_rate,
                seed=ModelHyperParams.dropout_seed,
                is_test=False)
110 111 112
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
113
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
114

115 116 117
    if cache is not None:  # use cache and concat time steps
        k = cache["k"] = layers.concat([cache["k"], k], axis=1)
        v = cache["v"] = layers.concat([cache["v"], v], axis=1)
118

G
guosheng 已提交
119 120 121
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
122

G
guosheng 已提交
123
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
124 125 126
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)
127

128 129 130 131 132 133 134 135 136 137
    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
138 139 140
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
141 142 143 144 145
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
                       act="relu")
G
guosheng 已提交
146
    out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
147 148 149
    return out


150
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
151
    """
Y
ying 已提交
152
    Add residual connection, layer normalization and droput to the out tensor
153 154 155 156 157
    optionally according to the value of process_cmd.
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
158
        if cmd == "a":  # add residual connection
159
            out = out + prev_out if prev_out else out
Y
ying 已提交
160
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
161 162 163 164 165
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
166
        elif cmd == "d":  # add dropout
167 168
            if dropout_rate:
                out = layers.dropout(
G
guosheng 已提交
169 170 171 172
                    out,
                    dropout_prob=dropout_rate,
                    seed=ModelHyperParams.dropout_seed,
                    is_test=False)
173 174 175 176 177 178 179 180 181 182 183 184
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_max_len,
185
                    dropout_rate=0.,
G
guosheng 已提交
186
                    word_emb_param_name=None,
187
                    pos_enc_param_name=None):
Y
ying 已提交
188 189
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
190
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
191
    This module is used at the bottom of the encoder stacks.
192 193
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
194 195
        src_word,
        size=[src_vocab_size, src_emb_dim],
G
guosheng 已提交
196 197 198
        param_attr=fluid.ParamAttr(
            name=word_emb_param_name,
            initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
Y
Yu Yang 已提交
199

G
guosheng 已提交
200
    src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
201 202 203 204 205 206 207
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
    return layers.dropout(
G
guosheng 已提交
208 209 210
        enc_input,
        dropout_prob=dropout_rate,
        seed=ModelHyperParams.dropout_seed,
211
        is_test=False) if dropout_rate else enc_input
212 213 214 215 216 217 218 219


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
220 221 222 223 224 225 226
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
227
                  dropout_rate=0.):
Y
ying 已提交
228 229 230 231 232
    """The encoder layers that can be stacked to form a deep encoder.
    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
233
    """
234 235 236
    attn_output = multi_head_attention(enc_input, enc_input, enc_input,
                                       attn_bias, d_key, d_value, d_model,
                                       n_head, dropout_rate)
Y
ying 已提交
237 238
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
239
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
240 241 242 243 244 245 246 247 248 249 250
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
251
            dropout_rate=0.):
252
    """
Y
ying 已提交
253 254
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
255 256
    """
    for i in range(n_layer):
257 258
        enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
                                   d_model, d_inner_hid, dropout_rate)
259 260 261 262
        enc_input = enc_output
    return enc_output


Y
ying 已提交
263 264 265 266 267 268 269 270 271
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
272
                  dropout_rate=0.,
273
                  cache=None):
Y
ying 已提交
274 275 276
    """ The layer to be stacked in decoder part.
    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
277
    """
Y
ying 已提交
278 279 280 281 282 283 284 285 286
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
287
        dropout_rate,
288
        cache, )
Y
ying 已提交
289 290 291 292 293 294 295 296 297 298 299 300 301 302
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
303
        dropout_rate, )
Y
ying 已提交
304 305 306 307 308 309 310 311 312 313 314 315 316 317
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
318 319 320
    return dec_output


Y
ying 已提交
321 322 323 324 325 326 327 328 329 330
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
331
            dropout_rate=0.,
332
            caches=None):
333 334 335 336
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
Yu Yang 已提交
337 338 339 340
        cache = None
        if caches is not None:
            cache = caches[i]

Y
ying 已提交
341
        dec_output = decoder_layer(
342 343 344 345 346 347 348 349 350
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
Y
Yu Yang 已提交
351 352
            dropout_rate,
            cache=cache)
353 354 355 356
        dec_input = dec_output
    return dec_output


357
def make_all_inputs(input_fields):
358 359 360
    """
    Define the input data layers for the transformer model.
    """
361 362 363 364 365 366
    inputs = []
    for input_field in input_fields:
        input_var = layers.data(
            name=input_field,
            shape=input_descs[input_field][0],
            dtype=input_descs[input_field][1],
367 368
            lod_level=input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0,
369
            append_batch_size=False)
370 371
        inputs.append(input_var)
    return inputs
372 373


Y
ying 已提交
374 375 376 377 378 379 380 381 382 383
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
384
        dropout_rate,
G
guosheng 已提交
385
        weight_sharing,
386
        label_smooth_eps, ):
G
guosheng 已提交
387 388 389 390
    if weight_sharing:
        assert src_vocab_size == src_vocab_size, (
            "Vocabularies in source and target should be same for weight sharing."
        )
391
    enc_inputs = make_all_inputs(encoder_data_input_fields)
392

393 394 395 396 397 398 399 400 401 402
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
G
guosheng 已提交
403
        weight_sharing,
404
        enc_inputs, )
405

406
    dec_inputs = make_all_inputs(decoder_data_input_fields[:-1])
407 408 409 410 411 412 413 414 415 416 417

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
G
guosheng 已提交
418
        weight_sharing,
419
        dec_inputs,
420 421 422 423
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
424 425 426 427 428 429
    label, weights = make_all_inputs(label_data_input_fields)
    if label_smooth_eps:
        label = layers.label_smooth(
            label=layers.one_hot(
                input=label, depth=trg_vocab_size),
            epsilon=label_smooth_eps)
430

431
    cost = layers.softmax_with_cross_entropy(
432 433
        logits=layers.reshape(
            predict, shape=[-1, trg_vocab_size]),
434 435
        label=label,
        soft_label=True if label_smooth_eps else False)
436
    weighted_cost = cost * weights
G
guosheng 已提交
437 438 439
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
    avg_cost = sum_cost / token_num
Y
Yu Yang 已提交
440
    avg_cost.stop_gradient = True
G
guosheng 已提交
441
    return sum_cost, avg_cost, predict, token_num
442 443 444 445 446 447 448 449 450 451 452


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
G
guosheng 已提交
453
                 weight_sharing,
454
                 enc_inputs=None):
455 456 457
    """
    The wrapper assembles together all needed layers for the encoder.
    """
458
    if enc_inputs is None:
459
        # This is used to implement independent encoder program in inference.
460
        src_word, src_pos, src_slf_attn_bias = \
Y
Yu Yang 已提交
461
            make_all_inputs(encoder_data_input_fields)
462
    else:
463
        src_word, src_pos, src_slf_attn_bias = \
464
            enc_inputs
Y
ying 已提交
465 466 467 468 469 470
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
471
        dropout_rate,
G
guosheng 已提交
472
        word_emb_param_name=word_emb_param_names[0])
473 474
    enc_output = encoder(enc_input, src_slf_attn_bias, n_layer, n_head, d_key,
                         d_value, d_model, d_inner_hid, dropout_rate)
475 476 477 478 479 480 481 482 483 484 485 486
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
G
guosheng 已提交
487
                 weight_sharing,
488
                 dec_inputs=None,
489 490
                 enc_output=None,
                 caches=None):
491 492 493
    """
    The wrapper assembles together all needed layers for the decoder.
    """
494
    if dec_inputs is None:
495
        # This is used to implement independent decoder program in inference.
G
guosheng 已提交
496
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
497
        enc_output = make_all_inputs(
498
            decoder_data_input_fields + decoder_util_input_fields)
499
    else:
500
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
Y
ying 已提交
501 502 503 504 505 506 507

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
508
        dropout_rate,
G
guosheng 已提交
509 510
        word_emb_param_name=word_emb_param_names[0]
        if weight_sharing else word_emb_param_names[1])
Y
ying 已提交
511 512 513 514 515 516 517 518 519 520 521
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
Y
Yu Yang 已提交
522 523
        dropout_rate,
        caches=caches)
524
    # Return logits for training and probs for inference.
G
guosheng 已提交
525
    if weight_sharing:
526 527 528 529
        predict = layers.matmul(
            x=dec_output,
            y=fluid.get_var(word_emb_param_names[0]),
            transpose_y=True)
G
guosheng 已提交
530
    else:
531 532 533
        predict = layers.fc(input=dec_output,
                            size=trg_vocab_size,
                            bias_attr=False,
Y
Fix bug  
Yu Yang 已提交
534 535 536
                            num_flatten_dims=2)
    if dec_inputs is None:
        predict = layers.softmax(predict)
537
    return predict
538 539 540 541 542 543 544 545 546 547 548 549 550


def fast_decode(
        src_vocab_size,
        trg_vocab_size,
        max_in_len,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
551
        weight_sharing,
552 553 554
        beam_size,
        max_out_len,
        eos_idx, ):
555 556 557 558
    """
    Use beam search to decode. Caches will be used to store states of history
    steps which can make the decoding faster.
    """
559 560
    enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
                              d_key, d_value, d_model, d_inner_hid,
561
                              dropout_rate, weight_sharing)
Y
Yu Yang 已提交
562 563
    start_tokens, init_scores, trg_src_attn_bias = \
        make_all_inputs(fast_decoder_data_input_fields )
564 565 566

    def beam_search():
        max_len = layers.fill_constant(
567 568 569 570 571
            shape=[1], dtype=start_tokens.dtype, value=max_out_len)
        step_idx = layers.fill_constant(
            shape=[1], dtype=start_tokens.dtype, value=0)
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond)
572
        # array states will be stored for each step.
Y
Yu Yang 已提交
573
        ids = layers.array_write(
Y
Yu Yang 已提交
574
            layers.reshape(start_tokens, (-1, 1)), step_idx)
575
        scores = layers.array_write(init_scores, step_idx)
576 577 578
        # cell states will be overwrited at each step.
        # caches contains states of history steps to reduce redundant
        # computation in decoder.
579 580 581 582
        caches = [{
            "k": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
583
                dtype=enc_output.dtype,
584 585 586 587
                value=0),
            "v": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
588
                dtype=enc_output.dtype,
589 590 591 592
                value=0)
        } for i in range(n_layer)]
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
Y
Yu Yang 已提交
593
            pre_ids = layers.reshape(pre_ids, (-1, 1, 1))
594
            pre_scores = layers.array_read(array=scores, i=step_idx)
595 596
            # sequence_expand can gather sequences according to lod thus can be
            # used in beam search to sift states corresponding to selected ids.
597
            pre_src_attn_bias = layers.sequence_expand(
598 599
                x=trg_src_attn_bias, y=pre_scores)
            pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
600 601
            pre_caches = [{
                "k": layers.sequence_expand(
602
                    x=cache["k"], y=pre_scores),
603
                "v": layers.sequence_expand(
604
                    x=cache["v"], y=pre_scores),
605
            } for cache in caches]
606 607 608 609
            pre_pos = layers.elementwise_mul(
                x=layers.fill_constant_batch_size_like(
                    input=pre_enc_output,  # cann't use pre_ids here since it has lod
                    value=1,
Y
Yu Yang 已提交
610
                    shape=[-1, 1, 1],
611 612 613 614
                    dtype=pre_ids.dtype),
                y=layers.increment(
                    x=step_idx, value=1.0, in_place=False),
                axis=0)
615 616 617 618 619 620 621 622 623 624
            logits = wrap_decoder(
                trg_vocab_size,
                max_in_len,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
                dropout_rate,
625
                weight_sharing,
Y
Yu Yang 已提交
626
                dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
627 628
                enc_output=pre_enc_output,
                caches=pre_caches)
Y
Yu Yang 已提交
629 630
            logits = layers.reshape(logits, (-1, trg_vocab_size))

631 632
            topk_scores, topk_indices = layers.topk(
                input=layers.softmax(logits), k=beam_size)
633
            accu_scores = layers.elementwise_add(
634
                x=layers.log(topk_scores),
635 636 637 638 639
                y=layers.reshape(
                    pre_scores, shape=[-1]),
                axis=0)
            # beam_search op uses lod to distinguish branches.
            topk_indices = layers.lod_reset(topk_indices, pre_ids)
640 641
            selected_ids, selected_scores = layers.beam_search(
                pre_ids=pre_ids,
642
                pre_scores=pre_scores,
643 644 645 646
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
                end_id=eos_idx)
Y
Yu Yang 已提交
647

648 649
            layers.increment(x=step_idx, value=1.0, in_place=True)
            # update states
650 651
            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
652 653 654 655 656
            layers.assign(pre_src_attn_bias, trg_src_attn_bias)
            layers.assign(pre_enc_output, enc_output)
            for i in range(n_layer):
                layers.assign(pre_caches[i]["k"], caches[i]["k"])
                layers.assign(pre_caches[i]["v"], caches[i]["v"])
657 658 659
            length_cond = layers.less_than(x=step_idx, y=max_len)
            finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
            layers.logical_and(x=length_cond, y=finish_cond, out=cond)
660

661
        finished_ids, finished_scores = layers.beam_search_decode(
Y
Yu Yang 已提交
662
            ids, scores, beam_size=beam_size, end_id=eos_idx)
663 664 665 666
        return finished_ids, finished_scores

    finished_ids, finished_scores = beam_search()
    return finished_ids, finished_scores