train.py 1.7 KB
Newer Older
C
caoying03 已提交
1 2
import os
import logging
P
pakchoi 已提交
3 4
import gzip

C
caoying03 已提交
5 6 7 8 9
import paddle.v2 as paddle
from network_conf import ngram_lm

logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
P
pakchoi 已提交
10 11


C
caoying03 已提交
12 13 14 15
def train(model_save_dir):
    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)

P
pakchoi 已提交
16 17 18 19
    paddle.init(use_gpu=False, trainer_count=1)
    word_dict = paddle.dataset.imikolov.build_dict()
    dict_size = len(word_dict)

C
caoying03 已提交
20
    optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
P
pakchoi 已提交
21

C
caoying03 已提交
22
    cost = ngram_lm(hidden_size=128, emb_size=512, dict_size=dict_size)
P
pakchoi 已提交
23
    parameters = paddle.parameters.create(cost)
C
caoying03 已提交
24
    trainer = paddle.trainer.SGD(cost, parameters, optimizer)
P
pakchoi 已提交
25 26 27

    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
C
caoying03 已提交
28 29 30
            if event.batch_id and not event.batch_id % 10:
                logger.info("Pass %d, Batch %d, Cost %f" %
                            (event.pass_id, event.batch_id, event.cost))
P
Peng Li 已提交
31
        elif isinstance(event, paddle.event.EndPass):
P
pakchoi 已提交
32 33
            result = trainer.test(
                paddle.batch(paddle.dataset.imikolov.test(word_dict, 5), 64))
C
caoying03 已提交
34
            logger.info("Test Pass %d, Cost %f" % (event.pass_id, result.cost))
P
pakchoi 已提交
35

C
caoying03 已提交
36 37 38 39
            save_path = os.path.join(model_save_dir,
                                     "model_pass_%05d.tar.gz" % event.pass_id)
            logger.info("Save model into %s ..." % save_path)
            with gzip.open(save_path, "w") as f:
40
                trainer.save_parameter_to_tar(f)
P
pakchoi 已提交
41 42

    trainer.train(
L
livc 已提交
43 44 45
        paddle.batch(
            paddle.reader.shuffle(
                lambda: paddle.dataset.imikolov.train(word_dict, 5)(),
46 47
                buf_size=1000),
            64),
Y
yangyaming 已提交
48 49
        num_passes=1000,
        event_handler=event_handler)
P
pakchoi 已提交
50 51


C
caoying03 已提交
52 53
if __name__ == "__main__":
    train(model_save_dir="models")