train.py 6.0 KB
Newer Older
1
import os
2 3 4 5 6 7
import sys
import gzip

import paddle.v2 as paddle

import reader
8 9
from utils import logger, parse_train_cmd, build_dict, load_dict
from network_conf import fc_net, convolution_net
10 11 12 13 14 15 16


def train(topology,
          train_data_dir=None,
          test_data_dir=None,
          word_dict_path=None,
          label_dict_path=None,
17
          model_save_dir="models",
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
          batch_size=32,
          num_passes=10):
    """
    train dnn model


    :params train_data_path: path of training data, if this parameter
        is not specified, paddle.dataset.imdb will be used to run this example
    :type train_data_path: str
    :params test_data_path: path of testing data, if this parameter
        is not specified, paddle.dataset.imdb will be used to run this example
    :type test_data_path: str
    :params word_dict_path: path of training data, if this parameter
        is not specified, paddle.dataset.imdb will be used to run this example
    :type word_dict_path: str
    :params num_pass: train pass number
    :type num_pass: int
    """
36 37
    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)
38 39 40 41

    use_default_data = (train_data_dir is None)

    if use_default_data:
O
olenet 已提交
42
        logger.info(("No training data are provided, "
43 44 45 46 47 48
                     "use paddle.dataset.imdb to train the model."))
        logger.info("please wait to build the word dictionary ...")

        word_dict = paddle.dataset.imdb.word_dict()
        train_reader = paddle.batch(
            paddle.reader.shuffle(
49
                lambda: paddle.dataset.imdb.train(word_dict)(), buf_size=51200),
50 51
            batch_size=100)
        test_reader = paddle.batch(
W
wanghaoshuang 已提交
52
            lambda: paddle.dataset.imdb.test(word_dict)(), batch_size=100)
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

        class_num = 2
    else:
        if word_dict_path is None or not os.path.exists(word_dict_path):
            logger.info(("word dictionary is not given, the dictionary "
                         "is automatically built from the training data."))

            # build the word dictionary to map the original string-typed
            # words into integer-typed index
            build_dict(
                data_dir=train_data_dir,
                save_path=word_dict_path,
                use_col=1,
                cutoff_fre=5,
                insert_extra_words=["<UNK>"])

        if not os.path.exists(label_dict_path):
            logger.info(("label dictionary is not given, the dictionary "
                         "is automatically built from the training data."))
            # build the label dictionary to map the original string-typed
            # label into integer-typed index
            build_dict(
                data_dir=train_data_dir, save_path=label_dict_path, use_col=0)

        word_dict = load_dict(word_dict_path)

        lbl_dict = load_dict(label_dict_path)
        class_num = len(lbl_dict)
        logger.info("class number is : %d." % (len(lbl_dict)))

        train_reader = paddle.batch(
            paddle.reader.shuffle(
                reader.train_reader(train_data_dir, word_dict, lbl_dict),
86
                buf_size=51200),
87 88 89 90 91 92
            batch_size=batch_size)

        if test_data_dir is not None:
            # here, because training and testing data share a same format,
            # we still use the reader.train_reader to read the testing data.
            test_reader = paddle.batch(
93
                reader.train_reader(test_data_dir, word_dict, lbl_dict),
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
                batch_size=batch_size)
        else:
            test_reader = None

    dict_dim = len(word_dict)
    logger.info("length of word dictionary is : %d." % (dict_dim))

    paddle.init(use_gpu=False, trainer_count=1)

    # network config
    cost, prob, label = topology(dict_dim, class_num)

    # create parameters
    parameters = paddle.parameters.create(cost)

    # create optimizer
    adam_optimizer = paddle.optimizer.Adam(
        learning_rate=1e-3,
        regularization=paddle.optimizer.L2Regularization(rate=1e-3),
        model_average=paddle.optimizer.ModelAverage(average_window=0.5))

    # create trainer
    trainer = paddle.trainer.SGD(
        cost=cost,
        extra_layers=paddle.evaluator.auc(input=prob, label=label),
        parameters=parameters,
        update_equation=adam_optimizer)

    # begin training network
    feeding = {"word": 0, "label": 1}

    def _event_handler(event):
        """
        Define end batch and end pass event handler
        """
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 100 == 0:
                logger.info("Pass %d, Batch %d, Cost %f, %s\n" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics))

        if isinstance(event, paddle.event.EndPass):
            if test_reader is not None:
                result = trainer.test(reader=test_reader, feeding=feeding)
                logger.info("Test at Pass %d, %s \n" % (event.pass_id,
                                                        result.metrics))
139 140 141
            with gzip.open(
                    os.path.join(model_save_dir, "dnn_params_pass_%05d.tar.gz" %
                                 event.pass_id), "w") as f:
142
                trainer.save_parameter_to_tar(f)
143 144 145 146 147 148 149 150 151 152 153 154

    trainer.train(
        reader=train_reader,
        event_handler=_event_handler,
        feeding=feeding,
        num_passes=num_passes)

    logger.info("Training has finished.")


def main(args):
    if args.nn_type == "dnn":
155
        topology = fc_net
156
    elif args.nn_type == "cnn":
157
        topology = convolution_net
158 159 160 161 162 163 164 165

    train(
        topology=topology,
        train_data_dir=args.train_data_dir,
        test_data_dir=args.test_data_dir,
        word_dict_path=args.word_dict,
        label_dict_path=args.label_dict,
        batch_size=args.batch_size,
166 167
        num_passes=args.num_passes,
        model_save_dir=args.model_save_dir)
168 169 170 171 172 173 174 175 176


if __name__ == "__main__":
    args = parse_train_cmd()
    if args.train_data_dir is not None:
        assert args.word_dict and args.label_dict, (
            "the parameter train_data_dir, word_dict_path, and label_dict_path "
            "should be set at the same time.")
    main(args)