model.py 24.3 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7 8
from config import TrainTaskConfig, pos_enc_param_names, \
    encoder_input_data_names, decoder_input_data_names, label_data_names
Y
ying 已提交
9

10 11

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
12
    """
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
31
                         n_head=1,
G
guosheng 已提交
32 33 34
                         dropout_rate=0.,
                         pre_softmax_shape=None,
                         post_softmax_shape=None):
35
    """
Y
ying 已提交
36 37 38
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
39 40 41
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
42
            "Inputs: quries, keys and values should all be 3-D tensors.")
43

G
guosheng 已提交
44
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
45
        """
46 47 48
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
49 50 51 52 53
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
54 55 56
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
57 58 59 60 61
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
62 63 64
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
65 66 67 68 69
                      size=d_value * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_value,
                          fan_out=n_head * d_value),
70 71 72 73
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
74
    def __split_heads(x, n_head):
75 76 77
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
78 79
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
80
        """
G
guosheng 已提交
81
        if n_head == 1:
82 83 84
            return x

        hidden_size = x.shape[-1]
85 86
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
87
        reshaped = layers.reshape(
88
            x=x, shape=[0, -1, n_head, hidden_size // n_head])
89 90

        # permuate the dimensions into:
G
guosheng 已提交
91
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
92 93 94 95 96 97 98 99 100 101 102 103
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
104 105
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
106 107
        return layers.reshape(
            x=trans_x,
108
            shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]]))
109

G
guosheng 已提交
110
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
111 112 113
        """
        Scaled Dot-Product Attention
        """
G
guosheng 已提交
114
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
115
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
G
guosheng 已提交
116 117 118 119 120 121 122 123
        weights = layers.reshape(
            x=layers.elementwise_add(
                x=product, y=attn_bias) if attn_bias else product,
            shape=[-1, product.shape[-1]],
            actual_shape=pre_softmax_shape,
            act="softmax")
        weights = layers.reshape(
            x=weights, shape=product.shape, actual_shape=post_softmax_shape)
124 125 126 127 128 129
        if dropout_rate:
            weights = layers.dropout(
                weights, dropout_prob=dropout_rate, is_test=False)
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
130
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
131

G
guosheng 已提交
132 133 134
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
135

G
guosheng 已提交
136
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
137 138 139 140 141 142 143
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)

    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
G
guosheng 已提交
144
                         param_attr=fluid.initializer.Xavier(uniform=False),
145 146 147 148 149 150 151
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
152 153 154
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
155 156 157 158
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
G
guosheng 已提交
159 160
                       param_attr=fluid.initializer.Uniform(
                           low=-(d_hid**-0.5), high=(d_hid**-0.5)),
161
                       act="relu")
G
guosheng 已提交
162 163 164 165 166
    out = layers.fc(input=hidden,
                    size=d_hid,
                    num_flatten_dims=2,
                    param_attr=fluid.initializer.Uniform(
                        low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)))
167 168 169
    return out


170
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
171
    """
Y
ying 已提交
172
    Add residual connection, layer normalization and droput to the out tensor
173
    optionally according to the value of process_cmd.
Y
ying 已提交
174

175 176 177 178
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
179
        if cmd == "a":  # add residual connection
180
            out = out + prev_out if prev_out else out
Y
ying 已提交
181
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
182 183 184 185 186
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
187
        elif cmd == "d":  # add dropout
188 189 190
            if dropout_rate:
                out = layers.dropout(
                    out, dropout_prob=dropout_rate, is_test=False)
191 192 193 194 195 196 197 198 199 200 201 202 203
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_pad_idx,
                    src_max_len,
204
                    dropout_rate=0.,
205
                    pos_pad_idx=0,
206
                    src_data_shape=None,
207
                    pos_enc_param_name=None):
Y
ying 已提交
208 209
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
210
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
211 212

    This module is used at the bottom of the encoder stacks.
213 214
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
215 216 217 218
        src_word,
        size=[src_vocab_size, src_emb_dim],
        padding_idx=src_pad_idx,
        param_attr=fluid.initializer.Normal(0., 1.))
219 220 221
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
G
guosheng 已提交
222
        padding_idx=pos_pad_idx,
223 224 225
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
226 227 228 229
    enc_input = layers.reshape(
        x=enc_input,
        shape=[-1, src_max_len, src_emb_dim],
        actual_shape=src_data_shape)
230
    return layers.dropout(
231 232
        enc_input, dropout_prob=dropout_rate,
        is_test=False) if dropout_rate else enc_input
233 234 235 236 237 238 239 240


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
241 242 243 244 245 246 247
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
248 249 250
                  dropout_rate=0.,
                  pre_softmax_shape=None,
                  post_softmax_shape=None):
Y
ying 已提交
251 252 253 254 255 256
    """The encoder layers that can be stacked to form a deep encoder.

    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
