model.py 24.0 KB
Newer Older
1 2
from functools import partial
import numpy as np
Y
ying 已提交
3

L
Luo Tao 已提交
4 5
import paddle.fluid as fluid
import paddle.fluid.layers as layers
Y
ying 已提交
6

7 8
from config import TrainTaskConfig, pos_enc_param_names, \
    encoder_input_data_names, decoder_input_data_names, label_data_names
Y
ying 已提交
9

10 11

def position_encoding_init(n_position, d_pos_vec):
Y
ying 已提交
12
    """
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
    Generate the initial values for the sinusoid position encoding table.
    """
    position_enc = np.array([[
        pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
        for j in range(d_pos_vec)
    ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1
    return position_enc.astype("float32")


def multi_head_attention(queries,
                         keys,
                         values,
                         attn_bias,
                         d_key,
                         d_value,
                         d_model,
G
guosheng 已提交
31
                         n_head=1,
G
guosheng 已提交
32 33 34
                         dropout_rate=0.,
                         pre_softmax_shape=None,
                         post_softmax_shape=None):
35
    """
Y
ying 已提交
36 37 38
    Multi-Head Attention. Note that attn_bias is added to the logit before
    computing softmax activiation to mask certain selected positions so that
    they will not considered in attention weights.
39 40 41
    """
    if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
        raise ValueError(
Y
ying 已提交
42
            "Inputs: quries, keys and values should all be 3-D tensors.")
43

G
guosheng 已提交
44
    def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
Y
ying 已提交
45
        """
46 47 48
        Add linear projection to queries, keys, and values.
        """
        q = layers.fc(input=queries,
G
guosheng 已提交
49 50 51 52 53
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
54 55 56
                      bias_attr=False,
                      num_flatten_dims=2)
        k = layers.fc(input=keys,
G
guosheng 已提交
57 58 59 60 61
                      size=d_key * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_key,
                          fan_out=n_head * d_key),
62 63 64
                      bias_attr=False,
                      num_flatten_dims=2)
        v = layers.fc(input=values,
G
guosheng 已提交
65 66 67 68 69
                      size=d_value * n_head,
                      param_attr=fluid.initializer.Xavier(
                          uniform=False,
                          fan_in=d_model * d_value,
                          fan_out=n_head * d_value),
70 71 72 73
                      bias_attr=False,
                      num_flatten_dims=2)
        return q, k, v

G
guosheng 已提交
74
    def __split_heads(x, n_head):
75 76 77
        """
        Reshape the last dimension of inpunt tensor x so that it becomes two
        dimensions and then transpose. Specifically, input a tensor with shape
G
guosheng 已提交
78 79
        [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
        with shape [bs, n_head, max_sequence_length, hidden_dim].
80
        """
G
guosheng 已提交
81
        if n_head == 1:
82 83 84
            return x

        hidden_size = x.shape[-1]
85 86
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
87
        reshaped = layers.reshape(
88
            x=x, shape=[0, -1, n_head, hidden_size // n_head])
89 90

        # permuate the dimensions into:
G
guosheng 已提交
91
        # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
92 93 94 95 96 97 98 99 100 101 102 103
        return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])

    def __combine_heads(x):
        """
        Transpose and then reshape the last two dimensions of inpunt tensor x
        so that it becomes one dimension, which is reverse to __split_heads.
        """
        if len(x.shape) == 3: return x
        if len(x.shape) != 4:
            raise ValueError("Input(x) should be a 4-D Tensor.")

        trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
104 105
        # The value 0 in shape attr means copying the corresponding dimension
        # size of the input as the output dimension size.
106 107
        return layers.reshape(
            x=trans_x,
108
            shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]]))
109

G
guosheng 已提交
110
    def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
111 112 113
        """
        Scaled Dot-Product Attention
        """
G
guosheng 已提交
114
        scaled_q = layers.scale(x=q, scale=d_model**-0.5)
115
        product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
