Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
fa6f8d16
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
1 年多 前同步成功
通知
283
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fa6f8d16
编写于
12月 23, 2019
作者:
B
Bin Long
提交者:
GitHub
12月 23, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into add_search_tip
上级
109d5293
81dfda51
变更
19
展开全部
显示空白变更内容
内联
并排
Showing
19 changed file
with
704 addition
and
529 deletion
+704
-529
.travis.yml
.travis.yml
+0
-10
demo/serving/bert_service/README.md
demo/serving/bert_service/README.md
+16
-12
demo/serving/bert_service/bert_service_client.py
demo/serving/bert_service/bert_service_client.py
+8
-5
paddlehub/__init__.py
paddlehub/__init__.py
+1
-1
paddlehub/autofinetune/autoft.py
paddlehub/autofinetune/autoft.py
+30
-5
paddlehub/autofinetune/mpi_helper.py
paddlehub/autofinetune/mpi_helper.py
+115
-0
paddlehub/commands/autofinetune.py
paddlehub/commands/autofinetune.py
+37
-12
paddlehub/commands/install.py
paddlehub/commands/install.py
+17
-7
paddlehub/commands/run.py
paddlehub/commands/run.py
+29
-26
paddlehub/commands/serving.py
paddlehub/commands/serving.py
+4
-2
paddlehub/commands/show.py
paddlehub/commands/show.py
+0
-2
paddlehub/common/hub_server.py
paddlehub/common/hub_server.py
+13
-7
paddlehub/module/check_info.proto
paddlehub/module/check_info.proto
+3
-2
paddlehub/module/check_info_pb2.py
paddlehub/module/check_info_pb2.py
+26
-11
paddlehub/module/checker.py
paddlehub/module/checker.py
+23
-13
paddlehub/module/manager.py
paddlehub/module/manager.py
+80
-8
paddlehub/module/module.py
paddlehub/module/module.py
+232
-327
paddlehub/serving/bert_serving/bert_service.py
paddlehub/serving/bert_serving/bert_service.py
+49
-79
paddlehub/serving/bert_serving/bs_client.py
paddlehub/serving/bert_serving/bs_client.py
+21
-0
未找到文件。
.travis.yml
浏览文件 @
fa6f8d16
...
...
@@ -16,12 +16,6 @@ jobs:
os
:
linux
python
:
3.6
script
:
/bin/bash ./scripts/check_code_style.sh
-
name
:
"
CI
on
Linux/Python3.5"
os
:
linux
python
:
3.5
-
name
:
"
CI
on
Linux/Python2.7"
os
:
linux
python
:
2.7
env
:
-
PYTHONPATH=${PWD}
...
...
@@ -30,10 +24,6 @@ install:
-
pip install --upgrade paddlepaddle
-
pip install -r requirements.txt
script
:
-
if [[ $TRAVIS_OS_NAME != "windows" ]]; then /bin/bash ./scripts/test_cml.sh; fi
-
if [[ $TRAVIS_OS_NAME != "windows" ]]; then /bin/bash ./scripts/test_all_module.sh; fi
notifications
:
email
:
on_success
:
change
...
...
demo/serving/bert_service/README.md
浏览文件 @
fa6f8d16
...
...
@@ -68,7 +68,7 @@ $ pip install ujson
|模型|网络|
|:-|:-:|
|
[
ERNIE
](
https://paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel
)
|ERNIE|
|
[
ernie
](
https://paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel
)
|ERNIE|
|
[
ernie_tiny
](
https://paddlepaddle.org.cn/hubdetail?name=ernie_tiny&en_category=SemanticModel
)
|ERNIE|
|
[
ernie_v2_eng_large
](
https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_large&en_category=SemanticModel
)
|ERNIE|
|
[
ernie_v2_eng_base
](
https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_base&en_category=SemanticModel
)
|ERNIE|
...
...
@@ -179,18 +179,22 @@ Server[baidu::paddle_serving::predictor::bert_service::BertServiceImpl] is servi
首先导入客户端依赖。
```
python
from
paddlehub.serving.bert_serving
import
b
ert_service
from
paddlehub.serving.bert_serving
import
b
s_client
```
接着输入文本信息。
接着启动并初始化
`bert service`
客户端
`BSClient`
(这里的server为虚拟地址,需根据自己实际ip设置)
```
python
bc
=
bs_client
.
BSClient
(
module_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
```
然后输入文本信息。
```
python
input_text
=
[[
"西风吹老洞庭波"
],
[
"一夜湘君白发多"
],
[
"醉后不知天在水"
],
[
"满船清梦压星河"
],
]
```
然后利用客户端接口发送文本到服务端,以获取embedding结果(server为虚拟地址,需根据自己实际ip设置)。
最后利用客户端接口
`get_result`
发送文本到服务端,以获取embedding结果。
```
python
result
=
bert_service
.
connect
(
input_text
=
input_text
,
model_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
result
=
bc
.
get_result
(
input_text
=
input_text
)
```
最后即可得到embedding结果(此处只展示部分结果)。
```
python
...
...
@@ -221,8 +225,8 @@ Paddle Inference Server exit successfully!
> Q : 如何在一台服务器部署多个模型?
> A : 可通过多次启动`Bert Service`,分配不同端口实现。如果使用GPU,需要指定不同的显卡。如同时部署`ernie`和`bert_chinese_L-12_H-768_A-12`,分别执行命令如下:
> ```shell
> $ hub serving start bert_servi
ng
-m ernie -p 8866
> $ hub serving start bert_servi
ng -m bert_serving
-m bert_chinese_L-12_H-768_A-12 -p 8867
> $ hub serving start bert_servi
ce
-m ernie -p 8866
> $ hub serving start bert_servi
ce
-m bert_chinese_L-12_H-768_A-12 -p 8867
> ```
> Q : 启动时显示"Check out http://yq01-gpu-255-129-12-00.epc.baidu.com:8887 in web
...
...
demo/serving/bert_service/bert_service_client.py
浏览文件 @
fa6f8d16
# coding: utf8
from
paddlehub.serving.bert_serving
import
b
ert_service
from
paddlehub.serving.bert_serving
import
b
s_client
if
__name__
==
"__main__"
:
# 初始化bert_service客户端BSClient
bc
=
bs_client
.
BSClient
(
module_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
# 输入要做embedding的文本
# 文本格式为[["文本1"], ["文本2"], ]
input_text
=
[
...
...
@@ -10,10 +13,10 @@ if __name__ == "__main__":
[
"醉后不知天在水"
],
[
"满船清梦压星河"
],
]
# 调用客户端接口bert_service.connect()获取结果
result
=
bert_service
.
connect
(
input_text
=
input_text
,
model_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
# 打印embedding结果
# BSClient.get_result()获取结果
result
=
bc
.
get_result
(
input_text
=
input_text
)
# 打印输入文本的embedding结果
for
item
in
result
:
print
(
item
)
paddlehub/__init__.py
浏览文件 @
fa6f8d16
...
...
@@ -38,7 +38,7 @@ from .common.logger import logger
from
.common.paddle_helper
import
connect_program
from
.common.hub_server
import
default_hub_server
from
.module.module
import
Module
,
create_module
from
.module.module
import
Module
from
.module.base_processor
import
BaseProcessor
from
.module.signature
import
Signature
,
create_signature
from
.module.manager
import
default_module_manager
...
...
paddlehub/autofinetune/autoft.py
浏览文件 @
fa6f8d16
...
...
@@ -26,6 +26,7 @@ from tb_paddle import SummaryWriter
from
paddlehub.common.logger
import
logger
from
paddlehub.common.utils
import
mkdir
from
paddlehub.autofinetune.evaluator
import
REWARD_SUM
,
TMP_HOME
from
paddlehub.autofinetune.mpi_helper
import
MPIHelper
if
six
.
PY3
:
INF
=
math
.
inf
...
...
@@ -75,6 +76,12 @@ class BaseTuningStrategy(object):
logdir
=
self
.
_output_dir
+
'/visualization/pop_{}'
.
format
(
i
))
self
.
writer_pop_trails
.
append
(
writer_pop_trail
)
# for parallel on mpi
self
.
mpi
=
MPIHelper
()
if
self
.
mpi
.
multi_machine
:
print
(
"Autofinetune multimachine mode: running on {}"
.
format
(
self
.
mpi
.
gather
(
self
.
mpi
.
name
)))
@
property
def
thread
(
self
):
return
self
.
_num_thread
...
...
@@ -177,16 +184,22 @@ class BaseTuningStrategy(object):
solutions_modeldirs
=
{}
mkdir
(
output_dir
)
for
idx
,
solution
in
enumerate
(
solutions
):
solutions
=
self
.
mpi
.
bcast
(
solutions
)
# split solutions to "solutions for me"
range_start
,
range_end
=
self
.
mpi
.
split_range
(
len
(
solutions
))
my_solutions
=
solutions
[
range_start
:
range_end
]
for
idx
,
solution
in
enumerate
(
my_solutions
):
cuda
=
self
.
is_cuda_free
[
"free"
][
0
]
modeldir
=
output_dir
+
"/model-"
+
str
(
idx
)
+
"/"
log_file
=
output_dir
+
"/log-"
+
str
(
idx
)
+
".info"
params_cudas_dirs
.
append
([
solution
,
cuda
,
modeldir
,
log_file
])
solutions_modeldirs
[
tuple
(
solution
)]
=
modeldir
solutions_modeldirs
[
tuple
(
solution
)]
=
(
modeldir
,
self
.
mpi
.
rank
)
self
.
is_cuda_free
[
"free"
].
remove
(
cuda
)
self
.
is_cuda_free
[
"busy"
].
append
(
cuda
)
if
len
(
params_cudas_dirs
)
==
self
.
thread
or
idx
==
len
(
solutions
)
-
1
:
)
==
self
.
thread
or
idx
==
len
(
my_
solutions
)
-
1
:
tp
=
ThreadPool
(
len
(
params_cudas_dirs
))
solution_results
+=
tp
.
map
(
self
.
evaluator
.
run
,
params_cudas_dirs
)
...
...
@@ -198,13 +211,25 @@ class BaseTuningStrategy(object):
self
.
is_cuda_free
[
"busy"
].
remove
(
param_cuda
[
1
])
params_cudas_dirs
=
[]
self
.
feedback
(
solutions
,
solution_results
)
all_solution_results
=
self
.
mpi
.
gather
(
solution_results
)
if
self
.
mpi
.
rank
==
0
:
# only rank 0 need to feedback
all_solution_results
=
[
y
for
x
in
all_solution_results
for
y
in
x
]
self
.
feedback
(
solutions
,
all_solution_results
)
# remove the tmp.txt which records the eval results for trials
tmp_file
=
os
.
path
.
join
(
TMP_HOME
,
"tmp.txt"
)
if
os
.
path
.
exists
(
tmp_file
):
os
.
remove
(
tmp_file
)
return
solutions_modeldirs
# collect all solutions_modeldirs
collected_solutions_modeldirs
=
self
.
mpi
.
allgather
(
solutions_modeldirs
)
return_dict
=
{}
for
i
in
collected_solutions_modeldirs
:
return_dict
.
update
(
i
)
return
return_dict
class
HAZero
(
BaseTuningStrategy
):
...
...
paddlehub/autofinetune/mpi_helper.py
0 → 100755
浏览文件 @
fa6f8d16
#!/usr/bin/env python
# -*- coding: utf-8 -*-
class
MPIHelper
(
object
):
def
__init__
(
self
):
try
:
from
mpi4py
import
MPI
except
:
# local run
self
.
_size
=
1
self
.
_rank
=
0
self
.
_multi_machine
=
False
import
socket
self
.
_name
=
socket
.
gethostname
()
else
:
# in mpi environment
self
.
_comm
=
MPI
.
COMM_WORLD
self
.
_size
=
self
.
_comm
.
Get_size
()
self
.
_rank
=
self
.
_comm
.
Get_rank
()
self
.
_name
=
MPI
.
Get_processor_name
()
if
self
.
_size
>
1
:
self
.
_multi_machine
=
True
else
:
self
.
_multi_machine
=
False
@
property
def
multi_machine
(
self
):
return
self
.
_multi_machine
@
property
def
rank
(
self
):
return
self
.
_rank
@
property
def
size
(
self
):
return
self
.
_size
@
property
def
name
(
self
):
return
self
.
_name
def
bcast
(
self
,
data
):
if
self
.
_multi_machine
:
# call real bcast
return
self
.
_comm
.
bcast
(
data
,
root
=
0
)
else
:
# do nothing
return
data
def
gather
(
self
,
data
):
if
self
.
_multi_machine
:
# call real gather
return
self
.
_comm
.
gather
(
data
,
root
=
0
)
else
:
# do nothing
return
[
data
]
def
allgather
(
self
,
data
):
if
self
.
_multi_machine
:
# call real allgather
return
self
.
_comm
.
allgather
(
data
)
else
:
# do nothing
return
[
data
]
# calculate split range on mpi environment
def
split_range
(
self
,
array_length
):
if
self
.
_size
==
1
:
return
0
,
array_length
average_count
=
array_length
/
self
.
_size
if
array_length
%
self
.
_size
==
0
:
return
average_count
*
self
.
_rank
,
average_count
*
(
self
.
_rank
+
1
)
else
:
if
self
.
_rank
<
array_length
%
self
.
_size
:
return
(
average_count
+
1
)
*
self
.
_rank
,
(
average_count
+
1
)
*
(
self
.
_rank
+
1
)
else
:
start
=
(
average_count
+
1
)
*
(
array_length
%
self
.
_size
)
\
+
average_count
*
(
self
.
_rank
-
array_length
%
self
.
_size
)
return
start
,
start
+
average_count
if
__name__
==
"__main__"
:
mpi
=
MPIHelper
()
print
(
"Hello world from process {} of {} at {}."
.
format
(
mpi
.
rank
,
mpi
.
size
,
mpi
.
name
))
all_node_names
=
mpi
.
gather
(
mpi
.
name
)
print
(
"all node names using gather: {}"
.
format
(
all_node_names
))
all_node_names
=
mpi
.
allgather
(
mpi
.
name
)
print
(
"all node names using allgather: {}"
.
format
(
all_node_names
))
if
mpi
.
rank
==
0
:
data
=
range
(
10
)
else
:
data
=
None
data
=
mpi
.
bcast
(
data
)
print
(
"after bcast, process {} have data {}"
.
format
(
mpi
.
rank
,
data
))
data
=
[
i
+
mpi
.
rank
for
i
in
data
]
print
(
"after modify, process {} have data {}"
.
format
(
mpi
.
rank
,
data
))
new_data
=
mpi
.
gather
(
data
)
print
(
"after gather, process {} have data {}"
.
format
(
mpi
.
rank
,
new_data
))
# test for split
for
i
in
range
(
12
):
length
=
i
+
mpi
.
size
# length should >= mpi.size
[
start
,
end
]
=
mpi
.
split_range
(
length
)
split_result
=
mpi
.
gather
([
start
,
end
])
print
(
"length {}, split_result {}"
.
format
(
length
,
split_result
))
paddlehub/commands/autofinetune.py
浏览文件 @
fa6f8d16
...
...
@@ -188,37 +188,62 @@ class AutoFineTuneCommand(BaseCommand):
run_round_cnt
=
run_round_cnt
+
1
print
(
"PaddleHub Autofinetune ends."
)
best_hparams_origin
=
autoft
.
get_best_hparams
()
best_hparams_origin
=
autoft
.
mpi
.
bcast
(
best_hparams_origin
)
with
open
(
autoft
.
_output_dir
+
"/log_file.txt"
,
"w"
)
as
f
:
best_hparams
=
evaluator
.
convert_params
(
autoft
.
get_best_hparams
()
)
best_hparams
=
evaluator
.
convert_params
(
best_hparams_origin
)
print
(
"The final best hyperparameters:"
)
f
.
write
(
"The final best hyperparameters:
\n
"
)
for
index
,
hparam_name
in
enumerate
(
autoft
.
hparams_name_list
):
print
(
"%s=%s"
%
(
hparam_name
,
best_hparams
[
index
]))
f
.
write
(
hparam_name
+
"
\t
:
\t
"
+
str
(
best_hparams
[
index
])
+
"
\n
"
)
best_hparams_dir
,
best_hparams_rank
=
solutions_modeldirs
[
tuple
(
best_hparams_origin
)]
print
(
"The final best eval score is %s."
%
autoft
.
get_best_eval_value
())
if
autoft
.
mpi
.
multi_machine
:
print
(
"The final best model parameters are saved as "
+
autoft
.
_output_dir
+
"/best_model on rank "
+
str
(
best_hparams_rank
)
+
" ."
)
else
:
print
(
"The final best model parameters are saved as "
+
autoft
.
_output_dir
+
"/best_model ."
)
f
.
write
(
"The final best eval score is %s.
\n
"
%
autoft
.
get_best_eval_value
())
f
.
write
(
"The final best model parameters are saved as ./best_model ."
)
best_model_dir
=
autoft
.
_output_dir
+
"/best_model"
shutil
.
copytree
(
solutions_modeldirs
[
tuple
(
autoft
.
get_best_hparams
())],
best_model_dir
)
if
autoft
.
mpi
.
rank
==
best_hparams_rank
:
shutil
.
copytree
(
best_hparams_dir
,
best_model_dir
)
if
autoft
.
mpi
.
multi_machine
:
f
.
write
(
"The final best model parameters are saved as ./best_model on rank "
\
+
str
(
best_hparams_rank
)
+
" ."
)
f
.
write
(
"
\t
"
.
join
(
autoft
.
hparams_name_list
)
+
"
\t
saved_params_dir
\t
rank
\n
"
)
else
:
f
.
write
(
"The final best model parameters are saved as ./best_model ."
)
f
.
write
(
"
\t
"
.
join
(
autoft
.
hparams_name_list
)
+
"
\t
saved_params_dir
\n
"
)
print
(
"The related infomation about hyperparamemters searched are saved as %s/log_file.txt ."
%
autoft
.
_output_dir
)
for
solution
,
modeldir
in
solutions_modeldirs
.
items
():
param
=
evaluator
.
convert_params
(
solution
)
param
=
[
str
(
p
)
for
p
in
param
]
f
.
write
(
"
\t
"
.
join
(
param
)
+
"
\t
"
+
modeldir
+
"
\n
"
)
if
autoft
.
mpi
.
multi_machine
:
f
.
write
(
"
\t
"
.
join
(
param
)
+
"
\t
"
+
modeldir
[
0
]
+
"
\t
"
+
str
(
modeldir
[
1
])
+
"
\n
"
)
else
:
f
.
write
(
"
\t
"
.
join
(
param
)
+
"
\t
"
+
modeldir
[
0
]
+
"
\n
"
)
return
True
...
...
paddlehub/commands/install.py
浏览文件 @
fa6f8d16
...
...
@@ -18,6 +18,7 @@ from __future__ import division
from
__future__
import
print_function
import
argparse
import
os
from
paddlehub.common
import
utils
from
paddlehub.module.manager
import
default_module_manager
...
...
@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand):
print
(
"ERROR: Please specify a module name.
\n
"
)
self
.
help
()
return
False
extra
=
{
"command"
:
"install"
}
if
argv
[
0
].
endswith
(
"tar.gz"
)
or
argv
[
0
].
endswith
(
"phm"
):
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_package
=
argv
[
0
],
extra
=
extra
)
elif
os
.
path
.
exists
(
argv
[
0
])
and
os
.
path
.
isdir
(
argv
[
0
]):
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_dir
=
argv
[
0
],
extra
=
extra
)
else
:
module_name
=
argv
[
0
]
module_version
=
None
if
"=="
not
in
module_name
else
module_name
.
split
(
"=="
)[
1
]
module_name
=
module_name
if
"=="
not
in
module_name
else
module_name
.
split
(
"=="
)[
0
]
extra
=
{
"command"
:
"install"
}
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
module_name
,
module_version
=
module_version
,
extra
=
extra
)
module_name
=
module_name
,
module_version
=
module_version
,
extra
=
extra
)
print
(
tips
)
return
True
...
...
paddlehub/commands/run.py
浏览文件 @
fa6f8d16
...
...
@@ -71,7 +71,7 @@ class RunCommand(BaseCommand):
if
not
result
:
return
None
return
hub
.
Module
(
module_dir
=
module_dir
)
return
hub
.
Module
(
directory
=
module_dir
[
0
]
)
def
add_module_config_arg
(
self
):
configs
=
self
.
module
.
processor
.
configs
()
...
...
@@ -105,7 +105,7 @@ class RunCommand(BaseCommand):
def
add_module_input_arg
(
self
):
module_type
=
self
.
module
.
type
.
lower
()
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
self
.
arg_input_group
.
add_argument
(
'--input_file'
,
type
=
str
,
...
...
@@ -152,7 +152,7 @@ class RunCommand(BaseCommand):
def
get_data
(
self
):
module_type
=
self
.
module
.
type
.
lower
()
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
input_data
=
{}
if
len
(
expect_data_format
)
==
1
:
key
=
list
(
expect_data_format
.
keys
())[
0
]
...
...
@@ -177,7 +177,7 @@ class RunCommand(BaseCommand):
def
check_data
(
self
,
data
):
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
if
len
(
data
.
keys
())
!=
len
(
expect_data_format
.
keys
()):
print
(
...
...
@@ -236,10 +236,13 @@ class RunCommand(BaseCommand):
return
False
# If the module is not executable, give an alarm and exit
if
not
self
.
module
.
default_signatur
e
:
if
not
self
.
module
.
is_runabl
e
:
print
(
"ERROR! Module %s is not executable."
%
module_name
)
return
False
if
self
.
module
.
code_version
==
"v2"
:
results
=
self
.
module
(
argv
[
1
:])
else
:
self
.
module
.
check_processor
()
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
...
...
@@ -260,7 +263,7 @@ class RunCommand(BaseCommand):
return
False
results
=
self
.
module
(
sign_name
=
self
.
module
.
default_signature
.
nam
e
,
sign_name
=
self
.
module
.
default_signatur
e
,
data
=
data
,
use_gpu
=
self
.
args
.
use_gpu
,
batch_size
=
self
.
args
.
batch_size
,
...
...
paddlehub/commands/serving.py
浏览文件 @
fa6f8d16
...
...
@@ -159,7 +159,7 @@ class ServingCommand(BaseCommand):
module
=
args
.
modules
if
module
is
not
None
:
use_gpu
=
args
.
use_gpu
port
=
args
.
port
[
0
]
port
=
args
.
port
if
ServingCommand
.
is_port_occupied
(
"127.0.0.1"
,
port
)
is
True
:
print
(
"Port %s is occupied, please change it."
%
(
port
))
return
False
...
...
@@ -206,10 +206,12 @@ class ServingCommand(BaseCommand):
if
args
.
sub_command
==
"start"
:
if
args
.
bert_service
==
"bert_service"
:
ServingCommand
.
start_bert_serving
(
args
)
el
s
e
:
el
if
args
.
bert_service
is
Non
e
:
ServingCommand
.
start_serving
(
args
)
else
:
ServingCommand
.
show_help
()
else
:
ServingCommand
.
show_help
()
command
=
ServingCommand
.
instance
()
paddlehub/commands/show.py
浏览文件 @
fa6f8d16
...
...
@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand):
cwd
=
os
.
getcwd
()
module_dir
=
default_module_manager
.
search_module
(
module_name
)
module_dir
=
(
os
.
path
.
join
(
cwd
,
module_name
),
None
)
if
not
module_dir
else
module_dir
if
not
module_dir
or
not
os
.
path
.
exists
(
module_dir
[
0
]):
print
(
"%s is not existed!"
%
module_name
)
return
True
...
...
paddlehub/common/hub_server.py
浏览文件 @
fa6f8d16
...
...
@@ -298,17 +298,23 @@ class CacheUpdater(threading.Thread):
api_url
=
srv_utils
.
uri_path
(
default_hub_server
.
get_server_url
(),
'search'
)
cache_path
=
os
.
path
.
join
(
CACHE_HOME
,
RESOURCE_LIST_FILE
)
if
os
.
path
.
exists
(
cache_path
):
extra
=
{
"command"
:
"update_cache"
,
"mtime"
:
os
.
stat
(
cache_path
).
st_mtime
}
else
:
extra
=
{
"command"
:
"update_cache"
,
"mtime"
:
time
.
strftime
(
"%Y-%m-%d %H:%M:%S"
,
time
.
localtime
())
}
try
:
r
=
srv_utils
.
hub_request
(
api_url
,
payload
,
extra
)
except
Exception
as
err
:
pass
if
r
.
get
(
"update_cache"
,
0
)
==
1
:
with
open
(
cache_path
,
'w+'
)
as
fp
:
yaml
.
safe_dump
({
'resource_list'
:
r
[
'data'
]},
fp
)
except
Exception
as
err
:
pass
def
run
(
self
):
self
.
update_resource_list_file
(
self
.
module
,
self
.
version
)
...
...
paddlehub/module/check_info.proto
浏览文件 @
fa6f8d16
...
...
@@ -50,6 +50,7 @@ message CheckInfo {
string
paddle_version
=
1
;
string
hub_version
=
2
;
string
module_proto_version
=
3
;
repeated
FileInfo
file_infos
=
4
;
repeated
Requires
requires
=
5
;
string
module_code_version
=
4
;
repeated
FileInfo
file_infos
=
5
;
repeated
Requires
requires
=
6
;
};
paddlehub/module/check_info_pb2.py
浏览文件 @
fa6f8d16
#coding:utf-8
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: check_info.proto
...
...
@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package
=
'paddlehub.module.checkinfo'
,
syntax
=
'proto3'
,
serialized_pb
=
_b
(
'
\n\x10\x63
heck_info.proto
\x12\x1a
paddlehub.module.checkinfo
\"\x85\x01\n\x08\x46
ileInfo
\x12\x11\n\t
file_name
\x18\x01
\x01
(
\t\x12\x33\n\x04
type
\x18\x02
\x01
(
\x0e\x32
%.paddlehub.module.checkinfo.FILE_TYPE
\x12\x0f\n\x07
is_need
\x18\x03
\x01
(
\x08\x12\x0b\n\x03
md5
\x18\x04
\x01
(
\t\x12\x13\n\x0b\x64\x65
scription
\x18\x05
\x01
(
\t\"\x84\x01\n\x08
Requires
\x12
>
\n\x0c
require_type
\x18\x01
\x01
(
\x0e\x32
(.paddlehub.module.checkinfo.REQUIRE_TYPE
\x12\x0f\n\x07
version
\x18\x02
\x01
(
\t\x12\x12\n\n
great_than
\x18\x03
\x01
(
\x08\x12\x13\n\x0b\x64\x65
scription
\x18\x04
\x01
(
\t\"\x
c8\x01\n\t
CheckInfo
\x12\x16\n\x0e
paddle_version
\x18\x01
\x01
(
\t\x12\x13\n\x0b
hub_version
\x18\x02
\x01
(
\t\x12\x1c\n\x14
module_proto_version
\x18\x03
\x01
(
\t\x12\x38\n\n
file_infos
\x18\x04
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.FileInfo
\x12\x36\n\x08
requires
\x18\x05
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.Requires*
\x1e\n\t
FILE_TYPE
\x12\x08\n\x04\x46
ILE
\x10\x00\x12\x07\n\x03\x44
IR
\x10\x01
*[
\n\x0c
REQUIRE_TYPE
\x12\x12\n\x0e
PYTHON_PACKAGE
\x10\x00\x12\x0e\n\n
HUB_MODULE
\x10\x01\x12\n\n\x06
SYSTEM
\x10\x02\x12\x0b\n\x07\x43
OMMAND
\x10\x03\x12\x0e\n\n
PY_VERSION
\x10\x04\x42\x02
H
\x03\x62\x06
proto3'
'
\n\x10\x63
heck_info.proto
\x12\x1a
paddlehub.module.checkinfo
\"\x85\x01\n\x08\x46
ileInfo
\x12\x11\n\t
file_name
\x18\x01
\x01
(
\t\x12\x33\n\x04
type
\x18\x02
\x01
(
\x0e\x32
%.paddlehub.module.checkinfo.FILE_TYPE
\x12\x0f\n\x07
is_need
\x18\x03
\x01
(
\x08\x12\x0b\n\x03
md5
\x18\x04
\x01
(
\t\x12\x13\n\x0b\x64\x65
scription
\x18\x05
\x01
(
\t\"\x84\x01\n\x08
Requires
\x12
>
\n\x0c
require_type
\x18\x01
\x01
(
\x0e\x32
(.paddlehub.module.checkinfo.REQUIRE_TYPE
\x12\x0f\n\x07
version
\x18\x02
\x01
(
\t\x12\x12\n\n
great_than
\x18\x03
\x01
(
\x08\x12\x13\n\x0b\x64\x65
scription
\x18\x04
\x01
(
\t\"\x
e5\x01\n\t
CheckInfo
\x12\x16\n\x0e
paddle_version
\x18\x01
\x01
(
\t\x12\x13\n\x0b
hub_version
\x18\x02
\x01
(
\t\x12\x1c\n\x14
module_proto_version
\x18\x03
\x01
(
\t\x12\x1b\n\x13
module_code_version
\x18\x04
\x01
(
\t\x12\x38\n\n
file_infos
\x18\x05
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.FileInfo
\x12\x36\n\x08
requires
\x18\x06
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.Requires*
\x1e\n\t
FILE_TYPE
\x12\x08\n\x04\x46
ILE
\x10\x00\x12\x07\n\x03\x44
IR
\x10\x01
*[
\n\x0c
REQUIRE_TYPE
\x12\x12\n\x0e
PYTHON_PACKAGE
\x10\x00\x12\x0e\n\n
HUB_MODULE
\x10\x01\x12\n\n\x06
SYSTEM
\x10\x02\x12\x0b\n\x07\x43
OMMAND
\x10\x03\x12\x0e\n\n
PY_VERSION
\x10\x04\x42\x02
H
\x03\x62\x06
proto3'
))
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
...
...
@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
],
containing_type
=
None
,
options
=
None
,
serialized_start
=
5
22
,
serialized_end
=
5
52
,
serialized_start
=
5
51
,
serialized_end
=
5
81
,
)
_sym_db
.
RegisterEnumDescriptor
(
_FILE_TYPE
)
...
...
@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
],
containing_type
=
None
,
options
=
None
,
serialized_start
=
5
54
,
serialized_end
=
6
45
,
serialized_start
=
5
83
,
serialized_end
=
6
74
,
)
_sym_db
.
RegisterEnumDescriptor
(
_REQUIRE_TYPE
)
...
...
@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor(
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'
file_infos
'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.
file_infos
'
,
name
=
'
module_code_version
'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.
module_code_version
'
,
index
=
3
,
number
=
4
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'file_infos'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.file_infos'
,
index
=
4
,
number
=
5
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
...
...
@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor(
_descriptor
.
FieldDescriptor
(
name
=
'requires'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.requires'
,
index
=
4
,
number
=
5
,
index
=
5
,
number
=
6
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
...
...
@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor(
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
320
,
serialized_end
=
5
20
,
serialized_end
=
5
49
,
)
_FILEINFO
.
fields_by_name
[
'type'
].
enum_type
=
_FILE_TYPE
...
...
paddlehub/module/checker.py
浏览文件 @
fa6f8d16
...
...
@@ -32,20 +32,22 @@ FILE_SEP = "/"
class
ModuleChecker
(
object
):
def
__init__
(
self
,
module_path
):
self
.
module_path
=
module_path
def
__init__
(
self
,
directory
):
self
.
_directory
=
directory
self
.
_pb_path
=
os
.
path
.
join
(
self
.
directory
,
CHECK_INFO_PB_FILENAME
)
def
generate_check_info
(
self
):
check_info
=
check_info_pb2
.
CheckInfo
()
check_info
.
paddle_version
=
paddle
.
__version__
check_info
.
hub_version
=
hub_version
check_info
.
module_proto_version
=
module_proto_version
check_info
.
module_code_version
=
"v2"
file_infos
=
check_info
.
file_infos
file_list
=
[
file
for
file
in
os
.
listdir
(
self
.
module_path
)]
file_list
=
[
file
for
file
in
os
.
listdir
(
self
.
directory
)]
while
file_list
:
file
=
file_list
[
0
]
file_list
=
file_list
[
1
:]
abs_path
=
os
.
path
.
join
(
self
.
module_path
,
file
)
abs_path
=
os
.
path
.
join
(
self
.
directory
,
file
)
if
os
.
path
.
isdir
(
abs_path
):
for
sub_file
in
os
.
listdir
(
abs_path
):
sub_file
=
os
.
path
.
join
(
file
,
sub_file
)
...
...
@@ -62,9 +64,12 @@ class ModuleChecker(object):
file_info
.
type
=
check_info_pb2
.
FILE
file_info
.
is_need
=
True
with
open
(
os
.
path
.
join
(
self
.
module_path
,
CHECK_INFO_PB_FILENAME
),
"wb"
)
as
fi
:
fi
.
write
(
check_info
.
SerializeToString
())
with
open
(
self
.
pb_path
,
"wb"
)
as
file
:
file
.
write
(
check_info
.
SerializeToString
())
@
property
def
module_code_version
(
self
):
return
self
.
check_info
.
module_code_version
@
property
def
module_proto_version
(
self
):
...
...
@@ -82,20 +87,25 @@ class ModuleChecker(object):
def
file_infos
(
self
):
return
self
.
check_info
.
file_infos
@
property
def
directory
(
self
):
return
self
.
_directory
@
property
def
pb_path
(
self
):
return
self
.
_pb_path
def
check
(
self
):
result
=
True
self
.
check_info_pb_path
=
os
.
path
.
join
(
self
.
module_path
,
CHECK_INFO_PB_FILENAME
)
if
not
(
os
.
path
.
exists
(
self
.
check_info_pb_path
)
or
os
.
path
.
isfile
(
self
.
check_info_pb_path
)):
if
not
(
os
.
path
.
exists
(
self
.
pb_path
)
or
os
.
path
.
isfile
(
self
.
pb_path
)):
logger
.
warning
(
"This module lacks core file %s"
%
CHECK_INFO_PB_FILENAME
)
result
=
False
self
.
check_info
=
check_info_pb2
.
CheckInfo
()
try
:
with
open
(
self
.
check_info_
pb_path
,
"rb"
)
as
fi
:
with
open
(
self
.
pb_path
,
"rb"
)
as
fi
:
pb_string
=
fi
.
read
()
result
=
self
.
check_info
.
ParseFromString
(
pb_string
)
if
len
(
pb_string
)
==
0
or
(
result
is
not
None
...
...
@@ -182,7 +192,7 @@ class ModuleChecker(object):
for
file_info
in
self
.
file_infos
:
file_type
=
file_info
.
type
file_path
=
file_info
.
file_name
.
replace
(
FILE_SEP
,
os
.
sep
)
file_path
=
os
.
path
.
join
(
self
.
module_path
,
file_path
)
file_path
=
os
.
path
.
join
(
self
.
directory
,
file_path
)
if
not
os
.
path
.
exists
(
file_path
):
if
file_info
.
is_need
:
logger
.
warning
(
...
...
paddlehub/module/manager.py
浏览文件 @
fa6f8d16
...
...
@@ -19,7 +19,9 @@ from __future__ import print_function
import
os
import
shutil
from
functools
import
cmp_to_key
import
tarfile
from
paddlehub.common
import
utils
from
paddlehub.common
import
srv_utils
...
...
@@ -79,10 +81,15 @@ class LocalModuleManager(object):
return
self
.
modules_dict
.
get
(
module_name
,
None
)
def
install_module
(
self
,
module_name
,
module_name
=
None
,
module_dir
=
None
,
module_package
=
None
,
module_version
=
None
,
upgrade
=
False
,
extra
=
None
):
md5_value
=
installed_module_version
=
None
from_user_dir
=
True
if
module_dir
else
False
if
module_name
:
self
.
all_modules
(
update
=
True
)
module_info
=
self
.
modules_dict
.
get
(
module_name
,
None
)
if
module_info
:
...
...
@@ -95,6 +102,62 @@ class LocalModuleManager(object):
module_dir
)
return
True
,
tips
,
self
.
modules_dict
[
module_name
]
search_result
=
hub
.
default_hub_server
.
get_module_url
(
module_name
,
version
=
module_version
,
extra
=
extra
)
name
=
search_result
.
get
(
'name'
,
None
)
url
=
search_result
.
get
(
'url'
,
None
)
md5_value
=
search_result
.
get
(
'md5'
,
None
)
installed_module_version
=
search_result
.
get
(
'version'
,
None
)
if
not
url
or
(
module_version
is
not
None
and
installed_module_version
!=
module_version
)
or
(
name
!=
module_name
):
if
default_hub_server
.
_server_check
()
is
False
:
tips
=
"Request Hub-Server unsuccessfully, please check your network."
else
:
tips
=
"Can't find module %s"
%
module_name
if
module_version
:
tips
+=
" with version %s"
%
module_version
module_tag
=
module_name
if
not
module_version
else
'%s-%s'
%
(
module_name
,
module_version
)
return
False
,
tips
,
None
result
,
tips
,
module_zip_file
=
default_downloader
.
download_file
(
url
=
url
,
save_path
=
hub
.
CACHE_HOME
,
save_name
=
module_name
,
replace
=
True
,
print_progress
=
True
)
result
,
tips
,
module_dir
=
default_downloader
.
uncompress
(
file
=
module_zip_file
,
dirname
=
MODULE_HOME
,
delete_file
=
True
,
print_progress
=
True
)
if
module_package
:
with
tarfile
.
open
(
module_package
,
"r:gz"
)
as
tar
:
file_names
=
tar
.
getnames
()
size
=
len
(
file_names
)
-
1
module_dir
=
os
.
path
.
split
(
file_names
[
0
])[
0
]
module_dir
=
os
.
path
.
join
(
hub
.
CACHE_HOME
,
module_dir
)
# remove cache
if
os
.
path
.
exists
(
module_dir
):
shutil
.
rmtree
(
module_dir
)
for
index
,
file_name
in
enumerate
(
file_names
):
tar
.
extract
(
file_name
,
hub
.
CACHE_HOME
)
if
module_dir
:
if
not
module_name
:
module_name
=
hub
.
Module
(
directory
=
module_dir
).
name
self
.
all_modules
(
update
=
False
)
module_info
=
self
.
modules_dict
.
get
(
module_name
,
None
)
if
module_info
:
module_dir
=
self
.
modules_dict
[
module_name
][
0
]
module_tag
=
module_name
if
not
module_version
else
'%s-%s'
%
(
module_name
,
module_version
)
tips
=
"Module %s already installed in %s"
%
(
module_tag
,
module_dir
)
return
True
,
tips
,
self
.
modules_dict
[
module_name
]
search_result
=
hub
.
default_hub_server
.
get_module_url
(
module_name
,
version
=
module_version
,
extra
=
extra
)
name
=
search_result
.
get
(
'name'
,
None
)
...
...
@@ -162,9 +225,18 @@ class LocalModuleManager(object):
with
open
(
os
.
path
.
join
(
MODULE_HOME
,
module_dir
,
"md5.txt"
),
"w"
)
as
fp
:
fp
.
write
(
md5_value
)
if
md5_value
:
with
open
(
os
.
path
.
join
(
MODULE_HOME
,
module_dir
,
"md5.txt"
),
"w"
)
as
fp
:
fp
.
write
(
md5_value
)
save_path
=
os
.
path
.
join
(
MODULE_HOME
,
module_name
)
if
os
.
path
.
exists
(
save_path
):
shutil
.
rmtree
(
save_path
)
shutil
.
move
(
save_path
)
if
from_user_dir
:
shutil
.
copytree
(
module_dir
,
save_path
)
else
:
shutil
.
move
(
module_dir
,
save_path
)
module_dir
=
save_path
tips
=
"Successfully installed %s"
%
module_name
...
...
paddlehub/module/module.py
浏览文件 @
fa6f8d16
此差异已折叠。
点击以展开。
paddlehub/serving/bert_serving/bert_service.py
浏览文件 @
fa6f8d16
...
...
@@ -14,7 +14,6 @@
# limitations under the License.
import
sys
import
time
import
paddlehub
as
hub
import
ujson
import
random
...
...
@@ -30,7 +29,7 @@ if is_py3:
import
http.client
as
httplib
class
BertService
():
class
BertService
(
object
):
def
__init__
(
self
,
profile
=
False
,
max_seq_len
=
128
,
...
...
@@ -42,7 +41,7 @@ class BertService():
load_balance
=
'round_robin'
):
self
.
process_id
=
process_id
self
.
reader_flag
=
False
self
.
batch_size
=
16
self
.
batch_size
=
0
self
.
max_seq_len
=
max_seq_len
self
.
profile
=
profile
self
.
model_name
=
model_name
...
...
@@ -55,15 +54,6 @@ class BertService():
self
.
feed_var_names
=
''
self
.
retry
=
retry
def
connect
(
self
,
server
=
'127.0.0.1:8010'
):
self
.
server_list
.
append
(
server
)
def
connect_all_server
(
self
,
server_list
):
for
server_str
in
server_list
:
self
.
server_list
.
append
(
server_str
)
def
data_convert
(
self
,
text
):
if
self
.
reader_flag
==
False
:
module
=
hub
.
Module
(
name
=
self
.
model_name
)
inputs
,
outputs
,
program
=
module
.
context
(
trainable
=
True
,
max_seq_len
=
self
.
max_seq_len
)
...
...
@@ -79,10 +69,14 @@ class BertService():
do_lower_case
=
self
.
do_lower_case
)
self
.
reader_flag
=
True
return
self
.
reader
.
data_generator
(
batch_size
=
self
.
batch_size
,
phase
=
'predict'
,
data
=
text
)
def
add_server
(
self
,
server
=
'127.0.0.1:8010'
):
self
.
server_list
.
append
(
server
)
def
infer
(
self
,
request_msg
):
def
add_server_list
(
self
,
server_list
):
for
server_str
in
server_list
:
self
.
server_list
.
append
(
server_str
)
def
request_server
(
self
,
request_msg
):
if
self
.
load_balance
==
'round_robin'
:
try
:
cur_con
=
httplib
.
HTTPConnection
(
...
...
@@ -157,17 +151,13 @@ class BertService():
self
.
server_list
)
return
'retry'
def
encode
(
self
,
text
):
if
type
(
text
)
!=
list
:
raise
TypeError
(
'Only support list'
)
def
prepare_data
(
self
,
text
):
self
.
batch_size
=
len
(
text
)
data_generator
=
self
.
data_convert
(
text
)
start
=
time
.
time
()
request_time
=
0
result
=
[]
data_generator
=
self
.
reader
.
data_generator
(
batch_size
=
self
.
batch_size
,
phase
=
'predict'
,
data
=
text
)
request_msg
=
""
for
run_step
,
batch
in
enumerate
(
data_generator
(),
start
=
1
):
request
=
[]
copy_start
=
time
.
time
()
token_list
=
batch
[
0
][
0
].
reshape
(
-
1
).
tolist
()
pos_list
=
batch
[
0
][
1
].
reshape
(
-
1
).
tolist
()
sent_list
=
batch
[
0
][
2
].
reshape
(
-
1
).
tolist
()
...
...
@@ -184,54 +174,34 @@ class BertService():
si
+
1
)
*
self
.
max_seq_len
]
request
.
append
(
instance_dict
)
copy_time
=
time
.
time
()
-
copy_start
request
=
{
"instances"
:
request
}
request
[
"max_seq_len"
]
=
self
.
max_seq_len
request
[
"feed_var_names"
]
=
self
.
feed_var_names
request_msg
=
ujson
.
dumps
(
request
)
if
self
.
show_ids
:
logger
.
info
(
request_msg
)
request_start
=
time
.
time
()
response_msg
=
self
.
infer
(
request_msg
)
return
request_msg
def
encode
(
self
,
text
):
if
type
(
text
)
!=
list
:
raise
TypeError
(
'Only support list'
)
request_msg
=
self
.
prepare_data
(
text
)
response_msg
=
self
.
request_server
(
request_msg
)
retry
=
0
while
type
(
response_msg
)
==
str
and
response_msg
==
'retry'
:
if
retry
<
self
.
retry
:
retry
+=
1
logger
.
info
(
'Try to connect another servers'
)
response_msg
=
self
.
inf
er
(
request_msg
)
response_msg
=
self
.
request_serv
er
(
request_msg
)
else
:
logger
.
error
(
'Infer
failed after {} times retry'
.
format
(
logger
.
error
(
'Request
failed after {} times retry'
.
format
(
self
.
retry
))
break
result
=
[]
for
msg
in
response_msg
[
"instances"
]:
for
sample
in
msg
[
"instances"
]:
result
.
append
(
sample
[
"values"
])
request_time
+=
time
.
time
()
-
request_start
total_time
=
time
.
time
()
-
start
if
self
.
profile
:
return
[
total_time
,
request_time
,
response_msg
[
'op_time'
],
response_msg
[
'infer_time'
],
copy_time
]
else
:
return
result
def
connect
(
input_text
,
model_name
,
max_seq_len
=
128
,
show_ids
=
False
,
do_lower_case
=
True
,
server
=
"127.0.0.1:8866"
,
retry
=
3
):
# format of input_text like [["As long as"],]
bc
=
BertService
(
max_seq_len
=
max_seq_len
,
model_name
=
model_name
,
show_ids
=
show_ids
,
do_lower_case
=
do_lower_case
,
retry
=
retry
)
bc
.
connect
(
server
)
result
=
bc
.
encode
(
input_text
)
return
result
paddlehub/serving/bert_serving/bs_client.py
0 → 100644
浏览文件 @
fa6f8d16
from
paddlehub.serving.bert_serving
import
bert_service
class
BSClient
(
object
):
def
__init__
(
self
,
module_name
,
server
,
max_seq_len
=
20
,
show_ids
=
False
,
do_lower_case
=
True
,
retry
=
3
):
self
.
bs
=
bert_service
.
BertService
(
model_name
=
module_name
,
max_seq_len
=
max_seq_len
,
show_ids
=
show_ids
,
do_lower_case
=
do_lower_case
,
retry
=
retry
)
self
.
bs
.
add_server
(
server
=
server
)
def
get_result
(
self
,
input_text
):
return
self
.
bs
.
encode
(
input_text
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录