model.py 18.9 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7
from config import *
Y
ying 已提交
8

9 10

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
11
    """
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
30
                         n_head=1,
G
guosheng 已提交
31 32 33
                         dropout_rate=0.,
                         pre_softmax_shape=None,
                         post_softmax_shape=None):
34
    """
Y
ying 已提交
35 36 37
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
38 39 40
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
41
            "Inputs: quries, keys and values should all be 3-D tensors.")
42

G
guosheng 已提交
43
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
44
        """
45 46 47
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
48 49 50 51 52
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
53 54 55
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
56 57 58 59 60
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
61 62 63
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
64 65 66 67 68
                      size=d_value * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_value,
                          fan_out=n_head * d_value),
69 70 71 72
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
73
    def __split_heads(x, n_head):
74 75 76
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
77 78
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
79
        """
G
guosheng 已提交
80
        if n_head == 1:
81 82 83
            return x

        hidden_size = x.shape[-1]
84 85
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
86
        reshaped = layers.reshape(
87
            x=x, shape=[0, -1, n_head, hidden_size // n_head])
88 89

        # permuate the dimensions into:
G
guosheng 已提交
90
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
91 92 93 94 95 96 97 98 99 100 101 102
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
103 104
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
105 106
        return layers.reshape(
            x=trans_x,
107
            shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]]))
108

G
guosheng 已提交
109
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
110 111 112
        """
        Scaled Dot-Product Attention
        """
G
guosheng 已提交
113
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
114
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
G
guosheng 已提交
115 116 117 118 119 120 121 122
        weights = layers.reshape(
            x=layers.elementwise_add(
                x=product, y=attn_bias) if attn_bias else product,
            shape=[-1, product.shape[-1]],
            actual_shape=pre_softmax_shape,
            act="softmax")
        weights = layers.reshape(
            x=weights, shape=product.shape, actual_shape=post_softmax_shape)
123 124 125 126 127 128
        if dropout_rate:
            weights = layers.dropout(
                weights, dropout_prob=dropout_rate, is_test=False)
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
129
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
130

G
guosheng 已提交
131 132 133
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
134

G
guosheng 已提交
135
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
136 137 138 139 140 141 142
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)

    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
G
guosheng 已提交
143
                         param_attr=fluid.initializer.Xavier(uniform=False),
144 145 146 147 148 149 150
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
151 152 153
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
154 155 156 157
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
G
guosheng 已提交
158 159
                       param_attr=fluid.initializer.Uniform(
                           low=-(d_hid**-0.5), high=(d_hid**-0.5)),
160
                       act="relu")
G
guosheng 已提交
161 162 163 164 165
    out = layers.fc(input=hidden,
                    size=d_hid,
                    num_flatten_dims=2,
                    param_attr=fluid.initializer.Uniform(
                        low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)))
166 167 168
    return out


169
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
170
    """
Y
ying 已提交
171
    Add residual connection, layer normalization and droput to the out tensor
172 173 174 175 176
    optionally according to the value of process_cmd.
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
177
        if cmd == "a":  # add residual connection
178
            out = out + prev_out if prev_out else out
Y
ying 已提交
179
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
180 181 182 183 184
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
185
        elif cmd == "d":  # add dropout
186 187 188
            if dropout_rate:
                out = layers.dropout(
                    out, dropout_prob=dropout_rate, is_test=False)
189 190 191 192 193 194 195 196 197 198 199 200
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_max_len,
201 202
                    dropout_rate=0.,
                    src_data_shape=None,
203
                    pos_enc_param_name=None):
Y
ying 已提交
204 205
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
206
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
207
    This module is used at the bottom of the encoder stacks.
208 209
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
210 211 212
        src_word,
        size=[src_vocab_size, src_emb_dim],
        param_attr=fluid.initializer.Normal(0., 1.))
213 214 215 216 217 218
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
219 220 221 222
    enc_input = layers.reshape(
        x=enc_input,
        shape=[-1, src_max_len, src_emb_dim],
        actual_shape=src_data_shape)
223
    return layers.dropout(
224 225
        enc_input, dropout_prob=dropout_rate,
        is_test=False) if dropout_rate else enc_input
226 227 228 229 230 231 232 233


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
234 235 236 237 238 239 240
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
241 242 243
                  dropout_rate=0.,
                  pre_softmax_shape=None,
                  post_softmax_shape=None):
Y
ying 已提交
244 245 246 247 248
    """The encoder layers that can be stacked to form a deep encoder.
    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
249
    """
G
guosheng 已提交
250 251 252
    attn_output = multi_head_attention(
        enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model,
        n_head, dropout_rate, pre_softmax_shape, post_softmax_shape)
Y
ying 已提交
253 254
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
255
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
256 257 258 259 260 261 262 263 264 265 266
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
267 268 269
            dropout_rate=0.,
            pre_softmax_shape=None,
            post_softmax_shape=None):
270
    """
Y
ying 已提交
271 272
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
273 274
    """
    for i in range(n_layer):
275 276 277 278 279 280 281 282
        enc_output = encoder_layer(
            enc_input,
            attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
283 284 285
            dropout_rate,
            pre_softmax_shape,
            post_softmax_shape, )
286 287 288 289
        enc_input = enc_output
    return enc_output


Y
ying 已提交
290 291 292 293 294 295 296 297 298
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
299 300 301 302 303
                  dropout_rate=0.,
                  slf_attn_pre_softmax_shape=None,
                  slf_attn_post_softmax_shape=None,
                  src_attn_pre_softmax_shape=None,
                  src_attn_post_softmax_shape=None):
Y
ying 已提交
304 305 306
    """ The layer to be stacked in decoder part.
    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
