transformer_model.py 15.8 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
16

Y
Yu Yang 已提交
17 18
import numpy as np

19
import paddle
Y
Yu Yang 已提交
20 21 22 23 24
import paddle.fluid as fluid
import paddle.fluid.layers as layers

pos_enc_param_names = (
    "src_pos_enc_table",
25 26
    "trg_pos_enc_table",
)
Y
Yu Yang 已提交
27

W
Wu Yi 已提交
28
batch_size = 2
Y
Yu Yang 已提交
29 30 31 32 33 34


def position_encoding_init(n_position, d_pos_vec):
    """
    Generate the initial values for the sinusoid position encoding table.
    """
35 36 37 38 39 40 41 42 43 44 45
    position_enc = np.array(
        [
            [
                pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
                for j in range(d_pos_vec)
            ]
            if pos != 0
            else np.zeros(d_pos_vec)
            for pos in range(n_position)
        ]
    )
Y
Yu Yang 已提交
46 47 48 49 50
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


51 52 53 54 55 56 57 58 59 60 61
def multi_head_attention(
    queries,
    keys,
    values,
    attn_bias,
    d_key,
    d_value,
    d_model,
    n_head=1,
    dropout_rate=0.0,
):
Y
Yu Yang 已提交
62 63 64 65 66 67 68
    """
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
69 70
            "Inputs: queries, keys and values should all be 3-D tensors."
        )
Y
Yu Yang 已提交
71 72 73 74 75

    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
        """
        Add linear projection to queries, keys, and values.
        """
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
        q = layers.fc(
            input=queries,
            size=d_key * n_head,
            param_attr=fluid.initializer.Xavier(
                uniform=False, fan_in=d_model * d_key, fan_out=n_head * d_key
            ),
            bias_attr=False,
            num_flatten_dims=2,
        )
        k = layers.fc(
            input=keys,
            size=d_key * n_head,
            param_attr=fluid.initializer.Xavier(
                uniform=False, fan_in=d_model * d_key, fan_out=n_head * d_key
            ),
            bias_attr=False,
            num_flatten_dims=2,
        )
        v = layers.fc(
            input=values,
            size=d_value * n_head,
            param_attr=fluid.initializer.Xavier(
                uniform=False,
                fan_in=d_model * d_value,
                fan_out=n_head * d_value,
            ),
            bias_attr=False,
            num_flatten_dims=2,
        )
Y
Yu Yang 已提交
105 106 107 108
        return q, k, v

    def __split_heads(x, n_head):
        """
T
tianshuo78520a 已提交
109
        Reshape the last dimension of input tensor x so that it becomes two
Y
Yu Yang 已提交
110 111 112 113 114 115 116 117 118
        dimensions and then transpose. Specifically, input a tensor with shape
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
        """
        if n_head == 1:
            return x

        hidden_size = x.shape[-1]
        # FIXME(guosheng): Decouple the program desc with batch_size.
119
        reshaped = paddle.reshape(
120 121
            x=x, shape=[batch_size, -1, n_head, hidden_size // n_head]
        )
Y
Yu Yang 已提交
122

T
tianshuo78520a 已提交
123
        # permute the dimensions into:
Y
Yu Yang 已提交
124
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
125
        return paddle.transpose(x=reshaped, perm=[0, 2, 1, 3])
Y
Yu Yang 已提交
126 127 128

    def __combine_heads(x):
        """
T
tianshuo78520a 已提交
129
        Transpose and then reshape the last two dimensions of input tensor x
