model.py 26.8 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7
from config import *
Y
ying 已提交
8

9 10

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
11
    """
12 13
    Generate the initial values for the sinusoid position encoding table.
    """
14 15 16 17 18 19 20 21 22 23 24 25
    channels = d_pos_vec
    position = np.arange(n_position)
    num_timescales = channels // 2
    log_timescale_increment = (np.log(float(1e4) / float(1)) /
                               (num_timescales - 1))
    inv_timescales = np.exp(np.arange(
        num_timescales)) * -log_timescale_increment
    scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
                                                               0)
    signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
    signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
    position_enc = signal
26 27 28 29 30 31 32 33 34 35
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
36
                         n_head=1,
G
guosheng 已提交
37
                         dropout_rate=0.,
38
                         cache=None):
39
    """
Y
ying 已提交
40 41 42
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
43
    """
44 45 46
    keys = queries if keys is None else keys
    values = keys if values is None else values

47 48
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
49
            "Inputs: quries, keys and values should all be 3-D tensors.")
50

G
guosheng 已提交
51
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
52
        """
53 54 55
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
56
                      size=d_key * n_head,
57 58 59
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
60
                      size=d_key * n_head,
61 62 63
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
64
                      size=d_value * n_head,
65 66 67 68
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
69
    def __split_heads(x, n_head):
70 71 72
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
73 74
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
75
        """
G
guosheng 已提交
76
        if n_head == 1:
77 78 79
            return x

        hidden_size = x.shape[-1]
80 81
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
82
        reshaped = layers.reshape(
83
            x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
84 85

        # permuate the dimensions into:
G
guosheng 已提交
86
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
87 88 89 90 91 92 93 94 95 96 97 98
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
99 100
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
101
        return layers.reshape(
102 103 104
            x=trans_x,
            shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
            inplace=True)
105

106
    def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
107 108 109
        """
        Scaled Dot-Product Attention
        """
110
        scaled_q = layers.scale(x=q, scale=d_key**-0.5)
111
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
112 113 114
        if attn_bias:
            product += attn_bias
        weights = layers.softmax(product)
115 116
        if dropout_rate:
            weights = layers.dropout(
G
guosheng 已提交
117 118 119 120
                weights,
                dropout_prob=dropout_rate,
                seed=ModelHyperParams.dropout_seed,
                is_test=False)
121 122 123
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
124
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
125

126 127 128
    if cache is not None:  # use cache and concat time steps
        k = cache["k"] = layers.concat([cache["k"], k], axis=1)
        v = cache["v"] = layers.concat([cache["v"], v], axis=1)
129

G
guosheng 已提交
130 131 132
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
133

G
guosheng 已提交
134
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
135 136 137
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)
138

139 140 141 142 143 144 145 146
    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


147
def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate):
148
    """
Y
ying 已提交
149 150 151
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
152 153 154 155 156
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
                       act="relu")
157 158 159 160 161 162
    if dropout_rate:
        hidden = layers.dropout(
            hidden,
            dropout_prob=dropout_rate,
            seed=ModelHyperParams.dropout_seed,
            is_test=False)
G
guosheng 已提交
163
    out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
164 165 166
    return out


167
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
168
    """
Y
ying 已提交
169
    Add residual connection, layer normalization and droput to the out tensor
170 171 172 173 174
    optionally according to the value of process_cmd.
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
175
        if cmd == "a":  # add residual connection
176
            out = out + prev_out if prev_out else out
Y
ying 已提交
177
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
178 179 180 181 182
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
183
        elif cmd == "d":  # add dropout
184 185
            if dropout_rate:
                out = layers.dropout(
G
guosheng 已提交
186 187 188 189
                    out,
                    dropout_prob=dropout_rate,
                    seed=ModelHyperParams.dropout_seed,
                    is_test=False)
190 191 192 193 194 195 196
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