307
    """
Y
ying 已提交
308 309 310 311 312 313 314 315 316
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
317 318 319
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape, )
Y
ying 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332 333
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
334 335 336
        dropout_rate,
        src_attn_pre_softmax_shape,
        src_attn_post_softmax_shape, )
Y
ying 已提交
337 338 339 340 341 342 343 344 345 346 347 348 349 350
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
351 352 353
    return dec_output


Y
ying 已提交
354 355 356 357 358 359 360 361 362 363
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
364 365 366 367 368
            dropout_rate=0.,
            slf_attn_pre_softmax_shape=None,
            slf_attn_post_softmax_shape=None,
            src_attn_pre_softmax_shape=None,
            src_attn_post_softmax_shape=None):
369 370 371 372
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
373 374 375 376 377 378 379 380 381 382
        dec_output = decoder_layer(
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
383 384 385 386 387
            dropout_rate,
            slf_attn_pre_softmax_shape,
            slf_attn_post_softmax_shape,
            src_attn_pre_softmax_shape,
            src_attn_post_softmax_shape, )
388 389 390 391
        dec_input = dec_output
    return dec_output


392
def make_all_inputs(input_fields):
393 394 395
    """
    Define the input data layers for the transformer model.
    """
396 397 398 399 400 401
    inputs = []
    for input_field in input_fields:
        input_var = layers.data(
            name=input_field,
            shape=input_descs[input_field][0],
            dtype=input_descs[input_field][1],
402
            append_batch_size=False)
403 404
        inputs.append(input_var)
    return inputs
405 406


Y
ying 已提交
407 408 409 410 411 412 413 414 415 416
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
417 418 419 420
        dropout_rate,
        label_smooth_eps, ):
    enc_inputs = make_all_inputs(encoder_data_input_fields +
                                 encoder_util_input_fields)
421

422 423 424 425 426 427 428 429 430 431
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
432
        enc_inputs, )
433

434 435
    dec_inputs = make_all_inputs(decoder_data_input_fields[:-1] +
                                 decoder_util_input_fields)
436 437 438 439 440 441 442 443 444 445 446

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
447
        dec_inputs,
448 449 450 451
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
452 453 454 455 456 457 458 459 460 461 462
    label, weights = make_all_inputs(label_data_input_fields)
    if label_smooth_eps:
        label = layers.label_smooth(
            label=layers.one_hot(
                input=label, depth=trg_vocab_size),
            epsilon=label_smooth_eps)
    cost = layers.softmax_with_cross_entropy(
        logits=predict,
        label=label,
        soft_label=True if label_smooth_eps else False)
    # cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
463
    weighted_cost = cost * weights
G
guosheng 已提交
464 465 466
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
    avg_cost = sum_cost / token_num
G
guosheng 已提交
467
    return sum_cost, avg_cost, predict, token_num
468 469 470 471 472 473 474 475 476 477 478


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
479
                 enc_inputs=None):
480 481 482
    """
    The wrapper assembles together all needed layers for the encoder.
    """
483
    if enc_inputs is None:
484
        # This is used to implement independent encoder program in inference.
485 486
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
487 488
            make_all_inputs(encoder_data_input_fields +
                                 encoder_util_input_fields)
489
    else:
490 491 492
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
            enc_inputs
Y
ying 已提交
493 494 495 496 497 498
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
499 500
        dropout_rate,
        src_data_shape, )
Y
ying 已提交
501 502 503 504 505 506 507 508 509
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
510 511 512
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape, )
513 514 515 516 517 518 519 520 521 522 523 524
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
525
                 dec_inputs=None,
526 527 528 529
                 enc_output=None):
    """
    The wrapper assembles together all needed layers for the decoder.
    """
530
    if dec_inputs is None:
531
        # This is used to implement independent decoder program in inference.
G
guosheng 已提交
532
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
533
            enc_output, trg_data_shape, slf_attn_pre_softmax_shape, \
534
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
535 536
            src_attn_post_softmax_shape = make_all_inputs(
            decoder_data_input_fields + decoder_util_input_fields)
537
    else:
G
guosheng 已提交
538
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
539 540 541
            trg_data_shape, slf_attn_pre_softmax_shape, \
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
            src_attn_post_softmax_shape = dec_inputs
Y
ying 已提交
542 543 544 545 546 547 548

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
549 550
        dropout_rate,
        trg_data_shape, )
Y
ying 已提交
551 552 553 554 555 556 557 558 559 560 561
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
562 563 564 565 566
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape,
        src_attn_pre_softmax_shape,
        src_attn_post_softmax_shape, )
567
    # Return logits for training and probs for inference.
568
    predict = layers.reshape(
G
guosheng 已提交
569 570 571 572 573
        x=layers.fc(input=dec_output,
                    size=trg_vocab_size,
                    bias_attr=False,
                    num_flatten_dims=2),
        shape=[-1, trg_vocab_size],
574
        act="softmax" if dec_inputs is None else None)
575
    return predict