bert_client.py 5.6 KB
Newer Older
M
MRXLT 已提交
1
# coding:utf-8
M
MRXLT 已提交
2
import os
M
MRXLT 已提交
3 4 5 6 7
import sys
import numpy as np
import paddlehub as hub
import ujson
import random
M
MRXLT 已提交
8
import time
M
MRXLT 已提交
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
from paddlehub.common.logger import logger
import socket
from paddle_serving_client import Client

_ver = sys.version_info
is_py2 = (_ver[0] == 2)
is_py3 = (_ver[0] == 3)

if is_py2:
    import httplib
if is_py3:
    import http.client as httplib


class BertService():
    def __init__(self,
                 max_seq_len=128,
                 model_name="bert_uncased_L-12_H-768_A-12",
                 show_ids=False,
                 do_lower_case=True,
                 process_id=0,
M
MRXLT 已提交
30
                 retry=3):
M
MRXLT 已提交
31 32 33 34 35 36 37 38
        self.process_id = process_id
        self.reader_flag = False
        self.batch_size = 0
        self.max_seq_len = max_seq_len
        self.model_name = model_name
        self.show_ids = show_ids
        self.do_lower_case = do_lower_case
        self.retry = retry
M
MRXLT 已提交
39
        self.pid = os.getpid()
M
MRXLT 已提交
40 41
        self.profile = True if ("FLAGS_profile_client" in os.environ and
                                os.environ["FLAGS_profile_client"]) else False
M
MRXLT 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66

        module = hub.Module(name=self.model_name)
        inputs, outputs, program = module.context(
            trainable=True, max_seq_len=self.max_seq_len)
        input_ids = inputs["input_ids"]
        position_ids = inputs["position_ids"]
        segment_ids = inputs["segment_ids"]
        input_mask = inputs["input_mask"]
        self.reader = hub.reader.ClassifyReader(
            vocab_path=module.get_vocab_path(),
            dataset=None,
            max_seq_len=self.max_seq_len,
            do_lower_case=self.do_lower_case)
        self.reader_flag = True

    def load_client(self, config_file, server_addr):
        self.client = Client()
        self.client.load_client_config(config_file)
        self.client.connect(server_addr)

    def run_general(self, text, fetch):
        self.batch_size = len(text)
        data_generator = self.reader.data_generator(
            batch_size=self.batch_size, phase='predict', data=text)
        result = []
M
MRXLT 已提交
67
        prepro_start = time.time()
M
MRXLT 已提交
68 69 70 71 72 73 74 75 76 77 78 79
        for run_step, batch in enumerate(data_generator(), start=1):
            token_list = batch[0][0].reshape(-1).tolist()
            pos_list = batch[0][1].reshape(-1).tolist()
            sent_list = batch[0][2].reshape(-1).tolist()
            mask_list = batch[0][3].reshape(-1).tolist()
            for si in range(self.batch_size):
                feed = {
                    "input_ids": token_list,
                    "position_ids": pos_list,
                    "segment_ids": sent_list,
                    "input_mask": mask_list
                }
M
MRXLT 已提交
80 81
                prepro_end = time.time()
                if self.profile:
M
MRXLT 已提交
82 83
                    print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
                        self.pid,
M
MRXLT 已提交
84 85
                        int(round(prepro_start * 1000000)),
                        int(round(prepro_end * 1000000))))
M
MRXLT 已提交
86 87 88 89 90 91 92 93 94
                fetch_map = self.client.predict(feed=feed, fetch=fetch)

        return fetch_map

    def run_batch_general(self, text, fetch):
        self.batch_size = len(text)
        data_generator = self.reader.data_generator(
            batch_size=self.batch_size, phase='predict', data=text)
        result = []
M
MRXLT 已提交
95
        prepro_start = time.time()
M
MRXLT 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        for run_step, batch in enumerate(data_generator(), start=1):
            token_list = batch[0][0].reshape(-1).tolist()
            pos_list = batch[0][1].reshape(-1).tolist()
            sent_list = batch[0][2].reshape(-1).tolist()
            mask_list = batch[0][3].reshape(-1).tolist()
            feed_batch = []
            for si in range(self.batch_size):
                feed = {
                    "input_ids": token_list[si * self.max_seq_len:(si + 1) *
                                            self.max_seq_len],
                    "position_ids":
                    pos_list[si * self.max_seq_len:(si + 1) * self.max_seq_len],
                    "segment_ids": sent_list[si * self.max_seq_len:(si + 1) *
                                             self.max_seq_len],
                    "input_mask":
                    mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len]
                }
                feed_batch.append(feed)
M
MRXLT 已提交
114 115
            prepro_end = time.time()
            if self.profile:
M
MRXLT 已提交
116 117
                print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
                    self.pid,
M
MRXLT 已提交
118 119
                    int(round(prepro_start * 1000000)),
                    int(round(prepro_end * 1000000))))
M
MRXLT 已提交
120 121 122 123 124 125 126
            fetch_map_batch = self.client.batch_predict(
                feed_batch=feed_batch, fetch=fetch)
        return fetch_map_batch


def test():
    bc = BertService(
M
MRXLT 已提交
127
        model_name='bert_chinese_L-12_H-768_A-12',
M
MRXLT 已提交
128 129 130
        max_seq_len=20,
        show_ids=False,
        do_lower_case=True)
M
MRXLT 已提交
131
    server_addr = ["127.0.0.1:9292"]
M
MRXLT 已提交
132 133 134
    config_file = './serving_client_conf/serving_client_conf.prototxt'
    fetch = ["pooled_output"]
    bc.load_client(config_file, server_addr)
M
MRXLT 已提交
135
    batch_size = 1
M
MRXLT 已提交
136 137
    batch = []
    for line in sys.stdin:
M
MRXLT 已提交
138 139 140
        if batch_size == 1:
            result = bc.run_general([[line.strip()]], fetch)
            print(result)
M
MRXLT 已提交
141
        else:
M
MRXLT 已提交
142 143 144 145 146 147 148 149 150 151 152 153
            if len(batch) < batch_size:
                batch.append([line.strip()])
            else:
                result = bc.run_batch_general(batch, fetch)
                batch = []
                for r in result:
                    print(r)
    if len(batch) > 0:
        result = bc.run_batch_general(batch, fetch)
        batch = []
        for r in result:
            print(r)
M
MRXLT 已提交
154 155 156 157


if __name__ == '__main__':
    test()