model.py 25.9 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7
from config import *
Y
ying 已提交
8

9 10
FLAG = False

11 12

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
13
    """
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
32
                         n_head=1,
G
guosheng 已提交
33 34
                         dropout_rate=0.,
                         pre_softmax_shape=None,
35 36
                         post_softmax_shape=None,
                         cache=None):
37
    """
Y
ying 已提交
38 39 40
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
41 42 43
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
44
            "Inputs: quries, keys and values should all be 3-D tensors.")
45

G
guosheng 已提交
46
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
47
        """
48 49 50
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
51 52 53 54 55
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
56 57 58
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
59 60 61 62 63
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
64 65 66
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
67 68 69 70 71
                      size=d_value * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_value,
                          fan_out=n_head * d_value),
72 73 74 75
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
76
    def __split_heads(x, n_head):
77 78 79
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
80 81
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
82
        """
G
guosheng 已提交
83
        if n_head == 1:
84 85 86
            return x

        hidden_size = x.shape[-1]
87 88
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
89
        reshaped = layers.reshape(
90
            x=x, shape=[0, 0, n_head, hidden_size // n_head])
91 92

        # permuate the dimensions into:
G
guosheng 已提交
93
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
94 95 96 97 98 99 100 101 102 103 104 105
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
106 107
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
108 109
        return layers.reshape(
            x=trans_x,
110
            shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
111

G
guosheng 已提交
112
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
113 114 115
        """
        Scaled Dot-Product Attention
        """
G
guosheng 已提交
116
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
117
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
G
guosheng 已提交
118 119 120 121 122 123 124 125
        weights = layers.reshape(
            x=layers.elementwise_add(
                x=product, y=attn_bias) if attn_bias else product,
            shape=[-1, product.shape[-1]],
            actual_shape=pre_softmax_shape,
            act="softmax")
        weights = layers.reshape(
            x=weights, shape=product.shape, actual_shape=post_softmax_shape)
126 127 128 129 130 131 132 133 134
        global FLAG
        if FLAG:
            print "hehehehehe"
            layers.Print(scaled_q)
            layers.Print(k)
            layers.Print(v)
            layers.Print(product)
            layers.Print(weights)
            FLAG = False
135 136 137
        if dropout_rate:
            weights = layers.dropout(
                weights, dropout_prob=dropout_rate, is_test=False)
138

139 140 141
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
142
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
143

144 145 146
    if cache is not None:  # use cache and concat time steps
        k = cache["k"] = layers.concat([cache["k"], k], axis=1)
        v = cache["v"] = layers.concat([cache["v"], v], axis=1)
147 148 149 150 151 152 153
    # global FLAG
    # if FLAG:
    #     print "hehehehehe"
    #     layers.Print(q)
    #     layers.Print(k)
    #     layers.Print(v)
    #     FLAG = False
G
guosheng 已提交
154 155 156
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
157

G
guosheng 已提交
158
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
159 160 161 162 163 164
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)
    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
G
guosheng 已提交
165
                         param_attr=fluid.initializer.Xavier(uniform=False),
166 167
                         bias_attr=False,
                         num_flatten_dims=2)
168 169 170 171 172
    # global FLAG
    # if FLAG:
    #     print "hehehehehe"
    #     layers.Print(proj_out)
    #     FLAG = False
173 174 175 176 177
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
178 179 180
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
181 182 183 184
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
G
guosheng 已提交
185 186
                       param_attr=fluid.initializer.Uniform(
                           low=-(d_hid**-0.5), high=(d_hid**-0.5)),
187
                       act="relu")
G
guosheng 已提交
188 189 190 191 192
    out = layers.fc(input=hidden,
                    size=d_hid,
                    num_flatten_dims=2,
                    param_attr=fluid.initializer.Uniform(
                        low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)))
193 194 195
    return out


