train.py 8.3 KB
Newer Older
S
Superjom 已提交
1 2 3 4 5 6 7
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse

import paddle.v2 as paddle
from network_conf import DSSM
import reader
S
Superjom 已提交
8
from utils import TaskType, load_dic, logger, ModelType, ModelArch, display_args
S
Superjom 已提交
9 10 11 12

parser = argparse.ArgumentParser(description="PaddlePaddle DSSM example")

parser.add_argument(
S
Superjom 已提交
13
    '-i',
S
Superjom 已提交
14 15 16 17 18
    '--train_data_path',
    type=str,
    required=False,
    help="path of training dataset")
parser.add_argument(
S
Superjom 已提交
19
    '-t',
S
Superjom 已提交
20 21 22 23 24
    '--test_data_path',
    type=str,
    required=False,
    help="path of testing dataset")
parser.add_argument(
S
Superjom 已提交
25
    '-s',
S
Superjom 已提交
26 27 28 29 30 31 32 33 34 35 36
    '--source_dic_path',
    type=str,
    required=False,
    help="path of the source's word dic")
parser.add_argument(
    '--target_dic_path',
    type=str,
    required=False,
    help="path of the target's word dic, if not set, the `source_dic_path` will be used"
)
parser.add_argument(
S
Superjom 已提交
37
    '-b',
S
Superjom 已提交
38 39 40 41 42
    '--batch_size',
    type=int,
    default=10,
    help="size of mini-batch (default:10)")
parser.add_argument(
S
Superjom 已提交
43
    '-p',
S
Superjom 已提交
44 45 46 47 48
    '--num_passes',
    type=int,
    default=10,
    help="number of passes to run(default:10)")
parser.add_argument(
S
Superjom 已提交
49
    '-y',
S
Superjom 已提交
50
    '--model_type',
S
Superjom 已提交
51
    type=int,
S
Superjom 已提交
52 53
    required=True,
    default=ModelType.CLASSIFICATION_MODE,
S
Superjom 已提交
54 55 56
    help="model type, %d for classification, %d for pairwise rank, %d for regression (default: classification)"
    % (ModelType.CLASSIFICATION_MODE, ModelType.RANK_MODE,
       ModelType.REGRESSION_MODE))
S
Superjom 已提交
57
parser.add_argument(
S
Superjom 已提交
58
    '-a',
S
Superjom 已提交
59 60 61 62
    '--model_arch',
    type=int,
    required=True,
    default=ModelArch.CNN_MODE,
S
Superjom 已提交
63 64
    help="model architecture, %d for CNN, %d for FC, %d for RNN" %
    (ModelArch.CNN_MODE, ModelArch.FC_MODE, ModelArch.RNN_MODE))
S
Superjom 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78
parser.add_argument(
    '--share_network_between_source_target',
    type=bool,
    default=False,
    help="whether to share network parameters between source and target")
parser.add_argument(
    '--share_embed',
    type=bool,
    default=False,
    help="whether to share word embedding between source and target")
parser.add_argument(
    '--dnn_dims',
    type=str,
    default='256,128,64,32',
S
Superjom 已提交
79
    help="dimentions of dnn layers, default is '256,128,64,32', which means create a 4-layer dnn, demention of each layer is 256, 128, 64 and 32"
S
Superjom 已提交
80 81 82
)
parser.add_argument(
    '--num_workers', type=int, default=1, help="num worker threads, default 1")
S
Superjom 已提交
83 84 85 86 87 88 89 90 91 92 93
parser.add_argument(
    '--use_gpu',
    type=bool,
    default=False,
    help="whether to use GPU devices (default: False)")
parser.add_argument(
    '-c',
    '--class_num',
    type=int,
    default=0,
    help="number of categories for classification task.")
S
Superjom 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
parser.add_argument(
    '--model_output_prefix',
    type=str,
    default="./",
    help="prefix of the path for model to store, (default: ./)")
parser.add_argument(
    '-g',
    '--num_batches_to_log',
    type=int,
    default=100,
    help="number of batches to output train log, (default: 100)")
parser.add_argument(
    '-e',
    '--num_batches_to_test',
    type=int,
    default=200,
    help="number of batches to test, (default: 200)")
parser.add_argument(
    '-z',
    '--num_batches_to_save_model',
    type=int,
    default=400,
    help="number of batches to output model, (default: 400)")
S
Superjom 已提交
117

S
Superjom 已提交
118
# arguments check.
S
Superjom 已提交
119
args = parser.parse_args()
S
Superjom 已提交
120
args.model_type = ModelType(args.model_type)
S
Superjom 已提交
121 122 123
args.model_arch = ModelArch(args.model_arch)
if args.model_type.is_classification():
    assert args.class_num > 1, "--class_num should be set in classification task."
S
Superjom 已提交
124 125

layer_dims = [int(i) for i in args.dnn_dims.split(',')]
S
Superjom 已提交
126
args.target_dic_path = args.source_dic_path if not args.target_dic_path else args.target_dic_path
S
Superjom 已提交
127

S
Superjom 已提交
128 129 130 131 132