Y
Yu Yang 已提交
130 131
        so that it becomes one dimension, which is reverse to __split_heads.
        """
132 133
        if len(x.shape) == 3:
            return x
Y
Yu Yang 已提交
134 135 136
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

137
        trans_x = paddle.transpose(x, perm=[0, 2, 1, 3])
Y
Yu Yang 已提交
138
        # FIXME(guosheng): Decouple the program desc with batch_size.
139
        return paddle.reshape(
Y
Yu Yang 已提交
140
            x=trans_x,
141
            shape=list(
142 143 144
                map(int, [batch_size, -1, trans_x.shape[2] * trans_x.shape[3]])
            ),
        )
Y
Yu Yang 已提交
145 146 147 148 149 150 151 152 153 154 155

    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
        """
        Scaled Dot-Product Attention
        """

        # FIXME(guosheng): Optimize the shape in reshape_op or softmax_op.

        # The current implementation of softmax_op only supports 2D tensor,
        # consequently it cannot be directly used here.
        # If to use the reshape_op, Besides, the shape of product inferred in
翟飞跃 已提交
156
        # compile-time is not the actual shape in run-time. It can't be used
Y
Yu Yang 已提交
157 158 159 160
        # to set the attribute of reshape_op.
        # So, here define the softmax for temporary solution.

        def __softmax(x, eps=1e-9):
161
            exp_out = paddle.exp(x=x)
162
            sum_out = paddle.sum(exp_out, axis=-1, keepdim=False)
Y
Yu Yang 已提交
163 164
            return layers.elementwise_div(x=exp_out, y=sum_out, axis=0)

2
201716010711 已提交
165
        scaled_q = paddle.scale(x=q, scale=d_model**-0.5)
K
kangguangli 已提交
166
        product = paddle.matmul(x=scaled_q, y=k, transpose_y=True)
167
        weights = __softmax(paddle.add(x=product, y=attn_bias))
Y
Yu Yang 已提交
168
        if dropout_rate:
169 170 171
            weights = layers.dropout(
                weights, dropout_prob=dropout_rate, is_test=False
            )
K
kangguangli 已提交
172
        out = paddle.matmul(weights, v)
Y
Yu Yang 已提交
173 174 175 176 177 178 179 180
        return out

    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)

    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)

181 182 183
    ctx_multiheads = scaled_dot_product_attention(
        q, k, v, attn_bias, d_model, dropout_rate
    )
Y
Yu Yang 已提交
184 185 186 187

    out = __combine_heads(ctx_multiheads)

    # Project back to the model size.
188 189 190 191 192 193 194
    proj_out = layers.fc(
        input=out,
        size=d_model,
        param_attr=fluid.initializer.Xavier(uniform=False),
        bias_attr=False,
        num_flatten_dims=2,
    )
Y
Yu Yang 已提交
195 196 197 198 199 200 201 202 203
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
    """
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
    hidden = layers.fc(
        input=x,
        size=d_inner_hid,
        num_flatten_dims=2,
        param_attr=fluid.initializer.Uniform(
            low=-(d_hid**-0.5), high=(d_hid**-0.5)
        ),
        act="relu",
    )
    out = layers.fc(
        input=hidden,
        size=d_hid,
        num_flatten_dims=2,
        param_attr=fluid.initializer.Uniform(
            low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)
        ),
    )
Y
Yu Yang 已提交
221 222 223
    return out


224
def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.0):
Y
Yu Yang 已提交
225 226 227 228 229 230 231 232 233 234 235
    """
    Add residual connection, layer normalization and droput to the out tensor
    optionally according to the value of process_cmd.

    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
        if cmd == "a":  # add residual connection
            out = out + prev_out if prev_out else out
        elif cmd == "n":  # add layer normalization
236 237 238 239 240 241
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.0),
                bias_attr=fluid.initializer.Constant(0.0),
            )
Y
Yu Yang 已提交
242 243 244 245 246 247 248 249 250 251
        elif cmd == "d":  # add dropout
            if dropout:
                out = layers.dropout(out, dropout_prob=dropout, is_test=False)
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


252 253 254 255 256 257 258 259 260 261 262
def prepare_encoder(
    src_word,
    src_pos,
    src_vocab_size,
    src_emb_dim,
    src_pad_idx,
    src_max_len,
    dropout=0.0,
    pos_pad_idx=0,
    pos_enc_param_name=None,
):
Y
Yu Yang 已提交
263 264 265 266 267 268
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
    [batch_size, max_src_length_in_batch, d_model].

    This module is used at the bottom of the encoder stacks.
    """
269 270 271 272 273 274
    src_word_emb = layers.embedding(
        src_word,
        size=[src_vocab_size, src_emb_dim],
        padding_idx=src_pad_idx,
        param_attr=fluid.initializer.Normal(0.0, 1.0),
    )
Y
Yu Yang 已提交
275 276 277 278
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        padding_idx=pos_pad_idx,
279 280
        param_attr=fluid.ParamAttr(name=pos_enc_param_name, trainable=False),
    )
C
chengduo 已提交
281
    src_pos_enc.stop_gradient = True
Y
Yu Yang 已提交
282 283 284
    enc_input = src_word_emb + src_pos_enc

    # FIXME(guosheng): Decouple the program desc with batch_size.
285
    enc_input = paddle.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim])
286 287 288 289 290
    return (
        layers.dropout(enc_input, dropout_prob=dropout, is_test=False)
        if dropout
        else enc_input
    )
Y
Yu Yang 已提交
291 292


293 294 295 296 297 298
prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0]
)
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1]
)
Y
Yu Yang 已提交
299 300


301 302 303 304 305 306 307 308 309 310
def encoder_layer(
    enc_input,
    attn_bias,
    n_head,
    d_key,
    d_value,
    d_model,
    d_inner_hid,
    dropout_rate=0.0,
):
Y
Yu Yang 已提交
311 312 313 314 315 316 317
    """The encoder layers that can be stacked to form a deep encoder.

    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
    """
318 319 320 321 322 323 324 325 326 327 328 329 330 331
    attn_output = multi_head_attention(
        enc_input,
        enc_input,
        enc_input,
        attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
        dropout_rate,
    )
    attn_output = post_process_layer(
        enc_input, attn_output, "dan", dropout_rate
    )
Y
Yu Yang 已提交
332 333 334 335
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
def encoder(
    enc_input,
    attn_bias,
    n_layer,
    n_head,
    d_key,
    d_value,
    d_model,
    d_inner_hid,
    dropout_rate=0.0,
):
    """
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
    """
    for i in range(n_layer):
        enc_output = encoder_layer(
            enc_input,
Y
Yu Yang 已提交
354 355 356 357 358 359
            attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
360 361
            dropout_rate,
        )