257
    """
G
guosheng 已提交
258 259 260
    attn_output = multi_head_attention(
        enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model,
        n_head, dropout_rate, pre_softmax_shape, post_softmax_shape)
Y
ying 已提交
261 262
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
263
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
264 265 266 267 268 269 270 271 272 273 274
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
275 276 277
            dropout_rate=0.,
            pre_softmax_shape=None,
            post_softmax_shape=None):
278
    """
Y
ying 已提交
279 280
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
281 282
    """
    for i in range(n_layer):
283 284 285 286 287 288 289 290
        enc_output = encoder_layer(
            enc_input,
            attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
291 292 293
            dropout_rate,
            pre_softmax_shape,
            post_softmax_shape, )
294 295 296 297
        enc_input = enc_output
    return enc_output


Y
ying 已提交
298 299 300 301 302 303 304 305 306
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
307 308 309 310 311
                  dropout_rate=0.,
                  slf_attn_pre_softmax_shape=None,
                  slf_attn_post_softmax_shape=None,
                  src_attn_pre_softmax_shape=None,
                  src_attn_post_softmax_shape=None):
Y
ying 已提交
312 313 314 315
    """ The layer to be stacked in decoder part.

    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
316
    """
Y
ying 已提交
317 318 319 320 321 322 323 324 325
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
326 327 328
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape, )
Y
ying 已提交
329 330 331 332 333 334 335 336 337 338 339 340 341 342
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
343 344 345
        dropout_rate,
        src_attn_pre_softmax_shape,
        src_attn_post_softmax_shape, )
Y
ying 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
360 361 362
    return dec_output


Y
ying 已提交
363 364 365 366 367 368 369 370 371 372
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
373 374 375 376 377
            dropout_rate=0.,
            slf_attn_pre_softmax_shape=None,
            slf_attn_post_softmax_shape=None,
            src_attn_pre_softmax_shape=None,
            src_attn_post_softmax_shape=None):
378 379 380 381
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
382 383 384 385 386 387 388 389 390 391
        dec_output = decoder_layer(
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
392 393 394 395 396
            dropout_rate,
            slf_attn_pre_softmax_shape,
            slf_attn_post_softmax_shape,
            src_attn_pre_softmax_shape,
            src_attn_post_softmax_shape, )
397 398 399 400
        dec_input = dec_output
    return dec_output


401 402 403 404
def make_inputs(input_data_names,
                n_head,
                d_model,
                max_length,
405
                is_pos,
406 407
                slf_attn_bias_flag,
                src_attn_bias_flag,
408
                enc_output_flag=False,
409
                data_shape_flag=True,
G
guosheng 已提交
410 411
                slf_attn_shape_flag=True,
                src_attn_shape_flag=True):
412 413 414 415
    """
    Define the input data layers for the transformer model.
    """
    input_layers = []
416 417 418 419 420
    batch_size = 1  # Only for the infer-shape in compile time.
    # The shapes here act as placeholder and are set to pass the infer-shape in
    # compile time.
    # The actual data shape of word is:
    # [batch_size * max_len_in_batch, 1]
421
    word = layers.data(
422
        name=input_data_names[len(input_layers)],
423 424 425 426 427
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
    input_layers += [word]
    # This is used for position data or label weight.
428 429
    # The actual data shape of pos is:
    # [batch_size * max_len_in_batch, 1]
430
    pos = layers.data(
431
        name=input_data_names[len(input_layers)],
432
        shape=[batch_size * max_length, 1],
433
        dtype="int64" if is_pos else "float32",
434 435 436
        append_batch_size=False)
    input_layers += [pos]
    if slf_attn_bias_flag:
437 438 439
        # This input is used to remove attention weights on paddings for the
        # encoder and to remove attention weights on subsequent words for the
        # decoder.
440 441
        # The actual data shape of slf_attn_bias_flag is:
        # [batch_size, n_head, max_len_in_batch, max_len_in_batch]
442
        slf_attn_bias = layers.data(
443 444
            name=input_data_names[len(input_layers)],
            shape=[batch_size, n_head, max_length, max_length],
445 446 447 448
            dtype="float32",
            append_batch_size=False)
        input_layers += [slf_attn_bias]
    if src_attn_bias_flag:
449 450
        # This input is used to remove attention weights on paddings. It's used
        # in encoder-decoder attention.
451 452
        # The actual data shape of slf_attn_bias_flag is:
        # [batch_size, n_head, trg_max_len_in_batch, src_max_len_in_batch]
453
        src_attn_bias = layers.data(
454
            name=input_data_names[len(input_layers)],
455 456 457 458
            shape=[batch_size, n_head, max_length, max_length],
            dtype="float32",
            append_batch_size=False)
        input_layers += [src_attn_bias]
459
    if data_shape_flag:
460
        # This input is used to reshape the output of embedding layer.
461 462 463 464 465 466
        data_shape = layers.data(
            name=input_data_names[len(input_layers)],
            shape=[3],
            dtype="int32",
            append_batch_size=False)
        input_layers += [data_shape]
G
guosheng 已提交
467
    if slf_attn_shape_flag:
468
        # This shape input is used to reshape before softmax in self attention.
G
guosheng 已提交
469 470
        slf_attn_pre_softmax_shape = layers.data(
            name=input_data_names[len(input_layers)],
471
            shape=[2],
G
guosheng 已提交
472 473 474
            dtype="int32",
            append_batch_size=False)
        input_layers += [slf_attn_pre_softmax_shape]
475
        # This shape input is used to reshape after softmax in self attention.
G
guosheng 已提交
476 477
        slf_attn_post_softmax_shape = layers.data(
            name=input_data_names[len(input_layers)],
478
            shape=[4],
G
guosheng 已提交
479 480 481 482 483 484
            dtype="int32",
            append_batch_size=False)
        input_layers += [slf_attn_post_softmax_shape]
    if src_attn_shape_flag:
        src_attn_pre_softmax_shape = layers.data(
            name=input_data_names[len(input_layers)],
485
            shape=[2],
G
guosheng 已提交
486 487 488 489 490
            dtype="int32",
            append_batch_size=False)
        input_layers += [src_attn_pre_softmax_shape]
        src_attn_post_softmax_shape = layers.data(
            name=input_data_names[len(input_layers)],
491
            shape=[4],
G
guosheng 已提交
492 493 494
            dtype="int32",
            append_batch_size=False)
        input_layers += [src_attn_post_softmax_shape]
495
    if enc_output_flag:
496 497 498
        # This input is used in independent decoder program for inference.
        # The actual data shape of slf_attn_bias_flag is:
        # [batch_size, max_len_in_batch, d_model]
499 500 501 502 503 504
        enc_output = layers.data(
            name=input_data_names[len(input_layers)],
            shape=[batch_size, max_length, d_model],
            dtype="float32",
            append_batch_size=False)
        input_layers += [enc_output]
G
guosheng 已提交
505

506 507 508
    return input_layers


Y
ying 已提交
509 510 511 512 513 514 515 516 517 518 519 520 521 522
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
        src_pad_idx,
        trg_pad_idx,
        pos_pad_idx, ):
523
    enc_inputs = make_inputs(
G
guosheng 已提交
524 525 526 527 528 529 530 531
        encoder_input_data_names,
        n_head,
        d_model,
        max_length,
        is_pos=True,
        slf_attn_bias_flag=True,
        src_attn_bias_flag=False,
        enc_output_flag=False,
532
        data_shape_flag=True,
G
guosheng 已提交
533 534
        slf_attn_shape_flag=True,
        src_attn_shape_flag=False)
535

536 537 538 539 540 541 542 543 544 545 546 547
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
        src_pad_idx,
        pos_pad_idx,
548
        enc_inputs, )
549

550
    dec_inputs = make_inputs(
G
guosheng 已提交
551 552 553 554 555 556 557 558
        decoder_input_data_names,
        n_head,
        d_model,
        max_length,
        is_pos=True,
        slf_attn_bias_flag=True,
        src_attn_bias_flag=True,
        enc_output_flag=False,
559
        data_shape_flag=True,
G
guosheng 已提交
560 561
        slf_attn_shape_flag=True,
        src_attn_shape_flag=True)
562 563 564 565 566 567 568 569 570 571 572 573 574

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
        trg_pad_idx,
        pos_pad_idx,
575
        dec_inputs,
576 577 578 579
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
G
guosheng 已提交
580 581 582 583 584 585 586 587 588
    gold, weights = make_inputs(
        label_data_names,
        n_head,
        d_model,
        max_length,
        is_pos=False,
        slf_attn_bias_flag=False,
        src_attn_bias_flag=False,
        enc_output_flag=False,
589
        data_shape_flag=False,
G
guosheng 已提交
590 591
        slf_attn_shape_flag=False,
        src_attn_shape_flag=False)
592
    cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
593
    weighted_cost = cost * weights
G
guosheng 已提交
594 595 596
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
    avg_cost = sum_cost / token_num
G
guosheng 已提交
597
    return sum_cost, avg_cost, predict, token_num
598 599 600 601 602 603 604 605 606 607 608 609 610


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
                 src_pad_idx,
                 pos_pad_idx,
611
                 enc_inputs=None):
612 613 614
    """
    The wrapper assembles together all needed layers for the encoder.
    """
615
    if enc_inputs is None:
616
        # This is used to implement independent encoder program in inference.
617 618 619
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
            make_inputs(
G
guosheng 已提交
620 621 622 623 624 625 626 627
                encoder_input_data_names,
                n_head,
                d_model,
                max_length,
                is_pos=True,
                slf_attn_bias_flag=True,
                src_attn_bias_flag=False,
                enc_output_flag=False,
628
                data_shape_flag=True,
G
guosheng 已提交
629 630
                slf_attn_shape_flag=True,
                src_attn_shape_flag=False)
631
    else:
632 633 634
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
            enc_inputs
Y
ying 已提交
635 636 637 638 639 640 641
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        src_pad_idx,
        max_length,
642 643 644
        dropout_rate,
        pos_pad_idx,
        src_data_shape, )
Y
ying 已提交
645 646 647 648 649 650 651 652 653
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
654 655 656
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape, )
657 658 659 660 661 662 663 664 665 666 667 668 669 670
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
                 trg_pad_idx,
                 pos_pad_idx,
671
                 dec_inputs=None,
672 673 674 675
                 enc_output=None):
    """
    The wrapper assembles together all needed layers for the decoder.
    """
676
    if dec_inputs is None:
677
        # This is used to implement independent decoder program in inference.
G
guosheng 已提交
678
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
679 680 681
            trg_data_shape, slf_attn_pre_softmax_shape, \
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
            src_attn_post_softmax_shape, enc_output = make_inputs(
G
guosheng 已提交
682 683 684 685 686 687 688 689
                decoder_input_data_names,
                n_head,
                d_model,
                max_length,
                is_pos=True,
                slf_attn_bias_flag=True,
                src_attn_bias_flag=True,
                enc_output_flag=True,
690
                data_shape_flag=True,
G
guosheng 已提交
691 692
                slf_attn_shape_flag=True,
                src_attn_shape_flag=True)
693
    else:
G
guosheng 已提交
694
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
695 696 697
            trg_data_shape, slf_attn_pre_softmax_shape, \
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
            src_attn_post_softmax_shape = dec_inputs
Y
ying 已提交
698 699 700 701 702 703 704 705

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        trg_pad_idx,
        max_length,
706 707 708
        dropout_rate,
        pos_pad_idx,
        trg_data_shape, )
Y
ying 已提交
709 710 711 712 713 714 715 716 717 718 719
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
720 721 722 723 724
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape,
        src_attn_pre_softmax_shape,
        src_attn_post_softmax_shape, )
725
    # Return logits for training and probs for inference.
726
    predict = layers.reshape(
727 728 729 730 731 732
        x=layers.fc(
            input=dec_output,
            size=trg_vocab_size - 1,  # To exclude <pad>.
            bias_attr=False,
            num_flatten_dims=2),
        shape=[-1, trg_vocab_size - 1],
733
        act="softmax" if dec_inputs is None else None)
734
    return predict