model.py 20.2 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7 8
from config import TrainTaskConfig, pos_enc_param_names, \
    encoder_input_data_names, decoder_input_data_names, label_data_names
Y
ying 已提交
9

10 11

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
12
    """
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
31
                         n_head=1,
32 33
                         dropout_rate=0.):
    """
Y
ying 已提交
34 35 36
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
37 38 39
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
40
            "Inputs: quries, keys and values should all be 3-D tensors.")
41

G
guosheng 已提交
42
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
43
        """
44 45 46
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
47 48 49 50 51
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
52 53 54
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
55 56 57 58 59
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
60 61 62
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
63 64 65 66 67
                      size=d_value * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_value,
                          fan_out=n_head * d_value),
68 69 70 71
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
72
    def __split_heads(x, n_head):
73 74 75
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
76 77
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
78
        """
G
guosheng 已提交
79
        if n_head == 1:
80 81 82
            return x

        hidden_size = x.shape[-1]
83 84
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
85
        reshaped = layers.reshape(
86
            x=x, shape=[0, -1, n_head, hidden_size // n_head])
87 88

        # permuate the dimensions into:
G
guosheng 已提交
89
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
90 91 92 93 94 95 96 97 98 99 100 101
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
102 103
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
104 105
        return layers.reshape(
            x=trans_x,
106
            shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]]))
107

G
guosheng 已提交
108
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
109 110 111 112
        """
        Scaled Dot-Product Attention
        """

113 114 115
        # FIXME(guosheng): Remove __softmax when softmax_op supporting high
        # rank tensors. softmax_op only supports 2D tensor currently.
        # Otherwise, add extra input data to reshape.
116 117
        def __softmax(x, eps=1e-9):
            exp_out = layers.exp(x=x)
G
guosheng 已提交
118
            sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False)
119 120
            return layers.elementwise_div(x=exp_out, y=sum_out, axis=0)

G
guosheng 已提交
121
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
122
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
123 124 125
        weights = __softmax(
            layers.elementwise_add(
                x=product, y=attn_bias) if attn_bias else product)
126
        # weights = __softmax(product)
127 128 129 130 131 132
        if dropout_rate:
            weights = layers.dropout(
                weights, dropout_prob=dropout_rate, is_test=False)
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
133
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
134

G
guosheng 已提交
135 136 137
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
138

G
guosheng 已提交
139
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
140 141 142 143 144 145 146
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)

    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
G
guosheng 已提交
147
                         param_attr=fluid.initializer.Xavier(uniform=False),
148 149 150 151 152 153 154
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
155 156 157
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
158 159 160 161
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
G
guosheng 已提交
162 163
                       param_attr=fluid.initializer.Uniform(
                           low=-(d_hid**-0.5), high=(d_hid**-0.5)),
164
                       act="relu")
G
guosheng 已提交
165 166 167 168 169
    out = layers.fc(input=hidden,
                    size=d_hid,
                    num_flatten_dims=2,
                    param_attr=fluid.initializer.Uniform(
                        low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)))
170 171 172
    return out


173
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
174
    """
Y
ying 已提交
175
    Add residual connection, layer normalization and droput to the out tensor
176
    optionally according to the value of process_cmd.
Y
ying 已提交
177

178 179 180 181
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
182
        if cmd == "a":  # add residual connection
183
            out = out + prev_out if prev_out else out
Y
ying 已提交
184
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
185 186 187 188 189
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
190
        elif cmd == "d":  # add dropout
191 192 193
            if dropout_rate:
                out = layers.dropout(
                    out, dropout_prob=dropout_rate, is_test=False)
194 195 196 197 198 199 200 201 202 203 204 205 206
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_pad_idx,
                    src_max_len,
207
                    dropout_rate=0.,
208
                    pos_pad_idx=0,
209
                    src_data_shape=None,
210
                    pos_enc_param_name=None):
Y
ying 已提交
211 212
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
213
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
214 215

    This module is used at the bottom of the encoder stacks.
216 217
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
218 219 220 221
        src_word,
        size=[src_vocab_size, src_emb_dim],
        padding_idx=src_pad_idx,
        param_attr=fluid.initializer.Normal(0., 1.))
222 223 224
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
G
guosheng 已提交
225
        padding_idx=pos_pad_idx,
226 227 228
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
229 230 231 232
    enc_input = layers.reshape(
        x=enc_input,
        shape=[-1, src_max_len, src_emb_dim],
        actual_shape=src_data_shape)
233
    return layers.dropout(
234 235
        enc_input, dropout_prob=dropout_rate,
        is_test=False) if dropout_rate else enc_input
236 237 238 239 240 241 242 243


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
244 245 246 247 248 249 250 251 252 253 254 255 256 257
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
                  dropout_rate=0.):
    """The encoder layers that can be stacked to form a deep encoder.

    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
