predict.py 5.6 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Load checkpoint of running classifier to do prediction and save inference model."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import time
import numpy as np
import multiprocessing

import paddle
import paddle.fluid as fluid

from finetune_args import parser
from utils.args import print_arguments
from utils.init import init_pretraining_params, init_checkpoint

import define_predict_pack
import reader.data_reader as reader

_WORK_DIR = os.path.split(os.path.realpath(__file__))[0]
P
pkpk 已提交
37 38 39
sys.path.append(
    '../../models/dialogue_model_toolkit/dialogue_general_understanding')
sys.path.append('../../models/')
Y
Yibing Liu 已提交
40

P
pkpk 已提交
41
from bert import BertConfig, BertModel
Y
Yibing Liu 已提交
42
from create_model import create_model
P
pkpk 已提交
43 44 45
import define_paradigm

from model_check import check_cuda
Y
Yibing Liu 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61


def main(args):
    """main function"""
    bert_config = BertConfig(args.bert_config_path)
    bert_config.print_config()

    task_name = args.task_name.lower()
    paradigm_inst = define_paradigm.Paradigm(task_name)
    pred_inst = define_predict_pack.DefinePredict()
    pred_func = getattr(pred_inst, pred_inst.task_map[task_name])

    processors = {
        'udc': reader.UDCProcessor,
        'swda': reader.SWDAProcessor,
        'mrda': reader.MRDAProcessor,
P
pkpk 已提交
62
        'atis_slot': reader.ATISSlotProcessor,
Y
Yibing Liu 已提交
63
        'atis_intent': reader.ATISIntentProcessor,
P
pkpk 已提交
64 65
        'dstc2': reader.DSTC2Processor,
        'dstc2_asr': reader.DSTC2Processor,
Y
Yibing Liu 已提交
66 67 68 69 70 71 72 73
    }

    in_tokens = {
        'udc': True,
        'swda': True,
        'mrda': True,
        'atis_slot': False,
        'atis_intent': True,
P
pkpk 已提交
74 75
        'dstc2': True,
        'dstc2_asr': True
Y
Yibing Liu 已提交
76 77 78 79 80
    }

    processor = processors[task_name](data_dir=args.data_dir,
                                      vocab_path=args.vocab_path,
                                      max_seq_len=args.max_seq_len,
P
pkpk 已提交
81
                                      do_lower_case=args.do_lower_case,
Y
Yibing Liu 已提交
82
                                      in_tokens=in_tokens[task_name],
P
pkpk 已提交
83
                                      task_name=task_name,
Y
Yibing Liu 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
                                      random_seed=args.random_seed)
    num_labels = len(processor.get_labels())

    predict_prog = fluid.Program()
    predict_startup = fluid.Program()
    with fluid.program_guard(predict_prog, predict_startup):
        with fluid.unique_name.guard():
            pred_results = create_model(
                args,
                pyreader_name='predict_reader',
                bert_config=bert_config,
                num_labels=num_labels,
                paradigm_inst=paradigm_inst,
                is_prediction=True)
            predict_pyreader = pred_results.get('pyreader', None)
            probs = pred_results.get('probs', None)
100
            feed_target_names = pred_results.get('feed_targets_name', None)
Y
Yibing Liu 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123

    predict_prog = predict_prog.clone(for_test=True)

    if args.use_cuda:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()
    else:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))

    place = fluid.CUDAPlace(0) if args.use_cuda == True else fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(predict_startup)

    if args.init_checkpoint:
        init_pretraining_params(exe, args.init_checkpoint, predict_prog)
    else:
        raise ValueError("args 'init_checkpoint' should be set for prediction!")

    predict_exe = fluid.ParallelExecutor(
        use_cuda=args.use_cuda, main_program=predict_prog)

    test_data_generator = processor.data_generator(
P
pkpk 已提交
124
        batch_size=args.batch_size, phase='test', epoch=1, shuffle=False)
Y
Yibing Liu 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
    predict_pyreader.decorate_tensor_provider(test_data_generator)

    predict_pyreader.start()
    all_results = []
    time_begin = time.time()
    while True:
        try:
            results = predict_exe.run(fetch_list=[probs.name])
            all_results.extend(results[0])
        except fluid.core.EOFException:
            predict_pyreader.reset()
            break
    time_end = time.time()

    np.set_printoptions(precision=4, suppress=True)
    print("-------------- prediction results --------------")
    print("example_id\t" + '  '.join(processor.get_labels()))
P
pkpk 已提交
142 143
    if in_tokens[task_name]:
        for index, result in enumerate(all_results):
Y
Yibing Liu 已提交
144 145
            tags = pred_func(result)
            print("%s\t%s" % (index, tags))
P
pkpk 已提交
146
    else:
Y
Yibing Liu 已提交
147
        tags = pred_func(all_results, args.max_seq_len)
P
pkpk 已提交
148
        for index, tag in enumerate(tags):
Y
Yibing Liu 已提交
149
            print("%s\t%s" % (index, tag))
P
pkpk 已提交
150

Y
Yibing Liu 已提交
151 152 153 154 155 156 157 158 159 160 161
    if args.save_inference_model_path:
        _, ckpt_dir = os.path.split(args.init_checkpoint)
        dir_name = ckpt_dir + '_inference_model'
        model_path = os.path.join(args.save_inference_model_path, dir_name)
        fluid.io.save_inference_model(
            model_path,
            feed_target_names, [probs],
            exe,
            main_program=predict_prog)


P
pkpk 已提交
162
if __name__ == '__main__':
Y
Yibing Liu 已提交
163 164
    args = parser.parse_args()
    print_arguments(args)
P
pkpk 已提交
165 166 167

    check_cuda(args.use_cuda)

Y
Yibing Liu 已提交
168
    main(args)