model.py 26.7 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7
from config import *
Y
ying 已提交
8

9 10

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
11
    """
12 13
    Generate the initial values for the sinusoid position encoding table.
    """
14 15 16 17 18 19 20 21 22 23 24 25
    channels = d_pos_vec
    position = np.arange(n_position)
    num_timescales = channels // 2
    log_timescale_increment = (np.log(float(1e4) / float(1)) /
                               (num_timescales - 1))
    inv_timescales = np.exp(np.arange(
        num_timescales)) * -log_timescale_increment
    scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
                                                               0)
    signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
    signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
    position_enc = signal
26 27 28 29 30 31 32 33 34 35
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
36
                         n_head=1,
G
guosheng 已提交
37
                         dropout_rate=0.,
38
                         cache=None):
39
    """
Y
ying 已提交
40 41 42
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
43
    """
44 45 46
    keys = queries if keys is None else keys
    values = keys if values is None else values

47 48
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
49
            "Inputs: quries, keys and values should all be 3-D tensors.")
50

G
guosheng 已提交
51
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
52
        """
53 54 55
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
56
                      size=d_key * n_head,
57 58 59
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
60
                      size=d_key * n_head,
61 62 63
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
64
                      size=d_value * n_head,
65 66 67 68
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
69
    def __split_heads(x, n_head):
70 71 72
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
73 74
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
75
        """
G
guosheng 已提交
76
        if n_head == 1:
77 78 79
            return x

        hidden_size = x.shape[-1]
80 81
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
82
        reshaped = layers.reshape(
83
            x=x, shape=[0, 0, n_head, hidden_size // n_head])
84 85

        # permuate the dimensions into:
G
guosheng 已提交
86
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
87 88 89 90 91 92 93 94 95 96 97 98
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
99 100
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
101
        return layers.reshape(
G
guosheng 已提交
102
            x=trans_x, shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]])
103

104
    def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
105 106 107
        """
        Scaled Dot-Product Attention
        """
108
        scaled_q = layers.scale(x=q, scale=d_key**-0.5)
109
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
110 111 112
        if attn_bias:
            product += attn_bias
        weights = layers.softmax(product)
113 114
        if dropout_rate:
            weights = layers.dropout(
G
guosheng 已提交
115 116 117 118
                weights,
                dropout_prob=dropout_rate,
                seed=ModelHyperParams.dropout_seed,
                is_test=False)
119 120 121
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
122
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
123

124 125 126
    if cache is not None:  # use cache and concat time steps
        k = cache["k"] = layers.concat([cache["k"], k], axis=1)
        v = cache["v"] = layers.concat([cache["v"], v], axis=1)
127

G
guosheng 已提交
128 129 130
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
131

G
guosheng 已提交
132
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
133 134 135
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)
136

137 138 139 140 141 142 143 144
    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


145
def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate):
146
    """
Y
ying 已提交
147 148 149
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
150 151 152 153 154
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
                       act="relu")
155 156 157 158 159 160
    if dropout_rate:
        hidden = layers.dropout(
            hidden,
            dropout_prob=dropout_rate,
            seed=ModelHyperParams.dropout_seed,
            is_test=False)
G
guosheng 已提交
161
    out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
162 163 164
    return out


165
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
166
    """
Y
ying 已提交
167
    Add residual connection, layer normalization and droput to the out tensor
168 169 170 171 172
    optionally according to the value of process_cmd.
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
173
        if cmd == "a":  # add residual connection
174
            out = out + prev_out if prev_out else out
Y
ying 已提交
175
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
176 177 178 179 180
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
181
        elif cmd == "d":  # add dropout
182 183
            if dropout_rate:
                out = layers.dropout(
G
guosheng 已提交
184 185 186 187
                    out,
                    dropout_prob=dropout_rate,
                    seed=ModelHyperParams.dropout_seed,
                    is_test=False)
188 189 190 191 192 193 194
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


