model.py 31.7 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7
from config import *
Y
ying 已提交
8

9

10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
def wrap_layer_with_block(layer, block_idx):
    """
    Make layer define support indicating block, by which we can add layers
    to other blocks within current block. This will make it easy to define
    cache among while loop.
    """

    class BlockGuard(object):
        """
        BlockGuard class.

        BlockGuard class is used to switch to the given block in a program by
        using the Python `with` keyword.
        """

        def __init__(self, block_idx=None, main_program=None):
            self.main_program = fluid.default_main_program(
            ) if main_program is None else main_program
            self.old_block_idx = self.main_program.current_block().idx
            self.new_block_idx = block_idx

        def __enter__(self):
            self.main_program.current_block_idx = self.new_block_idx

        def __exit__(self, exc_type, exc_val, exc_tb):
            self.main_program.current_block_idx = self.old_block_idx
            if exc_type is not None:
                return False  # re-raise exception
            return True

    def layer_wrapper(*args, **kwargs):
        with BlockGuard(block_idx):
            return layer(*args, **kwargs)

    return layer_wrapper


47
def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
48
    """
49 50
    Generate the initial values for the sinusoid position encoding table.
    """
51 52 53 54 55 56 57 58 59 60 61 62
    channels = d_pos_vec
    position = np.arange(n_position)
    num_timescales = channels // 2
    log_timescale_increment = (np.log(float(1e4) / float(1)) /
                               (num_timescales - 1))
    inv_timescales = np.exp(np.arange(
        num_timescales)) * -log_timescale_increment
    scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
                                                               0)
    signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
    signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
    position_enc = signal
63 64 65 66 67 68 69 70 71 72
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
73
                         n_head=1,
G
guosheng 已提交
74
                         dropout_rate=0.,
75 76 77
                         cache=None,
                         gather_idx=None,
                         static_kv=False):
78
    """
Y
ying 已提交
79 80 81
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
82
    """
83 84 85
    keys = queries if keys is None else keys
    values = keys if values is None else values

86 87
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
88
            "Inputs: quries, keys and values should all be 3-D tensors.")
89

G
guosheng 已提交
90
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
91
        """
92 93 94
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
95
                      size=d_key * n_head,
96 97
                      bias_attr=False,
                      num_flatten_dims=2)
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
        # For encoder-decoder attention in inference, insert the ops and vars
        # into global block to use as cache among beam search.
        fc_layer = wrap_layer_with_block(
            layers.fc, fluid.default_main_program().current_block()
            .parent_idx) if cache is not None and static_kv else layers.fc
        k = fc_layer(
            input=keys,
            size=d_key * n_head,
            bias_attr=False,
            num_flatten_dims=2)
        v = fc_layer(
            input=values,
            size=d_value * n_head,
            bias_attr=False,
            num_flatten_dims=2)
113 114
        return q, k, v

115
    def __split_heads_qkv(queries, keys, values, n_head, d_key, d_value):
116
        """
117 118 119
        Reshape input tensors at the last dimension to split multi-heads 
        and then transpose. Specifically, transform the input tensor with shape
        [bs, max_sequence_length, n_head * hidden_dim] to the output tensor
G
guosheng 已提交
120
        with shape [bs, n_head, max_sequence_length, hidden_dim].
