train.py 4.2 KB
Newer Older
C
caoying03 已提交
1
#!/usr/bin/env python
Z
zhaopu 已提交
2
# coding=utf-8
C
caoying03 已提交
3
import os
Z
zhaopu 已提交
4
import sys
C
caoying03 已提交
5 6 7
import gzip
import pdb

Z
zhaopu 已提交
8
import paddle.v2 as paddle
C
caoying03 已提交
9
import config as conf
Z
zhaopu 已提交
10
import reader
C
caoying03 已提交
11 12
from network_conf import rnn_lm
from utils import logger, build_dict, load_dict
Z
zhaopu 已提交
13 14


C
caoying03 已提交
15 16 17 18 19
def train(topology,
          train_reader,
          test_reader,
          model_save_dir="models",
          num_passes=10):
Z
zhaopu 已提交
20 21 22
    """
    train model.

23 24
    :param topology: cost layer of the model to train.
    :type topology: LayerOuput
Z
zhaopu 已提交
25
    :param train_reader: train data reader.
26
    :type trainer_reader: collections.Iterable
Z
zhaopu 已提交
27
    :param test_reader: test data reader.
28 29 30 31 32
    :type test_reader: collections.Iterable
    :param model_save_dir: path to save the trained model
    :type model_save_dir: str
    :param num_passes: number of epoch
    :type num_passes: int
Z
zhaopu 已提交
33
    """
C
caoying03 已提交
34 35
    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)
Z
zhaopu 已提交
36

C
caoying03 已提交
37
    # initialize PaddlePaddle
38
    paddle.init(use_gpu=conf.use_gpu, trainer_count=conf.trainer_count)
Z
zhaopu 已提交
39 40 41 42 43 44 45 46

    # create optimizer
    adam_optimizer = paddle.optimizer.Adam(
        learning_rate=1e-3,
        regularization=paddle.optimizer.L2Regularization(rate=1e-3),
        model_average=paddle.optimizer.ModelAverage(
            average_window=0.5, max_average_window=10000))

C
caoying03 已提交
47 48
    # create parameters
    parameters = paddle.parameters.create(topology)
Z
zhaopu 已提交
49 50
    # create trainer
    trainer = paddle.trainer.SGD(
C
caoying03 已提交
51
        cost=topology, parameters=parameters, update_equation=adam_optimizer)
Z
zhaopu 已提交
52

C
caoying03 已提交
53
    # define the event_handler callback
Z
zhaopu 已提交
54 55
    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
C
caoying03 已提交
56 57
            if not event.batch_id % conf.log_period:
                logger.info("Pass %d, Batch %d, Cost %f, %s" % (
Z
zhaopu 已提交
58 59
                    event.pass_id, event.batch_id, event.cost, event.metrics))

C
caoying03 已提交
60 61 62 63 64 65 66 67
            if (not event.batch_id %
                    conf.save_period_by_batches) and event.batch_id:
                save_name = os.path.join(model_save_dir,
                                         "rnn_lm_pass_%05d_batch_%03d.tar.gz" %
                                         (event.pass_id, event.batch_id))
                with gzip.open(save_name, "w") as f:
                    parameters.to_tar(f)

Z
zhaopu 已提交
68
        if isinstance(event, paddle.event.EndPass):
C
caoying03 已提交
69 70 71 72 73 74 75
            if test_reader is not None:
                result = trainer.test(reader=test_reader)
                logger.info("Test with Pass %d, %s" %
                            (event.pass_id, result.metrics))
            save_name = os.path.join(model_save_dir, "rnn_lm_pass_%05d.tar.gz" %
                                     (event.pass_id))
            with gzip.open(save_name, "w") as f:
Z
zhaopu 已提交
76 77
                parameters.to_tar(f)

C
caoying03 已提交
78
    logger.info("start training...")
Z
zhaopu 已提交
79 80 81
    trainer.train(
        reader=train_reader, event_handler=event_handler, num_passes=num_passes)

C
caoying03 已提交
82
    logger.info("Training is finished.")
Z
zhaopu 已提交
83 84 85 86


def main():
    # prepare vocab
C
caoying03 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    if not (os.path.exists(conf.vocab_file) and
            os.path.getsize(conf.vocab_file)):
        logger.info(("word dictionary does not exist, "
                     "build it from the training data"))
        build_dict(conf.train_file, conf.vocab_file, conf.max_word_num,
                   conf.cutoff_word_fre)
    logger.info("load word dictionary.")
    word_dict = load_dict(conf.vocab_file)
    logger.info("dictionay size = %d" % (len(word_dict)))

    cost = rnn_lm(
        len(word_dict), conf.emb_dim, conf.hidden_size, conf.stacked_rnn_num,
        conf.rnn_type)

    # define reader
    reader_args = {
        "file_name": conf.train_file,
        "word_dict": word_dict,
    }
    train_reader = paddle.batch(
        paddle.reader.shuffle(
            reader.rnn_reader(**reader_args), buf_size=102400),
        batch_size=conf.batch_size)
    test_reader = None
    if os.path.exists(conf.test_file) and os.path.getsize(conf.test_file):
Z
zhaopu 已提交
112 113
        test_reader = paddle.batch(
            paddle.reader.shuffle(
C
caoying03 已提交
114
                reader.rnn_reader(**reader_args), buf_size=65536),
Z
zhaopu 已提交
115 116 117
            batch_size=config.batch_size)

    train(
C
caoying03 已提交
118
        topology=cost,
Z
zhaopu 已提交
119 120
        train_reader=train_reader,
        test_reader=test_reader,
C
caoying03 已提交
121 122
        model_save_dir=conf.model_save_dir,
        num_passes=conf.num_passes)
Z
zhaopu 已提交
123 124 125 126


if __name__ == "__main__":
    main()