train.py 7.1 KB
Newer Older
P
peterzhang2029 已提交
1 2 3
import os
import sys
import gzip
4
import click
P
peterzhang2029 已提交
5 6 7 8

import paddle.v2 as paddle

import reader
P
peterzhang2029 已提交
9 10 11
from network_conf import nested_net
from utils import build_word_dict, build_label_dict, load_dict, logger
from config import TrainerConfig as conf
12 13 14 15 16 17


@click.command('train')
@click.option(
    "--train_data_dir",
    default=None,
P
peterzhang2029 已提交
18 19
    help=("The path of training dataset (default: None). "
          "If this parameter is not set, "
20 21 22 23
          "imdb dataset will be used."))
@click.option(
    "--test_data_dir",
    default=None,
P
peterzhang2029 已提交
24 25
    help=("The path of testing dataset (default: None). "
          "If this parameter is not set, "
26 27 28 29 30
          "imdb dataset will be used."))
@click.option(
    "--word_dict_path",
    type=str,
    default=None,
P
peterzhang2029 已提交
31 32 33
    help=("The path of word dictionary (default: None). "
          "If this parameter is not set, imdb dataset will be used. "
          "If this parameter is set, but the file does not exist, "
34 35 36
          "word dictionay will be built from "
          "the training data automatically."))
@click.option(
P
peterzhang2029 已提交
37 38 39
    "--label_dict_path",
    type=str,
    default=None,
P
peterzhang2029 已提交
40
    help=("The path of label dictionary (default: None). "
P
peterzhang2029 已提交
41 42 43 44
          "If this parameter is not set, imdb dataset will be used. "
          "If this parameter is set, but the file does not exist, "
          "label dictionay will be built from "
          "the training data automatically."))
45 46 47 48
@click.option(
    "--model_save_dir",
    type=str,
    default="models",
P
peterzhang2029 已提交
49
    help="The path to save the trained models (default: 'models').")
P
peterzhang2029 已提交
50 51
def train(train_data_dir, test_data_dir, word_dict_path, label_dict_path,
          model_save_dir):
P
peterzhang2029 已提交
52
    """
P
peterzhang2029 已提交
53
    :params train_data_path: The path of training data, if this parameter
P
peterzhang2029 已提交
54 55
        is not specified, imdb dataset will be used to run this example
    :type train_data_path: str
P
peterzhang2029 已提交
56
    :params test_data_path: The path of testing data, if this parameter
P
peterzhang2029 已提交
57 58
        is not specified, imdb dataset will be used to run this example
    :type test_data_path: str
P
peterzhang2029 已提交
59
    :params word_dict_path: The path of word dictionary, if this parameter
P
peterzhang2029 已提交
60 61
        is not specified, imdb dataset will be used to run this example
    :type word_dict_path: str
P
peterzhang2029 已提交
62
    :params label_dict_path: The path of label dictionary, if this parameter
P
peterzhang2029 已提交
63 64
        is not specified, imdb dataset will be used to run this example
    :type label_dict_path: str
P
peterzhang2029 已提交
65
    :params model_save_dir: dir where models saved
P
peterzhang2029 已提交
66
    :type model_save_dir: str
P
peterzhang2029 已提交
67
    """
68
    if train_data_dir is not None:
P
peterzhang2029 已提交
69 70 71
        assert word_dict_path and label_dict_path, (
            "The parameter train_data_dir, word_dict_path, label_dict_path "
            "should be set at the same time.")
72

P
peterzhang2029 已提交
73 74 75 76 77 78 79 80
    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)

    use_default_data = (train_data_dir is None)

    if use_default_data:
        logger.info(("No training data are porivided, "
                     "use imdb to train the model."))
P
peterzhang2029 已提交
81
        logger.info("Please wait to build the word dictionary ...")
P
peterzhang2029 已提交
82 83 84 85 86 87 88 89 90 91 92

        word_dict = reader.imdb_word_dict()
        train_reader = paddle.batch(
            paddle.reader.shuffle(
                lambda: reader.imdb_train(word_dict), buf_size=1000),
            batch_size=100)
        test_reader = paddle.batch(
            lambda: reader.imdb_test(word_dict), batch_size=100)
        class_num = 2
    else:
        if word_dict_path is None or not os.path.exists(word_dict_path):
P
peterzhang2029 已提交
93
            logger.info(("Word dictionary is not given, the dictionary "
P
peterzhang2029 已提交
94 95 96 97
                         "is automatically built from the training data."))

            # build the word dictionary to map the original string-typed
            # words into integer-typed index
