bert_client.py 5.6 KB
Newer Older
M
MRXLT 已提交
1
# coding:utf-8
B
barrierye 已提交
2
# pylint: disable=doc-string-missing
M
MRXLT 已提交
3
import os
M
MRXLT 已提交
4 5 6 7 8
import sys
import numpy as np
import paddlehub as hub
import ujson
import random
M
MRXLT 已提交
9
import time
M
MRXLT 已提交
10 11 12
from paddlehub.common.logger import logger
import socket
from paddle_serving_client import Client
G
guru4elephant 已提交
13 14 15
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
args = benchmark_args()
M
MRXLT 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33

_ver = sys.version_info
is_py2 = (_ver[0] == 2)
is_py3 = (_ver[0] == 3)

if is_py2:
    import httplib
if is_py3:
    import http.client as httplib


class BertService():
    def __init__(self,
                 max_seq_len=128,
                 model_name="bert_uncased_L-12_H-768_A-12",
                 show_ids=False,
                 do_lower_case=True,
                 process_id=0,
M
MRXLT 已提交
34
                 retry=3):
M
MRXLT 已提交
35 36 37 38 39 40 41 42
        self.process_id = process_id
        self.reader_flag = False
        self.batch_size = 0
        self.max_seq_len = max_seq_len
        self.model_name = model_name
        self.show_ids = show_ids
        self.do_lower_case = do_lower_case
        self.retry = retry
M
MRXLT 已提交
43
        self.pid = os.getpid()
M
MRXLT 已提交
44 45
        self.profile = True if ("FLAGS_profile_client" in os.environ and
                                os.environ["FLAGS_profile_client"]) else False
M
MRXLT 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70

        module = hub.Module(name=self.model_name)
        inputs, outputs, program = module.context(
            trainable=True, max_seq_len=self.max_seq_len)
        input_ids = inputs["input_ids"]
        position_ids = inputs["position_ids"]
        segment_ids = inputs["segment_ids"]
        input_mask = inputs["input_mask"]
        self.reader = hub.reader.ClassifyReader(
            vocab_path=module.get_vocab_path(),
            dataset=None,
            max_seq_len=self.max_seq_len,
            do_lower_case=self.do_lower_case)
        self.reader_flag = True

    def load_client(self, config_file, server_addr):
        self.client = Client()
        self.client.load_client_config(config_file)
        self.client.connect(server_addr)

    def run_general(self, text, fetch):
        self.batch_size = len(text)
        data_generator = self.reader.data_generator(
            batch_size=self.batch_size, phase='predict', data=text)
        result = []
M
MRXLT 已提交
71
        prepro_start = time.time()
M
MRXLT 已提交
72 73 74 75 76 77 78 79 80 81 82 83
        for run_step, batch in enumerate(data_generator(), start=1):
            token_list = batch[0][0].reshape(-1).tolist()
            pos_list = batch[0][1].reshape(-1).tolist()
            sent_list = batch[0][2].reshape(-1).tolist()
            mask_list = batch[0][3].reshape(-1).tolist()
            for si in range(self.batch_size):
                feed = {
                    "input_ids": token_list,
                    "position_ids": pos_list,
                    "segment_ids": sent_list,
                    "input_mask": mask_list
                }
M
MRXLT 已提交
84 85
                prepro_end = time.time()
                if self.profile:
M
MRXLT 已提交
86 87
                    print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
                        self.pid,
M
MRXLT 已提交
88 89
                        int(round(prepro_start * 1000000)),
                        int(round(prepro_end * 1000000))))
M
MRXLT 已提交
90 91 92 93 94 95 96 97 98
                fetch_map = self.client.predict(feed=feed, fetch=fetch)

        return fetch_map

    def run_batch_general(self, text, fetch):
        self.batch_size = len(text)
        data_generator = self.reader.data_generator(
            batch_size=self.batch_size, phase='predict', data=text)
        result = []
M
MRXLT 已提交
99
        prepro_start = time.time()
M
MRXLT 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
        for run_step, batch in enumerate(data_generator(), start=1):
            token_list = batch[0][0].reshape(-1).tolist()
            pos_list = batch[0][1].reshape(-1).tolist()
            sent_list = batch[0][2].reshape(-1).tolist()
            mask_list = batch[0][3].reshape(-1).tolist()
            feed_batch = []
            for si in range(self.batch_size):
                feed = {
                    "input_ids": token_list[si * self.max_seq_len:(si + 1) *
                                            self.max_seq_len],
                    "position_ids":
                    pos_list[si * self.max_seq_len:(si + 1) * self.max_seq_len],
                    "segment_ids": sent_list[si * self.max_seq_len:(si + 1) *
                                             self.max_seq_len],
                    "input_mask":
                    mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len]
                }
                feed_batch.append(feed)
M
MRXLT 已提交
118 119
            prepro_end = time.time()
            if self.profile:
M
MRXLT 已提交
120 121
                print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
                    self.pid,
M
MRXLT 已提交
122 123
                    int(round(prepro_start * 1000000)),
                    int(round(prepro_end * 1000000))))
M
MRXLT 已提交
124 125 126 127 128
            fetch_map_batch = self.client.batch_predict(
                feed_batch=feed_batch, fetch=fetch)
        return fetch_map_batch


G
guru4elephant 已提交
129
def single_func(idx, resource):
M
MRXLT 已提交
130
    bc = BertService(
M
MRXLT 已提交
131
        model_name='bert_chinese_L-12_H-768_A-12',
M
MRXLT 已提交
132 133 134 135 136
        max_seq_len=20,
        show_ids=False,
        do_lower_case=True)
    config_file = './serving_client_conf/serving_client_conf.prototxt'
    fetch = ["pooled_output"]
G
guru4elephant 已提交
137
    server_addr = [resource["endpoint"][idx]]
M
MRXLT 已提交
138
    bc.load_client(config_file, server_addr)
M
MRXLT 已提交
139
    batch_size = 1
G
guru4elephant 已提交
140 141 142 143 144 145
    start = time.time()
    fin = open("data-c.txt")
    for line in fin:
        result = bc.run_general([[line.strip()]], fetch)
    end = time.time()
    return [[end - start]]
M
MRXLT 已提交
146

B
barrierye 已提交
147

M
MRXLT 已提交
148
if __name__ == '__main__':
G
guru4elephant 已提交
149
    multi_thread_runner = MultiThreadRunner()
B
barrierye 已提交
150 151 152 153 154 155
    result = multi_thread_runner.run(single_func, args.thread, {
        "endpoint": [
            "127.0.0.1:9494", "127.0.0.1:9495", "127.0.0.1:9496",
            "127.0.0.1:9497"
        ]
    })