Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
99f0e9a4
S
Serving
项目概览
PaddlePaddle
/
Serving
接近 2 年 前同步成功
通知
186
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
99f0e9a4
编写于
8月 19, 2021
作者:
T
TeslaZhao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modify model conversion interfaces and model load methods
上级
005f10dc
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
195 addition
and
63 deletion
+195
-63
python/paddle_serving_app/local_predict.py
python/paddle_serving_app/local_predict.py
+94
-25
python/paddle_serving_client/io/__init__.py
python/paddle_serving_client/io/__init__.py
+25
-12
python/pipeline/operator.py
python/pipeline/operator.py
+68
-24
python/pipeline/pipeline_client.py
python/pipeline/pipeline_client.py
+7
-1
python/pipeline/util.py
python/pipeline/util.py
+1
-1
未找到文件。
python/paddle_serving_app/local_predict.py
浏览文件 @
99f0e9a4
...
@@ -22,6 +22,7 @@ import argparse
...
@@ -22,6 +22,7 @@ import argparse
from
.proto
import
general_model_config_pb2
as
m_config
from
.proto
import
general_model_config_pb2
as
m_config
import
paddle.inference
as
paddle_infer
import
paddle.inference
as
paddle_infer
import
logging
import
logging
import
glob
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(message)s"
)
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
"LocalPredictor"
)
logger
=
logging
.
getLogger
(
"LocalPredictor"
)
...
@@ -51,6 +52,23 @@ class LocalPredictor(object):
...
@@ -51,6 +52,23 @@ class LocalPredictor(object):
self
.
fetch_names_to_idx_
=
{}
self
.
fetch_names_to_idx_
=
{}
self
.
fetch_names_to_type_
=
{}
self
.
fetch_names_to_type_
=
{}
def
search_suffix_files
(
self
,
model_path
,
target_suffix
):
"""
Find all files with the suffix xxx in the specified directory.
Args:
model_path: model directory, not None.
target_suffix: filenames with target suffix, not None. e.g: *.pdmodel
Returns:
file_list, None, [] or [path, ] .
"""
if
model_path
is
None
or
target_suffix
is
None
:
return
None
file_list
=
glob
.
glob
(
os
.
path
.
join
(
model_path
,
target_suffix
))
return
file_list
def
load_model_config
(
self
,
def
load_model_config
(
self
,
model_path
,
model_path
,
use_gpu
=
False
,
use_gpu
=
False
,
...
@@ -97,11 +115,30 @@ class LocalPredictor(object):
...
@@ -97,11 +115,30 @@ class LocalPredictor(object):
f
=
open
(
client_config
,
'r'
)
f
=
open
(
client_config
,
'r'
)
model_conf
=
google
.
protobuf
.
text_format
.
Merge
(
model_conf
=
google
.
protobuf
.
text_format
.
Merge
(
str
(
f
.
read
()),
model_conf
)
str
(
f
.
read
()),
model_conf
)
# Init paddle_infer config
# Paddle's model files and parameter files have multiple naming rules:
# 1) __model__, __params__
# 2) *.pdmodel, *.pdiparams
# 3) __model__, conv2d_1.w_0, conv2d_2.w_0, fc_1.w_0, conv2d_1.b_0, ...
pdmodel_file_list
=
self
.
search_suffix_files
(
model_path
,
"*.pdmodel"
)
pdiparams_file_list
=
self
.
search_suffix_files
(
model_path
,
"*.pdiparams"
)
if
os
.
path
.
exists
(
os
.
path
.
join
(
model_path
,
"__params__"
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
model_path
,
"__params__"
)):
# case 1) initializing
config
=
paddle_infer
.
Config
(
config
=
paddle_infer
.
Config
(
os
.
path
.
join
(
model_path
,
"__model__"
),
os
.
path
.
join
(
model_path
,
"__model__"
),
os
.
path
.
join
(
model_path
,
"__params__"
))
os
.
path
.
join
(
model_path
,
"__params__"
))
elif
pdmodel_file_list
and
len
(
pdmodel_file_list
)
>
0
and
pdiparams_file_list
and
len
(
pdiparams_file_list
)
>
0
:
# case 2) initializing
logger
.
info
(
"pdmodel_file_list:{}, pdiparams_file_list:{}"
.
format
(
pdmodel_file_list
,
pdiparams_file_list
))
config
=
paddle_infer
.
Config
(
pdmodel_file_list
[
0
],
pdiparams_file_list
[
0
])
else
:
else
:
# case 3) initializing.
config
=
paddle_infer
.
Config
(
model_path
)
config
=
paddle_infer
.
Config
(
model_path
)
logger
.
info
(
logger
.
info
(
...
@@ -201,8 +238,9 @@ class LocalPredictor(object):
...
@@ -201,8 +238,9 @@ class LocalPredictor(object):
Run model inference by Paddle Inference API.
Run model inference by Paddle Inference API.
Args:
Args:
feed: feed var
feed: feed var list, None is not allowed.
fetch: fetch var
fetch: fetch var list, None allowed. when it is None, all fetch
vars are returned. Otherwise, return fetch specified result.
batch: batch data or not, False default.If batch is False, a new
batch: batch data or not, False default.If batch is False, a new
dimension is added to header of the shape[np.newaxis].
dimension is added to header of the shape[np.newaxis].
log_id: for logging
log_id: for logging
...
@@ -210,16 +248,8 @@ class LocalPredictor(object):
...
@@ -210,16 +248,8 @@ class LocalPredictor(object):
Returns:
Returns:
fetch_map: dict
fetch_map: dict
"""
"""
if
feed
is
None
or
fetch
is
None
:
if
feed
is
None
:
raise
ValueError
(
"You should specify feed and fetch for prediction.
\
raise
ValueError
(
"You should specify feed vars for prediction.
\
log_id:{}"
.
format
(
log_id
))
fetch_list
=
[]
if
isinstance
(
fetch
,
str
):
fetch_list
=
[
fetch
]
elif
isinstance
(
fetch
,
list
):
fetch_list
=
fetch
else
:
raise
ValueError
(
"Fetch only accepts string and list of string.
\
log_id:{}"
.
format
(
log_id
))
log_id:{}"
.
format
(
log_id
))
feed_batch
=
[]
feed_batch
=
[]
...
@@ -231,18 +261,20 @@ class LocalPredictor(object):
...
@@ -231,18 +261,20 @@ class LocalPredictor(object):
raise
ValueError
(
"Feed only accepts dict and list of dict.
\
raise
ValueError
(
"Feed only accepts dict and list of dict.
\
log_id:{}"
.
format
(
log_id
))
log_id:{}"
.
format
(
log_id
))
fetch_names
=
[]
fetch_list
=
[]
if
fetch
is
not
None
:
if
isinstance
(
fetch
,
str
):
fetch_list
=
[
fetch
]
elif
isinstance
(
fetch
,
list
):
fetch_list
=
fetch
# Filter invalid fetch names
# Filter invalid fetch names
fetch_names
=
[]
for
key
in
fetch_list
:
for
key
in
fetch_list
:
if
key
in
self
.
fetch_names_
:
if
key
in
self
.
fetch_names_
:
fetch_names
.
append
(
key
)
fetch_names
.
append
(
key
)
if
len
(
fetch_names
)
==
0
:
# Assemble the input data of paddle predictor, and filter invalid inputs.
raise
ValueError
(
"Fetch names should not be empty or out of saved fetch list.
\
log_id:{}"
.
format
(
log_id
))
# Assemble the input data of paddle predictor
input_names
=
self
.
predictor
.
get_input_names
()
input_names
=
self
.
predictor
.
get_input_names
()
for
name
in
input_names
:
for
name
in
input_names
:
if
isinstance
(
feed
[
name
],
list
):
if
isinstance
(
feed
[
name
],
list
):
...
@@ -282,11 +314,15 @@ class LocalPredictor(object):
...
@@ -282,11 +314,15 @@ class LocalPredictor(object):
input_tensor_handle
.
copy_from_cpu
(
feed
[
name
][
np
.
newaxis
,
:])
input_tensor_handle
.
copy_from_cpu
(
feed
[
name
][
np
.
newaxis
,
:])
else
:
else
:
input_tensor_handle
.
copy_from_cpu
(
feed
[
name
])
input_tensor_handle
.
copy_from_cpu
(
feed
[
name
])
# set output tensor handlers
output_tensor_handles
=
[]
output_tensor_handles
=
[]
output_name_to_index_dict
=
{}
output_names
=
self
.
predictor
.
get_output_names
()
output_names
=
self
.
predictor
.
get_output_names
()
for
output_name
in
output_names
:
for
i
,
output_name
in
enumerate
(
output_names
)
:
output_tensor_handle
=
self
.
predictor
.
get_output_handle
(
output_name
)
output_tensor_handle
=
self
.
predictor
.
get_output_handle
(
output_name
)
output_tensor_handles
.
append
(
output_tensor_handle
)
output_tensor_handles
.
append
(
output_tensor_handle
)
output_name_to_index_dict
[
output_name
]
=
i
# Run inference
# Run inference
self
.
predictor
.
run
()
self
.
predictor
.
run
()
...
@@ -296,10 +332,43 @@ class LocalPredictor(object):
...
@@ -296,10 +332,43 @@ class LocalPredictor(object):
for
output_tensor_handle
in
output_tensor_handles
:
for
output_tensor_handle
in
output_tensor_handles
:
output
=
output_tensor_handle
.
copy_to_cpu
()
output
=
output_tensor_handle
.
copy_to_cpu
()
outputs
.
append
(
output
)
outputs
.
append
(
output
)
outputs_len
=
len
(
outputs
)
# Copy fetch vars. If fetch is None, it will copy all results from output_tensor_handles.
# Otherwise, it will copy the fields specified from output_tensor_handles.
fetch_map
=
{}
fetch_map
=
{}
for
i
,
name
in
enumerate
(
fetch
):
if
fetch
is
None
:
fetch_map
[
name
]
=
outputs
[
i
]
for
i
,
name
in
enumerate
(
output_names
):
if
len
(
output_tensor_handles
[
i
].
lod
())
>
0
:
fetch_map
[
name
]
=
outputs
[
i
]
fetch_map
[
name
+
".lod"
]
=
np
.
array
(
output_tensor_handles
[
i
]
if
len
(
output_tensor_handles
[
i
].
lod
())
>
0
:
.
lod
()[
0
]).
astype
(
'int32'
)
fetch_map
[
name
+
".lod"
]
=
np
.
array
(
output_tensor_handles
[
i
].
lod
()[
0
]).
astype
(
'int32'
)
else
:
# Because the save_inference_model interface will increase the scale op
# in the network, the name of fetch_var is different from that in prototxt.
# Therefore, it is compatible with v0.6.x and the previous model save format,
# and here is compatible with the results that do not match.
fetch_match_num
=
0
for
i
,
name
in
enumerate
(
fetch
):
output_index
=
output_name_to_index_dict
.
get
(
name
)
if
output_index
is
None
:
continue
fetch_map
[
name
]
=
outputs
[
output_index
]
fetch_match_num
+=
1
if
len
(
output_tensor_handles
[
output_index
].
lod
())
>
0
:
fetch_map
[
name
+
".lod"
]
=
np
.
array
(
output_tensor_handles
[
output_index
].
lod
()[
0
]).
astype
(
'int32'
)
# Compatible with v0.6.x and lower versions model saving formats.
if
fetch_match_num
==
0
:
logger
.
debug
(
"fetch match num is 0. Retrain the model please!"
)
for
i
,
name
in
enumerate
(
fetch
):
if
i
>=
outputs_len
:
break
fetch_map
[
name
]
=
outputs
[
i
]
if
len
(
output_tensor_handles
[
i
].
lod
())
>
0
:
fetch_map
[
name
+
".lod"
]
=
np
.
array
(
output_tensor_handles
[
i
].
lod
()[
0
]).
astype
(
'int32'
)
return
fetch_map
return
fetch_map
python/paddle_serving_client/io/__init__.py
浏览文件 @
99f0e9a4
...
@@ -67,7 +67,6 @@ def save_dygraph_model(serving_model_folder, client_config_folder, model):
...
@@ -67,7 +67,6 @@ def save_dygraph_model(serving_model_folder, client_config_folder, model):
}
}
config
=
model_conf
.
GeneralModelConfig
()
config
=
model_conf
.
GeneralModelConfig
()
#int64 = 0; float32 = 1; int32 = 2;
for
key
in
feed_var_dict
:
for
key
in
feed_var_dict
:
feed_var
=
model_conf
.
FeedVar
()
feed_var
=
model_conf
.
FeedVar
()
feed_var
.
alias_name
=
key
feed_var
.
alias_name
=
key
...
@@ -127,7 +126,6 @@ def save_dygraph_model(serving_model_folder, client_config_folder, model):
...
@@ -127,7 +126,6 @@ def save_dygraph_model(serving_model_folder, client_config_folder, model):
def
var_type_conversion
(
dtype
):
def
var_type_conversion
(
dtype
):
"""
"""
Variable type conversion
Variable type conversion
Args:
Args:
dtype: type of core.VarDesc.VarType.xxxxx
dtype: type of core.VarDesc.VarType.xxxxx
(https://github.com/PaddlePaddle/Paddle/blob/release/2.1/python/paddle/framework/dtype.py)
(https://github.com/PaddlePaddle/Paddle/blob/release/2.1/python/paddle/framework/dtype.py)
...
@@ -184,7 +182,9 @@ def save_model(server_model_folder,
...
@@ -184,7 +182,9 @@ def save_model(server_model_folder,
main_program
=
None
,
main_program
=
None
,
encryption
=
False
,
encryption
=
False
,
key_len
=
128
,
key_len
=
128
,
encrypt_conf
=
None
):
encrypt_conf
=
None
,
model_filename
=
None
,
params_filename
=
None
):
executor
=
Executor
(
place
=
CPUPlace
())
executor
=
Executor
(
place
=
CPUPlace
())
feed_var_names
=
[
feed_var_dict
[
x
].
name
for
x
in
feed_var_dict
]
feed_var_names
=
[
feed_var_dict
[
x
].
name
for
x
in
feed_var_dict
]
...
@@ -194,15 +194,27 @@ def save_model(server_model_folder,
...
@@ -194,15 +194,27 @@ def save_model(server_model_folder,
target_vars
.
append
(
fetch_var_dict
[
key
])
target_vars
.
append
(
fetch_var_dict
[
key
])
target_var_names
.
append
(
key
)
target_var_names
.
append
(
key
)
if
not
os
.
path
.
exists
(
server_model_folder
):
os
.
makedirs
(
server_model_folder
)
if
not
encryption
:
if
not
encryption
:
save_inference_model
(
if
not
model_filename
:
server_model_folder
,
model_filename
=
"model.pdmodel"
feed_var_names
,
if
not
params_filename
:
target_vars
,
params_filename
=
"params.pdiparams"
executor
,
model_filename
=
"__model__"
,
new_model_path
=
os
.
path
.
join
(
server_model_folder
,
model_filename
)
params_filename
=
"__params__"
,
new_params_path
=
os
.
path
.
join
(
server_model_folder
,
params_filename
)
main_program
=
main_program
)
with
open
(
new_model_path
,
"wb"
)
as
new_model_file
:
new_model_file
.
write
(
main_program
.
desc
.
serialize_to_string
())
paddle
.
static
.
save_vars
(
executor
=
executor
,
dirname
=
server_model_folder
,
main_program
=
main_program
,
vars
=
None
,
predicate
=
paddle
.
static
.
io
.
is_persistable
,
filename
=
params_filename
)
else
:
else
:
if
encrypt_conf
==
None
:
if
encrypt_conf
==
None
:
aes_cipher
=
CipherFactory
.
create_cipher
()
aes_cipher
=
CipherFactory
.
create_cipher
()
...
@@ -296,7 +308,8 @@ def inference_model_to_serving(dirname,
...
@@ -296,7 +308,8 @@ def inference_model_to_serving(dirname,
}
}
fetch_dict
=
{
x
.
name
:
x
for
x
in
fetch_targets
}
fetch_dict
=
{
x
.
name
:
x
for
x
in
fetch_targets
}
save_model
(
serving_server
,
serving_client
,
feed_dict
,
fetch_dict
,
save_model
(
serving_server
,
serving_client
,
feed_dict
,
fetch_dict
,
inference_program
,
encryption
,
key_len
,
encrypt_conf
)
inference_program
,
encryption
,
key_len
,
encrypt_conf
,
model_filename
,
params_filename
)
feed_names
=
feed_dict
.
keys
()
feed_names
=
feed_dict
.
keys
()
fetch_names
=
fetch_dict
.
keys
()
fetch_names
=
fetch_dict
.
keys
()
return
feed_names
,
fetch_names
return
feed_names
,
fetch_names
python/pipeline/operator.py
浏览文件 @
99f0e9a4
...
@@ -40,6 +40,7 @@ from .channel import (ThreadChannel, ProcessChannel, ChannelDataErrcode,
...
@@ -40,6 +40,7 @@ from .channel import (ThreadChannel, ProcessChannel, ChannelDataErrcode,
from
.util
import
NameGenerator
from
.util
import
NameGenerator
from
.profiler
import
UnsafeTimeProfiler
as
TimeProfiler
from
.profiler
import
UnsafeTimeProfiler
as
TimeProfiler
from
.
import
local_service_handler
from
.
import
local_service_handler
from
.pipeline_client
import
PipelineClient
as
PPClient
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
(
__name__
)
_op_name_gen
=
NameGenerator
(
"Op"
)
_op_name_gen
=
NameGenerator
(
"Op"
)
...
@@ -330,9 +331,8 @@ class Op(object):
...
@@ -330,9 +331,8 @@ class Op(object):
if
self
.
client_type
==
'brpc'
:
if
self
.
client_type
==
'brpc'
:
client
=
Client
()
client
=
Client
()
client
.
load_client_config
(
client_config
)
client
.
load_client_config
(
client_config
)
# 待测试完成后,使用brpc-http替代。
elif
self
.
client_type
==
'pipeline_grpc'
:
# elif self.client_type == 'grpc':
client
=
PPClient
()
# client = MultiLangClient()
elif
self
.
client_type
==
'local_predictor'
:
elif
self
.
client_type
==
'local_predictor'
:
if
self
.
local_predictor
is
None
:
if
self
.
local_predictor
is
None
:
raise
ValueError
(
"local predictor not yet created"
)
raise
ValueError
(
"local predictor not yet created"
)
...
@@ -531,32 +531,72 @@ class Op(object):
...
@@ -531,32 +531,72 @@ class Op(object):
Returns:
Returns:
call_result: predict result
call_result: predict result
"""
"""
err
,
err_info
=
ChannelData
.
check_batch_npdata
(
feed_batch
)
if
err
!=
0
:
call_result
=
None
_LOGGER
.
critical
(
err_code
=
ChannelDataErrcode
.
OK
.
value
self
.
_log
(
"Failed to run process: {}. Please override "
err_info
=
""
"preprocess func."
.
format
(
err_info
)))
os
.
_exit
(
-
1
)
if
self
.
client_type
==
"local_predictor"
:
if
self
.
client_type
==
"local_predictor"
:
err
,
err_info
=
ChannelData
.
check_batch_npdata
(
feed_batch
)
if
err
!=
0
:
_LOGGER
.
error
(
self
.
_log
(
"Failed to run process: {}. feed_batch must be
\
npdata in process for local_predictor mode."
.
format
(
err_info
)))
return
call_result
,
ChannelDataErrcode
.
TYPE_ERROR
.
value
,
"feed_batch must be npdata"
call_result
=
self
.
client
.
predict
(
call_result
=
self
.
client
.
predict
(
feed
=
feed_batch
[
0
],
feed
=
feed_batch
[
0
],
fetch
=
self
.
_fetch_names
,
fetch
=
self
.
_fetch_names
,
batch
=
True
,
batch
=
True
,
log_id
=
typical_logid
)
log_id
=
typical_logid
)
else
:
elif
self
.
client_type
==
"brpc"
:
err
,
err_info
=
ChannelData
.
check_batch_npdata
(
feed_batch
)
if
err
!=
0
:
_LOGGER
.
error
(
self
.
_log
(
"Failed to run process: {}. feed_batch must be
\
npdata in process for brpc mode."
.
format
(
err_info
)))
return
call_result
,
ChannelDataErrcode
.
TYPE_ERROR
.
value
,
"feed_batch must be npdata"
call_result
=
self
.
client
.
predict
(
call_result
=
self
.
client
.
predict
(
feed
=
feed_batch
,
feed
=
feed_batch
[
0
]
,
fetch
=
self
.
_fetch_names
,
fetch
=
self
.
_fetch_names
,
batch
=
True
,
batch
=
True
,
log_id
=
typical_logid
)
log_id
=
typical_logid
)
# 后续用HttpClient替代
'''
elif
self
.
client_type
==
"pipeline_grpc"
:
if isinstance(self.client, MultiLangClient):
err
,
err_info
=
ChannelData
.
check_dictdata
(
feed_batch
)
if call_result is None or call_result["serving_status_code"] != 0:
if
err
!=
0
:
return None
_LOGGER
.
error
(
call_result.pop("serving_status_code")
self
.
_log
(
"Failed to run process: {}. feed_batch must be
\
'''
npdata in process for pipeline_grpc mode."
return
call_result
.
format
(
err_info
)))
return
call_result
,
ChannelDataErrcode
.
TYPE_ERROR
.
value
,
"feed_batch must be dict"
call_result
=
self
.
client
.
predict
(
feed_dict
=
feed_batch
[
0
],
fetch
=
self
.
_fetch_names
,
asyn
=
False
,
profile
=
False
)
if
call_result
is
None
:
_LOGGER
.
error
(
self
.
_log
(
"Failed in pipeline_grpc. call_result is None."
))
return
call_result
,
ChannelDataErrcode
.
UNKNOW
.
value
,
"pipeline_grpc error"
if
call_result
.
err_no
!=
0
:
_LOGGER
.
error
(
self
.
_log
(
"Failed in pipeline_grpc. err_no:{}, err_info:{}"
.
format
(
call_result
.
err_no
,
call_result
.
err_msg
)))
return
call_result
,
ChannelDataErrcode
(
call_result
.
err_no
).
value
,
call_result
.
err_msg
new_dict
=
{}
err_code
=
ChannelDataErrcode
(
call_result
.
err_no
).
value
err_info
=
call_result
.
err_msg
for
idx
,
key
in
enumerate
(
call_result
.
key
):
new_dict
[
key
]
=
[
call_result
.
value
[
idx
]]
call_result
=
new_dict
return
call_result
,
err_code
,
err_info
def
postprocess
(
self
,
input_data
,
fetch_data
,
data_id
=
0
,
log_id
=
0
):
def
postprocess
(
self
,
input_data
,
fetch_data
,
data_id
=
0
,
log_id
=
0
):
"""
"""
...
@@ -891,16 +931,20 @@ class Op(object):
...
@@ -891,16 +931,20 @@ class Op(object):
midped_batch
=
None
midped_batch
=
None
error_code
=
ChannelDataErrcode
.
OK
.
value
error_code
=
ChannelDataErrcode
.
OK
.
value
error_info
=
""
if
self
.
_timeout
<=
0
:
if
self
.
_timeout
<=
0
:
# No retry
# No retry
try
:
try
:
if
batch_input
is
False
:
if
batch_input
is
False
:
midped_batch
=
self
.
process
(
feed_batch
,
typical_logid
)
midped_batch
,
error_code
,
error_info
=
self
.
process
(
feed_batch
,
typical_logid
)
else
:
else
:
midped_batch
=
[]
midped_batch
=
[]
for
idx
in
range
(
len
(
feed_batch
)):
for
idx
in
range
(
len
(
feed_batch
)):
predict_res
=
self
.
process
([
feed_batch
[
idx
]],
predict_res
,
error_code
,
error_info
=
self
.
process
(
typical_logid
)
[
feed_batch
[
idx
]],
typical_logid
)
if
error_code
!=
ChannelDataErrcode
.
OK
.
value
:
break
midped_batch
.
append
(
predict_res
)
midped_batch
.
append
(
predict_res
)
except
Exception
as
e
:
except
Exception
as
e
:
error_code
=
ChannelDataErrcode
.
UNKNOW
.
value
error_code
=
ChannelDataErrcode
.
UNKNOW
.
value
...
@@ -913,14 +957,14 @@ class Op(object):
...
@@ -913,14 +957,14 @@ class Op(object):
try
:
try
:
# time out for each process
# time out for each process
if
batch_input
is
False
:
if
batch_input
is
False
:
midped_batch
=
func_timeout
.
func_timeout
(
midped_batch
,
error_code
,
error_info
=
func_timeout
.
func_timeout
(
self
.
_timeout
,
self
.
_timeout
,
self
.
process
,
self
.
process
,
args
=
(
feed_batch
,
typical_logid
))
args
=
(
feed_batch
,
typical_logid
))
else
:
else
:
midped_batch
=
[]
midped_batch
=
[]
for
idx
in
range
(
len
(
feed_batch
)):
for
idx
in
range
(
len
(
feed_batch
)):
predict_res
=
func_timeout
.
func_timeout
(
predict_res
,
error_code
,
error_info
=
func_timeout
.
func_timeout
(
self
.
_timeout
,
self
.
_timeout
,
self
.
process
,
self
.
process
,
args
=
([
feed_batch
[
idx
]],
typical_logid
))
args
=
([
feed_batch
[
idx
]],
typical_logid
))
...
...
python/pipeline/pipeline_client.py
浏览文件 @
99f0e9a4
...
@@ -93,13 +93,19 @@ class PipelineClient(object):
...
@@ -93,13 +93,19 @@ class PipelineClient(object):
def
_unpack_response_package
(
self
,
resp
,
fetch
):
def
_unpack_response_package
(
self
,
resp
,
fetch
):
return
resp
return
resp
def
predict
(
self
,
feed_dict
,
fetch
=
None
,
asyn
=
False
,
profile
=
False
):
def
predict
(
self
,
feed_dict
,
fetch
=
None
,
asyn
=
False
,
profile
=
False
,
log_id
=
0
):
if
not
isinstance
(
feed_dict
,
dict
):
if
not
isinstance
(
feed_dict
,
dict
):
raise
TypeError
(
raise
TypeError
(
"feed must be dict type with format: {name: value}."
)
"feed must be dict type with format: {name: value}."
)
if
fetch
is
not
None
and
not
isinstance
(
fetch
,
list
):
if
fetch
is
not
None
and
not
isinstance
(
fetch
,
list
):
raise
TypeError
(
"fetch must be list type with format: [name]."
)
raise
TypeError
(
"fetch must be list type with format: [name]."
)
req
=
self
.
_pack_request_package
(
feed_dict
,
profile
)
req
=
self
.
_pack_request_package
(
feed_dict
,
profile
)
req
.
logid
=
log_id
if
not
asyn
:
if
not
asyn
:
resp
=
self
.
_stub
.
inference
(
req
)
resp
=
self
.
_stub
.
inference
(
req
)
return
self
.
_unpack_response_package
(
resp
,
fetch
)
return
self
.
_unpack_response_package
(
resp
,
fetch
)
...
...
python/pipeline/util.py
浏览文件 @
99f0e9a4
...
@@ -39,7 +39,7 @@ class AvailablePortGenerator(object):
...
@@ -39,7 +39,7 @@ class AvailablePortGenerator(object):
def
port_is_available
(
port
):
def
port_is_available
(
port
):
with
closing
(
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
))
as
sock
:
with
closing
(
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
))
as
sock
:
sock
.
settimeout
(
2
)
sock
.
settimeout
(
2
)
result
=
sock
.
connect_ex
((
'
127.0.0.1
'
,
port
))
result
=
sock
.
connect_ex
((
'
0.0.0.0
'
,
port
))
if
result
!=
0
:
if
result
!=
0
:
return
True
return
True
else
:
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录