196
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
197
    """
Y
ying 已提交
198
    Add residual connection, layer normalization and droput to the out tensor
199 200 201 202 203
    optionally according to the value of process_cmd.
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
204
        if cmd == "a":  # add residual connection
205
            out = out + prev_out if prev_out else out
Y
ying 已提交
206
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
207 208 209 210 211
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
212
        elif cmd == "d":  # add dropout
213 214 215
            if dropout_rate:
                out = layers.dropout(
                    out, dropout_prob=dropout_rate, is_test=False)
216 217 218 219 220 221 222 223 224 225 226 227
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_max_len,
228 229
                    dropout_rate=0.,
                    src_data_shape=None,
230
                    pos_enc_param_name=None):
Y
ying 已提交
231 232
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
233
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
234
    This module is used at the bottom of the encoder stacks.
235 236
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
237 238 239
        src_word,
        size=[src_vocab_size, src_emb_dim],
        param_attr=fluid.initializer.Normal(0., 1.))
240 241 242 243 244 245
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
246 247
    enc_input = layers.reshape(
        x=enc_input,
248
        shape=[batch_size, seq_len, src_emb_dim],
249
        actual_shape=src_data_shape)
250
    return layers.dropout(
251 252
        enc_input, dropout_prob=dropout_rate,
        is_test=False) if dropout_rate else enc_input
253 254 255 256 257 258 259 260


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
261 262 263 264 265 266 267
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
268 269 270
                  dropout_rate=0.,
                  pre_softmax_shape=None,
                  post_softmax_shape=None):
Y
ying 已提交
271 272 273 274 275
    """The encoder layers that can be stacked to form a deep encoder.
    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
276
    """
G
guosheng 已提交
277 278 279
    attn_output = multi_head_attention(
        enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model,
        n_head, dropout_rate, pre_softmax_shape, post_softmax_shape)
Y
ying 已提交
280 281
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
282
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
283 284 285 286 287 288 289 290 291 292 293
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
294 295 296
            dropout_rate=0.,
            pre_softmax_shape=None,
            post_softmax_shape=None):
297
    """
Y
ying 已提交
298 299
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
300 301
    """
    for i in range(n_layer):
302 303 304 305 306 307 308 309
        enc_output = encoder_layer(
            enc_input,
            attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
310 311 312
            dropout_rate,
            pre_softmax_shape,
            post_softmax_shape, )
313 314 315 316
        enc_input = enc_output
    return enc_output


Y
ying 已提交
317 318 319 320 321 322 323 324 325
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
326 327 328 329
                  dropout_rate=0.,
                  slf_attn_pre_softmax_shape=None,
                  slf_attn_post_softmax_shape=None,
                  src_attn_pre_softmax_shape=None,
330 331
                  src_attn_post_softmax_shape=None,
                  cache=None):
Y
ying 已提交
332 333 334
    """ The layer to be stacked in decoder part.
    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
