wrapper.py 6.1 KB
Newer Older
0
0YuanZhang0 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""BERT (PaddlePaddle) model wrapper"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import json
import collections
import multiprocessing
import argparse
import numpy as np
import paddle.fluid as fluid
from squad_reader import DataProcessor, get_answers
from model.xlnet import XLNetConfig, XLNetModel


conf_dir = "xlnet_config"
bert_config_path = conf_dir+'/xlnet_config.json'
spiece_model_file = conf_dir+'/spiece.model'
ema_decay = 0.9999
verbose = False
vocab_path = conf_dir+'/vocab.txt'
max_seq_len = 800
max_query_length = 64
max_answer_length = 30
in_tokens = False
do_lower_case = False
doc_stride = 128
n_best_size = 20
start_n_top = 5
end_n_top = 5
use_cuda = True


class BertModelWrapper():
    """
    Wrap a tnet model
     the basic processes include input checking, preprocessing, calling tf-serving
     and postprocessing
    """
    def __init__(self, model_dir):
        """ """
        xlnet_config = XLNetConfig(bert_config_path)
        xlnet_config.print_config()

        if use_cuda:
            place = fluid.CUDAPlace(0)
            dev_count = fluid.core.get_cuda_device_count()
        else:
            place = fluid.CPUPlace()
            dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
        self.exe = fluid.Executor(place)

        self.processor = DataProcessor(
            spiece_model_file=spiece_model_file,
            uncased=do_lower_case,
            max_seq_length=max_seq_len,
            doc_stride=doc_stride,
            max_query_length=max_query_length)

        self.inference_program, self.feed_target_names, self.fetch_targets = \
            fluid.io.load_inference_model(dirname=model_dir, executor=self.exe)

        # self.inference_program = fluid.compiler.CompiledProgram(self.inference_program)
        # self.exe = fluid.ParallelExecutor(
        #     use_cuda=use_cuda,
        #     main_program=self.inference_program)

    def preprocessor(self, samples, batch_size):
        """Preprocess the input samples, including word seg, padding, token to ids"""
        # Tokenization and paragraph padding
        examples, features, batch = self.processor.data_generator(
            samples, batch_size)
        self.samples = samples
        return examples, features, batch

    def call_mrc(self, batch, squeeze_dim0=False, return_list=False):
        """MRC"""
        if squeeze_dim0 and return_list:
            raise ValueError("squeeze_dim0 only work for dict-type return value.")
        src_ids = batch[0]
        pos_ids = batch[1]
        sent_ids = batch[2]
        input_mask = batch[3]
        unique_id = batch[4]
        emmmm = batch[5]
        feed_dict = {
            self.feed_target_names[0]: src_ids,
            self.feed_target_names[1]: pos_ids,
            self.feed_target_names[2]: sent_ids,
            self.feed_target_names[3]: input_mask,
            self.feed_target_names[4]: unique_id,
            self.feed_target_names[5]: emmmm
        }
        
        np_unique_ids, np_start_logits, np_start_top_index, np_end_logits, np_end_top_index, np_cls_logits = \
            self.exe.run(self.inference_program, feed=feed_dict, fetch_list=self.fetch_targets, use_program_cache=True)

        # np_unique_ids, np_start_logits, np_end_logits, np_num_seqs = \
        #     self.exe.run(feed=feed_dict, fetch_list=self.fetch_targets)

        if len(np_unique_ids) == 1 and squeeze_dim0:
            np_unique_ids = np_unique_ids[0]
            np_start_logits = np_start_logits[0]
            np_end_logits = np_end_logits[0]

        if return_list:
            mrc_results = [{'unique_ids': id, 'start_logits': st, 'start_idx': st_idx, 'end_logits': end, 'end_idx': end_idx, 'cls': cls} 
                            for id, st, st_idx, end, end_idx, cls in zip(np_unique_ids, np_start_logits, np_start_top_index, np_end_logits, np_end_top_index, np_cls_logits)]
        else:
            raise NotImplementedError()
        return mrc_results

    def postprocessor(self, examples, features, mrc_results):
        """Extract answer
         batch: [examples, features] from preprocessor
         mrc_results: model results from call_mrc. if mrc_results is list, each element of which is a size=1 batch.
        """
        RawResult = collections.namedtuple("RawResult",
                                            ["unique_id", "start_top_log_probs", "start_top_index",
                                            "end_top_log_probs", "end_top_index", "cls_logits"])
        results = []
        if isinstance(mrc_results, list):
            for res in mrc_results:
                unique_id = res['unique_ids'][0]
                start_logits = [float(x) for x in res['start_logits'].flat]
                start_idx = [int(x) for x in res['start_idx'].flat]
                end_logits = [float(x) for x in res['end_logits'].flat]
                end_idx = [int(x) for x in res['end_idx'].flat]
                cls_logits = float(res['cls'].flat[0])

                results.append(
                    RawResult(
                        unique_id=unique_id,
                        start_top_log_probs=start_logits,
                        start_top_index=start_idx,
                        end_top_log_probs=end_logits,
                        end_top_index=end_idx,
                        cls_logits=cls_logits))
        else:
            assert isinstance(mrc_results, dict)
            raise NotImplementedError()
            for idx in range(mrc_results['unique_ids'].shape[0]):
                unique_id = int(mrc_results['unique_ids'][idx])
                start_logits = [float(x) for x in mrc_results['start_logits'][idx].flat]
                end_logits = [float(x) for x in mrc_results['end_logits'][idx].flat]
                results.append(
                    RawResult(
                        unique_id=unique_id,
                        start_logits=start_logits,
                        end_logits=end_logits))
        
        answers = get_answers(
            examples, features, results, n_best_size,
            max_answer_length, start_n_top, end_n_top)
        return answers