train.py 1.8 KB
Newer Older
C
caoying03 已提交
1 2 3
import os
import logging
import gzip
4 5

import paddle.v2 as paddle
C
caoying03 已提交
6 7 8 9
from network_conf import ngram_lm

logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
10 11


C
caoying03 已提交
12 13 14 15
def main(save_dir="models"):
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

16
    paddle.init(use_gpu=False, trainer_count=1)
17
    word_dict = paddle.dataset.imikolov.build_dict(min_word_freq=2)
18
    dict_size = len(word_dict)
19 20 21 22 23

    adam_optimizer = paddle.optimizer.Adam(
        learning_rate=3e-3,
        regularization=paddle.optimizer.L2Regularization(8e-4))

C
caoying03 已提交
24
    cost = ngram_lm(hidden_size=256, embed_size=32, dict_size=dict_size)
25
    parameters = paddle.parameters.create(cost)
26 27 28

    def event_handler(event):
        if isinstance(event, paddle.event.EndPass):
C
caoying03 已提交
29 30 31 32
            model_name = os.path.join(save_dir, "hsigmoid_pass_%05d.tar.gz" %
                                      event.pass_id)
            logger.info("Save model into %s ..." % model_name)
            with gzip.open(model_name, "w") as f:
33 34 35
                parameters.to_tar(f)

        if isinstance(event, paddle.event.EndIteration):
C
caoying03 已提交
36
            if event.batch_id and event.batch_id % 10 == 0:
37 38 39
                result = trainer.test(
                    paddle.batch(
                        paddle.dataset.imikolov.test(word_dict, 5), 32))
C
caoying03 已提交
40 41 42
                logger.info(
                    "Pass %d, Batch %d, Cost %f, Test Cost %f" %
                    (event.pass_id, event.batch_id, event.cost, result.cost))
43

44
    trainer = paddle.trainer.SGD(cost, parameters, adam_optimizer)
45 46 47 48 49 50 51

    trainer.train(
        paddle.batch(
            paddle.reader.shuffle(
                lambda: paddle.dataset.imikolov.train(word_dict, 5)(),
                buf_size=1000), 64),
        num_passes=30,
C
caoying03 已提交
52
        event_handler=event_handler)
53 54


C
caoying03 已提交
55
if __name__ == "__main__":
56
    main()