train.py 7.6 KB
Newer Older
S
Superjom 已提交
1 2 3 4 5 6 7 8
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import gzip

import paddle.v2 as paddle
from network_conf import DSSM
import reader
S
Superjom 已提交
9
from utils import TaskType, load_dic, logger, ModelType, ModelArch
S
Superjom 已提交
10 11 12 13

parser = argparse.ArgumentParser(description="PaddlePaddle DSSM example")

parser.add_argument(
S
Superjom 已提交
14
    '-i',
S
Superjom 已提交
15 16 17 18 19
    '--train_data_path',
    type=str,
    required=False,
    help="path of training dataset")
parser.add_argument(
S
Superjom 已提交
20
    '-t',
S
Superjom 已提交
21 22 23 24 25
    '--test_data_path',
    type=str,
    required=False,
    help="path of testing dataset")
parser.add_argument(
S
Superjom 已提交
26
    '-s',
S
Superjom 已提交
27 28 29 30 31 32 33 34 35 36 37
    '--source_dic_path',
    type=str,
    required=False,
    help="path of the source's word dic")
parser.add_argument(
    '--target_dic_path',
    type=str,
    required=False,
    help="path of the target's word dic, if not set, the `source_dic_path` will be used"
)
parser.add_argument(
S
Superjom 已提交
38
    '-b',
S
Superjom 已提交
39 40 41 42 43
    '--batch_size',
    type=int,
    default=10,
    help="size of mini-batch (default:10)")
parser.add_argument(
S
Superjom 已提交
44
    '-p',
S
Superjom 已提交
45 46 47 48 49
    '--num_passes',
    type=int,
    default=10,
    help="number of passes to run(default:10)")
parser.add_argument(
S
Superjom 已提交
50
    '-y',
S
Superjom 已提交
51
    '--model_type',
S
Superjom 已提交
52
    type=int,
S
Superjom 已提交
53 54
    required=True,
    default=ModelType.CLASSIFICATION_MODE,
S
Superjom 已提交
55 56 57
    help="model type, %d for classification, %d for pairwise rank, %d for regression (default: classification)"
    % (ModelType.CLASSIFICATION_MODE, ModelType.RANK_MODE,
       ModelType.REGRESSION_MODE))
S
Superjom 已提交
58 59 60 61 62
parser.add_argument(
    '--model_arch',
    type=int,
    required=True,
    default=ModelArch.CNN_MODE,
S
Superjom 已提交
63 64
    help="model architecture, %d for CNN, %d for FC, %d for RNN" %
    (ModelArch.CNN_MODE, ModelArch.FC_MODE, ModelArch.RNN_MODE))
S
Superjom 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78
parser.add_argument(
    '--share_network_between_source_target',
    type=bool,
    default=False,
    help="whether to share network parameters between source and target")
parser.add_argument(
    '--share_embed',
    type=bool,
    default=False,
    help="whether to share word embedding between source and target")
parser.add_argument(
    '--dnn_dims',
    type=str,
    default='256,128,64,32',
S
Superjom 已提交
79
    help="dimentions of dnn layers, default is '256,128,64,32', which means create a 4-layer dnn, demention of each layer is 256, 128, 64 and 32"
S
Superjom 已提交
80 81 82
)
parser.add_argument(
    '--num_workers', type=int, default=1, help="num worker threads, default 1")
S
Superjom 已提交
83 84 85 86 87 88 89 90 91 92 93
parser.add_argument(
    '--use_gpu',
    type=bool,
    default=False,
    help="whether to use GPU devices (default: False)")
parser.add_argument(
    '-c',
    '--class_num',
    type=int,
    default=0,
    help="number of categories for classification task.")
S
Superjom 已提交
94

S
Superjom 已提交
95
# arguments check.
S
Superjom 已提交
96
args = parser.parse_args()
S
Superjom 已提交
97
args.model_type = ModelType(args.model_type)
S
Superjom 已提交
98 99 100
args.model_arch = ModelArch(args.model_arch)
if args.model_type.is_classification():
    assert args.class_num > 1, "--class_num should be set in classification task."
S
Superjom 已提交
101 102 103 104

layer_dims = [int(i) for i in args.dnn_dims.split(',')]
target_dic_path = args.source_dic_path if not args.target_dic_path else args.target_dic_path

S
Superjom 已提交
105 106 107
model_save_name_prefix = "dssm_pass_%s_%s" % (args.model_type,
                                              args.model_arch, )

S
Superjom 已提交
108 109 110 111 112

def train(train_data_path=None,
          test_data_path=None,
          source_dic_path=None,
          target_dic_path=None,
S
Superjom 已提交
113 114
          model_type=ModelType.create_classification(),
          model_arch=ModelArch.create_cnn(),
S
Superjom 已提交
115 116 117 118 119
          batch_size=10,
          num_passes=10,
          share_semantic_generator=False,
          share_embed=False,
          class_num=None,
S
Superjom 已提交
120 121
          num_workers=1,
          use_gpu=False):
