train.py 1.7 KB
Newer Older
C
caoying03 已提交
1
#!/usr/bin/env python
P
pakchoi 已提交
2
# -*- encoding:utf-8 -*-
C
caoying03 已提交
3 4
import os
import logging
P
pakchoi 已提交
5 6
import gzip

C
caoying03 已提交
7 8 9 10 11
import paddle.v2 as paddle
from network_conf import ngram_lm

logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
P
pakchoi 已提交
12 13


C
caoying03 已提交
14 15 16 17
def train(model_save_dir):
    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)

P
pakchoi 已提交
18 19 20 21
    paddle.init(use_gpu=False, trainer_count=1)
    word_dict = paddle.dataset.imikolov.build_dict()
    dict_size = len(word_dict)

C
caoying03 已提交
22
    optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
P
pakchoi 已提交
23

C
caoying03 已提交
24
    cost = ngram_lm(hidden_size=128, emb_size=512, dict_size=dict_size)
P
pakchoi 已提交
25
    parameters = paddle.parameters.create(cost)
C
caoying03 已提交
26
    trainer = paddle.trainer.SGD(cost, parameters, optimizer)
P
pakchoi 已提交
27 28 29

    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
C
caoying03 已提交
30 31 32
            if event.batch_id and not event.batch_id % 10:
                logger.info("Pass %d, Batch %d, Cost %f" %
                            (event.pass_id, event.batch_id, event.cost))
P
pakchoi 已提交
33 34 35 36

        if isinstance(event, paddle.event.EndPass):
            result = trainer.test(
                paddle.batch(paddle.dataset.imikolov.test(word_dict, 5), 64))
C
caoying03 已提交
37
            logger.info("Test Pass %d, Cost %f" % (event.pass_id, result.cost))
P
pakchoi 已提交
38

C
caoying03 已提交
39 40 41 42
            save_path = os.path.join(model_save_dir,
                                     "model_pass_%05d.tar.gz" % event.pass_id)
            logger.info("Save model into %s ..." % save_path)
            with gzip.open(save_path, "w") as f:
P
pakchoi 已提交
43 44 45 46 47
                parameters.to_tar(f)

    trainer.train(
        paddle.batch(paddle.dataset.imikolov.train(word_dict, 5), 64),
        num_passes=1000,
C
caoying03 已提交
48
        event_handler=event_handler)
P
pakchoi 已提交
49 50


C
caoying03 已提交
51 52
if __name__ == "__main__":
    train(model_save_dir="models")