model.py 28.0 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7
from config import *
Y
ying 已提交
8

9 10

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
11
    """
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
30
                         n_head=1,
G
guosheng 已提交
31 32
                         dropout_rate=0.,
                         pre_softmax_shape=None,
33 34
                         post_softmax_shape=None,
                         cache=None):
35
    """
Y
ying 已提交
36 37 38
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
39
    """
40 41 42
    keys = queries if keys is None else keys
    values = keys if values is None else values

43 44
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
45
            "Inputs: quries, keys and values should all be 3-D tensors.")
46

G
guosheng 已提交
47
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
48
        """
49 50 51
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
52
                      size=d_key * n_head,
53 54 55
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
56
                      size=d_key * n_head,
57 58 59
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
60
                      size=d_value * n_head,
61 62 63 64
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
65
    def __split_heads(x, n_head):
66 67 68
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
69 70
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
71
        """
G
guosheng 已提交
72
        if n_head == 1:
73 74 75
            return x

        hidden_size = x.shape[-1]
76 77
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
78
        reshaped = layers.reshape(
79
            x=x, shape=[0, 0, n_head, hidden_size // n_head])
80 81

        # permuate the dimensions into:
G
guosheng 已提交
82
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
83 84 85 86 87 88 89 90 91 92 93 94
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
95 96
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
97 98
        return layers.reshape(
            x=trans_x,
99
            shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
100

101
    def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
102 103 104
        """
        Scaled Dot-Product Attention
        """
105
        scaled_q = layers.scale(x=q, scale=d_key**-0.5)
106
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
G
guosheng 已提交
107 108 109 110 111 112 113 114
        weights = layers.reshape(
            x=layers.elementwise_add(
                x=product, y=attn_bias) if attn_bias else product,
            shape=[-1, product.shape[-1]],
            actual_shape=pre_softmax_shape,
            act="softmax")
        weights = layers.reshape(
            x=weights, shape=product.shape, actual_shape=post_softmax_shape)
115 116 117 118 119 120
        if dropout_rate:
            weights = layers.dropout(
                weights, dropout_prob=dropout_rate, is_test=False)
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
121
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
122

123 124 125
    if cache is not None:  # use cache and concat time steps
        k = cache["k"] = layers.concat([cache["k"], k], axis=1)
        v = cache["v"] = layers.concat([cache["v"], v], axis=1)
126

G
guosheng 已提交
127 128 129
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
130

G
guosheng 已提交
131
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
132 133 134
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)
135

136 137 138 139 140 141 142 143
    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


144
def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate):
145
    """
Y
ying 已提交
146 147 148
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
149 150 151 152 153
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
                       act="relu")
154 155 156
    if dropout_rate:
        hidden = layers.dropout(
            hidden, dropout_prob=dropout_rate, is_test=False)
G
guosheng 已提交
157
    out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
158 159 160
    return out


161
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
162
    """
Y
ying 已提交
163
    Add residual connection, layer normalization and droput to the out tensor
164 165 166 167 168
    optionally according to the value of process_cmd.
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
169
        if cmd == "a":  # add residual connection
170
            out = out + prev_out if prev_out else out
Y
ying 已提交
171
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
172 173 174 175 176
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
177
        elif cmd == "d":  # add dropout
178 179 180
            if dropout_rate:
                out = layers.dropout(
                    out, dropout_prob=dropout_rate, is_test=False)
181 182 183 184 185 186 187 188 189 190 191 192
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_max_len,
193 194
                    dropout_rate=0.,
                    src_data_shape=None,
G
guosheng 已提交
195
                    word_emb_param_name=None,
196
                    pos_enc_param_name=None):
Y
ying 已提交
197 198
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
199
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
200
    This module is used at the bottom of the encoder stacks.
201 202
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
203 204
        src_word,
        size=[src_vocab_size, src_emb_dim],
G
guosheng 已提交
205 206 207 208
        param_attr=fluid.ParamAttr(
            name=word_emb_param_name,
            initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
    src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
209 210 211 212 213 214
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
215 216
    enc_input = layers.reshape(
        x=enc_input,
217
        shape=[batch_size, seq_len, src_emb_dim],
218
        actual_shape=src_data_shape)
219
    return layers.dropout(
220 221
        enc_input, dropout_prob=dropout_rate,
        is_test=False) if dropout_rate else enc_input
222 223 224 225 226 227 228 229


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
230 231 232 233 234 235 236
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
237 238 239 240 241
                  prepostprocess_dropout,
                  attention_dropout,
                  relu_dropout,
                  preprocess_cmd,
                  postprocess_cmd,
G
guosheng 已提交
242 243
                  pre_softmax_shape=None,
                  post_softmax_shape=None):
Y
ying 已提交
244 245 246 247 248
    """The encoder layers that can be stacked to form a deep encoder.
    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