121
        """
122 123
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
124 125
        reshaped_q = layers.reshape(
            x=queries, shape=[0, 0, n_head, d_key], inplace=True)
126
        # permuate the dimensions into:
G
guosheng 已提交
127
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
        q = layers.transpose(x=reshaped_q, perm=[0, 2, 1, 3])
        # For encoder-decoder attention in inference, insert the ops and vars
        # into global block to use as cache among beam search.
        reshape_layer = wrap_layer_with_block(
            layers.reshape,
            fluid.default_main_program().current_block()
            .parent_idx) if cache is not None and static_kv else layers.reshape
        transpose_layer = wrap_layer_with_block(
            layers.transpose,
            fluid.default_main_program().current_block().
            parent_idx) if cache is not None and static_kv else layers.transpose
        reshaped_k = reshape_layer(
            x=keys, shape=[0, 0, n_head, d_key], inplace=True)
        k = transpose_layer(x=reshaped_k, perm=[0, 2, 1, 3])
        reshaped_v = reshape_layer(
            x=values, shape=[0, 0, n_head, d_value], inplace=True)
        v = transpose_layer(x=reshaped_v, perm=[0, 2, 1, 3])

        if cache is not None:  # only for faster inference
            if static_kv:  # For encoder-decoder attention in inference
                cache_k, cache_v = cache["static_k"], cache["static_v"]
                # To init the static_k and static_v in cache.
                # Maybe we can use condition_op(if_else) to do these at the first
                # step in while loop to replace these, however it might be less
                # efficient.
                static_cache_init = wrap_layer_with_block(
                    layers.assign,
                    fluid.default_main_program().current_block().parent_idx)
                static_cache_init(k, cache_k)
                static_cache_init(v, cache_v)
            else:  # For decoder self-attention in inference
                cache_k, cache_v = cache["k"], cache["v"]
            # gather cell states corresponding to selected parent
            select_k = layers.gather(cache_k, index=gather_idx)
            select_v = layers.gather(cache_v, index=gather_idx)
            if not static_kv:
                # For self attention in inference, use cache and concat time steps.
                select_k = layers.concat([select_k, k], axis=2)
                select_v = layers.concat([select_v, v], axis=2)
            # update cell states(caches) cached in global block
            layers.assign(select_k, cache_k)
            layers.assign(select_v, cache_v)
            return q, select_k, select_v
        return q, k, v
172 173 174 175 176 177 178 179 180 181

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
182 183
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
184
        return layers.reshape(
185 186 187
            x=trans_x,
            shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
            inplace=True)
188

189
    def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
190 191 192
        """
        Scaled Dot-Product Attention
        """
193
        product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_key**-0.5)
194 195 196
        if attn_bias:
            product += attn_bias
        weights = layers.softmax(product)
197 198
        if dropout_rate:
            weights = layers.dropout(
G
guosheng 已提交
199 200 201 202
                weights,
                dropout_prob=dropout_rate,
                seed=ModelHyperParams.dropout_seed,
                is_test=False)
203 204 205
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
206
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
207
    q, k, v = __split_heads_qkv(q, k, v, n_head, d_key, d_value)
208

G
guosheng 已提交
209
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
210 211 212
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)
213

214 215 216 217 218 219 220 221
    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


222
def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate):
223
    """
Y
ying 已提交
224 225 226
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
227 228 229 230 231
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
                       act="relu")
232 233 234 235 236 237
    if dropout_rate:
        hidden = layers.dropout(
            hidden,
            dropout_prob=dropout_rate,
            seed=ModelHyperParams.dropout_seed,
            is_test=False)
G
guosheng 已提交
238
    out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
239 240 241
    return out


242
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
243
    """
Y
ying 已提交
244
    Add residual connection, layer normalization and droput to the out tensor
245 246 247 248 249
    optionally according to the value of process_cmd.
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
250
        if cmd == "a":  # add residual connection
251
            out = out + prev_out if prev_out else out
Y
ying 已提交
252
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
253 254 255 256 257
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
258
        elif cmd == "d":  # add dropout
259 260
            if dropout_rate:
                out = layers.dropout(
G
guosheng 已提交
261 262 263 264
                    out,
                    dropout_prob=dropout_rate,
                    seed=ModelHyperParams.dropout_seed,
                    is_test=False)
265 266 267 268 269 270 271
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


272 273 274 275 276 277 278 279
def prepare_encoder_decoder(src_word,
                            src_pos,
                            src_vocab_size,
                            src_emb_dim,
                            src_max_len,
                            dropout_rate=0.,
                            word_emb_param_name=None,
                            pos_enc_param_name=None):
Y
ying 已提交
280 281
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
282
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
283
    This module is used at the bottom of the encoder stacks.
284 285
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
286 287
        src_word,
        size=[src_vocab_size, src_emb_dim],
288
        padding_idx=ModelHyperParams.bos_idx,  # set embedding of bos to 0
G
guosheng 已提交
289 290 291
        param_attr=fluid.ParamAttr(
            name=word_emb_param_name,
            initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
Y
Yu Yang 已提交
292

G
guosheng 已提交
293
    src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
294 295 296 297 298
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
C
chengduoZH 已提交
299
    src_pos_enc.stop_gradient = True
300 301
    enc_input = src_word_emb + src_pos_enc
    return layers.dropout(
G
guosheng 已提交
302 303 304
        enc_input,
        dropout_prob=dropout_rate,
        seed=ModelHyperParams.dropout_seed,
305
        is_test=False) if dropout_rate else enc_input
306 307 308


prepare_encoder = partial(
309
    prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0])
310
prepare_decoder = partial(
311
    prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[1])
312 313


Y
ying 已提交
314 315 316 317 318 319 320
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
321 322 323 324 325
                  prepostprocess_dropout,
                  attention_dropout,
                  relu_dropout,
                  preprocess_cmd="n",
                  postprocess_cmd="da"):
Y
ying 已提交
326 327 328 329 330
    """The encoder layers that can be stacked to form a deep encoder.
    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