G
guosheng 已提交
116 117 118 119 120 121 122 123
        weights = layers.reshape(
            x=layers.elementwise_add(
                x=product, y=attn_bias) if attn_bias else product,
            shape=[-1, product.shape[-1]],
            actual_shape=pre_softmax_shape,
            act="softmax")
        weights = layers.reshape(
            x=weights, shape=product.shape, actual_shape=post_softmax_shape)
124 125 126 127 128 129
        if dropout_rate:
            weights = layers.dropout(
                weights, dropout_prob=dropout_rate, is_test=False)
        out = layers.matmul(weights, v)
        return out

G
guosheng 已提交
130
    q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
131

G
guosheng 已提交
132 133 134
    q = __split_heads(q, n_head)
    k = __split_heads(k, n_head)
    v = __split_heads(v, n_head)
135

G
guosheng 已提交
136
    ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
137 138 139 140 141 142 143
                                                  dropout_rate)

    out = __combine_heads(ctx_multiheads)

    # Project back to the model size.
    proj_out = layers.fc(input=out,
                         size=d_model,
G
guosheng 已提交
144
                         param_attr=fluid.initializer.Xavier(uniform=False),
145 146 147 148 149 150 151
                         bias_attr=False,
                         num_flatten_dims=2)
    return proj_out


def positionwise_feed_forward(x, d_inner_hid, d_hid):
    """
Y
ying 已提交
152 153 154
    Position-wise Feed-Forward Networks.
    This module consists of two linear transformations with a ReLU activation
    in between, which is applied to each position separately and identically.
155 156 157 158
    """
    hidden = layers.fc(input=x,
                       size=d_inner_hid,
                       num_flatten_dims=2,
G
guosheng 已提交
159 160
                       param_attr=fluid.initializer.Uniform(
                           low=-(d_hid**-0.5), high=(d_hid**-0.5)),
161
                       act="relu")
G
guosheng 已提交
162 163 164 165 166
    out = layers.fc(input=hidden,
                    size=d_hid,
                    num_flatten_dims=2,
                    param_attr=fluid.initializer.Uniform(
                        low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)))
167 168 169
    return out


170
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
171
    """
Y
ying 已提交
172
    Add residual connection, layer normalization and droput to the out tensor
173
    optionally according to the value of process_cmd.
Y
ying 已提交
174

175 176 177 178
    This will be used before or after multi-head attention and position-wise
    feed-forward networks.
    """
    for cmd in process_cmd:
Y
ying 已提交
179
        if cmd == "a":  # add residual connection
180
            out = out + prev_out if prev_out else out
Y
ying 已提交
181
        elif cmd == "n":  # add layer normalization
G
guosheng 已提交
182 183 184 185 186
            out = layers.layer_norm(
                out,
                begin_norm_axis=len(out.shape) - 1,
                param_attr=fluid.initializer.Constant(1.),
                bias_attr=fluid.initializer.Constant(0.))
Y
ying 已提交
187
        elif cmd == "d":  # add dropout
188 189 190
            if dropout_rate:
                out = layers.dropout(
                    out, dropout_prob=dropout_rate, is_test=False)
191 192 193 194 195 196 197 198 199 200 201 202
    return out


pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer


def prepare_encoder(src_word,
                    src_pos,
                    src_vocab_size,
                    src_emb_dim,
                    src_max_len,
203 204
                    dropout_rate=0.,
                    src_data_shape=None,
205
                    pos_enc_param_name=None):
Y
ying 已提交
206 207
    """Add word embeddings and position encodings.
    The output tensor has a shape of:
208
    [batch_size, max_src_length_in_batch, d_model].
Y
ying 已提交
209 210

    This module is used at the bottom of the encoder stacks.
211 212
    """
    src_word_emb = layers.embedding(
G
guosheng 已提交
213 214 215
        src_word,
        size=[src_vocab_size, src_emb_dim],
        param_attr=fluid.initializer.Normal(0., 1.))
216 217 218 219 220 221
    src_pos_enc = layers.embedding(
        src_pos,
        size=[src_max_len, src_emb_dim],
        param_attr=fluid.ParamAttr(
            name=pos_enc_param_name, trainable=False))
    enc_input = src_word_emb + src_pos_enc
