model.py 17.5 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6 7 8 9 10

from config import TrainTaskConfig, input_data_names, pos_enc_param_names

# FIXME(guosheng): Remove out the batch_size from the model.
batch_size = TrainTaskConfig.batch_size
11 12 13


def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
14
    """
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
33
                         n_head=1,
34 35
                         dropout_rate=0.):
    """
Y
ying 已提交
36 37 38
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
39 40 41
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
42
            "Inputs: quries, keys and values should all be 3-D tensors.")
43

G
guosheng 已提交
44
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
45
        """
46 47 48
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
49 50 51 52 53
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
54 55 56
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
57 58 59 60 61
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
62 63 64
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
65 66 67 68 69
                      size=d_value * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_value,
                          fan_out=n_head * d_value),
70 71 72 73
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
74
    def __split_heads(x, n_head):
75 76 77
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
78 79
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
80
        """
G
guosheng 已提交
81
        if n_head == 1:
82 83 84
            return x

        hidden_size = x.shape[-1]
Y
ying 已提交
85
        # FIXME(guosheng): Decouple the program desc with batch_size.
86
        reshaped = layers.reshape(
G
guosheng 已提交
87
            x=x, shape=[batch_size, -1, n_head, hidden_size // n_head])
88 89

        # permuate the dimensions into:
G
guosheng 已提交
90
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
91 92 93 94 95 96 97 98 99 100 101 102
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
Y
ying 已提交
103
        # FIXME(guosheng): Decouple the program desc with batch_size.
104 105 106 107 108
        return layers.reshape(
            x=trans_x,
            shape=map(int,
                      [batch_size, -1, trans_x.shape[2] * trans_x.shape[3]]))

G
guosheng 已提交
109
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
110 111 112 113
        """
        Scaled Dot-Product Attention
        """

Y
ying 已提交
114 115 116 117 118 119 120 121 122
        # FIXME(guosheng): Optimize the shape in reshape_op or softmax_op.

        # The current implementation of softmax_op only supports 2D tensor,
        # consequently it cannot be directly used here.
        # If to use the reshape_op, Besides, the shape of product inferred in
        # compile-time is not the actual shape in run-time. It cann't be used
        # to set the attribute of reshape_op.
        # So, here define the softmax for temporary solution.

123 124
        def __softmax(x, eps=1e-9):
            exp_out = layers.exp(x=x)
G
guosheng 已提交
125
            sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False)
126 127
            return layers.elementwise_div(x=exp_out, y=sum_out, axis=0)

G
guosheng 已提交
128
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
129 130 131 132 133 134 135 136
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
        weights = __softmax(layers.elementwise_add(x=product, y=attn_bias))
        if dropout_rate:
            weights = layers.dropout(
                weights, dropout_prob=dropout_rate, is_test=False)
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
137
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
138

G
guosheng 已提交
139 140 141
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
142

G
guosheng 已提交
143
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
144 145 146 147 148 149 150
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)

    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
G
guosheng 已提交
151
                         param_attr=fluid.initializer.Xavier(uniform=False),
152 153 154 155 156 157 158
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
159 160 161
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
162 163 164 165
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
G
guosheng 已提交
166 167
                       param_attr=fluid.initializer.Uniform(
                           low=-(d_hid**-0.5), high=(d_hid**-0.5)),
168
                       act="relu")
G
guosheng 已提交
169 170 171 172 173
    out = layers.fc(input=hidden,
                    size=d_hid,
                    num_flatten_dims=2,
                    param_attr=fluid.initializer.Uniform(
                        low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)))
174 175 176 177 178
    return out


def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.):
    """
Y
ying 已提交
179
    Add residual connection, layer normalization and droput to the out tensor
180
    optionally according to the value of process_cmd.
Y
ying 已提交
181

182 183 184 185
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
186
        if cmd == "a":  # add residual connection
187
            out = out + prev_out if prev_out else out
Y
ying 已提交
188
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
189 190 191 192 193
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
194
        elif cmd == "d":  # add dropout
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
            if dropout:
                out = layers.dropout(out, dropout_prob=dropout, is_test=False)
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_pad_idx,
                    src_max_len,
                    dropout=0.,
                    pos_pad_idx=0,
                    pos_enc_param_name=None):
Y
ying 已提交
213 214
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
215
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
216 217

    This module is used at the bottom of the encoder stacks.
218 219
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
220 221 222 223
        src_word,
        size=[src_vocab_size, src_emb_dim],
        padding_idx=src_pad_idx,
        param_attr=fluid.initializer.Normal(0., 1.))
224 225 226
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
G
guosheng 已提交
227
        padding_idx=pos_pad_idx,
228 229 230
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
Y
ying 已提交
231 232

    # FIXME(guosheng): Decouple the program desc with batch_size.
233 234 235 236 237 238 239 240 241 242 243 244
    enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim])
    return layers.dropout(
        enc_input, dropout_prob=dropout,
        is_test=False) if dropout else enc_input


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
245 246 247 248 249 250 251 252 253 254 255 256 257 258
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
                  dropout_rate=0.):
    """The encoder layers that can be stacked to form a deep encoder.

    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