195 196 197 198 199 200 201 202
def prepare_encoder_decoder(src_word,
                            src_pos,
                            src_vocab_size,
                            src_emb_dim,
                            src_max_len,
                            dropout_rate=0.,
                            word_emb_param_name=None,
                            pos_enc_param_name=None):
Y
ying 已提交
203 204
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
205
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
206
    This module is used at the bottom of the encoder stacks.
207 208
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
209 210
        src_word,
        size=[src_vocab_size, src_emb_dim],
211
        padding_idx=ModelHyperParams.bos_idx,  # set embedding of bos to 0
G
guosheng 已提交
212 213 214
        param_attr=fluid.ParamAttr(
            name=word_emb_param_name,
            initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
Y
Yu Yang 已提交
215

G
guosheng 已提交
216
    src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
217 218 219 220 221 222 223
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
    return layers.dropout(
G
guosheng 已提交
224 225 226
        enc_input,
        dropout_prob=dropout_rate,
        seed=ModelHyperParams.dropout_seed,
227
        is_test=False) if dropout_rate else enc_input
228 229 230


prepare_encoder = partial(
231
    prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0])
232
prepare_decoder = partial(
233
    prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[1])
234 235


Y
ying 已提交
236 237 238 239 240 241 242
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
243 244 245 246 247
                  prepostprocess_dropout,
                  attention_dropout,
                  relu_dropout,
                  preprocess_cmd="n",
                  postprocess_cmd="da"):
Y
ying 已提交
248 249 250 251 252
    """The encoder layers that can be stacked to form a deep encoder.
    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
253
    """
254 255 256 257 258 259 260 261 262 263 264
    attn_output = multi_head_attention(
        pre_process_layer(enc_input, preprocess_cmd,
                          prepostprocess_dropout), None, None, attn_bias, d_key,
        d_value, d_model, n_head, attention_dropout)
    attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd,
                                     prepostprocess_dropout)
    ffd_output = positionwise_feed_forward(
        pre_process_layer(attn_output, preprocess_cmd, prepostprocess_dropout),
        d_inner_hid, d_model, relu_dropout)
    return post_process_layer(attn_output, ffd_output, postprocess_cmd,
                              prepostprocess_dropout)
Y
ying 已提交
265 266 267 268 269 270 271 272 273 274


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
275 276 277 278 279
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd="n",
            postprocess_cmd="da"):
280
    """
Y
ying 已提交
281 282
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
283 284
    """
    for i in range(n_layer):
285 286 287 288 289 290 291 292 293 294 295 296 297
        enc_output = encoder_layer(
            enc_input,
            attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd,
            postprocess_cmd, )
298
        enc_input = enc_output
299 300
    enc_output = pre_process_layer(enc_output, preprocess_cmd,
                                   prepostprocess_dropout)
301 302 303
    return enc_output


Y
ying 已提交
304 305 306 307 308 309 310 311 312
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
313 314 315 316 317
                  prepostprocess_dropout,
                  attention_dropout,
                  relu_dropout,
                  preprocess_cmd,
                  postprocess_cmd,
318
                  cache=None):
Y
ying 已提交
319 320 321
    """ The layer to be stacked in decoder part.
    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
322
    """
Y
ying 已提交
323
    slf_attn_output = multi_head_attention(
324 325 326
        pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout),
        None,
        None,
Y
ying 已提交
327 328 329 330 331
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
332
        attention_dropout,
333
        cache, )
Y
ying 已提交
334 335 336
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
337 338
        postprocess_cmd,
        prepostprocess_dropout, )
Y
ying 已提交
339
    enc_attn_output = multi_head_attention(
340 341
        pre_process_layer(slf_attn_output, preprocess_cmd,
                          prepostprocess_dropout),
Y
ying 已提交
342 343 344 345 346 347 348
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
349
        attention_dropout, )
Y
ying 已提交
350 351 352
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
353 354
        postprocess_cmd,
        prepostprocess_dropout, )
