model.py 24.9 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7
from config import *
Y
ying 已提交
8

9 10

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
11
    """
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
30
                         n_head=1,
G
guosheng 已提交
31 32
                         dropout_rate=0.,
                         pre_softmax_shape=None,
33 34
                         post_softmax_shape=None,
                         cache=None):
35
    """
Y
ying 已提交
36 37 38
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
39 40 41
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
42
            "Inputs: quries, keys and values should all be 3-D tensors.")
43

G
guosheng 已提交
44
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
45
        """
46 47 48
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
49 50 51 52 53
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
54 55 56
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
57 58 59 60 61
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
62 63 64
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
65 66 67 68 69
                      size=d_value * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_value,
                          fan_out=n_head * d_value),
70 71 72 73
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
74
    def __split_heads(x, n_head):
75 76 77
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
78 79
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
80
        """
G
guosheng 已提交
81
        if n_head == 1:
82 83 84
            return x

        hidden_size = x.shape[-1]
85 86
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
87
        reshaped = layers.reshape(
88
            x=x, shape=[0, 0, n_head, hidden_size // n_head])
89 90

        # permuate the dimensions into:
G
guosheng 已提交
91
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
92 93 94 95 96 97 98 99 100 101 102 103
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
104 105
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
106 107
        return layers.reshape(
            x=trans_x,
108
            shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
109

G
guosheng 已提交
110
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
111 112 113
        """
        Scaled Dot-Product Attention
        """
G
guosheng 已提交
114
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
115
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
G
guosheng 已提交
116 117 118 119 120 121 122 123
        weights = layers.reshape(
            x=layers.elementwise_add(
                x=product, y=attn_bias) if attn_bias else product,
            shape=[-1, product.shape[-1]],
            actual_shape=pre_softmax_shape,
            act="softmax")
        weights = layers.reshape(
            x=weights, shape=product.shape, actual_shape=post_softmax_shape)
124 125 126
        if dropout_rate:
            weights = layers.dropout(
                weights, dropout_prob=dropout_rate, is_test=False)
127

128 129 130
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
131
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
132

133 134 135
    if cache is not None:  # use cache and concat time steps
        k = cache["k"] = layers.concat([cache["k"], k], axis=1)
        v = cache["v"] = layers.concat([cache["v"], v], axis=1)
G
guosheng 已提交
136 137 138
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
139

G
guosheng 已提交
140
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
141 142 143 144 145 146
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)
    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
G
guosheng 已提交
147
                         param_attr=fluid.initializer.Xavier(uniform=False),
148 149 150 151 152 153 154
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
155 156 157
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
158 159 160 161
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
G
guosheng 已提交
162 163
                       param_attr=fluid.initializer.Uniform(
                           low=-(d_hid**-0.5), high=(d_hid**-0.5)),
164
                       act="relu")
G
guosheng 已提交
165 166 167 168 169
    out = layers.fc(input=hidden,
                    size=d_hid,
                    num_flatten_dims=2,
                    param_attr=fluid.initializer.Uniform(
                        low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)))
170 171 172
    return out


173
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
174
    """
Y
ying 已提交
175
    Add residual connection, layer normalization and droput to the out tensor
176 177 178 179 180
    optionally according to the value of process_cmd.
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
181
        if cmd == "a":  # add residual connection
182
            out = out + prev_out if prev_out else out
Y
ying 已提交
183
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
184 185 186 187 188
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
189
        elif cmd == "d":  # add dropout
190 191 192
            if dropout_rate:
                out = layers.dropout(
                    out, dropout_prob=dropout_rate, is_test=False)
193 194 195 196 197 198 199 200 201 202 203 204
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_max_len,
205 206
                    dropout_rate=0.,
                    src_data_shape=None,
207
                    pos_enc_param_name=None):
Y
ying 已提交
208 209
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
210
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
211
    This module is used at the bottom of the encoder stacks.
212 213
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
214 215 216
        src_word,
        size=[src_vocab_size, src_emb_dim],
        param_attr=fluid.initializer.Normal(0., 1.))
217 218 219 220 221 222
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
223 224
    enc_input = layers.reshape(
        x=enc_input,
225
        shape=[batch_size, seq_len, src_emb_dim],
226
        actual_shape=src_data_shape)
227
    return layers.dropout(
228 229
        enc_input, dropout_prob=dropout_rate,
        is_test=False) if dropout_rate else enc_input
230 231 232 233 234 235 236 237


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
238 239 240 241 242 243 244
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
245 246 247
                  dropout_rate=0.,
                  pre_softmax_shape=None,
                  post_softmax_shape=None):
Y
ying 已提交
248 249 250 251 252
    """The encoder layers that can be stacked to form a deep encoder.
    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