197 198 199 200 201 202 203 204
def prepare_encoder_decoder(src_word,
                            src_pos,
                            src_vocab_size,
                            src_emb_dim,
                            src_max_len,
                            dropout_rate=0.,
                            word_emb_param_name=None,
                            pos_enc_param_name=None):
Y
ying 已提交
205 206
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
207
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
208
    This module is used at the bottom of the encoder stacks.
209 210
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
211 212
        src_word,
        size=[src_vocab_size, src_emb_dim],
213
        padding_idx=ModelHyperParams.bos_idx,  # set embedding of bos to 0
G
guosheng 已提交
214 215 216
        param_attr=fluid.ParamAttr(
            name=word_emb_param_name,
            initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
Y
Yu Yang 已提交
217

G
guosheng 已提交
218
    src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
219 220 221 222 223
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
C
chengduoZH 已提交
224
    src_pos_enc.stop_gradient = True
225 226
    enc_input = src_word_emb + src_pos_enc
    return layers.dropout(
G
guosheng 已提交
227 228 229
        enc_input,
        dropout_prob=dropout_rate,
        seed=ModelHyperParams.dropout_seed,
230
        is_test=False) if dropout_rate else enc_input
231 232 233


prepare_encoder = partial(
234
    prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0])
235
prepare_decoder = partial(
236
    prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[1])
237 238


Y
ying 已提交
239 240 241 242 243 244 245
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
246 247 248 249 250
                  prepostprocess_dropout,
                  attention_dropout,
                  relu_dropout,
                  preprocess_cmd="n",
                  postprocess_cmd="da"):
Y
ying 已提交
251 252 253 254 255
    """The encoder layers that can be stacked to form a deep encoder.
    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
256
    """
257 258 259 260 261 262 263 264 265 266 267
    attn_output = multi_head_attention(
        pre_process_layer(enc_input, preprocess_cmd,
                          prepostprocess_dropout), None, None, attn_bias, d_key,
        d_value, d_model, n_head, attention_dropout)
    attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd,
                                     prepostprocess_dropout)
    ffd_output = positionwise_feed_forward(
        pre_process_layer(attn_output, preprocess_cmd, prepostprocess_dropout),
        d_inner_hid, d_model, relu_dropout)
    return post_process_layer(attn_output, ffd_output, postprocess_cmd,
                              prepostprocess_dropout)
Y
ying 已提交
268 269 270 271 272 273 274 275 276 277


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
278 279 280 281 282
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd="n",
            postprocess_cmd="da"):
283
    """
Y
ying 已提交
284 285
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
286 287
    """
    for i in range(n_layer):
288 289 290 291 292 293 294 295 296 297 298 299 300
        enc_output = encoder_layer(
            enc_input,
            attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd,
            postprocess_cmd, )
301
        enc_input = enc_output
302 303
    enc_output = pre_process_layer(enc_output, preprocess_cmd,
                                   prepostprocess_dropout)
304 305 306
    return enc_output


Y
ying 已提交
307 308 309 310 311 312 313 314 315
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
316 317 318 319 320
                  prepostprocess_dropout,
                  attention_dropout,
                  relu_dropout,
                  preprocess_cmd,
                  postprocess_cmd,
321
                  cache=None):
Y
ying 已提交
322 323 324
    """ The layer to be stacked in decoder part.
    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
325
    """
Y
ying 已提交
326
    slf_attn_output = multi_head_attention(
327 328 329
        pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout),
        None,
        None,
Y
ying 已提交
330 331 332 333 334
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
335
        attention_dropout,
336
        cache, )
Y
ying 已提交
337 338 339
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
340 341
        postprocess_cmd,
        prepostprocess_dropout, )
Y
ying 已提交
342
    enc_attn_output = multi_head_attention(
343 344
        pre_process_layer(slf_attn_output, preprocess_cmd,
                          prepostprocess_dropout),
Y
ying 已提交
345 346 347 348 349 350 351
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
352
        attention_dropout, )
Y
ying 已提交
353 354 355
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
356 357
        postprocess_cmd,
        prepostprocess_dropout, )
