network_conf.py 2.0 KB
Newer Older
C
caoying03 已提交
1
import math
2

C
caoying03 已提交
3
import paddle.v2 as paddle
4
from paddle.v2.layer import parse_network
C
caoying03 已提交
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29


def ngram_lm(hidden_size, emb_size, dict_size, gram_num=4, is_train=True):
    emb_layers = []
    embed_param_attr = paddle.attr.Param(
        name="_proj", initial_std=0.001, learning_rate=1, l2_rate=0)
    for i in range(gram_num):
        word = paddle.layer.data(
            name="__word%02d__" % (i),
            type=paddle.data_type.integer_value(dict_size))
        emb_layers.append(
            paddle.layer.embedding(
                input=word, size=emb_size, param_attr=embed_param_attr))
    next_word = paddle.layer.data(
        name="__target_word__", type=paddle.data_type.integer_value(dict_size))

    context_embedding = paddle.layer.concat(input=emb_layers)

    hidden_layer = paddle.layer.fc(
        input=context_embedding,
        size=hidden_size,
        act=paddle.activation.Tanh(),
        param_attr=paddle.attr.Param(initial_std=1. / math.sqrt(emb_size * 8)))

    if is_train:
30 31 32 33 34 35 36
        return paddle.layer.nce(input=hidden_layer,
                                label=next_word,
                                num_classes=dict_size,
                                param_attr=paddle.attr.Param(name="nce_w"),
                                bias_attr=paddle.attr.Param(name="nce_b"),
                                num_neg_samples=25,
                                neg_distribution=None)
C
caoying03 已提交
37
    else:
38
        return paddle.layer.mixed(
C
caoying03 已提交
39
            size=dict_size,
40 41
            input=paddle.layer.trans_full_matrix_projection(
                hidden_layer, param_attr=paddle.attr.Param(name="nce_w")),
C
caoying03 已提交
42
            act=paddle.activation.Softmax(),
43 44
            bias_attr=paddle.attr.Param(name="nce_b"))

C
caoying03 已提交
45

46 47 48 49 50 51 52 53 54 55
if __name__ == "__main__":
    # this is to test and debug the network topology defination.
    # please set the hyper-parameters as needed.
    print(parse_network(
        ngram_lm(
            hidden_size=256,
            emb_size=256,
            dict_size=1024,
            gram_num=4,
            is_train=True)))