run_classifier.py 19.7 KB
Newer Older
L
Li Fuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
Yibing Liu 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27
"""
SimNet Task
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import argparse
import multiprocessing
import sys

D
Dilyar 已提交
28 29 30 31 32
defaultencoding = 'utf-8'
if sys.getdefaultencoding() != defaultencoding:
    reload(sys)
    sys.setdefaultencoding(defaultencoding)

P
pkpk 已提交
33
sys.path.append("../shared_modules/")
Y
Yibing Liu 已提交
34 35 36 37 38 39 40 41

import paddle
import paddle.fluid as fluid
import numpy as np
import config
import utils
import reader
import models.matching.paddle_layers as layers
42 43
import io
import logging
D
Dilyar 已提交
44 45

from utils import ArgConfig
46 47
from models.model_check import check_version
from models.model_check import check_cuda
Y
Yibing Liu 已提交
48 49


50
def create_model(args, is_inference=False, is_pointwise=False):
D
Dilyar 已提交
51 52 53 54
    """
    Create Model for simnet
    """
    if is_inference:
55 56 57 58 59 60 61 62
        left = fluid.data(name='left', shape=[None], dtype='int64', lod_level=1)
        pos_right = fluid.data(
            name='pos_right', shape=[None], dtype='int64', lod_level=1)
        inf_loader = fluid.io.DataLoader.from_generator(
            capacity=16,
            feed_list=[left, pos_right],
            iterable=False,
            use_double_buffer=False)
D
Dilyar 已提交
63

64
        return inf_loader, left, pos_right
Y
Yibing Liu 已提交
65

D
Dilyar 已提交
66 67
    else:
        if is_pointwise:
68 69 70 71 72 73 74 75 76 77 78 79
            left = fluid.data(
                name='left', shape=[None], dtype='int64', lod_level=1)
            right = fluid.data(
                name='right', shape=[None], dtype='int64', lod_level=1)
            label = fluid.data(name='label', shape=[None], dtype='int64')
            pointwise_loader = fluid.io.DataLoader.from_generator(
                capacity=16,
                feed_list=[left, right, label],
                iterable=False,
                use_double_buffer=False)

            return pointwise_loader, left, right, label
Y
Yibing Liu 已提交
80

D
Dilyar 已提交
81
        else:
82 83 84 85 86 87 88 89 90 91 92
            left = fluid.data(
                name='left', shape=[None], dtype='int64', lod_level=1)
            pos_right = fluid.data(
                name='pos_right', shape=[None], dtype='int64', lod_level=1)
            neg_right = fluid.data(
                name='neg_right', shape=[None], dtype='int64', lod_level=1)
            pairwise_loader = fluid.io.DataLoader.from_generator(
                capacity=16,
                feed_list=[left, pos_right, neg_right],
                iterable=False,
                use_double_buffer=False)
P
pkpk 已提交
93

94
            return pairwise_loader, left, pos_right, neg_right
P
pkpk 已提交
95 96


Y
Yibing Liu 已提交
97 98 99 100 101 102 103 104 105
def train(conf_dict, args):
    """
    train processic
    """
    # loading vocabulary
    vocab = utils.load_vocab(args.vocab_path)
    # get vocab size
    conf_dict['dict_size'] = len(vocab)
    # Load network structure dynamically
P
pkpk 已提交
106
    net = utils.import_class("../shared_modules/models/matching",
107 108
                             conf_dict["net"]["module_name"],
                             conf_dict["net"]["class_name"])(conf_dict)
Y
Yibing Liu 已提交
109
    # Load loss function dynamically
P
pkpk 已提交
110
    loss = utils.import_class("../shared_modules/models/matching/losses",
111 112
                              conf_dict["loss"]["module_name"],
                              conf_dict["loss"]["class_name"])(conf_dict)
Y
Yibing Liu 已提交
113 114
    # Load Optimization method
    optimizer = utils.import_class(
P
pkpk 已提交
115
        "../shared_modules/models/matching/optimizers", "paddle_optimizers",
116
        conf_dict["optimizer"]["class_name"])(conf_dict)
Y
Yibing Liu 已提交
117 118 119 120
    # load auc method
    metric = fluid.metrics.Auc(name="auc")
    # Get device
    if args.use_cuda:
Y
Yibing Liu 已提交
121
        place = fluid.CUDAPlace(0)
Y
Yibing Liu 已提交
122
    else:
Y
Yibing Liu 已提交
123
        place = fluid.CPUPlace()
D
Dilyar 已提交
124 125 126
    exe = fluid.Executor(place)
    startup_prog = fluid.Program()
    train_program = fluid.Program()
Y
Yibing Liu 已提交
127

u010070587's avatar
u010070587 已提交
128 129 130 131 132 133
    # used for continuous evaluation 
    if args.enable_ce:
        SEED = 102
        startup_prog.random_seed = SEED
        train_program.random_seed = SEED

Y
Yibing Liu 已提交
134 135 136
    simnet_process = reader.SimNetProcessor(args, vocab)
    if args.task_mode == "pairwise":
        # Build network
D
Dilyar 已提交
137 138
        with fluid.program_guard(train_program, startup_prog):
            with fluid.unique_name.guard():
139
                train_loader, left, pos_right, neg_right = create_model(args)
D
Dilyar 已提交
140 141 142 143 144 145
                left_feat, pos_score = net.predict(left, pos_right)
                pred = pos_score
                _, neg_score = net.predict(left, neg_right)
                avg_cost = loss.compute(pos_score, neg_score)
                avg_cost.persistable = True
                optimizer.ops(avg_cost)
P
pkpk 已提交
146

D
Dilyar 已提交
147
        # Get Reader
P
pkpk 已提交
148 149
        get_train_examples = simnet_process.get_reader(
            "train", epoch=args.epoch)
Y
Yibing Liu 已提交
150
        if args.do_valid:
D
Dilyar 已提交
151 152 153
            test_prog = fluid.Program()
            with fluid.program_guard(test_prog, startup_prog):
                with fluid.unique_name.guard():
154 155
                    test_loader, left, pos_right = create_model(
                        args, is_inference=True)
D
Dilyar 已提交
156 157 158 159
                    left_feat, pos_score = net.predict(left, pos_right)
                    pred = pos_score
            test_prog = test_prog.clone(for_test=True)

Y
Yibing Liu 已提交
160 161
    else:
        # Build network
D
Dilyar 已提交
162 163
        with fluid.program_guard(train_program, startup_prog):
            with fluid.unique_name.guard():
164 165
                train_loader, left, right, label = create_model(
                    args, is_pointwise=True)
D
Dilyar 已提交
166 167 168 169
                left_feat, pred = net.predict(left, right)
                avg_cost = loss.compute(pred, label)
                avg_cost.persistable = True
                optimizer.ops(avg_cost)
Y
Yibing Liu 已提交
170 171

        # Get Feeder and Reader
P
pkpk 已提交
172 173
        get_train_examples = simnet_process.get_reader(
            "train", epoch=args.epoch)
Y
Yibing Liu 已提交
174
        if args.do_valid:
D
Dilyar 已提交
175 176 177
            test_prog = fluid.Program()
            with fluid.program_guard(test_prog, startup_prog):
                with fluid.unique_name.guard():
178 179
                    test_loader, left, right = create_model(
                        args, is_inference=True)
D
Dilyar 已提交
180 181 182 183
                    left_feat, pred = net.predict(left, right)
            test_prog = test_prog.clone(for_test=True)

    if args.init_checkpoint is not "":
P
pkpk 已提交
184
        utils.init_checkpoint(exe, args.init_checkpoint, startup_prog)
D
Dilyar 已提交
185

186
    def valid_and_test(test_program, test_loader, get_valid_examples, process,
P
pkpk 已提交
187
                       mode, exe, fetch_list):
Y
Yibing Liu 已提交
188 189 190 191
        """
        return auc and acc
        """
        # Get Batch Data
P
pkpk 已提交
192 193
        batch_data = fluid.io.batch(
            get_valid_examples, args.batch_size, drop_last=False)
194 195
        test_loader.set_sample_list_generator(batch_data)
        test_loader.start()
Y
Yibing Liu 已提交
196
        pred_list = []
D
Dilyar 已提交
197 198
        while True:
            try:
P
pkpk 已提交
199
                _pred = exe.run(program=test_program, fetch_list=[pred.name])
D
Dilyar 已提交
200 201
                pred_list += list(_pred)
            except fluid.core.EOFException:
202
                test_loader.reset()
D
Dilyar 已提交
203
                break
Y
Yibing Liu 已提交
204 205 206 207 208 209
        pred_list = np.vstack(pred_list)
        if mode == "test":
            label_list = process.get_test_label()
        elif mode == "valid":
            label_list = process.get_valid_label()
        if args.task_mode == "pairwise":
210 211 212
            pred_list = (pred_list + 1) / 2
            pred_list = np.hstack(
                (np.ones_like(pred_list) - pred_list, pred_list))
Y
Yibing Liu 已提交
213 214 215 216
        metric.reset()
        metric.update(pred_list, label_list)
        auc = metric.eval()
        if args.compute_accuracy:
217 218
            acc = utils.get_accuracy(pred_list, label_list, args.task_mode,
                                     args.lamda)
Y
Yibing Liu 已提交
219 220 221 222 223 224 225 226
            return auc, acc
        else:
            return auc

    # run train
    logging.info("start train process ...")
    # set global step
    global_step = 0
Z
zhengya01 已提交
227
    ce_info = []
D
Dilyar 已提交
228
    train_exe = exe
D
Dilyar 已提交
229
    #for epoch_id in range(args.epoch):
u010070587's avatar
u010070587 已提交
230 231
    # used for continuous evaluation
    if args.enable_ce:
P
pkpk 已提交
232 233
        train_batch_data = fluid.io.batch(
            get_train_examples, args.batch_size, drop_last=False)
u010070587's avatar
u010070587 已提交
234 235 236
    else:
        train_batch_data = fluid.io.batch(
            fluid.io.shuffle(
P
pkpk 已提交
237
                get_train_examples, buf_size=10000),
u010070587's avatar
u010070587 已提交
238 239
            args.batch_size,
            drop_last=False)
240 241
    train_loader.set_sample_list_generator(train_batch_data)
    train_loader.start()
D
Dilyar 已提交
242 243 244 245 246 247 248
    exe.run(startup_prog)
    losses = []
    start_time = time.time()
    while True:
        try:
            global_step += 1
            fetch_list = [avg_cost.name]
P
pkpk 已提交
249 250
            avg_loss = train_exe.run(program=train_program,
                                     fetch_list=fetch_list)
251
            losses.append(np.mean(avg_loss[0]))
D
Dilyar 已提交
252 253
            if args.do_valid and global_step % args.validation_steps == 0:
                get_valid_examples = simnet_process.get_reader("valid")
P
pkpk 已提交
254
                valid_result = valid_and_test(
255 256
                    test_prog, test_loader, get_valid_examples, simnet_process,
                    "valid", exe, [pred.name])
D
Dilyar 已提交
257 258 259
                if args.compute_accuracy:
                    valid_auc, valid_acc = valid_result
                    logging.info(
P
pkpk 已提交
260 261
                        "global_steps: %d, valid_auc: %f, valid_acc: %f, valid_loss: %f"
                        % (global_step, valid_auc, valid_acc, np.mean(losses)))
D
Dilyar 已提交
262 263
                else:
                    valid_auc = valid_result
P
pkpk 已提交
264 265 266
                    logging.info(
                        "global_steps: %d, valid_auc: %f, valid_loss: %f" %
                        (global_step, valid_auc, np.mean(losses)))
D
Dilyar 已提交
267 268
            if global_step % args.save_steps == 0:
                model_save_dir = os.path.join(args.output_dir,
P
pkpk 已提交
269
                                              conf_dict["model_path"])
D
Dilyar 已提交
270
                model_path = os.path.join(model_save_dir, str(global_step))
P
pkpk 已提交
271

D
Dilyar 已提交
272 273 274 275 276 277 278 279 280 281 282 283
                if not os.path.exists(model_save_dir):
                    os.makedirs(model_save_dir)
                if args.task_mode == "pairwise":
                    feed_var_names = [left.name, pos_right.name]
                    target_vars = [left_feat, pos_score]
                else:
                    feed_var_names = [
                        left.name,
                        right.name,
                    ]
                    target_vars = [left_feat, pred]
                fluid.io.save_inference_model(model_path, feed_var_names,
P
pkpk 已提交
284
                                              target_vars, exe, test_prog)
D
Dilyar 已提交
285
                logging.info("saving infer model in %s" % model_path)
P
pkpk 已提交
286

D
Dilyar 已提交
287
        except fluid.core.EOFException:
288
            train_loader.reset()
D
Dilyar 已提交
289 290 291
            break
    end_time = time.time()
    #logging.info("epoch: %d, loss: %f, used time: %d sec" %
P
pkpk 已提交
292
    #(epoch_id, np.mean(losses), end_time - start_time))
D
Dilyar 已提交
293
    ce_info.append([np.mean(losses), end_time - start_time])
D
Dilyar 已提交
294
    #final save
P
pkpk 已提交
295 296
    logging.info("the final step is %s" % global_step)
    model_save_dir = os.path.join(args.output_dir, conf_dict["model_path"])
D
Dilyar 已提交
297 298 299 300 301 302 303 304 305 306 307 308
    model_path = os.path.join(model_save_dir, str(global_step))
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    if args.task_mode == "pairwise":
        feed_var_names = [left.name, pos_right.name]
        target_vars = [left_feat, pos_score]
    else:
        feed_var_names = [
            left.name,
            right.name,
        ]
        target_vars = [left_feat, pred]
P
pkpk 已提交
309 310
    fluid.io.save_inference_model(model_path, feed_var_names, target_vars, exe,
                                  test_prog)
D
Dilyar 已提交
311
    logging.info("saving infer model in %s" % model_path)
u010070587's avatar
u010070587 已提交
312
    # used for continuous evaluation
Z
zhengya01 已提交
313 314 315 316 317
    if args.enable_ce:
        card_num = get_cards()
        ce_loss = 0
        ce_time = 0
        try:
u010070587's avatar
u010070587 已提交
318 319
            ce_loss = ce_info[-1][0]
            ce_time = ce_info[-1][1]
Z
zhengya01 已提交
320 321 322
        except:
            logging.info("ce info err!")
        print("kpis\teach_step_duration_%s_card%s\t%s" %
323
              (args.task_name, card_num, ce_time))
Z
zhengya01 已提交
324
        print("kpis\ttrain_loss_%s_card%s\t%f" %
325
              (args.task_name, card_num, ce_loss))
Z
zhengya01 已提交
326

Y
Yibing Liu 已提交
327 328 329
    if args.do_test:
        if args.task_mode == "pairwise":
            # Get Feeder and Reader
D
Dilyar 已提交
330
            get_test_examples = simnet_process.get_reader("test")
Y
Yibing Liu 已提交
331 332
        else:
            # Get Feeder and Reader
D
Dilyar 已提交
333
            get_test_examples = simnet_process.get_reader("test")
334 335
        test_result = valid_and_test(test_prog, test_loader, get_test_examples,
                                     simnet_process, "test", exe, [pred.name])
Y
Yibing Liu 已提交
336 337
        if args.compute_accuracy:
            test_auc, test_acc = test_result
338 339
            logging.info("AUC of test is %f, Accuracy of test is %f" %
                         (test_auc, test_acc))
Y
Yibing Liu 已提交
340 341 342 343 344 345 346
        else:
            test_auc = test_result
            logging.info("AUC of test is %f" % test_auc)


def test(conf_dict, args):
    """
