bert_rpc_client.py 3.9 KB
Newer Older
M
MRXLT 已提交
1
# coding:utf-8
B
barrierye 已提交
2
# pylint: disable=doc-string-missing
M
MRXLT 已提交
3
import os
M
MRXLT 已提交
4 5 6 7 8
import sys
import numpy as np
import paddlehub as hub
import ujson
import random
M
MRXLT 已提交
9
import time
M
MRXLT 已提交
10 11 12
from paddlehub.common.logger import logger
import socket
from paddle_serving_client import Client
G
guru4elephant 已提交
13 14
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
M
MRXLT 已提交
15 16
from bert_reader import BertReader

G
guru4elephant 已提交
17
args = benchmark_args()
M
MRXLT 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

_ver = sys.version_info
is_py2 = (_ver[0] == 2)
is_py3 = (_ver[0] == 3)

if is_py2:
    import httplib
if is_py3:
    import http.client as httplib


class BertService():
    def __init__(self,
                 max_seq_len=128,
                 model_name="bert_uncased_L-12_H-768_A-12",
                 show_ids=False,
                 do_lower_case=True,
                 process_id=0,
M
MRXLT 已提交
36
                 retry=3):
M
MRXLT 已提交
37 38 39 40 41 42 43 44
        self.process_id = process_id
        self.reader_flag = False
        self.batch_size = 0
        self.max_seq_len = max_seq_len
        self.model_name = model_name
        self.show_ids = show_ids
        self.do_lower_case = do_lower_case
        self.retry = retry
M
MRXLT 已提交
45
        self.pid = os.getpid()
M
MRXLT 已提交
46 47
        self.profile = True if ("FLAGS_profile_client" in os.environ and
                                os.environ["FLAGS_profile_client"]) else False
M
MRXLT 已提交
48
        self.reader = BertReader(vocab_file="vocab.txt", max_seq_len=20)
M
MRXLT 已提交
49 50 51 52 53 54 55 56 57
        self.reader_flag = True

    def load_client(self, config_file, server_addr):
        self.client = Client()
        self.client.load_client_config(config_file)
        self.client.connect(server_addr)

    def run_general(self, text, fetch):
        result = []
M
MRXLT 已提交
58
        prepro_start = time.time()
M
MRXLT 已提交
59 60 61 62 63 64 65
        feed = self.reader.process(text)
        if self.profile:
            print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
                self.pid,
                int(round(prepro_start * 1000000)),
                int(round(prepro_end * 1000000))))
        fetch_map = self.client.predict(feed=feed, fetch=fetch)
M
MRXLT 已提交
66 67 68 69 70 71

        return fetch_map

    def run_batch_general(self, text, fetch):
        self.batch_size = len(text)
        result = []
M
MRXLT 已提交
72
        prepro_start = time.time()
M
MRXLT 已提交
73 74 75 76 77 78 79 80 81 82 83 84
        feed_batch = []
        for si in range(self.batch_size):
            feed = self.reader.process(text[si])
            feed_batch.append(feed)
        prepro_end = time.time()
        if self.profile:
            print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
                self.pid,
                int(round(prepro_start * 1000000)),
                int(round(prepro_end * 1000000))))
        fetch_map_batch = self.client.batch_predict(
            feed_batch=feed_batch, fetch=fetch)
M
MRXLT 已提交
85 86 87
        return fetch_map_batch


G
guru4elephant 已提交
88
def single_func(idx, resource):
M
MRXLT 已提交
89
    bc = BertService(
M
MRXLT 已提交
90
        model_name='bert_chinese_L-12_H-768_A-12',
M
MRXLT 已提交
91 92 93 94 95
        max_seq_len=20,
        show_ids=False,
        do_lower_case=True)
    config_file = './serving_client_conf/serving_client_conf.prototxt'
    fetch = ["pooled_output"]
M
MRXLT 已提交
96
    server_addr = [resource["endpoint"][idx % len(resource["endpoint"])]]
M
MRXLT 已提交
97
    bc.load_client(config_file, server_addr)
M
MRXLT 已提交
98
    batch_size = 1
M
MRXLT 已提交
99 100
    use_batch = False if batch_size == 1 else True
    feed_batch = []
G
guru4elephant 已提交
101 102 103
    start = time.time()
    fin = open("data-c.txt")
    for line in fin:
M
MRXLT 已提交
104 105 106 107 108 109 110 111 112 113 114 115
        if not use_batch:
            result = bc.run_general(line.strip(), fetch)
        else:
            if len(feed_batch) == batch_size:
                result = bc.run_batch_general(feed_batch, fetch)
                feed_batch = []
            else:
                feed_batch.append(line.strip())
    if use_batch and len(feed_batch) > 0:
        result = bc.run_batch_general(feed_batch, fetch)
        feed_batch = []

G
guru4elephant 已提交
116 117
    end = time.time()
    return [[end - start]]
M
MRXLT 已提交
118

B
barrierye 已提交
119

M
MRXLT 已提交
120
if __name__ == '__main__':
G
guru4elephant 已提交
121
    multi_thread_runner = MultiThreadRunner()
M
MRXLT 已提交
122 123 124
    result = multi_thread_runner.run(single_func, args.thread,
                                     {"endpoint": ["127.0.0.1:9292"]})
    print("time cost for each thread {}".format(result))