Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
fa6f8d16
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fa6f8d16
编写于
12月 23, 2019
作者:
B
Bin Long
提交者:
GitHub
12月 23, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into add_search_tip
上级
109d5293
81dfda51
变更
19
显示空白变更内容
内联
并排
Showing
19 changed file
with
704 addition
and
529 deletion
+704
-529
.travis.yml
.travis.yml
+0
-10
demo/serving/bert_service/README.md
demo/serving/bert_service/README.md
+16
-12
demo/serving/bert_service/bert_service_client.py
demo/serving/bert_service/bert_service_client.py
+8
-5
paddlehub/__init__.py
paddlehub/__init__.py
+1
-1
paddlehub/autofinetune/autoft.py
paddlehub/autofinetune/autoft.py
+30
-5
paddlehub/autofinetune/mpi_helper.py
paddlehub/autofinetune/mpi_helper.py
+115
-0
paddlehub/commands/autofinetune.py
paddlehub/commands/autofinetune.py
+37
-12
paddlehub/commands/install.py
paddlehub/commands/install.py
+17
-7
paddlehub/commands/run.py
paddlehub/commands/run.py
+29
-26
paddlehub/commands/serving.py
paddlehub/commands/serving.py
+4
-2
paddlehub/commands/show.py
paddlehub/commands/show.py
+0
-2
paddlehub/common/hub_server.py
paddlehub/common/hub_server.py
+13
-7
paddlehub/module/check_info.proto
paddlehub/module/check_info.proto
+3
-2
paddlehub/module/check_info_pb2.py
paddlehub/module/check_info_pb2.py
+26
-11
paddlehub/module/checker.py
paddlehub/module/checker.py
+23
-13
paddlehub/module/manager.py
paddlehub/module/manager.py
+80
-8
paddlehub/module/module.py
paddlehub/module/module.py
+232
-327
paddlehub/serving/bert_serving/bert_service.py
paddlehub/serving/bert_serving/bert_service.py
+49
-79
paddlehub/serving/bert_serving/bs_client.py
paddlehub/serving/bert_serving/bs_client.py
+21
-0
未找到文件。
.travis.yml
浏览文件 @
fa6f8d16
...
@@ -16,12 +16,6 @@ jobs:
...
@@ -16,12 +16,6 @@ jobs:
os
:
linux
os
:
linux
python
:
3.6
python
:
3.6
script
:
/bin/bash ./scripts/check_code_style.sh
script
:
/bin/bash ./scripts/check_code_style.sh
-
name
:
"
CI
on
Linux/Python3.5"
os
:
linux
python
:
3.5
-
name
:
"
CI
on
Linux/Python2.7"
os
:
linux
python
:
2.7
env
:
env
:
-
PYTHONPATH=${PWD}
-
PYTHONPATH=${PWD}
...
@@ -30,10 +24,6 @@ install:
...
@@ -30,10 +24,6 @@ install:
-
pip install --upgrade paddlepaddle
-
pip install --upgrade paddlepaddle
-
pip install -r requirements.txt
-
pip install -r requirements.txt
script
:
-
if [[ $TRAVIS_OS_NAME != "windows" ]]; then /bin/bash ./scripts/test_cml.sh; fi
-
if [[ $TRAVIS_OS_NAME != "windows" ]]; then /bin/bash ./scripts/test_all_module.sh; fi
notifications
:
notifications
:
email
:
email
:
on_success
:
change
on_success
:
change
...
...
demo/serving/bert_service/README.md
浏览文件 @
fa6f8d16
...
@@ -68,7 +68,7 @@ $ pip install ujson
...
@@ -68,7 +68,7 @@ $ pip install ujson
|模型|网络|
|模型|网络|
|:-|:-:|
|:-|:-:|
|
[
ERNIE
](
https://paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel
)
|ERNIE|
|
[
ernie
](
https://paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel
)
|ERNIE|
|
[
ernie_tiny
](
https://paddlepaddle.org.cn/hubdetail?name=ernie_tiny&en_category=SemanticModel
)
|ERNIE|
|
[
ernie_tiny
](
https://paddlepaddle.org.cn/hubdetail?name=ernie_tiny&en_category=SemanticModel
)
|ERNIE|
|
[
ernie_v2_eng_large
](
https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_large&en_category=SemanticModel
)
|ERNIE|
|
[
ernie_v2_eng_large
](
https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_large&en_category=SemanticModel
)
|ERNIE|
|
[
ernie_v2_eng_base
](
https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_base&en_category=SemanticModel
)
|ERNIE|
|
[
ernie_v2_eng_base
](
https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_base&en_category=SemanticModel
)
|ERNIE|
...
@@ -179,18 +179,22 @@ Server[baidu::paddle_serving::predictor::bert_service::BertServiceImpl] is servi
...
@@ -179,18 +179,22 @@ Server[baidu::paddle_serving::predictor::bert_service::BertServiceImpl] is servi
首先导入客户端依赖。
首先导入客户端依赖。
```
python
```
python
from
paddlehub.serving.bert_serving
import
b
ert_service
from
paddlehub.serving.bert_serving
import
b
s_client
```
```
接着输入文本信息。
接着启动并初始化
`bert service`
客户端
`BSClient`
(这里的server为虚拟地址,需根据自己实际ip设置)
```
python
bc
=
bs_client
.
BSClient
(
module_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
```
然后输入文本信息。
```
python
```
python
input_text
=
[[
"西风吹老洞庭波"
],
[
"一夜湘君白发多"
],
[
"醉后不知天在水"
],
[
"满船清梦压星河"
],
]
input_text
=
[[
"西风吹老洞庭波"
],
[
"一夜湘君白发多"
],
[
"醉后不知天在水"
],
[
"满船清梦压星河"
],
]
```
```
然后利用客户端接口发送文本到服务端,以获取embedding结果(server为虚拟地址,需根据自己实际ip设置)。
最后利用客户端接口
`get_result`
发送文本到服务端,以获取embedding结果。
```
python
```
python
result
=
bert_service
.
connect
(
result
=
bc
.
get_result
(
input_text
=
input_text
)
input_text
=
input_text
,
model_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
```
```
最后即可得到embedding结果(此处只展示部分结果)。
最后即可得到embedding结果(此处只展示部分结果)。
```
python
```
python
...
@@ -221,8 +225,8 @@ Paddle Inference Server exit successfully!
...
@@ -221,8 +225,8 @@ Paddle Inference Server exit successfully!
> Q : 如何在一台服务器部署多个模型?
> Q : 如何在一台服务器部署多个模型?
> A : 可通过多次启动`Bert Service`,分配不同端口实现。如果使用GPU,需要指定不同的显卡。如同时部署`ernie`和`bert_chinese_L-12_H-768_A-12`,分别执行命令如下:
> A : 可通过多次启动`Bert Service`,分配不同端口实现。如果使用GPU,需要指定不同的显卡。如同时部署`ernie`和`bert_chinese_L-12_H-768_A-12`,分别执行命令如下:
> ```shell
> ```shell
> $ hub serving start bert_servi
ng
-m ernie -p 8866
> $ hub serving start bert_servi
ce
-m ernie -p 8866
> $ hub serving start bert_servi
ng -m bert_serving
-m bert_chinese_L-12_H-768_A-12 -p 8867
> $ hub serving start bert_servi
ce
-m bert_chinese_L-12_H-768_A-12 -p 8867
> ```
> ```
> Q : 启动时显示"Check out http://yq01-gpu-255-129-12-00.epc.baidu.com:8887 in web
> Q : 启动时显示"Check out http://yq01-gpu-255-129-12-00.epc.baidu.com:8887 in web
...
...
demo/serving/bert_service/bert_service_client.py
浏览文件 @
fa6f8d16
# coding: utf8
# coding: utf8
from
paddlehub.serving.bert_serving
import
b
ert_service
from
paddlehub.serving.bert_serving
import
b
s_client
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# 初始化bert_service客户端BSClient
bc
=
bs_client
.
BSClient
(
module_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
# 输入要做embedding的文本
# 输入要做embedding的文本
# 文本格式为[["文本1"], ["文本2"], ]
# 文本格式为[["文本1"], ["文本2"], ]
input_text
=
[
input_text
=
[
...
@@ -10,10 +13,10 @@ if __name__ == "__main__":
...
@@ -10,10 +13,10 @@ if __name__ == "__main__":
[
"醉后不知天在水"
],
[
"醉后不知天在水"
],
[
"满船清梦压星河"
],
[
"满船清梦压星河"
],
]
]
# 调用客户端接口bert_service.connect()获取结果
result
=
bert_service
.
connect
(
input_text
=
input_text
,
model_name
=
"ernie_tiny"
,
server
=
"127.0.0.1:8866"
)
# 打印embedding结果
# BSClient.get_result()获取结果
result
=
bc
.
get_result
(
input_text
=
input_text
)
# 打印输入文本的embedding结果
for
item
in
result
:
for
item
in
result
:
print
(
item
)
print
(
item
)
paddlehub/__init__.py
浏览文件 @
fa6f8d16
...
@@ -38,7 +38,7 @@ from .common.logger import logger
...
@@ -38,7 +38,7 @@ from .common.logger import logger
from
.common.paddle_helper
import
connect_program
from
.common.paddle_helper
import
connect_program
from
.common.hub_server
import
default_hub_server
from
.common.hub_server
import
default_hub_server
from
.module.module
import
Module
,
create_module
from
.module.module
import
Module
from
.module.base_processor
import
BaseProcessor
from
.module.base_processor
import
BaseProcessor
from
.module.signature
import
Signature
,
create_signature
from
.module.signature
import
Signature
,
create_signature
from
.module.manager
import
default_module_manager
from
.module.manager
import
default_module_manager
...
...
paddlehub/autofinetune/autoft.py
浏览文件 @
fa6f8d16
...
@@ -26,6 +26,7 @@ from tb_paddle import SummaryWriter
...
@@ -26,6 +26,7 @@ from tb_paddle import SummaryWriter
from
paddlehub.common.logger
import
logger
from
paddlehub.common.logger
import
logger
from
paddlehub.common.utils
import
mkdir
from
paddlehub.common.utils
import
mkdir
from
paddlehub.autofinetune.evaluator
import
REWARD_SUM
,
TMP_HOME
from
paddlehub.autofinetune.evaluator
import
REWARD_SUM
,
TMP_HOME
from
paddlehub.autofinetune.mpi_helper
import
MPIHelper
if
six
.
PY3
:
if
six
.
PY3
:
INF
=
math
.
inf
INF
=
math
.
inf
...
@@ -75,6 +76,12 @@ class BaseTuningStrategy(object):
...
@@ -75,6 +76,12 @@ class BaseTuningStrategy(object):
logdir
=
self
.
_output_dir
+
'/visualization/pop_{}'
.
format
(
i
))
logdir
=
self
.
_output_dir
+
'/visualization/pop_{}'
.
format
(
i
))
self
.
writer_pop_trails
.
append
(
writer_pop_trail
)
self
.
writer_pop_trails
.
append
(
writer_pop_trail
)
# for parallel on mpi
self
.
mpi
=
MPIHelper
()
if
self
.
mpi
.
multi_machine
:
print
(
"Autofinetune multimachine mode: running on {}"
.
format
(
self
.
mpi
.
gather
(
self
.
mpi
.
name
)))
@
property
@
property
def
thread
(
self
):
def
thread
(
self
):
return
self
.
_num_thread
return
self
.
_num_thread
...
@@ -177,16 +184,22 @@ class BaseTuningStrategy(object):
...
@@ -177,16 +184,22 @@ class BaseTuningStrategy(object):
solutions_modeldirs
=
{}
solutions_modeldirs
=
{}
mkdir
(
output_dir
)
mkdir
(
output_dir
)
for
idx
,
solution
in
enumerate
(
solutions
):
solutions
=
self
.
mpi
.
bcast
(
solutions
)
# split solutions to "solutions for me"
range_start
,
range_end
=
self
.
mpi
.
split_range
(
len
(
solutions
))
my_solutions
=
solutions
[
range_start
:
range_end
]
for
idx
,
solution
in
enumerate
(
my_solutions
):
cuda
=
self
.
is_cuda_free
[
"free"
][
0
]
cuda
=
self
.
is_cuda_free
[
"free"
][
0
]
modeldir
=
output_dir
+
"/model-"
+
str
(
idx
)
+
"/"
modeldir
=
output_dir
+
"/model-"
+
str
(
idx
)
+
"/"
log_file
=
output_dir
+
"/log-"
+
str
(
idx
)
+
".info"
log_file
=
output_dir
+
"/log-"
+
str
(
idx
)
+
".info"
params_cudas_dirs
.
append
([
solution
,
cuda
,
modeldir
,
log_file
])
params_cudas_dirs
.
append
([
solution
,
cuda
,
modeldir
,
log_file
])
solutions_modeldirs
[
tuple
(
solution
)]
=
modeldir
solutions_modeldirs
[
tuple
(
solution
)]
=
(
modeldir
,
self
.
mpi
.
rank
)
self
.
is_cuda_free
[
"free"
].
remove
(
cuda
)
self
.
is_cuda_free
[
"free"
].
remove
(
cuda
)
self
.
is_cuda_free
[
"busy"
].
append
(
cuda
)
self
.
is_cuda_free
[
"busy"
].
append
(
cuda
)
if
len
(
params_cudas_dirs
if
len
(
params_cudas_dirs
)
==
self
.
thread
or
idx
==
len
(
solutions
)
-
1
:
)
==
self
.
thread
or
idx
==
len
(
my_
solutions
)
-
1
:
tp
=
ThreadPool
(
len
(
params_cudas_dirs
))
tp
=
ThreadPool
(
len
(
params_cudas_dirs
))
solution_results
+=
tp
.
map
(
self
.
evaluator
.
run
,
solution_results
+=
tp
.
map
(
self
.
evaluator
.
run
,
params_cudas_dirs
)
params_cudas_dirs
)
...
@@ -198,13 +211,25 @@ class BaseTuningStrategy(object):
...
@@ -198,13 +211,25 @@ class BaseTuningStrategy(object):
self
.
is_cuda_free
[
"busy"
].
remove
(
param_cuda
[
1
])
self
.
is_cuda_free
[
"busy"
].
remove
(
param_cuda
[
1
])
params_cudas_dirs
=
[]
params_cudas_dirs
=
[]
self
.
feedback
(
solutions
,
solution_results
)
all_solution_results
=
self
.
mpi
.
gather
(
solution_results
)
if
self
.
mpi
.
rank
==
0
:
# only rank 0 need to feedback
all_solution_results
=
[
y
for
x
in
all_solution_results
for
y
in
x
]
self
.
feedback
(
solutions
,
all_solution_results
)
# remove the tmp.txt which records the eval results for trials
# remove the tmp.txt which records the eval results for trials
tmp_file
=
os
.
path
.
join
(
TMP_HOME
,
"tmp.txt"
)
tmp_file
=
os
.
path
.
join
(
TMP_HOME
,
"tmp.txt"
)
if
os
.
path
.
exists
(
tmp_file
):
if
os
.
path
.
exists
(
tmp_file
):
os
.
remove
(
tmp_file
)
os
.
remove
(
tmp_file
)
return
solutions_modeldirs
# collect all solutions_modeldirs
collected_solutions_modeldirs
=
self
.
mpi
.
allgather
(
solutions_modeldirs
)
return_dict
=
{}
for
i
in
collected_solutions_modeldirs
:
return_dict
.
update
(
i
)
return
return_dict
class
HAZero
(
BaseTuningStrategy
):
class
HAZero
(
BaseTuningStrategy
):
...
...
paddlehub/autofinetune/mpi_helper.py
0 → 100755
浏览文件 @
fa6f8d16
#!/usr/bin/env python
# -*- coding: utf-8 -*-
class
MPIHelper
(
object
):
def
__init__
(
self
):
try
:
from
mpi4py
import
MPI
except
:
# local run
self
.
_size
=
1
self
.
_rank
=
0
self
.
_multi_machine
=
False
import
socket
self
.
_name
=
socket
.
gethostname
()
else
:
# in mpi environment
self
.
_comm
=
MPI
.
COMM_WORLD
self
.
_size
=
self
.
_comm
.
Get_size
()
self
.
_rank
=
self
.
_comm
.
Get_rank
()
self
.
_name
=
MPI
.
Get_processor_name
()
if
self
.
_size
>
1
:
self
.
_multi_machine
=
True
else
:
self
.
_multi_machine
=
False
@
property
def
multi_machine
(
self
):
return
self
.
_multi_machine
@
property
def
rank
(
self
):
return
self
.
_rank
@
property
def
size
(
self
):
return
self
.
_size
@
property
def
name
(
self
):
return
self
.
_name
def
bcast
(
self
,
data
):
if
self
.
_multi_machine
:
# call real bcast
return
self
.
_comm
.
bcast
(
data
,
root
=
0
)
else
:
# do nothing
return
data
def
gather
(
self
,
data
):
if
self
.
_multi_machine
:
# call real gather
return
self
.
_comm
.
gather
(
data
,
root
=
0
)
else
:
# do nothing
return
[
data
]
def
allgather
(
self
,
data
):
if
self
.
_multi_machine
:
# call real allgather
return
self
.
_comm
.
allgather
(
data
)
else
:
# do nothing
return
[
data
]
# calculate split range on mpi environment
def
split_range
(
self
,
array_length
):
if
self
.
_size
==
1
:
return
0
,
array_length
average_count
=
array_length
/
self
.
_size
if
array_length
%
self
.
_size
==
0
:
return
average_count
*
self
.
_rank
,
average_count
*
(
self
.
_rank
+
1
)
else
:
if
self
.
_rank
<
array_length
%
self
.
_size
:
return
(
average_count
+
1
)
*
self
.
_rank
,
(
average_count
+
1
)
*
(
self
.
_rank
+
1
)
else
:
start
=
(
average_count
+
1
)
*
(
array_length
%
self
.
_size
)
\
+
average_count
*
(
self
.
_rank
-
array_length
%
self
.
_size
)
return
start
,
start
+
average_count
if
__name__
==
"__main__"
:
mpi
=
MPIHelper
()
print
(
"Hello world from process {} of {} at {}."
.
format
(
mpi
.
rank
,
mpi
.
size
,
mpi
.
name
))
all_node_names
=
mpi
.
gather
(
mpi
.
name
)
print
(
"all node names using gather: {}"
.
format
(
all_node_names
))
all_node_names
=
mpi
.
allgather
(
mpi
.
name
)
print
(
"all node names using allgather: {}"
.
format
(
all_node_names
))
if
mpi
.
rank
==
0
:
data
=
range
(
10
)
else
:
data
=
None
data
=
mpi
.
bcast
(
data
)
print
(
"after bcast, process {} have data {}"
.
format
(
mpi
.
rank
,
data
))
data
=
[
i
+
mpi
.
rank
for
i
in
data
]
print
(
"after modify, process {} have data {}"
.
format
(
mpi
.
rank
,
data
))
new_data
=
mpi
.
gather
(
data
)
print
(
"after gather, process {} have data {}"
.
format
(
mpi
.
rank
,
new_data
))
# test for split
for
i
in
range
(
12
):
length
=
i
+
mpi
.
size
# length should >= mpi.size
[
start
,
end
]
=
mpi
.
split_range
(
length
)
split_result
=
mpi
.
gather
([
start
,
end
])
print
(
"length {}, split_result {}"
.
format
(
length
,
split_result
))
paddlehub/commands/autofinetune.py
浏览文件 @
fa6f8d16
...
@@ -188,37 +188,62 @@ class AutoFineTuneCommand(BaseCommand):
...
@@ -188,37 +188,62 @@ class AutoFineTuneCommand(BaseCommand):
run_round_cnt
=
run_round_cnt
+
1
run_round_cnt
=
run_round_cnt
+
1
print
(
"PaddleHub Autofinetune ends."
)
print
(
"PaddleHub Autofinetune ends."
)
best_hparams_origin
=
autoft
.
get_best_hparams
()
best_hparams_origin
=
autoft
.
mpi
.
bcast
(
best_hparams_origin
)
with
open
(
autoft
.
_output_dir
+
"/log_file.txt"
,
"w"
)
as
f
:
with
open
(
autoft
.
_output_dir
+
"/log_file.txt"
,
"w"
)
as
f
:
best_hparams
=
evaluator
.
convert_params
(
autoft
.
get_best_hparams
()
)
best_hparams
=
evaluator
.
convert_params
(
best_hparams_origin
)
print
(
"The final best hyperparameters:"
)
print
(
"The final best hyperparameters:"
)
f
.
write
(
"The final best hyperparameters:
\n
"
)
f
.
write
(
"The final best hyperparameters:
\n
"
)
for
index
,
hparam_name
in
enumerate
(
autoft
.
hparams_name_list
):
for
index
,
hparam_name
in
enumerate
(
autoft
.
hparams_name_list
):
print
(
"%s=%s"
%
(
hparam_name
,
best_hparams
[
index
]))
print
(
"%s=%s"
%
(
hparam_name
,
best_hparams
[
index
]))
f
.
write
(
hparam_name
+
"
\t
:
\t
"
+
str
(
best_hparams
[
index
])
+
"
\n
"
)
f
.
write
(
hparam_name
+
"
\t
:
\t
"
+
str
(
best_hparams
[
index
])
+
"
\n
"
)
best_hparams_dir
,
best_hparams_rank
=
solutions_modeldirs
[
tuple
(
best_hparams_origin
)]
print
(
"The final best eval score is %s."
%
print
(
"The final best eval score is %s."
%
autoft
.
get_best_eval_value
())
autoft
.
get_best_eval_value
())
if
autoft
.
mpi
.
multi_machine
:
print
(
"The final best model parameters are saved as "
+
autoft
.
_output_dir
+
"/best_model on rank "
+
str
(
best_hparams_rank
)
+
" ."
)
else
:
print
(
"The final best model parameters are saved as "
+
print
(
"The final best model parameters are saved as "
+
autoft
.
_output_dir
+
"/best_model ."
)
autoft
.
_output_dir
+
"/best_model ."
)
f
.
write
(
"The final best eval score is %s.
\n
"
%
f
.
write
(
"The final best eval score is %s.
\n
"
%
autoft
.
get_best_eval_value
())
autoft
.
get_best_eval_value
())
f
.
write
(
"The final best model parameters are saved as ./best_model ."
)
best_model_dir
=
autoft
.
_output_dir
+
"/best_model"
best_model_dir
=
autoft
.
_output_dir
+
"/best_model"
shutil
.
copytree
(
solutions_modeldirs
[
tuple
(
autoft
.
get_best_hparams
())],
best_model_dir
)
if
autoft
.
mpi
.
rank
==
best_hparams_rank
:
shutil
.
copytree
(
best_hparams_dir
,
best_model_dir
)
if
autoft
.
mpi
.
multi_machine
:
f
.
write
(
"The final best model parameters are saved as ./best_model on rank "
\
+
str
(
best_hparams_rank
)
+
" ."
)
f
.
write
(
"
\t
"
.
join
(
autoft
.
hparams_name_list
)
+
"
\t
saved_params_dir
\t
rank
\n
"
)
else
:
f
.
write
(
"The final best model parameters are saved as ./best_model ."
)
f
.
write
(
"
\t
"
.
join
(
autoft
.
hparams_name_list
)
+
f
.
write
(
"
\t
"
.
join
(
autoft
.
hparams_name_list
)
+
"
\t
saved_params_dir
\n
"
)
"
\t
saved_params_dir
\n
"
)
print
(
print
(
"The related infomation about hyperparamemters searched are saved as %s/log_file.txt ."
"The related infomation about hyperparamemters searched are saved as %s/log_file.txt ."
%
autoft
.
_output_dir
)
%
autoft
.
_output_dir
)
for
solution
,
modeldir
in
solutions_modeldirs
.
items
():
for
solution
,
modeldir
in
solutions_modeldirs
.
items
():
param
=
evaluator
.
convert_params
(
solution
)
param
=
evaluator
.
convert_params
(
solution
)
param
=
[
str
(
p
)
for
p
in
param
]
param
=
[
str
(
p
)
for
p
in
param
]
f
.
write
(
"
\t
"
.
join
(
param
)
+
"
\t
"
+
modeldir
+
"
\n
"
)
if
autoft
.
mpi
.
multi_machine
:
f
.
write
(
"
\t
"
.
join
(
param
)
+
"
\t
"
+
modeldir
[
0
]
+
"
\t
"
+
str
(
modeldir
[
1
])
+
"
\n
"
)
else
:
f
.
write
(
"
\t
"
.
join
(
param
)
+
"
\t
"
+
modeldir
[
0
]
+
"
\n
"
)
return
True
return
True
...
...
paddlehub/commands/install.py
浏览文件 @
fa6f8d16
...
@@ -18,6 +18,7 @@ from __future__ import division
...
@@ -18,6 +18,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
argparse
import
argparse
import
os
from
paddlehub.common
import
utils
from
paddlehub.common
import
utils
from
paddlehub.module.manager
import
default_module_manager
from
paddlehub.module.manager
import
default_module_manager
...
@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand):
...
@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand):
print
(
"ERROR: Please specify a module name.
\n
"
)
print
(
"ERROR: Please specify a module name.
\n
"
)
self
.
help
()
self
.
help
()
return
False
return
False
extra
=
{
"command"
:
"install"
}
if
argv
[
0
].
endswith
(
"tar.gz"
)
or
argv
[
0
].
endswith
(
"phm"
):
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_package
=
argv
[
0
],
extra
=
extra
)
elif
os
.
path
.
exists
(
argv
[
0
])
and
os
.
path
.
isdir
(
argv
[
0
]):
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_dir
=
argv
[
0
],
extra
=
extra
)
else
:
module_name
=
argv
[
0
]
module_name
=
argv
[
0
]
module_version
=
None
if
"=="
not
in
module_name
else
module_name
.
split
(
module_version
=
None
if
"=="
not
in
module_name
else
module_name
.
split
(
"=="
)[
1
]
"=="
)[
1
]
module_name
=
module_name
if
"=="
not
in
module_name
else
module_name
.
split
(
module_name
=
module_name
if
"=="
not
in
module_name
else
module_name
.
split
(
"=="
)[
0
]
"=="
)[
0
]
extra
=
{
"command"
:
"install"
}
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
module_name
,
module_version
=
module_version
,
extra
=
extra
)
module_name
=
module_name
,
module_version
=
module_version
,
extra
=
extra
)
print
(
tips
)
print
(
tips
)
return
True
return
True
...
...
paddlehub/commands/run.py
浏览文件 @
fa6f8d16
...
@@ -71,7 +71,7 @@ class RunCommand(BaseCommand):
...
@@ -71,7 +71,7 @@ class RunCommand(BaseCommand):
if
not
result
:
if
not
result
:
return
None
return
None
return
hub
.
Module
(
module_dir
=
module_dir
)
return
hub
.
Module
(
directory
=
module_dir
[
0
]
)
def
add_module_config_arg
(
self
):
def
add_module_config_arg
(
self
):
configs
=
self
.
module
.
processor
.
configs
()
configs
=
self
.
module
.
processor
.
configs
()
...
@@ -105,7 +105,7 @@ class RunCommand(BaseCommand):
...
@@ -105,7 +105,7 @@ class RunCommand(BaseCommand):
def
add_module_input_arg
(
self
):
def
add_module_input_arg
(
self
):
module_type
=
self
.
module
.
type
.
lower
()
module_type
=
self
.
module
.
type
.
lower
()
expect_data_format
=
self
.
module
.
processor
.
data_format
(
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
self
.
arg_input_group
.
add_argument
(
self
.
arg_input_group
.
add_argument
(
'--input_file'
,
'--input_file'
,
type
=
str
,
type
=
str
,
...
@@ -152,7 +152,7 @@ class RunCommand(BaseCommand):
...
@@ -152,7 +152,7 @@ class RunCommand(BaseCommand):
def
get_data
(
self
):
def
get_data
(
self
):
module_type
=
self
.
module
.
type
.
lower
()
module_type
=
self
.
module
.
type
.
lower
()
expect_data_format
=
self
.
module
.
processor
.
data_format
(
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
input_data
=
{}
input_data
=
{}
if
len
(
expect_data_format
)
==
1
:
if
len
(
expect_data_format
)
==
1
:
key
=
list
(
expect_data_format
.
keys
())[
0
]
key
=
list
(
expect_data_format
.
keys
())[
0
]
...
@@ -177,7 +177,7 @@ class RunCommand(BaseCommand):
...
@@ -177,7 +177,7 @@ class RunCommand(BaseCommand):
def
check_data
(
self
,
data
):
def
check_data
(
self
,
data
):
expect_data_format
=
self
.
module
.
processor
.
data_format
(
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
if
len
(
data
.
keys
())
!=
len
(
expect_data_format
.
keys
()):
if
len
(
data
.
keys
())
!=
len
(
expect_data_format
.
keys
()):
print
(
print
(
...
@@ -236,10 +236,13 @@ class RunCommand(BaseCommand):
...
@@ -236,10 +236,13 @@ class RunCommand(BaseCommand):
return
False
return
False
# If the module is not executable, give an alarm and exit
# If the module is not executable, give an alarm and exit
if
not
self
.
module
.
default_signatur
e
:
if
not
self
.
module
.
is_runabl
e
:
print
(
"ERROR! Module %s is not executable."
%
module_name
)
print
(
"ERROR! Module %s is not executable."
%
module_name
)
return
False
return
False
if
self
.
module
.
code_version
==
"v2"
:
results
=
self
.
module
(
argv
[
1
:])
else
:
self
.
module
.
check_processor
()
self
.
module
.
check_processor
()
self
.
add_module_config_arg
()
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
self
.
add_module_input_arg
()
...
@@ -260,7 +263,7 @@ class RunCommand(BaseCommand):
...
@@ -260,7 +263,7 @@ class RunCommand(BaseCommand):
return
False
return
False
results
=
self
.
module
(
results
=
self
.
module
(
sign_name
=
self
.
module
.
default_signature
.
nam
e
,
sign_name
=
self
.
module
.
default_signatur
e
,
data
=
data
,
data
=
data
,
use_gpu
=
self
.
args
.
use_gpu
,
use_gpu
=
self
.
args
.
use_gpu
,
batch_size
=
self
.
args
.
batch_size
,
batch_size
=
self
.
args
.
batch_size
,
...
...
paddlehub/commands/serving.py
浏览文件 @
fa6f8d16
...
@@ -159,7 +159,7 @@ class ServingCommand(BaseCommand):
...
@@ -159,7 +159,7 @@ class ServingCommand(BaseCommand):
module
=
args
.
modules
module
=
args
.
modules
if
module
is
not
None
:
if
module
is
not
None
:
use_gpu
=
args
.
use_gpu
use_gpu
=
args
.
use_gpu
port
=
args
.
port
[
0
]
port
=
args
.
port
if
ServingCommand
.
is_port_occupied
(
"127.0.0.1"
,
port
)
is
True
:
if
ServingCommand
.
is_port_occupied
(
"127.0.0.1"
,
port
)
is
True
:
print
(
"Port %s is occupied, please change it."
%
(
port
))
print
(
"Port %s is occupied, please change it."
%
(
port
))
return
False
return
False
...
@@ -206,10 +206,12 @@ class ServingCommand(BaseCommand):
...
@@ -206,10 +206,12 @@ class ServingCommand(BaseCommand):
if
args
.
sub_command
==
"start"
:
if
args
.
sub_command
==
"start"
:
if
args
.
bert_service
==
"bert_service"
:
if
args
.
bert_service
==
"bert_service"
:
ServingCommand
.
start_bert_serving
(
args
)
ServingCommand
.
start_bert_serving
(
args
)
el
s
e
:
el
if
args
.
bert_service
is
Non
e
:
ServingCommand
.
start_serving
(
args
)
ServingCommand
.
start_serving
(
args
)
else
:
else
:
ServingCommand
.
show_help
()
ServingCommand
.
show_help
()
else
:
ServingCommand
.
show_help
()
command
=
ServingCommand
.
instance
()
command
=
ServingCommand
.
instance
()
paddlehub/commands/show.py
浏览文件 @
fa6f8d16
...
@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand):
...
@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand):
cwd
=
os
.
getcwd
()
cwd
=
os
.
getcwd
()
module_dir
=
default_module_manager
.
search_module
(
module_name
)
module_dir
=
default_module_manager
.
search_module
(
module_name
)
module_dir
=
(
os
.
path
.
join
(
cwd
,
module_name
),
None
)
if
not
module_dir
else
module_dir
if
not
module_dir
or
not
os
.
path
.
exists
(
module_dir
[
0
]):
if
not
module_dir
or
not
os
.
path
.
exists
(
module_dir
[
0
]):
print
(
"%s is not existed!"
%
module_name
)
print
(
"%s is not existed!"
%
module_name
)
return
True
return
True
...
...
paddlehub/common/hub_server.py
浏览文件 @
fa6f8d16
...
@@ -298,17 +298,23 @@ class CacheUpdater(threading.Thread):
...
@@ -298,17 +298,23 @@ class CacheUpdater(threading.Thread):
api_url
=
srv_utils
.
uri_path
(
default_hub_server
.
get_server_url
(),
api_url
=
srv_utils
.
uri_path
(
default_hub_server
.
get_server_url
(),
'search'
)
'search'
)
cache_path
=
os
.
path
.
join
(
CACHE_HOME
,
RESOURCE_LIST_FILE
)
cache_path
=
os
.
path
.
join
(
CACHE_HOME
,
RESOURCE_LIST_FILE
)
if
os
.
path
.
exists
(
cache_path
):
extra
=
{
extra
=
{
"command"
:
"update_cache"
,
"command"
:
"update_cache"
,
"mtime"
:
os
.
stat
(
cache_path
).
st_mtime
"mtime"
:
os
.
stat
(
cache_path
).
st_mtime
}
}
else
:
extra
=
{
"command"
:
"update_cache"
,
"mtime"
:
time
.
strftime
(
"%Y-%m-%d %H:%M:%S"
,
time
.
localtime
())
}
try
:
try
:
r
=
srv_utils
.
hub_request
(
api_url
,
payload
,
extra
)
r
=
srv_utils
.
hub_request
(
api_url
,
payload
,
extra
)
except
Exception
as
err
:
pass
if
r
.
get
(
"update_cache"
,
0
)
==
1
:
if
r
.
get
(
"update_cache"
,
0
)
==
1
:
with
open
(
cache_path
,
'w+'
)
as
fp
:
with
open
(
cache_path
,
'w+'
)
as
fp
:
yaml
.
safe_dump
({
'resource_list'
:
r
[
'data'
]},
fp
)
yaml
.
safe_dump
({
'resource_list'
:
r
[
'data'
]},
fp
)
except
Exception
as
err
:
pass
def
run
(
self
):
def
run
(
self
):
self
.
update_resource_list_file
(
self
.
module
,
self
.
version
)
self
.
update_resource_list_file
(
self
.
module
,
self
.
version
)
...
...
paddlehub/module/check_info.proto
浏览文件 @
fa6f8d16
...
@@ -50,6 +50,7 @@ message CheckInfo {
...
@@ -50,6 +50,7 @@ message CheckInfo {
string
paddle_version
=
1
;
string
paddle_version
=
1
;
string
hub_version
=
2
;
string
hub_version
=
2
;
string
module_proto_version
=
3
;
string
module_proto_version
=
3
;
repeated
FileInfo
file_infos
=
4
;
string
module_code_version
=
4
;
repeated
Requires
requires
=
5
;
repeated
FileInfo
file_infos
=
5
;
repeated
Requires
requires
=
6
;
};
};
paddlehub/module/check_info_pb2.py
浏览文件 @
fa6f8d16
#coding:utf-8
# Generated by the protocol buffer compiler. DO NOT EDIT!
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: check_info.proto
# source: check_info.proto
...
@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
...
@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package
=
'paddlehub.module.checkinfo'
,
package
=
'paddlehub.module.checkinfo'
,
syntax
=
'proto3'
,
syntax
=
'proto3'
,
serialized_pb
=
_b
(
serialized_pb
=
_b
(
'
\n\x10\x63
heck_info.proto
\x12\x1a
paddlehub.module.checkinfo
\"\x85\x01\n\x08\x46
ileInfo
\x12\x11\n\t
file_name
\x18\x01
\x01
(
\t\x12\x33\n\x04
type
\x18\x02
\x01
(
\x0e\x32
%.paddlehub.module.checkinfo.FILE_TYPE
\x12\x0f\n\x07
is_need
\x18\x03
\x01
(
\x08\x12\x0b\n\x03
md5
\x18\x04
\x01
(
\t\x12\x13\n\x0b\x64\x65
scription
\x18\x05
\x01
(
\t\"\x84\x01\n\x08
Requires
\x12
>
\n\x0c
require_type
\x18\x01
\x01
(
\x0e\x32
(.paddlehub.module.checkinfo.REQUIRE_TYPE
\x12\x0f\n\x07
version
\x18\x02
\x01
(
\t\x12\x12\n\n
great_than
\x18\x03
\x01
(
\x08\x12\x13\n\x0b\x64\x65
scription
\x18\x04
\x01
(
\t\"\x
c8\x01\n\t
CheckInfo
\x12\x16\n\x0e
paddle_version
\x18\x01
\x01
(
\t\x12\x13\n\x0b
hub_version
\x18\x02
\x01
(
\t\x12\x1c\n\x14
module_proto_version
\x18\x03
\x01
(
\t\x12\x38\n\n
file_infos
\x18\x04
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.FileInfo
\x12\x36\n\x08
requires
\x18\x05
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.Requires*
\x1e\n\t
FILE_TYPE
\x12\x08\n\x04\x46
ILE
\x10\x00\x12\x07\n\x03\x44
IR
\x10\x01
*[
\n\x0c
REQUIRE_TYPE
\x12\x12\n\x0e
PYTHON_PACKAGE
\x10\x00\x12\x0e\n\n
HUB_MODULE
\x10\x01\x12\n\n\x06
SYSTEM
\x10\x02\x12\x0b\n\x07\x43
OMMAND
\x10\x03\x12\x0e\n\n
PY_VERSION
\x10\x04\x42\x02
H
\x03\x62\x06
proto3'
'
\n\x10\x63
heck_info.proto
\x12\x1a
paddlehub.module.checkinfo
\"\x85\x01\n\x08\x46
ileInfo
\x12\x11\n\t
file_name
\x18\x01
\x01
(
\t\x12\x33\n\x04
type
\x18\x02
\x01
(
\x0e\x32
%.paddlehub.module.checkinfo.FILE_TYPE
\x12\x0f\n\x07
is_need
\x18\x03
\x01
(
\x08\x12\x0b\n\x03
md5
\x18\x04
\x01
(
\t\x12\x13\n\x0b\x64\x65
scription
\x18\x05
\x01
(
\t\"\x84\x01\n\x08
Requires
\x12
>
\n\x0c
require_type
\x18\x01
\x01
(
\x0e\x32
(.paddlehub.module.checkinfo.REQUIRE_TYPE
\x12\x0f\n\x07
version
\x18\x02
\x01
(
\t\x12\x12\n\n
great_than
\x18\x03
\x01
(
\x08\x12\x13\n\x0b\x64\x65
scription
\x18\x04
\x01
(
\t\"\x
e5\x01\n\t
CheckInfo
\x12\x16\n\x0e
paddle_version
\x18\x01
\x01
(
\t\x12\x13\n\x0b
hub_version
\x18\x02
\x01
(
\t\x12\x1c\n\x14
module_proto_version
\x18\x03
\x01
(
\t\x12\x1b\n\x13
module_code_version
\x18\x04
\x01
(
\t\x12\x38\n\n
file_infos
\x18\x05
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.FileInfo
\x12\x36\n\x08
requires
\x18\x06
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.Requires*
\x1e\n\t
FILE_TYPE
\x12\x08\n\x04\x46
ILE
\x10\x00\x12\x07\n\x03\x44
IR
\x10\x01
*[
\n\x0c
REQUIRE_TYPE
\x12\x12\n\x0e
PYTHON_PACKAGE
\x10\x00\x12\x0e\n\n
HUB_MODULE
\x10\x01\x12\n\n\x06
SYSTEM
\x10\x02\x12\x0b\n\x07\x43
OMMAND
\x10\x03\x12\x0e\n\n
PY_VERSION
\x10\x04\x42\x02
H
\x03\x62\x06
proto3'
))
))
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
...
@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
...
@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
],
],
containing_type
=
None
,
containing_type
=
None
,
options
=
None
,
options
=
None
,
serialized_start
=
5
22
,
serialized_start
=
5
51
,
serialized_end
=
5
52
,
serialized_end
=
5
81
,
)
)
_sym_db
.
RegisterEnumDescriptor
(
_FILE_TYPE
)
_sym_db
.
RegisterEnumDescriptor
(
_FILE_TYPE
)
...
@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
...
@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
],
],
containing_type
=
None
,
containing_type
=
None
,
options
=
None
,
options
=
None
,
serialized_start
=
5
54
,
serialized_start
=
5
83
,
serialized_end
=
6
45
,
serialized_end
=
6
74
,
)
)
_sym_db
.
RegisterEnumDescriptor
(
_REQUIRE_TYPE
)
_sym_db
.
RegisterEnumDescriptor
(
_REQUIRE_TYPE
)
...
@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor(
...
@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor(
extension_scope
=
None
,
extension_scope
=
None
,
options
=
None
),
options
=
None
),
_descriptor
.
FieldDescriptor
(
_descriptor
.
FieldDescriptor
(
name
=
'
file_infos
'
,
name
=
'
module_code_version
'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.
file_infos
'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.
module_code_version
'
,
index
=
3
,
index
=
3
,
number
=
4
,
number
=
4
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'file_infos'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.file_infos'
,
index
=
4
,
number
=
5
,
type
=
11
,
type
=
11
,
cpp_type
=
10
,
cpp_type
=
10
,
label
=
3
,
label
=
3
,
...
@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor(
...
@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor(
_descriptor
.
FieldDescriptor
(
_descriptor
.
FieldDescriptor
(
name
=
'requires'
,
name
=
'requires'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.requires'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.requires'
,
index
=
4
,
index
=
5
,
number
=
5
,
number
=
6
,
type
=
11
,
type
=
11
,
cpp_type
=
10
,
cpp_type
=
10
,
label
=
3
,
label
=
3
,
...
@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor(
...
@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor(
extension_ranges
=
[],
extension_ranges
=
[],
oneofs
=
[],
oneofs
=
[],
serialized_start
=
320
,
serialized_start
=
320
,
serialized_end
=
5
20
,
serialized_end
=
5
49
,
)
)
_FILEINFO
.
fields_by_name
[
'type'
].
enum_type
=
_FILE_TYPE
_FILEINFO
.
fields_by_name
[
'type'
].
enum_type
=
_FILE_TYPE
...
...
paddlehub/module/checker.py
浏览文件 @
fa6f8d16
...
@@ -32,20 +32,22 @@ FILE_SEP = "/"
...
@@ -32,20 +32,22 @@ FILE_SEP = "/"
class
ModuleChecker
(
object
):
class
ModuleChecker
(
object
):
def
__init__
(
self
,
module_path
):
def
__init__
(
self
,
directory
):
self
.
module_path
=
module_path
self
.
_directory
=
directory
self
.
_pb_path
=
os
.
path
.
join
(
self
.
directory
,
CHECK_INFO_PB_FILENAME
)
def
generate_check_info
(
self
):
def
generate_check_info
(
self
):
check_info
=
check_info_pb2
.
CheckInfo
()
check_info
=
check_info_pb2
.
CheckInfo
()
check_info
.
paddle_version
=
paddle
.
__version__
check_info
.
paddle_version
=
paddle
.
__version__
check_info
.
hub_version
=
hub_version
check_info
.
hub_version
=
hub_version
check_info
.
module_proto_version
=
module_proto_version
check_info
.
module_proto_version
=
module_proto_version
check_info
.
module_code_version
=
"v2"
file_infos
=
check_info
.
file_infos
file_infos
=
check_info
.
file_infos
file_list
=
[
file
for
file
in
os
.
listdir
(
self
.
module_path
)]
file_list
=
[
file
for
file
in
os
.
listdir
(
self
.
directory
)]
while
file_list
:
while
file_list
:
file
=
file_list
[
0
]
file
=
file_list
[
0
]
file_list
=
file_list
[
1
:]
file_list
=
file_list
[
1
:]
abs_path
=
os
.
path
.
join
(
self
.
module_path
,
file
)
abs_path
=
os
.
path
.
join
(
self
.
directory
,
file
)
if
os
.
path
.
isdir
(
abs_path
):
if
os
.
path
.
isdir
(
abs_path
):
for
sub_file
in
os
.
listdir
(
abs_path
):
for
sub_file
in
os
.
listdir
(
abs_path
):
sub_file
=
os
.
path
.
join
(
file
,
sub_file
)
sub_file
=
os
.
path
.
join
(
file
,
sub_file
)
...
@@ -62,9 +64,12 @@ class ModuleChecker(object):
...
@@ -62,9 +64,12 @@ class ModuleChecker(object):
file_info
.
type
=
check_info_pb2
.
FILE
file_info
.
type
=
check_info_pb2
.
FILE
file_info
.
is_need
=
True
file_info
.
is_need
=
True
with
open
(
os
.
path
.
join
(
self
.
module_path
,
CHECK_INFO_PB_FILENAME
),
with
open
(
self
.
pb_path
,
"wb"
)
as
file
:
"wb"
)
as
fi
:
file
.
write
(
check_info
.
SerializeToString
())
fi
.
write
(
check_info
.
SerializeToString
())
@
property
def
module_code_version
(
self
):
return
self
.
check_info
.
module_code_version
@
property
@
property
def
module_proto_version
(
self
):
def
module_proto_version
(
self
):
...
@@ -82,20 +87,25 @@ class ModuleChecker(object):
...
@@ -82,20 +87,25 @@ class ModuleChecker(object):
def
file_infos
(
self
):
def
file_infos
(
self
):
return
self
.
check_info
.
file_infos
return
self
.
check_info
.
file_infos
@
property
def
directory
(
self
):
return
self
.
_directory
@
property
def
pb_path
(
self
):
return
self
.
_pb_path
def
check
(
self
):
def
check
(
self
):
result
=
True
result
=
True
self
.
check_info_pb_path
=
os
.
path
.
join
(
self
.
module_path
,
CHECK_INFO_PB_FILENAME
)
if
not
(
os
.
path
.
exists
(
self
.
check_info_pb_path
)
if
not
(
os
.
path
.
exists
(
self
.
pb_path
)
or
os
.
path
.
isfile
(
self
.
pb_path
)):
or
os
.
path
.
isfile
(
self
.
check_info_pb_path
)):
logger
.
warning
(
logger
.
warning
(
"This module lacks core file %s"
%
CHECK_INFO_PB_FILENAME
)
"This module lacks core file %s"
%
CHECK_INFO_PB_FILENAME
)
result
=
False
result
=
False
self
.
check_info
=
check_info_pb2
.
CheckInfo
()
self
.
check_info
=
check_info_pb2
.
CheckInfo
()
try
:
try
:
with
open
(
self
.
check_info_
pb_path
,
"rb"
)
as
fi
:
with
open
(
self
.
pb_path
,
"rb"
)
as
fi
:
pb_string
=
fi
.
read
()
pb_string
=
fi
.
read
()
result
=
self
.
check_info
.
ParseFromString
(
pb_string
)
result
=
self
.
check_info
.
ParseFromString
(
pb_string
)
if
len
(
pb_string
)
==
0
or
(
result
is
not
None
if
len
(
pb_string
)
==
0
or
(
result
is
not
None
...
@@ -182,7 +192,7 @@ class ModuleChecker(object):
...
@@ -182,7 +192,7 @@ class ModuleChecker(object):
for
file_info
in
self
.
file_infos
:
for
file_info
in
self
.
file_infos
:
file_type
=
file_info
.
type
file_type
=
file_info
.
type
file_path
=
file_info
.
file_name
.
replace
(
FILE_SEP
,
os
.
sep
)
file_path
=
file_info
.
file_name
.
replace
(
FILE_SEP
,
os
.
sep
)
file_path
=
os
.
path
.
join
(
self
.
module_path
,
file_path
)
file_path
=
os
.
path
.
join
(
self
.
directory
,
file_path
)
if
not
os
.
path
.
exists
(
file_path
):
if
not
os
.
path
.
exists
(
file_path
):
if
file_info
.
is_need
:
if
file_info
.
is_need
:
logger
.
warning
(
logger
.
warning
(
...
...
paddlehub/module/manager.py
浏览文件 @
fa6f8d16
...
@@ -19,7 +19,9 @@ from __future__ import print_function
...
@@ -19,7 +19,9 @@ from __future__ import print_function
import
os
import
os
import
shutil
import
shutil
from
functools
import
cmp_to_key
from
functools
import
cmp_to_key
import
tarfile
from
paddlehub.common
import
utils
from
paddlehub.common
import
utils
from
paddlehub.common
import
srv_utils
from
paddlehub.common
import
srv_utils
...
@@ -79,10 +81,15 @@ class LocalModuleManager(object):
...
@@ -79,10 +81,15 @@ class LocalModuleManager(object):
return
self
.
modules_dict
.
get
(
module_name
,
None
)
return
self
.
modules_dict
.
get
(
module_name
,
None
)
def
install_module
(
self
,
def
install_module
(
self
,
module_name
,
module_name
=
None
,
module_dir
=
None
,
module_package
=
None
,
module_version
=
None
,
module_version
=
None
,
upgrade
=
False
,
upgrade
=
False
,
extra
=
None
):
extra
=
None
):
md5_value
=
installed_module_version
=
None
from_user_dir
=
True
if
module_dir
else
False
if
module_name
:
self
.
all_modules
(
update
=
True
)
self
.
all_modules
(
update
=
True
)
module_info
=
self
.
modules_dict
.
get
(
module_name
,
None
)
module_info
=
self
.
modules_dict
.
get
(
module_name
,
None
)
if
module_info
:
if
module_info
:
...
@@ -95,6 +102,62 @@ class LocalModuleManager(object):
...
@@ -95,6 +102,62 @@ class LocalModuleManager(object):
module_dir
)
module_dir
)
return
True
,
tips
,
self
.
modules_dict
[
module_name
]
return
True
,
tips
,
self
.
modules_dict
[
module_name
]
search_result
=
hub
.
default_hub_server
.
get_module_url
(
module_name
,
version
=
module_version
,
extra
=
extra
)
name
=
search_result
.
get
(
'name'
,
None
)
url
=
search_result
.
get
(
'url'
,
None
)
md5_value
=
search_result
.
get
(
'md5'
,
None
)
installed_module_version
=
search_result
.
get
(
'version'
,
None
)
if
not
url
or
(
module_version
is
not
None
and
installed_module_version
!=
module_version
)
or
(
name
!=
module_name
):
if
default_hub_server
.
_server_check
()
is
False
:
tips
=
"Request Hub-Server unsuccessfully, please check your network."
else
:
tips
=
"Can't find module %s"
%
module_name
if
module_version
:
tips
+=
" with version %s"
%
module_version
module_tag
=
module_name
if
not
module_version
else
'%s-%s'
%
(
module_name
,
module_version
)
return
False
,
tips
,
None
result
,
tips
,
module_zip_file
=
default_downloader
.
download_file
(
url
=
url
,
save_path
=
hub
.
CACHE_HOME
,
save_name
=
module_name
,
replace
=
True
,
print_progress
=
True
)
result
,
tips
,
module_dir
=
default_downloader
.
uncompress
(
file
=
module_zip_file
,
dirname
=
MODULE_HOME
,
delete_file
=
True
,
print_progress
=
True
)
if
module_package
:
with
tarfile
.
open
(
module_package
,
"r:gz"
)
as
tar
:
file_names
=
tar
.
getnames
()
size
=
len
(
file_names
)
-
1
module_dir
=
os
.
path
.
split
(
file_names
[
0
])[
0
]
module_dir
=
os
.
path
.
join
(
hub
.
CACHE_HOME
,
module_dir
)
# remove cache
if
os
.
path
.
exists
(
module_dir
):
shutil
.
rmtree
(
module_dir
)
for
index
,
file_name
in
enumerate
(
file_names
):
tar
.
extract
(
file_name
,
hub
.
CACHE_HOME
)
if
module_dir
:
if
not
module_name
:
module_name
=
hub
.
Module
(
directory
=
module_dir
).
name
self
.
all_modules
(
update
=
False
)
module_info
=
self
.
modules_dict
.
get
(
module_name
,
None
)
if
module_info
:
module_dir
=
self
.
modules_dict
[
module_name
][
0
]
module_tag
=
module_name
if
not
module_version
else
'%s-%s'
%
(
module_name
,
module_version
)
tips
=
"Module %s already installed in %s"
%
(
module_tag
,
module_dir
)
return
True
,
tips
,
self
.
modules_dict
[
module_name
]
search_result
=
hub
.
default_hub_server
.
get_module_url
(
search_result
=
hub
.
default_hub_server
.
get_module_url
(
module_name
,
version
=
module_version
,
extra
=
extra
)
module_name
,
version
=
module_version
,
extra
=
extra
)
name
=
search_result
.
get
(
'name'
,
None
)
name
=
search_result
.
get
(
'name'
,
None
)
...
@@ -162,9 +225,18 @@ class LocalModuleManager(object):
...
@@ -162,9 +225,18 @@ class LocalModuleManager(object):
with
open
(
os
.
path
.
join
(
MODULE_HOME
,
module_dir
,
"md5.txt"
),
with
open
(
os
.
path
.
join
(
MODULE_HOME
,
module_dir
,
"md5.txt"
),
"w"
)
as
fp
:
"w"
)
as
fp
:
fp
.
write
(
md5_value
)
fp
.
write
(
md5_value
)
if
md5_value
:
with
open
(
os
.
path
.
join
(
MODULE_HOME
,
module_dir
,
"md5.txt"
),
"w"
)
as
fp
:
fp
.
write
(
md5_value
)
save_path
=
os
.
path
.
join
(
MODULE_HOME
,
module_name
)
save_path
=
os
.
path
.
join
(
MODULE_HOME
,
module_name
)
if
os
.
path
.
exists
(
save_path
):
if
os
.
path
.
exists
(
save_path
):
shutil
.
rmtree
(
save_path
)
shutil
.
move
(
save_path
)
if
from_user_dir
:
shutil
.
copytree
(
module_dir
,
save_path
)
else
:
shutil
.
move
(
module_dir
,
save_path
)
shutil
.
move
(
module_dir
,
save_path
)
module_dir
=
save_path
module_dir
=
save_path
tips
=
"Successfully installed %s"
%
module_name
tips
=
"Successfully installed %s"
%
module_name
...
...
paddlehub/module/module.py
浏览文件 @
fa6f8d16
...
@@ -21,6 +21,10 @@ import os
...
@@ -21,6 +21,10 @@ import os
import
time
import
time
import
sys
import
sys
import
functools
import
functools
import
inspect
import
importlib
import
tarfile
from
collections
import
defaultdict
from
shutil
import
copyfile
from
shutil
import
copyfile
import
paddle
import
paddle
...
@@ -28,22 +32,19 @@ import paddle.fluid as fluid
...
@@ -28,22 +32,19 @@ import paddle.fluid as fluid
from
paddlehub.common
import
utils
from
paddlehub.common
import
utils
from
paddlehub.common
import
paddle_helper
from
paddlehub.common
import
paddle_helper
from
paddlehub.common.
logger
import
logger
from
paddlehub.common.
dir
import
CACHE_HOME
from
paddlehub.common.lock
import
lock
from
paddlehub.common.lock
import
lock
from
paddlehub.common.downloader
import
default_downloader
from
paddlehub.common.logger
import
logger
from
paddlehub.common.hub_server
import
CacheUpdater
from
paddlehub.module
import
module_desc_pb2
from
paddlehub.module
import
module_desc_pb2
from
paddlehub.common.dir
import
CONF_HOME
from
paddlehub.module
import
check_info_pb2
from
paddlehub.module
import
check_info_pb2
from
paddlehub.common.hub_server
import
CacheUpdater
from
paddlehub.module.signature
import
Signature
,
create_signature
from
paddlehub.module.checker
import
ModuleChecker
from
paddlehub.module.manager
import
default_module_manager
from
paddlehub.module.manager
import
default_module_manager
from
paddlehub.module.checker
import
ModuleChecker
from
paddlehub.module.signature
import
Signature
,
create_signature
from
paddlehub.module.base_processor
import
BaseProcessor
from
paddlehub.module.base_processor
import
BaseProcessor
from
paddlehub.io.parser
import
yaml_parser
from
paddlehub.io.parser
import
yaml_parser
from
paddlehub
import
version
from
paddlehub
import
version
__all__
=
[
'Module'
,
'create_module'
]
# PaddleHub module dir name
# PaddleHub module dir name
ASSETS_DIRNAME
=
"assets"
ASSETS_DIRNAME
=
"assets"
MODEL_DIRNAME
=
"model"
MODEL_DIRNAME
=
"model"
...
@@ -52,67 +53,227 @@ PYTHON_DIR = "python"
...
@@ -52,67 +53,227 @@ PYTHON_DIR = "python"
PROCESSOR_NAME
=
"processor"
PROCESSOR_NAME
=
"processor"
# PaddleHub var prefix
# PaddleHub var prefix
HUB_VAR_PREFIX
=
"@HUB_%s@"
HUB_VAR_PREFIX
=
"@HUB_%s@"
# PaddleHub Module package suffix
HUB_PACKAGE_SUFFIX
=
"phm"
def
create_module
(
directory
,
name
,
author
,
email
,
module_type
,
summary
,
version
):
save_file_name
=
"{}-{}.{}"
.
format
(
name
,
version
,
HUB_PACKAGE_SUFFIX
)
# record module info and serialize
desc
=
module_desc_pb2
.
ModuleDesc
()
attr
=
desc
.
attr
attr
.
type
=
module_desc_pb2
.
MAP
module_info
=
attr
.
map
.
data
[
'module_info'
]
module_info
.
type
=
module_desc_pb2
.
MAP
utils
.
from_pyobj_to_module_attr
(
name
,
module_info
.
map
.
data
[
'name'
])
utils
.
from_pyobj_to_module_attr
(
author
,
module_info
.
map
.
data
[
'author'
])
utils
.
from_pyobj_to_module_attr
(
email
,
module_info
.
map
.
data
[
'author_email'
])
utils
.
from_pyobj_to_module_attr
(
module_type
,
module_info
.
map
.
data
[
'type'
])
utils
.
from_pyobj_to_module_attr
(
summary
,
module_info
.
map
.
data
[
'summary'
])
utils
.
from_pyobj_to_module_attr
(
version
,
module_info
.
map
.
data
[
'version'
])
module_desc_path
=
os
.
path
.
join
(
directory
,
"module_desc.pb"
)
with
open
(
module_desc_path
,
"wb"
)
as
f
:
f
.
write
(
desc
.
SerializeToString
())
# generate check info
checker
=
ModuleChecker
(
directory
)
checker
.
generate_check_info
()
# add __init__
module_init_1
=
os
.
path
.
join
(
directory
,
"__init__.py"
)
with
open
(
module_init_1
,
"a"
)
as
file
:
file
.
write
(
""
)
module_init_2
=
os
.
path
.
join
(
directory
,
"python"
,
"__init__.py"
)
with
open
(
module_init_2
,
"a"
)
as
file
:
file
.
write
(
""
)
# package the module
with
tarfile
.
open
(
save_file_name
,
"w:gz"
)
as
tar
:
for
dirname
,
_
,
files
in
os
.
walk
(
directory
):
for
file
in
files
:
tar
.
add
(
os
.
path
.
join
(
dirname
,
file
))
os
.
remove
(
module_desc_path
)
os
.
remove
(
checker
.
pb_path
)
os
.
remove
(
module_init_1
)
os
.
remove
(
module_init_2
)
class
Module
(
object
):
def
__new__
(
cls
,
name
=
None
,
directory
=
None
,
module_dir
=
None
,
version
=
None
):
module
=
None
if
cls
.
__name__
==
"Module"
:
if
name
:
module
=
cls
.
init_with_name
(
name
=
name
,
version
=
version
)
elif
directory
:
module
=
cls
.
init_with_directory
(
directory
=
directory
)
elif
module_dir
:
logger
.
warning
(
"Parameter module_dir is deprecated, please use directory to specify the path"
)
if
isinstance
(
module_dir
,
list
)
or
isinstance
(
module_dir
,
tuple
):
directory
=
module_dir
[
0
]
version
=
module_dir
[
1
]
else
:
directory
=
module_dir
module
=
cls
.
init_with_directory
(
directory
=
directory
)
if
not
module
:
module
=
object
.
__new__
(
cls
)
else
:
CacheUpdater
(
module
.
name
,
module
.
version
).
start
()
return
module
def
__init__
(
self
,
name
=
None
,
directory
=
None
,
module_dir
=
None
,
version
=
None
):
if
not
directory
:
return
self
.
_code_version
=
"v2"
self
.
_directory
=
directory
self
.
module_desc_path
=
os
.
path
.
join
(
self
.
directory
,
MODULE_DESC_PBNAME
)
self
.
_desc
=
module_desc_pb2
.
ModuleDesc
()
with
open
(
self
.
module_desc_path
,
"rb"
)
as
file
:
self
.
_desc
.
ParseFromString
(
file
.
read
())
module_info
=
self
.
desc
.
attr
.
map
.
data
[
'module_info'
]
self
.
_name
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'name'
])
self
.
_author
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author'
])
self
.
_author_email
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author_email'
])
self
.
_version
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'version'
])
self
.
_type
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'type'
])
self
.
_summary
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'summary'
])
self
.
_initialize
()
@
classmethod
def
init_with_name
(
cls
,
name
,
version
=
None
):
fp_lock
=
open
(
os
.
path
.
join
(
CACHE_HOME
,
name
),
"a"
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_EX
)
log_msg
=
"Installing %s module"
%
name
if
version
:
log_msg
+=
"-%s"
%
version
logger
.
info
(
log_msg
)
extra
=
{
"command"
:
"install"
}
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
name
,
module_version
=
version
,
extra
=
extra
)
if
not
result
:
logger
.
error
(
tips
)
raise
RuntimeError
(
tips
)
logger
.
info
(
tips
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
return
cls
.
init_with_directory
(
directory
=
module_dir
[
0
])
@
classmethod
def
init_with_directory
(
cls
,
directory
):
desc_file
=
os
.
path
.
join
(
directory
,
MODULE_DESC_PBNAME
)
checker
=
ModuleChecker
(
directory
)
checker
.
check
()
def
create_module
(
sign_arr
,
module_code_version
=
checker
.
module_code_version
module_dir
,
if
module_code_version
==
"v2"
:
processor
=
None
,
basename
=
os
.
path
.
split
(
directory
)[
-
1
]
assets
=
None
,
dirname
=
os
.
path
.
join
(
*
list
(
os
.
path
.
split
(
directory
)[:
-
1
]))
module_info
=
None
,
sys
.
path
.
append
(
dirname
)
exe
=
None
,
pymodule
=
importlib
.
import_module
(
extra_info
=
None
):
"{}.python.module"
.
format
(
basename
))
sign_arr
=
utils
.
to_list
(
sign_arr
)
return
pymodule
.
HubModule
(
directory
=
directory
)
module
=
Module
(
return
ModuleV1
(
directory
=
directory
)
signatures
=
sign_arr
,
processor
=
processor
,
@
property
assets
=
assets
,
def
desc
(
self
):
module_info
=
module_info
,
return
self
.
_desc
extra_info
=
extra_info
)
module
.
serialize_to_path
(
path
=
module_dir
,
exe
=
exe
)
@
property
def
directory
(
self
):
return
self
.
_directory
@
property
def
author
(
self
):
return
self
.
_author
@
property
def
author_email
(
self
):
return
self
.
_author_email
@
property
def
summary
(
self
):
return
self
.
_summary
@
property
def
type
(
self
):
return
self
.
_type
@
property
def
version
(
self
):
return
self
.
_version
@
property
def
name
(
self
):
return
self
.
_name
@
property
def
name_prefix
(
self
):
return
self
.
_name_prefix
@
property
def
code_version
(
self
):
return
self
.
_code_version
@
property
def
is_runable
(
self
):
return
False
def
_initialize
(
self
):
pass
class
ModuleHelper
(
object
):
class
ModuleHelper
(
object
):
def
__init__
(
self
,
module_dir
):
def
__init__
(
self
,
directory
):
self
.
module_dir
=
module_dir
self
.
directory
=
directory
def
module_desc_path
(
self
):
def
module_desc_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
MODULE_DESC_PBNAME
)
return
os
.
path
.
join
(
self
.
directory
,
MODULE_DESC_PBNAME
)
def
model_path
(
self
):
def
model_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
MODEL_DIRNAME
)
return
os
.
path
.
join
(
self
.
directory
,
MODEL_DIRNAME
)
def
processor_path
(
self
):
def
processor_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
PYTHON_DIR
)
return
os
.
path
.
join
(
self
.
directory
,
PYTHON_DIR
)
def
processor_name
(
self
):
def
processor_name
(
self
):
return
PROCESSOR_NAME
return
PROCESSOR_NAME
def
assets_path
(
self
):
def
assets_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
ASSETS_DIRNAME
)
return
os
.
path
.
join
(
self
.
directory
,
ASSETS_DIRNAME
)
class
Module
(
object
):
class
ModuleV1
(
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
name
=
None
,
directory
=
None
,
module_dir
=
None
,
name
=
None
,
module_dir
=
None
,
signatures
=
None
,
module_info
=
None
,
assets
=
None
,
processor
=
None
,
extra_info
=
None
,
version
=
None
):
version
=
None
):
self
.
desc
=
module_desc_pb2
.
ModuleDesc
()
if
not
directory
:
return
super
(
ModuleV1
,
self
).
__init__
(
name
,
directory
,
module_dir
,
version
)
self
.
_code_version
=
"v1"
self
.
program
=
None
self
.
program
=
None
self
.
assets
=
[]
self
.
assets
=
[]
self
.
helper
=
None
self
.
helper
=
None
self
.
signatures
=
{}
self
.
signatures
=
{}
self
.
default_signature
=
None
self
.
default_signature
=
None
self
.
module_info
=
None
self
.
processor
=
None
self
.
processor
=
None
self
.
extra_info
=
{}
if
extra_info
is
None
else
extra_info
self
.
extra_info
=
{}
if
not
isinstance
(
self
.
extra_info
,
dict
):
raise
TypeError
(
"The extra_info should be an instance of python dict"
)
# cache data
# cache data
self
.
last_call_name
=
None
self
.
last_call_name
=
None
...
@@ -120,62 +281,21 @@ class Module(object):
...
@@ -120,62 +281,21 @@ class Module(object):
self
.
cache_fetch_dict
=
None
self
.
cache_fetch_dict
=
None
self
.
cache_program
=
None
self
.
cache_program
=
None
fp_lock
=
open
(
os
.
path
.
join
(
CONF_HOME
,
'config.json'
))
self
.
helper
=
ModuleHelper
(
directory
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_EX
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
if
name
:
self
.
program
,
_
,
_
=
fluid
.
io
.
load_inference_model
(
self
.
_init_with_name
(
name
=
name
,
version
=
version
)
self
.
helper
.
model_path
(),
executor
=
exe
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
for
block
in
self
.
program
.
blocks
:
elif
module_dir
:
for
op
in
block
.
ops
:
self
.
_init_with_module_file
(
module_dir
=
module_dir
[
0
])
if
"op_callstack"
in
op
.
all_attrs
():
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
op
.
_set_attr
(
"op_callstack"
,
[
""
])
name
=
module_dir
[
0
].
split
(
"/"
)[
-
1
]
self
.
_load_processor
()
if
len
(
module_dir
)
>
1
:
self
.
_load_assets
()
version
=
module_dir
[
1
]
self
.
_recover_from_desc
()
else
:
self
.
_generate_sign_attr
()
version
=
default_module_manager
.
search_module
(
name
)[
1
]
self
.
_generate_extra_info
()
elif
signatures
:
self
.
_restore_parameter
(
self
.
program
)
if
processor
:
self
.
_recover_variable_info
(
self
.
program
)
if
not
issubclass
(
processor
,
BaseProcessor
):
raise
TypeError
(
"Processor shoule be an instance of paddlehub.BaseProcessor"
)
if
assets
:
self
.
assets
=
utils
.
to_list
(
assets
)
# for asset in assets:
# utils.check_path(assets)
self
.
processor
=
processor
self
.
_generate_module_info
(
module_info
)
self
.
_init_with_signature
(
signatures
=
signatures
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
else
:
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
raise
ValueError
(
"Module initialized parameter is empty"
)
CacheUpdater
(
name
,
version
).
start
()
def
_init_with_name
(
self
,
name
,
version
=
None
):
log_msg
=
"Installing %s module"
%
name
if
version
:
log_msg
+=
"-%s"
%
version
logger
.
info
(
log_msg
)
extra
=
{
"command"
:
"install"
}
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
name
,
module_version
=
version
,
extra
=
extra
)
if
not
result
:
logger
.
error
(
tips
)
raise
RuntimeError
(
tips
)
else
:
logger
.
info
(
tips
)
self
.
_init_with_module_file
(
module_dir
[
0
])
def
_init_with_url
(
self
,
url
):
utils
.
check_url
(
url
)
result
,
tips
,
module_dir
=
default_downloader
.
download_file_and_uncompress
(
url
,
save_path
=
"."
)
if
not
result
:
logger
.
error
(
tips
)
raise
RuntimeError
(
tips
)
else
:
self
.
_init_with_module_file
(
module_dir
)
def
_dump_processor
(
self
):
def
_dump_processor
(
self
):
import
inspect
import
inspect
...
@@ -216,52 +336,6 @@ class Module(object):
...
@@ -216,52 +336,6 @@ class Module(object):
filepath
=
os
.
path
.
join
(
self
.
helper
.
assets_path
(),
file
)
filepath
=
os
.
path
.
join
(
self
.
helper
.
assets_path
(),
file
)
self
.
assets
.
append
(
filepath
)
self
.
assets
.
append
(
filepath
)
def
_init_with_module_file
(
self
,
module_dir
):
checker
=
ModuleChecker
(
module_dir
)
checker
.
check
()
self
.
helper
=
ModuleHelper
(
module_dir
)
with
open
(
self
.
helper
.
module_desc_path
(),
"rb"
)
as
fi
:
self
.
desc
.
ParseFromString
(
fi
.
read
())
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
self
.
program
,
_
,
_
=
fluid
.
io
.
load_inference_model
(
self
.
helper
.
model_path
(),
executor
=
exe
)
for
block
in
self
.
program
.
blocks
:
for
op
in
block
.
ops
:
if
"op_callstack"
in
op
.
all_attrs
():
op
.
_set_attr
(
"op_callstack"
,
[
""
])
self
.
_load_processor
()
self
.
_load_assets
()
self
.
_recover_from_desc
()
self
.
_generate_sign_attr
()
self
.
_generate_extra_info
()
self
.
_restore_parameter
(
self
.
program
)
self
.
_recover_variable_info
(
self
.
program
)
def
_init_with_signature
(
self
,
signatures
):
self
.
name_prefix
=
HUB_VAR_PREFIX
%
self
.
name
self
.
_process_signatures
(
signatures
)
self
.
_check_signatures
()
self
.
_generate_desc
()
self
.
_generate_sign_attr
()
self
.
_generate_extra_info
()
def
_init_with_program
(
self
,
program
):
pass
def
_process_signatures
(
self
,
signatures
):
self
.
signatures
=
{}
self
.
program
=
signatures
[
0
].
inputs
[
0
].
block
.
program
for
sign
in
signatures
:
if
sign
.
name
in
self
.
signatures
:
raise
ValueError
(
"Error! Signature array contains duplicated signatrues %s"
%
sign
)
if
self
.
default_signature
is
None
and
sign
.
for_predict
:
self
.
default_signature
=
sign
self
.
signatures
[
sign
.
name
]
=
sign
def
_restore_parameter
(
self
,
program
):
def
_restore_parameter
(
self
,
program
):
global_block
=
program
.
global_block
()
global_block
=
program
.
global_block
()
param_attrs
=
self
.
desc
.
attr
.
map
.
data
[
'param_attrs'
]
param_attrs
=
self
.
desc
.
attr
.
map
.
data
[
'param_attrs'
]
...
@@ -302,21 +376,6 @@ class Module(object):
...
@@ -302,21 +376,6 @@ class Module(object):
self
.
__dict__
[
"get_%s"
%
key
]
=
functools
.
partial
(
self
.
__dict__
[
"get_%s"
%
key
]
=
functools
.
partial
(
self
.
get_extra_info
,
key
=
key
)
self
.
get_extra_info
,
key
=
key
)
def
_generate_module_info
(
self
,
module_info
=
None
):
if
not
module_info
:
self
.
module_info
=
{}
else
:
if
not
utils
.
is_yaml_file
(
module_info
):
logger
.
critical
(
"Module info file should be yaml format"
)
exit
(
1
)
self
.
module_info
=
yaml_parser
.
parse
(
module_info
)
self
.
author
=
self
.
module_info
.
get
(
'author'
,
'UNKNOWN'
)
self
.
author_email
=
self
.
module_info
.
get
(
'author_email'
,
'UNKNOWN'
)
self
.
summary
=
self
.
module_info
.
get
(
'summary'
,
'UNKNOWN'
)
self
.
type
=
self
.
module_info
.
get
(
'type'
,
'UNKNOWN'
)
self
.
version
=
self
.
module_info
.
get
(
'version'
,
'UNKNOWN'
)
self
.
name
=
self
.
module_info
.
get
(
'name'
,
'UNKNOWN'
)
def
_generate_sign_attr
(
self
):
def
_generate_sign_attr
(
self
):
self
.
_check_signatures
()
self
.
_check_signatures
()
for
sign
in
self
.
signatures
:
for
sign
in
self
.
signatures
:
...
@@ -369,21 +428,21 @@ class Module(object):
...
@@ -369,21 +428,21 @@ class Module(object):
default_signature_name
=
utils
.
from_module_attr_to_pyobj
(
default_signature_name
=
utils
.
from_module_attr_to_pyobj
(
self
.
desc
.
attr
.
map
.
data
[
'default_signature'
])
self
.
desc
.
attr
.
map
.
data
[
'default_signature'
])
self
.
default_signature
=
self
.
signatures
[
self
.
default_signature
=
self
.
signatures
[
default_signature_name
]
if
default_signature_name
else
None
default_signature_name
]
.
name
if
default_signature_name
else
None
# recover module info
# recover module info
module_info
=
self
.
desc
.
attr
.
map
.
data
[
'module_info'
]
module_info
=
self
.
desc
.
attr
.
map
.
data
[
'module_info'
]
self
.
name
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
name
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'name'
])
module_info
.
map
.
data
[
'name'
])
self
.
author
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
author
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author'
])
module_info
.
map
.
data
[
'author'
])
self
.
author_email
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
author_email
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author_email'
])
module_info
.
map
.
data
[
'author_email'
])
self
.
version
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
version
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'version'
])
module_info
.
map
.
data
[
'version'
])
self
.
type
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
type
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'type'
])
module_info
.
map
.
data
[
'type'
])
self
.
summary
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
summary
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'summary'
])
module_info
.
map
.
data
[
'summary'
])
# recover extra info
# recover extra info
...
@@ -393,77 +452,9 @@ class Module(object):
...
@@ -393,77 +452,9 @@ class Module(object):
self
.
extra_info
[
key
]
=
utils
.
from_module_attr_to_pyobj
(
value
)
self
.
extra_info
[
key
]
=
utils
.
from_module_attr_to_pyobj
(
value
)
# recover name prefix
# recover name prefix
self
.
name_prefix
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
name_prefix
=
utils
.
from_module_attr_to_pyobj
(
self
.
desc
.
attr
.
map
.
data
[
"name_prefix"
])
self
.
desc
.
attr
.
map
.
data
[
"name_prefix"
])
def
_generate_desc
(
self
):
# save fluid Parameter
attr
=
self
.
desc
.
attr
attr
.
type
=
module_desc_pb2
.
MAP
param_attrs
=
attr
.
map
.
data
[
'param_attrs'
]
param_attrs
.
type
=
module_desc_pb2
.
MAP
for
param
in
self
.
program
.
global_block
().
iter_parameters
():
param_attr
=
param_attrs
.
map
.
data
[
param
.
name
]
paddle_helper
.
from_param_to_module_attr
(
param
,
param_attr
)
# save Variable Info
var_infos
=
attr
.
map
.
data
[
'var_infos'
]
var_infos
.
type
=
module_desc_pb2
.
MAP
for
block
in
self
.
program
.
blocks
:
for
var
in
block
.
vars
.
values
():
var_info
=
var_infos
.
map
.
data
[
var
.
name
]
var_info
.
type
=
module_desc_pb2
.
MAP
utils
.
from_pyobj_to_module_attr
(
var
.
stop_gradient
,
var_info
.
map
.
data
[
'stop_gradient'
])
utils
.
from_pyobj_to_module_attr
(
block
.
idx
,
var_info
.
map
.
data
[
'block_id'
])
# save signarture info
for
key
,
sign
in
self
.
signatures
.
items
():
var
=
self
.
desc
.
sign2var
[
sign
.
name
]
feed_desc
=
var
.
feed_desc
fetch_desc
=
var
.
fetch_desc
feed_names
=
sign
.
feed_names
fetch_names
=
sign
.
fetch_names
for
index
,
input
in
enumerate
(
sign
.
inputs
):
feed_var
=
feed_desc
.
add
()
feed_var
.
var_name
=
self
.
get_var_name_with_prefix
(
input
.
name
)
feed_var
.
alias
=
feed_names
[
index
]
for
index
,
output
in
enumerate
(
sign
.
outputs
):
fetch_var
=
fetch_desc
.
add
()
fetch_var
.
var_name
=
self
.
get_var_name_with_prefix
(
output
.
name
)
fetch_var
.
alias
=
fetch_names
[
index
]
# save default signature
utils
.
from_pyobj_to_module_attr
(
self
.
default_signature
.
name
if
self
.
default_signature
else
None
,
attr
.
map
.
data
[
'default_signature'
])
# save name prefix
utils
.
from_pyobj_to_module_attr
(
self
.
name_prefix
,
self
.
desc
.
attr
.
map
.
data
[
"name_prefix"
])
# save module info
module_info
=
attr
.
map
.
data
[
'module_info'
]
module_info
.
type
=
module_desc_pb2
.
MAP
utils
.
from_pyobj_to_module_attr
(
self
.
name
,
module_info
.
map
.
data
[
'name'
])
utils
.
from_pyobj_to_module_attr
(
self
.
version
,
module_info
.
map
.
data
[
'version'
])
utils
.
from_pyobj_to_module_attr
(
self
.
author
,
module_info
.
map
.
data
[
'author'
])
utils
.
from_pyobj_to_module_attr
(
self
.
author_email
,
module_info
.
map
.
data
[
'author_email'
])
utils
.
from_pyobj_to_module_attr
(
self
.
type
,
module_info
.
map
.
data
[
'type'
])
utils
.
from_pyobj_to_module_attr
(
self
.
summary
,
module_info
.
map
.
data
[
'summary'
])
# save extra info
extra_info
=
attr
.
map
.
data
[
'extra_info'
]
extra_info
.
type
=
module_desc_pb2
.
MAP
for
key
,
value
in
self
.
extra_info
.
items
():
utils
.
from_pyobj_to_module_attr
(
value
,
extra_info
.
map
.
data
[
key
])
def
__call__
(
self
,
sign_name
,
data
,
use_gpu
=
False
,
batch_size
=
1
,
**
kwargs
):
def
__call__
(
self
,
sign_name
,
data
,
use_gpu
=
False
,
batch_size
=
1
,
**
kwargs
):
self
.
check_processor
()
self
.
check_processor
()
...
@@ -525,6 +516,10 @@ class Module(object):
...
@@ -525,6 +516,10 @@ class Module(object):
if
not
self
.
processor
:
if
not
self
.
processor
:
raise
ValueError
(
"This Module is not callable!"
)
raise
ValueError
(
"This Module is not callable!"
)
@
property
def
is_runable
(
self
):
return
self
.
default_signature
!=
None
def
context
(
self
,
def
context
(
self
,
sign_name
=
None
,
sign_name
=
None
,
for_test
=
False
,
for_test
=
False
,
...
@@ -664,93 +659,3 @@ class Module(object):
...
@@ -664,93 +659,3 @@ class Module(object):
raise
ValueError
(
raise
ValueError
(
"All input and outputs variables in signature should come from the same Program"
"All input and outputs variables in signature should come from the same Program"
)
)
def
serialize_to_path
(
self
,
path
=
None
,
exe
=
None
):
self
.
_check_signatures
()
self
.
_generate_desc
()
# create module path for saving
if
path
is
None
:
path
=
os
.
path
.
join
(
"."
,
self
.
name
)
self
.
helper
=
ModuleHelper
(
path
)
utils
.
mkdir
(
self
.
helper
.
module_dir
)
# create module pb
module_desc
=
module_desc_pb2
.
ModuleDesc
()
logger
.
info
(
"PaddleHub version = %s"
%
version
.
hub_version
)
logger
.
info
(
"PaddleHub Module proto version = %s"
%
version
.
module_proto_version
)
logger
.
info
(
"Paddle version = %s"
%
paddle
.
__version__
)
feeded_var_names
=
[
input
.
name
for
key
,
sign
in
self
.
signatures
.
items
()
for
input
in
sign
.
inputs
]
target_vars
=
[
output
for
key
,
sign
in
self
.
signatures
.
items
()
for
output
in
sign
.
outputs
]
feeded_var_names
=
list
(
set
(
feeded_var_names
))
target_vars
=
list
(
set
(
target_vars
))
# save inference program
program
=
self
.
program
.
clone
()
for
block
in
program
.
blocks
:
for
op
in
block
.
ops
:
if
"op_callstack"
in
op
.
all_attrs
():
op
.
_set_attr
(
"op_callstack"
,
[
""
])
if
not
exe
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
=
place
)
utils
.
mkdir
(
self
.
helper
.
model_path
())
fluid
.
io
.
save_inference_model
(
self
.
helper
.
model_path
(),
feeded_var_names
=
list
(
feeded_var_names
),
target_vars
=
list
(
target_vars
),
main_program
=
program
,
executor
=
exe
)
with
open
(
os
.
path
.
join
(
self
.
helper
.
model_path
(),
"__model__"
),
"rb"
)
as
file
:
program_desc_str
=
file
.
read
()
rename_program
=
fluid
.
framework
.
Program
.
parse_from_string
(
program_desc_str
)
varlist
=
{
var
:
block
for
block
in
rename_program
.
blocks
for
var
in
block
.
vars
if
self
.
get_name_prefix
()
not
in
var
}
for
var
,
block
in
varlist
.
items
():
old_name
=
var
new_name
=
self
.
get_var_name_with_prefix
(
old_name
)
block
.
_rename_var
(
old_name
,
new_name
)
utils
.
mkdir
(
self
.
helper
.
model_path
())
with
open
(
os
.
path
.
join
(
self
.
helper
.
model_path
(),
"__model__"
),
"wb"
)
as
f
:
f
.
write
(
rename_program
.
desc
.
serialize_to_string
())
for
file
in
os
.
listdir
(
self
.
helper
.
model_path
()):
if
(
file
==
"__model__"
or
self
.
get_name_prefix
()
in
file
):
continue
os
.
rename
(
os
.
path
.
join
(
self
.
helper
.
model_path
(),
file
),
os
.
path
.
join
(
self
.
helper
.
model_path
(),
self
.
get_var_name_with_prefix
(
file
)))
# create processor file
if
self
.
processor
:
self
.
_dump_processor
()
# create assets
self
.
_dump_assets
()
# create check info
checker
=
ModuleChecker
(
self
.
helper
.
module_dir
)
checker
.
generate_check_info
()
# Serialize module_desc pb
module_pb
=
self
.
desc
.
SerializeToString
()
with
open
(
self
.
helper
.
module_desc_path
(),
"wb"
)
as
f
:
f
.
write
(
module_pb
)
paddlehub/serving/bert_serving/bert_service.py
浏览文件 @
fa6f8d16
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# limitations under the License.
# limitations under the License.
import
sys
import
sys
import
time
import
paddlehub
as
hub
import
paddlehub
as
hub
import
ujson
import
ujson
import
random
import
random
...
@@ -30,7 +29,7 @@ if is_py3:
...
@@ -30,7 +29,7 @@ if is_py3:
import
http.client
as
httplib
import
http.client
as
httplib
class
BertService
():
class
BertService
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
profile
=
False
,
profile
=
False
,
max_seq_len
=
128
,
max_seq_len
=
128
,
...
@@ -42,7 +41,7 @@ class BertService():
...
@@ -42,7 +41,7 @@ class BertService():
load_balance
=
'round_robin'
):
load_balance
=
'round_robin'
):
self
.
process_id
=
process_id
self
.
process_id
=
process_id
self
.
reader_flag
=
False
self
.
reader_flag
=
False
self
.
batch_size
=
16
self
.
batch_size
=
0
self
.
max_seq_len
=
max_seq_len
self
.
max_seq_len
=
max_seq_len
self
.
profile
=
profile
self
.
profile
=
profile
self
.
model_name
=
model_name
self
.
model_name
=
model_name
...
@@ -55,15 +54,6 @@ class BertService():
...
@@ -55,15 +54,6 @@ class BertService():
self
.
feed_var_names
=
''
self
.
feed_var_names
=
''
self
.
retry
=
retry
self
.
retry
=
retry
def
connect
(
self
,
server
=
'127.0.0.1:8010'
):
self
.
server_list
.
append
(
server
)
def
connect_all_server
(
self
,
server_list
):
for
server_str
in
server_list
:
self
.
server_list
.
append
(
server_str
)
def
data_convert
(
self
,
text
):
if
self
.
reader_flag
==
False
:
module
=
hub
.
Module
(
name
=
self
.
model_name
)
module
=
hub
.
Module
(
name
=
self
.
model_name
)
inputs
,
outputs
,
program
=
module
.
context
(
inputs
,
outputs
,
program
=
module
.
context
(
trainable
=
True
,
max_seq_len
=
self
.
max_seq_len
)
trainable
=
True
,
max_seq_len
=
self
.
max_seq_len
)
...
@@ -79,10 +69,14 @@ class BertService():
...
@@ -79,10 +69,14 @@ class BertService():
do_lower_case
=
self
.
do_lower_case
)
do_lower_case
=
self
.
do_lower_case
)
self
.
reader_flag
=
True
self
.
reader_flag
=
True
return
self
.
reader
.
data_generator
(
def
add_server
(
self
,
server
=
'127.0.0.1:8010'
):
batch_size
=
self
.
batch_size
,
phase
=
'predict'
,
data
=
text
)
self
.
server_list
.
append
(
server
)
def
infer
(
self
,
request_msg
):
def
add_server_list
(
self
,
server_list
):
for
server_str
in
server_list
:
self
.
server_list
.
append
(
server_str
)
def
request_server
(
self
,
request_msg
):
if
self
.
load_balance
==
'round_robin'
:
if
self
.
load_balance
==
'round_robin'
:
try
:
try
:
cur_con
=
httplib
.
HTTPConnection
(
cur_con
=
httplib
.
HTTPConnection
(
...
@@ -157,17 +151,13 @@ class BertService():
...
@@ -157,17 +151,13 @@ class BertService():
self
.
server_list
)
self
.
server_list
)
return
'retry'
return
'retry'
def
encode
(
self
,
text
):
def
prepare_data
(
self
,
text
):
if
type
(
text
)
!=
list
:
raise
TypeError
(
'Only support list'
)
self
.
batch_size
=
len
(
text
)
self
.
batch_size
=
len
(
text
)
data_generator
=
self
.
data_convert
(
text
)
data_generator
=
self
.
reader
.
data_generator
(
start
=
time
.
time
()
batch_size
=
self
.
batch_size
,
phase
=
'predict'
,
data
=
text
)
request_time
=
0
request_msg
=
""
result
=
[]
for
run_step
,
batch
in
enumerate
(
data_generator
(),
start
=
1
):
for
run_step
,
batch
in
enumerate
(
data_generator
(),
start
=
1
):
request
=
[]
request
=
[]
copy_start
=
time
.
time
()
token_list
=
batch
[
0
][
0
].
reshape
(
-
1
).
tolist
()
token_list
=
batch
[
0
][
0
].
reshape
(
-
1
).
tolist
()
pos_list
=
batch
[
0
][
1
].
reshape
(
-
1
).
tolist
()
pos_list
=
batch
[
0
][
1
].
reshape
(
-
1
).
tolist
()
sent_list
=
batch
[
0
][
2
].
reshape
(
-
1
).
tolist
()
sent_list
=
batch
[
0
][
2
].
reshape
(
-
1
).
tolist
()
...
@@ -184,54 +174,34 @@ class BertService():
...
@@ -184,54 +174,34 @@ class BertService():
si
+
1
)
*
self
.
max_seq_len
]
si
+
1
)
*
self
.
max_seq_len
]
request
.
append
(
instance_dict
)
request
.
append
(
instance_dict
)
copy_time
=
time
.
time
()
-
copy_start
request
=
{
"instances"
:
request
}
request
=
{
"instances"
:
request
}
request
[
"max_seq_len"
]
=
self
.
max_seq_len
request
[
"max_seq_len"
]
=
self
.
max_seq_len
request
[
"feed_var_names"
]
=
self
.
feed_var_names
request
[
"feed_var_names"
]
=
self
.
feed_var_names
request_msg
=
ujson
.
dumps
(
request
)
request_msg
=
ujson
.
dumps
(
request
)
if
self
.
show_ids
:
if
self
.
show_ids
:
logger
.
info
(
request_msg
)
logger
.
info
(
request_msg
)
request_start
=
time
.
time
()
response_msg
=
self
.
infer
(
request_msg
)
return
request_msg
def
encode
(
self
,
text
):
if
type
(
text
)
!=
list
:
raise
TypeError
(
'Only support list'
)
request_msg
=
self
.
prepare_data
(
text
)
response_msg
=
self
.
request_server
(
request_msg
)
retry
=
0
retry
=
0
while
type
(
response_msg
)
==
str
and
response_msg
==
'retry'
:
while
type
(
response_msg
)
==
str
and
response_msg
==
'retry'
:
if
retry
<
self
.
retry
:
if
retry
<
self
.
retry
:
retry
+=
1
retry
+=
1
logger
.
info
(
'Try to connect another servers'
)
logger
.
info
(
'Try to connect another servers'
)
response_msg
=
self
.
inf
er
(
request_msg
)
response_msg
=
self
.
request_serv
er
(
request_msg
)
else
:
else
:
logger
.
error
(
'Infer
failed after {} times retry'
.
format
(
logger
.
error
(
'Request
failed after {} times retry'
.
format
(
self
.
retry
))
self
.
retry
))
break
break
result
=
[]
for
msg
in
response_msg
[
"instances"
]:
for
msg
in
response_msg
[
"instances"
]:
for
sample
in
msg
[
"instances"
]:
for
sample
in
msg
[
"instances"
]:
result
.
append
(
sample
[
"values"
])
result
.
append
(
sample
[
"values"
])
request_time
+=
time
.
time
()
-
request_start
total_time
=
time
.
time
()
-
start
if
self
.
profile
:
return
[
total_time
,
request_time
,
response_msg
[
'op_time'
],
response_msg
[
'infer_time'
],
copy_time
]
else
:
return
result
def
connect
(
input_text
,
model_name
,
max_seq_len
=
128
,
show_ids
=
False
,
do_lower_case
=
True
,
server
=
"127.0.0.1:8866"
,
retry
=
3
):
# format of input_text like [["As long as"],]
bc
=
BertService
(
max_seq_len
=
max_seq_len
,
model_name
=
model_name
,
show_ids
=
show_ids
,
do_lower_case
=
do_lower_case
,
retry
=
retry
)
bc
.
connect
(
server
)
result
=
bc
.
encode
(
input_text
)
return
result
return
result
paddlehub/serving/bert_serving/bs_client.py
0 → 100644
浏览文件 @
fa6f8d16
from
paddlehub.serving.bert_serving
import
bert_service
class
BSClient
(
object
):
def
__init__
(
self
,
module_name
,
server
,
max_seq_len
=
20
,
show_ids
=
False
,
do_lower_case
=
True
,
retry
=
3
):
self
.
bs
=
bert_service
.
BertService
(
model_name
=
module_name
,
max_seq_len
=
max_seq_len
,
show_ids
=
show_ids
,
do_lower_case
=
do_lower_case
,
retry
=
retry
)
self
.
bs
.
add_server
(
server
=
server
)
def
get_result
(
self
,
input_text
):
return
self
.
bs
.
encode
(
input_text
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录