train.py 8.9 KB
Newer Older
S
Superjom 已提交
1
import argparse
R
ranqiu 已提交
2
import distutils.util
S
Superjom 已提交
3 4 5 6

import paddle.v2 as paddle
from network_conf import DSSM
import reader
S
Superjom 已提交
7
from utils import TaskType, load_dic, logger, ModelType, ModelArch, display_args
S
Superjom 已提交
8 9 10 11

parser = argparse.ArgumentParser(description="PaddlePaddle DSSM example")

parser.add_argument(
C
caoying03 已提交
12 13
    "-i",
    "--train_data_path",
S
Superjom 已提交
14 15
    type=str,
    required=False,
C
caoying03 已提交
16
    help="The path of training data.")
S
Superjom 已提交
17
parser.add_argument(
C
caoying03 已提交
18 19
    "-t",
    "--test_data_path",
S
Superjom 已提交
20 21
    type=str,
    required=False,
C
caoying03 已提交
22
    help="The path of testing data.")
S
Superjom 已提交
23
parser.add_argument(
C
caoying03 已提交
24 25
    "-s",
    "--source_dic_path",
S
Superjom 已提交
26 27
    type=str,
    required=False,
C
caoying03 已提交
28
    help="The path of the source's word dictionary.")
S
Superjom 已提交
29
parser.add_argument(
C
caoying03 已提交
30
    "--target_dic_path",
S
Superjom 已提交
31 32
    type=str,
    required=False,
C
caoying03 已提交
33 34
    help=("The path of the target's word dictionary, "
          "if this parameter is not set, the `source_dic_path` will be used"))
S
Superjom 已提交
35
parser.add_argument(
C
caoying03 已提交
36 37
    "-b",
    "--batch_size",
S
Superjom 已提交
38
    type=int,
R
ranqiu 已提交
39
    default=32,
C
caoying03 已提交
40
    help="The size of mini-batch (default:32).")
S
Superjom 已提交
41
parser.add_argument(
C
caoying03 已提交
42 43
    "-p",
    "--num_passes",
S
Superjom 已提交
44 45
    type=int,
    default=10,
C
caoying03 已提交
46
    help="The number of passes to run(default:10).")
S
Superjom 已提交
47
parser.add_argument(
C
caoying03 已提交
48 49
    "-y",
    "--model_type",
S
Superjom 已提交
50
    type=int,
S
Superjom 已提交
51 52
    required=True,
    default=ModelType.CLASSIFICATION_MODE,
C
caoying03 已提交
53 54 55 56
    help=("model type, %d for classification, %d for pairwise rank, "
          "%d for regression (default: classification).") %
    (ModelType.CLASSIFICATION_MODE, ModelType.RANK_MODE,
     ModelType.REGRESSION_MODE))
S
Superjom 已提交
57
parser.add_argument(
C
caoying03 已提交
58 59
    "-a",
    "--model_arch",
S
Superjom 已提交
60 61 62
    type=int,
    required=True,
    default=ModelArch.CNN_MODE,
C
caoying03 已提交
63
    help="The model architecture, %d for CNN, %d for FC, %d for RNN." %
S
Superjom 已提交
64
    (ModelArch.CNN_MODE, ModelArch.FC_MODE, ModelArch.RNN_MODE))
S
Superjom 已提交
65
parser.add_argument(
C
caoying03 已提交
66
    "--share_network_between_source_target",
R
ranqiu 已提交
67
    type=distutils.util.strtobool,
S
Superjom 已提交
68
    default=False,
C
caoying03 已提交
69
    help="Whether to share network parameters between source and target.")
S
Superjom 已提交
70
parser.add_argument(
C
caoying03 已提交
71
    "--share_embed",
R
ranqiu 已提交
72
    type=distutils.util.strtobool,
S
Superjom 已提交
73
    default=False,
C
caoying03 已提交
74
    help="Whether to share word embedding between source and target.")
S
Superjom 已提交
75
parser.add_argument(
C
caoying03 已提交
76
    "--dnn_dims",
S
Superjom 已提交
77
    type=str,
C
caoying03 已提交
78 79 80 81
    default="256,128,64,32",
    help=("The dimentions of dnn layers, default is '256,128,64,32', "
          "which means create a 4-layer dnn. The dimention of each layer is "
          "'256, 128, 64 and 32'."))
S
Superjom 已提交
82
parser.add_argument(
C
caoying03 已提交
83 84 85 86
    "--num_workers",
    type=int,
    default=1,
    help="The number of worker threads, default 1.")
