Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
f6fb1c20
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f6fb1c20
编写于
12月 26, 2019
作者:
M
MRXLT
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add model_name check
上级
e4263533
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
49 addition
and
15 deletion
+49
-15
paddlehub/commands/serving.py
paddlehub/commands/serving.py
+1
-1
paddlehub/serving/bert_serving/bert_service.py
paddlehub/serving/bert_serving/bert_service.py
+48
-14
未找到文件。
paddlehub/commands/serving.py
浏览文件 @
f6fb1c20
...
...
@@ -66,7 +66,7 @@ class ServingCommand(BaseCommand):
from
paddle_gpu_serving.run
import
BertServer
bs
=
BertServer
(
with_gpu
=
args
.
use_gpu
)
bs
.
with_model
(
model_name
=
args
.
modules
[
0
])
bs
.
run
(
gpu_index
=
args
.
gpu
,
port
=
args
.
port
)
bs
.
run
(
gpu_index
=
args
.
gpu
,
port
=
int
(
args
.
port
)
)
@
staticmethod
def
is_port_occupied
(
ip
,
port
):
...
...
paddlehub/serving/bert_serving/bert_service.py
浏览文件 @
f6fb1c20
...
...
@@ -18,6 +18,7 @@ import paddlehub as hub
import
ujson
import
random
from
paddlehub.common.logger
import
logger
import
socket
_ver
=
sys
.
version_info
is_py2
=
(
_ver
[
0
]
==
2
)
...
...
@@ -51,6 +52,7 @@ class BertService(object):
self
.
con_index
=
0
self
.
load_balance
=
load_balance
self
.
server_list
=
[]
self
.
serving_list
=
[]
self
.
feed_var_names
=
''
self
.
retry
=
retry
...
...
@@ -71,29 +73,58 @@ class BertService(object):
def
add_server
(
self
,
server
=
'127.0.0.1:8010'
):
self
.
server_list
.
append
(
server
)
self
.
check_server
()
def
add_server_list
(
self
,
server_list
):
for
server_str
in
server_list
:
self
.
server_list
.
append
(
server_str
)
self
.
check_server
()
def
check_server
(
self
):
for
server
in
self
.
server_list
:
client
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
server_ip
=
server
.
split
(
':'
)[
0
]
server_port
=
int
(
server
.
split
(
':'
)[
1
])
client
.
connect
((
server_ip
,
server_port
))
client
.
send
(
b
'pending server'
)
response
=
client
.
recv
(
1024
).
decode
()
response_list
=
response
.
split
(
'
\t
'
)
status_code
=
int
(
response_list
[
0
].
split
(
':'
)[
1
])
if
status_code
==
0
:
server_model
=
response_list
[
1
].
split
(
':'
)[
1
]
if
server_model
==
self
.
model_name
:
serving_port
=
response_list
[
2
].
split
(
':'
)[
1
]
serving_ip
=
server_ip
self
.
serving_list
.
append
(
serving_ip
+
':'
+
serving_port
)
else
:
logger
.
error
(
'model_name not match, server {} using : {} '
.
format
(
server
,
server_model
))
else
:
error_msg
=
response_list
[
1
]
logger
.
error
(
'connect server {} failed. {}'
.
format
(
server
,
error_msg
))
def
request_server
(
self
,
request_msg
):
if
self
.
load_balance
==
'round_robin'
:
try
:
cur_con
=
httplib
.
HTTPConnection
(
self
.
serv
er
_list
[
self
.
con_index
])
self
.
serv
ing
_list
[
self
.
con_index
])
cur_con
.
request
(
'POST'
,
"/BertService/inference"
,
request_msg
,
{
"Content-Type"
:
"application/json"
})
response
=
cur_con
.
getresponse
()
response_msg
=
response
.
read
()
response_msg
=
ujson
.
loads
(
response_msg
)
self
.
con_index
+=
1
self
.
con_index
=
self
.
con_index
%
len
(
self
.
serv
er
_list
)
self
.
con_index
=
self
.
con_index
%
len
(
self
.
serv
ing
_list
)
return
response_msg
except
BaseException
as
err
:
logger
.
warning
(
"Infer Error with server {} : {}"
.
format
(
self
.
serv
er
_list
[
self
.
con_index
],
err
))
if
len
(
self
.
serv
er
_list
)
==
0
:
self
.
serv
ing
_list
[
self
.
con_index
],
err
))
if
len
(
self
.
serv
ing
_list
)
==
0
:
logger
.
error
(
'All server failed, process will exit'
)
return
'fail'
else
:
...
...
@@ -103,10 +134,10 @@ class BertService(object):
elif
self
.
load_balance
==
'random'
:
try
:
random
.
seed
()
self
.
con_index
=
random
.
randint
(
0
,
len
(
self
.
serv
er
_list
)
-
1
)
self
.
con_index
=
random
.
randint
(
0
,
len
(
self
.
serv
ing
_list
)
-
1
)
logger
.
info
(
self
.
con_index
)
cur_con
=
httplib
.
HTTPConnection
(
self
.
serv
er
_list
[
self
.
con_index
])
self
.
serv
ing
_list
[
self
.
con_index
])
cur_con
.
request
(
'POST'
,
"/BertService/inference"
,
request_msg
,
{
"Content-Type"
:
"application/json"
})
response
=
cur_con
.
getresponse
()
...
...
@@ -117,21 +148,21 @@ class BertService(object):
except
BaseException
as
err
:
logger
.
warning
(
"Infer Error with server {} : {}"
.
format
(
self
.
serv
er
_list
[
self
.
con_index
],
err
))
if
len
(
self
.
serv
er
_list
)
==
0
:
self
.
serv
ing
_list
[
self
.
con_index
],
err
))
if
len
(
self
.
serv
ing
_list
)
==
0
:
logger
.
error
(
'All server failed, process will exit'
)
return
'fail'
else
:
self
.
con_index
=
random
.
randint
(
0
,
len
(
self
.
serv
er
_list
)
-
1
)
len
(
self
.
serv
ing
_list
)
-
1
)
return
'retry'
elif
self
.
load_balance
==
'bind'
:
try
:
self
.
con_index
=
int
(
self
.
process_id
)
%
len
(
self
.
serv
er
_list
)
self
.
con_index
=
int
(
self
.
process_id
)
%
len
(
self
.
serv
ing
_list
)
cur_con
=
httplib
.
HTTPConnection
(
self
.
serv
er
_list
[
self
.
con_index
])
self
.
serv
ing
_list
[
self
.
con_index
])
cur_con
.
request
(
'POST'
,
"/BertService/inference"
,
request_msg
,
{
"Content-Type"
:
"application/json"
})
response
=
cur_con
.
getresponse
()
...
...
@@ -142,13 +173,13 @@ class BertService(object):
except
BaseException
as
err
:
logger
.
warning
(
"Infer Error with server {} : {}"
.
format
(
self
.
serv
er
_list
[
self
.
con_index
],
err
))
if
len
(
self
.
serv
er
_list
)
==
0
:
self
.
serv
ing
_list
[
self
.
con_index
],
err
))
if
len
(
self
.
serv
ing
_list
)
==
0
:
logger
.
error
(
'All server failed, process will exit'
)
return
'fail'
else
:
self
.
con_index
=
int
(
self
.
process_id
)
%
len
(
self
.
serv
er
_list
)
self
.
serv
ing
_list
)
return
'retry'
def
prepare_data
(
self
,
text
):
...
...
@@ -184,6 +215,9 @@ class BertService(object):
return
request_msg
def
encode
(
self
,
text
):
if
len
(
self
.
serving_list
)
==
0
:
logger
.
error
(
'No match server.'
)
return
-
1
if
type
(
text
)
!=
list
:
raise
TypeError
(
'Only support list'
)
request_msg
=
self
.
prepare_data
(
text
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录