Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
fa6f8d16
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
280
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
fa6f8d16
编写于
12月 23, 2019
作者:
B
Bin Long
提交者:
GitHub
12月 23, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into add_search_tip
上级
109d5293
81dfda51
变更
19
显示空白变更内容
内联
并排
Showing
19 changed file
with
704 addition
and
529 deletion
+704
-529
.travis.yml
.travis.yml
+0
-10
demo/serving/bert_service/README.md
demo/serving/bert_service/README.md
+16
-12
demo/serving/bert_service/bert_service_client.py
demo/serving/bert_service/bert_service_client.py
+8
-5
paddlehub/__init__.py
paddlehub/__init__.py
+1
-1
paddlehub/autofinetune/autoft.py
paddlehub/autofinetune/autoft.py
+30
-5
paddlehub/autofinetune/mpi_helper.py
paddlehub/autofinetune/mpi_helper.py
+115
-0
paddlehub/commands/autofinetune.py
paddlehub/commands/autofinetune.py
+37
-12
paddlehub/commands/install.py
paddlehub/commands/install.py
+17
-7
paddlehub/commands/run.py
paddlehub/commands/run.py
+29
-26
paddlehub/commands/serving.py
paddlehub/commands/serving.py
+4
-2
paddlehub/commands/show.py
paddlehub/commands/show.py
+0
-2
paddlehub/common/hub_server.py
paddlehub/common/hub_server.py
+13
-7
paddlehub/module/check_info.proto
paddlehub/module/check_info.proto
+3
-2
paddlehub/module/check_info_pb2.py
paddlehub/module/check_info_pb2.py
+26
-11
paddlehub/module/checker.py
paddlehub/module/checker.py
+23
-13
paddlehub/module/manager.py
paddlehub/module/manager.py
+80
-8
paddlehub/module/module.py
paddlehub/module/module.py
+232
-327
paddlehub/serving/bert_serving/bert_service.py
paddlehub/serving/bert_serving/bert_service.py
+49
-79
paddlehub/serving/bert_serving/bs_client.py
paddlehub/serving/bert_serving/bs_client.py
+21
-0
未找到文件。
.travis.yml
浏览文件 @
fa6f8d16
...
...
@@ -16,12 +16,6 @@ jobs:
os
:
linux
python
:
3.6
script
:
/bin/bash ./scripts/check_code_style.sh
-
name
:
"
CI
on
Linux/Python3.5"
os
:
linux
python
:
3.5
-
name
:
"
CI
on
Linux/Python2.7"
os
:
linux
python
:
2.7
env
:
-
PYTHONPATH=${PWD}
...
...
@@ -30,10 +24,6 @@ install:
-
pip install --upgrade paddlepaddle
-
pip install -r requirements.txt
script
:
-
if [[ $TRAVIS_OS_NAME != "windows" ]]; then /bin/bash ./scripts/test_cml.sh; fi
-
if [[ $TRAVIS_OS_NAME != "windows" ]]; then /bin/bash ./scripts/test_all_module.sh; fi
notifications
:
email
:
on_success
:
change
...
...
demo/serving/bert_service/README.md
浏览文件 @
fa6f8d16
...
...
@@ -68,7 +68,7 @@ $ pip install ujson
|模型|网络|
|:-|:-:|
|
[
ERNIE
](
https://paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel
)
|ERNIE|
|
[
ernie
](
https://paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel
)
|ERNIE|
|
[
ernie_tiny
](
https://paddlepaddle.org.cn/hubdetail?name=ernie_tiny&en_category=SemanticModel
)
|ERNIE|
|
[
ernie_v2_eng_large
](
https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_large&en_category=SemanticModel
)
|ERNIE|
|
[
ernie_v2_eng_base
](
https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_base&en_category=SemanticModel
)
|ERNIE|
...
...
@@ -179,18 +179,22 @@ Server[baidu::paddle_serving::predictor::bert_service::BertServiceImpl] is servi
首先导入客户端依赖。
```
python
from
paddlehub.serving.bert_serving
import
b
ert_service
from
paddlehub.serving.bert_serving
import
b
s_client
```
接着输入文本信息。
接着启动并初始化
`bert service`
客户端
`BSClient`
(这里的server为虚拟地址,需根据自己实际ip设置)
```
python
bc
=
bs_client
.
BSClient
(
module_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
```
然后输入文本信息。
```
python
input_text
=
[[
"西风吹老洞庭波"
],
[
"一夜湘君白发多"
],
[
"醉后不知天在水"
],
[
"满船清梦压星河"
],
]
```
然后利用客户端接口发送文本到服务端,以获取embedding结果(server为虚拟地址,需根据自己实际ip设置)。
最后利用客户端接口
`get_result`
发送文本到服务端,以获取embedding结果。
```
python
result
=
bert_service
.
connect
(
input_text
=
input_text
,
model_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
result
=
bc
.
get_result
(
input_text
=
input_text
)
```
最后即可得到embedding结果(此处只展示部分结果)。
```
python
...
...
@@ -221,8 +225,8 @@ Paddle Inference Server exit successfully!
> Q : 如何在一台服务器部署多个模型?
> A : 可通过多次启动`Bert Service`,分配不同端口实现。如果使用GPU,需要指定不同的显卡。如同时部署`ernie`和`bert_chinese_L-12_H-768_A-12`,分别执行命令如下:
> ```shell
> $ hub serving start bert_servi
ng
-m ernie -p 8866
> $ hub serving start bert_servi
ng -m bert_serving
-m bert_chinese_L-12_H-768_A-12 -p 8867
> $ hub serving start bert_servi
ce
-m ernie -p 8866
> $ hub serving start bert_servi
ce
-m bert_chinese_L-12_H-768_A-12 -p 8867
> ```
> Q : 启动时显示"Check out http://yq01-gpu-255-129-12-00.epc.baidu.com:8887 in web
...
...
demo/serving/bert_service/bert_service_client.py
浏览文件 @
fa6f8d16
# coding: utf8
from
paddlehub.serving.bert_serving
import
b
ert_service
from
paddlehub.serving.bert_serving
import
b
s_client
if
__name__
==
"__main__"
:
# 初始化bert_service客户端BSClient
bc
=
bs_client
.
BSClient
(
module_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
# 输入要做embedding的文本
# 文本格式为[["文本1"], ["文本2"], ]
input_text
=
[
...
...
@@ -10,10 +13,10 @@ if __name__ == "__main__":
[
"醉后不知天在水"
],
[
"满船清梦压星河"
],
]
# 调用客户端接口bert_service.connect()获取结果
result
=
bert_service
.
connect
(
input_text
=
input_text
,
model_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
# 打印embedding结果
# BSClient.get_result()获取结果
result
=
bc
.
get_result
(
input_text
=
input_text
)
# 打印输入文本的embedding结果
for
item
in
result
:
print
(
item
)
paddlehub/__init__.py
浏览文件 @
fa6f8d16
...
...
@@ -38,7 +38,7 @@ from .common.logger import logger
from
.common.paddle_helper
import
connect_program
from
.common.hub_server
import
default_hub_server
from
.module.module
import
Module
,
create_module
from
.module.module
import
Module
from
.module.base_processor
import
BaseProcessor
from
.module.signature
import
Signature
,
create_signature
from
.module.manager
import
default_module_manager
...
...
paddlehub/autofinetune/autoft.py
浏览文件 @
fa6f8d16
...
...
@@ -26,6 +26,7 @@ from tb_paddle import SummaryWriter
from
paddlehub.common.logger
import
logger
from
paddlehub.common.utils
import
mkdir
from
paddlehub.autofinetune.evaluator
import
REWARD_SUM
,
TMP_HOME
from
paddlehub.autofinetune.mpi_helper
import
MPIHelper
if
six
.
PY3
:
INF
=
math
.
inf
...
...
@@ -75,6 +76,12 @@ class BaseTuningStrategy(object):
logdir
=
self
.
_output_dir
+
'/visualization/pop_{}'
.
format
(
i
))
self
.
writer_pop_trails
.
append
(
writer_pop_trail
)
# for parallel on mpi
self
.
mpi
=
MPIHelper
()
if
self
.
mpi
.
multi_machine
:
print
(
"Autofinetune multimachine mode: running on {}"
.
format
(
self
.
mpi
.
gather
(
self
.
mpi
.
name
)))
@
property
def
thread
(
self
):
return
self
.
_num_thread
...
...
@@ -177,16 +184,22 @@ class BaseTuningStrategy(object):
solutions_modeldirs
=
{}
mkdir
(
output_dir
)
for
idx
,
solution
in
enumerate
(
solutions
):
solutions
=
self
.
mpi
.
bcast
(
solutions
)
# split solutions to "solutions for me"
range_start
,
range_end
=
self
.
mpi
.
split_range
(
len
(
solutions
))
my_solutions
=
solutions
[
range_start
:
range_end
]
for
idx
,
solution
in
enumerate
(
my_solutions
):
cuda
=
self
.
is_cuda_free
[
"free"
][
0
]
modeldir
=
output_dir
+
"/model-"
+
str
(
idx
)
+
"/"
log_file
=
output_dir
+
"/log-"
+
str
(
idx
)
+
".info"
params_cudas_dirs
.
append
([
solution
,
cuda
,
modeldir
,
log_file
])
solutions_modeldirs
[
tuple
(
solution
)]
=
modeldir
solutions_modeldirs
[
tuple
(
solution
)]
=
(
modeldir
,
self
.
mpi
.
rank
)
self
.
is_cuda_free
[
"free"
].
remove
(
cuda
)
self
.
is_cuda_free
[
"busy"
].
append
(
cuda
)
if
len
(
params_cudas_dirs
)
==
self
.
thread
or
idx
==
len
(
solutions
)
-
1
:
)
==
self
.
thread
or
idx
==
len
(
my_
solutions
)
-
1
:
tp
=
ThreadPool
(
len
(
params_cudas_dirs
))
solution_results
+=
tp
.
map
(
self
.
evaluator
.
run
,
params_cudas_dirs
)
...
...
@@ -198,13 +211,25 @@ class BaseTuningStrategy(object):
self
.
is_cuda_free
[
"busy"
].
remove
(
param_cuda
[
1
])
params_cudas_dirs
=
[]
self
.
feedback
(
solutions
,
solution_results
)
all_solution_results
=
self
.
mpi
.
gather
(
solution_results
)
if
self
.
mpi
.
rank
==
0
:
# only rank 0 need to feedback
all_solution_results
=
[
y
for
x
in
all_solution_results
for
y
in
x
]
self
.
feedback
(
solutions
,
all_solution_results
)
# remove the tmp.txt which records the eval results for trials
tmp_file
=
os
.
path
.
join
(
TMP_HOME
,
"tmp.txt"
)
if
os
.
path
.
exists
(
tmp_file
):
os
.
remove
(
tmp_file
)
return
solutions_modeldirs
# collect all solutions_modeldirs
collected_solutions_modeldirs
=
self
.
mpi
.
allgather
(
solutions_modeldirs
)
return_dict
=
{}
for
i
in
collected_solutions_modeldirs
:
return_dict
.
update
(
i
)
return
return_dict
class
HAZero
(
BaseTuningStrategy
):
...
...
paddlehub/autofinetune/mpi_helper.py
0 → 100755
浏览文件 @
fa6f8d16
#!/usr/bin/env python
# -*- coding: utf-8 -*-
class
MPIHelper
(
object
):
def
__init__
(
self
):
try
:
from
mpi4py
import
MPI
except
:
# local run
self
.
_size
=
1
self
.
_rank
=
0
self
.
_multi_machine
=
False
import
socket
self
.
_name
=
socket
.
gethostname
()
else
:
# in mpi environment
self
.
_comm
=
MPI
.
COMM_WORLD
self
.
_size
=
self
.
_comm
.
Get_size
()
self
.
_rank
=
self
.
_comm
.
Get_rank
()
self
.
_name
=
MPI
.
Get_processor_name
()
if
self
.
_size
>
1
:
self
.
_multi_machine
=
True
else
:
self
.
_multi_machine
=
False
@
property
def
multi_machine
(
self
):
return
self
.
_multi_machine
@
property
def
rank
(
self
):
return
self
.
_rank
@
property
def
size
(
self
):
return
self
.
_size
@
property
def
name
(
self
):
return
self
.
_name
def
bcast
(
self
,
data
):
if
self
.
_multi_machine
:
# call real bcast
return
self
.
_comm
.
bcast
(
data
,
root
=
0
)
else
:
# do nothing
return
data
def
gather
(
self
,
data
):
if
self
.
_multi_machine
:
# call real gather
return
self
.
_comm
.
gather
(
data
,
root
=
0
)
else
:
# do nothing
return
[
data
]
def
allgather
(
self
,
data
):
if
self
.
_multi_machine
:
# call real allgather
return
self
.
_comm
.
allgather
(
data
)
else
:
# do nothing
return
[
data
]
# calculate split range on mpi environment
def
split_range
(
self
,
array_length
):
if
self
.
_size
==
1
:
return
0
,
array_length
average_count
=
array_length
/
self
.
_size
if
array_length
%
self
.
_size
==
0
:
return
average_count
*
self
.
_rank
,
average_count
*
(
self
.
_rank
+
1
)
else
:
if
self
.
_rank
<
array_length
%
self
.
_size
:
return
(
average_count
+
1
)
*
self
.
_rank
,
(
average_count
+
1
)
*
(
self
.
_rank
+
1
)
else
:
start
=
(
average_count
+
1
)
*
(
array_length
%
self
.
_size
)
\
+
average_count
*
(
self
.
_rank
-
array_length
%
self
.
_size
)
return
start
,
start
+
average_count
if
__name__
==
"__main__"
:
mpi
=
MPIHelper
()
print
(
"Hello world from process {} of {} at {}."
.
format
(
mpi
.
rank
,
mpi
.
size
,
mpi
.
name
))
all_node_names
=
mpi
.
gather
(
mpi
.
name
)
print
(
"all node names using gather: {}"
.
format
(
all_node_names
))
all_node_names
=
mpi
.
allgather
(
mpi
.
name
)
print
(
"all node names using allgather: {}"
.
format
(
all_node_names
))
if
mpi
.
rank
==
0
:
data
=
range
(
10
)
else
:
data
=
None
data
=
mpi
.
bcast
(
data
)
print
(
"after bcast, process {} have data {}"
.
format
(
mpi
.
rank
,
data
))
data
=
[
i
+
mpi
.
rank
for
i
in
data
]
print
(
"after modify, process {} have data {}"
.
format
(
mpi
.
rank
,
data
))
new_data
=
mpi
.
gather
(
data
)
print
(
"after gather, process {} have data {}"
.
format
(
mpi
.
rank
,
new_data
))
# test for split
for
i
in
range
(
12
):
length
=
i
+
mpi
.
size
# length should >= mpi.size
[
start
,
end
]
=
mpi
.
split_range
(
length
)
split_result
=
mpi
.
gather
([
start
,
end
])
print
(
"length {}, split_result {}"
.
format
(
length
,
split_result
))
paddlehub/commands/autofinetune.py
浏览文件 @
fa6f8d16
...
...
@@ -188,37 +188,62 @@ class AutoFineTuneCommand(BaseCommand):
run_round_cnt
=
run_round_cnt
+
1
print
(
"PaddleHub Autofinetune ends."
)
best_hparams_origin
=
autoft
.
get_best_hparams
()
best_hparams_origin
=
autoft
.
mpi
.
bcast
(
best_hparams_origin
)
with
open
(
autoft
.
_output_dir
+
"/log_file.txt"
,
"w"
)
as
f
:
best_hparams
=
evaluator
.
convert_params
(
autoft
.
get_best_hparams
()
)
best_hparams
=
evaluator
.
convert_params
(
best_hparams_origin
)
print
(
"The final best hyperparameters:"
)
f
.
write
(
"The final best hyperparameters:
\n
"
)
for
index
,
hparam_name
in
enumerate
(
autoft
.
hparams_name_list
):
print
(
"%s=%s"
%
(
hparam_name
,
best_hparams
[
index
]))
f
.
write
(
hparam_name
+
"
\t
:
\t
"
+
str
(
best_hparams
[
index
])
+
"
\n
"
)
best_hparams_dir
,
best_hparams_rank
=
solutions_modeldirs
[
tuple
(
best_hparams_origin
)]
print
(
"The final best eval score is %s."
%
autoft
.
get_best_eval_value
())
if
autoft
.
mpi
.
multi_machine
:
print
(
"The final best model parameters are saved as "
+
autoft
.
_output_dir
+
"/best_model on rank "
+
str
(
best_hparams_rank
)
+
" ."
)
else
:
print
(
"The final best model parameters are saved as "
+
autoft
.
_output_dir
+
"/best_model ."
)
f
.
write
(
"The final best eval score is %s.
\n
"
%
autoft
.
get_best_eval_value
())
f
.
write
(
"The final best model parameters are saved as ./best_model ."
)
best_model_dir
=
autoft
.
_output_dir
+
"/best_model"
shutil
.
copytree
(
solutions_modeldirs
[
tuple
(
autoft
.
get_best_hparams
())],
best_model_dir
)
if
autoft
.
mpi
.
rank
==
best_hparams_rank
:
shutil
.
copytree
(
best_hparams_dir
,
best_model_dir
)
if
autoft
.
mpi
.
multi_machine
:
f
.
write
(
"The final best model parameters are saved as ./best_model on rank "
\
+
str
(
best_hparams_rank
)
+
" ."
)
f
.
write
(
"
\t
"
.
join
(
autoft
.
hparams_name_list
)
+
"
\t
saved_params_dir
\t
rank
\n
"
)
else
:
f
.
write
(
"The final best model parameters are saved as ./best_model ."
)
f
.
write
(
"
\t
"
.
join
(
autoft
.
hparams_name_list
)
+
"
\t
saved_params_dir
\n
"
)
print
(
"The related infomation about hyperparamemters searched are saved as %s/log_file.txt ."
%
autoft
.
_output_dir
)
for
solution
,
modeldir
in
solutions_modeldirs
.
items
():
param
=
evaluator
.
convert_params
(
solution
)
param
=
[
str
(
p
)
for
p
in
param
]
f
.
write
(
"
\t
"
.
join
(
param
)
+
"
\t
"
+
modeldir
+
"
\n
"
)
if
autoft
.
mpi
.
multi_machine
:
f
.
write
(
"
\t
"
.
join
(
param
)
+
"
\t
"
+
modeldir
[
0
]
+
"
\t
"
+
str
(
modeldir
[
1
])
+
"
\n
"
)
else
:
f
.
write
(
"
\t
"
.
join
(
param
)
+
"
\t
"
+
modeldir
[
0
]
+
"
\n
"
)
return
True
...
...
paddlehub/commands/install.py
浏览文件 @
fa6f8d16
...
...
@@ -18,6 +18,7 @@ from __future__ import division
from
__future__
import
print_function
import
argparse
import
os
from
paddlehub.common
import
utils
from
paddlehub.module.manager
import
default_module_manager
...
...
@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand):
print
(
"ERROR: Please specify a module name.
\n
"
)
self
.
help
()
return
False
extra
=
{
"command"
:
"install"
}
if
argv
[
0
].
endswith
(
"tar.gz"
)
or
argv
[
0
].
endswith
(
"phm"
):
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_package
=
argv
[
0
],
extra
=
extra
)
elif
os
.
path
.
exists
(
argv
[
0
])
and
os
.
path
.
isdir
(
argv
[
0
]):
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_dir
=
argv
[
0
],
extra
=
extra
)
else
:
module_name
=
argv
[
0
]
module_version
=
None
if
"=="
not
in
module_name
else
module_name
.
split
(
"=="
)[
1
]
module_name
=
module_name
if
"=="
not
in
module_name
else
module_name
.
split
(
"=="
)[
0
]
extra
=
{
"command"
:
"install"
}
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
module_name
,
module_version
=
module_version
,
extra
=
extra
)
module_name
=
module_name
,
module_version
=
module_version
,
extra
=
extra
)
print
(
tips
)
return
True
...
...
paddlehub/commands/run.py
浏览文件 @
fa6f8d16
...
...
@@ -71,7 +71,7 @@ class RunCommand(BaseCommand):
if
not
result
:
return
None
return
hub
.
Module
(
module_dir
=
module_dir
)
return
hub
.
Module
(
directory
=
module_dir
[
0
]
)
def
add_module_config_arg
(
self
):
configs
=
self
.
module
.
processor
.
configs
()
...
...
@@ -105,7 +105,7 @@ class RunCommand(BaseCommand):
def
add_module_input_arg
(
self
):
module_type
=
self
.
module
.
type
.
lower
()
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
self
.
arg_input_group
.
add_argument
(
'--input_file'
,
type
=
str
,
...
...
@@ -152,7 +152,7 @@ class RunCommand(BaseCommand):
def
get_data
(
self
):
module_type
=
self
.
module
.
type
.
lower
()
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
input_data
=
{}
if
len
(
expect_data_format
)
==
1
:
key
=
list
(
expect_data_format
.
keys
())[
0
]
...
...
@@ -177,7 +177,7 @@ class RunCommand(BaseCommand):
def
check_data
(
self
,
data
):
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
if
len
(
data
.
keys
())
!=
len
(
expect_data_format
.
keys
()):
print
(
...
...
@@ -236,10 +236,13 @@ class RunCommand(BaseCommand):
return
False
# If the module is not executable, give an alarm and exit
if
not
self
.
module
.
default_signatur
e
:
if
not
self
.
module
.
is_runabl
e
:
print
(
"ERROR! Module %s is not executable."
%
module_name
)
return
False
if
self
.
module
.
code_version
==
"v2"
:
results
=
self
.
module
(
argv
[
1
:])
else
:
self
.
module
.
check_processor
()
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
...
...
@@ -260,7 +263,7 @@ class RunCommand(BaseCommand):
return
False
results
=
self
.
module
(
sign_name
=
self
.
module
.
default_signature
.
nam
e
,
sign_name
=
self
.
module
.
default_signatur
e
,
data
=
data
,
use_gpu
=
self
.
args
.
use_gpu
,
batch_size
=
self
.
args
.
batch_size
,
...
...
paddlehub/commands/serving.py
浏览文件 @
fa6f8d16
...
...
@@ -159,7 +159,7 @@ class ServingCommand(BaseCommand):
module
=
args
.
modules
if
module
is
not
None
:
use_gpu
=
args
.
use_gpu
port
=
args
.
port
[
0
]
port
=
args
.
port
if
ServingCommand
.
is_port_occupied
(
"127.0.0.1"
,
port
)
is
True
:
print
(
"Port %s is occupied, please change it."
%
(
port
))
return
False
...
...
@@ -206,10 +206,12 @@ class ServingCommand(BaseCommand):
if
args
.
sub_command
==
"start"
:
if
args
.
bert_service
==
"bert_service"
:
ServingCommand
.
start_bert_serving
(
args
)
el
s
e
:
el
if
args
.
bert_service
is
Non
e
:
ServingCommand
.
start_serving
(
args
)
else
:
ServingCommand
.
show_help
()
else
:
ServingCommand
.
show_help
()
command
=
ServingCommand
.
instance
()
paddlehub/commands/show.py
浏览文件 @
fa6f8d16
...
...
@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand):
cwd
=
os
.
getcwd
()
module_dir
=
default_module_manager
.
search_module
(
module_name
)
module_dir
=
(
os
.
path
.
join
(
cwd
,
module_name
),
None
)
if
not
module_dir
else
module_dir
if
not
module_dir
or
not
os
.
path
.
exists
(
module_dir
[
0
]):
print
(
"%s is not existed!"
%
module_name
)
return
True
...
...
paddlehub/common/hub_server.py
浏览文件 @
fa6f8d16
...
...
@@ -298,17 +298,23 @@ class CacheUpdater(threading.Thread):
api_url
=
srv_utils
.
uri_path
(
default_hub_server
.
get_server_url
(),
'search'
)
cache_path
=
os
.
path
.
join
(
CACHE_HOME
,
RESOURCE_LIST_FILE
)
if
os
.
path
.
exists
(
cache_path
):
extra
=
{
"command"
:
"update_cache"
,
"mtime"
:
os
.
stat
(
cache_path
).
st_mtime
}
else
:
extra
=
{
"command"
:
"update_cache"
,
"mtime"
:
time
.
strftime
(
"%Y-%m-%d %H:%M:%S"
,
time
.
localtime
())
}
try
:
r
=
srv_utils
.
hub_request
(
api_url
,
payload
,
extra
)
except
Exception
as
err
:
pass
if
r
.
get
(
"update_cache"
,
0
)
==
1
:
with
open
(
cache_path
,
'w+'
)
as
fp
:
yaml
.
safe_dump
({
'resource_list'
:
r
[
'data'
]},
fp
)
except
Exception
as
err
:
pass
def
run
(
self
):
self
.
update_resource_list_file
(
self
.
module
,
self
.
version
)
...
...
paddlehub/module/check_info.proto
浏览文件 @
fa6f8d16
...
...
@@ -50,6 +50,7 @@ message CheckInfo {
string
paddle_version
=
1
;
string
hub_version
=
2
;
string
module_proto_version
=
3
;
repeated
FileInfo
file_infos
=
4
;
repeated
Requires
requires
=
5
;
string
module_code_version
=
4
;
repeated
FileInfo
file_infos
=
5
;
repeated
Requires
requires
=
6
;
};
paddlehub/module/check_info_pb2.py
浏览文件 @
fa6f8d16
#coding:utf-8
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: check_info.proto
...
...
@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package
=
'paddlehub.module.checkinfo'
,
syntax
=
'proto3'
,
serialized_pb
=
_b
(
'
\n\x10\x63
heck_info.proto
\x12\x1a
paddlehub.module.checkinfo
\"\x85\x01\n\x08\x46
ileInfo
\x12\x11\n\t
file_name
\x18\x01
\x01
(
\t\x12\x33\n\x04
type
\x18\x02
\x01
(
\x0e\x32
%.paddlehub.module.checkinfo.FILE_TYPE
\x12\x0f\n\x07
is_need
\x18\x03
\x01
(
\x08\x12\x0b\n\x03
md5
\x18\x04
\x01
(
\t\x12\x13\n\x0b\x64\x65
scription
\x18\x05
\x01
(
\t\"\x84\x01\n\x08
Requires
\x12
>
\n\x0c
require_type
\x18\x01
\x01
(
\x0e\x32
(.paddlehub.module.checkinfo.REQUIRE_TYPE
\x12\x0f\n\x07
version
\x18\x02
\x01
(
\t\x12\x12\n\n
great_than
\x18\x03
\x01
(
\x08\x12\x13\n\x0b\x64\x65
scription
\x18\x04
\x01
(
\t\"\x
c8\x01\n\t
CheckInfo
\x12\x16\n\x0e
paddle_version
\x18\x01
\x01
(
\t\x12\x13\n\x0b
hub_version
\x18\x02
\x01
(
\t\x12\x1c\n\x14
module_proto_version
\x18\x03
\x01
(
\t\x12\x38\n\n
file_infos
\x18\x04
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.FileInfo
\x12\x36\n\x08
requires
\x18\x05
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.Requires*
\x1e\n\t
FILE_TYPE
\x12\x08\n\x04\x46
ILE
\x10\x00\x12\x07\n\x03\x44
IR
\x10\x01
*[
\n\x0c
REQUIRE_TYPE
\x12\x12\n\x0e
PYTHON_PACKAGE
\x10\x00\x12\x0e\n\n
HUB_MODULE
\x10\x01\x12\n\n\x06
SYSTEM
\x10\x02\x12\x0b\n\x07\x43
OMMAND
\x10\x03\x12\x0e\n\n
PY_VERSION
\x10\x04\x42\x02
H
\x03\x62\x06
proto3'
'
\n\x10\x63
heck_info.proto
\x12\x1a
paddlehub.module.checkinfo
\"\x85\x01\n\x08\x46
ileInfo
\x12\x11\n\t
file_name
\x18\x01
\x01
(
\t\x12\x33\n\x04
type
\x18\x02
\x01
(
\x0e\x32
%.paddlehub.module.checkinfo.FILE_TYPE
\x12\x0f\n\x07
is_need
\x18\x03
\x01
(
\x08\x12\x0b\n\x03
md5
\x18\x04
\x01
(
\t\x12\x13\n\x0b\x64\x65
scription
\x18\x05
\x01
(
\t\"\x84\x01\n\x08
Requires
\x12
>
\n\x0c
require_type
\x18\x01
\x01
(
\x0e\x32
(.paddlehub.module.checkinfo.REQUIRE_TYPE
\x12\x0f\n\x07
version
\x18\x02
\x01
(
\t\x12\x12\n\n
great_than
\x18\x03
\x01
(
\x08\x12\x13\n\x0b\x64\x65
scription
\x18\x04
\x01
(
\t\"\x
e5\x01\n\t
CheckInfo
\x12\x16\n\x0e
paddle_version
\x18\x01
\x01
(
\t\x12\x13\n\x0b
hub_version
\x18\x02
\x01
(
\t\x12\x1c\n\x14
module_proto_version
\x18\x03
\x01
(
\t\x12\x1b\n\x13
module_code_version
\x18\x04
\x01
(
\t\x12\x38\n\n
file_infos
\x18\x05
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.FileInfo
\x12\x36\n\x08
requires
\x18\x06
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.Requires*
\x1e\n\t
FILE_TYPE
\x12\x08\n\x04\x46
ILE
\x10\x00\x12\x07\n\x03\x44
IR
\x10\x01
*[
\n\x0c
REQUIRE_TYPE
\x12\x12\n\x0e
PYTHON_PACKAGE
\x10\x00\x12\x0e\n\n
HUB_MODULE
\x10\x01\x12\n\n\x06
SYSTEM
\x10\x02\x12\x0b\n\x07\x43
OMMAND
\x10\x03\x12\x0e\n\n
PY_VERSION
\x10\x04\x42\x02
H
\x03\x62\x06
proto3'
))
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
...
...
@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
],
containing_type
=
None
,
options
=
None
,
serialized_start
=
5
22
,
serialized_end
=
5
52
,
serialized_start
=
5
51
,
serialized_end
=
5
81
,
)
_sym_db
.
RegisterEnumDescriptor
(
_FILE_TYPE
)
...
...
@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
],
containing_type
=
None
,
options
=
None
,
serialized_start
=
5
54
,
serialized_end
=
6
45
,
serialized_start
=
5
83
,
serialized_end
=
6
74
,
)
_sym_db
.
RegisterEnumDescriptor
(
_REQUIRE_TYPE
)
...
...
@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor(
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'
file_infos
'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.
file_infos
'
,
name
=
'
module_code_version
'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.
module_code_version
'
,
index
=
3
,
number
=
4
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'file_infos'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.file_infos'
,
index
=
4
,
number
=
5
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
...
...
@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor(
_descriptor
.
FieldDescriptor
(
name
=
'requires'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.requires'
,
index
=
4
,
number
=
5
,
index
=
5
,
number
=
6
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
...
...
@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor(
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
320
,
serialized_end
=
5
20
,
serialized_end
=
5
49
,
)
_FILEINFO
.
fields_by_name
[
'type'
].
enum_type
=
_FILE_TYPE
...
...
paddlehub/module/checker.py
浏览文件 @
fa6f8d16
...
...
@@ -32,20 +32,22 @@ FILE_SEP = "/"
class
ModuleChecker
(
object
):
def
__init__
(
self
,
module_path
):
self
.
module_path
=
module_path
def
__init__
(
self
,
directory
):
self
.
_directory
=
directory
self
.
_pb_path
=
os
.
path
.
join
(
self
.
directory
,
CHECK_INFO_PB_FILENAME
)
def
generate_check_info
(
self
):
check_info
=
check_info_pb2
.
CheckInfo
()
check_info
.
paddle_version
=
paddle
.
__version__
check_info
.
hub_version
=
hub_version
check_info
.
module_proto_version
=
module_proto_version
check_info
.
module_code_version
=
"v2"
file_infos
=
check_info
.
file_infos
file_list
=
[
file
for
file
in
os
.
listdir
(
self
.
module_path
)]
file_list
=
[
file
for
file
in
os
.
listdir
(
self
.
directory
)]
while
file_list
:
file
=
file_list
[
0
]
file_list
=
file_list
[
1
:]
abs_path
=
os
.
path
.
join
(
self
.
module_path
,
file
)
abs_path
=
os
.
path
.
join
(
self
.
directory
,
file
)
if
os
.
path
.
isdir
(
abs_path
):
for
sub_file
in
os
.
listdir
(
abs_path
):
sub_file
=
os
.
path
.
join
(
file
,
sub_file
)
...
...
@@ -62,9 +64,12 @@ class ModuleChecker(object):
file_info
.
type
=
check_info_pb2
.
FILE
file_info
.
is_need
=
True
with
open
(
os
.
path
.
join
(
self
.
module_path
,
CHECK_INFO_PB_FILENAME
),
"wb"
)
as
fi
:
fi
.
write
(
check_info
.
SerializeToString
())
with
open
(
self
.
pb_path
,
"wb"
)
as
file
:
file
.
write
(
check_info
.
SerializeToString
())
@
property
def
module_code_version
(
self
):
return
self
.
check_info
.
module_code_version
@
property
def
module_proto_version
(
self
):
...
...
@@ -82,20 +87,25 @@ class ModuleChecker(object):
def
file_infos
(
self
):
return
self
.
check_info
.
file_infos
@
property
def
directory
(
self
):
return
self
.
_directory
@
property
def
pb_path
(
self
):
return
self
.
_pb_path
def
check
(
self
):
result
=
True
self
.
check_info_pb_path
=
os
.
path
.
join
(
self
.
module_path
,
CHECK_INFO_PB_FILENAME
)
if
not
(
os
.
path
.
exists
(
self
.
check_info_pb_path
)
or
os
.
path
.
isfile
(
self
.
check_info_pb_path
)):
if
not
(
os
.
path
.
exists
(
self
.
pb_path
)
or
os
.
path
.
isfile
(
self
.
pb_path
)):
logger
.
warning
(
"This module lacks core file %s"
%
CHECK_INFO_PB_FILENAME
)
result
=
False
self
.
check_info
=
check_info_pb2
.
CheckInfo
()
try
:
with
open
(
self
.
check_info_
pb_path
,
"rb"
)
as
fi
:
with
open
(
self
.
pb_path
,
"rb"
)
as
fi
:
pb_string
=
fi
.
read
()
result
=
self
.
check_info
.
ParseFromString
(
pb_string
)
if
len
(
pb_string
)
==
0
or
(
result
is
not
None
...
...
@@ -182,7 +192,7 @@ class ModuleChecker(object):
for
file_info
in
self
.
file_infos
:
file_type
=
file_info
.
type
file_path
=
file_info
.
file_name
.
replace
(
FILE_SEP
,
os
.
sep
)
file_path
=
os
.
path
.
join
(
self
.
module_path
,
file_path
)
file_path
=
os
.
path
.
join
(
self
.
directory
,
file_path
)
if
not
os
.
path
.
exists
(
file_path
):
if
file_info
.
is_need
:
logger
.
warning
(
...
...
paddlehub/module/manager.py
浏览文件 @
fa6f8d16
...
...
@@ -19,7 +19,9 @@ from __future__ import print_function
import
os
import
shutil
from
functools
import
cmp_to_key
import
tarfile
from
paddlehub.common
import
utils
from
paddlehub.common
import
srv_utils
...
...
@@ -79,10 +81,15 @@ class LocalModuleManager(object):
return
self
.
modules_dict
.
get
(
module_name
,
None
)
def
install_module
(
self
,
module_name
,
module_name
=
None
,
module_dir
=
None
,
module_package
=
None
,
module_version
=
None
,
upgrade
=
False
,
extra
=
None
):
md5_value
=
installed_module_version
=
None
from_user_dir
=
True
if
module_dir
else
False
if
module_name
:
self
.
all_modules
(
update
=
True
)
module_info
=
self
.
modules_dict
.
get
(
module_name
,
None
)
if
module_info
:
...
...
@@ -95,6 +102,62 @@ class LocalModuleManager(object):
module_dir
)
return
True
,
tips
,
self
.
modules_dict
[
module_name
]
search_result
=
hub
.
default_hub_server
.
get_module_url
(
module_name
,
version
=
module_version
,
extra
=
extra
)
name
=
search_result
.
get
(
'name'
,
None
)
url
=
search_result
.
get
(
'url'
,
None
)
md5_value
=
search_result
.
get
(
'md5'
,
None
)
installed_module_version
=
search_result
.
get
(
'version'
,
None
)
if
not
url
or
(
module_version
is
not
None
and
installed_module_version
!=
module_version
)
or
(
name
!=
module_name
):
if
default_hub_server
.
_server_check
()
is
False
:
tips
=
"Request Hub-Server unsuccessfully, please check your network."
else
:
tips
=
"Can't find module %s"
%
module_name
if
module_version
:
tips
+=
" with version %s"
%
module_version
module_tag
=
module_name
if
not
module_version
else
'%s-%s'
%
(
module_name
,
module_version
)
return
False
,
tips
,
None
result
,
tips
,
module_zip_file
=
default_downloader
.
download_file
(
url
=
url
,
save_path
=
hub
.
CACHE_HOME
,
save_name
=
module_name
,
replace
=
True
,
print_progress
=
True
)
result
,
tips
,
module_dir
=
default_downloader
.
uncompress
(
file
=
module_zip_file
,
dirname
=
MODULE_HOME
,
delete_file
=
True
,
print_progress
=
True
)
if
module_package
:
with
tarfile
.
open
(
module_package
,
"r:gz"
)
as
tar
:
file_names
=
tar
.
getnames
()
size
=
len
(
file_names
)
-
1
module_dir
=
os
.
path
.
split
(
file_names
[
0
])[
0
]
module_dir
=
os
.
path
.
join
(
hub
.
CACHE_HOME
,
module_dir
)
# remove cache
if
os
.
path
.
exists
(
module_dir
):
shutil
.
rmtree
(
module_dir
)
for
index
,
file_name
in
enumerate
(
file_names
):
tar
.
extract
(
file_name
,
hub
.
CACHE_HOME
)
if
module_dir
:
if
not
module_name
:
module_name
=
hub
.
Module
(
directory
=
module_dir
).
name
self
.
all_modules
(
update
=
False
)
module_info
=
self
.
modules_dict
.
get
(
module_name
,
None
)
if
module_info
:
module_dir
=
self
.
modules_dict
[
module_name
][
0
]
module_tag
=
module_name
if
not
module_version
else
'%s-%s'
%
(
module_name
,
module_version
)
tips
=
"Module %s already installed in %s"
%
(
module_tag
,
module_dir
)
return
True
,
tips
,
self
.
modules_dict
[
module_name
]
search_result
=
hub
.
default_hub_server
.
get_module_url
(
module_name
,
version
=
module_version
,
extra
=
extra
)
name
=
search_result
.
get
(
'name'
,
None
)
...
...
@@ -162,9 +225,18 @@ class LocalModuleManager(object):
with
open
(
os
.
path
.
join
(
MODULE_HOME
,
module_dir
,
"md5.txt"
),
"w"
)
as
fp
:
fp
.
write
(
md5_value
)
if
md5_value
:
with
open
(
os
.
path
.
join
(
MODULE_HOME
,
module_dir
,
"md5.txt"
),
"w"
)
as
fp
:
fp
.
write
(
md5_value
)
save_path
=
os
.
path
.
join
(
MODULE_HOME
,
module_name
)
if
os
.
path
.
exists
(
save_path
):
shutil
.
rmtree
(
save_path
)
shutil
.
move
(
save_path
)
if
from_user_dir
:
shutil
.
copytree
(
module_dir
,
save_path
)
else
:
shutil
.
move
(
module_dir
,
save_path
)
module_dir
=
save_path
tips
=
"Successfully installed %s"
%
module_name
...
...
paddlehub/module/module.py
浏览文件 @
fa6f8d16
...
...
@@ -21,6 +21,10 @@ import os
import
time
import
sys
import
functools
import
inspect
import
importlib
import
tarfile
from
collections
import
defaultdict
from
shutil
import
copyfile
import
paddle
...
...
@@ -28,22 +32,19 @@ import paddle.fluid as fluid
from
paddlehub.common
import
utils
from
paddlehub.common
import
paddle_helper
from
paddlehub.common.
logger
import
logger
from
paddlehub.common.
dir
import
CACHE_HOME
from
paddlehub.common.lock
import
lock
from
paddlehub.common.downloader
import
default_downloader
from
paddlehub.common.logger
import
logger
from
paddlehub.common.hub_server
import
CacheUpdater
from
paddlehub.module
import
module_desc_pb2
from
paddlehub.common.dir
import
CONF_HOME
from
paddlehub.module
import
check_info_pb2
from
paddlehub.common.hub_server
import
CacheUpdater
from
paddlehub.module.signature
import
Signature
,
create_signature
from
paddlehub.module.checker
import
ModuleChecker
from
paddlehub.module.manager
import
default_module_manager
from
paddlehub.module.checker
import
ModuleChecker
from
paddlehub.module.signature
import
Signature
,
create_signature
from
paddlehub.module.base_processor
import
BaseProcessor
from
paddlehub.io.parser
import
yaml_parser
from
paddlehub
import
version
__all__
=
[
'Module'
,
'create_module'
]
# PaddleHub module dir name
ASSETS_DIRNAME
=
"assets"
MODEL_DIRNAME
=
"model"
...
...
@@ -52,67 +53,227 @@ PYTHON_DIR = "python"
PROCESSOR_NAME
=
"processor"
# PaddleHub var prefix
HUB_VAR_PREFIX
=
"@HUB_%s@"
# PaddleHub Module package suffix
HUB_PACKAGE_SUFFIX
=
"phm"
def
create_module
(
directory
,
name
,
author
,
email
,
module_type
,
summary
,
version
):
save_file_name
=
"{}-{}.{}"
.
format
(
name
,
version
,
HUB_PACKAGE_SUFFIX
)
# record module info and serialize
desc
=
module_desc_pb2
.
ModuleDesc
()
attr
=
desc
.
attr
attr
.
type
=
module_desc_pb2
.
MAP
module_info
=
attr
.
map
.
data
[
'module_info'
]
module_info
.
type
=
module_desc_pb2
.
MAP
utils
.
from_pyobj_to_module_attr
(
name
,
module_info
.
map
.
data
[
'name'
])
utils
.
from_pyobj_to_module_attr
(
author
,
module_info
.
map
.
data
[
'author'
])
utils
.
from_pyobj_to_module_attr
(
email
,
module_info
.
map
.
data
[
'author_email'
])
utils
.
from_pyobj_to_module_attr
(
module_type
,
module_info
.
map
.
data
[
'type'
])
utils
.
from_pyobj_to_module_attr
(
summary
,
module_info
.
map
.
data
[
'summary'
])
utils
.
from_pyobj_to_module_attr
(
version
,
module_info
.
map
.
data
[
'version'
])
module_desc_path
=
os
.
path
.
join
(
directory
,
"module_desc.pb"
)
with
open
(
module_desc_path
,
"wb"
)
as
f
:
f
.
write
(
desc
.
SerializeToString
())
# generate check info
checker
=
ModuleChecker
(
directory
)
checker
.
generate_check_info
()
# add __init__
module_init_1
=
os
.
path
.
join
(
directory
,
"__init__.py"
)
with
open
(
module_init_1
,
"a"
)
as
file
:
file
.
write
(
""
)
module_init_2
=
os
.
path
.
join
(
directory
,
"python"
,
"__init__.py"
)
with
open
(
module_init_2
,
"a"
)
as
file
:
file
.
write
(
""
)
# package the module
with
tarfile
.
open
(
save_file_name
,
"w:gz"
)
as
tar
:
for
dirname
,
_
,
files
in
os
.
walk
(
directory
):
for
file
in
files
:
tar
.
add
(
os
.
path
.
join
(
dirname
,
file
))
os
.
remove
(
module_desc_path
)
os
.
remove
(
checker
.
pb_path
)
os
.
remove
(
module_init_1
)
os
.
remove
(
module_init_2
)
class
Module
(
object
):
def
__new__
(
cls
,
name
=
None
,
directory
=
None
,
module_dir
=
None
,
version
=
None
):
module
=
None
if
cls
.
__name__
==
"Module"
:
if
name
:
module
=
cls
.
init_with_name
(
name
=
name
,
version
=
version
)
elif
directory
:
module
=
cls
.
init_with_directory
(
directory
=
directory
)
elif
module_dir
:
logger
.
warning
(
"Parameter module_dir is deprecated, please use directory to specify the path"
)
if
isinstance
(
module_dir
,
list
)
or
isinstance
(
module_dir
,
tuple
):
directory
=
module_dir
[
0
]
version
=
module_dir
[
1
]
else
:
directory
=
module_dir
module
=
cls
.
init_with_directory
(
directory
=
directory
)
if
not
module
:
module
=
object
.
__new__
(
cls
)
else
:
CacheUpdater
(
module
.
name
,
module
.
version
).
start
()
return
module
def
__init__
(
self
,
name
=
None
,
directory
=
None
,
module_dir
=
None
,
version
=
None
):
if
not
directory
:
return
self
.
_code_version
=
"v2"
self
.
_directory
=
directory
self
.
module_desc_path
=
os
.
path
.
join
(
self
.
directory
,
MODULE_DESC_PBNAME
)
self
.
_desc
=
module_desc_pb2
.
ModuleDesc
()
with
open
(
self
.
module_desc_path
,
"rb"
)
as
file
:
self
.
_desc
.
ParseFromString
(
file
.
read
())
module_info
=
self
.
desc
.
attr
.
map
.
data
[
'module_info'
]
self
.
_name
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'name'
])
self
.
_author
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author'
])
self
.
_author_email
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author_email'
])
self
.
_version
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'version'
])
self
.
_type
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'type'
])
self
.
_summary
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'summary'
])
self
.
_initialize
()
@
classmethod
def
init_with_name
(
cls
,
name
,
version
=
None
):
fp_lock
=
open
(
os
.
path
.
join
(
CACHE_HOME
,
name
),
"a"
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_EX
)
log_msg
=
"Installing %s module"
%
name
if
version
:
log_msg
+=
"-%s"
%
version
logger
.
info
(
log_msg
)
extra
=
{
"command"
:
"install"
}
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
name
,
module_version
=
version
,
extra
=
extra
)
if
not
result
:
logger
.
error
(
tips
)
raise
RuntimeError
(
tips
)
logger
.
info
(
tips
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
return
cls
.
init_with_directory
(
directory
=
module_dir
[
0
])
@
classmethod
def
init_with_directory
(
cls
,
directory
):
desc_file
=
os
.
path
.
join
(
directory
,
MODULE_DESC_PBNAME
)
checker
=
ModuleChecker
(
directory
)
checker
.
check
()
def
create_module
(
sign_arr
,
module_dir
,
processor
=
None
,
assets
=
None
,
module_info
=
None
,
exe
=
None
,
extra_info
=
None
):
sign_arr
=
utils
.
to_list
(
sign_arr
)
module
=
Module
(
signatures
=
sign_arr
,
processor
=
processor
,
assets
=
assets
,
module_info
=
module_info
,
extra_info
=
extra_info
)
module
.
serialize_to_path
(
path
=
module_dir
,
exe
=
exe
)
module_code_version
=
checker
.
module_code_version
if
module_code_version
==
"v2"
:
basename
=
os
.
path
.
split
(
directory
)[
-
1
]
dirname
=
os
.
path
.
join
(
*
list
(
os
.
path
.
split
(
directory
)[:
-
1
]))
sys
.
path
.
append
(
dirname
)
pymodule
=
importlib
.
import_module
(
"{}.python.module"
.
format
(
basename
))
return
pymodule
.
HubModule
(
directory
=
directory
)
return
ModuleV1
(
directory
=
directory
)
@
property
def
desc
(
self
):
return
self
.
_desc
@
property
def
directory
(
self
):
return
self
.
_directory
@
property
def
author
(
self
):
return
self
.
_author
@
property
def
author_email
(
self
):
return
self
.
_author_email
@
property
def
summary
(
self
):
return
self
.
_summary
@
property
def
type
(
self
):
return
self
.
_type
@
property
def
version
(
self
):
return
self
.
_version
@
property
def
name
(
self
):
return
self
.
_name
@
property
def
name_prefix
(
self
):
return
self
.
_name_prefix
@
property
def
code_version
(
self
):
return
self
.
_code_version
@
property
def
is_runable
(
self
):
return
False
def
_initialize
(
self
):
pass
class
ModuleHelper
(
object
):
def
__init__
(
self
,
module_dir
):
self
.
module_dir
=
module_dir
def
__init__
(
self
,
directory
):
self
.
directory
=
directory
def
module_desc_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
MODULE_DESC_PBNAME
)
return
os
.
path
.
join
(
self
.
directory
,
MODULE_DESC_PBNAME
)
def
model_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
MODEL_DIRNAME
)
return
os
.
path
.
join
(
self
.
directory
,
MODEL_DIRNAME
)
def
processor_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
PYTHON_DIR
)
return
os
.
path
.
join
(
self
.
directory
,
PYTHON_DIR
)
def
processor_name
(
self
):
return
PROCESSOR_NAME
def
assets_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
ASSETS_DIRNAME
)
return
os
.
path
.
join
(
self
.
directory
,
ASSETS_DIRNAME
)
class
Module
(
object
):
def
__init__
(
self
,
name
=
None
,
module_dir
=
None
,
signatures
=
None
,
module_info
=
None
,
assets
=
None
,
processor
=
None
,
extra_info
=
None
,
class
ModuleV1
(
Module
):
def
__init__
(
self
,
name
=
None
,
directory
=
None
,
module_dir
=
None
,
version
=
None
):
self
.
desc
=
module_desc_pb2
.
ModuleDesc
()
if
not
directory
:
return
super
(
ModuleV1
,
self
).
__init__
(
name
,
directory
,
module_dir
,
version
)
self
.
_code_version
=
"v1"
self
.
program
=
None
self
.
assets
=
[]
self
.
helper
=
None
self
.
signatures
=
{}
self
.
default_signature
=
None
self
.
module_info
=
None
self
.
processor
=
None
self
.
extra_info
=
{}
if
extra_info
is
None
else
extra_info
if
not
isinstance
(
self
.
extra_info
,
dict
):
raise
TypeError
(
"The extra_info should be an instance of python dict"
)
self
.
extra_info
=
{}
# cache data
self
.
last_call_name
=
None
...
...
@@ -120,62 +281,21 @@ class Module(object):
self
.
cache_fetch_dict
=
None
self
.
cache_program
=
None
fp_lock
=
open
(
os
.
path
.
join
(
CONF_HOME
,
'config.json'
))
lock
.
flock
(
fp_lock
,
lock
.
LOCK_EX
)
if
name
:
self
.
_init_with_name
(
name
=
name
,
version
=
version
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
elif
module_dir
:
self
.
_init_with_module_file
(
module_dir
=
module_dir
[
0
])
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
name
=
module_dir
[
0
].
split
(
"/"
)[
-
1
]
if
len
(
module_dir
)
>
1
:
version
=
module_dir
[
1
]
else
:
version
=
default_module_manager
.
search_module
(
name
)[
1
]
elif
signatures
:
if
processor
:
if
not
issubclass
(
processor
,
BaseProcessor
):
raise
TypeError
(
"Processor shoule be an instance of paddlehub.BaseProcessor"
)
if
assets
:
self
.
assets
=
utils
.
to_list
(
assets
)
# for asset in assets:
# utils.check_path(assets)
self
.
processor
=
processor
self
.
_generate_module_info
(
module_info
)
self
.
_init_with_signature
(
signatures
=
signatures
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
else
:
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
raise
ValueError
(
"Module initialized parameter is empty"
)
CacheUpdater
(
name
,
version
).
start
()
def
_init_with_name
(
self
,
name
,
version
=
None
):
log_msg
=
"Installing %s module"
%
name
if
version
:
log_msg
+=
"-%s"
%
version
logger
.
info
(
log_msg
)
extra
=
{
"command"
:
"install"
}
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
name
,
module_version
=
version
,
extra
=
extra
)
if
not
result
:
logger
.
error
(
tips
)
raise
RuntimeError
(
tips
)
else
:
logger
.
info
(
tips
)
self
.
_init_with_module_file
(
module_dir
[
0
])
def
_init_with_url
(
self
,
url
):
utils
.
check_url
(
url
)
result
,
tips
,
module_dir
=
default_downloader
.
download_file_and_uncompress
(
url
,
save_path
=
"."
)
if
not
result
:
logger
.
error
(
tips
)
raise
RuntimeError
(
tips
)
else
:
self
.
_init_with_module_file
(
module_dir
)
self
.
helper
=
ModuleHelper
(
directory
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
self
.
program
,
_
,
_
=
fluid
.
io
.
load_inference_model
(
self
.
helper
.
model_path
(),
executor
=
exe
)
for
block
in
self
.
program
.
blocks
:
for
op
in
block
.
ops
:
if
"op_callstack"
in
op
.
all_attrs
():
op
.
_set_attr
(
"op_callstack"
,
[
""
])
self
.
_load_processor
()
self
.
_load_assets
()
self
.
_recover_from_desc
()
self
.
_generate_sign_attr
()
self
.
_generate_extra_info
()
self
.
_restore_parameter
(
self
.
program
)
self
.
_recover_variable_info
(
self
.
program
)
def
_dump_processor
(
self
):
import
inspect
...
...
@@ -216,52 +336,6 @@ class Module(object):
filepath
=
os
.
path
.
join
(
self
.
helper
.
assets_path
(),
file
)
self
.
assets
.
append
(
filepath
)
def
_init_with_module_file
(
self
,
module_dir
):
checker
=
ModuleChecker
(
module_dir
)
checker
.
check
()
self
.
helper
=
ModuleHelper
(
module_dir
)
with
open
(
self
.
helper
.
module_desc_path
(),
"rb"
)
as
fi
:
self
.
desc
.
ParseFromString
(
fi
.
read
())
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
self
.
program
,
_
,
_
=
fluid
.
io
.
load_inference_model
(
self
.
helper
.
model_path
(),
executor
=
exe
)
for
block
in
self
.
program
.
blocks
:
for
op
in
block
.
ops
:
if
"op_callstack"
in
op
.
all_attrs
():
op
.
_set_attr
(
"op_callstack"
,
[
""
])
self
.
_load_processor
()
self
.
_load_assets
()
self
.
_recover_from_desc
()
self
.
_generate_sign_attr
()
self
.
_generate_extra_info
()
self
.
_restore_parameter
(
self
.
program
)
self
.
_recover_variable_info
(
self
.
program
)
def
_init_with_signature
(
self
,
signatures
):
self
.
name_prefix
=
HUB_VAR_PREFIX
%
self
.
name
self
.
_process_signatures
(
signatures
)
self
.
_check_signatures
()
self
.
_generate_desc
()
self
.
_generate_sign_attr
()
self
.
_generate_extra_info
()
def
_init_with_program
(
self
,
program
):
pass
def
_process_signatures
(
self
,
signatures
):
self
.
signatures
=
{}
self
.
program
=
signatures
[
0
].
inputs
[
0
].
block
.
program
for
sign
in
signatures
:
if
sign
.
name
in
self
.
signatures
:
raise
ValueError
(
"Error! Signature array contains duplicated signatrues %s"
%
sign
)
if
self
.
default_signature
is
None
and
sign
.
for_predict
:
self
.
default_signature
=
sign
self
.
signatures
[
sign
.
name
]
=
sign
def
_restore_parameter
(
self
,
program
):
global_block
=
program
.
global_block
()
param_attrs
=
self
.
desc
.
attr
.
map
.
data
[
'param_attrs'
]
...
...
@@ -302,21 +376,6 @@ class Module(object):
self
.
__dict__
[
"get_%s"
%
key
]
=
functools
.
partial
(
self
.
get_extra_info
,
key
=
key
)
def
_generate_module_info
(
self
,
module_info
=
None
):
if
not
module_info
:
self
.
module_info
=
{}
else
:
if
not
utils
.
is_yaml_file
(
module_info
):
logger
.
critical
(
"Module info file should be yaml format"
)
exit
(
1
)
self
.
module_info
=
yaml_parser
.
parse
(
module_info
)
self
.
author
=
self
.
module_info
.
get
(
'author'
,
'UNKNOWN'
)
self
.
author_email
=
self
.
module_info
.
get
(
'author_email'
,
'UNKNOWN'
)
self
.
summary
=
self
.
module_info
.
get
(
'summary'
,
'UNKNOWN'
)
self
.
type
=
self
.
module_info
.
get
(
'type'
,
'UNKNOWN'
)
self
.
version
=
self
.
module_info
.
get
(
'version'
,
'UNKNOWN'
)
self
.
name
=
self
.
module_info
.
get
(
'name'
,
'UNKNOWN'
)
def
_generate_sign_attr
(
self
):
self
.
_check_signatures
()
for
sign
in
self
.
signatures
:
...
...
@@ -369,21 +428,21 @@ class Module(object):
default_signature_name
=
utils
.
from_module_attr_to_pyobj
(
self
.
desc
.
attr
.
map
.
data
[
'default_signature'
])
self
.
default_signature
=
self
.
signatures
[
default_signature_name
]
if
default_signature_name
else
None
default_signature_name
]
.
name
if
default_signature_name
else
None
# recover module info
module_info
=
self
.
desc
.
attr
.
map
.
data
[
'module_info'
]
self
.
name
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
name
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'name'
])
self
.
author
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
author
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author'
])
self
.
author_email
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
author_email
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author_email'
])
self
.
version
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
version
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'version'
])
self
.
type
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
type
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'type'
])
self
.
summary
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
summary
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'summary'
])
# recover extra info
...
...
@@ -393,77 +452,9 @@ class Module(object):
self
.
extra_info
[
key
]
=
utils
.
from_module_attr_to_pyobj
(
value
)
# recover name prefix
self
.
name_prefix
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
name_prefix
=
utils
.
from_module_attr_to_pyobj
(
self
.
desc
.
attr
.
map
.
data
[
"name_prefix"
])
def
_generate_desc
(
self
):
# save fluid Parameter
attr
=
self
.
desc
.
attr
attr
.
type
=
module_desc_pb2
.
MAP
param_attrs
=
attr
.
map
.
data
[
'param_attrs'
]
param_attrs
.
type
=
module_desc_pb2
.
MAP
for
param
in
self
.
program
.
global_block
().
iter_parameters
():
param_attr
=
param_attrs
.
map
.
data
[
param
.
name
]
paddle_helper
.
from_param_to_module_attr
(
param
,
param_attr
)
# save Variable Info
var_infos
=
attr
.
map
.
data
[
'var_infos'
]
var_infos
.
type
=
module_desc_pb2
.
MAP
for
block
in
self
.
program
.
blocks
:
for
var
in
block
.
vars
.
values
():
var_info
=
var_infos
.
map
.
data
[
var
.
name
]
var_info
.
type
=
module_desc_pb2
.
MAP
utils
.
from_pyobj_to_module_attr
(
var
.
stop_gradient
,
var_info
.
map
.
data
[
'stop_gradient'
])
utils
.
from_pyobj_to_module_attr
(
block
.
idx
,
var_info
.
map
.
data
[
'block_id'
])
# save signarture info
for
key
,
sign
in
self
.
signatures
.
items
():
var
=
self
.
desc
.
sign2var
[
sign
.
name
]
feed_desc
=
var
.
feed_desc
fetch_desc
=
var
.
fetch_desc
feed_names
=
sign
.
feed_names
fetch_names
=
sign
.
fetch_names
for
index
,
input
in
enumerate
(
sign
.
inputs
):
feed_var
=
feed_desc
.
add
()
feed_var
.
var_name
=
self
.
get_var_name_with_prefix
(
input
.
name
)
feed_var
.
alias
=
feed_names
[
index
]
for
index
,
output
in
enumerate
(
sign
.
outputs
):
fetch_var
=
fetch_desc
.
add
()
fetch_var
.
var_name
=
self
.
get_var_name_with_prefix
(
output
.
name
)
fetch_var
.
alias
=
fetch_names
[
index
]
# save default signature
utils
.
from_pyobj_to_module_attr
(
self
.
default_signature
.
name
if
self
.
default_signature
else
None
,
attr
.
map
.
data
[
'default_signature'
])
# save name prefix
utils
.
from_pyobj_to_module_attr
(
self
.
name_prefix
,
self
.
desc
.
attr
.
map
.
data
[
"name_prefix"
])
# save module info
module_info
=
attr
.
map
.
data
[
'module_info'
]
module_info
.
type
=
module_desc_pb2
.
MAP
utils
.
from_pyobj_to_module_attr
(
self
.
name
,
module_info
.
map
.
data
[
'name'
])
utils
.
from_pyobj_to_module_attr
(
self
.
version
,
module_info
.
map
.
data
[
'version'
])
utils
.
from_pyobj_to_module_attr
(
self
.
author
,
module_info
.
map
.
data
[
'author'
])
utils
.
from_pyobj_to_module_attr
(
self
.
author_email
,
module_info
.
map
.
data
[
'author_email'
])
utils
.
from_pyobj_to_module_attr
(
self
.
type
,
module_info
.
map
.
data
[
'type'
])
utils
.
from_pyobj_to_module_attr
(
self
.
summary
,
module_info
.
map
.
data
[
'summary'
])
# save extra info
extra_info
=
attr
.
map
.
data
[
'extra_info'
]
extra_info
.
type
=
module_desc_pb2
.
MAP
for
key
,
value
in
self
.
extra_info
.
items
():
utils
.
from_pyobj_to_module_attr
(
value
,
extra_info
.
map
.
data
[
key
])
def
__call__
(
self
,
sign_name
,
data
,
use_gpu
=
False
,
batch_size
=
1
,
**
kwargs
):
self
.
check_processor
()
...
...
@@ -525,6 +516,10 @@ class Module(object):
if
not
self
.
processor
:
raise
ValueError
(
"This Module is not callable!"
)
@
property
def
is_runable
(
self
):
return
self
.
default_signature
!=
None
def
context
(
self
,
sign_name
=
None
,
for_test
=
False
,
...
...
@@ -664,93 +659,3 @@ class Module(object):
raise
ValueError
(
"All input and outputs variables in signature should come from the same Program"
)
def
serialize_to_path
(
self
,
path
=
None
,
exe
=
None
):
self
.
_check_signatures
()
self
.
_generate_desc
()
# create module path for saving
if
path
is
None
:
path
=
os
.
path
.
join
(
"."
,
self
.
name
)
self
.
helper
=
ModuleHelper
(
path
)
utils
.
mkdir
(
self
.
helper
.
module_dir
)
# create module pb
module_desc
=
module_desc_pb2
.
ModuleDesc
()
logger
.
info
(
"PaddleHub version = %s"
%
version
.
hub_version
)
logger
.
info
(
"PaddleHub Module proto version = %s"
%
version
.
module_proto_version
)
logger
.
info
(
"Paddle version = %s"
%
paddle
.
__version__
)
feeded_var_names
=
[
input
.
name
for
key
,
sign
in
self
.
signatures
.
items
()
for
input
in
sign
.
inputs
]
target_vars
=
[
output
for
key
,
sign
in
self
.
signatures
.
items
()
for
output
in
sign
.
outputs
]
feeded_var_names
=
list
(
set
(
feeded_var_names
))
target_vars
=
list
(
set
(
target_vars
))
# save inference program
program
=
self
.
program
.
clone
()
for
block
in
program
.
blocks
:
for
op
in
block
.
ops
:
if
"op_callstack"
in
op
.
all_attrs
():
op
.
_set_attr
(
"op_callstack"
,
[
""
])
if
not
exe
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
=
place
)
utils
.
mkdir
(
self
.
helper
.
model_path
())
fluid
.
io
.
save_inference_model
(
self
.
helper
.
model_path
(),
feeded_var_names
=
list
(
feeded_var_names
),
target_vars
=
list
(
target_vars
),
main_program
=
program
,
executor
=
exe
)
with
open
(
os
.
path
.
join
(
self
.
helper
.
model_path
(),
"__model__"
),
"rb"
)
as
file
:
program_desc_str
=
file
.
read
()
rename_program
=
fluid
.
framework
.
Program
.
parse_from_string
(
program_desc_str
)
varlist
=
{
var
:
block
for
block
in
rename_program
.
blocks
for
var
in
block
.
vars
if
self
.
get_name_prefix
()
not
in
var
}
for
var
,
block
in
varlist
.
items
():
old_name
=
var
new_name
=
self
.
get_var_name_with_prefix
(
old_name
)
block
.
_rename_var
(
old_name
,
new_name
)
utils
.
mkdir
(
self
.
helper
.
model_path
())
with
open
(
os
.
path
.
join
(
self
.
helper
.
model_path
(),
"__model__"
),
"wb"
)
as
f
:
f
.
write
(
rename_program
.
desc
.
serialize_to_string
())
for
file
in
os
.
listdir
(
self
.
helper
.
model_path
()):
if
(
file
==
"__model__"
or
self
.
get_name_prefix
()
in
file
):
continue
os
.
rename
(
os
.
path
.
join
(
self
.
helper
.
model_path
(),
file
),
os
.
path
.
join
(
self
.
helper
.
model_path
(),
self
.
get_var_name_with_prefix
(
file
)))
# create processor file
if
self
.
processor
:
self
.
_dump_processor
()
# create assets
self
.
_dump_assets
()
# create check info
checker
=
ModuleChecker
(
self
.
helper
.
module_dir
)
checker
.
generate_check_info
()
# Serialize module_desc pb
module_pb
=
self
.
desc
.
SerializeToString
()
with
open
(
self
.
helper
.
module_desc_path
(),
"wb"
)
as
f
:
f
.
write
(
module_pb
)
paddlehub/serving/bert_serving/bert_service.py
浏览文件 @
fa6f8d16
...
...
@@ -14,7 +14,6 @@
# limitations under the License.
import
sys
import
time
import
paddlehub
as
hub
import
ujson
import
random
...
...
@@ -30,7 +29,7 @@ if is_py3:
import
http.client
as
httplib
class
BertService
():
class
BertService
(
object
):
def
__init__
(
self
,
profile
=
False
,
max_seq_len
=
128
,
...
...
@@ -42,7 +41,7 @@ class BertService():
load_balance
=
'round_robin'
):
self
.
process_id
=
process_id
self
.
reader_flag
=
False
self
.
batch_size
=
16
self
.
batch_size
=
0
self
.
max_seq_len
=
max_seq_len
self
.
profile
=
profile
self
.
model_name
=
model_name
...
...
@@ -55,15 +54,6 @@ class BertService():
self
.
feed_var_names
=
''
self
.
retry
=
retry
def
connect
(
self
,
server
=
'127.0.0.1:8010'
):
self
.
server_list
.
append
(
server
)
def
connect_all_server
(
self
,
server_list
):
for
server_str
in
server_list
:
self
.
server_list
.
append
(
server_str
)
def
data_convert
(
self
,
text
):
if
self
.
reader_flag
==
False
:
module
=
hub
.
Module
(
name
=
self
.
model_name
)
inputs
,
outputs
,
program
=
module
.
context
(
trainable
=
True
,
max_seq_len
=
self
.
max_seq_len
)
...
...
@@ -79,10 +69,14 @@ class BertService():
do_lower_case
=
self
.
do_lower_case
)
self
.
reader_flag
=
True
return
self
.
reader
.
data_generator
(
batch_size
=
self
.
batch_size
,
phase
=
'predict'
,
data
=
text
)
def
add_server
(
self
,
server
=
'127.0.0.1:8010'
):
self
.
server_list
.
append
(
server
)
def
infer
(
self
,
request_msg
):
def
add_server_list
(
self
,
server_list
):
for
server_str
in
server_list
:
self
.
server_list
.
append
(
server_str
)
def
request_server
(
self
,
request_msg
):
if
self
.
load_balance
==
'round_robin'
:
try
:
cur_con
=
httplib
.
HTTPConnection
(
...
...
@@ -157,17 +151,13 @@ class BertService():
self
.
server_list
)
return
'retry'
def
encode
(
self
,
text
):
if
type
(
text
)
!=
list
:
raise
TypeError
(
'Only support list'
)
def
prepare_data
(
self
,
text
):
self
.
batch_size
=
len
(
text
)
data_generator
=
self
.
data_convert
(
text
)
start
=
time
.
time
()
request_time
=
0
result
=
[]
data_generator
=
self
.
reader
.
data_generator
(
batch_size
=
self
.
batch_size
,
phase
=
'predict'
,
data
=
text
)
request_msg
=
""
for
run_step
,
batch
in
enumerate
(
data_generator
(),
start
=
1
):
request
=
[]
copy_start
=
time
.
time
()
token_list
=
batch
[
0
][
0
].
reshape
(
-
1
).
tolist
()
pos_list
=
batch
[
0
][
1
].
reshape
(
-
1
).
tolist
()
sent_list
=
batch
[
0
][
2
].
reshape
(
-
1
).
tolist
()
...
...
@@ -184,54 +174,34 @@ class BertService():
si
+
1
)
*
self
.
max_seq_len
]
request
.
append
(
instance_dict
)
copy_time
=
time
.
time
()
-
copy_start
request
=
{
"instances"
:
request
}
request
[
"max_seq_len"
]
=
self
.
max_seq_len
request
[
"feed_var_names"
]
=
self
.
feed_var_names
request_msg
=
ujson
.
dumps
(
request
)
if
self
.
show_ids
:
logger
.
info
(
request_msg
)
request_start
=
time
.
time
()
response_msg
=
self
.
infer
(
request_msg
)
return
request_msg
def
encode
(
self
,
text
):
if
type
(
text
)
!=
list
:
raise
TypeError
(
'Only support list'
)
request_msg
=
self
.
prepare_data
(
text
)
response_msg
=
self
.
request_server
(
request_msg
)
retry
=
0
while
type
(
response_msg
)
==
str
and
response_msg
==
'retry'
:
if
retry
<
self
.
retry
:
retry
+=
1
logger
.
info
(
'Try to connect another servers'
)
response_msg
=
self
.
inf
er
(
request_msg
)
response_msg
=
self
.
request_serv
er
(
request_msg
)
else
:
logger
.
error
(
'Infer
failed after {} times retry'
.
format
(
logger
.
error
(
'Request
failed after {} times retry'
.
format
(
self
.
retry
))
break
result
=
[]
for
msg
in
response_msg
[
"instances"
]:
for
sample
in
msg
[
"instances"
]:
result
.
append
(
sample
[
"values"
])
request_time
+=
time
.
time
()
-
request_start
total_time
=
time
.
time
()
-
start
if
self
.
profile
:
return
[
total_time
,
request_time
,
response_msg
[
'op_time'
],
response_msg
[
'infer_time'
],
copy_time
]
else
:
return
result
def
connect
(
input_text
,
model_name
,
max_seq_len
=
128
,
show_ids
=
False
,
do_lower_case
=
True
,
server
=
"127.0.0.1:8866"
,
retry
=
3
):
# format of input_text like [["As long as"],]
bc
=
BertService
(
max_seq_len
=
max_seq_len
,
model_name
=
model_name
,
show_ids
=
show_ids
,
do_lower_case
=
do_lower_case
,
retry
=
retry
)
bc
.
connect
(
server
)
result
=
bc
.
encode
(
input_text
)
return
result
paddlehub/serving/bert_serving/bs_client.py
0 → 100644
浏览文件 @
fa6f8d16
from
paddlehub.serving.bert_serving
import
bert_service
class
BSClient
(
object
):
def
__init__
(
self
,
module_name
,
server
,
max_seq_len
=
20
,
show_ids
=
False
,
do_lower_case
=
True
,
retry
=
3
):
self
.
bs
=
bert_service
.
BertService
(
model_name
=
module_name
,
max_seq_len
=
max_seq_len
,
show_ids
=
show_ids
,
do_lower_case
=
do_lower_case
,
retry
=
retry
)
self
.
bs
.
add_server
(
server
=
server
)
def
get_result
(
self
,
input_text
):
return
self
.
bs
.
encode
(
input_text
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录