model_wrapper.py 5.0 KB
Newer Older
0
0YuanZhang0 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""BERT (PaddlePaddle) model wrapper"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import json
import collections
import multiprocessing
import argparse
import numpy as np
import paddle.fluid as fluid
from task_reader.mrqa import DataProcessor, get_answers

ema_decay = 0.9999
verbose = False
max_seq_len = 512
max_query_length = 64
max_answer_length = 30
in_tokens = False
do_lower_case = False
doc_stride = 128
n_best_size = 20
use_cuda = True

class BertModelWrapper():
    """
    Wrap a tnet model
     the basic processes include input checking, preprocessing, calling tf-serving
     and postprocessing
    """
    def __init__(self, model_dir):
        """ """
        if use_cuda:
            place = fluid.CUDAPlace(0)
            dev_count = fluid.core.get_cuda_device_count()
        else:
            place = fluid.CPUPlace()
            dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
        self.exe = fluid.Executor(place)

        self.bert_preprocessor = DataProcessor(
            vocab_path=os.path.join(model_dir, 'vocab.txt'),
            do_lower_case=do_lower_case,
            max_seq_length=max_seq_len,
            in_tokens=in_tokens,
            doc_stride=doc_stride,
            max_query_length=max_query_length)

        self.inference_program, self.feed_target_names, self.fetch_targets = \
            fluid.io.load_inference_model(dirname=model_dir, executor=self.exe)

    def preprocessor(self, samples, batch_size, examples_start_id, features_start_id):
        """Preprocess the input samples, including word seg, padding, token to ids"""
        # Tokenization and paragraph padding
        examples, features, batch = self.bert_preprocessor.data_generator(
            samples, batch_size, max_len=max_seq_len, examples_start_id=examples_start_id, features_start_id=features_start_id)
        self.samples = samples
        return examples, features, batch

    def call_mrc(self, batch, squeeze_dim0=False, return_list=False):
        """MRC"""
        if squeeze_dim0 and return_list:
            raise ValueError("squeeze_dim0 only work for dict-type return value.")
        src_ids = batch[0]
        pos_ids = batch[1]
        sent_ids = batch[2]
        input_mask = batch[3]
        unique_id = batch[4]
        feed_dict = {
            self.feed_target_names[0]: src_ids,
            self.feed_target_names[1]: pos_ids,
            self.feed_target_names[2]: sent_ids,
            self.feed_target_names[3]: input_mask,
            self.feed_target_names[4]: unique_id
        }
        
        np_unique_ids, np_start_logits, np_end_logits, np_num_seqs = \
            self.exe.run(self.inference_program, feed=feed_dict, fetch_list=self.fetch_targets)

        if len(np_unique_ids) == 1 and squeeze_dim0:
            np_unique_ids = np_unique_ids[0]
            np_start_logits = np_start_logits[0]
            np_end_logits = np_end_logits[0]

        if return_list:
            mrc_results = [{'unique_ids': id, 'start_logits': st, 'end_logits': end} 
                            for id, st, end in zip(np_unique_ids, np_start_logits, np_end_logits)]
        else:
            mrc_results = {
                'unique_ids': np_unique_ids,
                'start_logits': np_start_logits,
                'end_logits': np_end_logits,
            }
        return mrc_results

    def postprocessor(self, examples, features, mrc_results):
        """Extract answer
         batch: [examples, features] from preprocessor
         mrc_results: model results from call_mrc. if mrc_results is list, each element of which is a size=1 batch.
        """
        RawResult = collections.namedtuple("RawResult",
                                           ["unique_id", "start_logits", "end_logits"])
        results = []
        if isinstance(mrc_results, list):
            for res in mrc_results:
                unique_id = res['unique_ids'][0]
                start_logits = [float(x) for x in res['start_logits'].flat]
                end_logits = [float(x) for x in res['end_logits'].flat]
                results.append(
                    RawResult(
                        unique_id=unique_id,
                        start_logits=start_logits,
                        end_logits=end_logits))
        else:
            assert isinstance(mrc_results, dict)
            for idx in range(mrc_results['unique_ids'].shape[0]):
                unique_id = int(mrc_results['unique_ids'][idx])
                start_logits = [float(x) for x in mrc_results['start_logits'][idx].flat]
                end_logits = [float(x) for x in mrc_results['end_logits'][idx].flat]
                results.append(
                    RawResult(
                        unique_id=unique_id,
                        start_logits=start_logits,
                        end_logits=end_logits))
        
        answers = get_answers(
            examples, features, results, n_best_size,
            max_answer_length, do_lower_case, verbose)
        return answers