Y
Yu Yang 已提交
362 363 364 365
        enc_input = enc_output
    return enc_output


366 367 368 369 370 371 372 373 374 375 376 377 378
def decoder_layer(
    dec_input,
    enc_output,
    slf_attn_bias,
    dec_enc_attn_bias,
    n_head,
    d_key,
    d_value,
    d_model,
    d_inner_hid,
    dropout_rate=0.0,
):
    """The layer to be stacked in decoder part.
Y
Yu Yang 已提交
379 380 381 382 383 384 385 386 387 388 389 390 391

    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
    """
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
392 393
        dropout_rate,
    )
Y
Yu Yang 已提交
394 395 396 397
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
398 399
        dropout_rate,
    )
Y
Yu Yang 已提交
400 401 402 403 404 405 406 407 408
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
409 410
        dropout_rate,
    )
Y
Yu Yang 已提交
411 412 413 414
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
415 416
        dropout_rate,
    )
Y
Yu Yang 已提交
417 418 419
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
420 421
        d_model,
    )
Y
Yu Yang 已提交
422 423 424 425
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
426 427
        dropout_rate,
    )
Y
Yu Yang 已提交
428 429 430
    return dec_output


431 432 433 434 435 436 437 438 439 440 441 442 443
def decoder(
    dec_input,
    enc_output,
    dec_slf_attn_bias,
    dec_enc_attn_bias,
    n_layer,
    n_head,
    d_key,
    d_value,
    d_model,
    d_inner_hid,
    dropout_rate=0.0,
):
Y
Yu Yang 已提交
444 445 446 447 448 449 450 451 452 453 454 455 456 457
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
        dec_output = decoder_layer(
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
458 459
            dropout_rate,
        )
Y
Yu Yang 已提交
460 461 462 463
        dec_input = dec_output
    return dec_output


464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
def build_inputs(max_length, n_head):
    names = [
        'src_word',
        'src_pos',
        'trg_word',
        'trg_pos',
        'src_slf_attn_bias',
        'trg_slf_attn_bias',
        'trg_src_attn_bias',
        'gold',
        'weights',
    ]

    shapes = [
        [batch_size * max_length, 1],
        [batch_size * max_length, 1],
        [batch_size * max_length, 1],
        [batch_size * max_length, 1],
        [batch_size, n_head, max_length, max_length],
        [batch_size, n_head, max_length, max_length],
        [batch_size, n_head, max_length, max_length],
        [batch_size * max_length, 1],
        [batch_size * max_length, 1],
    ]

    dtypes = [
        'int64',
        'int64',
        'int64',
        'int64',
        'float32',
        'float32',
        'float32',
        'int64',
        'float32',
    ]

    all_inputs = []
    for name, shape, dtype in zip(names, shapes, dtypes):
        all_inputs.append(
504 505 506 507
            fluid.layers.data(
                name=name, shape=shape, dtype=dtype, append_batch_size=False
            )
        )
508 509 510
    return all_inputs


Y
Yu Yang 已提交
511
def transformer(
512 513 514 515 516 517 518 519 520 521 522 523 524 525
    src_vocab_size,
    trg_vocab_size,
    max_length,
    n_layer,
    n_head,
    d_key,
    d_value,
    d_model,
    d_inner_hid,
    dropout_rate,
    src_pad_idx,
    trg_pad_idx,
    pos_pad_idx,
):
526

527 528 529 530 531 532 533 534 535 536 537
    (
        src_word,
        src_pos,
        trg_word,
        trg_pos,
        src_slf_attn_bias,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        gold,
        weights,
    ) = build_inputs(max_length, n_head)
Y
Yu Yang 已提交
538 539 540 541 542 543 544 545

    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        src_pad_idx,
        max_length,
546 547
        dropout_rate,
    )
Y
Yu Yang 已提交
548 549 550 551 552 553 554 555 556
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
557 558
        dropout_rate,
    )
Y
Yu Yang 已提交
559 560 561 562 563 564 565 566

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        trg_pad_idx,
        max_length,
567 568
        dropout_rate,
    )
Y
Yu Yang 已提交
569 570 571 572 573 574 575 576 577 578 579
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
580 581
        dropout_rate,
    )
Y
Yu Yang 已提交
582 583 584

    # TODO(guosheng): Share the weight matrix between the embedding layers and
    # the pre-softmax linear transformation.
585
    predict = paddle.reshape(
586 587 588 589 590 591 592 593 594
        x=layers.fc(
            input=dec_output,
            size=trg_vocab_size,
            param_attr=fluid.initializer.Xavier(uniform=False),
            bias_attr=False,
            num_flatten_dims=2,
        ),
        shape=[-1, trg_vocab_size],
    )
595
    predict = paddle.nn.functional.softmax(predict)
Y
Yu Yang 已提交
596

597 598 599
    cost = paddle.nn.functional.cross_entropy(
        input=predict, label=gold, reduction='none', use_softmax=False
    )
Y
Yu Yang 已提交
600
    weighted_cost = cost * weights
601
    return paddle.sum(weighted_cost)