model.py 18.9 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7 8
from config import TrainTaskConfig, pos_enc_param_names, \
    encoder_input_data_names, decoder_input_data_names, label_data_names
Y
ying 已提交
9 10 11

# FIXME(guosheng): Remove out the batch_size from the model.
batch_size = TrainTaskConfig.batch_size
12 13 14


def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
15
    """
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
34
                         n_head=1,
35 36
                         dropout_rate=0.):
    """
Y
ying 已提交
37 38 39
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
40 41 42
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
43
            "Inputs: quries, keys and values should all be 3-D tensors.")
44

G
guosheng 已提交
45
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
46
        """
47 48 49
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
50 51 52 53 54
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
55 56 57
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
58 59 60 61 62
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
63 64 65
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
66 67 68 69 70
                      size=d_value * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_value,
                          fan_out=n_head * d_value),
71 72 73 74
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
75
    def __split_heads(x, n_head):
76 77 78
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
79 80
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
81
        """
G
guosheng 已提交
82
        if n_head == 1:
83 84 85
            return x

        hidden_size = x.shape[-1]
Y
ying 已提交
86
        # FIXME(guosheng): Decouple the program desc with batch_size.
87
        reshaped = layers.reshape(
G
guosheng 已提交
88
            x=x, shape=[batch_size, -1, n_head, hidden_size // n_head])
89 90

        # permuate the dimensions into:
G
guosheng 已提交
91
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
92 93 94 95 96 97 98 99 100 101 102 103
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
Y
ying 已提交
104
        # FIXME(guosheng): Decouple the program desc with batch_size.
105 106 107 108 109
        return layers.reshape(
            x=trans_x,
            shape=map(int,
                      [batch_size, -1, trans_x.shape[2] * trans_x.shape[3]]))

G
guosheng 已提交
110
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
111 112 113 114
        """
        Scaled Dot-Product Attention
        """

Y
ying 已提交
115 116 117 118 119 120 121 122 123
        # FIXME(guosheng): Optimize the shape in reshape_op or softmax_op.

        # The current implementation of softmax_op only supports 2D tensor,
        # consequently it cannot be directly used here.
        # If to use the reshape_op, Besides, the shape of product inferred in
        # compile-time is not the actual shape in run-time. It cann't be used
        # to set the attribute of reshape_op.
        # So, here define the softmax for temporary solution.

124 125
        def __softmax(x, eps=1e-9):
            exp_out = layers.exp(x=x)
G
guosheng 已提交
126
            sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False)
127 128
            return layers.elementwise_div(x=exp_out, y=sum_out, axis=0)

G
guosheng 已提交
129
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
130
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
131 132 133
        weights = __softmax(
            layers.elementwise_add(
                x=product, y=attn_bias) if attn_bias else product)
134 135 136 137 138 139
        if dropout_rate:
            weights = layers.dropout(
                weights, dropout_prob=dropout_rate, is_test=False)
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
140
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
141

G
guosheng 已提交
142 143 144
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
145

G
guosheng 已提交
146
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
147 148 149 150 151 152 153
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)

    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
G
guosheng 已提交
154
                         param_attr=fluid.initializer.Xavier(uniform=False),
155 156 157 158 159 160 161
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
162 163 164
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
165 166 167 168
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
G
guosheng 已提交
169 170
                       param_attr=fluid.initializer.Uniform(
                           low=-(d_hid**-0.5), high=(d_hid**-0.5)),
171
                       act="relu")
G
guosheng 已提交
172 173 174 175 176
    out = layers.fc(input=hidden,
                    size=d_hid,
                    num_flatten_dims=2,
                    param_attr=fluid.initializer.Uniform(
                        low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)))
177 178 179 180 181
    return out


def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.):
    """
Y
ying 已提交
182
    Add residual connection, layer normalization and droput to the out tensor
183
    optionally according to the value of process_cmd.
Y
ying 已提交
184

185 186 187 188
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
189
        if cmd == "a":  # add residual connection
190
            out = out + prev_out if prev_out else out
Y
ying 已提交
191
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
192 193 194 195 196
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
197
        elif cmd == "d":  # add dropout
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
            if dropout:
                out = layers.dropout(out, dropout_prob=dropout, is_test=False)
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_pad_idx,
                    src_max_len,
                    dropout=0.,
                    pos_pad_idx=0,
                    pos_enc_param_name=None):
Y
ying 已提交
216 217
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
218
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
219 220

    This module is used at the bottom of the encoder stacks.
221 222
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
223 224 225 226
        src_word,
        size=[src_vocab_size, src_emb_dim],
        padding_idx=src_pad_idx,
        param_attr=fluid.initializer.Normal(0., 1.))
227 228 229
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
G
guosheng 已提交
230
        padding_idx=pos_pad_idx,
231 232 233
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
Y
ying 已提交
234 235

    # FIXME(guosheng): Decouple the program desc with batch_size.
236 237 238 239 240 241 242 243 244 245 246 247
    enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim])
    return layers.dropout(
        enc_input, dropout_prob=dropout,
        is_test=False) if dropout else enc_input


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
248 249 250 251 252 253 254 255 256 257 258 259 260 261
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
                  dropout_rate=0.):
    """The encoder layers that can be stacked to form a deep encoder.

    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