def train(train_data_path=None,
          test_data_path=None,
          source_dic_path=None,
          target_dic_path=None,
S
Superjom 已提交
133 134
          model_type=ModelType.create_classification(),
          model_arch=ModelArch.create_cnn(),
S
Superjom 已提交
135 136 137 138 139
          batch_size=10,
          num_passes=10,
          share_semantic_generator=False,
          share_embed=False,
          class_num=None,
S
Superjom 已提交
140 141
          num_workers=1,
          use_gpu=False):
S
Superjom 已提交
142 143 144 145 146 147
    '''
    Train the DSSM.
    '''
    default_train_path = './data/rank/train.txt'
    default_test_path = './data/rank/test.txt'
    default_dic_path = './data/vocab.txt'
S
Superjom 已提交
148
    if not model_type.is_rank():
S
Superjom 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
        default_train_path = './data/classification/train.txt'
        default_test_path = './data/classification/test.txt'

    use_default_data = not train_data_path

    if use_default_data:
        train_data_path = default_train_path
        test_data_path = default_test_path
        source_dic_path = default_dic_path
        target_dic_path = default_dic_path

    dataset = reader.Dataset(
        train_path=train_data_path,
        test_path=test_data_path,
        source_dic_path=source_dic_path,
        target_dic_path=target_dic_path,
S
Superjom 已提交
165
        model_type=model_type, )
S
Superjom 已提交
166 167 168 169 170 171 172 173 174

    train_reader = paddle.batch(
        paddle.reader.shuffle(dataset.train, buf_size=1000),
        batch_size=batch_size)

    test_reader = paddle.batch(
        paddle.reader.shuffle(dataset.test, buf_size=1000),
        batch_size=batch_size)

S
Superjom 已提交
175
    paddle.init(use_gpu=use_gpu, trainer_count=num_workers)
S
Superjom 已提交
176 177 178 179 180 181

    cost, prediction, label = DSSM(
        dnn_dims=layer_dims,
        vocab_sizes=[
            len(load_dic(path)) for path in [source_dic_path, target_dic_path]
        ],
S
Superjom 已提交
182
        model_type=model_type,
S
Superjom 已提交
183
        model_arch=model_arch,
S
Superjom 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196
        share_semantic_generator=share_semantic_generator,
        class_num=class_num,
        share_embed=share_embed)()

    parameters = paddle.parameters.create(cost)

    adam_optimizer = paddle.optimizer.Adam(
        learning_rate=1e-3,
        regularization=paddle.optimizer.L2Regularization(rate=1e-3),
        model_average=paddle.optimizer.ModelAverage(average_window=0.5))

    trainer = paddle.trainer.SGD(
        cost=cost,
S
Superjom 已提交
197 198
        extra_layers=paddle.evaluator.auc(input=prediction, label=label)
        if not model_type.is_rank() else None,
S
Superjom 已提交
199 200 201 202
        parameters=parameters,
        update_equation=adam_optimizer)

    feeding = {}
S
Superjom 已提交
203
    if model_type.is_classification() or model_type.is_regression():
S
Superjom 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217
        feeding = {'source_input': 0, 'target_input': 1, 'label_input': 2}
    else:
        feeding = {
            'source_input': 0,
            'left_target_input': 1,
            'right_target_input': 2,
            'label_input': 3
        }

    def _event_handler(event):
        '''
        Define batch handler
        '''
        if isinstance(event, paddle.event.EndIteration):
S
Superjom 已提交
218 219 220
            # output train log
            if event.batch_id % args.num_batches_to_log == 0:
                logger.info("Pass %d, Batch %d, Cost %f, %s" % (
S
Superjom 已提交
221 222
                    event.pass_id, event.batch_id, event.cost, event.metrics))

S
Superjom 已提交
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
            # test model
            if event.batch_id > 0 and event.batch_id % args.num_batches_to_test == 0:
                if test_reader is not None:
                    if model_type.is_classification():
                        result = trainer.test(
                            reader=test_reader, feeding=feeding)
                        logger.info("Test at Pass %d, %s" % (event.pass_id,
                                                             result.metrics))
                    else:
                        result = None
            # save model
            if event.batch_id > 0 and event.batch_id % args.num_batches_to_save_model == 0:
                model_desc = "{type}_{arch}".format(
                    type=str(args.model_type), arch=str(args.model_arch))
                with open("%sdssm_%s_pass_%05d.tar" %
                          (args.model_output_prefix, model_desc,
                           event.pass_id), "w") as f:
                    parameters.to_tar(f)
S
Superjom 已提交
241 242 243 244 245 246 247 248 249 250 251

    trainer.train(
        reader=train_reader,
        event_handler=_event_handler,
        feeding=feeding,
        num_passes=num_passes)

    logger.info("Training has finished.")


if __name__ == '__main__':
S
Superjom 已提交
252
    display_args(args)
S
Superjom 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265 266
    train(
        train_data_path=args.train_data_path,
        test_data_path=args.test_data_path,
        source_dic_path=args.source_dic_path,
        target_dic_path=args.target_dic_path,
        model_type=ModelType(args.model_type),
        model_arch=ModelArch(args.model_arch),
        batch_size=args.batch_size,
        num_passes=args.num_passes,
        share_semantic_generator=args.share_network_between_source_target,
        share_embed=args.share_embed,
        class_num=args.class_num,
        num_workers=args.num_workers,
        use_gpu=args.use_gpu)