Y
ying 已提交
358
    ffd_output = positionwise_feed_forward(
359 360
        pre_process_layer(enc_attn_output, preprocess_cmd,
                          prepostprocess_dropout),
Y
ying 已提交
361
        d_inner_hid,
362 363
        d_model,
        relu_dropout, )
Y
ying 已提交
364 365 366
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
367 368
        postprocess_cmd,
        prepostprocess_dropout, )
369 370 371
    return dec_output


Y
ying 已提交
372 373 374 375 376 377 378 379 380 381
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
382 383 384 385 386
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd,
            postprocess_cmd,
387
            caches=None):
388 389 390 391
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
392
        dec_output = decoder_layer(
393 394 395 396 397 398 399 400 401
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
402 403 404 405 406 407
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd,
            postprocess_cmd,
            cache=None if caches is None else caches[i])
408
        dec_input = dec_output
409 410
    dec_output = pre_process_layer(dec_output, preprocess_cmd,
                                   prepostprocess_dropout)
411 412 413
    return dec_output


414
def make_all_inputs(input_fields):
415 416 417
    """
    Define the input data layers for the transformer model.
    """
418 419 420 421 422 423
    inputs = []
    for input_field in input_fields:
        input_var = layers.data(
            name=input_field,
            shape=input_descs[input_field][0],
            dtype=input_descs[input_field][1],
424 425
            lod_level=input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0,
426
            append_batch_size=False)
427 428
        inputs.append(input_var)
    return inputs
429 430


431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
def make_all_py_reader_inputs(input_fields, is_test=False):
    reader = layers.py_reader(
        capacity=20,
        name="test_reader" if is_test else "train_reader",
        shapes=[input_descs[input_field][0] for input_field in input_fields],
        dtypes=[input_descs[input_field][1] for input_field in input_fields],
        lod_levels=[
            input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0
            for input_field in input_fields
        ])
    return layers.read_file(reader), reader


def transformer(src_vocab_size,
                trg_vocab_size,
                max_length,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
                prepostprocess_dropout,
                attention_dropout,
                relu_dropout,
                preprocess_cmd,
                postprocess_cmd,
                weight_sharing,
                label_smooth_eps,
                use_py_reader=False,
                is_test=False):
G
guosheng 已提交
463
    if weight_sharing:
G
guosheng 已提交
464
        assert src_vocab_size == trg_vocab_size, (
G
guosheng 已提交
465 466
            "Vocabularies in source and target should be same for weight sharing."
        )
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482

    data_input_names = encoder_data_input_fields + \
                decoder_data_input_fields[:-1] + label_data_input_fields

    if use_py_reader:
        all_inputs, reader = make_all_py_reader_inputs(data_input_names,
                                                       is_test)
    else:
        all_inputs = make_all_inputs(data_input_names)

    enc_inputs_len = len(encoder_data_input_fields)
    dec_inputs_len = len(decoder_data_input_fields[:-1])
    enc_inputs = all_inputs[0:enc_inputs_len]
    dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
    label = all_inputs[-2]
    weights = all_inputs[-1]
483

484 485 486 487 488 489 490 491 492
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
493 494 495 496 497
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
G
guosheng 已提交
498
        weight_sharing,
499
        enc_inputs, )
500 501 502 503 504 505 506 507 508 509

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
510 511 512 513 514
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
G
guosheng 已提交
515
        weight_sharing,
516
        dec_inputs,
517 518 519 520
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
521 522 523 524 525
    if label_smooth_eps:
        label = layers.label_smooth(
            label=layers.one_hot(
                input=label, depth=trg_vocab_size),
            epsilon=label_smooth_eps)
526

527
    cost = layers.softmax_with_cross_entropy(
528
        logits=predict,
529 530
        label=label,
        soft_label=True if label_smooth_eps else False)
531
    weighted_cost = cost * weights
G
guosheng 已提交
532 533
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
534
    token_num.stop_gradient = True
G
guosheng 已提交
535
    avg_cost = sum_cost / token_num
