train.py 5.4 KB
Newer Older
P
peterzhang2029 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import gzip

import paddle.v2 as paddle

import reader
from network_conf import nest_net
from utils import logger, parse_train_cmd


def train(train_data_dir=None,
          test_data_dir=None,
          word_dict_path=None,
          model_save_dir="models",
          batch_size=32,
          num_passes=10):
    """
    :params train_data_path: path of training data, if this parameter
        is not specified, imdb dataset will be used to run this example
    :type train_data_path: str
    :params test_data_path: path of testing data, if this parameter
        is not specified, imdb dataset will be used to run this example
    :type test_data_path: str
    :params word_dict_path: path of training data, if this parameter
        is not specified, imdb dataset will be used to run this example
    :type word_dict_path: str
    :params model_save_dir: dir where models saved
    :type num_pass: str
    :params batch_size: train batch size
    :type num_pass: int
    :params num_pass: train pass number
    :type num_pass: int
    """
    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)

    use_default_data = (train_data_dir is None)

    if use_default_data:
        logger.info(("No training data are porivided, "
                     "use imdb to train the model."))
        logger.info("please wait to build the word dictionary ...")

        word_dict = reader.imdb_word_dict()

        train_reader = paddle.batch(
            paddle.reader.shuffle(
                lambda: reader.imdb_train(word_dict), buf_size=1000),
            batch_size=100)
        test_reader = paddle.batch(
            lambda: reader.imdb_test(word_dict), batch_size=100)
        class_num = 2
    else:
        if word_dict_path is None or not os.path.exists(word_dict_path):
            logger.info(("word dictionary is not given, the dictionary "
                         "is automatically built from the training data."))

            # build the word dictionary to map the original string-typed
            # words into integer-typed index
            reader.build_dict(
                data_dir=train_data_dir,
                save_path=word_dict_path,
                use_col=1,
                cutoff_fre=0)

        word_dict = reader.load_dict(word_dict_path)
        class_num = args.class_num
        logger.info("class number is : %d." % class_num)

        train_reader = paddle.batch(
            paddle.reader.shuffle(
                reader.train_reader(train_data_dir, word_dict), buf_size=1000),
            batch_size=batch_size)

        if test_data_dir is not None:
            # here, because training and testing data share a same format,
            # we still use the reader.train_reader to read the testing data.
            test_reader = paddle.batch(
                paddle.reader.shuffle(
                    reader.train_reader(test_data_dir, word_dict),
                    buf_size=1000),
                batch_size=batch_size)
        else:
            test_reader = None

    dict_dim = len(word_dict)
    emb_size = 28
    hidden_size = 128

    logger.info("length of word dictionary is : %d." % (dict_dim))

    paddle.init(use_gpu=True, trainer_count=4)

    # network config
    cost, prob, label = nest_net(
        dict_dim, emb_size, hidden_size, class_num, is_infer=False)

    # create parameters
    parameters = paddle.parameters.create(cost)

    # create optimizer
    adam_optimizer = paddle.optimizer.Adam(
        learning_rate=1e-3,
        regularization=paddle.optimizer.L2Regularization(rate=1e-3),
        model_average=paddle.optimizer.ModelAverage(average_window=0.5))

    # create trainer
    trainer = paddle.trainer.SGD(
        cost=cost,
        extra_layers=paddle.evaluator.auc(input=prob, label=label),
        parameters=parameters,
        update_equation=adam_optimizer)

    # begin training network
    feeding = {"word": 0, "label": 1}

    def _event_handler(event):
        """
        Define end batch and end pass event handler
        """
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 100 == 0:
                logger.info("Pass %d, Batch %d, Cost %f, %s\n" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics))

        if isinstance(event, paddle.event.EndPass):
            if test_reader is not None:
                result = trainer.test(reader=test_reader, feeding=feeding)
                logger.info("Test at Pass %d, %s \n" % (event.pass_id,
                                                        result.metrics))
            with gzip.open(
                    os.path.join(model_save_dir, "params_pass_%05d.tar.gz" %
                                 event.pass_id), "w") as f:
                parameters.to_tar(f)

    trainer.train(
        reader=train_reader,
        event_handler=_event_handler,
        feeding=feeding,
        num_passes=num_passes)

    logger.info("Training has finished.")


def main(args):
    train(
        train_data_dir=args.train_data_dir,
        test_data_dir=args.test_data_dir,
        word_dict_path=args.word_dict,
        batch_size=args.batch_size,
        num_passes=args.num_passes,
        model_save_dir=args.model_save_dir)


if __name__ == "__main__":
    args = parse_train_cmd()
    if args.train_data_dir is not None:
        assert args.word_dict, ("the parameter train_data_dir, word_dict_path "
                                "should be set at the same time.")
    main(args)