Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
99f0e9a4
S
Serving
项目概览
PaddlePaddle
/
Serving
大约 1 年 前同步成功
通知
186
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
99f0e9a4
编写于
8月 19, 2021
作者:
T
TeslaZhao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modify model conversion interfaces and model load methods
上级
005f10dc
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
195 addition
and
63 deletion
+195
-63
python/paddle_serving_app/local_predict.py
python/paddle_serving_app/local_predict.py
+94
-25
python/paddle_serving_client/io/__init__.py
python/paddle_serving_client/io/__init__.py
+25
-12
python/pipeline/operator.py
python/pipeline/operator.py
+68
-24
python/pipeline/pipeline_client.py
python/pipeline/pipeline_client.py
+7
-1
python/pipeline/util.py
python/pipeline/util.py
+1
-1
未找到文件。
python/paddle_serving_app/local_predict.py
浏览文件 @
99f0e9a4
...
...
@@ -22,6 +22,7 @@ import argparse
from
.proto
import
general_model_config_pb2
as
m_config
import
paddle.inference
as
paddle_infer
import
logging
import
glob
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
"LocalPredictor"
)
...
...
@@ -51,6 +52,23 @@ class LocalPredictor(object):
self
.
fetch_names_to_idx_
=
{}
self
.
fetch_names_to_type_
=
{}
def
search_suffix_files
(
self
,
model_path
,
target_suffix
):
"""
Find all files with the suffix xxx in the specified directory.
Args:
model_path: model directory, not None.
target_suffix: filenames with target suffix, not None. e.g: *.pdmodel
Returns:
file_list, None, [] or [path, ] .
"""
if
model_path
is
None
or
target_suffix
is
None
:
return
None
file_list
=
glob
.
glob
(
os
.
path
.
join
(
model_path
,
target_suffix
))
return
file_list
def
load_model_config
(
self
,
model_path
,
use_gpu
=
False
,
...
...
@@ -97,11 +115,30 @@ class LocalPredictor(object):
f
=
open
(
client_config
,
'r'
)
model_conf
=
google
.
protobuf
.
text_format
.
Merge
(
str
(
f
.
read
()),
model_conf
)
# Init paddle_infer config
# Paddle's model files and parameter files have multiple naming rules:
# 1) __model__, __params__
# 2) *.pdmodel, *.pdiparams
# 3) __model__, conv2d_1.w_0, conv2d_2.w_0, fc_1.w_0, conv2d_1.b_0, ...
pdmodel_file_list
=
self
.
search_suffix_files
(
model_path
,
"*.pdmodel"
)
pdiparams_file_list
=
self
.
search_suffix_files
(
model_path
,
"*.pdiparams"
)
if
os
.
path
.
exists
(
os
.
path
.
join
(
model_path
,
"__params__"
)):
# case 1) initializing
config
=
paddle_infer
.
Config
(
os
.
path
.
join
(
model_path
,
"__model__"
),
os
.
path
.
join
(
model_path
,
"__params__"
))
elif
pdmodel_file_list
and
len
(
pdmodel_file_list
)
>
0
and
pdiparams_file_list
and
len
(
pdiparams_file_list
)
>
0
:
# case 2) initializing
logger
.
info
(
"pdmodel_file_list:{}, pdiparams_file_list:{}"
.
format
(
pdmodel_file_list
,
pdiparams_file_list
))
config
=
paddle_infer
.
Config
(
pdmodel_file_list
[
0
],
pdiparams_file_list
[
0
])
else
:
# case 3) initializing.
config
=
paddle_infer
.
Config
(
model_path
)
logger
.
info
(
...
...
@@ -201,8 +238,9 @@ class LocalPredictor(object):
Run model inference by Paddle Inference API.
Args:
feed: feed var
fetch: fetch var
feed: feed var list, None is not allowed.
fetch: fetch var list, None allowed. when it is None, all fetch
vars are returned. Otherwise, return fetch specified result.
batch: batch data or not, False default.If batch is False, a new
dimension is added to header of the shape[np.newaxis].
log_id: for logging
...
...
@@ -210,16 +248,8 @@ class LocalPredictor(object):
Returns:
fetch_map: dict
"""
if
feed
is
None
or
fetch
is
None
:
raise
ValueError
(
"You should specify feed and fetch for prediction.
\
log_id:{}"
.
format
(
log_id
))
fetch_list
=
[]
if
isinstance
(
fetch
,
str
):
fetch_list
=
[
fetch
]
elif
isinstance
(
fetch
,
list
):
fetch_list
=
fetch
else
:
raise
ValueError
(
"Fetch only accepts string and list of string.
\
if
feed
is
None
:
raise
ValueError
(
"You should specify feed vars for prediction.
\
log_id:{}"
.
format
(
log_id
))
feed_batch
=
[]
...
...
@@ -231,18 +261,20 @@ class LocalPredictor(object):
raise
ValueError
(
"Feed only accepts dict and list of dict.
\
log_id:{}"
.
format
(
log_id
))
fetch_names
=
[]
fetch_list
=
[]
if
fetch
is
not
None
:
if
isinstance
(
fetch
,
str
):
fetch_list
=
[
fetch
]
elif
isinstance
(
fetch
,
list
):
fetch_list
=
fetch
# Filter invalid fetch names
fetch_names
=
[]
for
key
in
fetch_list
:
if
key
in
self
.
fetch_names_
:
fetch_names
.
append
(
key
)
if
len
(
fetch_names
)
==
0
:
raise
ValueError
(
"Fetch names should not be empty or out of saved fetch list.
\
log_id:{}"
.
format
(
log_id
))
# Assemble the input data of paddle predictor
# Assemble the input data of paddle predictor, and filter invalid inputs.
input_names
=
self
.
predictor
.
get_input_names
()
for
name
in
input_names
:
if
isinstance
(
feed
[
name
],
list
):
...
...
@@ -282,11 +314,15 @@ class LocalPredictor(object):
input_tensor_handle
.
copy_from_cpu
(
feed
[
name
][
np
.
newaxis
,
:])
else
:
input_tensor_handle
.
copy_from_cpu
(
feed
[
name
])
# set output tensor handlers
output_tensor_handles
=
[]
output_name_to_index_dict
=
{}
output_names
=
self
.
predictor
.
get_output_names
()
for
output_name
in
output_names
:
for
i
,
output_name
in
enumerate
(
output_names
)
:
output_tensor_handle
=
self
.
predictor
.
get_output_handle
(
output_name
)
output_tensor_handles
.
append
(
output_tensor_handle
)
output_name_to_index_dict
[
output_name
]
=
i
# Run inference
self
.
predictor
.
run
()
...
...
@@ -296,10 +332,43 @@ class LocalPredictor(object):
for
output_tensor_handle
in
output_tensor_handles
:
output
=
output_tensor_handle
.
copy_to_cpu
()
outputs
.
append
(
output
)
outputs_len
=
len
(
outputs
)
# Copy fetch vars. If fetch is None, it will copy all results from output_tensor_handles.
# Otherwise, it will copy the fields specified from output_tensor_handles.
fetch_map
=
{}
for
i
,
name
in
enumerate
(
fetch
):
fetch_map
[
name
]
=
outputs
[
i
]
if
len
(
output_tensor_handles
[
i
].
lod
())
>
0
:
fetch_map
[
name
+
".lod"
]
=
np
.
array
(
output_tensor_handles
[
i
]
.
lod
()[
0
]).
astype
(
'int32'
)
if
fetch
is
None
:
for
i
,
name
in
enumerate
(
output_names
):
fetch_map
[
name
]
=
outputs
[
i
]
if
len
(
output_tensor_handles
[
i
].
lod
())
>
0
:
fetch_map
[
name
+
".lod"
]
=
np
.
array
(
output_tensor_handles
[
i
].
lod
()[
0
]).
astype
(
'int32'
)
else
:
# Because the save_inference_model interface will increase the scale op
# in the network, the name of fetch_var is different from that in prototxt.
# Therefore, it is compatible with v0.6.x and the previous model save format,
# and here is compatible with the results that do not match.
fetch_match_num
=
0
for
i
,
name
in
enumerate
(
fetch
):
output_index
=
output_name_to_index_dict
.
get
(
name
)
if
output_index
is
None
:
continue
fetch_map
[
name
]
=
outputs
[
output_index
]
fetch_match_num
+=
1
if
len
(
output_tensor_handles
[
output_index
].
lod
())
>
0
:
fetch_map
[
name
+
".lod"
]
=
np
.
array
(
output_tensor_handles
[
output_index
].
lod
()[
0
]).
astype
(
'int32'
)
# Compatible with v0.6.x and lower versions model saving formats.
if
fetch_match_num
==
0
:
logger
.
debug
(
"fetch match num is 0. Retrain the model please!"
)
for
i
,
name
in
enumerate
(
fetch
):
if
i
>=
outputs_len
:
break
fetch_map
[
name
]
=
outputs
[
i
]
if
len
(
output_tensor_handles
[
i
].
lod
())
>
0
:
fetch_map
[
name
+
".lod"
]
=
np
.
array
(
output_tensor_handles
[
i
].
lod
()[
0
]).
astype
(
'int32'
)
return
fetch_map
python/paddle_serving_client/io/__init__.py
浏览文件 @
99f0e9a4
...
...
@@ -67,7 +67,6 @@ def save_dygraph_model(serving_model_folder, client_config_folder, model):
}
config
=
model_conf
.
GeneralModelConfig
()
#int64 = 0; float32 = 1; int32 = 2;
for
key
in
feed_var_dict
:
feed_var
=
model_conf
.
FeedVar
()
feed_var
.
alias_name
=
key
...
...
@@ -127,7 +126,6 @@ def save_dygraph_model(serving_model_folder, client_config_folder, model):
def
var_type_conversion
(
dtype
):
"""
Variable type conversion
Args:
dtype: type of core.VarDesc.VarType.xxxxx
(https://github.com/PaddlePaddle/Paddle/blob/release/2.1/python/paddle/framework/dtype.py)
...
...
@@ -184,7 +182,9 @@ def save_model(server_model_folder,
main_program
=
None
,
encryption
=
False
,
key_len
=
128
,
encrypt_conf
=
None
):
encrypt_conf
=
None
,
model_filename
=
None
,
params_filename
=
None
):
executor
=
Executor
(
place
=
CPUPlace
())
feed_var_names
=
[
feed_var_dict
[
x
].
name
for
x
in
feed_var_dict
]
...
...
@@ -194,15 +194,27 @@ def save_model(server_model_folder,
target_vars
.
append
(
fetch_var_dict
[
key
])
target_var_names
.
append
(
key
)
if
not
os
.
path
.
exists
(
server_model_folder
):
os
.
makedirs
(
server_model_folder
)
if
not
encryption
:
save_inference_model
(
server_model_folder
,
feed_var_names
,
target_vars
,
executor
,
model_filename
=
"__model__"
,
params_filename
=
"__params__"
,
main_program
=
main_program
)
if
not
model_filename
:
model_filename
=
"model.pdmodel"
if
not
params_filename
:
params_filename
=
"params.pdiparams"
new_model_path
=
os
.
path
.
join
(
server_model_folder
,
model_filename
)
new_params_path
=
os
.
path
.
join
(
server_model_folder
,
params_filename
)
with
open
(
new_model_path
,
"wb"
)
as
new_model_file
:
new_model_file
.
write
(
main_program
.
desc
.
serialize_to_string
())
paddle
.
static
.
save_vars
(
executor
=
executor
,
dirname
=
server_model_folder
,
main_program
=
main_program
,
vars
=
None
,
predicate
=
paddle
.
static
.
io
.
is_persistable
,
filename
=
params_filename
)
else
:
if
encrypt_conf
==
None
:
aes_cipher
=
CipherFactory
.
create_cipher
()
...
...
@@ -296,7 +308,8 @@ def inference_model_to_serving(dirname,
}
fetch_dict
=
{
x
.
name
:
x
for
x
in
fetch_targets
}
save_model
(
serving_server
,
serving_client
,
feed_dict
,
fetch_dict
,
inference_program
,
encryption
,
key_len
,
encrypt_conf
)
inference_program
,
encryption
,
key_len
,
encrypt_conf
,
model_filename
,
params_filename
)
feed_names
=
feed_dict
.
keys
()
fetch_names
=
fetch_dict
.
keys
()
return
feed_names
,
fetch_names
python/pipeline/operator.py
浏览文件 @
99f0e9a4
...
...
@@ -40,6 +40,7 @@ from .channel import (ThreadChannel, ProcessChannel, ChannelDataErrcode,
from
.util
import
NameGenerator
from
.profiler
import
UnsafeTimeProfiler
as
TimeProfiler
from
.
import
local_service_handler
from
.pipeline_client
import
PipelineClient
as
PPClient
_LOGGER
=
logging
.
getLogger
(
__name__
)
_op_name_gen
=
NameGenerator
(
"Op"
)
...
...
@@ -330,9 +331,8 @@ class Op(object):
if
self
.
client_type
==
'brpc'
:
client
=
Client
()
client
.
load_client_config
(
client_config
)
# 待测试完成后,使用brpc-http替代。
# elif self.client_type == 'grpc':
# client = MultiLangClient()
elif
self
.
client_type
==
'pipeline_grpc'
:
client
=
PPClient
()
elif
self
.
client_type
==
'local_predictor'
:
if
self
.
local_predictor
is
None
:
raise
ValueError
(
"local predictor not yet created"
)
...
...
@@ -531,32 +531,72 @@ class Op(object):
Returns:
call_result: predict result
"""
err
,
err_info
=
ChannelData
.
check_batch_npdata
(
feed_batch
)
if
err
!=
0
:
_LOGGER
.
critical
(
self
.
_log
(
"Failed to run process: {}. Please override "
"preprocess func."
.
format
(
err_info
)))
os
.
_exit
(
-
1
)
call_result
=
None
err_code
=
ChannelDataErrcode
.
OK
.
value
err_info
=
""
if
self
.
client_type
==
"local_predictor"
:
err
,
err_info
=
ChannelData
.
check_batch_npdata
(
feed_batch
)
if
err
!=
0
:
_LOGGER
.
error
(
self
.
_log
(
"Failed to run process: {}. feed_batch must be
\
npdata in process for local_predictor mode."
.
format
(
err_info
)))
return
call_result
,
ChannelDataErrcode
.
TYPE_ERROR
.
value
,
"feed_batch must be npdata"
call_result
=
self
.
client
.
predict
(
feed
=
feed_batch
[
0
],
fetch
=
self
.
_fetch_names
,
batch
=
True
,
log_id
=
typical_logid
)
else
:
elif
self
.
client_type
==
"brpc"
:
err
,
err_info
=
ChannelData
.
check_batch_npdata
(
feed_batch
)
if
err
!=
0
:
_LOGGER
.
error
(
self
.
_log
(
"Failed to run process: {}. feed_batch must be
\
npdata in process for brpc mode."
.
format
(
err_info
)))
return
call_result
,
ChannelDataErrcode
.
TYPE_ERROR
.
value
,
"feed_batch must be npdata"
call_result
=
self
.
client
.
predict
(
feed
=
feed_batch
,
feed
=
feed_batch
[
0
]
,
fetch
=
self
.
_fetch_names
,
batch
=
True
,
log_id
=
typical_logid
)
# 后续用HttpClient替代
'''
if isinstance(self.client, MultiLangClient):
if call_result is None or call_result["serving_status_code"] != 0:
return None
call_result.pop("serving_status_code")
'''
return
call_result
elif
self
.
client_type
==
"pipeline_grpc"
:
err
,
err_info
=
ChannelData
.
check_dictdata
(
feed_batch
)
if
err
!=
0
:
_LOGGER
.
error
(
self
.
_log
(
"Failed to run process: {}. feed_batch must be
\
npdata in process for pipeline_grpc mode."
.
format
(
err_info
)))
return
call_result
,
ChannelDataErrcode
.
TYPE_ERROR
.
value
,
"feed_batch must be dict"
call_result
=
self
.
client
.
predict
(
feed_dict
=
feed_batch
[
0
],
fetch
=
self
.
_fetch_names
,
asyn
=
False
,
profile
=
False
)
if
call_result
is
None
:
_LOGGER
.
error
(
self
.
_log
(
"Failed in pipeline_grpc. call_result is None."
))
return
call_result
,
ChannelDataErrcode
.
UNKNOW
.
value
,
"pipeline_grpc error"
if
call_result
.
err_no
!=
0
:
_LOGGER
.
error
(
self
.
_log
(
"Failed in pipeline_grpc. err_no:{}, err_info:{}"
.
format
(
call_result
.
err_no
,
call_result
.
err_msg
)))
return
call_result
,
ChannelDataErrcode
(
call_result
.
err_no
).
value
,
call_result
.
err_msg
new_dict
=
{}
err_code
=
ChannelDataErrcode
(
call_result
.
err_no
).
value
err_info
=
call_result
.
err_msg
for
idx
,
key
in
enumerate
(
call_result
.
key
):
new_dict
[
key
]
=
[
call_result
.
value
[
idx
]]
call_result
=
new_dict
return
call_result
,
err_code
,
err_info
def
postprocess
(
self
,
input_data
,
fetch_data
,
data_id
=
0
,
log_id
=
0
):
"""
...
...
@@ -891,16 +931,20 @@ class Op(object):
midped_batch
=
None
error_code
=
ChannelDataErrcode
.
OK
.
value
error_info
=
""
if
self
.
_timeout
<=
0
:
# No retry
try
:
if
batch_input
is
False
:
midped_batch
=
self
.
process
(
feed_batch
,
typical_logid
)
midped_batch
,
error_code
,
error_info
=
self
.
process
(
feed_batch
,
typical_logid
)
else
:
midped_batch
=
[]
for
idx
in
range
(
len
(
feed_batch
)):
predict_res
=
self
.
process
([
feed_batch
[
idx
]],
typical_logid
)
predict_res
,
error_code
,
error_info
=
self
.
process
(
[
feed_batch
[
idx
]],
typical_logid
)
if
error_code
!=
ChannelDataErrcode
.
OK
.
value
:
break
midped_batch
.
append
(
predict_res
)
except
Exception
as
e
:
error_code
=
ChannelDataErrcode
.
UNKNOW
.
value
...
...
@@ -913,14 +957,14 @@ class Op(object):
try
:
# time out for each process
if
batch_input
is
False
:
midped_batch
=
func_timeout
.
func_timeout
(
midped_batch
,
error_code
,
error_info
=
func_timeout
.
func_timeout
(
self
.
_timeout
,
self
.
process
,
args
=
(
feed_batch
,
typical_logid
))
else
:
midped_batch
=
[]
for
idx
in
range
(
len
(
feed_batch
)):
predict_res
=
func_timeout
.
func_timeout
(
predict_res
,
error_code
,
error_info
=
func_timeout
.
func_timeout
(
self
.
_timeout
,
self
.
process
,
args
=
([
feed_batch
[
idx
]],
typical_logid
))
...
...
python/pipeline/pipeline_client.py
浏览文件 @
99f0e9a4
...
...
@@ -93,13 +93,19 @@ class PipelineClient(object):
def
_unpack_response_package
(
self
,
resp
,
fetch
):
return
resp
def
predict
(
self
,
feed_dict
,
fetch
=
None
,
asyn
=
False
,
profile
=
False
):
def
predict
(
self
,
feed_dict
,
fetch
=
None
,
asyn
=
False
,
profile
=
False
,
log_id
=
0
):
if
not
isinstance
(
feed_dict
,
dict
):
raise
TypeError
(
"feed must be dict type with format: {name: value}."
)
if
fetch
is
not
None
and
not
isinstance
(
fetch
,
list
):
raise
TypeError
(
"fetch must be list type with format: [name]."
)
req
=
self
.
_pack_request_package
(
feed_dict
,
profile
)
req
.
logid
=
log_id
if
not
asyn
:
resp
=
self
.
_stub
.
inference
(
req
)
return
self
.
_unpack_response_package
(
resp
,
fetch
)
...
...
python/pipeline/util.py
浏览文件 @
99f0e9a4
...
...
@@ -39,7 +39,7 @@ class AvailablePortGenerator(object):
def
port_is_available
(
port
):
with
closing
(
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
))
as
sock
:
sock
.
settimeout
(
2
)
result
=
sock
.
connect_ex
((
'
127.0.0.1
'
,
port
))
result
=
sock
.
connect_ex
((
'
0.0.0.0
'
,
port
))
if
result
!=
0
:
return
True
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录