api_train_v2.py 3.5 KB
Newer Older
Q
qiaolongfei 已提交
1
import gzip
H
Helin Wang 已提交
2 3 4 5 6 7 8 9 10 11
import math

import paddle.v2 as paddle

embsize = 32
hiddensize = 256
N = 5


def wordemb(inlayer):
12
    wordemb = paddle.layer.embedding(
H
Helin Wang 已提交
13 14 15 16 17 18
        input=inlayer,
        size=embsize,
        param_attr=paddle.attr.Param(
            name="_proj",
            initial_std=0.001,
            learning_rate=1,
19 20
            l2_rate=0,
            sparse_update=True))
H
Helin Wang 已提交
21 22 23 24
    return wordemb


def main():
25 26 27 28 29 30 31 32
    # for local training
    cluster_train = False

    if not cluster_train:
        paddle.init(use_gpu=False, trainer_count=1)
    else:
        paddle.init(
            use_gpu=False,
33
            trainer_count=2,
34 35 36 37
            port=7164,
            ports_num=1,
            ports_num_for_sparse=1,
            num_gradient_servers=1)
H
Helin Wang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
    word_dict = paddle.dataset.imikolov.build_dict()
    dict_size = len(word_dict)
    firstword = paddle.layer.data(
        name="firstw", type=paddle.data_type.integer_value(dict_size))
    secondword = paddle.layer.data(
        name="secondw", type=paddle.data_type.integer_value(dict_size))
    thirdword = paddle.layer.data(
        name="thirdw", type=paddle.data_type.integer_value(dict_size))
    fourthword = paddle.layer.data(
        name="fourthw", type=paddle.data_type.integer_value(dict_size))
    nextword = paddle.layer.data(
        name="fifthw", type=paddle.data_type.integer_value(dict_size))

    Efirst = wordemb(firstword)
    Esecond = wordemb(secondword)
    Ethird = wordemb(thirdword)
    Efourth = wordemb(fourthword)

    contextemb = paddle.layer.concat(input=[Efirst, Esecond, Ethird, Efourth])
    hidden1 = paddle.layer.fc(input=contextemb,
                              size=hiddensize,
                              act=paddle.activation.Sigmoid(),
                              layer_attr=paddle.attr.Extra(drop_rate=0.5),
                              bias_attr=paddle.attr.Param(learning_rate=2),
                              param_attr=paddle.attr.Param(
                                  initial_std=1. / math.sqrt(embsize * 8),
                                  learning_rate=1))
    predictword = paddle.layer.fc(input=hidden1,
                                  size=dict_size,
                                  bias_attr=paddle.attr.Param(learning_rate=2),
                                  act=paddle.activation.Softmax())

    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 100 == 0:
Q
qiaolongfei 已提交
73 74
                with gzip.open("batch-" + str(event.batch_id), 'w') as f:
                    trainer.save_parameter_to_tar(f)
H
Helin Wang 已提交
75
                result = trainer.test(
76 77
                    paddle.batch(
                        paddle.dataset.imikolov.test(word_dict, N), 32))
H
Helin Wang 已提交
78 79 80 81 82
                print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics,
                    result.metrics)

    cost = paddle.layer.classification_cost(input=predictword, label=nextword)
83

H
Helin Wang 已提交
84
    parameters = paddle.parameters.create(cost)
85
    adagrad = paddle.optimizer.AdaGrad(
H
Helin Wang 已提交
86 87
        learning_rate=3e-3,
        regularization=paddle.optimizer.L2Regularization(8e-4))
88 89 90 91
    trainer = paddle.trainer.SGD(cost,
                                 parameters,
                                 adagrad,
                                 is_local=not cluster_train)
H
Helin Wang 已提交
92
    trainer.train(
93
        paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32),
H
Helin Wang 已提交
94
        num_passes=30,
95
        event_handler=event_handler)
H
Helin Wang 已提交
96 97 98 99


if __name__ == '__main__':
    main()