Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
9e56448b
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9e56448b
编写于
12月 19, 2019
作者:
走神的阿圆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bs multiprocessing problem
上级
5797e613
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
91 addition
and
93 deletion
+91
-93
demo/serving/bert_service/README.md
demo/serving/bert_service/README.md
+13
-9
demo/serving/bert_service/bert_service_client.py
demo/serving/bert_service/bert_service_client.py
+8
-5
paddlehub/serving/bert_serving/bert_service.py
paddlehub/serving/bert_serving/bert_service.py
+49
-79
paddlehub/serving/bert_serving/bs_client.py
paddlehub/serving/bert_serving/bs_client.py
+21
-0
未找到文件。
demo/serving/bert_service/README.md
浏览文件 @
9e56448b
...
...
@@ -179,18 +179,22 @@ Server[baidu::paddle_serving::predictor::bert_service::BertServiceImpl] is servi
首先导入客户端依赖。
```
python
from
paddlehub.serving.bert_serving
import
b
ert_service
from
paddlehub.serving.bert_serving
import
b
s_client
```
接着输入文本信息。
接着启动并初始化
`bert service`
客户端
`BSClient`
(这里的server为虚拟地址,需根据自己实际ip设置)
```
python
bc
=
bs_client
.
BSClient
(
module_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
```
然后输入文本信息。
```
python
input_text
=
[[
"西风吹老洞庭波"
],
[
"一夜湘君白发多"
],
[
"醉后不知天在水"
],
[
"满船清梦压星河"
],
]
```
然后利用客户端接口发送文本到服务端,以获取embedding结果(server为虚拟地址,需根据自己实际ip设置)。
最后利用客户端接口
`get_result`
发送文本到服务端,以获取embedding结果。
```
python
result
=
bert_service
.
connect
(
input_text
=
input_text
,
model_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
result
=
bc
.
get_result
(
input_text
=
input_text
)
```
最后即可得到embedding结果(此处只展示部分结果)。
```
python
...
...
@@ -229,8 +233,8 @@ Paddle Inference Server exit successfully!
browser.",这个页面有什么作用。
> A : 这是`BRPC`的内置服务,主要用于查看请求数、资源占用等信息,可对server端性能有大致了解,具体信息可查看[BRPC内置服务](https://github.com/apache/incubator-brpc/blob/master/docs/cn/builtin_service.md)。
> Q : 为什么输入文本的格式为[["文本1"], ["文本2"], ],而不是["文本1", "文本2", ]?
> A : 因为Bert模型可以对一轮对话生成向量表示,例如[["问题1","回答1"],["问题2","回答2"]],为了防止使用时混乱,每个样本使用一个list表示,一个样本list内部可以是1条string或2条string,如下面的文本:
> Q : 为什么输入文本的格式为[["文本1"], ["文本2"], ],而不是["文本1", "文本2", ]?
> A : 因为Bert模型可以对一轮对话生成向量表示,例如[["问题1","回答1"],["问题2","回答2"]],为了防止使用时混乱,每个样本使用一个list表示,一个样本list内部可以是1条string或2条string,如下面的文本:
> ```python
> input_text = [
> ["你今天吃饭了吗","我已经吃过饭了"],
...
...
demo/serving/bert_service/bert_service_client.py
浏览文件 @
9e56448b
# coding: utf8
from
paddlehub.serving.bert_serving
import
b
ert_service
from
paddlehub.serving.bert_serving
import
b
s_client
if
__name__
==
"__main__"
:
# 初始化bert_service客户端BSClient
bc
=
bs_client
.
BSClient
(
module_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
# 输入要做embedding的文本
# 文本格式为[["文本1"], ["文本2"], ]
input_text
=
[
...
...
@@ -10,10 +13,10 @@ if __name__ == "__main__":
[
"醉后不知天在水"
],
[
"满船清梦压星河"
],
]
# 调用客户端接口bert_service.connect()获取结果
result
=
bert_service
.
connect
(
input_text
=
input_text
,
model_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
# 打印embedding结果
# BSClient.get_result()获取结果
result
=
bc
.
get_result
(
input_text
=
input_text
)
# 打印输入文本的embedding结果
for
item
in
result
:
print
(
item
)
paddlehub/serving/bert_serving/bert_service.py
浏览文件 @
9e56448b
...
...
@@ -14,7 +14,6 @@
# limitations under the License.
import
sys
import
time
import
paddlehub
as
hub
import
ujson
import
random
...
...
@@ -30,7 +29,7 @@ if is_py3:
import
http.client
as
httplib
class
BertService
():
class
BertService
(
object
):
def
__init__
(
self
,
profile
=
False
,
max_seq_len
=
128
,
...
...
@@ -42,7 +41,7 @@ class BertService():
load_balance
=
'round_robin'
):
self
.
process_id
=
process_id
self
.
reader_flag
=
False
self
.
batch_size
=
16
self
.
batch_size
=
0
self
.
max_seq_len
=
max_seq_len
self
.
profile
=
profile
self
.
model_name
=
model_name
...
...
@@ -55,34 +54,29 @@ class BertService():
self
.
feed_var_names
=
''
self
.
retry
=
retry
def
connect
(
self
,
server
=
'127.0.0.1:8010'
):
module
=
hub
.
Module
(
name
=
self
.
model_name
)
inputs
,
outputs
,
program
=
module
.
context
(
trainable
=
True
,
max_seq_len
=
self
.
max_seq_len
)
input_ids
=
inputs
[
"input_ids"
]
position_ids
=
inputs
[
"position_ids"
]
segment_ids
=
inputs
[
"segment_ids"
]
input_mask
=
inputs
[
"input_mask"
]
self
.
feed_var_names
=
input_ids
.
name
+
';'
+
position_ids
.
name
+
';'
+
segment_ids
.
name
+
';'
+
input_mask
.
name
self
.
reader
=
hub
.
reader
.
ClassifyReader
(
vocab_path
=
module
.
get_vocab_path
(),
dataset
=
None
,
max_seq_len
=
self
.
max_seq_len
,
do_lower_case
=
self
.
do_lower_case
)
self
.
reader_flag
=
True
def
add_server
(
self
,
server
=
'127.0.0.1:8010'
):
self
.
server_list
.
append
(
server
)
def
connect_all_server
(
self
,
server_list
):
def
add_server_list
(
self
,
server_list
):
for
server_str
in
server_list
:
self
.
server_list
.
append
(
server_str
)
def
data_convert
(
self
,
text
):
if
self
.
reader_flag
==
False
:
module
=
hub
.
Module
(
name
=
self
.
model_name
)
inputs
,
outputs
,
program
=
module
.
context
(
trainable
=
True
,
max_seq_len
=
self
.
max_seq_len
)
input_ids
=
inputs
[
"input_ids"
]
position_ids
=
inputs
[
"position_ids"
]
segment_ids
=
inputs
[
"segment_ids"
]
input_mask
=
inputs
[
"input_mask"
]
self
.
feed_var_names
=
input_ids
.
name
+
';'
+
position_ids
.
name
+
';'
+
segment_ids
.
name
+
';'
+
input_mask
.
name
self
.
reader
=
hub
.
reader
.
ClassifyReader
(
vocab_path
=
module
.
get_vocab_path
(),
dataset
=
None
,
max_seq_len
=
self
.
max_seq_len
,
do_lower_case
=
self
.
do_lower_case
)
self
.
reader_flag
=
True
return
self
.
reader
.
data_generator
(
batch_size
=
self
.
batch_size
,
phase
=
'predict'
,
data
=
text
)
def
infer
(
self
,
request_msg
):
def
request_server
(
self
,
request_msg
):
if
self
.
load_balance
==
'round_robin'
:
try
:
cur_con
=
httplib
.
HTTPConnection
(
...
...
@@ -157,17 +151,13 @@ class BertService():
self
.
server_list
)
return
'retry'
def
encode
(
self
,
text
):
if
type
(
text
)
!=
list
:
raise
TypeError
(
'Only support list'
)
def
prepare_data
(
self
,
text
):
self
.
batch_size
=
len
(
text
)
data_generator
=
self
.
data_convert
(
text
)
start
=
time
.
time
()
request_time
=
0
result
=
[]
data_generator
=
self
.
reader
.
data_generator
(
batch_size
=
self
.
batch_size
,
phase
=
'predict'
,
data
=
text
)
request_msg
=
""
for
run_step
,
batch
in
enumerate
(
data_generator
(),
start
=
1
):
request
=
[]
copy_start
=
time
.
time
()
token_list
=
batch
[
0
][
0
].
reshape
(
-
1
).
tolist
()
pos_list
=
batch
[
0
][
1
].
reshape
(
-
1
).
tolist
()
sent_list
=
batch
[
0
][
2
].
reshape
(
-
1
).
tolist
()
...
...
@@ -184,54 +174,34 @@ class BertService():
si
+
1
)
*
self
.
max_seq_len
]
request
.
append
(
instance_dict
)
copy_time
=
time
.
time
()
-
copy_start
request
=
{
"instances"
:
request
}
request
[
"max_seq_len"
]
=
self
.
max_seq_len
request
[
"feed_var_names"
]
=
self
.
feed_var_names
request_msg
=
ujson
.
dumps
(
request
)
if
self
.
show_ids
:
logger
.
info
(
request_msg
)
request_start
=
time
.
time
()
response_msg
=
self
.
infer
(
request_msg
)
retry
=
0
while
type
(
response_msg
)
==
str
and
response_msg
==
'retry'
:
if
retry
<
self
.
retry
:
retry
+=
1
logger
.
info
(
'Try to connect another servers'
)
response_msg
=
self
.
infer
(
request_msg
)
else
:
logger
.
error
(
'Infer failed after {} times retry'
.
format
(
self
.
retry
))
break
for
msg
in
response_msg
[
"instances"
]:
for
sample
in
msg
[
"instances"
]:
result
.
append
(
sample
[
"values"
])
request_time
+=
time
.
time
()
-
request_start
total_time
=
time
.
time
()
-
start
if
self
.
profile
:
return
[
total_time
,
request_time
,
response_msg
[
'op_time'
],
response_msg
[
'infer_time'
],
copy_time
]
else
:
return
result
def
connect
(
input_text
,
model_name
,
max_seq_len
=
128
,
show_ids
=
False
,
do_lower_case
=
True
,
server
=
"127.0.0.1:8866"
,
retry
=
3
):
# format of input_text like [["As long as"],]
bc
=
BertService
(
max_seq_len
=
max_seq_len
,
model_name
=
model_name
,
show_ids
=
show_ids
,
do_lower_case
=
do_lower_case
,
retry
=
retry
)
bc
.
connect
(
server
)
result
=
bc
.
encode
(
input_text
)
return
result
return
request_msg
def
encode
(
self
,
text
):
if
type
(
text
)
!=
list
:
raise
TypeError
(
'Only support list'
)
request_msg
=
self
.
prepare_data
(
text
)
response_msg
=
self
.
request_server
(
request_msg
)
retry
=
0
while
type
(
response_msg
)
==
str
and
response_msg
==
'retry'
:
if
retry
<
self
.
retry
:
retry
+=
1
logger
.
info
(
'Try to connect another servers'
)
response_msg
=
self
.
request_server
(
request_msg
)
else
:
logger
.
error
(
'Request failed after {} times retry'
.
format
(
self
.
retry
))
break
result
=
[]
for
msg
in
response_msg
[
"instances"
]:
for
sample
in
msg
[
"instances"
]:
result
.
append
(
sample
[
"values"
])
return
result
paddlehub/serving/bert_serving/bs_client.py
0 → 100644
浏览文件 @
9e56448b
from
paddlehub.serving.bert_serving
import
bert_service
class
BSClient
(
object
):
def
__init__
(
self
,
module_name
,
server
,
max_seq_len
=
20
,
show_ids
=
False
,
do_lower_case
=
True
,
retry
=
3
):
self
.
bs
=
bert_service
.
BertService
(
model_name
=
module_name
,
max_seq_len
=
max_seq_len
,
show_ids
=
show_ids
,
do_lower_case
=
do_lower_case
,
retry
=
retry
)
self
.
bs
.
add_server
(
server
=
server
)
def
get_result
(
self
,
input_text
):
return
self
.
bs
.
encode
(
input_text
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录