331
    """
332 333 334 335 336 337 338 339 340 341 342
    attn_output = multi_head_attention(
        pre_process_layer(enc_input, preprocess_cmd,
                          prepostprocess_dropout), None, None, attn_bias, d_key,
        d_value, d_model, n_head, attention_dropout)
    attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd,
                                     prepostprocess_dropout)
    ffd_output = positionwise_feed_forward(
        pre_process_layer(attn_output, preprocess_cmd, prepostprocess_dropout),
        d_inner_hid, d_model, relu_dropout)
    return post_process_layer(attn_output, ffd_output, postprocess_cmd,
                              prepostprocess_dropout)
Y
ying 已提交
343 344 345 346 347 348 349 350 351 352


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
353 354 355 356 357
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd="n",
            postprocess_cmd="da"):
358
    """
Y
ying 已提交
359 360
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
361 362
    """
    for i in range(n_layer):
363 364 365 366 367 368 369 370 371 372 373 374 375
        enc_output = encoder_layer(
            enc_input,
            attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd,
            postprocess_cmd, )
376
        enc_input = enc_output
377 378
    enc_output = pre_process_layer(enc_output, preprocess_cmd,
                                   prepostprocess_dropout)
379 380 381
    return enc_output


Y
ying 已提交
382 383 384 385 386 387 388 389 390
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
391 392 393 394 395
                  prepostprocess_dropout,
                  attention_dropout,
                  relu_dropout,
                  preprocess_cmd,
                  postprocess_cmd,
396 397
                  cache=None,
                  gather_idx=None):
Y
ying 已提交
398 399 400
    """ The layer to be stacked in decoder part.
    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