Y
ying 已提交
355
    ffd_output = positionwise_feed_forward(
356 357
        pre_process_layer(enc_attn_output, preprocess_cmd,
                          prepostprocess_dropout),
Y
ying 已提交
358
        d_inner_hid,
359 360
        d_model,
        relu_dropout, )
Y
ying 已提交
361 362 363
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
364 365
        postprocess_cmd,
        prepostprocess_dropout, )
366 367 368
    return dec_output


Y
ying 已提交
369 370 371 372 373 374 375 376 377 378
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
379 380 381 382 383
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd,
            postprocess_cmd,
384
            caches=None):
385 386 387 388
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
389
        dec_output = decoder_layer(
390 391 392 393 394 395 396 397 398
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
399 400 401 402 403 404
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd,
            postprocess_cmd,
            cache=None if caches is None else caches[i])
405
        dec_input = dec_output
406 407
    dec_output = pre_process_layer(dec_output, preprocess_cmd,
                                   prepostprocess_dropout)
408 409 410
    return dec_output


411
def make_all_inputs(input_fields):
412 413 414
    """
    Define the input data layers for the transformer model.
    """
415 416 417 418 419 420
    inputs = []
    for input_field in input_fields:
        input_var = layers.data(
            name=input_field,
            shape=input_descs[input_field][0],
            dtype=input_descs[input_field][1],
421 422
            lod_level=input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0,
423
            append_batch_size=False)
424 425
        inputs.append(input_var)
    return inputs
426 427


428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
def make_all_py_reader_inputs(input_fields, is_test=False):
    reader = layers.py_reader(
        capacity=20,
        name="test_reader" if is_test else "train_reader",
        shapes=[input_descs[input_field][0] for input_field in input_fields],
        dtypes=[input_descs[input_field][1] for input_field in input_fields],
        lod_levels=[
            input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0
            for input_field in input_fields
        ])
    return layers.read_file(reader), reader


def transformer(src_vocab_size,
                trg_vocab_size,
                max_length,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
                prepostprocess_dropout,
                attention_dropout,
                relu_dropout,
                preprocess_cmd,
                postprocess_cmd,
                weight_sharing,
                label_smooth_eps,
                use_py_reader=False,
                is_test=False):
G
guosheng 已提交
460 461 462 463
    if weight_sharing:
        assert src_vocab_size == src_vocab_size, (
            "Vocabularies in source and target should be same for weight sharing."
        )
464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479

    data_input_names = encoder_data_input_fields + \
                decoder_data_input_fields[:-1] + label_data_input_fields

    if use_py_reader:
        all_inputs, reader = make_all_py_reader_inputs(data_input_names,
                                                       is_test)
    else:
        all_inputs = make_all_inputs(data_input_names)

    enc_inputs_len = len(encoder_data_input_fields)
    dec_inputs_len = len(decoder_data_input_fields[:-1])
    enc_inputs = all_inputs[0:enc_inputs_len]
    dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
    label = all_inputs[-2]
    weights = all_inputs[-1]
480

481 482 483 484 485 486 487 488 489
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
490 491 492 493 494
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
G
guosheng 已提交
495
        weight_sharing,
496
        enc_inputs, )
497 498 499 500 501 502 503 504 505 506

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
507 508 509 510 511
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
G
guosheng 已提交
512
        weight_sharing,
513
        dec_inputs,
514 515 516 517
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
518 519 520 521 522
    if label_smooth_eps:
        label = layers.label_smooth(
            label=layers.one_hot(
                input=label, depth=trg_vocab_size),
            epsilon=label_smooth_eps)
523

524
    cost = layers.softmax_with_cross_entropy(
525 526
        logits=layers.reshape(
            predict, shape=[-1, trg_vocab_size]),
527 528
        label=label,
        soft_label=True if label_smooth_eps else False)
529
    weighted_cost = cost * weights
G
guosheng 已提交
530 531
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
532
    token_num.stop_gradient = True
G
guosheng 已提交
533
    avg_cost = sum_cost / token_num
