train.py 4.3 KB
Newer Older
C
caoying03 已提交
1
import os
Z
zhaopu 已提交
2
import sys
C
caoying03 已提交
3 4
import gzip

Z
zhaopu 已提交
5
import paddle.v2 as paddle
C
caoying03 已提交
6
import config as conf
Z
zhaopu 已提交
7
import reader
C
caoying03 已提交
8 9
from network_conf import rnn_lm
from utils import logger, build_dict, load_dict
Z
zhaopu 已提交
10 11


C
caoying03 已提交
12 13 14 15 16
def train(topology,
          train_reader,
          test_reader,
          model_save_dir="models",
          num_passes=10):
Z
zhaopu 已提交
17 18 19
    """
    train model.

20 21
    :param topology: cost layer of the model to train.
    :type topology: LayerOuput
Z
zhaopu 已提交
22
    :param train_reader: train data reader.
23
    :type trainer_reader: collections.Iterable
Z
zhaopu 已提交
24
    :param test_reader: test data reader.
25 26 27 28 29
    :type test_reader: collections.Iterable
    :param model_save_dir: path to save the trained model
    :type model_save_dir: str
    :param num_passes: number of epoch
    :type num_passes: int
Z
zhaopu 已提交
30
    """
C
caoying03 已提交
31 32
    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)
Z
zhaopu 已提交
33

C
caoying03 已提交
34
    # initialize PaddlePaddle
35
    paddle.init(use_gpu=conf.use_gpu, trainer_count=conf.trainer_count)
Z
zhaopu 已提交
36 37 38 39 40 41 42 43

    # create optimizer
    adam_optimizer = paddle.optimizer.Adam(
        learning_rate=1e-3,
        regularization=paddle.optimizer.L2Regularization(rate=1e-3),
        model_average=paddle.optimizer.ModelAverage(
            average_window=0.5, max_average_window=10000))

C
caoying03 已提交
44 45
    # create parameters
    parameters = paddle.parameters.create(topology)
P
peterzhang2029 已提交
46 47
    # create sum evaluator
    sum_eval = paddle.evaluator.sum(topology)
Z
zhaopu 已提交
48
    # create trainer
49 50 51 52
    trainer = paddle.trainer.SGD(cost=topology,
                                 parameters=parameters,
                                 update_equation=adam_optimizer,
                                 extra_layers=sum_eval)
Z
zhaopu 已提交
53

C
caoying03 已提交
54
    # define the event_handler callback
Z
zhaopu 已提交
55 56
    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
C
caoying03 已提交
57 58
            if not event.batch_id % conf.log_period:
                logger.info("Pass %d, Batch %d, Cost %f, %s" % (
Z
zhaopu 已提交
59 60
                    event.pass_id, event.batch_id, event.cost, event.metrics))

C
caoying03 已提交
61 62 63 64 65 66
            if (not event.batch_id %
                    conf.save_period_by_batches) and event.batch_id:
                save_name = os.path.join(model_save_dir,
                                         "rnn_lm_pass_%05d_batch_%03d.tar.gz" %
                                         (event.pass_id, event.batch_id))
                with gzip.open(save_name, "w") as f:
67
                    trainer.save_parameter_to_tar(f)
C
caoying03 已提交
68

Z
zhaopu 已提交
69
        if isinstance(event, paddle.event.EndPass):
C
caoying03 已提交
70 71 72 73 74 75 76
            if test_reader is not None:
                result = trainer.test(reader=test_reader)
                logger.info("Test with Pass %d, %s" %
                            (event.pass_id, result.metrics))
            save_name = os.path.join(model_save_dir, "rnn_lm_pass_%05d.tar.gz" %
                                     (event.pass_id))
            with gzip.open(save_name, "w") as f:
77
                trainer.save_parameter_to_tar(f)
Z
zhaopu 已提交
78

C
caoying03 已提交
79
    logger.info("start training...")
Z
zhaopu 已提交
80 81 82
    trainer.train(
        reader=train_reader, event_handler=event_handler, num_passes=num_passes)

C
caoying03 已提交
83
    logger.info("Training is finished.")
Z
zhaopu 已提交
84 85 86 87


def main():
    # prepare vocab
C
caoying03 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
    if not (os.path.exists(conf.vocab_file) and
            os.path.getsize(conf.vocab_file)):
        logger.info(("word dictionary does not exist, "
                     "build it from the training data"))
        build_dict(conf.train_file, conf.vocab_file, conf.max_word_num,
                   conf.cutoff_word_fre)
    logger.info("load word dictionary.")
    word_dict = load_dict(conf.vocab_file)
    logger.info("dictionay size = %d" % (len(word_dict)))

    cost = rnn_lm(
        len(word_dict), conf.emb_dim, conf.hidden_size, conf.stacked_rnn_num,
        conf.rnn_type)

    # define reader
    reader_args = {
        "file_name": conf.train_file,
        "word_dict": word_dict,
    }
    train_reader = paddle.batch(
        paddle.reader.shuffle(
            reader.rnn_reader(**reader_args), buf_size=102400),
        batch_size=conf.batch_size)
    test_reader = None
    if os.path.exists(conf.test_file) and os.path.getsize(conf.test_file):
Z
zhaopu 已提交
113 114
        test_reader = paddle.batch(
            paddle.reader.shuffle(
C
caoying03 已提交
115
                reader.rnn_reader(**reader_args), buf_size=65536),
R
ranqiu 已提交
116
            batch_size=conf.batch_size)
Z
zhaopu 已提交
117 118

    train(
C
caoying03 已提交
119
        topology=cost,
Z
zhaopu 已提交
120 121
        train_reader=train_reader,
        test_reader=test_reader,
C
caoying03 已提交
122 123
        model_save_dir=conf.model_save_dir,
        num_passes=conf.num_passes)
Z
zhaopu 已提交
124 125 126 127


if __name__ == "__main__":
    main()