S
Superjom 已提交
87
parser.add_argument(
C
caoying03 已提交
88
    "--use_gpu",
R
ranqiu 已提交
89
    type=distutils.util.strtobool,
S
Superjom 已提交
90
    default=False,
C
caoying03 已提交
91
    help="Whether to use GPU devices (default: False)")
S
Superjom 已提交
92
parser.add_argument(
C
caoying03 已提交
93 94
    "-c",
    "--class_num",
S
Superjom 已提交
95 96
    type=int,
    default=0,
C
caoying03 已提交
97
    help="The number of categories for classification task.")
S
Superjom 已提交
98
parser.add_argument(
C
caoying03 已提交
99
    "--model_output_prefix",
S
Superjom 已提交
100 101
    type=str,
    default="./",
C
caoying03 已提交
102
    help="The prefix of the path to store the trained models (default: ./).")
S
Superjom 已提交
103
parser.add_argument(
C
caoying03 已提交
104 105
    "-g",
    "--num_batches_to_log",
S
Superjom 已提交
106 107
    type=int,
    default=100,
C
caoying03 已提交
108 109
    help=("The log period. Every num_batches_to_test batches, "
          "a training log will be printed. (default: 100)"))
S
Superjom 已提交
110
parser.add_argument(
C
caoying03 已提交
111 112
    "-e",
    "--num_batches_to_test",
S
Superjom 已提交
113 114
    type=int,
    default=200,
C
caoying03 已提交
115 116
    help=("The test period. Every num_batches_to_save_model batches, "
          "the specified test sample will be test (default: 200)."))
S
Superjom 已提交
117
parser.add_argument(
C
caoying03 已提交
118 119
    "-z",
    "--num_batches_to_save_model",
S
Superjom 已提交
120 121
    type=int,
    default=400,
C
caoying03 已提交
122 123
    help=("Every num_batches_to_save_model batches, "
          "a trained model will be saved (default: 400)."))
S
Superjom 已提交
124 125

args = parser.parse_args()
S
Superjom 已提交
126
args.model_type = ModelType(args.model_type)
S
Superjom 已提交
127 128
args.model_arch = ModelArch(args.model_arch)
if args.model_type.is_classification():
C
caoying03 已提交
129 130
    assert args.class_num > 1, ("The parameter class_num should be set in "
                                "classification task.")
S
Superjom 已提交
131

C
caoying03 已提交
132 133 134
layer_dims = [int(i) for i in args.dnn_dims.split(",")]
args.target_dic_path = args.source_dic_path if not \
        args.target_dic_path else args.target_dic_path
S
Superjom 已提交
135

S
Superjom 已提交
136 137 138 139 140

def train(train_data_path=None,
          test_data_path=None,
          source_dic_path=None,
          target_dic_path=None,
S
Superjom 已提交
141 142
          model_type=ModelType.create_classification(),
          model_arch=ModelArch.create_cnn(),
R
ranqiu 已提交
143
          batch_size=32,
S
Superjom 已提交
144 145 146 147
          num_passes=10,
          share_semantic_generator=False,
          share_embed=False,
          class_num=None,
S
Superjom 已提交
148 149
          num_workers=1,
          use_gpu=False):
C
caoying03 已提交
150
    """
S
Superjom 已提交
151
    Train the DSSM.
C
caoying03 已提交
152 153 154 155
    """
    default_train_path = "./data/rank/train.txt"
    default_test_path = "./data/rank/test.txt"
    default_dic_path = "./data/vocab.txt"
S
Superjom 已提交
156
    if not model_type.is_rank():
C
caoying03 已提交
157 158
        default_train_path = "./data/classification/train.txt"
        default_test_path = "./data/classification/test.txt"
S
Superjom 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172

    use_default_data = not train_data_path

    if use_default_data:
        train_data_path = default_train_path
        test_data_path = default_test_path
        source_dic_path = default_dic_path
        target_dic_path = default_dic_path

    dataset = reader.Dataset(
        train_path=train_data_path,
        test_path=test_data_path,
        source_dic_path=source_dic_path,
        target_dic_path=target_dic_path,
S
Superjom 已提交
173
        model_type=model_type, )
S
Superjom 已提交
174 175

    train_reader = paddle.batch(
176 177
        paddle.reader.shuffle(
            dataset.train, buf_size=1000),
S
Superjom 已提交
178 179 180
        batch_size=batch_size)

    test_reader = paddle.batch(
181 182
        paddle.reader.shuffle(
            dataset.test, buf_size=1000),
S
Superjom 已提交
183 184
        batch_size=batch_size)

