train.py 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import gzip

import paddle.v2 as paddle

import network_conf
import reader
from utils import *


def train(topology,
          train_data_dir=None,
          test_data_dir=None,
          word_dict_path=None,
          label_dict_path=None,
          batch_size=32,
          num_passes=10):
    """
    train dnn model


    :params train_data_path: path of training data, if this parameter
        is not specified, paddle.dataset.imdb will be used to run this example
    :type train_data_path: str
    :params test_data_path: path of testing data, if this parameter
        is not specified, paddle.dataset.imdb will be used to run this example
    :type test_data_path: str
    :params word_dict_path: path of training data, if this parameter
        is not specified, paddle.dataset.imdb will be used to run this example
    :type word_dict_path: str
    :params num_pass: train pass number
    :type num_pass: int
    """

    use_default_data = (train_data_dir is None)

    if use_default_data:
        logger.info(("No training data are porivided, "
                     "use paddle.dataset.imdb to train the model."))
        logger.info("please wait to build the word dictionary ...")

        word_dict = paddle.dataset.imdb.word_dict()
        train_reader = paddle.batch(
            paddle.reader.shuffle(
                lambda: paddle.dataset.imdb.train(word_dict), buf_size=1000),
            batch_size=100)
        test_reader = paddle.batch(
            lambda: paddle.dataset.imdb.test(word_dict), batch_size=100)

        class_num = 2
    else:
        if word_dict_path is None or not os.path.exists(word_dict_path):
            logger.info(("word dictionary is not given, the dictionary "
                         "is automatically built from the training data."))

            # build the word dictionary to map the original string-typed
            # words into integer-typed index
            build_dict(
                data_dir=train_data_dir,
                save_path=word_dict_path,
                use_col=1,
                cutoff_fre=5,
                insert_extra_words=["<UNK>"])

        if not os.path.exists(label_dict_path):
            logger.info(("label dictionary is not given, the dictionary "
                         "is automatically built from the training data."))
            # build the label dictionary to map the original string-typed
            # label into integer-typed index
            build_dict(
                data_dir=train_data_dir, save_path=label_dict_path, use_col=0)

        word_dict = load_dict(word_dict_path)

        lbl_dict = load_dict(label_dict_path)
        class_num = len(lbl_dict)
        logger.info("class number is : %d." % (len(lbl_dict)))

        train_reader = paddle.batch(
            paddle.reader.shuffle(
                reader.train_reader(train_data_dir, word_dict, lbl_dict),
                buf_size=1000),
            batch_size=batch_size)

        if test_data_dir is not None:
            # here, because training and testing data share a same format,
            # we still use the reader.train_reader to read the testing data.
            test_reader = paddle.batch(
                paddle.reader.shuffle(
                    reader.train_reader(test_data_dir, word_dict, lbl_dict),
                    buf_size=1000),
                batch_size=batch_size)
        else:
            test_reader = None

    dict_dim = len(word_dict)
    logger.info("length of word dictionary is : %d." % (dict_dim))

    paddle.init(use_gpu=False, trainer_count=1)

    # network config
    cost, prob, label = topology(dict_dim, class_num)

    # create parameters
    parameters = paddle.parameters.create(cost)

    # create optimizer
    adam_optimizer = paddle.optimizer.Adam(
        learning_rate=1e-3,
        regularization=paddle.optimizer.L2Regularization(rate=1e-3),
        model_average=paddle.optimizer.ModelAverage(average_window=0.5))

    # create trainer
    trainer = paddle.trainer.SGD(
        cost=cost,
        extra_layers=paddle.evaluator.auc(input=prob, label=label),
        parameters=parameters,
        update_equation=adam_optimizer)

    # begin training network
    feeding = {"word": 0, "label": 1}

    def _event_handler(event):
        """
        Define end batch and end pass event handler
        """
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 100 == 0:
                logger.info("Pass %d, Batch %d, Cost %f, %s\n" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics))

        if isinstance(event, paddle.event.EndPass):
            if test_reader is not None:
                result = trainer.test(reader=test_reader, feeding=feeding)
                logger.info("Test at Pass %d, %s \n" % (event.pass_id,
                                                        result.metrics))
            with gzip.open("dnn_params_pass_%05d.tar.gz" % event.pass_id,
                           "w") as f:
                parameters.to_tar(f)

    trainer.train(
        reader=train_reader,
        event_handler=_event_handler,
        feeding=feeding,
        num_passes=num_passes)

    logger.info("Training has finished.")


def main(args):
    if args.nn_type == "dnn":
        topology = network_conf.fc_net
    elif args.nn_type == "cnn":
        topology = network_conf.convolution_net

    train(
        topology=topology,
        train_data_dir=args.train_data_dir,
        test_data_dir=args.test_data_dir,
        word_dict_path=args.word_dict,
        label_dict_path=args.label_dict,
        batch_size=args.batch_size,
        num_passes=args.num_passes)


if __name__ == "__main__":
    args = parse_train_cmd()
    if args.train_data_dir is not None:
        assert args.word_dict and args.label_dict, (
            "the parameter train_data_dir, word_dict_path, and label_dict_path "
            "should be set at the same time.")
    main(args)