diff --git a/demo/serving/bert_service/README.md b/demo/serving/bert_service/README.md index 8c8d52d0372032c3a5c92b34f63e48c0193bba66..1df7a1c580311cdf95d484ce75c5b000887a1b27 100644 --- a/demo/serving/bert_service/README.md +++ b/demo/serving/bert_service/README.md @@ -179,18 +179,22 @@ Server[baidu::paddle_serving::predictor::bert_service::BertServiceImpl] is servi 首先导入客户端依赖。 ```python -from paddlehub.serving.bert_serving import bert_service +from paddlehub.serving.bert_serving import bs_client ``` -接着输入文本信息。 + +接着启动并初始化`bert service`客户端`BSClient`(这里的server为虚拟地址,需根据自己实际ip设置) +```python +bc = bs_client.BSClient(module_name="ernie_tiny", server="127.0.0.1:8866") +``` + +然后输入文本信息。 ```python input_text = [["西风吹老洞庭波"], ["一夜湘君白发多"], ["醉后不知天在水"], ["满船清梦压星河"], ] ``` -然后利用客户端接口发送文本到服务端,以获取embedding结果(server为虚拟地址,需根据自己实际ip设置)。 + +最后利用客户端接口`get_result`发送文本到服务端,以获取embedding结果。 ```python -result = bert_service.connect( - input_text=input_text, - model_name="ernie_tiny", - server="127.0.0.1:8866") +result = bc.get_result(input_text=input_text) ``` 最后即可得到embedding结果(此处只展示部分结果)。 ```python @@ -229,8 +233,8 @@ Paddle Inference Server exit successfully! browser.",这个页面有什么作用。 > A : 这是`BRPC`的内置服务,主要用于查看请求数、资源占用等信息,可对server端性能有大致了解,具体信息可查看[BRPC内置服务](https://github.com/apache/incubator-brpc/blob/master/docs/cn/builtin_service.md)。 -> Q : 为什么输入文本的格式为[["文本1"], ["文本2"], ],而不是["文本1", "文本2", ]? -> A : 因为Bert模型可以对一轮对话生成向量表示,例如[["问题1","回答1"],["问题2","回答2"]],为了防止使用时混乱,每个样本使用一个list表示,一个样本list内部可以是1条string或2条string,如下面的文本: +> Q : 为什么输入文本的格式为[["文本1"], ["文本2"], ],而不是["文本1", "文本2", ]? +> A : 因为Bert模型可以对一轮对话生成向量表示,例如[["问题1","回答1"],["问题2","回答2"]],为了防止使用时混乱,每个样本使用一个list表示,一个样本list内部可以是1条string或2条string,如下面的文本: > ```python > input_text = [ > ["你今天吃饭了吗","我已经吃过饭了"], diff --git a/demo/serving/bert_service/bert_service_client.py b/demo/serving/bert_service/bert_service_client.py index a8c2533641301bc0659699cd54a7e15fcce9bda3..a7a02183ea7279707070c380ec909b82de0ea0db 100644 --- a/demo/serving/bert_service/bert_service_client.py +++ b/demo/serving/bert_service/bert_service_client.py @@ -1,7 +1,10 @@ # coding: utf8 -from paddlehub.serving.bert_serving import bert_service +from paddlehub.serving.bert_serving import bs_client if __name__ == "__main__": + # 初始化bert_service客户端BSClient + bc = bs_client.BSClient(module_name="ernie_tiny", server="127.0.0.1:8866") + # 输入要做embedding的文本 # 文本格式为[["文本1"], ["文本2"], ] input_text = [ @@ -10,10 +13,10 @@ if __name__ == "__main__": ["醉后不知天在水"], ["满船清梦压星河"], ] - # 调用客户端接口bert_service.connect()获取结果 - result = bert_service.connect( - input_text=input_text, model_name="ernie_tiny", server="127.0.0.1:8866") - # 打印embedding结果 + # BSClient.get_result()获取结果 + result = bc.get_result(input_text=input_text) + + # 打印输入文本的embedding结果 for item in result: print(item) diff --git a/paddlehub/serving/bert_serving/bert_service.py b/paddlehub/serving/bert_serving/bert_service.py index fa873d9d1b7010a847caddbc7f75da56311f772c..d78698fa965fabe9664e5089792f56093e33e792 100644 --- a/paddlehub/serving/bert_serving/bert_service.py +++ b/paddlehub/serving/bert_serving/bert_service.py @@ -14,7 +14,6 @@ # limitations under the License. import sys -import time import paddlehub as hub import ujson import random @@ -30,7 +29,7 @@ if is_py3: import http.client as httplib -class BertService(): +class BertService(object): def __init__(self, profile=False, max_seq_len=128, @@ -42,7 +41,7 @@ class BertService(): load_balance='round_robin'): self.process_id = process_id self.reader_flag = False - self.batch_size = 16 + self.batch_size = 0 self.max_seq_len = max_seq_len self.profile = profile self.model_name = model_name @@ -55,34 +54,29 @@ class BertService(): self.feed_var_names = '' self.retry = retry - def connect(self, server='127.0.0.1:8010'): + module = hub.Module(name=self.model_name) + inputs, outputs, program = module.context( + trainable=True, max_seq_len=self.max_seq_len) + input_ids = inputs["input_ids"] + position_ids = inputs["position_ids"] + segment_ids = inputs["segment_ids"] + input_mask = inputs["input_mask"] + self.feed_var_names = input_ids.name + ';' + position_ids.name + ';' + segment_ids.name + ';' + input_mask.name + self.reader = hub.reader.ClassifyReader( + vocab_path=module.get_vocab_path(), + dataset=None, + max_seq_len=self.max_seq_len, + do_lower_case=self.do_lower_case) + self.reader_flag = True + + def add_server(self, server='127.0.0.1:8010'): self.server_list.append(server) - def connect_all_server(self, server_list): + def add_server_list(self, server_list): for server_str in server_list: self.server_list.append(server_str) - def data_convert(self, text): - if self.reader_flag == False: - module = hub.Module(name=self.model_name) - inputs, outputs, program = module.context( - trainable=True, max_seq_len=self.max_seq_len) - input_ids = inputs["input_ids"] - position_ids = inputs["position_ids"] - segment_ids = inputs["segment_ids"] - input_mask = inputs["input_mask"] - self.feed_var_names = input_ids.name + ';' + position_ids.name + ';' + segment_ids.name + ';' + input_mask.name - self.reader = hub.reader.ClassifyReader( - vocab_path=module.get_vocab_path(), - dataset=None, - max_seq_len=self.max_seq_len, - do_lower_case=self.do_lower_case) - self.reader_flag = True - - return self.reader.data_generator( - batch_size=self.batch_size, phase='predict', data=text) - - def infer(self, request_msg): + def request_server(self, request_msg): if self.load_balance == 'round_robin': try: cur_con = httplib.HTTPConnection( @@ -157,17 +151,13 @@ class BertService(): self.server_list) return 'retry' - def encode(self, text): - if type(text) != list: - raise TypeError('Only support list') + def prepare_data(self, text): self.batch_size = len(text) - data_generator = self.data_convert(text) - start = time.time() - request_time = 0 - result = [] + data_generator = self.reader.data_generator( + batch_size=self.batch_size, phase='predict', data=text) + request_msg = "" for run_step, batch in enumerate(data_generator(), start=1): request = [] - copy_start = time.time() token_list = batch[0][0].reshape(-1).tolist() pos_list = batch[0][1].reshape(-1).tolist() sent_list = batch[0][2].reshape(-1).tolist() @@ -184,54 +174,34 @@ class BertService(): si + 1) * self.max_seq_len] request.append(instance_dict) - copy_time = time.time() - copy_start request = {"instances": request} request["max_seq_len"] = self.max_seq_len request["feed_var_names"] = self.feed_var_names request_msg = ujson.dumps(request) if self.show_ids: logger.info(request_msg) - request_start = time.time() - response_msg = self.infer(request_msg) - retry = 0 - while type(response_msg) == str and response_msg == 'retry': - if retry < self.retry: - retry += 1 - logger.info('Try to connect another servers') - response_msg = self.infer(request_msg) - else: - logger.error('Infer failed after {} times retry'.format( - self.retry)) - break - for msg in response_msg["instances"]: - for sample in msg["instances"]: - result.append(sample["values"]) - - request_time += time.time() - request_start - total_time = time.time() - start - if self.profile: - return [ - total_time, request_time, response_msg['op_time'], - response_msg['infer_time'], copy_time - ] - else: - return result - - -def connect(input_text, - model_name, - max_seq_len=128, - show_ids=False, - do_lower_case=True, - server="127.0.0.1:8866", - retry=3): - # format of input_text like [["As long as"],] - bc = BertService( - max_seq_len=max_seq_len, - model_name=model_name, - show_ids=show_ids, - do_lower_case=do_lower_case, - retry=retry) - bc.connect(server) - result = bc.encode(input_text) - return result + + return request_msg + + def encode(self, text): + if type(text) != list: + raise TypeError('Only support list') + request_msg = self.prepare_data(text) + + response_msg = self.request_server(request_msg) + retry = 0 + while type(response_msg) == str and response_msg == 'retry': + if retry < self.retry: + retry += 1 + logger.info('Try to connect another servers') + response_msg = self.request_server(request_msg) + else: + logger.error('Request failed after {} times retry'.format( + self.retry)) + break + result = [] + for msg in response_msg["instances"]: + for sample in msg["instances"]: + result.append(sample["values"]) + + return result diff --git a/paddlehub/serving/bert_serving/bs_client.py b/paddlehub/serving/bert_serving/bs_client.py new file mode 100644 index 0000000000000000000000000000000000000000..f2871a6109e15d11cc534431a698bdd471c74652 --- /dev/null +++ b/paddlehub/serving/bert_serving/bs_client.py @@ -0,0 +1,21 @@ +from paddlehub.serving.bert_serving import bert_service + + +class BSClient(object): + def __init__(self, + module_name, + server, + max_seq_len=20, + show_ids=False, + do_lower_case=True, + retry=3): + self.bs = bert_service.BertService( + model_name=module_name, + max_seq_len=max_seq_len, + show_ids=show_ids, + do_lower_case=do_lower_case, + retry=retry) + self.bs.add_server(server=server) + + def get_result(self, input_text): + return self.bs.encode(input_text)