api_train_v2.py 3.5 KB
Newer Older
Q
qiaolongfei 已提交
1
import gzip
H
Helin Wang 已提交
2 3 4 5 6 7 8 9 10 11
import math

import paddle.v2 as paddle

embsize = 32
hiddensize = 256
N = 5


def wordemb(inlayer):
12
    wordemb = paddle.layer.embedding(
H
Helin Wang 已提交
13 14 15 16 17 18
        input=inlayer,
        size=embsize,
        param_attr=paddle.attr.Param(
            name="_proj",
            initial_std=0.001,
            learning_rate=1,
19 20
            l2_rate=0,
            sparse_update=True))
H
Helin Wang 已提交
21 22 23 24
    return wordemb


def main():
25 26 27 28 29 30 31 32
    # for local training
    cluster_train = False

    if not cluster_train:
        paddle.init(use_gpu=False, trainer_count=1)
    else:
        paddle.init(
            use_gpu=False,
33
            trainer_count=2,
34 35 36 37
            port=7164,
            ports_num=1,
            ports_num_for_sparse=1,
            num_gradient_servers=1)
H
Helin Wang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
    word_dict = paddle.dataset.imikolov.build_dict()
    dict_size = len(word_dict)
    firstword = paddle.layer.data(
        name="firstw", type=paddle.data_type.integer_value(dict_size))
    secondword = paddle.layer.data(
        name="secondw", type=paddle.data_type.integer_value(dict_size))
    thirdword = paddle.layer.data(
        name="thirdw", type=paddle.data_type.integer_value(dict_size))
    fourthword = paddle.layer.data(
        name="fourthw", type=paddle.data_type.integer_value(dict_size))
    nextword = paddle.layer.data(
        name="fifthw", type=paddle.data_type.integer_value(dict_size))

    Efirst = wordemb(firstword)
    Esecond = wordemb(secondword)
    Ethird = wordemb(thirdword)
    Efourth = wordemb(fourthword)

    contextemb = paddle.layer.concat(input=[Efirst, Esecond, Ethird, Efourth])
    hidden1 = paddle.layer.fc(input=contextemb,
                              size=hiddensize,
                              act=paddle.activation.Sigmoid(),
                              layer_attr=paddle.attr.Extra(drop_rate=0.5),
                              bias_attr=paddle.attr.Param(learning_rate=2),
                              param_attr=paddle.attr.Param(
                                  initial_std=1. / math.sqrt(embsize * 8),
                                  learning_rate=1))
    predictword = paddle.layer.fc(input=hidden1,
                                  size=dict_size,
                                  bias_attr=paddle.attr.Param(learning_rate=2),
                                  act=paddle.activation.Softmax())

    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 100 == 0:
73 74
                with gzip.open("batch-" + str(event.batch_id) + ".tar.gz",
                               'w') as f:
Q
qiaolongfei 已提交
75
                    trainer.save_parameter_to_tar(f)
H
Helin Wang 已提交
76
                result = trainer.test(
77 78
                    paddle.batch(
                        paddle.dataset.imikolov.test(word_dict, N), 32))
H
Helin Wang 已提交
79 80 81 82 83
                print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics,
                    result.metrics)

    cost = paddle.layer.classification_cost(input=predictword, label=nextword)
84

H
Helin Wang 已提交
85
    parameters = paddle.parameters.create(cost)
86
    adagrad = paddle.optimizer.AdaGrad(
H
Helin Wang 已提交
87 88
        learning_rate=3e-3,
        regularization=paddle.optimizer.L2Regularization(8e-4))
89 90 91 92
    trainer = paddle.trainer.SGD(cost,
                                 parameters,
                                 adagrad,
                                 is_local=not cluster_train)
H
Helin Wang 已提交
93
    trainer.train(
94
        paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32),
H
Helin Wang 已提交
95
        num_passes=30,
96
        event_handler=event_handler)
H
Helin Wang 已提交
97 98 99 100


if __name__ == '__main__':
    main()