D
Dilyar 已提交
347
    Evaluation Function
Y
Yibing Liu 已提交
348
    """
D
Dilyar 已提交
349 350 351 352 353 354
    if args.use_cuda:
        place = fluid.CUDAPlace(0)
    else:
        place = fluid.CPUPlace()
    exe = fluid.Executor(place)

Y
Yibing Liu 已提交
355 356
    vocab = utils.load_vocab(args.vocab_path)
    simnet_process = reader.SimNetProcessor(args, vocab)
P
pkpk 已提交
357

D
Dilyar 已提交
358 359 360
    startup_prog = fluid.Program()

    get_test_examples = simnet_process.get_reader("test")
P
pkpk 已提交
361 362
    batch_data = fluid.io.batch(
        get_test_examples, args.batch_size, drop_last=False)
D
Dilyar 已提交
363 364 365 366
    test_prog = fluid.Program()

    conf_dict['dict_size'] = len(vocab)

P
pkpk 已提交
367
    net = utils.import_class("../shared_modules/models/matching",
D
Dilyar 已提交
368 369 370
                             conf_dict["net"]["module_name"],
                             conf_dict["net"]["class_name"])(conf_dict)

Y
Yibing Liu 已提交
371
    metric = fluid.metrics.Auc(name="auc")
D
Dilyar 已提交
372

373
    with io.open("predictions.txt", "w", encoding="utf8") as predictions_file:
Y
Yibing Liu 已提交
374
        if args.task_mode == "pairwise":
D
Dilyar 已提交
375 376
            with fluid.program_guard(test_prog, startup_prog):
                with fluid.unique_name.guard():
377 378
                    test_loader, left, pos_right = create_model(
                        args, is_inference=True)
D
Dilyar 已提交
379 380 381 382
                    left_feat, pos_score = net.predict(left, pos_right)
                    pred = pos_score
            test_prog = test_prog.clone(for_test=True)

Y
Yibing Liu 已提交
383
        else:
D
Dilyar 已提交
384 385
            with fluid.program_guard(test_prog, startup_prog):
                with fluid.unique_name.guard():
386 387
                    test_loader, left, right = create_model(
                        args, is_inference=True)
D
Dilyar 已提交
388 389 390 391 392
                    left_feat, pred = net.predict(left, right)
            test_prog = test_prog.clone(for_test=True)

        exe.run(startup_prog)

P
pkpk 已提交
393 394
        utils.init_checkpoint(exe, args.init_checkpoint, main_program=test_prog)

D
Dilyar 已提交
395
        test_exe = exe
396
        test_loader.set_sample_list_generator(batch_data)
D
Dilyar 已提交
397

Y
Yibing Liu 已提交
398
        logging.info("start test process ...")
399
        test_loader.start()
Y
Yibing Liu 已提交
400
        pred_list = []
D
Dilyar 已提交
401 402 403 404
        fetch_list = [pred.name]
        output = []
        while True:
            try:
P
pkpk 已提交
405
                output = test_exe.run(program=test_prog, fetch_list=fetch_list)
D
Dilyar 已提交
406
                if args.task_mode == "pairwise":
P
pkpk 已提交
407 408
                    pred_list += list(
                        map(lambda item: float(item[0]), output[0]))
409
                    predictions_file.write(u"\n".join(
P
pkpk 已提交
410 411
                        map(lambda item: str((item[0] + 1) / 2), output[0])) +
                                           "\n")
D
Dilyar 已提交
412 413
                else:
                    pred_list += map(lambda item: item, output[0])
414
                    predictions_file.write(u"\n".join(
P
pkpk 已提交
415 416
                        map(lambda item: str(np.argmax(item)), output[0])) +
                                           "\n")
D
Dilyar 已提交
417
            except fluid.core.EOFException:
418
                test_loader.reset()
D
Dilyar 已提交
419
                break
Y
Yibing Liu 已提交
420 421
        if args.task_mode == "pairwise":
            pred_list = np.array(pred_list).reshape((-1, 1))
422 423 424
            pred_list = (pred_list + 1) / 2
            pred_list = np.hstack(
                (np.ones_like(pred_list) - pred_list, pred_list))
Y
Yibing Liu 已提交
425 426 427 428 429 430
        else:
            pred_list = np.array(pred_list)
        labels = simnet_process.get_test_label()

        metric.update(pred_list, labels)
        if args.compute_accuracy:
431 432 433 434
            acc = utils.get_accuracy(pred_list, labels, args.task_mode,
                                     args.lamda)
            logging.info("AUC of test is %f, Accuracy of test is %f" %
                         (metric.eval(), acc))
Y
Yibing Liu 已提交
435 436 437 438 439
        else:
            logging.info("AUC of test is %f" % metric.eval())

    if args.verbose_result:
        utils.get_result_file(args)
440 441
        logging.info("test result saved in %s" %
                     os.path.join(os.getcwd(), args.test_result_path))
Y
Yibing Liu 已提交
442 443


D
Dilyar 已提交
444
def infer(conf_dict, args):
Y
Yibing Liu 已提交
445 446 447 448
    """
    run predict
    """
    if args.use_cuda:
Y
Yibing Liu 已提交
449
        place = fluid.CUDAPlace(0)
Y
Yibing Liu 已提交
450
    else:
Y
Yibing Liu 已提交
451
        place = fluid.CPUPlace()
D
Dilyar 已提交
452 453 454 455 456 457 458 459
    exe = fluid.Executor(place)

    vocab = utils.load_vocab(args.vocab_path)
    simnet_process = reader.SimNetProcessor(args, vocab)

    startup_prog = fluid.Program()

    get_infer_examples = simnet_process.get_infer_reader
P
pkpk 已提交
460 461
    batch_data = fluid.io.batch(
        get_infer_examples, args.batch_size, drop_last=False)
D
Dilyar 已提交
462 463 464 465 466

    test_prog = fluid.Program()

    conf_dict['dict_size'] = len(vocab)

P
pkpk 已提交
467
    net = utils.import_class("../shared_modules/models/matching",
D
Dilyar 已提交
468 469 470
                             conf_dict["net"]["module_name"],
                             conf_dict["net"]["class_name"])(conf_dict)

Y
Yibing Liu 已提交
471
    if args.task_mode == "pairwise":
D
Dilyar 已提交
472 473
        with fluid.program_guard(test_prog, startup_prog):
            with fluid.unique_name.guard():
474 475
                infer_loader, left, pos_right = create_model(
                    args, is_inference=True)
D
Dilyar 已提交
476 477 478
                left_feat, pos_score = net.predict(left, pos_right)
                pred = pos_score
        test_prog = test_prog.clone(for_test=True)
Y
Yibing Liu 已提交
479
    else:
D
Dilyar 已提交
480 481
        with fluid.program_guard(test_prog, startup_prog):
            with fluid.unique_name.guard():
482 483
                infer_loader, left, right = create_model(
                    args, is_inference=True)
D
Dilyar 已提交
484 485 486 487 488
                left_feat, pred = net.predict(left, right)
        test_prog = test_prog.clone(for_test=True)

    exe.run(startup_prog)

P
pkpk 已提交
489 490
    utils.init_checkpoint(exe, args.init_checkpoint, main_program=test_prog)

D
Dilyar 已提交
491
    test_exe = exe
492
    infer_loader.set_sample_list_generator(batch_data)
D
Dilyar 已提交
493

Y
Yibing Liu 已提交
494 495
    logging.info("start test process ...")
    preds_list = []
D
Dilyar 已提交
496 497
    fetch_list = [pred.name]
    output = []
498
    infer_loader.start()
D
Dilyar 已提交
499
    while True:
P
pkpk 已提交
500 501 502 503 504 505 506 507
        try:
            output = test_exe.run(program=test_prog, fetch_list=fetch_list)
            if args.task_mode == "pairwise":
                preds_list += list(
                    map(lambda item: str((item[0] + 1) / 2), output[0]))
            else:
                preds_list += map(lambda item: str(np.argmax(item)), output[0])
        except fluid.core.EOFException:
508
            infer_loader.reset()
P
pkpk 已提交
509
            break
510
    with io.open(args.infer_result_path, "w", encoding="utf8") as infer_file:
Y
Yibing Liu 已提交
511 512
        for _data, _pred in zip(simnet_process.get_infer_data(), preds_list):
            infer_file.write(_data + "\t" + _pred + "\n")
513 514
    logging.info("infer result saved in %s" %
                 os.path.join(os.getcwd(), args.infer_result_path))
Y
Yibing Liu 已提交
515 516


Z
zhengya01 已提交
517 518 519 520 521 522 523
def get_cards():
    num = 0
    cards = os.environ.get('CUDA_VISIBLE_DEVICES', '')
    if cards != '':
        num = len(cards.split(","))
    return num

P
pkpk 已提交
524

D
Dilyar 已提交
525
if __name__ == "__main__":
Z
zhengya01 已提交
526

D
Dilyar 已提交
527 528
    args = ArgConfig()
    args = args.build_conf()
Y
Yibing Liu 已提交
529 530

    utils.print_arguments(args)
531 532
    check_cuda(args.use_cuda)
    check_version()
Y
Yibing Liu 已提交
533 534
    utils.init_log("./log/TextSimilarityNet")
    conf_dict = config.SimNetConfig(args)
D
Dilyar 已提交
535 536 537 538 539 540 541 542
    if args.do_train:
        train(conf_dict, args)
    elif args.do_test:
        test(conf_dict, args)
    elif args.do_infer:
        infer(conf_dict, args)
    else:
        raise ValueError(
P
pkpk 已提交
543
            "one of do_train and do_test and do_infer must be True")