model.py 16.3 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

4 5 6
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import paddle.v2.fluid.layers as layers
Y
ying 已提交
7 8 9 10 11

from config import TrainTaskConfig, input_data_names, pos_enc_param_names

# FIXME(guosheng): Remove out the batch_size from the model.
batch_size = TrainTaskConfig.batch_size
12 13 14


def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
15
    """
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
                         num_heads=1,
                         dropout_rate=0.):
    """
Y
ying 已提交
37 38 39
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
40 41 42
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
43
            "Inputs: quries, keys and values should all be 3-D tensors.")
44 45

    def __compute_qkv(queries, keys, values, num_heads, d_key, d_value):
Y
ying 已提交
46
        """
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
                      size=d_key * num_heads,
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
                      size=d_key * num_heads,
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
                      size=d_value * num_heads,
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

    def __split_heads(x, num_heads):
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
        [bs, max_sequence_length, num_heads * hidden_dim] then output a tensor
        with shape [bs, num_heads, max_sequence_length, hidden_dim].
        """
        if num_heads == 1:
            return x

        hidden_size = x.shape[-1]
Y
ying 已提交
74
        # FIXME(guosheng): Decouple the program desc with batch_size.
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
        reshaped = layers.reshape(
            x=x, shape=[batch_size, -1, num_heads, hidden_size // num_heads])

        # permuate the dimensions into:
        # [batch_size, num_heads, max_sequence_len, hidden_size_per_head]
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
Y
ying 已提交
92
        # FIXME(guosheng): Decouple the program desc with batch_size.
93 94 95 96 97 98 99 100 101 102
        return layers.reshape(
            x=trans_x,
            shape=map(int,
                      [batch_size, -1, trans_x.shape[2] * trans_x.shape[3]]))

    def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
        """
        Scaled Dot-Product Attention
        """

Y
ying 已提交
103 104 105 106 107 108 109 110 111
        # FIXME(guosheng): Optimize the shape in reshape_op or softmax_op.

        # The current implementation of softmax_op only supports 2D tensor,
        # consequently it cannot be directly used here.
        # If to use the reshape_op, Besides, the shape of product inferred in
        # compile-time is not the actual shape in run-time. It cann't be used
        # to set the attribute of reshape_op.
        # So, here define the softmax for temporary solution.

112 113
        def __softmax(x, eps=1e-9):
            exp_out = layers.exp(x=x)
G
guosheng 已提交
114
            sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False)
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
            return layers.elementwise_div(x=exp_out, y=sum_out, axis=0)

        scaled_q = layers.scale(x=q, scale=d_key**-0.5)
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
        weights = __softmax(layers.elementwise_add(x=product, y=attn_bias))
        if dropout_rate:
            weights = layers.dropout(
                weights, dropout_prob=dropout_rate, is_test=False)
        out = layers.matmul(weights, v)
        return out

    q, k, v = __compute_qkv(queries, keys, values, num_heads, d_key, d_value)

    q = __split_heads(q, num_heads)
    k = __split_heads(k, num_heads)
    v = __split_heads(v, num_heads)

    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key,
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)

    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
147 148 149
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
150 151 152 153 154
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
                       act="relu")
155
    out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
156 157 158 159 160
    return out


def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.):
    """
Y
ying 已提交
161
    Add residual connection, layer normalization and droput to the out tensor
162
    optionally according to the value of process_cmd.
Y
ying 已提交
163

164 165 166 167
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
168
        if cmd == "a":  # add residual connection
169
            out = out + prev_out if prev_out else out
Y
ying 已提交
170
        elif cmd == "n":  # add layer normalization
171
            out = layers.layer_norm(out, begin_norm_axis=len(out.shape) - 1)
Y
ying 已提交
172
        elif cmd == "d":  # add dropout
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
            if dropout:
                out = layers.dropout(out, dropout_prob=dropout, is_test=False)
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_pad_idx,
                    src_max_len,
                    dropout=0.,
                    pos_pad_idx=0,
                    pos_enc_param_name=None):
Y
ying 已提交
191 192
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
193
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
194 195

    This module is used at the bottom of the encoder stacks.
196 197 198 199 200 201
    """
    src_word_emb = layers.embedding(
        src_word, size=[src_vocab_size, src_emb_dim], padding_idx=src_pad_idx)
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
G
guosheng 已提交
202
        padding_idx=pos_pad_idx,
203 204 205
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
Y
ying 已提交
206 207

    # FIXME(guosheng): Decouple the program desc with batch_size.
208 209 210 211 212 213 214 215 216 217 218 219
    enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim])
    return layers.dropout(
        enc_input, dropout_prob=dropout,
        is_test=False) if dropout else enc_input


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232 233
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
                  dropout_rate=0.):
    """The encoder layers that can be stacked to form a deep encoder.

    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
234 235 236
    """
    attn_output = multi_head_attention(enc_input, enc_input, enc_input,
                                       attn_bias, d_key, d_value, d_model,
Y
ying 已提交
237 238 239
                                       n_head, dropout_rate)
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
240
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
241 242 243 244 245 246 247 248 249 250 251 252
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate=0.):
253
    """
Y
ying 已提交
254 255
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
256 257 258
    """
    for i in range(n_layer):
        enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