S
Superjom 已提交
185
    paddle.init(use_gpu=use_gpu, trainer_count=num_workers)
S
Superjom 已提交
186 187 188 189 190 191

    cost, prediction, label = DSSM(
        dnn_dims=layer_dims,
        vocab_sizes=[
            len(load_dic(path)) for path in [source_dic_path, target_dic_path]
        ],
S
Superjom 已提交
192
        model_type=model_type,
S
Superjom 已提交
193
        model_arch=model_arch,
S
Superjom 已提交
194 195 196 197 198 199 200
        share_semantic_generator=share_semantic_generator,
        class_num=class_num,
        share_embed=share_embed)()

    parameters = paddle.parameters.create(cost)

    adam_optimizer = paddle.optimizer.Adam(
R
ranqiu 已提交
201
        learning_rate=2e-4,
S
Superjom 已提交
202 203 204 205 206
        regularization=paddle.optimizer.L2Regularization(rate=1e-3),
        model_average=paddle.optimizer.ModelAverage(average_window=0.5))

    trainer = paddle.trainer.SGD(
        cost=cost,
S
Superjom 已提交
207 208
        extra_layers=paddle.evaluator.auc(input=prediction, label=label)
        if not model_type.is_rank() else None,
S
Superjom 已提交
209 210 211 212
        parameters=parameters,
        update_equation=adam_optimizer)

    feeding = {}
S
Superjom 已提交
213
    if model_type.is_classification() or model_type.is_regression():
C
caoying03 已提交
214
        feeding = {"source_input": 0, "target_input": 1, "label_input": 2}
S
Superjom 已提交
215 216
    else:
        feeding = {
C
caoying03 已提交
217 218 219 220
            "source_input": 0,
            "left_target_input": 1,
            "right_target_input": 2,
            "label_input": 3
S
Superjom 已提交
221 222 223
        }

    def _event_handler(event):
C
caoying03 已提交
224
        """
S
Superjom 已提交
225
        Define batch handler
C
caoying03 已提交
226
        """
S
Superjom 已提交
227
        if isinstance(event, paddle.event.EndIteration):
S
Superjom 已提交
228 229 230
            # output train log
            if event.batch_id % args.num_batches_to_log == 0:
                logger.info("Pass %d, Batch %d, Cost %f, %s" % (
S
Superjom 已提交
231 232
                    event.pass_id, event.batch_id, event.cost, event.metrics))

S
Superjom 已提交
233
            # test model
C
caoying03 已提交
234 235
            if event.batch_id > 0 and \
                    event.batch_id % args.num_batches_to_test == 0:
S
Superjom 已提交
236 237 238 239 240 241 242 243 244
                if test_reader is not None:
                    if model_type.is_classification():
                        result = trainer.test(
                            reader=test_reader, feeding=feeding)
                        logger.info("Test at Pass %d, %s" % (event.pass_id,
                                                             result.metrics))
                    else:
                        result = None
            # save model
C
caoying03 已提交
245 246
            if event.batch_id > 0 and \
                    event.batch_id % args.num_batches_to_save_model == 0:
S
Superjom 已提交
247 248 249 250 251
                model_desc = "{type}_{arch}".format(
                    type=str(args.model_type), arch=str(args.model_arch))
                with open("%sdssm_%s_pass_%05d.tar" %
                          (args.model_output_prefix, model_desc,
                           event.pass_id), "w") as f:
252
                    trainer.save_parameter_to_tar(f)
S
Superjom 已提交
253 254 255 256 257 258 259 260 261 262

    trainer.train(
        reader=train_reader,
        event_handler=_event_handler,
        feeding=feeding,
        num_passes=num_passes)

    logger.info("Training has finished.")


C
caoying03 已提交
263
if __name__ == "__main__":
S
Superjom 已提交
264
    display_args(args)
S
Superjom 已提交
265 266 267 268 269 270 271 272 273 274 275 276 277 278
    train(
        train_data_path=args.train_data_path,
        test_data_path=args.test_data_path,
        source_dic_path=args.source_dic_path,
        target_dic_path=args.target_dic_path,
        model_type=ModelType(args.model_type),
        model_arch=ModelArch(args.model_arch),
        batch_size=args.batch_size,
        num_passes=args.num_passes,
        share_semantic_generator=args.share_network_between_source_target,
        share_embed=args.share_embed,
        class_num=args.class_num,
        num_workers=args.num_workers,
        use_gpu=args.use_gpu)