249
    """
G
guosheng 已提交
250
    attn_output = multi_head_attention(
251 252 253 254 255 256 257 258 259 260
        pre_process_layer(enc_input, preprocess_cmd, prepostprocess_dropout),
        None, None, attn_bias, d_key, d_value, d_model, n_head,
        attention_dropout, pre_softmax_shape, post_softmax_shape)
    attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd,
                                     prepostprocess_dropout)
    ffd_output = positionwise_feed_forward(
        pre_process_layer(attn_output, preprocess_cmd, prepostprocess_dropout),
        d_inner_hid, d_model, relu_dropout)
    return post_process_layer(attn_output, ffd_output, postprocess_cmd,
                              prepostprocess_dropout)
Y
ying 已提交
261 262 263 264 265 266 267 268 269 270


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
271 272 273 274 275
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd,
            postprocess_cmd,
G
guosheng 已提交
276 277
            pre_softmax_shape=None,
            post_softmax_shape=None):
278
    """
Y
ying 已提交
279 280
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
281 282
    """
    for i in range(n_layer):
283 284 285 286 287 288 289 290
        enc_output = encoder_layer(
            enc_input,
            attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
291 292 293 294 295
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd,
            postprocess_cmd,
G
guosheng 已提交
296 297
            pre_softmax_shape,
            post_softmax_shape, )
298
        enc_input = enc_output
299 300
    enc_output = pre_process_layer(enc_output, preprocess_cmd,
                                   prepostprocess_dropout)
301 302 303
    return enc_output


Y
ying 已提交
304 305 306 307 308 309 310 311 312
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
313 314 315 316 317
                  prepostprocess_dropout,
                  attention_dropout,
                  relu_dropout,
                  preprocess_cmd,
                  postprocess_cmd,
G
guosheng 已提交
318 319 320
                  slf_attn_pre_softmax_shape=None,
                  slf_attn_post_softmax_shape=None,
                  src_attn_pre_softmax_shape=None,
321 322
                  src_attn_post_softmax_shape=None,
                  cache=None):
Y
ying 已提交
323 324 325
    """ The layer to be stacked in decoder part.
    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