536
    return sum_cost, avg_cost, predict, token_num, reader if use_py_reader else None
537 538 539 540 541 542 543 544 545 546


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
547 548 549 550 551
                 prepostprocess_dropout,
                 attention_dropout,
                 relu_dropout,
                 preprocess_cmd,
                 postprocess_cmd,
G
guosheng 已提交
552
                 weight_sharing,
553
                 enc_inputs=None):
554 555 556
    """
    The wrapper assembles together all needed layers for the encoder.
    """
557
    if enc_inputs is None:
558
        # This is used to implement independent encoder program in inference.
559 560
        src_word, src_pos, src_slf_attn_bias = make_all_inputs(
            encoder_data_input_fields)
561
    else:
562
        src_word, src_pos, src_slf_attn_bias = enc_inputs
Y
ying 已提交
563 564 565 566 567 568
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
569
        prepostprocess_dropout,
G
guosheng 已提交
570
        word_emb_param_name=word_emb_param_names[0])
571 572 573 574 575 576 577 578 579 580 581 582 583 584
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd, )
585 586 587 588 589 590 591 592 593 594 595
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
596 597 598 599 600
                 prepostprocess_dropout,
                 attention_dropout,
                 relu_dropout,
                 preprocess_cmd,
                 postprocess_cmd,
G
guosheng 已提交
601
                 weight_sharing,
602
                 dec_inputs=None,
603 604
                 enc_output=None,
                 caches=None):
605 606 607
    """
    The wrapper assembles together all needed layers for the decoder.
    """
608
    if dec_inputs is None:
609
        # This is used to implement independent decoder program in inference.
610 611
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = \
            make_all_inputs(decoder_data_input_fields)
612
    else:
613
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
Y
ying 已提交
614 615 616 617 618 619 620

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
621
        prepostprocess_dropout,
G
guosheng 已提交
622 623
        word_emb_param_name=word_emb_param_names[0]
        if weight_sharing else word_emb_param_names[1])
Y
ying 已提交
624 625 626 627 628 629 630 631 632 633 634
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
635 636 637 638 639
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
Y
Yu Yang 已提交
640
        caches=caches)
641
    # Reshape to 2D tensor to use GEMM instead of BatchedGEMM
642 643
    dec_output = layers.reshape(
        dec_output, shape=[-1, dec_output.shape[-1]], inplace=True)
G
guosheng 已提交
644
    if weight_sharing:
645 646
        predict = layers.matmul(
            x=dec_output,
C
chengduoZH 已提交
647 648
            y=fluid.default_main_program().global_block().var(
                word_emb_param_names[0]),
649
            transpose_y=True)
G
guosheng 已提交
650
    else:
651 652 653
        predict = layers.fc(input=dec_output,
                            size=trg_vocab_size,
                            bias_attr=False,
Y
Fix bug  
Yu Yang 已提交
654 655
                            num_flatten_dims=2)
    if dec_inputs is None:
656
        # Return probs for independent decoder program.
Y
Fix bug  
Yu Yang 已提交
657
        predict = layers.softmax(predict)
658
    return predict
659 660 661 662 663 664 665 666 667 668 669 670


def fast_decode(
        src_vocab_size,
        trg_vocab_size,
        max_in_len,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
671 672 673 674 675
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
676
        weight_sharing,
677 678 679
        beam_size,
        max_out_len,
        eos_idx, ):
680 681 682 683
    """
    Use beam search to decode. Caches will be used to store states of history
    steps which can make the decoding faster.
    """
684 685 686 687 688 689
    enc_output = wrap_encoder(
        src_vocab_size, max_in_len, n_layer, n_head, d_key, d_value, d_model,
        d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout,
        preprocess_cmd, postprocess_cmd, weight_sharing)
    start_tokens, init_scores, trg_src_attn_bias = make_all_inputs(
        fast_decoder_data_input_fields)
