未验证 提交 d48b1d96 编写于 作者: B Bin Long 提交者: GitHub

Merge pull request #270 from ShenYuhan/bert_as_service

fix bs multiprocessing problem
...@@ -179,18 +179,22 @@ Server[baidu::paddle_serving::predictor::bert_service::BertServiceImpl] is servi ...@@ -179,18 +179,22 @@ Server[baidu::paddle_serving::predictor::bert_service::BertServiceImpl] is servi
首先导入客户端依赖。 首先导入客户端依赖。
```python ```python
from paddlehub.serving.bert_serving import bert_service from paddlehub.serving.bert_serving import bs_client
``` ```
接着输入文本信息。
接着启动并初始化`bert service`客户端`BSClient`(这里的server为虚拟地址,需根据自己实际ip设置)
```python
bc = bs_client.BSClient(module_name="ernie_tiny", server="127.0.0.1:8866")
```
然后输入文本信息。
```python ```python
input_text = [["西风吹老洞庭波"], ["一夜湘君白发多"], ["醉后不知天在水"], ["满船清梦压星河"], ] input_text = [["西风吹老洞庭波"], ["一夜湘君白发多"], ["醉后不知天在水"], ["满船清梦压星河"], ]
``` ```
然后利用客户端接口发送文本到服务端,以获取embedding结果(server为虚拟地址,需根据自己实际ip设置)。
最后利用客户端接口`get_result`发送文本到服务端,以获取embedding结果。
```python ```python
result = bert_service.connect( result = bc.get_result(input_text=input_text)
input_text=input_text,
model_name="ernie_tiny",
server="127.0.0.1:8866")
``` ```
最后即可得到embedding结果(此处只展示部分结果)。 最后即可得到embedding结果(此处只展示部分结果)。
```python ```python
...@@ -229,8 +233,8 @@ Paddle Inference Server exit successfully! ...@@ -229,8 +233,8 @@ Paddle Inference Server exit successfully!
browser.",这个页面有什么作用。 browser.",这个页面有什么作用。
> A : 这是`BRPC`的内置服务,主要用于查看请求数、资源占用等信息,可对server端性能有大致了解,具体信息可查看[BRPC内置服务](https://github.com/apache/incubator-brpc/blob/master/docs/cn/builtin_service.md)。 > A : 这是`BRPC`的内置服务,主要用于查看请求数、资源占用等信息,可对server端性能有大致了解,具体信息可查看[BRPC内置服务](https://github.com/apache/incubator-brpc/blob/master/docs/cn/builtin_service.md)。
> Q : 为什么输入文本的格式为[["文本1"], ["文本2"], ],而不是["文本1", "文本2", ]? > Q : 为什么输入文本的格式为[["文本1"], ["文本2"], ],而不是["文本1", "文本2", ]?
> A : 因为Bert模型可以对一轮对话生成向量表示,例如[["问题1","回答1"],["问题2","回答2"]],为了防止使用时混乱,每个样本使用一个list表示,一个样本list内部可以是1条string或2条string,如下面的文本: > A : 因为Bert模型可以对一轮对话生成向量表示,例如[["问题1","回答1"],["问题2","回答2"]],为了防止使用时混乱,每个样本使用一个list表示,一个样本list内部可以是1条string或2条string,如下面的文本:
> ```python > ```python
> input_text = [ > input_text = [
> ["你今天吃饭了吗","我已经吃过饭了"], > ["你今天吃饭了吗","我已经吃过饭了"],
......
# coding: utf8 # coding: utf8
from paddlehub.serving.bert_serving import bert_service from paddlehub.serving.bert_serving import bs_client
if __name__ == "__main__": if __name__ == "__main__":
# 初始化bert_service客户端BSClient
bc = bs_client.BSClient(module_name="ernie_tiny", server="127.0.0.1:8866")
# 输入要做embedding的文本 # 输入要做embedding的文本
# 文本格式为[["文本1"], ["文本2"], ] # 文本格式为[["文本1"], ["文本2"], ]
input_text = [ input_text = [
...@@ -10,10 +13,10 @@ if __name__ == "__main__": ...@@ -10,10 +13,10 @@ if __name__ == "__main__":
["醉后不知天在水"], ["醉后不知天在水"],
["满船清梦压星河"], ["满船清梦压星河"],
] ]
# 调用客户端接口bert_service.connect()获取结果
result = bert_service.connect(
input_text=input_text, model_name="ernie_tiny", server="127.0.0.1:8866")
# 打印embedding结果 # BSClient.get_result()获取结果
result = bc.get_result(input_text=input_text)
# 打印输入文本的embedding结果
for item in result: for item in result:
print(item) print(item)
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import sys import sys
import time
import paddlehub as hub import paddlehub as hub
import ujson import ujson
import random import random
...@@ -30,7 +29,7 @@ if is_py3: ...@@ -30,7 +29,7 @@ if is_py3:
import http.client as httplib import http.client as httplib
class BertService(): class BertService(object):
def __init__(self, def __init__(self,
profile=False, profile=False,
max_seq_len=128, max_seq_len=128,
...@@ -42,7 +41,7 @@ class BertService(): ...@@ -42,7 +41,7 @@ class BertService():
load_balance='round_robin'): load_balance='round_robin'):
self.process_id = process_id self.process_id = process_id
self.reader_flag = False self.reader_flag = False
self.batch_size = 16 self.batch_size = 0
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.profile = profile self.profile = profile
self.model_name = model_name self.model_name = model_name
...@@ -55,34 +54,29 @@ class BertService(): ...@@ -55,34 +54,29 @@ class BertService():
self.feed_var_names = '' self.feed_var_names = ''
self.retry = retry self.retry = retry
def connect(self, server='127.0.0.1:8010'): module = hub.Module(name=self.model_name)
inputs, outputs, program = module.context(
trainable=True, max_seq_len=self.max_seq_len)
input_ids = inputs["input_ids"]
position_ids = inputs["position_ids"]
segment_ids = inputs["segment_ids"]
input_mask = inputs["input_mask"]
self.feed_var_names = input_ids.name + ';' + position_ids.name + ';' + segment_ids.name + ';' + input_mask.name
self.reader = hub.reader.ClassifyReader(
vocab_path=module.get_vocab_path(),
dataset=None,
max_seq_len=self.max_seq_len,
do_lower_case=self.do_lower_case)
self.reader_flag = True
def add_server(self, server='127.0.0.1:8010'):
self.server_list.append(server) self.server_list.append(server)
def connect_all_server(self, server_list): def add_server_list(self, server_list):
for server_str in server_list: for server_str in server_list:
self.server_list.append(server_str) self.server_list.append(server_str)
def data_convert(self, text): def request_server(self, request_msg):
if self.reader_flag == False:
module = hub.Module(name=self.model_name)
inputs, outputs, program = module.context(
trainable=True, max_seq_len=self.max_seq_len)
input_ids = inputs["input_ids"]
position_ids = inputs["position_ids"]
segment_ids = inputs["segment_ids"]
input_mask = inputs["input_mask"]
self.feed_var_names = input_ids.name + ';' + position_ids.name + ';' + segment_ids.name + ';' + input_mask.name
self.reader = hub.reader.ClassifyReader(
vocab_path=module.get_vocab_path(),
dataset=None,
max_seq_len=self.max_seq_len,
do_lower_case=self.do_lower_case)
self.reader_flag = True
return self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
def infer(self, request_msg):
if self.load_balance == 'round_robin': if self.load_balance == 'round_robin':
try: try:
cur_con = httplib.HTTPConnection( cur_con = httplib.HTTPConnection(
...@@ -157,17 +151,13 @@ class BertService(): ...@@ -157,17 +151,13 @@ class BertService():
self.server_list) self.server_list)
return 'retry' return 'retry'
def encode(self, text): def prepare_data(self, text):
if type(text) != list:
raise TypeError('Only support list')
self.batch_size = len(text) self.batch_size = len(text)
data_generator = self.data_convert(text) data_generator = self.reader.data_generator(
start = time.time() batch_size=self.batch_size, phase='predict', data=text)
request_time = 0 request_msg = ""
result = []
for run_step, batch in enumerate(data_generator(), start=1): for run_step, batch in enumerate(data_generator(), start=1):
request = [] request = []
copy_start = time.time()
token_list = batch[0][0].reshape(-1).tolist() token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist() pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist() sent_list = batch[0][2].reshape(-1).tolist()
...@@ -184,54 +174,34 @@ class BertService(): ...@@ -184,54 +174,34 @@ class BertService():
si + 1) * self.max_seq_len] si + 1) * self.max_seq_len]
request.append(instance_dict) request.append(instance_dict)
copy_time = time.time() - copy_start
request = {"instances": request} request = {"instances": request}
request["max_seq_len"] = self.max_seq_len request["max_seq_len"] = self.max_seq_len
request["feed_var_names"] = self.feed_var_names request["feed_var_names"] = self.feed_var_names
request_msg = ujson.dumps(request) request_msg = ujson.dumps(request)
if self.show_ids: if self.show_ids:
logger.info(request_msg) logger.info(request_msg)
request_start = time.time()
response_msg = self.infer(request_msg) return request_msg
retry = 0
while type(response_msg) == str and response_msg == 'retry': def encode(self, text):
if retry < self.retry: if type(text) != list:
retry += 1 raise TypeError('Only support list')
logger.info('Try to connect another servers') request_msg = self.prepare_data(text)
response_msg = self.infer(request_msg)
else: response_msg = self.request_server(request_msg)
logger.error('Infer failed after {} times retry'.format( retry = 0
self.retry)) while type(response_msg) == str and response_msg == 'retry':
break if retry < self.retry:
for msg in response_msg["instances"]: retry += 1
for sample in msg["instances"]: logger.info('Try to connect another servers')
result.append(sample["values"]) response_msg = self.request_server(request_msg)
else:
request_time += time.time() - request_start logger.error('Request failed after {} times retry'.format(
total_time = time.time() - start self.retry))
if self.profile: break
return [ result = []
total_time, request_time, response_msg['op_time'], for msg in response_msg["instances"]:
response_msg['infer_time'], copy_time for sample in msg["instances"]:
] result.append(sample["values"])
else:
return result return result
def connect(input_text,
model_name,
max_seq_len=128,
show_ids=False,
do_lower_case=True,
server="127.0.0.1:8866",
retry=3):
# format of input_text like [["As long as"],]
bc = BertService(
max_seq_len=max_seq_len,
model_name=model_name,
show_ids=show_ids,
do_lower_case=do_lower_case,
retry=retry)
bc.connect(server)
result = bc.encode(input_text)
return result
from paddlehub.serving.bert_serving import bert_service
class BSClient(object):
def __init__(self,
module_name,
server,
max_seq_len=20,
show_ids=False,
do_lower_case=True,
retry=3):
self.bs = bert_service.BertService(
model_name=module_name,
max_seq_len=max_seq_len,
show_ids=show_ids,
do_lower_case=do_lower_case,
retry=retry)
self.bs.add_server(server=server)
def get_result(self, input_text):
return self.bs.encode(input_text)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册