335
    """
Y
ying 已提交
336 337 338 339 340 341 342 343 344
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
345 346
        dropout_rate,
        slf_attn_pre_softmax_shape,
347 348
        slf_attn_post_softmax_shape,
        cache, )
Y
ying 已提交
349 350 351 352 353 354 355 356 357 358 359 360 361 362
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
363 364 365
        dropout_rate,
        src_attn_pre_softmax_shape,
        src_attn_post_softmax_shape, )
Y
ying 已提交
366 367 368 369 370 371 372 373 374 375 376 377 378 379
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
380 381 382
    return dec_output


Y
ying 已提交
383 384 385 386 387 388 389 390 391 392
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
393 394 395 396
            dropout_rate=0.,
            slf_attn_pre_softmax_shape=None,
            slf_attn_post_softmax_shape=None,
            src_attn_pre_softmax_shape=None,
397 398
            src_attn_post_softmax_shape=None,
            caches=None):
399 400 401 402
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
403 404 405
        if i == 0:  #n_layer-1:
            global FLAG
            FLAG = True
Y
ying 已提交
406
        dec_output = decoder_layer(
407 408 409 410 411
            dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias, n_head,
            d_key, d_value, d_model, d_inner_hid, dropout_rate,
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
            src_attn_pre_softmax_shape, src_attn_post_softmax_shape, None
            if caches is None else caches[i])
412 413 414 415
        dec_input = dec_output
    return dec_output


416
def make_all_inputs(input_fields):
417 418 419
    """
    Define the input data layers for the transformer model.
    """
420 421 422 423 424 425
    inputs = []
    for input_field in input_fields:
        input_var = layers.data(
            name=input_field,
            shape=input_descs[input_field][0],
            dtype=input_descs[input_field][1],
426 427
            lod_level=input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0,
428
            append_batch_size=False)
429 430
        inputs.append(input_var)
    return inputs
431 432


Y
ying 已提交
433 434 435 436 437 438 439 440 441 442
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
443 444 445 446
        dropout_rate,
        label_smooth_eps, ):
    enc_inputs = make_all_inputs(encoder_data_input_fields +
                                 encoder_util_input_fields)
447

448 449 450 451 452 453 454 455 456 457
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
458
        enc_inputs, )
459

460 461
    dec_inputs = make_all_inputs(decoder_data_input_fields[:-1] +
                                 decoder_util_input_fields)
462 463 464 465 466 467 468 469 470 471 472

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
473
        dec_inputs,
474 475 476 477
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
478 479 480 481 482 483 484 485 486 487
    label, weights = make_all_inputs(label_data_input_fields)
    if label_smooth_eps:
        label = layers.label_smooth(
            label=layers.one_hot(
                input=label, depth=trg_vocab_size),
            epsilon=label_smooth_eps)
    cost = layers.softmax_with_cross_entropy(
        logits=predict,
        label=label,
        soft_label=True if label_smooth_eps else False)
488
    weighted_cost = cost * weights
G
guosheng 已提交
489 490 491
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
    avg_cost = sum_cost / token_num
G
guosheng 已提交
492
    return sum_cost, avg_cost, predict, token_num
493 494 495 496 497 498 499 500 501 502 503


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
504
                 enc_inputs=None):
505 506 507
    """
    The wrapper assembles together all needed layers for the encoder.
    """
508
    if enc_inputs is None:
509
        # This is used to implement independent encoder program in inference.
510 511
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
512 513
            make_all_inputs(encoder_data_input_fields +
                                 encoder_util_input_fields)
514
    else:
515 516 517
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
            enc_inputs
Y
ying 已提交
518 519 520 521 522 523
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
524 525
        dropout_rate,
        src_data_shape, )
Y
ying 已提交
526 527 528 529 530 531 532 533 534
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
535 536 537
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape, )
538 539 540 541 542 543 544 545 546 547 548 549
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
550
                 dec_inputs=None,
551 552
                 enc_output=None,
                 caches=None):
553 554 555
    """
    The wrapper assembles together all needed layers for the decoder.
    """
556
    if dec_inputs is None:
557
        # This is used to implement independent decoder program in inference.
G
guosheng 已提交
558
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
559
            enc_output, trg_data_shape, slf_attn_pre_softmax_shape, \
560
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
561 562
            src_attn_post_softmax_shape = make_all_inputs(
            decoder_data_input_fields + decoder_util_input_fields)
563
    else:
G
guosheng 已提交
564
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
565 566 567
            trg_data_shape, slf_attn_pre_softmax_shape, \
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
            src_attn_post_softmax_shape = dec_inputs
Y
ying 已提交
568 569 570 571 572 573 574

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
575 576
        dropout_rate,
        trg_data_shape, )
Y
ying 已提交
577 578 579 580 581 582 583 584 585 586 587
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
588 589 590 591
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape,
        src_attn_pre_softmax_shape,
592 593
        src_attn_post_softmax_shape,
        caches, )
594
    # Return logits for training and probs for inference.
595
    predict = layers.reshape(
G
guosheng 已提交
596 597 598 599 600
        x=layers.fc(input=dec_output,
                    size=trg_vocab_size,
                    bias_attr=False,
                    num_flatten_dims=2),
        shape=[-1, trg_vocab_size],
601
        act="softmax")  # if dec_inputs is None else None)
602
    return predict
603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621


def fast_decode(
        src_vocab_size,
        trg_vocab_size,
        max_in_len,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
        beam_size,
        max_out_len,
        eos_idx, ):
    enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
                              d_key, d_value, d_model, d_inner_hid,
                              dropout_rate)
622
    start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
623
        slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
624 625 626 627
        src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
        attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
        make_all_inputs(fast_decoder_data_input_fields +
                            fast_decoder_util_input_fields)
628 629 630

    def beam_search():
        max_len = layers.fill_constant(
631 632 633 634 635 636 637 638 639
            shape=[1], dtype=start_tokens.dtype, value=max_out_len)
        step_idx = layers.fill_constant(
            shape=[1], dtype=start_tokens.dtype, value=0)
        # cond = layers.fill_constant(
        #     shape=[1], dtype='bool', value=1, force_cpu=True)
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond)
        # init_scores = layers.fill_constant_batch_size_like(
        #     input=start_tokens, shape=[-1, 1], dtype="float32", value=0)
640 641 642 643 644 645 646 647
        # array states
        ids = layers.array_write(start_tokens, step_idx)
        scores = layers.array_write(init_scores, step_idx)
        # cell states (can be overwrited)
        caches = [{
            "k": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
648
                dtype=enc_output.dtype,
649 650 651 652
                value=0),
            "v": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
653
                dtype=enc_output.dtype,
654 655 656 657 658 659 660
                value=0)
        } for i in range(n_layer)]
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
            pre_scores = layers.array_read(array=scores, i=step_idx)
            pre_pos = layers.elementwise_mul(
                x=layers.fill_constant_batch_size_like(
661
                    input=pre_ids, value=1, shape=[-1, 1], dtype=pre_ids.dtype),
662
                y=layers.increment(
663 664
                    x=step_idx, value=1.0, in_place=False),
                axis=0)
665
            pre_src_attn_bias = layers.sequence_expand(
666 667
                x=trg_src_attn_bias, y=pre_scores)
            pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
668 669
            pre_caches = [{
                "k": layers.sequence_expand(
670
                    x=cache["k"], y=pre_scores),
671
                "v": layers.sequence_expand(
672
                    x=cache["v"], y=pre_scores),
673
            } for cache in caches]
674 675
            # layers.Print(pre_ids)
            # layers.Print(pre_pos)
676 677 678 679 680
            # layers.Print(pre_enc_output)
            # layers.Print(pre_src_attn_bias)
            # layers.Print(pre_caches[0]["k"])
            # layers.Print(pre_caches[0]["v"])
            # layers.Print(slf_attn_post_softmax_shape)
681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696
            logits = wrap_decoder(
                trg_vocab_size,
                max_in_len,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
                dropout_rate,
                dec_inputs=(
                    pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
                    slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
                    src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
                enc_output=pre_enc_output,
                caches=pre_caches)
697
            layers.Print(logits)
698
            topk_scores, topk_indices = layers.topk(logits, k=beam_size)
699 700
            # layers.Print(topk_scores)
            # layers.Print(topk_indices)
701
            accu_scores = layers.elementwise_add(
702 703
                # x=layers.log(x=layers.softmax(topk_scores)),
                x=layers.log(topk_scores),
704 705 706 707 708
                y=layers.reshape(
                    pre_scores, shape=[-1]),
                axis=0)
            # beam_search op uses lod to distinguish branches.
            topk_indices = layers.lod_reset(topk_indices, pre_ids)
709 710 711 712 713 714 715 716
            selected_ids, selected_scores = layers.beam_search(
                pre_ids=pre_ids,
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
                end_id=eos_idx)
            layers.increment(x=step_idx, value=1.0, in_place=True)
            # update states
717 718
            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
719 720 721 722 723
            layers.assign(pre_src_attn_bias, trg_src_attn_bias)
            layers.assign(pre_enc_output, enc_output)
            for i in range(n_layer):
                layers.assign(pre_caches[i]["k"], caches[i]["k"])
                layers.assign(pre_caches[i]["v"], caches[i]["v"])
724 725 726
            layers.Print(selected_ids)
            layers.Print(selected_scores)
            # layers.Print(caches[-1]["k"])
727
            layers.assign(
728 729 730
                layers.elementwise_add(
                    x=slf_attn_pre_softmax_shape,
                    y=attn_pre_softmax_shape_delta),
731 732 733 734 735 736
                slf_attn_pre_softmax_shape)
            layers.assign(
                layers.elementwise_add(
                    x=slf_attn_post_softmax_shape,
                    y=attn_post_softmax_shape_delta),
                slf_attn_post_softmax_shape)
737 738 739 740 741

            max_len_cond = layers.less_than(x=step_idx, y=max_len)
            all_finish_cond = layers.less_than(x=step_idx, y=max_len)
            layers.logical_or(x=max_len_cond, y=all_finish_cond, out=cond)

742 743
        finished_ids, finished_scores = layers.beam_search_decode(ids, scores,
                                                                  eos_idx)
744 745 746 747
        return finished_ids, finished_scores

    finished_ids, finished_scores = beam_search()
    return finished_ids, finished_scores