train.py 7.3 KB
Newer Older
S
Superjom 已提交
1 2 3 4 5 6 7 8
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import gzip

import paddle.v2 as paddle
from network_conf import DSSM
import reader
S
Superjom 已提交
9
from utils import TaskType, load_dic, logger, ModelType, ModelArch
S
Superjom 已提交
10 11 12 13

parser = argparse.ArgumentParser(description="PaddlePaddle DSSM example")

parser.add_argument(
S
Superjom 已提交
14
    '-i',
S
Superjom 已提交
15 16 17 18 19
    '--train_data_path',
    type=str,
    required=False,
    help="path of training dataset")
parser.add_argument(
S
Superjom 已提交
20
    '-t',
S
Superjom 已提交
21 22 23 24 25
    '--test_data_path',
    type=str,
    required=False,
    help="path of testing dataset")
parser.add_argument(
S
Superjom 已提交
26
    '-s',
S
Superjom 已提交
27 28 29 30 31 32 33 34 35 36 37
    '--source_dic_path',
    type=str,
    required=False,
    help="path of the source's word dic")
parser.add_argument(
    '--target_dic_path',
    type=str,
    required=False,
    help="path of the target's word dic, if not set, the `source_dic_path` will be used"
)
parser.add_argument(
S
Superjom 已提交
38
    '-b',
S
Superjom 已提交
39 40 41 42 43
    '--batch_size',
    type=int,
    default=10,
    help="size of mini-batch (default:10)")
parser.add_argument(
S
Superjom 已提交
44
    '-p',
S
Superjom 已提交
45 46 47 48 49
    '--num_passes',
    type=int,
    default=10,
    help="number of passes to run(default:10)")
parser.add_argument(
S
Superjom 已提交
50
    '-y',
S
Superjom 已提交
51
    '--model_type',
S
Superjom 已提交
52
    type=int,
S
Superjom 已提交
53 54
    required=True,
    default=ModelType.CLASSIFICATION_MODE,
S
Superjom 已提交
55
    help="model type, %d for classification, %d for pairwise rank (default: classification)"
S
Superjom 已提交
56 57 58 59 60 61 62 63
    % (ModelType.CLASSIFICATION_MODE, ModelType.RANK_MODE))
parser.add_argument(
    '--model_arch',
    type=int,
    required=True,
    default=ModelArch.CNN_MODE,
    help="model architecture, %d for CNN, %d for FC" % (ModelArch.CNN_MODE,
                                                        ModelArch.FC_MODE))
S
Superjom 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77
parser.add_argument(
    '--share_network_between_source_target',
    type=bool,
    default=False,
    help="whether to share network parameters between source and target")
parser.add_argument(
    '--share_embed',
    type=bool,
    default=False,
    help="whether to share word embedding between source and target")
parser.add_argument(
    '--dnn_dims',
    type=str,
    default='256,128,64,32',
S
Superjom 已提交
78
    help="dimentions of dnn layers, default is '256,128,64,32', which means create a 4-layer dnn, demention of each layer is 256, 128, 64 and 32"
S
Superjom 已提交
79 80 81
)
parser.add_argument(
    '--num_workers', type=int, default=1, help="num worker threads, default 1")
S
Superjom 已提交
82 83 84 85 86 87 88 89 90 91 92
parser.add_argument(
    '--use_gpu',
    type=bool,
    default=False,
    help="whether to use GPU devices (default: False)")
parser.add_argument(
    '-c',
    '--class_num',
    type=int,
    default=0,
    help="number of categories for classification task.")
S
Superjom 已提交
93

S
Superjom 已提交
94
# arguments check.
S
Superjom 已提交
95
args = parser.parse_args()
S
Superjom 已提交
96
args.model_type = ModelType(args.model_type)
S
Superjom 已提交
97 98 99
args.model_arch = ModelArch(args.model_arch)
if args.model_type.is_classification():
    assert args.class_num > 1, "--class_num should be set in classification task."
S
Superjom 已提交
100 101 102 103

layer_dims = [int(i) for i in args.dnn_dims.split(',')]
target_dic_path = args.source_dic_path if not args.target_dic_path else args.target_dic_path

S
Superjom 已提交
104 105 106
model_save_name_prefix = "dssm_pass_%s_%s" % (args.model_type,
                                              args.model_arch, )

S
Superjom 已提交
107 108 109 110 111

def train(train_data_path=None,
          test_data_path=None,
          source_dic_path=None,
          target_dic_path=None,
S
Superjom 已提交
112 113
          model_type=ModelType.create_classification(),
          model_arch=ModelArch.create_cnn(),
S
Superjom 已提交
114 115 116 117 118
          batch_size=10,
          num_passes=10,
          share_semantic_generator=False,
          share_embed=False,
          class_num=None,
S
Superjom 已提交
119 120
          num_workers=1,
          use_gpu=False):