P
peterzhang2029 已提交
98
            build_word_dict(
P
peterzhang2029 已提交
99 100 101 102 103
                data_dir=train_data_dir,
                save_path=word_dict_path,
                use_col=1,
                cutoff_fre=0)

P
peterzhang2029 已提交
104 105 106 107 108 109 110 111
        if not os.path.exists(label_dict_path):
            logger.info(("Label dictionary is not given, the dictionary "
                         "is automatically built from the training data."))
            # build the label dictionary to map the original string-typed
            # label into integer-typed index
            build_label_dict(
                data_dir=train_data_dir, save_path=label_dict_path, use_col=0)

112
        word_dict = load_dict(word_dict_path)
P
peterzhang2029 已提交
113 114 115
        label_dict = load_dict(label_dict_path)

        class_num = len(label_dict)
P
peterzhang2029 已提交
116
        logger.info("Class number is : %d." % class_num)
P
peterzhang2029 已提交
117 118 119

        train_reader = paddle.batch(
            paddle.reader.shuffle(
P
peterzhang2029 已提交
120 121 122
                reader.train_reader(train_data_dir, word_dict, label_dict),
                buf_size=conf.buf_size),
            batch_size=conf.batch_size)
P
peterzhang2029 已提交
123 124 125 126 127 128

        if test_data_dir is not None:
            # here, because training and testing data share a same format,
            # we still use the reader.train_reader to read the testing data.
            test_reader = paddle.batch(
                paddle.reader.shuffle(
P
peterzhang2029 已提交
129 130 131
                    reader.train_reader(test_data_dir, word_dict, label_dict),
                    buf_size=conf.buf_size),
                batch_size=conf.batch_size)
P
peterzhang2029 已提交
132 133 134 135 136
        else:
            test_reader = None

    dict_dim = len(word_dict)

P
peterzhang2029 已提交
137
    logger.info("Length of word dictionary is : %d." % (dict_dim))
P
peterzhang2029 已提交
138

P
peterzhang2029 已提交
139
    paddle.init(use_gpu=conf.use_gpu, trainer_count=conf.trainer_count)
P
peterzhang2029 已提交
140 141 142

    # create optimizer
    adam_optimizer = paddle.optimizer.Adam(
P
peterzhang2029 已提交
143 144 145 146 147
        learning_rate=conf.learning_rate,
        regularization=paddle.optimizer.L2Regularization(
            rate=conf.l2_learning_rate),
        model_average=paddle.optimizer.ModelAverage(
            average_window=conf.average_window))
P
peterzhang2029 已提交
148

P
peterzhang2029 已提交
149 150 151 152 153 154 155
    # define network topology.
    cost, prob, label = nested_net(dict_dim, class_num, is_infer=False)

    # create all the trainable parameters.
    parameters = paddle.parameters.create(cost)

    # create the trainer instance.
P
peterzhang2029 已提交
156 157 158 159 160 161
    trainer = paddle.trainer.SGD(
        cost=cost,
        extra_layers=paddle.evaluator.auc(input=prob, label=label),
        parameters=parameters,
        update_equation=adam_optimizer)

P
peterzhang2029 已提交
162
    # feeding dictionary
P
peterzhang2029 已提交
163 164 165 166
    feeding = {"word": 0, "label": 1}

    def _event_handler(event):
        """
P
peterzhang2029 已提交
167
        Define the end batch and the end pass event handler.
P
peterzhang2029 已提交
168 169
        """
        if isinstance(event, paddle.event.EndIteration):
P
peterzhang2029 已提交
170
            if event.batch_id % conf.log_period == 0:
P
peterzhang2029 已提交
171 172 173 174 175 176 177 178 179 180 181
                logger.info("Pass %d, Batch %d, Cost %f, %s\n" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics))

        if isinstance(event, paddle.event.EndPass):
            if test_reader is not None:
                result = trainer.test(reader=test_reader, feeding=feeding)
                logger.info("Test at Pass %d, %s \n" % (event.pass_id,
                                                        result.metrics))
            with gzip.open(
                    os.path.join(model_save_dir, "params_pass_%05d.tar.gz" %
                                 event.pass_id), "w") as f:
182
                trainer.save_parameter_to_tar(f)
P
peterzhang2029 已提交
183

P
peterzhang2029 已提交
184
    # begin training network
P
peterzhang2029 已提交
185 186 187 188
    trainer.train(
        reader=train_reader,
        event_handler=_event_handler,
        feeding=feeding,
P
peterzhang2029 已提交
189
        num_passes=conf.num_passes)
P
peterzhang2029 已提交
190 191 192 193 194

    logger.info("Training has finished.")


if __name__ == "__main__":
195
    train()