326
    """
Y
ying 已提交
327
    slf_attn_output = multi_head_attention(
328 329 330
        pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout),
        None,
        None,
Y
ying 已提交
331 332 333 334 335
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
336
        attention_dropout,
G
guosheng 已提交
337
        slf_attn_pre_softmax_shape,
338 339
        slf_attn_post_softmax_shape,
        cache, )
Y
ying 已提交
340 341 342
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
343 344
        postprocess_cmd,
        prepostprocess_dropout, )
Y
ying 已提交
345
    enc_attn_output = multi_head_attention(
346 347
        pre_process_layer(slf_attn_output, preprocess_cmd,
                          prepostprocess_dropout),
Y
ying 已提交
348 349 350 351 352 353 354
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
355
        attention_dropout,
G
guosheng 已提交
356 357
        src_attn_pre_softmax_shape,
        src_attn_post_softmax_shape, )
Y
ying 已提交
358 359 360
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
361 362
        postprocess_cmd,
        prepostprocess_dropout, )
Y
ying 已提交
363
    ffd_output = positionwise_feed_forward(
364 365
        pre_process_layer(enc_attn_output, preprocess_cmd,
                          prepostprocess_dropout),
Y
ying 已提交
366
        d_inner_hid,
367 368
        d_model,
        relu_dropout, )
Y
ying 已提交
369 370 371
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
372 373
        postprocess_cmd,
        prepostprocess_dropout, )
374 375 376
    return dec_output


Y
ying 已提交
377 378 379 380 381 382 383 384 385 386
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
387 388 389 390 391
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd,
            postprocess_cmd,
G
guosheng 已提交
392 393 394
            slf_attn_pre_softmax_shape=None,
            slf_attn_post_softmax_shape=None,
            src_attn_pre_softmax_shape=None,
395 396
            src_attn_post_softmax_shape=None,
            caches=None):
397 398 399 400
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
401
        dec_output = decoder_layer(
402 403 404 405 406 407 408 409 410
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
411 412 413 414 415
            prepostprocess_dropout,
            attention_dropout,
            relu_dropout,
            preprocess_cmd,
            postprocess_cmd,
416 417 418 419 420
            slf_attn_pre_softmax_shape,
            slf_attn_post_softmax_shape,
            src_attn_pre_softmax_shape,
            src_attn_post_softmax_shape,
            None if caches is None else caches[i], )
421
        dec_input = dec_output
422 423
    dec_output = pre_process_layer(dec_output, preprocess_cmd,
                                   prepostprocess_dropout)
424 425 426
    return dec_output


427
def make_all_inputs(input_fields):
428 429 430
    """
    Define the input data layers for the transformer model.
    """
431 432 433 434 435 436
    inputs = []
    for input_field in input_fields:
        input_var = layers.data(
            name=input_field,
            shape=input_descs[input_field][0],
            dtype=input_descs[input_field][1],
437 438
            lod_level=input_descs[input_field][2]
            if len(input_descs[input_field]) == 3 else 0,
439
            append_batch_size=False)
440 441
        inputs.append(input_var)
    return inputs
442 443


Y
ying 已提交
444 445 446 447 448 449 450 451 452 453
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
454 455 456 457 458
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
G
guosheng 已提交
459
        weight_sharing,
460
        label_smooth_eps, ):
G
guosheng 已提交
461 462 463 464
    if weight_sharing:
        assert src_vocab_size == src_vocab_size, (
            "Vocabularies in source and target should be same for weight sharing."
        )
465 466
    enc_inputs = make_all_inputs(encoder_data_input_fields +
                                 encoder_util_input_fields)
467

468 469 470 471 472 473 474 475 476
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
477 478 479 480 481
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
G
guosheng 已提交
482
        weight_sharing,
483
        enc_inputs, )
484

485 486
    dec_inputs = make_all_inputs(decoder_data_input_fields[:-1] +
                                 decoder_util_input_fields)
487 488 489 490 491 492 493 494 495 496

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
497 498 499 500 501
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
G
guosheng 已提交
502
        weight_sharing,
503
        dec_inputs,
504 505 506 507
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
508 509 510 511 512 513 514 515 516 517
    label, weights = make_all_inputs(label_data_input_fields)
    if label_smooth_eps:
        label = layers.label_smooth(
            label=layers.one_hot(
                input=label, depth=trg_vocab_size),
            epsilon=label_smooth_eps)
    cost = layers.softmax_with_cross_entropy(
        logits=predict,
        label=label,
        soft_label=True if label_smooth_eps else False)
518
    weighted_cost = cost * weights
G
guosheng 已提交
519 520 521
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
    avg_cost = sum_cost / token_num
G
guosheng 已提交
522
    return sum_cost, avg_cost, predict, token_num
523 524 525 526 527 528 529 530 531 532


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
533 534 535 536 537
                 prepostprocess_dropout,
                 attention_dropout,
                 relu_dropout,
                 preprocess_cmd,
                 postprocess_cmd,
G
guosheng 已提交
538
                 weight_sharing,
539
                 enc_inputs=None):
540 541 542
    """
    The wrapper assembles together all needed layers for the encoder.
    """
543
    if enc_inputs is None:
544
        # This is used to implement independent encoder program in inference.
545 546
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
547 548
            make_all_inputs(encoder_data_input_fields +
                                 encoder_util_input_fields)
549
    else:
550 551 552
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
            enc_inputs
Y
ying 已提交
553 554 555 556 557 558
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
559
        prepostprocess_dropout,
G
guosheng 已提交
560 561
        src_data_shape,
        word_emb_param_name=word_emb_param_names[0])
Y
ying 已提交
562 563 564 565 566 567 568 569 570
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
571 572 573 574 575
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
G
guosheng 已提交
576 577
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape, )
578 579 580 581 582 583 584 585 586 587 588
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
589 590 591 592 593
                 prepostprocess_dropout,
                 attention_dropout,
                 relu_dropout,
                 preprocess_cmd,
                 postprocess_cmd,
G
guosheng 已提交
594
                 weight_sharing,
595
                 dec_inputs=None,
596 597
                 enc_output=None,
                 caches=None):
598 599 600
    """
    The wrapper assembles together all needed layers for the decoder.
    """
601
    if dec_inputs is None:
602
        # This is used to implement independent decoder program in inference.
G
guosheng 已提交
603
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
604
            enc_output, trg_data_shape, slf_attn_pre_softmax_shape, \
605
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
606 607
            src_attn_post_softmax_shape = make_all_inputs(
            decoder_data_input_fields + decoder_util_input_fields)
608
    else:
G
guosheng 已提交
609
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
610 611 612
            trg_data_shape, slf_attn_pre_softmax_shape, \
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
            src_attn_post_softmax_shape = dec_inputs
Y
ying 已提交
613 614 615 616 617 618 619

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
620
        prepostprocess_dropout,
G
guosheng 已提交
621 622 623
        trg_data_shape,
        word_emb_param_name=word_emb_param_names[0]
        if weight_sharing else word_emb_param_names[1])
Y
ying 已提交
624 625 626 627 628 629 630 631 632 633 634
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
635 636 637 638 639
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
G
guosheng 已提交
640 641 642
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape,
        src_attn_pre_softmax_shape,
643 644
        src_attn_post_softmax_shape,
        caches, )
645
    # Return logits for training and probs for inference.
G
guosheng 已提交
646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
    if weight_sharing:
        predict = layers.reshape(
            x=layers.matmul(
                x=dec_output,
                y=fluid.get_var(word_emb_param_names[0]),
                transpose_y=True),
            shape=[-1, trg_vocab_size],
            act="softmax" if dec_inputs is None else None)
    else:
        predict = layers.reshape(
            x=layers.fc(input=dec_output,
                        size=trg_vocab_size,
                        bias_attr=False,
                        num_flatten_dims=2),
            shape=[-1, trg_vocab_size],
            act="softmax" if dec_inputs is None else None)
662
    return predict
663 664 665 666 667 668 669 670 671 672 673 674


def fast_decode(
        src_vocab_size,
        trg_vocab_size,
        max_in_len,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
675 676 677 678 679
        prepostprocess_dropout,
        attention_dropout,
        relu_dropout,
        preprocess_cmd,
        postprocess_cmd,
680
        weight_sharing,
681 682 683
        beam_size,
        max_out_len,
        eos_idx, ):
684 685 686 687
    """
    Use beam search to decode. Caches will be used to store states of history
    steps which can make the decoding faster.
    """
688 689 690 691
    enc_output = wrap_encoder(
        src_vocab_size, max_in_len, n_layer, n_head, d_key, d_value, d_model,
        d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout,
        preprocess_cmd, postprocess_cmd, weight_sharing)
692
    start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
693
        slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
694 695 696 697
        src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
        attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
        make_all_inputs(fast_decoder_data_input_fields +
                            fast_decoder_util_input_fields)
698 699 700

    def beam_search():
        max_len = layers.fill_constant(
701 702 703 704 705
            shape=[1], dtype=start_tokens.dtype, value=max_out_len)
        step_idx = layers.fill_constant(
            shape=[1], dtype=start_tokens.dtype, value=0)
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond)
706
        # array states will be stored for each step.
707 708
        ids = layers.array_write(start_tokens, step_idx)
        scores = layers.array_write(init_scores, step_idx)
709 710 711
        # cell states will be overwrited at each step.
        # caches contains states of history steps to reduce redundant
        # computation in decoder.
712 713 714 715
        caches = [{
            "k": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
716
                dtype=enc_output.dtype,
717 718 719 720
                value=0),
            "v": layers.fill_constant_batch_size_like(
                input=start_tokens,
                shape=[-1, 0, d_model],
721
                dtype=enc_output.dtype,
722 723 724 725 726
                value=0)
        } for i in range(n_layer)]
        with while_op.block():
            pre_ids = layers.array_read(array=ids, i=step_idx)
            pre_scores = layers.array_read(array=scores, i=step_idx)
727 728
            # sequence_expand can gather sequences according to lod thus can be
            # used in beam search to sift states corresponding to selected ids.
729
            pre_src_attn_bias = layers.sequence_expand(
730 731
                x=trg_src_attn_bias, y=pre_scores)
            pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
732 733
            pre_caches = [{
                "k": layers.sequence_expand(
734
                    x=cache["k"], y=pre_scores),
735
                "v": layers.sequence_expand(
736
                    x=cache["v"], y=pre_scores),
737
            } for cache in caches]
738 739 740 741 742 743 744 745 746
            pre_pos = layers.elementwise_mul(
                x=layers.fill_constant_batch_size_like(
                    input=pre_enc_output,  # cann't use pre_ids here since it has lod
                    value=1,
                    shape=[-1, 1],
                    dtype=pre_ids.dtype),
                y=layers.increment(
                    x=step_idx, value=1.0, in_place=False),
                axis=0)
747 748 749 750 751 752 753 754 755
            logits = wrap_decoder(
                trg_vocab_size,
                max_in_len,
                n_layer,
                n_head,
                d_key,
                d_value,
                d_model,
                d_inner_hid,
756 757 758 759 760
                prepostprocess_dropout,
                attention_dropout,
                relu_dropout,
                preprocess_cmd,
                postprocess_cmd,
761
                weight_sharing,
762 763 764 765 766 767
                dec_inputs=(
                    pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
                    slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
                    src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
                enc_output=pre_enc_output,
                caches=pre_caches)
768 769
            topk_scores, topk_indices = layers.topk(
                input=layers.softmax(logits), k=beam_size)
770
            accu_scores = layers.elementwise_add(
771
                x=layers.log(topk_scores),
772 773 774 775 776
                y=layers.reshape(
                    pre_scores, shape=[-1]),
                axis=0)
            # beam_search op uses lod to distinguish branches.
            topk_indices = layers.lod_reset(topk_indices, pre_ids)
777 778
            selected_ids, selected_scores = layers.beam_search(
                pre_ids=pre_ids,
779
                pre_scores=pre_scores,
780 781 782 783 784 785
                ids=topk_indices,
                scores=accu_scores,
                beam_size=beam_size,
                end_id=eos_idx)
            layers.increment(x=step_idx, value=1.0, in_place=True)
            # update states
786 787
            layers.array_write(selected_ids, i=step_idx, array=ids)
            layers.array_write(selected_scores, i=step_idx, array=scores)
788 789 790 791 792
            layers.assign(pre_src_attn_bias, trg_src_attn_bias)
            layers.assign(pre_enc_output, enc_output)
            for i in range(n_layer):
                layers.assign(pre_caches[i]["k"], caches[i]["k"])
                layers.assign(pre_caches[i]["v"], caches[i]["v"])
793
            layers.assign(
794 795 796
                layers.elementwise_add(
                    x=slf_attn_pre_softmax_shape,
                    y=attn_pre_softmax_shape_delta),
797 798 799 800 801 802
                slf_attn_pre_softmax_shape)
            layers.assign(
                layers.elementwise_add(
                    x=slf_attn_post_softmax_shape,
                    y=attn_post_softmax_shape_delta),
                slf_attn_post_softmax_shape)
803

804 805 806
            length_cond = layers.less_than(x=step_idx, y=max_len)
            finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
            layers.logical_and(x=length_cond, y=finish_cond, out=cond)
807

808 809
        finished_ids, finished_scores = layers.beam_search_decode(
            ids, scores, beam_size=beam_size, end_id=eos_idx)
810 811 812 813
        return finished_ids, finished_scores

    finished_ids, finished_scores = beam_search()
    return finished_ids, finished_scores