222 223 224 225
    enc_input = layers.reshape(
        x=enc_input,
        shape=[-1, src_max_len, src_emb_dim],
        actual_shape=src_data_shape)
226
    return layers.dropout(
227 228
        enc_input, dropout_prob=dropout_rate,
        is_test=False) if dropout_rate else enc_input
229 230 231 232 233 234 235 236


prepare_encoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
    prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])


Y
ying 已提交
237 238 239 240 241 242 243
def encoder_layer(enc_input,
                  attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
244 245 246
                  dropout_rate=0.,
                  pre_softmax_shape=None,
                  post_softmax_shape=None):
Y
ying 已提交
247 248 249 250 251 252
    """The encoder layers that can be stacked to form a deep encoder.

    This module consits of a multi-head (self) attention followed by
    position-wise feed-forward networks and both the two components companied
    with the post_process_layer to add residual connection, layer normalization
    and droput.
253
    """
G
guosheng 已提交
254 255 256
    attn_output = multi_head_attention(
        enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model,
        n_head, dropout_rate, pre_softmax_shape, post_softmax_shape)
Y
ying 已提交
257 258
    attn_output = post_process_layer(enc_input, attn_output, "dan",
                                     dropout_rate)
259
    ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
Y
ying 已提交
260 261 262 263 264 265 266 267 268 269 270
    return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)


def encoder(enc_input,
            attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
271 272 273
            dropout_rate=0.,
            pre_softmax_shape=None,
            post_softmax_shape=None):
274
    """
Y
ying 已提交
275 276
    The encoder is composed of a stack of identical layers returned by calling
    encoder_layer.
277 278
    """
    for i in range(n_layer):
279 280 281 282 283 284 285 286
        enc_output = encoder_layer(
            enc_input,
            attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
287 288 289
            dropout_rate,
            pre_softmax_shape,
            post_softmax_shape, )
290 291 292 293
        enc_input = enc_output
    return enc_output


Y
ying 已提交
294 295 296 297 298 299 300 301 302
def decoder_layer(dec_input,
                  enc_output,
                  slf_attn_bias,
                  dec_enc_attn_bias,
                  n_head,
                  d_key,
                  d_value,
                  d_model,
                  d_inner_hid,
G
guosheng 已提交
303 304 305 306 307
                  dropout_rate=0.,
                  slf_attn_pre_softmax_shape=None,
                  slf_attn_post_softmax_shape=None,
                  src_attn_pre_softmax_shape=None,
                  src_attn_post_softmax_shape=None):
Y
ying 已提交
308 309 310 311
    """ The layer to be stacked in decoder part.

    The structure of this module is similar to that in the encoder part except
    a multi-head attention is added to implement encoder-decoder attention.
312
    """
