train.py 8.5 KB
Newer Older
S
Superjom 已提交
1
import argparse
R
ranqiu 已提交
2
import distutils.util
S
Superjom 已提交
3 4 5 6

import paddle.v2 as paddle
from network_conf import DSSM
import reader
S
Superjom 已提交
7
from utils import TaskType, load_dic, logger, ModelType, ModelArch, display_args
S
Superjom 已提交
8 9 10 11

parser = argparse.ArgumentParser(description="PaddlePaddle DSSM example")

parser.add_argument(
S
Superjom 已提交
12
    '-i',
S
Superjom 已提交
13 14 15 16 17
    '--train_data_path',
    type=str,
    required=False,
    help="path of training dataset")
parser.add_argument(
S
Superjom 已提交
18
    '-t',
S
Superjom 已提交
19 20 21 22 23
    '--test_data_path',
    type=str,
    required=False,
    help="path of testing dataset")
parser.add_argument(
S
Superjom 已提交
24
    '-s',
S
Superjom 已提交
25 26 27 28 29 30 31 32
    '--source_dic_path',
    type=str,
    required=False,
    help="path of the source's word dic")
parser.add_argument(
    '--target_dic_path',
    type=str,
    required=False,
C
caoying03 已提交
33 34
    help=("path of the target's word dictionary, "
          "if not set, the `source_dic_path` will be used"))
S
Superjom 已提交
35
parser.add_argument(
S
Superjom 已提交
36
    '-b',
S
Superjom 已提交
37 38
    '--batch_size',
    type=int,
R
ranqiu 已提交
39 40
    default=32,
    help="size of mini-batch (default:32)")
S
Superjom 已提交
41
parser.add_argument(
S
Superjom 已提交
42
    '-p',
S
Superjom 已提交
43 44 45 46 47
    '--num_passes',
    type=int,
    default=10,
    help="number of passes to run(default:10)")
parser.add_argument(
S
Superjom 已提交
48
    '-y',
S
Superjom 已提交
49
    '--model_type',
S
Superjom 已提交
50
    type=int,
S
Superjom 已提交
51 52
    required=True,
    default=ModelType.CLASSIFICATION_MODE,
S
Superjom 已提交
53 54 55
    help="model type, %d for classification, %d for pairwise rank, %d for regression (default: classification)"
    % (ModelType.CLASSIFICATION_MODE, ModelType.RANK_MODE,
       ModelType.REGRESSION_MODE))
S
Superjom 已提交
56
parser.add_argument(
S
Superjom 已提交
57
    '-a',
S
Superjom 已提交
58 59 60 61
    '--model_arch',
    type=int,
    required=True,
    default=ModelArch.CNN_MODE,
S
Superjom 已提交
62 63
    help="model architecture, %d for CNN, %d for FC, %d for RNN" %
    (ModelArch.CNN_MODE, ModelArch.FC_MODE, ModelArch.RNN_MODE))
S
Superjom 已提交
64 65
parser.add_argument(
    '--share_network_between_source_target',
R
ranqiu 已提交
66
    type=distutils.util.strtobool,
S
Superjom 已提交
67 68 69 70
    default=False,
    help="whether to share network parameters between source and target")
parser.add_argument(
    '--share_embed',
R
ranqiu 已提交
71
    type=distutils.util.strtobool,
S
Superjom 已提交
72 73 74 75 76 77
    default=False,
    help="whether to share word embedding between source and target")
parser.add_argument(
    '--dnn_dims',
    type=str,
    default='256,128,64,32',
S
Superjom 已提交
78
    help="dimentions of dnn layers, default is '256,128,64,32', which means create a 4-layer dnn, demention of each layer is 256, 128, 64 and 32"
S
Superjom 已提交
79 80 81
)
parser.add_argument(
    '--num_workers', type=int, default=1, help="num worker threads, default 1")
S
Superjom 已提交
82 83
parser.add_argument(
    '--use_gpu',
R
ranqiu 已提交
84
    type=distutils.util.strtobool,
S
Superjom 已提交
85 86 87 88 89 90 91 92
    default=False,
    help="whether to use GPU devices (default: False)")
parser.add_argument(
    '-c',
    '--class_num',
    type=int,
    default=0,
    help="number of categories for classification task.")
S
Superjom 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
parser.add_argument(
    '--model_output_prefix',
    type=str,
    default="./",
    help="prefix of the path for model to store, (default: ./)")
parser.add_argument(
    '-g',
    '--num_batches_to_log',
    type=int,
    default=100,
    help="number of batches to output train log, (default: 100)")
parser.add_argument(
    '-e',
    '--num_batches_to_test',
    type=int,
    default=200,
    help="number of batches to test, (default: 200)")
parser.add_argument(
    '-z',
    '--num_batches_to_save_model',
    type=int,
    default=400,
    help="number of batches to output model, (default: 400)")
S
Superjom 已提交
116

S
Superjom 已提交
117
# arguments check.
S
Superjom 已提交
118
args = parser.parse_args()
S
Superjom 已提交
119
args.model_type = ModelType(args.model_type)
S
Superjom 已提交
120 121 122
args.model_arch = ModelArch(args.model_arch)
if args.model_type.is_classification():
    assert args.class_num > 1, "--class_num should be set in classification task."
S
Superjom 已提交
123 124

layer_dims = [int(i) for i in args.dnn_dims.split(',')]
S
Superjom 已提交
125
args.target_dic_path = args.source_dic_path if not args.target_dic_path else args.target_dic_path
S
Superjom 已提交
126

S
Superjom 已提交
127 128 129 130 131