Y
ying 已提交
259
                                   d_model, d_inner_hid, dropout_rate)
260 261 262 263
        enc_input = enc_output
    return enc_output


Y
ying 已提交
264 265 266 267 268 269 270 271 272 273 274 275 276 277
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
                  dropout_rate=0.):
    """ The layer to be stacked in decoder part.

    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
278
    """
Y
ying 已提交
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
        dropout_rate, )
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
        dropout_rate, )
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
318 319 320
    return dec_output


Y
ying 已提交
321 322 323 324 325 326 327 328 329 330 331
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate=0.):
332 333 334 335
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
336 337 338 339 340 341 342 343 344 345 346
        dec_output = decoder_layer(
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
            dropout_rate, )
347 348 349 350
        dec_input = dec_output
    return dec_output


Y
ying 已提交
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
        src_pad_idx,
        trg_pad_idx,
        pos_pad_idx, ):
    # The shapes here act as placeholder.
    # The shapes set here is to pass the infer-shape in compile time. The actual
    # shape of src_word in run time is:
    # [batch_size * max_src_length_in_a_batch, 1].
369 370 371 372 373
    src_word = layers.data(
        name=input_data_names[0],
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
Y
ying 已提交
374 375
    # The actual shape of src_pos in runtime is:
    # [batch_size * max_src_length_in_a_batch, 1].
376 377 378 379 380
    src_pos = layers.data(
        name=input_data_names[1],
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
Y
ying 已提交
381 382
    # The actual shape of trg_word is in runtime is:
    # [batch_size * max_trg_length_in_a_batch, 1].
383 384 385 386 387
    trg_word = layers.data(
        name=input_data_names[2],
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
Y
ying 已提交
388 389
    # The actual shape of trg_pos in runtime is:
    # [batch_size * max_trg_length_in_a_batch, 1].
390 391 392 393 394
    trg_pos = layers.data(
        name=input_data_names[3],
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
Y
ying 已提交
395 396 397
    # The actual shape of src_slf_attn_bias in runtime is:
    # [batch_size, n_head, max_src_length_in_a_batch, max_src_length_in_a_batch].
    # This input is used to remove attention weights on paddings.
398 399 400 401 402
    src_slf_attn_bias = layers.data(
        name=input_data_names[4],
        shape=[batch_size, n_head, max_length, max_length],
        dtype="float32",
        append_batch_size=False)
Y
ying 已提交
403
    # The actual shape of trg_slf_attn_bias in runtime is:
404
    # [batch_size, n_head, max_trg_length_in_batch, max_trg_length_in_batch].
Y
ying 已提交
405
    # This is used to remove attention weights on paddings and subsequent words.
406 407 408 409 410
    trg_slf_attn_bias = layers.data(
        name=input_data_names[5],
        shape=[batch_size, n_head, max_length, max_length],
        dtype="float32",
        append_batch_size=False)
Y
ying 已提交
411
    # The actual shape of trg_src_attn_bias in runtime is:
412
    # [batch_size, n_head, max_trg_length_in_batch, max_src_length_in_batch].
Y
ying 已提交
413
    # This is used to remove attention weights on paddings.
414 415 416 417 418 419
    trg_src_attn_bias = layers.data(
        name=input_data_names[6],
        shape=[batch_size, n_head, max_length, max_length],
        dtype="float32",
        append_batch_size=False)

Y
ying 已提交
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        src_pad_idx,
        max_length,
        dropout_rate, )
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate, )

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        trg_pad_idx,
        max_length,
        dropout_rate, )
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate, )

    # TODO(guosheng): Share the weight matrix between the embedding layers and
461 462 463 464 465 466 467 468
    # the pre-softmax linear transformation.
    predict = layers.reshape(
        x=layers.fc(input=dec_output,
                    size=trg_vocab_size,
                    bias_attr=False,
                    num_flatten_dims=2),
        shape=[-1, trg_vocab_size],
        act="softmax")
Y
ying 已提交
469 470
    # The actual shape of gold in runtime is:
    # [batch_size * max_trg_length_in_a_batch, 1].
471 472 473 474 475 476
    gold = layers.data(
        name=input_data_names[7],
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
    cost = layers.cross_entropy(input=predict, label=gold)
477 478
    # The actual shape of weights in runtime is:
    # [batch_size * max_trg_length_in_a_batch, 1].
G
guosheng 已提交
479 480
    # Padding index do not contribute to the total loss. This Weight is used to
    # cancel padding index in calculating the loss.
481 482 483 484 485 486 487
    weights = layers.data(
        name=input_data_names[8],
        shape=[batch_size * max_length, 1],
        dtype="float32",
        append_batch_size=False)
    weighted_cost = cost * weights
    return layers.reduce_sum(weighted_cost)