Y
ying 已提交
313 314 315 316 317 318 319 320 321
    slf_attn_output = multi_head_attention(
        dec_input,
        dec_input,
        dec_input,
        slf_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
322 323 324
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape, )
Y
ying 已提交
325 326 327 328 329 330 331 332 333 334 335 336 337 338
    slf_attn_output = post_process_layer(
        dec_input,
        slf_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    enc_attn_output = multi_head_attention(
        slf_attn_output,
        enc_output,
        enc_output,
        dec_enc_attn_bias,
        d_key,
        d_value,
        d_model,
        n_head,
G
guosheng 已提交
339 340 341
        dropout_rate,
        src_attn_pre_softmax_shape,
        src_attn_post_softmax_shape, )
Y
ying 已提交
342 343 344 345 346 347 348 349 350 351 352 353 354 355
    enc_attn_output = post_process_layer(
        slf_attn_output,
        enc_attn_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
    ffd_output = positionwise_feed_forward(
        enc_attn_output,
        d_inner_hid,
        d_model, )
    dec_output = post_process_layer(
        enc_attn_output,
        ffd_output,
        "dan",  # residual connection + dropout + layer normalization
        dropout_rate, )
356 357 358
    return dec_output


Y
ying 已提交
359 360 361 362 363 364 365 366 367 368
def decoder(dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_layer,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
369 370 371 372 373
            dropout_rate=0.,
            slf_attn_pre_softmax_shape=None,
            slf_attn_post_softmax_shape=None,
            src_attn_pre_softmax_shape=None,
            src_attn_post_softmax_shape=None):
374 375 376 377
    """
    The decoder is composed of a stack of identical decoder_layer layers.
    """
    for i in range(n_layer):
Y
ying 已提交
378 379 380 381 382 383 384 385 386 387
        dec_output = decoder_layer(
            dec_input,
            enc_output,
            dec_slf_attn_bias,
            dec_enc_attn_bias,
            n_head,
            d_key,
            d_value,
            d_model,
            d_inner_hid,
G
guosheng 已提交
388 389 390 391 392
            dropout_rate,
            slf_attn_pre_softmax_shape,
            slf_attn_post_softmax_shape,
            src_attn_pre_softmax_shape,
            src_attn_post_softmax_shape, )
393 394 395 396
        dec_input = dec_output
    return dec_output


397 398 399 400
def make_inputs(input_data_names,
                n_head,
                d_model,
                max_length,
401
                is_pos,
402 403
                slf_attn_bias_flag,
                src_attn_bias_flag,
404
                enc_output_flag=False,
405
                data_shape_flag=True,
G
guosheng 已提交
406 407
                slf_attn_shape_flag=True,
                src_attn_shape_flag=True):
408 409 410 411
    """
    Define the input data layers for the transformer model.
    """
    input_layers = []
412 413 414 415 416
    batch_size = 1  # Only for the infer-shape in compile time.
    # The shapes here act as placeholder and are set to pass the infer-shape in
    # compile time.
    # The actual data shape of word is:
    # [batch_size * max_len_in_batch, 1]
417
    word = layers.data(
418
        name=input_data_names[len(input_layers)],
419 420 421 422 423
        shape=[batch_size * max_length, 1],
        dtype="int64",
        append_batch_size=False)
    input_layers += [word]
    # This is used for position data or label weight.
424 425
    # The actual data shape of pos is:
    # [batch_size * max_len_in_batch, 1]
426
    pos = layers.data(
427
        name=input_data_names[len(input_layers)],
428
        shape=[batch_size * max_length, 1],
429
        dtype="int64" if is_pos else "float32",
430 431 432
        append_batch_size=False)
    input_layers += [pos]
    if slf_attn_bias_flag:
433 434 435
        # This input is used to remove attention weights on paddings for the
        # encoder and to remove attention weights on subsequent words for the
        # decoder.
436 437
        # The actual data shape of slf_attn_bias_flag is:
        # [batch_size, n_head, max_len_in_batch, max_len_in_batch]
438
        slf_attn_bias = layers.data(
439 440
            name=input_data_names[len(input_layers)],
            shape=[batch_size, n_head, max_length, max_length],
441 442 443 444
            dtype="float32",
            append_batch_size=False)
        input_layers += [slf_attn_bias]
    if src_attn_bias_flag:
445 446
        # This input is used to remove attention weights on paddings. It's used
        # in encoder-decoder attention.
447 448
        # The actual data shape of slf_attn_bias_flag is:
        # [batch_size, n_head, trg_max_len_in_batch, src_max_len_in_batch]
449
        src_attn_bias = layers.data(
450
            name=input_data_names[len(input_layers)],
451 452 453 454
            shape=[batch_size, n_head, max_length, max_length],
            dtype="float32",
            append_batch_size=False)
        input_layers += [src_attn_bias]
455
    if data_shape_flag:
456
        # This input is used to reshape the output of embedding layer.
457 458 459 460 461 462
        data_shape = layers.data(
            name=input_data_names[len(input_layers)],
            shape=[3],
            dtype="int32",
            append_batch_size=False)
        input_layers += [data_shape]
G
guosheng 已提交
463
    if slf_attn_shape_flag:
464
        # This shape input is used to reshape before softmax in self attention.
G
guosheng 已提交
465 466
        slf_attn_pre_softmax_shape = layers.data(
            name=input_data_names[len(input_layers)],
467
            shape=[2],
G
guosheng 已提交
468 469 470
            dtype="int32",
            append_batch_size=False)
        input_layers += [slf_attn_pre_softmax_shape]
471
        # This shape input is used to reshape after softmax in self attention.
G
guosheng 已提交
472 473
        slf_attn_post_softmax_shape = layers.data(
            name=input_data_names[len(input_layers)],
474
            shape=[4],
G
guosheng 已提交
475 476 477 478
            dtype="int32",
            append_batch_size=False)
        input_layers += [slf_attn_post_softmax_shape]
    if src_attn_shape_flag:
479 480
        # This shape input is used to reshape before softmax in encoder-decoder
        # attention.
G
guosheng 已提交
481 482
        src_attn_pre_softmax_shape = layers.data(
            name=input_data_names[len(input_layers)],
483
            shape=[2],
G
guosheng 已提交
484 485 486
            dtype="int32",
            append_batch_size=False)
        input_layers += [src_attn_pre_softmax_shape]
487 488
        # This shape input is used to reshape after softmax in encoder-decoder
        # attention.
G
guosheng 已提交
489 490
        src_attn_post_softmax_shape = layers.data(
            name=input_data_names[len(input_layers)],
491
            shape=[4],
G
guosheng 已提交
492 493 494
            dtype="int32",
            append_batch_size=False)
        input_layers += [src_attn_post_softmax_shape]
495
    if enc_output_flag:
496 497 498
        # This input is used in independent decoder program for inference.
        # The actual data shape of slf_attn_bias_flag is:
        # [batch_size, max_len_in_batch, d_model]
499 500 501 502 503 504
        enc_output = layers.data(
            name=input_data_names[len(input_layers)],
            shape=[batch_size, max_length, d_model],
            dtype="float32",
            append_batch_size=False)
        input_layers += [enc_output]
G
guosheng 已提交
505

506 507 508
    return input_layers


Y
ying 已提交
509 510 511 512 513 514 515 516 517 518
def transformer(
        src_vocab_size,
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
519
        dropout_rate, ):
520
    enc_inputs = make_inputs(
G
guosheng 已提交
521 522 523 524 525 526 527 528
        encoder_input_data_names,
        n_head,
        d_model,
        max_length,
        is_pos=True,
        slf_attn_bias_flag=True,
        src_attn_bias_flag=False,
        enc_output_flag=False,
529
        data_shape_flag=True,
G
guosheng 已提交
530 531
        slf_attn_shape_flag=True,
        src_attn_shape_flag=False)
532

533 534 535 536 537 538 539 540 541 542
    enc_output = wrap_encoder(
        src_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
543
        enc_inputs, )
544

545
    dec_inputs = make_inputs(
G
guosheng 已提交
546 547 548 549 550 551 552 553
        decoder_input_data_names,
        n_head,
        d_model,
        max_length,
        is_pos=True,
        slf_attn_bias_flag=True,
        src_attn_bias_flag=True,
        enc_output_flag=False,
554
        data_shape_flag=True,
G
guosheng 已提交
555 556
        slf_attn_shape_flag=True,
        src_attn_shape_flag=True)
557 558 559 560 561 562 563 564 565 566 567

    predict = wrap_decoder(
        trg_vocab_size,
        max_length,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
        dropout_rate,
568
        dec_inputs,
569 570 571 572
        enc_output, )

    # Padding index do not contribute to the total loss. The weights is used to
    # cancel padding index in calculating the loss.
G
guosheng 已提交
573 574 575 576 577 578 579 580 581
    gold, weights = make_inputs(
        label_data_names,
        n_head,
        d_model,
        max_length,
        is_pos=False,
        slf_attn_bias_flag=False,
        src_attn_bias_flag=False,
        enc_output_flag=False,
582
        data_shape_flag=False,
G
guosheng 已提交
583 584
        slf_attn_shape_flag=False,
        src_attn_shape_flag=False)
585
    cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
586
    weighted_cost = cost * weights
G
guosheng 已提交
587 588 589
    sum_cost = layers.reduce_sum(weighted_cost)
    token_num = layers.reduce_sum(weights)
    avg_cost = sum_cost / token_num
G
guosheng 已提交
590
    return sum_cost, avg_cost, predict, token_num
591 592 593 594 595 596 597 598 599 600 601


def wrap_encoder(src_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
602
                 enc_inputs=None):
603 604 605
    """
    The wrapper assembles together all needed layers for the encoder.
    """
606
    if enc_inputs is None:
607
        # This is used to implement independent encoder program in inference.
608 609 610
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
            make_inputs(
G
guosheng 已提交
611 612 613 614 615 616 617 618
                encoder_input_data_names,
                n_head,
                d_model,
                max_length,
                is_pos=True,
                slf_attn_bias_flag=True,
                src_attn_bias_flag=False,
                enc_output_flag=False,
619
                data_shape_flag=True,
G
guosheng 已提交
620 621
                slf_attn_shape_flag=True,
                src_attn_shape_flag=False)
622
    else:
623 624 625
        src_word, src_pos, src_slf_attn_bias, src_data_shape, \
            slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
            enc_inputs
Y
ying 已提交
626 627 628 629 630 631
    enc_input = prepare_encoder(
        src_word,
        src_pos,
        src_vocab_size,
        d_model,
        max_length,
632 633
        dropout_rate,
        src_data_shape, )
Y
ying 已提交
634 635 636 637 638 639 640 641 642
    enc_output = encoder(
        enc_input,
        src_slf_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
643 644 645
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape, )
646 647 648 649 650 651 652 653 654 655 656 657
    return enc_output


def wrap_decoder(trg_vocab_size,
                 max_length,
                 n_layer,
                 n_head,
                 d_key,
                 d_value,
                 d_model,
                 d_inner_hid,
                 dropout_rate,
658
                 dec_inputs=None,
659 660 661 662
                 enc_output=None):
    """
    The wrapper assembles together all needed layers for the decoder.
    """
663
    if dec_inputs is None:
664
        # This is used to implement independent decoder program in inference.
G
guosheng 已提交
665
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
666 667 668
            trg_data_shape, slf_attn_pre_softmax_shape, \
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
            src_attn_post_softmax_shape, enc_output = make_inputs(
G
guosheng 已提交
669 670 671 672 673 674 675 676
                decoder_input_data_names,
                n_head,
                d_model,
                max_length,
                is_pos=True,
                slf_attn_bias_flag=True,
                src_attn_bias_flag=True,
                enc_output_flag=True,
677
                data_shape_flag=True,
G
guosheng 已提交
678 679
                slf_attn_shape_flag=True,
                src_attn_shape_flag=True)
680
    else:
G
guosheng 已提交
681
        trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
682 683 684
            trg_data_shape, slf_attn_pre_softmax_shape, \
            slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
            src_attn_post_softmax_shape = dec_inputs
Y
ying 已提交
685 686 687 688 689 690 691

    dec_input = prepare_decoder(
        trg_word,
        trg_pos,
        trg_vocab_size,
        d_model,
        max_length,
692 693
        dropout_rate,
        trg_data_shape, )
Y
ying 已提交
694 695 696 697 698 699 700 701 702 703 704
    dec_output = decoder(
        dec_input,
        enc_output,
        trg_slf_attn_bias,
        trg_src_attn_bias,
        n_layer,
        n_head,
        d_key,
        d_value,
        d_model,
        d_inner_hid,
G
guosheng 已提交
705 706 707 708 709
        dropout_rate,
        slf_attn_pre_softmax_shape,
        slf_attn_post_softmax_shape,
        src_attn_pre_softmax_shape,
        src_attn_post_softmax_shape, )
710
    # Return logits for training and probs for inference.
711
    predict = layers.reshape(
G
guosheng 已提交
712 713 714 715 716
        x=layers.fc(input=dec_output,
                    size=trg_vocab_size,
                    bias_attr=False,
                    num_flatten_dims=2),
        shape=[-1, trg_vocab_size],
717
        act="softmax" if dec_inputs is None else None)
718
    return predict