534
    return sum_cost, avg_cost, predict, token_num, reader if use_py_reader else None
535 536 537 538 539 540 541 542 543 544


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
545 546 547 548 549
                 prepostprocess_dropout,
                 attention_dropout,
                 relu_dropout,
                 preprocess_cmd,
                 postprocess_cmd,
G
guosheng 已提交
550
                 weight_sharing,
551
                 enc_inputs=None):
552 553 554
    """
    The wrapper assembles together all needed layers for the encoder.
    """
555
    if enc_inputs is None:
556
        # This is used to implement independent encoder program in inference.
557 558
        src_word, src_pos, src_slf_attn_bias = make_all_inputs(
            encoder_data_input_fields)
559
    else:
560
        src_word, src_pos, src_slf_attn_bias = enc_inputs
Y
ying 已提交
561 562 563 564 565 566
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
567
        prepostprocess_dropout,
G
guosheng 已提交
568
        word_emb_param_name=word_emb_param_names[0])
569 570 571 572 573 574 575 576 577 578 579 580 581 582
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd, )
583 584 585 586 587 588 589 590 591 592 593
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
594 595 596 597 598
                 prepostprocess_dropout,
                 attention_dropout,
                 relu_dropout,
                 preprocess_cmd,
                 postprocess_cmd,
G
guosheng 已提交
599
                 weight_sharing,
600
                 dec_inputs=None,
601 602
                 enc_output=None,
                 caches=None):
603 604 605
    """
    The wrapper assembles together all needed layers for the decoder.
    """
606
    if dec_inputs is None:
607
        # This is used to implement independent decoder program in inference.
608 609
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = \
            make_all_inputs(decoder_data_input_fields)
610
    else:
611
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
Y
ying 已提交
612 613 614 615 616 617 618

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
619
        prepostprocess_dropout,
G
guosheng 已提交
620 621
        word_emb_param_name=word_emb_param_names[0]
        if weight_sharing else word_emb_param_names[1])
Y
ying 已提交
622 623 624 625 626 627 628 629 630 631 632
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
633 634 635 636 637
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
Y
Yu Yang 已提交
638
        caches=caches)
G
guosheng 已提交
639
    if weight_sharing:
640 641
        predict = layers.matmul(
            x=dec_output,
C
chengduoZH 已提交
642 643
            y=fluid.default_main_program().global_block().var(
                word_emb_param_names[0]),
644
            transpose_y=True)
G
guosheng 已提交
645
    else:
646 647 648
        predict = layers.fc(input=dec_output,
                            size=trg_vocab_size,
                            bias_attr=False,
Y
Fix bug  
Yu Yang 已提交
649 650
                            num_flatten_dims=2)
    if dec_inputs is None:
651
        # Return probs for independent decoder program.
Y
Fix bug  
Yu Yang 已提交
652
        predict = layers.softmax(predict)
653
    return predict
654 655 656 657 658 659 660 661 662 663 664 665


def fast_decode(
        src_vocab_size,
        trg_vocab_size,
        max_in_len,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
666 667 668 669 670
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
671
        weight_sharing,
672 673 674
        beam_size,
        max_out_len,
        eos_idx, ):
675 676 677 678
    """
    Use beam search to decode. Caches will be used to store states of history
    steps which can make the decoding faster.
    """
679 680 681 682 683 684
    enc_output = wrap_encoder(
        src_vocab_size, max_in_len, n_layer, n_head, d_key, d_value, d_model,
        d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout,
        preprocess_cmd, postprocess_cmd, weight_sharing)
    start_tokens, init_scores, trg_src_attn_bias = make_all_inputs(
        fast_decoder_data_input_fields)
685 686 687

    def beam_search():
        max_len = layers.fill_constant(
688 689 690 691 692
            shape=[1], dtype=start_tokens.dtype, value=max_out_len)
        step_idx = layers.fill_constant(
            shape=[1], dtype=start_tokens.dtype, value=0)
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond)
693
        # array states will be stored for each step.