690 691 692

    def beam_search():
        max_len = layers.fill_constant(
693 694 695 696 697
            shape=[1], dtype=start_tokens.dtype, value=max_out_len)
        step_idx = layers.fill_constant(
            shape=[1], dtype=start_tokens.dtype, value=0)
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond)
698
        # array states will be stored for each step.
Y
Yu Yang 已提交
699
        ids = layers.array_write(
Y
Yu Yang 已提交
700
            layers.reshape(start_tokens, (-1, 1)), step_idx)
701
        scores = layers.array_write(init_scores, step_idx)
702 703 704
        # cell states will be overwrited at each step.
        # caches contains states of history steps to reduce redundant
        # computation in decoder.
705 706 707 708
        caches = [{
            "k": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
709
                dtype=enc_output.dtype,
710 711 712 713
                value=0),
            "v": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
714
                dtype=enc_output.dtype,
715 716 717 718
                value=0)
        } for i in range(n_layer)]
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
Y
Yu Yang 已提交
719
            pre_ids = layers.reshape(pre_ids, (-1, 1, 1))
720
            pre_scores = layers.array_read(array=scores, i=step_idx)
721 722
            # sequence_expand can gather sequences according to lod thus can be
            # used in beam search to sift states corresponding to selected ids.
723
            pre_src_attn_bias = layers.sequence_expand(
724 725
                x=trg_src_attn_bias, y=pre_scores)
            pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
726 727
            pre_caches = [{
                "k": layers.sequence_expand(
728
                    x=cache["k"], y=pre_scores),
729
                "v": layers.sequence_expand(
730
                    x=cache["v"], y=pre_scores),
731
            } for cache in caches]
732 733 734 735
            pre_pos = layers.elementwise_mul(
                x=layers.fill_constant_batch_size_like(
                    input=pre_enc_output,  # cann't use pre_ids here since it has lod
                    value=1,
Y
Yu Yang 已提交
736
                    shape=[-1, 1, 1],
737
                    dtype=pre_ids.dtype),
738
                y=step_idx,
739
                axis=0)
740 741 742 743 744 745 746 747 748
            logits = wrap_decoder(
                trg_vocab_size,
                max_in_len,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
749 750 751 752 753
                prepostprocess_dropout,
                attention_dropout,
                relu_dropout,
                preprocess_cmd,
                postprocess_cmd,
754
                weight_sharing,
Y
Yu Yang 已提交
755
                dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
756 757
                enc_output=pre_enc_output,
                caches=pre_caches)
Y
Yu Yang 已提交
758

759 760
            topk_scores, topk_indices = layers.topk(
                input=layers.softmax(logits), k=beam_size)
761
            accu_scores = layers.elementwise_add(
762
                x=layers.log(topk_scores),
763 764 765 766 767
                y=layers.reshape(
                    pre_scores, shape=[-1]),
                axis=0)
            # beam_search op uses lod to distinguish branches.
            topk_indices = layers.lod_reset(topk_indices, pre_ids)
768 769
            selected_ids, selected_scores = layers.beam_search(
                pre_ids=pre_ids,
770
                pre_scores=pre_scores,
771 772 773 774
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
                end_id=eos_idx)
Y
Yu Yang 已提交
775

776 777
            layers.increment(x=step_idx, value=1.0, in_place=True)
            # update states
778 779
            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
780 781 782 783 784
            layers.assign(pre_src_attn_bias, trg_src_attn_bias)
            layers.assign(pre_enc_output, enc_output)
            for i in range(n_layer):
                layers.assign(pre_caches[i]["k"], caches[i]["k"])
                layers.assign(pre_caches[i]["v"], caches[i]["v"])
785 786 787
            length_cond = layers.less_than(x=step_idx, y=max_len)
            finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
            layers.logical_and(x=length_cond, y=finish_cond, out=cond)
788

789
        finished_ids, finished_scores = layers.beam_search_decode(
Y
Yu Yang 已提交
790
            ids, scores, beam_size=beam_size, end_id=eos_idx)
791 792 793 794
        return finished_ids, finished_scores

    finished_ids, finished_scores = beam_search()
    return finished_ids, finished_scores