259 260 261
    """
    attn_output = multi_head_attention(enc_input, enc_input, enc_input,
                                       attn_bias, d_key, d_value, d_model,
Y
ying 已提交
262 263 264
                                       n_head, dropout_rate)
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
265
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
266 267 268 269 270 271 272 273 274 275 276 277
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate=0.):
278
    """
Y
ying 已提交
279 280
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
281 282 283
    """
    for i in range(n_layer):
        enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
Y
ying 已提交
284
                                   d_model, d_inner_hid, dropout_rate)
285 286 287 288
        enc_input = enc_output
    return enc_output


Y
ying 已提交
289 290 291 292 293 294 295 296 297 298 299 300 301 302
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
                  dropout_rate=0.):
    """ The layer to be stacked in decoder part.

    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
303
    """
Y
ying 已提交
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
        dropout_rate, )
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
        dropout_rate, )
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
343 344 345
    return dec_output


Y
ying 已提交
346 347 348 349 350 351 352 353 354 355 356
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate=0.):
357 358 359 360
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
361 362 363 364 365 366 367 368 369 370 371
        dec_output = decoder_layer(
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate, )
372 373 374 375
        dec_input = dec_output
    return dec_output


Y
ying 已提交
376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
        src_pad_idx,
        trg_pad_idx,
        pos_pad_idx, ):
    # The shapes here act as placeholder.
    # The shapes set here is to pass the infer-shape in compile time. The actual
    # shape of src_word in run time is:
    # [batch_size * max_src_length_in_a_batch, 1].
394 395 396 397 398
    src_word = layers.data(
        name=input_data_names[0],
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
Y
ying 已提交
399 400
    # The actual shape of src_pos in runtime is:
    # [batch_size * max_src_length_in_a_batch, 1].
401 402 403 404 405
    src_pos = layers.data(
        name=input_data_names[1],
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
Y
ying 已提交
406 407
    # The actual shape of trg_word is in runtime is:
    # [batch_size * max_trg_length_in_a_batch, 1].
408 409 410 411 412
    trg_word = layers.data(
        name=input_data_names[2],
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
Y
ying 已提交
413 414
    # The actual shape of trg_pos in runtime is:
    # [batch_size * max_trg_length_in_a_batch, 1].
415 416 417 418 419
    trg_pos = layers.data(
        name=input_data_names[3],
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
Y
ying 已提交
420 421 422
    # The actual shape of src_slf_attn_bias in runtime is:
    # [batch_size, n_head, max_src_length_in_a_batch, max_src_length_in_a_batch].
    # This input is used to remove attention weights on paddings.
423 424 425 426 427
    src_slf_attn_bias = layers.data(
        name=input_data_names[4],
        shape=[batch_size, n_head, max_length, max_length],
        dtype="float32",
        append_batch_size=False)
Y
ying 已提交
428
    # The actual shape of trg_slf_attn_bias in runtime is:
429
    # [batch_size, n_head, max_trg_length_in_batch, max_trg_length_in_batch].
Y
ying 已提交
430
    # This is used to remove attention weights on paddings and subsequent words.
431 432 433 434 435
    trg_slf_attn_bias = layers.data(
        name=input_data_names[5],
        shape=[batch_size, n_head, max_length, max_length],
        dtype="float32",
        append_batch_size=False)
Y
ying 已提交
436
    # The actual shape of trg_src_attn_bias in runtime is:
437
    # [batch_size, n_head, max_trg_length_in_batch, max_src_length_in_batch].
Y
ying 已提交
438
    # This is used to remove attention weights on paddings.
439 440 441 442 443 444
    trg_src_attn_bias = layers.data(
        name=input_data_names[6],
        shape=[batch_size, n_head, max_length, max_length],
        dtype="float32",
        append_batch_size=False)

Y
ying 已提交
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        src_pad_idx,
        max_length,
        dropout_rate, )
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate, )

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        trg_pad_idx,
        max_length,
        dropout_rate, )
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate, )

    # TODO(guosheng): Share the weight matrix between the embedding layers and
486 487 488 489
    # the pre-softmax linear transformation.
    predict = layers.reshape(
        x=layers.fc(input=dec_output,
                    size=trg_vocab_size,
G
guosheng 已提交
490
                    param_attr=fluid.initializer.Xavier(uniform=False),
491 492 493 494
                    bias_attr=False,
                    num_flatten_dims=2),
        shape=[-1, trg_vocab_size],
        act="softmax")
Y
ying 已提交
495 496
    # The actual shape of gold in runtime is:
    # [batch_size * max_trg_length_in_a_batch, 1].
497 498 499 500 501 502
    gold = layers.data(
        name=input_data_names[7],
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
    cost = layers.cross_entropy(input=predict, label=gold)
503 504
    # The actual shape of weights in runtime is:
    # [batch_size * max_trg_length_in_a_batch, 1].
G
guosheng 已提交
505 506
    # Padding index do not contribute to the total loss. This Weight is used to
    # cancel padding index in calculating the loss.
507 508 509 510 511 512 513
    weights = layers.data(
        name=input_data_names[8],
        shape=[batch_size * max_length, 1],
        dtype="float32",
        append_batch_size=False)
    weighted_cost = cost * weights
    return layers.reduce_sum(weighted_cost)