S
Superjom 已提交
122 123 124 125 126 127
    '''
    Train the DSSM.
    '''
    default_train_path = './data/rank/train.txt'
    default_test_path = './data/rank/test.txt'
    default_dic_path = './data/vocab.txt'
S
Superjom 已提交
128
    if not model_type.is_rank():
S
Superjom 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
        default_train_path = './data/classification/train.txt'
        default_test_path = './data/classification/test.txt'

    use_default_data = not train_data_path

    if use_default_data:
        train_data_path = default_train_path
        test_data_path = default_test_path
        source_dic_path = default_dic_path
        target_dic_path = default_dic_path

    dataset = reader.Dataset(
        train_path=train_data_path,
        test_path=test_data_path,
        source_dic_path=source_dic_path,
        target_dic_path=target_dic_path,
S
Superjom 已提交
145
        model_type=model_type, )
S
Superjom 已提交
146 147 148 149 150 151 152 153 154

    train_reader = paddle.batch(
        paddle.reader.shuffle(dataset.train, buf_size=1000),
        batch_size=batch_size)

    test_reader = paddle.batch(
        paddle.reader.shuffle(dataset.test, buf_size=1000),
        batch_size=batch_size)

S
Superjom 已提交
155
    paddle.init(use_gpu=use_gpu, trainer_count=num_workers)
S
Superjom 已提交
156 157 158 159 160 161

    cost, prediction, label = DSSM(
        dnn_dims=layer_dims,
        vocab_sizes=[
            len(load_dic(path)) for path in [source_dic_path, target_dic_path]
        ],
S
Superjom 已提交
162
        model_type=model_type,
S
Superjom 已提交
163
        model_arch=model_arch,
S
Superjom 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176
        share_semantic_generator=share_semantic_generator,
        class_num=class_num,
        share_embed=share_embed)()

    parameters = paddle.parameters.create(cost)

    adam_optimizer = paddle.optimizer.Adam(
        learning_rate=1e-3,
        regularization=paddle.optimizer.L2Regularization(rate=1e-3),
        model_average=paddle.optimizer.ModelAverage(average_window=0.5))

    trainer = paddle.trainer.SGD(
        cost=cost,
S
Superjom 已提交
177
        extra_layers=None,
S
Superjom 已提交
178 179
        parameters=parameters,
        update_equation=adam_optimizer)
S
Superjom 已提交
180 181 182 183 184 185
    # trainer = paddle.trainer.SGD(
    #     cost=cost,
    #     extra_layers=paddle.evaluator.auc(input=prediction, label=label)
    #     if prediction and model_type.is_classification() else None,
    #     parameters=parameters,
    #     update_equation=adam_optimizer)
S
Superjom 已提交
186 187

    feeding = {}
S
Superjom 已提交
188
    if model_type.is_classification() or model_type.is_regression():
S
Superjom 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
        feeding = {'source_input': 0, 'target_input': 1, 'label_input': 2}
    else:
        feeding = {
            'source_input': 0,
            'left_target_input': 1,
            'right_target_input': 2,
            'label_input': 3
        }

    def _event_handler(event):
        '''
        Define batch handler
        '''
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 100 == 0:
                logger.info("Pass %d, Batch %d, Cost %f, %s\n" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics))

        if isinstance(event, paddle.event.EndPass):
            if test_reader is not None:
S
Superjom 已提交
209
                if model_type.is_classification():
S
Superjom 已提交
210 211 212 213 214
                    result = trainer.test(reader=test_reader, feeding=feeding)
                    logger.info("Test at Pass %d, %s \n" % (event.pass_id,
                                                            result.metrics))
                else:
                    result = None
S
Superjom 已提交
215 216
            with gzip.open("dssm_%s_pass_%05d.tar.gz" %
                           (model_save_name_prefix, event.pass_id), "w") as f:
S
Superjom 已提交
217 218 219 220 221 222 223 224 225 226 227 228
                parameters.to_tar(f)

    trainer.train(
        reader=train_reader,
        event_handler=_event_handler,
        feeding=feeding,
        num_passes=num_passes)

    logger.info("Training has finished.")


if __name__ == '__main__':
S
Superjom 已提交
229 230 231 232 233 234 235 236 237 238 239 240 241 242
    train(
        train_data_path=args.train_data_path,
        test_data_path=args.test_data_path,
        source_dic_path=args.source_dic_path,
        target_dic_path=args.target_dic_path,
        model_type=ModelType(args.model_type),
        model_arch=ModelArch(args.model_arch),
        batch_size=args.batch_size,
        num_passes=args.num_passes,
        share_semantic_generator=args.share_network_between_source_target,
        share_embed=args.share_embed,
        class_num=args.class_num,
        num_workers=args.num_workers,
        use_gpu=args.use_gpu)