262 263 264
    """
    attn_output = multi_head_attention(enc_input, enc_input, enc_input,
                                       attn_bias, d_key, d_value, d_model,
Y
ying 已提交
265 266 267
                                       n_head, dropout_rate)
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
268
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
269 270 271 272 273 274 275 276 277 278 279 280
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate=0.):
281
    """
Y
ying 已提交
282 283
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
284 285 286
    """
    for i in range(n_layer):
        enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
Y
ying 已提交
287
                                   d_model, d_inner_hid, dropout_rate)
288 289 290 291
        enc_input = enc_output
    return enc_output


Y
ying 已提交
292 293 294 295 296 297 298 299 300 301 302 303 304 305
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
                  dropout_rate=0.):
    """ The layer to be stacked in decoder part.

    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
306
    """
Y
ying 已提交
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
        dropout_rate, )
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
        dropout_rate, )
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
346 347 348
    return dec_output


Y
ying 已提交
349 350 351 352 353 354 355 356 357 358 359
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate=0.):
360 361 362 363
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
364 365 366 367 368 369 370 371 372 373 374
        dec_output = decoder_layer(
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate, )
375 376 377 378
        dec_input = dec_output
    return dec_output


379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
def make_inputs(input_data_names,
                n_head,
                d_model,
                batch_size,
                max_length,
                slf_attn_bias_flag,
                src_attn_bias_flag,
                pos_flag=1):
    """
    Define the input data layers for the transformer model.
    """
    input_layers = []
    # The shapes here act as placeholder.
    # The shapes set here is to pass the infer-shape in compile time.
    word = layers.data(
        name=input_data_names[0],
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
    input_layers += [word]
    # This is used for position data or label weight.
    pos = layers.data(
        name=input_data_names[1],
        shape=[batch_size * max_length, 1],
        dtype="int64" if pos_flag else "float32",
        append_batch_size=False)
    input_layers += [pos]
    if slf_attn_bias_flag:
        # This is used for attention bias or encoder output.
        slf_attn_bias = layers.data(
            name=input_data_names[2]
            if slf_attn_bias_flag == 1 else input_data_names[-1],
            shape=[batch_size, n_head, max_length, max_length]
            if slf_attn_bias_flag == 1 else [batch_size, max_length, d_model],
            dtype="float32",
            append_batch_size=False)
        input_layers += [slf_attn_bias]
    if src_attn_bias_flag:
        src_attn_bias = layers.data(
            name=input_data_names[3],
            shape=[batch_size, n_head, max_length, max_length],
            dtype="float32",
            append_batch_size=False)
        input_layers += [src_attn_bias]
    return input_layers


Y
ying 已提交
426 427 428 429 430 431 432 433 434 435 436 437 438 439
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
        src_pad_idx,
        trg_pad_idx,
        pos_pad_idx, ):
440 441
    enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model,
                                   batch_size, max_length, 1, 0)
442

443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
        src_pad_idx,
        pos_pad_idx,
        enc_input_layers, )

    dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model,
                                   batch_size, max_length, 1, 1)

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
        trg_pad_idx,
        pos_pad_idx,
        dec_input_layers,
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
    gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size,
                                max_length, 0, 0, 0)
    cost = layers.cross_entropy(input=predict, label=gold)
    weighted_cost = cost * weights
    return layers.reduce_sum(weighted_cost), predict


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
                 src_pad_idx,
                 pos_pad_idx,
                 enc_input_layers=None):
    """
    The wrapper assembles together all needed layers for the encoder.
    """
    if enc_input_layers is None:
        # This is used to implement independent encoder program in inference.
        src_word, src_pos, src_slf_attn_bias = make_inputs(
            encoder_input_data_names, n_head, d_model, batch_size, max_length,
            True, False)
    else:
        src_word, src_pos, src_slf_attn_bias = enc_input_layers
Y
ying 已提交
506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        src_pad_idx,
        max_length,
        dropout_rate, )
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate, )
524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
                 trg_pad_idx,
                 pos_pad_idx,
                 dec_input_layers=None,
                 enc_output=None):
    """
    The wrapper assembles together all needed layers for the decoder.
    """
    if dec_input_layers is None:
        # This is used to implement independent decoder program in inference.
        # No need for trg_slf_attn_bias because of no paddings in inference.
        trg_word, trg_pos, enc_output, trg_src_attn_bias = make_inputs(
            decoder_input_data_names, n_head, d_model, batch_size, max_length,
            2, 1)
        trg_slf_attn_bias = None
    else:
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers
Y
ying 已提交
552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        trg_pad_idx,
        max_length,
        dropout_rate, )
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate, )

574 575 576 577 578 579 580
    predict = layers.reshape(
        x=layers.fc(input=dec_output,
                    size=trg_vocab_size,
                    bias_attr=False,
                    num_flatten_dims=2),
        shape=[-1, trg_vocab_size],
        act="softmax")
581
    return predict