network_conf.py 1.9 KB
Newer Older
Z
zhaopu 已提交
1 2 3
import paddle.v2 as paddle


C
caoying03 已提交
4 5 6 7 8 9
def rnn_lm(vocab_dim,
           emb_dim,
           hidden_size,
           stacked_rnn_num,
           rnn_type="lstm",
           is_infer=False):
Z
zhaopu 已提交
10 11 12
    """
    RNN language model definition.

13 14 15 16
    :param vocab_dim: size of vocabulary.
    :type vocab_dim: int
    :param emb_dim: dimension of the embedding vector
    :type emb_dim: int
Z
zhaopu 已提交
17
    :param rnn_type: the type of RNN cell.
18 19 20 21 22
    :type rnn_type: int
    :param hidden_size: number of hidden unit.
    :type hidden_size: int
    :param stacked_rnn_num: number of stacked rnn cell.
    :type stacked_rnn_num: int
Z
zhaopu 已提交
23
    :return: cost and output layer of model.
24
    :rtype: LayerOutput
Z
zhaopu 已提交
25 26 27 28
    """

    # input layers
    input = paddle.layer.data(
C
caoying03 已提交
29 30 31 32 33
        name="input", type=paddle.data_type.integer_value_sequence(vocab_dim))
    if not is_infer:
        target = paddle.layer.data(
            name="target",
            type=paddle.data_type.integer_value_sequence(vocab_dim))
Z
zhaopu 已提交
34 35 36 37 38

    # embedding layer
    input_emb = paddle.layer.embedding(input=input, size=emb_dim)

    # rnn layer
C
caoying03 已提交
39 40
    if rnn_type == "lstm":
        for i in range(stacked_rnn_num):
Z
zhaopu 已提交
41
            rnn_cell = paddle.networks.simple_lstm(
C
caoying03 已提交
42 43 44
                input=rnn_cell if i else input_emb, size=hidden_size)
    elif rnn_type == "gru":
        for i in range(stacked_rnn_num):
Z
zhaopu 已提交
45
            rnn_cell = paddle.networks.simple_gru(
C
caoying03 已提交
46
                input=rnn_cell if i else input_emb, size=hidden_size)
Z
zhaopu 已提交
47
    else:
C
caoying03 已提交
48
        raise Exception("rnn_type error!")
Z
zhaopu 已提交
49 50

    # fc(full connected) and output layer
51 52 53
    output = paddle.layer.fc(input=[rnn_cell],
                             size=vocab_dim,
                             act=paddle.activation.Softmax())
Z
zhaopu 已提交
54

C
caoying03 已提交
55 56 57 58 59
    if is_infer:
        last_word = paddle.layer.last_seq(input=output)
        return last_word
    else:
        cost = paddle.layer.classification_cost(input=output, label=target)
Z
zhaopu 已提交
60

P
peterzhang2029 已提交
61
        return cost