def train(train_data_path=None,
          test_data_path=None,
          source_dic_path=None,
          target_dic_path=None,
S
Superjom 已提交
132 133
          model_type=ModelType.create_classification(),
          model_arch=ModelArch.create_cnn(),
S
Superjom 已提交
134 135 136 137 138
          batch_size=10,
          num_passes=10,
          share_semantic_generator=False,
          share_embed=False,
          class_num=None,
S
Superjom 已提交
139 140
          num_workers=1,
          use_gpu=False):
S
Superjom 已提交
141 142 143 144 145 146
    '''
    Train the DSSM.
    '''
    default_train_path = './data/rank/train.txt'
    default_test_path = './data/rank/test.txt'
    default_dic_path = './data/vocab.txt'
S
Superjom 已提交
147
    if not model_type.is_rank():
S
Superjom 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
        default_train_path = './data/classification/train.txt'
        default_test_path = './data/classification/test.txt'

    use_default_data = not train_data_path

    if use_default_data:
        train_data_path = default_train_path
        test_data_path = default_test_path
        source_dic_path = default_dic_path
        target_dic_path = default_dic_path

    dataset = reader.Dataset(
        train_path=train_data_path,
        test_path=test_data_path,
        source_dic_path=source_dic_path,
        target_dic_path=target_dic_path,
S
Superjom 已提交
164
        model_type=model_type, )
S
Superjom 已提交
165 166 167 168 169 170 171 172 173

    train_reader = paddle.batch(
        paddle.reader.shuffle(dataset.train, buf_size=1000),
        batch_size=batch_size)

    test_reader = paddle.batch(
        paddle.reader.shuffle(dataset.test, buf_size=1000),
        batch_size=batch_size)

S
Superjom 已提交
174
    paddle.init(use_gpu=use_gpu, trainer_count=num_workers)
S
Superjom 已提交
175 176 177 178 179 180

    cost, prediction, label = DSSM(
        dnn_dims=layer_dims,
        vocab_sizes=[
            len(load_dic(path)) for path in [source_dic_path, target_dic_path]
        ],
S
Superjom 已提交
181
        model_type=model_type,
S
Superjom 已提交
182
        model_arch=model_arch,
S
Superjom 已提交
183 184 185 186 187 188 189 190 191 192 193 194 195
        share_semantic_generator=share_semantic_generator,
        class_num=class_num,
        share_embed=share_embed)()

    parameters = paddle.parameters.create(cost)

    adam_optimizer = paddle.optimizer.Adam(
        learning_rate=1e-3,
        regularization=paddle.optimizer.L2Regularization(rate=1e-3),
        model_average=paddle.optimizer.ModelAverage(average_window=0.5))

    trainer = paddle.trainer.SGD(
        cost=cost,
S
Superjom 已提交
196 197
        extra_layers=paddle.evaluator.auc(input=prediction, label=label)
        if not model_type.is_rank() else None,
S
Superjom 已提交
198 199 200 201
        parameters=parameters,
        update_equation=adam_optimizer)

    feeding = {}
S
Superjom 已提交
202
    if model_type.is_classification() or model_type.is_regression():
S
Superjom 已提交
203 204 205 206 207 208 209 210 211 212 213 214 215 216
        feeding = {'source_input': 0, 'target_input': 1, 'label_input': 2}
    else:
        feeding = {
            'source_input': 0,
            'left_target_input': 1,
            'right_target_input': 2,
            'label_input': 3
        }

    def _event_handler(event):
        '''
        Define batch handler
        '''
        if isinstance(event, paddle.event.EndIteration):
S
Superjom 已提交
217 218 219
            # output train log
            if event.batch_id % args.num_batches_to_log == 0:
                logger.info("Pass %d, Batch %d, Cost %f, %s" % (
S
Superjom 已提交
220 221
                    event.pass_id, event.batch_id, event.cost, event.metrics))

S
Superjom 已提交
222
            # test model
C
caoying03 已提交
223 224
            if event.batch_id > 0 and \
                    event.batch_id % args.num_batches_to_test == 0:
S
Superjom 已提交
225 226 227 228 229 230 231 232 233
                if test_reader is not None:
                    if model_type.is_classification():
                        result = trainer.test(
                            reader=test_reader, feeding=feeding)
                        logger.info("Test at Pass %d, %s" % (event.pass_id,
                                                             result.metrics))
                    else:
                        result = None
            # save model
C
caoying03 已提交
234 235
            if event.batch_id > 0 and \
                    event.batch_id % args.num_batches_to_save_model == 0:
S
Superjom 已提交
236 237 238 239 240
                model_desc = "{type}_{arch}".format(
                    type=str(args.model_type), arch=str(args.model_arch))
                with open("%sdssm_%s_pass_%05d.tar" %
                          (args.model_output_prefix, model_desc,
                           event.pass_id), "w") as f:
241
                    trainer.save_parameter_to_tar(f)
S
Superjom 已提交
242 243 244 245 246 247 248 249 250 251 252

    trainer.train(
        reader=train_reader,
        event_handler=_event_handler,
        feeding=feeding,
        num_passes=num_passes)

    logger.info("Training has finished.")


if __name__ == '__main__':
S
Superjom 已提交
253
    display_args(args)
S
Superjom 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267
    train(
        train_data_path=args.train_data_path,
        test_data_path=args.test_data_path,
        source_dic_path=args.source_dic_path,
        target_dic_path=args.target_dic_path,
        model_type=ModelType(args.model_type),
        model_arch=ModelArch(args.model_arch),
        batch_size=args.batch_size,
        num_passes=args.num_passes,
        share_semantic_generator=args.share_network_between_source_target,
        share_embed=args.share_embed,
        class_num=args.class_num,
        num_workers=args.num_workers,
        use_gpu=args.use_gpu)