Y
Yu Yang 已提交
694
        ids = layers.array_write(
Y
Yu Yang 已提交
695
            layers.reshape(start_tokens, (-1, 1)), step_idx)
696
        scores = layers.array_write(init_scores, step_idx)
697 698 699
        # cell states will be overwrited at each step.
        # caches contains states of history steps to reduce redundant
        # computation in decoder.
700 701 702 703
        caches = [{
            "k": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
704
                dtype=enc_output.dtype,
705 706 707 708
                value=0),
            "v": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
709
                dtype=enc_output.dtype,
710 711 712 713
                value=0)
        } for i in range(n_layer)]
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
Y
Yu Yang 已提交
714
            pre_ids = layers.reshape(pre_ids, (-1, 1, 1))
715
            pre_scores = layers.array_read(array=scores, i=step_idx)
716 717
            # sequence_expand can gather sequences according to lod thus can be
            # used in beam search to sift states corresponding to selected ids.
718
            pre_src_attn_bias = layers.sequence_expand(
719 720
                x=trg_src_attn_bias, y=pre_scores)
            pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
721 722
            pre_caches = [{
                "k": layers.sequence_expand(
723
                    x=cache["k"], y=pre_scores),
724
                "v": layers.sequence_expand(
725
                    x=cache["v"], y=pre_scores),
726
            } for cache in caches]
727 728 729 730
            pre_pos = layers.elementwise_mul(
                x=layers.fill_constant_batch_size_like(
                    input=pre_enc_output,  # cann't use pre_ids here since it has lod
                    value=1,
Y
Yu Yang 已提交
731
                    shape=[-1, 1, 1],
732
                    dtype=pre_ids.dtype),
733
                y=step_idx,
734
                axis=0)
735 736 737 738 739 740 741 742 743
            logits = wrap_decoder(
                trg_vocab_size,
                max_in_len,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
744 745 746 747 748
                prepostprocess_dropout,
                attention_dropout,
                relu_dropout,
                preprocess_cmd,
                postprocess_cmd,
749
                weight_sharing,
Y
Yu Yang 已提交
750
                dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
751 752
                enc_output=pre_enc_output,
                caches=pre_caches)
Y
Yu Yang 已提交
753 754
            logits = layers.reshape(logits, (-1, trg_vocab_size))

755 756
            topk_scores, topk_indices = layers.topk(
                input=layers.softmax(logits), k=beam_size)
757
            accu_scores = layers.elementwise_add(
758
                x=layers.log(topk_scores),
759 760 761 762 763
                y=layers.reshape(
                    pre_scores, shape=[-1]),
                axis=0)
            # beam_search op uses lod to distinguish branches.
            topk_indices = layers.lod_reset(topk_indices, pre_ids)
764 765
            selected_ids, selected_scores = layers.beam_search(
                pre_ids=pre_ids,
766
                pre_scores=pre_scores,
767 768 769 770
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
                end_id=eos_idx)
Y
Yu Yang 已提交
771

772 773
            layers.increment(x=step_idx, value=1.0, in_place=True)
            # update states
774 775
            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
776 777 778 779 780
            layers.assign(pre_src_attn_bias, trg_src_attn_bias)
            layers.assign(pre_enc_output, enc_output)
            for i in range(n_layer):
                layers.assign(pre_caches[i]["k"], caches[i]["k"])
                layers.assign(pre_caches[i]["v"], caches[i]["v"])
781 782 783
            length_cond = layers.less_than(x=step_idx, y=max_len)
            finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
            layers.logical_and(x=length_cond, y=finish_cond, out=cond)
784

785
        finished_ids, finished_scores = layers.beam_search_decode(
Y
Yu Yang 已提交
786
            ids, scores, beam_size=beam_size, end_id=eos_idx)
787 788 789 790
        return finished_ids, finished_scores

    finished_ids, finished_scores = beam_search()
    return finished_ids, finished_scores