未验证 提交 d48b1d96 编写于 作者: B Bin Long 提交者: GitHub

Merge pull request #270 from ShenYuhan/bert_as_service

fix bs multiprocessing problem
......@@ -179,18 +179,22 @@ Server[baidu::paddle_serving::predictor::bert_service::BertServiceImpl] is servi
首先导入客户端依赖。
```python
from paddlehub.serving.bert_serving import bert_service
from paddlehub.serving.bert_serving import bs_client
```
接着输入文本信息。
接着启动并初始化`bert service`客户端`BSClient`(这里的server为虚拟地址,需根据自己实际ip设置)
```python
bc = bs_client.BSClient(module_name="ernie_tiny", server="127.0.0.1:8866")
```
然后输入文本信息。
```python
input_text = [["西风吹老洞庭波"], ["一夜湘君白发多"], ["醉后不知天在水"], ["满船清梦压星河"], ]
```
然后利用客户端接口发送文本到服务端,以获取embedding结果(server为虚拟地址,需根据自己实际ip设置)。
最后利用客户端接口`get_result`发送文本到服务端,以获取embedding结果。
```python
result = bert_service.connect(
input_text=input_text,
model_name="ernie_tiny",
server="127.0.0.1:8866")
result = bc.get_result(input_text=input_text)
```
最后即可得到embedding结果(此处只展示部分结果)。
```python
......
# coding: utf8
from paddlehub.serving.bert_serving import bert_service
from paddlehub.serving.bert_serving import bs_client
if __name__ == "__main__":
# 初始化bert_service客户端BSClient
bc = bs_client.BSClient(module_name="ernie_tiny", server="127.0.0.1:8866")
# 输入要做embedding的文本
# 文本格式为[["文本1"], ["文本2"], ]
input_text = [
......@@ -10,10 +13,10 @@ if __name__ == "__main__":
["醉后不知天在水"],
["满船清梦压星河"],
]
# 调用客户端接口bert_service.connect()获取结果
result = bert_service.connect(
input_text=input_text, model_name="ernie_tiny", server="127.0.0.1:8866")
# 打印embedding结果
# BSClient.get_result()获取结果
result = bc.get_result(input_text=input_text)
# 打印输入文本的embedding结果
for item in result:
print(item)
......@@ -14,7 +14,6 @@
# limitations under the License.
import sys
import time
import paddlehub as hub
import ujson
import random
......@@ -30,7 +29,7 @@ if is_py3:
import http.client as httplib
class BertService():
class BertService(object):
def __init__(self,
profile=False,
max_seq_len=128,
......@@ -42,7 +41,7 @@ class BertService():
load_balance='round_robin'):
self.process_id = process_id
self.reader_flag = False
self.batch_size = 16
self.batch_size = 0
self.max_seq_len = max_seq_len
self.profile = profile
self.model_name = model_name
......@@ -55,15 +54,6 @@ class BertService():
self.feed_var_names = ''
self.retry = retry
def connect(self, server='127.0.0.1:8010'):
self.server_list.append(server)
def connect_all_server(self, server_list):
for server_str in server_list:
self.server_list.append(server_str)
def data_convert(self, text):
if self.reader_flag == False:
module = hub.Module(name=self.model_name)
inputs, outputs, program = module.context(
trainable=True, max_seq_len=self.max_seq_len)
......@@ -79,10 +69,14 @@ class BertService():
do_lower_case=self.do_lower_case)
self.reader_flag = True
return self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
def add_server(self, server='127.0.0.1:8010'):
self.server_list.append(server)
def infer(self, request_msg):
def add_server_list(self, server_list):
for server_str in server_list:
self.server_list.append(server_str)
def request_server(self, request_msg):
if self.load_balance == 'round_robin':
try:
cur_con = httplib.HTTPConnection(
......@@ -157,17 +151,13 @@ class BertService():
self.server_list)
return 'retry'
def encode(self, text):
if type(text) != list:
raise TypeError('Only support list')
def prepare_data(self, text):
self.batch_size = len(text)
data_generator = self.data_convert(text)
start = time.time()
request_time = 0
result = []
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
request_msg = ""
for run_step, batch in enumerate(data_generator(), start=1):
request = []
copy_start = time.time()
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist()
......@@ -184,54 +174,34 @@ class BertService():
si + 1) * self.max_seq_len]
request.append(instance_dict)
copy_time = time.time() - copy_start
request = {"instances": request}
request["max_seq_len"] = self.max_seq_len
request["feed_var_names"] = self.feed_var_names
request_msg = ujson.dumps(request)
if self.show_ids:
logger.info(request_msg)
request_start = time.time()
response_msg = self.infer(request_msg)
return request_msg
def encode(self, text):
if type(text) != list:
raise TypeError('Only support list')
request_msg = self.prepare_data(text)
response_msg = self.request_server(request_msg)
retry = 0
while type(response_msg) == str and response_msg == 'retry':
if retry < self.retry:
retry += 1
logger.info('Try to connect another servers')
response_msg = self.infer(request_msg)
response_msg = self.request_server(request_msg)
else:
logger.error('Infer failed after {} times retry'.format(
logger.error('Request failed after {} times retry'.format(
self.retry))
break
result = []
for msg in response_msg["instances"]:
for sample in msg["instances"]:
result.append(sample["values"])
request_time += time.time() - request_start
total_time = time.time() - start
if self.profile:
return [
total_time, request_time, response_msg['op_time'],
response_msg['infer_time'], copy_time
]
else:
return result
def connect(input_text,
model_name,
max_seq_len=128,
show_ids=False,
do_lower_case=True,
server="127.0.0.1:8866",
retry=3):
# format of input_text like [["As long as"],]
bc = BertService(
max_seq_len=max_seq_len,
model_name=model_name,
show_ids=show_ids,
do_lower_case=do_lower_case,
retry=retry)
bc.connect(server)
result = bc.encode(input_text)
return result
from paddlehub.serving.bert_serving import bert_service
class BSClient(object):
def __init__(self,
module_name,
server,
max_seq_len=20,
show_ids=False,
do_lower_case=True,
retry=3):
self.bs = bert_service.BertService(
model_name=module_name,
max_seq_len=max_seq_len,
show_ids=show_ids,
do_lower_case=do_lower_case,
retry=retry)
self.bs.add_server(server=server)
def get_result(self, input_text):
return self.bs.encode(input_text)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册