network_conf.py 1.8 KB
Newer Older
Z
zhaopu 已提交
1 2 3 4 5
# coding=utf-8

import paddle.v2 as paddle


C
caoying03 已提交
6 7 8 9 10 11
def rnn_lm(vocab_dim,
           emb_dim,
           hidden_size,
           stacked_rnn_num,
           rnn_type="lstm",
           is_infer=False):
Z
zhaopu 已提交
12 13 14
    """
    RNN language model definition.

15 16 17 18
    :param vocab_dim: size of vocabulary.
    :type vocab_dim: int
    :param emb_dim: dimension of the embedding vector
    :type emb_dim: int
Z
zhaopu 已提交
19
    :param rnn_type: the type of RNN cell.
20 21 22 23 24
    :type rnn_type: int
    :param hidden_size: number of hidden unit.
    :type hidden_size: int
    :param stacked_rnn_num: number of stacked rnn cell.
    :type stacked_rnn_num: int
Z
zhaopu 已提交
25
    :return: cost and output layer of model.
26
    :rtype: LayerOutput
Z
zhaopu 已提交
27 28 29 30
    """

    # input layers
    input = paddle.layer.data(
C
caoying03 已提交
31 32 33 34 35
        name="input", type=paddle.data_type.integer_value_sequence(vocab_dim))
    if not is_infer:
        target = paddle.layer.data(
            name="target",
            type=paddle.data_type.integer_value_sequence(vocab_dim))
Z
zhaopu 已提交
36 37 38 39 40

    # embedding layer
    input_emb = paddle.layer.embedding(input=input, size=emb_dim)

    # rnn layer
C
caoying03 已提交
41 42
    if rnn_type == "lstm":
        for i in range(stacked_rnn_num):
Z
zhaopu 已提交
43
            rnn_cell = paddle.networks.simple_lstm(
C
caoying03 已提交
44 45 46
                input=rnn_cell if i else input_emb, size=hidden_size)
    elif rnn_type == "gru":
        for i in range(stacked_rnn_num):
Z
zhaopu 已提交
47
            rnn_cell = paddle.networks.simple_gru(
C
caoying03 已提交
48
                input=rnn_cell if i else input_emb, size=hidden_size)
Z
zhaopu 已提交
49
    else:
C
caoying03 已提交
50
        raise Exception("rnn_type error!")
Z
zhaopu 已提交
51 52 53

    # fc(full connected) and output layer
    output = paddle.layer.fc(
C
caoying03 已提交
54
        input=[rnn_cell], size=vocab_dim, act=paddle.activation.Softmax())
Z
zhaopu 已提交
55

C
caoying03 已提交
56 57 58 59 60
    if is_infer:
        last_word = paddle.layer.last_seq(input=output)
        return last_word
    else:
        cost = paddle.layer.classification_cost(input=output, label=target)
Z
zhaopu 已提交
61

C
caoying03 已提交
62
        return cost, output