S
Superjom 已提交
121 122 123 124 125 126
    '''
    Train the DSSM.
    '''
    default_train_path = './data/rank/train.txt'
    default_test_path = './data/rank/test.txt'
    default_dic_path = './data/vocab.txt'
S
Superjom 已提交
127
    if model_type.is_classification():
S
Superjom 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
        default_train_path = './data/classification/train.txt'
        default_test_path = './data/classification/test.txt'

    use_default_data = not train_data_path

    if use_default_data:
        train_data_path = default_train_path
        test_data_path = default_test_path
        source_dic_path = default_dic_path
        target_dic_path = default_dic_path

    dataset = reader.Dataset(
        train_path=train_data_path,
        test_path=test_data_path,
        source_dic_path=source_dic_path,
        target_dic_path=target_dic_path,
S
Superjom 已提交
144
        model_type=model_type, )
S
Superjom 已提交
145 146 147 148 149 150 151 152 153

    train_reader = paddle.batch(
        paddle.reader.shuffle(dataset.train, buf_size=1000),
        batch_size=batch_size)

    test_reader = paddle.batch(
        paddle.reader.shuffle(dataset.test, buf_size=1000),
        batch_size=batch_size)

S
Superjom 已提交
154
    paddle.init(use_gpu=use_gpu, trainer_count=num_workers)
S
Superjom 已提交
155 156 157 158 159 160

    cost, prediction, label = DSSM(
        dnn_dims=layer_dims,
        vocab_sizes=[
            len(load_dic(path)) for path in [source_dic_path, target_dic_path]
        ],
S
Superjom 已提交
161
        model_type=model_type,
S
Superjom 已提交
162
        model_arch=model_arch,
S
Superjom 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
        share_semantic_generator=share_semantic_generator,
        class_num=class_num,
        share_embed=share_embed)()

    parameters = paddle.parameters.create(cost)

    adam_optimizer = paddle.optimizer.Adam(
        learning_rate=1e-3,
        regularization=paddle.optimizer.L2Regularization(rate=1e-3),
        model_average=paddle.optimizer.ModelAverage(average_window=0.5))

    trainer = paddle.trainer.SGD(
        cost=cost,
        extra_layers=paddle.evaluator.auc(input=prediction, label=label)
        if prediction else None,
        parameters=parameters,
        update_equation=adam_optimizer)

    feeding = {}
S
Superjom 已提交
182
    if model_type.is_classification():
S
Superjom 已提交
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
        feeding = {'source_input': 0, 'target_input': 1, 'label_input': 2}
    else:
        feeding = {
            'source_input': 0,
            'left_target_input': 1,
            'right_target_input': 2,
            'label_input': 3
        }

    def _event_handler(event):
        '''
        Define batch handler
        '''
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 100 == 0:
                logger.info("Pass %d, Batch %d, Cost %f, %s\n" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics))

        if isinstance(event, paddle.event.EndPass):
            if test_reader is not None:
S
Superjom 已提交
203
                if model_type.is_classification():
S
Superjom 已提交
204 205 206 207 208
                    result = trainer.test(reader=test_reader, feeding=feeding)
                    logger.info("Test at Pass %d, %s \n" % (event.pass_id,
                                                            result.metrics))
                else:
                    result = None
S
Superjom 已提交
209 210
            with gzip.open("dssm_%s_pass_%05d.tar.gz" %
                           (model_save_name_prefix, event.pass_id), "w") as f:
S
Superjom 已提交
211 212 213 214 215 216 217 218 219 220 221 222
                parameters.to_tar(f)

    trainer.train(
        reader=train_reader,
        event_handler=_event_handler,
        feeding=feeding,
        num_passes=num_passes)

    logger.info("Training has finished.")


if __name__ == '__main__':
S
Superjom 已提交
223 224 225 226 227 228 229 230 231 232 233 234 235 236
    train(
        train_data_path=args.train_data_path,
        test_data_path=args.test_data_path,
        source_dic_path=args.source_dic_path,
        target_dic_path=args.target_dic_path,
        model_type=ModelType(args.model_type),
        model_arch=ModelArch(args.model_arch),
        batch_size=args.batch_size,
        num_passes=args.num_passes,
        share_semantic_generator=args.share_network_between_source_target,
        share_embed=args.share_embed,
        class_num=args.class_num,
        num_workers=args.num_workers,
        use_gpu=args.use_gpu)