258 259 260
    """
    attn_output = multi_head_attention(enc_input, enc_input, enc_input,
                                       attn_bias, d_key, d_value, d_model,
Y
ying 已提交
261 262 263
                                       n_head, dropout_rate)
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
264
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
265 266 267 268 269 270 271 272 273 274 275 276
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate=0.):
277
    """
Y
ying 已提交
278 279
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
280 281
    """
    for i in range(n_layer):
282 283 284 285 286 287 288 289 290
        enc_output = encoder_layer(
            enc_input,
            attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate, )
291 292 293 294
        enc_input = enc_output
    return enc_output


Y
ying 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
                  dropout_rate=0.):
    """ The layer to be stacked in decoder part.

    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
309
    """
Y
ying 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
        dropout_rate, )
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
        dropout_rate, )
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
349 350 351
    return dec_output


Y
ying 已提交
352 353 354 355 356 357 358 359 360 361 362
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate=0.):
363 364 365 366
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
367 368 369 370 371 372 373 374 375 376 377
        dec_output = decoder_layer(
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate, )
378 379 380 381
        dec_input = dec_output
    return dec_output


382 383 384 385
def make_inputs(input_data_names,
                n_head,
                d_model,
                max_length,
386 387 388 389 390
                is_pos=True,
                slf_attn_bias_flag=True,
                src_attn_bias_flag=True,
                enc_output_flag=False,
                data_shape_flag=True):
391 392 393 394
    """
    Define the input data layers for the transformer model.
    """
    input_layers = []
395 396 397 398 399
    batch_size = 1  # Only for the infer-shape in compile time.
    # The shapes here act as placeholder and are set to pass the infer-shape in
    # compile time.
    # The actual data shape of word is:
    # [batch_size * max_len_in_batch, 1]
400
    word = layers.data(
401
        name=input_data_names[len(input_layers)],
402 403 404 405 406
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
    input_layers += [word]
    # This is used for position data or label weight.
407 408
    # The actual data shape of pos is:
    # [batch_size * max_len_in_batch, 1]
409
    pos = layers.data(
410
        name=input_data_names[len(input_layers)],
411
        shape=[batch_size * max_length, 1],
412
        dtype="int64" if is_pos else "float32",
413 414 415
        append_batch_size=False)
    input_layers += [pos]
    if slf_attn_bias_flag:
416 417 418
        # This input is used to remove attention weights on paddings for the
        # encoder and to remove attention weights on subsequent words for the
        # decoder.
419 420
        # The actual data shape of slf_attn_bias_flag is:
        # [batch_size, n_head, max_len_in_batch, max_len_in_batch]
421
        slf_attn_bias = layers.data(
422 423
            name=input_data_names[len(input_layers)],
            shape=[batch_size, n_head, max_length, max_length],
424 425 426 427
            dtype="float32",
            append_batch_size=False)
        input_layers += [slf_attn_bias]
    if src_attn_bias_flag:
428
        # This input is used to remove attention weights on paddings.
429 430
        # The actual data shape of slf_attn_bias_flag is:
        # [batch_size, n_head, trg_max_len_in_batch, src_max_len_in_batch]
431
        src_attn_bias = layers.data(
432
            name=input_data_names[len(input_layers)],
433 434 435 436
            shape=[batch_size, n_head, max_length, max_length],
            dtype="float32",
            append_batch_size=False)
        input_layers += [src_attn_bias]
437 438 439 440 441 442 443 444
    if data_shape_flag:
        # This input is used to reshape.
        data_shape = layers.data(
            name=input_data_names[len(input_layers)],
            shape=[3],
            dtype="int32",
            append_batch_size=False)
        input_layers += [data_shape]
445
    if enc_output_flag:
446 447 448
        # This input is used in independent decoder program for inference.
        # The actual data shape of slf_attn_bias_flag is:
        # [batch_size, max_len_in_batch, d_model]
449 450 451 452 453 454
        enc_output = layers.data(
            name=input_data_names[len(input_layers)],
            shape=[batch_size, max_length, d_model],
            dtype="float32",
            append_batch_size=False)
        input_layers += [enc_output]
455 456 457
    return input_layers


Y
ying 已提交
458 459 460 461 462 463 464 465 466 467 468 469 470 471
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
        src_pad_idx,
        trg_pad_idx,
        pos_pad_idx, ):
472 473
    enc_inputs = make_inputs(encoder_input_data_names, n_head, d_model,
                             max_length, True, True, False)
474

475 476 477 478 479 480 481 482 483 484 485 486
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
        src_pad_idx,
        pos_pad_idx,
487
        enc_inputs, )
488

489 490
    dec_inputs = make_inputs(decoder_input_data_names, n_head, d_model,
                             max_length, True, True, True)
491 492 493 494 495 496 497 498 499 500 501 502 503

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
        trg_pad_idx,
        pos_pad_idx,
504
        dec_inputs,
505 506 507 508
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
509 510
    gold, weights = make_inputs(label_data_names, n_head, d_model, max_length,
                                False, False, False, False, False)
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526
    cost = layers.cross_entropy(input=predict, label=gold)
    weighted_cost = cost * weights
    return layers.reduce_sum(weighted_cost), predict


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
                 src_pad_idx,
                 pos_pad_idx,
527
                 enc_inputs=None):
528 529 530
    """
    The wrapper assembles together all needed layers for the encoder.
    """
531
    if enc_inputs is None:
532
        # This is used to implement independent encoder program in inference.
533 534 535
        src_word, src_pos, src_slf_attn_bias, src_data_shape = make_inputs(
            encoder_input_data_names, n_head, d_model, max_length, True, True,
            False)
536
    else:
537 538
        src_word, src_pos, src_slf_attn_bias, src_data_shape = enc_inputs

Y
ying 已提交
539 540 541 542 543 544 545
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        src_pad_idx,
        max_length,
546 547 548
        dropout_rate,
        pos_pad_idx,
        src_data_shape, )
Y
ying 已提交
549 550 551 552 553 554 555 556 557 558
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate, )
559 560 561 562 563 564 565 566 567 568 569 570 571 572
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
                 trg_pad_idx,
                 pos_pad_idx,
573
                 dec_inputs=None,
574 575 576 577
                 enc_output=None):
    """
    The wrapper assembles together all needed layers for the decoder.
    """
578
    if dec_inputs is None:
579
        # This is used to implement independent decoder program in inference.
580 581 582
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, enc_output = make_inputs(
            decoder_input_data_names, n_head, d_model, max_length, True, True,
            True, True)
583
    else:
584
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape = dec_inputs
Y
ying 已提交
585 586 587 588 589 590 591 592

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        trg_pad_idx,
        max_length,
593 594 595
        dropout_rate,
        pos_pad_idx,
        trg_data_shape, )
Y
ying 已提交
596 597 598 599 600 601 602 603 604 605 606 607 608
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate, )

609 610 611 612 613 614 615
    predict = layers.reshape(
        x=layers.fc(input=dec_output,
                    size=trg_vocab_size,
                    bias_attr=False,
                    num_flatten_dims=2),
        shape=[-1, trg_vocab_size],
        act="softmax")
616
    return predict