提交 a4a5704a 编写于 作者: M MRXLT

remove bert_rpc_client

上级 7abe8a1c
# coding:utf-8
# pylint: disable=doc-string-missing
import os
import sys
import numpy as np
import paddlehub as hub
import ujson
import random
import time
from paddlehub.common.logger import logger
import socket
from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
from bert_reader import BertReader
args = benchmark_args()
_ver = sys.version_info
is_py2 = (_ver[0] == 2)
is_py3 = (_ver[0] == 3)
if is_py2:
import httplib
if is_py3:
import http.client as httplib
class BertService():
def __init__(self,
max_seq_len=128,
model_name="bert_uncased_L-12_H-768_A-12",
show_ids=False,
do_lower_case=True,
process_id=0,
retry=3):
self.process_id = process_id
self.reader_flag = False
self.batch_size = 0
self.max_seq_len = max_seq_len
self.model_name = model_name
self.show_ids = show_ids
self.do_lower_case = do_lower_case
self.retry = retry
self.pid = os.getpid()
self.profile = True if ("FLAGS_profile_client" in os.environ and
os.environ["FLAGS_profile_client"]) else False
self.reader = BertReader(vocab_file="vocab.txt", max_seq_len=20)
self.reader_flag = True
def load_client(self, config_file, server_addr):
self.client = Client()
self.client.load_client_config(config_file)
self.client.connect(server_addr)
def run_general(self, text, fetch):
result = []
prepro_start = time.time()
feed = self.reader.process(text)
if self.profile:
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
self.pid,
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map = self.client.predict(feed=feed, fetch=fetch)
return fetch_map
def run_batch_general(self, text, fetch):
self.batch_size = len(text)
result = []
prepro_start = time.time()
feed_batch = []
for si in range(self.batch_size):
feed = self.reader.process(text[si])
feed_batch.append(feed)
prepro_end = time.time()
if self.profile:
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
self.pid,
int(round(prepro_start * 1000000)),
int(round(prepro_end * 1000000))))
fetch_map_batch = self.client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
return fetch_map_batch
def single_func(idx, resource):
bc = BertService(
model_name='bert_chinese_L-12_H-768_A-12',
max_seq_len=20,
show_ids=False,
do_lower_case=True)
config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"]
server_addr = [resource["endpoint"][idx % len(resource["endpoint"])]]
bc.load_client(config_file, server_addr)
batch_size = 1
use_batch = False if batch_size == 1 else True
feed_batch = []
start = time.time()
fin = open("data-c.txt")
for line in fin:
if not use_batch:
result = bc.run_general(line.strip(), fetch)
else:
if len(feed_batch) == batch_size:
result = bc.run_batch_general(feed_batch, fetch)
feed_batch = []
else:
feed_batch.append(line.strip())
if use_batch and len(feed_batch) > 0:
result = bc.run_batch_general(feed_batch, fetch)
feed_batch = []
end = time.time()
return [[end - start]]
if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": ["127.0.0.1:9292"]})
print("time cost for each thread {}".format(result))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册