train.py 1.9 KB
Newer Older
C
caoying03 已提交
1 2 3
import os
import logging
import gzip
4 5

import paddle.v2 as paddle
C
caoying03 已提交
6 7 8 9
from network_conf import ngram_lm

logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
10 11


C
caoying03 已提交
12 13 14 15
def main(save_dir="models"):
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

16
    paddle.init(use_gpu=False, trainer_count=1)
17
    word_dict = paddle.dataset.imikolov.build_dict(min_word_freq=2)
18
    dict_size = len(word_dict)
19 20 21 22 23

    adam_optimizer = paddle.optimizer.Adam(
        learning_rate=3e-3,
        regularization=paddle.optimizer.L2Regularization(8e-4))

C
caoying03 已提交
24
    cost = ngram_lm(hidden_size=256, embed_size=32, dict_size=dict_size)
25

26 27 28 29 30 31
    parameters = paddle.parameters.create(cost)
    adam_optimizer = paddle.optimizer.Adam(
        learning_rate=3e-3,
        regularization=paddle.optimizer.L2Regularization(8e-4))
    trainer = paddle.trainer.SGD(cost, parameters, adam_optimizer)

32 33
    def event_handler(event):
        if isinstance(event, paddle.event.EndPass):
C
caoying03 已提交
34 35 36 37
            model_name = os.path.join(save_dir, "hsigmoid_pass_%05d.tar.gz" %
                                      event.pass_id)
            logger.info("Save model into %s ..." % model_name)
            with gzip.open(model_name, "w") as f:
38
                trainer.save_parameter_to_tar(f)
39 40

        if isinstance(event, paddle.event.EndIteration):
C
caoying03 已提交
41
            if event.batch_id and event.batch_id % 10 == 0:
42 43 44
                result = trainer.test(
                    paddle.batch(
                        paddle.dataset.imikolov.test(word_dict, 5), 32))
C
caoying03 已提交
45 46 47
                logger.info(
                    "Pass %d, Batch %d, Cost %f, Test Cost %f" %
                    (event.pass_id, event.batch_id, event.cost, result.cost))
48

49 50 51 52 53 54
    trainer.train(
        paddle.batch(
            paddle.reader.shuffle(
                lambda: paddle.dataset.imikolov.train(word_dict, 5)(),
                buf_size=1000), 64),
        num_passes=30,
C
caoying03 已提交
55
        event_handler=event_handler)
56 57


C
caoying03 已提交
58
if __name__ == "__main__":
59
    main()