253
    """
G
guosheng 已提交
254 255 256
    attn_output = multi_head_attention(
        enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model,
        n_head, dropout_rate, pre_softmax_shape, post_softmax_shape)
Y
ying 已提交
257 258
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
259
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
260 261 262 263 264 265 266 267 268 269 270
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
271 272 273
            dropout_rate=0.,
            pre_softmax_shape=None,
            post_softmax_shape=None):
274
    """
Y
ying 已提交
275 276
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
277 278
    """
    for i in range(n_layer):
279 280 281 282 283 284 285 286
        enc_output = encoder_layer(
            enc_input,
            attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
287 288 289
            dropout_rate,
            pre_softmax_shape,
            post_softmax_shape, )
290 291 292 293
        enc_input = enc_output
    return enc_output


Y
ying 已提交
294 295 296 297 298 299 300 301 302
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
303 304 305 306
                  dropout_rate=0.,
                  slf_attn_pre_softmax_shape=None,
                  slf_attn_post_softmax_shape=None,
                  src_attn_pre_softmax_shape=None,
307 308
                  src_attn_post_softmax_shape=None,
                  cache=None):
Y
ying 已提交
309 310 311
    """ The layer to be stacked in decoder part.
    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
312
    """
Y
ying 已提交
313 314 315 316 317 318 319 320 321
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
322 323
        dropout_rate,
        slf_attn_pre_softmax_shape,
324 325
        slf_attn_post_softmax_shape,
        cache, )
Y
ying 已提交
326 327 328 329 330 331 332 333 334 335 336 337 338 339
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
340 341 342
        dropout_rate,
        src_attn_pre_softmax_shape,
        src_attn_post_softmax_shape, )
Y
ying 已提交
343 344 345 346 347 348 349 350 351 352 353 354 355 356
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
357 358 359
    return dec_output


Y
ying 已提交
360 361 362 363 364 365 366 367 368 369
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
370 371 372 373
            dropout_rate=0.,
            slf_attn_pre_softmax_shape=None,
            slf_attn_post_softmax_shape=None,
            src_attn_pre_softmax_shape=None,
374 375
            src_attn_post_softmax_shape=None,
            caches=None):
376 377 378 379
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
380
        dec_output = decoder_layer(
381 382 383 384 385
            dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias, n_head,
            d_key, d_value, d_model, d_inner_hid, dropout_rate,
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
            src_attn_pre_softmax_shape, src_attn_post_softmax_shape, None
            if caches is None else caches[i])
386 387 388 389
        dec_input = dec_output
    return dec_output


390
def make_all_inputs(input_fields):
391 392 393
    """
    Define the input data layers for the transformer model.
    """
394 395 396 397 398 399
    inputs = []
    for input_field in input_fields:
        input_var = layers.data(
            name=input_field,
            shape=input_descs[input_field][0],
            dtype=input_descs[input_field][1],
400 401
            lod_level=input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0,
402
            append_batch_size=False)
403 404
        inputs.append(input_var)
    return inputs
405 406


Y
ying 已提交
407 408 409 410 411 412 413 414 415 416
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
417 418 419 420
        dropout_rate,
        label_smooth_eps, ):
    enc_inputs = make_all_inputs(encoder_data_input_fields +
                                 encoder_util_input_fields)
421

422 423 424 425 426 427 428 429 430 431
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
432
        enc_inputs, )
433

434 435
    dec_inputs = make_all_inputs(decoder_data_input_fields[:-1] +
                                 decoder_util_input_fields)
436 437 438 439 440 441 442 443 444 445 446

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
447
        dec_inputs,
448 449 450 451
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
452 453 454 455 456 457 458 459 460 461
    label, weights = make_all_inputs(label_data_input_fields)
    if label_smooth_eps:
        label = layers.label_smooth(
            label=layers.one_hot(
                input=label, depth=trg_vocab_size),
            epsilon=label_smooth_eps)
    cost = layers.softmax_with_cross_entropy(
        logits=predict,
        label=label,
        soft_label=True if label_smooth_eps else False)
462
    weighted_cost = cost * weights
G
guosheng 已提交
463 464 465
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
    avg_cost = sum_cost / token_num
G
guosheng 已提交
466
    return sum_cost, avg_cost, predict, token_num
467 468 469 470 471 472 473 474 475 476 477


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
478
                 enc_inputs=None):
479 480 481
    """
    The wrapper assembles together all needed layers for the encoder.
    """
482
    if enc_inputs is None:
483
        # This is used to implement independent encoder program in inference.
484 485
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
486 487
            make_all_inputs(encoder_data_input_fields +
                                 encoder_util_input_fields)
488
    else:
489 490 491
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
            enc_inputs
Y
ying 已提交
492 493 494 495 496 497
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
498 499
        dropout_rate,
        src_data_shape, )
Y
ying 已提交
500 501 502 503 504 505 506 507 508
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
509 510 511
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape, )
512 513 514 515 516 517 518 519 520 521 522 523
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
524
                 dec_inputs=None,
525 526
                 enc_output=None,
                 caches=None):
527 528 529
    """
    The wrapper assembles together all needed layers for the decoder.
    """
530
    if dec_inputs is None:
531
        # This is used to implement independent decoder program in inference.
G
guosheng 已提交
532
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
533
            enc_output, trg_data_shape, slf_attn_pre_softmax_shape, \
534
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
535 536
            src_attn_post_softmax_shape = make_all_inputs(
            decoder_data_input_fields + decoder_util_input_fields)
537
    else:
G
guosheng 已提交
538
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
539 540 541
            trg_data_shape, slf_attn_pre_softmax_shape, \
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
            src_attn_post_softmax_shape = dec_inputs
Y
ying 已提交
542 543 544 545 546 547 548

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
549 550
        dropout_rate,
        trg_data_shape, )
Y
ying 已提交
551 552 553 554 555 556 557 558 559 560 561
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
562 563 564 565
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape,
        src_attn_pre_softmax_shape,
566 567
        src_attn_post_softmax_shape,
        caches, )
568
    # Return logits for training and probs for inference.
569
    predict = layers.reshape(
G
guosheng 已提交
570 571 572 573 574
        x=layers.fc(input=dec_output,
                    size=trg_vocab_size,
                    bias_attr=False,
                    num_flatten_dims=2),
        shape=[-1, trg_vocab_size],
575
        act="softmax" if dec_inputs is None else None)
576
    return predict
577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595


def fast_decode(
        src_vocab_size,
        trg_vocab_size,
        max_in_len,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
        beam_size,
        max_out_len,
        eos_idx, ):
    enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
                              d_key, d_value, d_model, d_inner_hid,
                              dropout_rate)
596
    start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
597
        slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
598 599 600 601
        src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
        attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
        make_all_inputs(fast_decoder_data_input_fields +
                            fast_decoder_util_input_fields)
602 603 604

    def beam_search():
        max_len = layers.fill_constant(
605 606 607 608 609 610 611 612 613
            shape=[1], dtype=start_tokens.dtype, value=max_out_len)
        step_idx = layers.fill_constant(
            shape=[1], dtype=start_tokens.dtype, value=0)
        # cond = layers.fill_constant(
        #     shape=[1], dtype='bool', value=1, force_cpu=True)
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond)
        # init_scores = layers.fill_constant_batch_size_like(
        #     input=start_tokens, shape=[-1, 1], dtype="float32", value=0)
614 615 616 617 618 619 620 621
        # array states
        ids = layers.array_write(start_tokens, step_idx)
        scores = layers.array_write(init_scores, step_idx)
        # cell states (can be overwrited)
        caches = [{
            "k": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
622
                dtype=enc_output.dtype,
623 624 625 626
                value=0),
            "v": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
627
                dtype=enc_output.dtype,
628 629 630 631 632 633 634
                value=0)
        } for i in range(n_layer)]
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
            pre_scores = layers.array_read(array=scores, i=step_idx)
            pre_pos = layers.elementwise_mul(
                x=layers.fill_constant_batch_size_like(
635
                    input=pre_ids, value=1, shape=[-1, 1], dtype=pre_ids.dtype),
636
                y=layers.increment(
637 638
                    x=step_idx, value=1.0, in_place=False),
                axis=0)
639
            pre_src_attn_bias = layers.sequence_expand(
640 641
                x=trg_src_attn_bias, y=pre_scores)
            pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
642 643
            pre_caches = [{
                "k": layers.sequence_expand(
644
                    x=cache["k"], y=pre_scores),
645
                "v": layers.sequence_expand(
646
                    x=cache["v"], y=pre_scores),
647
            } for cache in caches]
648 649 650 651 652 653
            layers.Print(pre_ids)
            # layers.Print(pre_enc_output)
            # layers.Print(pre_src_attn_bias)
            # layers.Print(pre_caches[0]["k"])
            # layers.Print(pre_caches[0]["v"])
            # layers.Print(slf_attn_post_softmax_shape)
654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671
            logits = wrap_decoder(
                trg_vocab_size,
                max_in_len,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
                dropout_rate,
                dec_inputs=(
                    pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
                    slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
                    src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
                enc_output=pre_enc_output,
                caches=pre_caches)
            topk_scores, topk_indices = layers.topk(logits, k=beam_size)
            accu_scores = layers.elementwise_add(
672 673 674 675 676 677
                x=layers.log(x=layers.softmax(topk_scores)),
                y=layers.reshape(
                    pre_scores, shape=[-1]),
                axis=0)
            # beam_search op uses lod to distinguish branches.
            topk_indices = layers.lod_reset(topk_indices, pre_ids)
678 679 680 681 682 683 684 685
            selected_ids, selected_scores = layers.beam_search(
                pre_ids=pre_ids,
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
                end_id=eos_idx)
            layers.increment(x=step_idx, value=1.0, in_place=True)
            # update states
686 687
            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
688 689 690 691 692
            layers.assign(pre_src_attn_bias, trg_src_attn_bias)
            layers.assign(pre_enc_output, enc_output)
            for i in range(n_layer):
                layers.assign(pre_caches[i]["k"], caches[i]["k"])
                layers.assign(pre_caches[i]["v"], caches[i]["v"])
693 694 695 696 697 698 699 700
            layers.assign(
                slf_attn_pre_softmax_shape + attn_pre_softmax_shape_delta,
                slf_attn_pre_softmax_shape)
            layers.assign(
                layers.elementwise_add(
                    x=slf_attn_post_softmax_shape,
                    y=attn_post_softmax_shape_delta),
                slf_attn_post_softmax_shape)
701 702 703 704 705

            max_len_cond = layers.less_than(x=step_idx, y=max_len)
            all_finish_cond = layers.less_than(x=step_idx, y=max_len)
            layers.logical_or(x=max_len_cond, y=all_finish_cond, out=cond)

706 707 708 709 710
        finished_ids, finished_scores = layers.beam_search_decode(ids, scores)
        return finished_ids, finished_scores

    finished_ids, finished_scores = beam_search()
    return finished_ids, finished_scores