401
    """
Y
ying 已提交
402
    slf_attn_output = multi_head_attention(
403 404 405
        pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout),
        None,
        None,
Y
ying 已提交
406 407 408 409 410
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
411
        attention_dropout,
412 413
        cache=cache,
        gather_idx=gather_idx)
Y
ying 已提交
414 415 416
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
417 418
        postprocess_cmd,
        prepostprocess_dropout, )
Y
ying 已提交
419
    enc_attn_output = multi_head_attention(
420 421
        pre_process_layer(slf_attn_output, preprocess_cmd,
                          prepostprocess_dropout),
Y
ying 已提交
422 423 424 425 426 427 428
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
429 430 431 432
        attention_dropout,
        cache=cache,
        gather_idx=gather_idx,
        static_kv=True)
Y
ying 已提交
433 434 435
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
436 437
        postprocess_cmd,
        prepostprocess_dropout, )
Y
ying 已提交
438
    ffd_output = positionwise_feed_forward(
439 440
        pre_process_layer(enc_attn_output, preprocess_cmd,
                          prepostprocess_dropout),
Y
ying 已提交
441
        d_inner_hid,
442 443
        d_model,
        relu_dropout, )
Y
ying 已提交
444 445 446
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
447 448
        postprocess_cmd,
        prepostprocess_dropout, )
449 450 451
    return dec_output


Y
ying 已提交
452 453 454 455 456 457 458 459 460 461
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
462 463 464 465 466
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd,
            postprocess_cmd,
467 468
            caches=None,
            gather_idx=None):
469 470 471 472
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
473
        dec_output = decoder_layer(
474 475 476 477 478 479 480 481 482
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
483 484 485 486 487
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd,
            postprocess_cmd,
488 489
            cache=None if caches is None else caches[i],
            gather_idx=gather_idx)
490
        dec_input = dec_output
491 492
    dec_output = pre_process_layer(dec_output, preprocess_cmd,
                                   prepostprocess_dropout)
493 494 495
    return dec_output


496
def make_all_inputs(input_fields):
497 498 499
    """
    Define the input data layers for the transformer model.
    """
500 501 502 503 504 505
    inputs = []
    for input_field in input_fields:
        input_var = layers.data(
            name=input_field,
            shape=input_descs[input_field][0],
            dtype=input_descs[input_field][1],
506 507
            lod_level=input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0,
508
            append_batch_size=False)
509 510
        inputs.append(input_var)
    return inputs
511 512


513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544
def make_all_py_reader_inputs(input_fields, is_test=False):
    reader = layers.py_reader(
        capacity=20,
        name="test_reader" if is_test else "train_reader",
        shapes=[input_descs[input_field][0] for input_field in input_fields],
        dtypes=[input_descs[input_field][1] for input_field in input_fields],
        lod_levels=[
            input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0
            for input_field in input_fields
        ])
    return layers.read_file(reader), reader


def transformer(src_vocab_size,
                trg_vocab_size,
                max_length,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
                prepostprocess_dropout,
                attention_dropout,
                relu_dropout,
                preprocess_cmd,
                postprocess_cmd,
                weight_sharing,
                label_smooth_eps,
                use_py_reader=False,
                is_test=False):
G
guosheng 已提交
545
    if weight_sharing:
G
guosheng 已提交
546
        assert src_vocab_size == trg_vocab_size, (
G
guosheng 已提交
547 548
            "Vocabularies in source and target should be same for weight sharing."
        )
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564

    data_input_names = encoder_data_input_fields + \
                decoder_data_input_fields[:-1] + label_data_input_fields

    if use_py_reader:
        all_inputs, reader = make_all_py_reader_inputs(data_input_names,
                                                       is_test)
    else:
        all_inputs = make_all_inputs(data_input_names)

    enc_inputs_len = len(encoder_data_input_fields)
    dec_inputs_len = len(decoder_data_input_fields[:-1])
    enc_inputs = all_inputs[0:enc_inputs_len]
    dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
    label = all_inputs[-2]
    weights = all_inputs[-1]
565

566 567 568 569 570 571 572 573 574
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
575 576 577 578 579
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
G
guosheng 已提交
580
        weight_sharing,
581
        enc_inputs, )
582 583 584 585 586 587 588 589 590 591

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
592 593 594 595 596
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
G
guosheng 已提交
597
        weight_sharing,
598
        dec_inputs,
599 600 601 602
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
603 604 605 606 607
    if label_smooth_eps:
        label = layers.label_smooth(
            label=layers.one_hot(
                input=label, depth=trg_vocab_size),
            epsilon=label_smooth_eps)
608

609
    cost = layers.softmax_with_cross_entropy(
610
        logits=predict,
611 612
        label=label,
        soft_label=True if label_smooth_eps else False)
613
    weighted_cost = cost * weights
G
guosheng 已提交
614 615
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
616
    token_num.stop_gradient = True
G
guosheng 已提交
617
    avg_cost = sum_cost / token_num
618
    return sum_cost, avg_cost, predict, token_num, reader if use_py_reader else None
619 620 621 622 623 624 625 626 627 628


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
629 630 631 632 633
                 prepostprocess_dropout,
                 attention_dropout,
                 relu_dropout,
                 preprocess_cmd,
                 postprocess_cmd,
G
guosheng 已提交
634
                 weight_sharing,
635
                 enc_inputs=None):
636 637 638
    """
    The wrapper assembles together all needed layers for the encoder.
    """
639
    if enc_inputs is None:
640
        # This is used to implement independent encoder program in inference.
641 642
        src_word, src_pos, src_slf_attn_bias = make_all_inputs(
            encoder_data_input_fields)
643
    else:
644
        src_word, src_pos, src_slf_attn_bias = enc_inputs
Y
ying 已提交
645 646 647 648 649 650
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
651
        prepostprocess_dropout,
G
guosheng 已提交
652
        word_emb_param_name=word_emb_param_names[0])
653 654 655 656 657 658 659 660 661 662 663 664 665 666
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd, )
667 668 669 670 671 672 673 674 675 676 677
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
678 679 680 681 682
                 prepostprocess_dropout,
                 attention_dropout,
                 relu_dropout,
                 preprocess_cmd,
                 postprocess_cmd,
G
guosheng 已提交
683
                 weight_sharing,
684
                 dec_inputs=None,
685
                 enc_output=None,
686 687
                 caches=None,
                 gather_idx=None):
688 689 690
    """
    The wrapper assembles together all needed layers for the decoder.
    """
691
    if dec_inputs is None:
692
        # This is used to implement independent decoder program in inference.
693 694
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = \
            make_all_inputs(decoder_data_input_fields)
695
    else:
696
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
Y
ying 已提交
697 698 699 700 701 702 703

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
704
        prepostprocess_dropout,
G
guosheng 已提交
705 706
        word_emb_param_name=word_emb_param_names[0]
        if weight_sharing else word_emb_param_names[1])
Y
ying 已提交
707 708 709 710 711 712 713 714 715 716 717
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
718 719 720 721 722
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
723 724
        caches=caches,
        gather_idx=gather_idx)
725
    # Reshape to 2D tensor to use GEMM instead of BatchedGEMM
726 727
    dec_output = layers.reshape(
        dec_output, shape=[-1, dec_output.shape[-1]], inplace=True)
G
guosheng 已提交
728
    if weight_sharing:
729 730
        predict = layers.matmul(
            x=dec_output,
C
chengduoZH 已提交
731 732
            y=fluid.default_main_program().global_block().var(
                word_emb_param_names[0]),
733
            transpose_y=True)
G
guosheng 已提交
734
    else:
735 736
        predict = layers.fc(input=dec_output,
                            size=trg_vocab_size,
737
                            bias_attr=False)
Y
Fix bug  
Yu Yang 已提交
738
    if dec_inputs is None:
739
        # Return probs for independent decoder program.
Y
Fix bug  
Yu Yang 已提交
740
        predict = layers.softmax(predict)
741
    return predict
742 743


744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779
def fast_decode(src_vocab_size,
                trg_vocab_size,
                max_in_len,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
                prepostprocess_dropout,
                attention_dropout,
                relu_dropout,
                preprocess_cmd,
                postprocess_cmd,
                weight_sharing,
                beam_size,
                max_out_len,
                eos_idx,
                use_py_reader=False):
    """
    Use beam search to decode. Caches will be used to store states of history
    steps which can make the decoding faster.
    """
    data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields

    if use_py_reader:
        all_inputs, reader = make_all_py_reader_inputs(data_input_names)
    else:
        all_inputs = make_all_inputs(data_input_names)

    enc_inputs_len = len(encoder_data_input_fields)
    dec_inputs_len = len(fast_decoder_data_input_fields)
    enc_inputs = all_inputs[0:enc_inputs_len]
    dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]

    enc_output = wrap_encoder(
780 781 782 783 784 785 786 787
        src_vocab_size,
        max_in_len,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
788 789 790 791 792
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
793
        weight_sharing,
794 795
        enc_inputs, )
    start_tokens, init_scores, parent_idx, trg_src_attn_bias = dec_inputs
796 797 798

    def beam_search():
        max_len = layers.fill_constant(
799 800 801 802
            shape=[1],
            dtype=start_tokens.dtype,
            value=max_out_len,
            force_cpu=True)
803
        step_idx = layers.fill_constant(
804 805
            shape=[1], dtype=start_tokens.dtype, value=0, force_cpu=True)
        cond = layers.less_than(x=step_idx, y=max_len)  # default force_cpu=True
806
        while_op = layers.While(cond)
807
        # array states will be stored for each step.
Y
Yu Yang 已提交
808
        ids = layers.array_write(
Y
Yu Yang 已提交
809
            layers.reshape(start_tokens, (-1, 1)), step_idx)
810
        scores = layers.array_write(init_scores, step_idx)
811
        # cell states will be overwrited at each step.
812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835
        # caches contains states of history steps in decoder self-attention
        # and static encoder output projections in encoder-decoder attention
        # to reduce redundant computation.
        caches = [
            {
                "k":  # for self attention
                layers.fill_constant_batch_size_like(
                    input=start_tokens,
                    shape=[-1, n_head, 0, d_key],
                    dtype=enc_output.dtype,
                    value=0),
                "v":  # for self attention
                layers.fill_constant_batch_size_like(
                    input=start_tokens,
                    shape=[-1, n_head, 0, d_value],
                    dtype=enc_output.dtype,
                    value=0),
                "static_k":  # for encoder-decoder attention
                layers.create_tensor(dtype=enc_output.dtype),
                "static_v":  # for encoder-decoder attention
                layers.create_tensor(dtype=enc_output.dtype)
            } for i in range(n_layer)
        ]

836 837
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
838 839 840
            # Since beam_search_op dosen't enforce pre_ids' shape, we can do
            # inplace reshape here which actually change the shape of pre_ids.
            pre_ids = layers.reshape(pre_ids, (-1, 1, 1), inplace=True)
841
            pre_scores = layers.array_read(array=scores, i=step_idx)
842 843 844
            # gather cell states corresponding to selected parent
            pre_src_attn_bias = layers.gather(
                trg_src_attn_bias, index=parent_idx)
845 846
            pre_pos = layers.elementwise_mul(
                x=layers.fill_constant_batch_size_like(
847
                    input=pre_src_attn_bias,  # cann't use lod tensor here
848
                    value=1,
Y
Yu Yang 已提交
849
                    shape=[-1, 1, 1],
850
                    dtype=pre_ids.dtype),
851
                y=step_idx,
852
                axis=0)
853 854 855 856 857 858 859 860 861
            logits = wrap_decoder(
                trg_vocab_size,
                max_in_len,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
862 863 864 865 866
                prepostprocess_dropout,
                attention_dropout,
                relu_dropout,
                preprocess_cmd,
                postprocess_cmd,
867
                weight_sharing,
Y
Yu Yang 已提交
868
                dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
869 870 871 872
                enc_output=enc_output,
                caches=caches,
                gather_idx=parent_idx)
            # intra-beam topK
873 874
            topk_scores, topk_indices = layers.topk(
                input=layers.softmax(logits), k=beam_size)
875
            accu_scores = layers.elementwise_add(
876 877
                x=layers.log(topk_scores), y=pre_scores, axis=0)
            # beam_search op uses lod to differentiate branches.
G
guoshengCS 已提交
878
            topk_indices = layers.lod_reset(accu_scores, pre_ids)
879 880 881
            # topK reduction across beams, also contain special handle of
            # end beams and end sentences(batch reduction)
            selected_ids, selected_scores, gather_idx = layers.beam_search(
882
                pre_ids=pre_ids,
883
                pre_scores=pre_scores,
884 885 886
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
887 888
                end_id=eos_idx,
                return_parent_idx=True)
889
            layers.increment(x=step_idx, value=1.0, in_place=True)
890 891
            # cell states(caches) have been updated in wrap_decoder,
            # only need to update beam search states here.
892 893
            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
894
            layers.assign(gather_idx, parent_idx)
895
            layers.assign(pre_src_attn_bias, trg_src_attn_bias)
896 897 898
            length_cond = layers.less_than(x=step_idx, y=max_len)
            finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
            layers.logical_and(x=length_cond, y=finish_cond, out=cond)
899

900
        finished_ids, finished_scores = layers.beam_search_decode(
Y
Yu Yang 已提交
901
            ids, scores, beam_size=beam_size, end_id=eos_idx)
902 903 904
        return finished_ids, finished_scores

    finished_ids, finished_scores = beam_search()
905
    return finished_ids, finished_scores, reader if use_py_reader else None