Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
f3eb9d4a
S
Serving
项目概览
PaddlePaddle
/
Serving
大约 1 年 前同步成功
通知
186
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f3eb9d4a
编写于
7月 02, 2020
作者:
B
barriery
提交者:
GitHub
7月 02, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #710 from barrierye/pipeline-update
update pipeline
上级
0dc8b905
22b168ca
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
127 addition
and
49 deletion
+127
-49
python/pipeline/channel.py
python/pipeline/channel.py
+37
-8
python/pipeline/operator.py
python/pipeline/operator.py
+78
-33
python/pipeline/pipeline_client.py
python/pipeline/pipeline_client.py
+5
-5
python/pipeline/pipeline_server.py
python/pipeline/pipeline_server.py
+5
-1
python/pipeline/profiler.py
python/pipeline/profiler.py
+2
-2
未找到文件。
python/pipeline/channel.py
浏览文件 @
f3eb9d4a
...
@@ -27,7 +27,7 @@ import logging
...
@@ -27,7 +27,7 @@ import logging
import
enum
import
enum
import
copy
import
copy
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
class
ChannelDataEcode
(
enum
.
Enum
):
class
ChannelDataEcode
(
enum
.
Enum
):
...
@@ -92,7 +92,16 @@ class ChannelData(object):
...
@@ -92,7 +92,16 @@ class ChannelData(object):
def
check_dictdata
(
dictdata
):
def
check_dictdata
(
dictdata
):
ecode
=
ChannelDataEcode
.
OK
.
value
ecode
=
ChannelDataEcode
.
OK
.
value
error_info
=
None
error_info
=
None
if
not
isinstance
(
dictdata
,
dict
):
if
isinstance
(
dictdata
,
list
):
# batch data
for
sample
in
dictdata
:
if
not
isinstance
(
sample
,
dict
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be dict, but get {}."
.
format
(
type
(
sample
))
break
elif
not
isinstance
(
dictdata
,
dict
):
# batch size = 1
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
error_info
=
"the value of data must "
\
"be dict, but get {}."
.
format
(
type
(
dictdata
))
"be dict, but get {}."
.
format
(
type
(
dictdata
))
...
@@ -102,12 +111,32 @@ class ChannelData(object):
...
@@ -102,12 +111,32 @@ class ChannelData(object):
def
check_npdata
(
npdata
):
def
check_npdata
(
npdata
):
ecode
=
ChannelDataEcode
.
OK
.
value
ecode
=
ChannelDataEcode
.
OK
.
value
error_info
=
None
error_info
=
None
if
isinstance
(
npdata
,
list
):
# batch data
for
sample
in
npdata
:
if
not
isinstance
(
sample
,
dict
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be dict, but get {}."
.
format
(
type
(
sample
))
break
for
_
,
value
in
sample
.
items
():
if
not
isinstance
(
value
,
np
.
ndarray
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be np.ndarray, but get {}."
.
format
(
type
(
value
))
return
ecode
,
error_info
elif
isinstance
(
npdata
,
dict
):
# batch_size = 1
for
_
,
value
in
npdata
.
items
():
for
_
,
value
in
npdata
.
items
():
if
not
isinstance
(
value
,
np
.
ndarray
):
if
not
isinstance
(
value
,
np
.
ndarray
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
error_info
=
"the value of data must "
\
"be np.ndarray, but get {}."
.
format
(
type
(
value
))
"be np.ndarray, but get {}."
.
format
(
type
(
value
))
break
break
else
:
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be dict, but get {}."
.
format
(
type
(
npdata
))
return
ecode
,
error_info
return
ecode
,
error_info
def
parse
(
self
):
def
parse
(
self
):
...
...
python/pipeline/operator.py
浏览文件 @
f3eb9d4a
...
@@ -19,13 +19,14 @@ from paddle_serving_client import MultiLangClient, Client
...
@@ -19,13 +19,14 @@ from paddle_serving_client import MultiLangClient, Client
from
concurrent
import
futures
from
concurrent
import
futures
import
logging
import
logging
import
func_timeout
import
func_timeout
import
os
from
numpy
import
*
from
numpy
import
*
from
.proto
import
pipeline_service_pb2
from
.proto
import
pipeline_service_pb2
from
.channel
import
ThreadChannel
,
ProcessChannel
,
ChannelDataEcode
,
ChannelData
,
ChannelDataType
from
.channel
import
ThreadChannel
,
ProcessChannel
,
ChannelDataEcode
,
ChannelData
,
ChannelDataType
from
.util
import
NameGenerator
from
.util
import
NameGenerator
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
_op_name_gen
=
NameGenerator
(
"Op"
)
_op_name_gen
=
NameGenerator
(
"Op"
)
...
@@ -59,6 +60,10 @@ class Op(object):
...
@@ -59,6 +60,10 @@ class Op(object):
self
.
_outputs
=
[]
self
.
_outputs
=
[]
self
.
_profiler
=
None
self
.
_profiler
=
None
# only for multithread
self
.
_for_init_op_lock
=
threading
.
Lock
()
self
.
_succ_init_op
=
False
def
init_profiler
(
self
,
profiler
):
def
init_profiler
(
self
,
profiler
):
self
.
_profiler
=
profiler
self
.
_profiler
=
profiler
...
@@ -71,18 +76,19 @@ class Op(object):
...
@@ -71,18 +76,19 @@ class Op(object):
fetch_names
):
fetch_names
):
if
self
.
with_serving
==
False
:
if
self
.
with_serving
==
False
:
_LOGGER
.
debug
(
"{} no client"
.
format
(
self
.
name
))
_LOGGER
.
debug
(
"{} no client"
.
format
(
self
.
name
))
return
return
None
_LOGGER
.
debug
(
"{} client_config: {}"
.
format
(
self
.
name
,
client_config
))
_LOGGER
.
debug
(
"{} client_config: {}"
.
format
(
self
.
name
,
client_config
))
_LOGGER
.
debug
(
"{} fetch_names: {}"
.
format
(
self
.
name
,
fetch_names
))
_LOGGER
.
debug
(
"{} fetch_names: {}"
.
format
(
self
.
name
,
fetch_names
))
if
client_type
==
'brpc'
:
if
client_type
==
'brpc'
:
self
.
_
client
=
Client
()
client
=
Client
()
self
.
_
client
.
load_client_config
(
client_config
)
client
.
load_client_config
(
client_config
)
elif
client_type
==
'grpc'
:
elif
client_type
==
'grpc'
:
self
.
_
client
=
MultiLangClient
()
client
=
MultiLangClient
()
else
:
else
:
raise
ValueError
(
"unknow client type: {}"
.
format
(
client_type
))
raise
ValueError
(
"unknow client type: {}"
.
format
(
client_type
))
self
.
_
client
.
connect
(
server_endpoints
)
client
.
connect
(
server_endpoints
)
self
.
_fetch_names
=
fetch_names
self
.
_fetch_names
=
fetch_names
return
client
def
_get_input_channel
(
self
):
def
_get_input_channel
(
self
):
return
self
.
_input
return
self
.
_input
...
@@ -130,19 +136,17 @@ class Op(object):
...
@@ -130,19 +136,17 @@ class Op(object):
(
_
,
input_dict
),
=
input_dicts
.
items
()
(
_
,
input_dict
),
=
input_dicts
.
items
()
return
input_dict
return
input_dict
def
process
(
self
,
feed_dict
):
def
process
(
self
,
client_predict_handler
,
feed_dict
):
err
,
err_info
=
ChannelData
.
check_npdata
(
feed_dict
)
err
,
err_info
=
ChannelData
.
check_npdata
(
feed_dict
)
if
err
!=
0
:
if
err
!=
0
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"{} Please override preprocess func."
.
format
(
err_info
))
"{} Please override preprocess func."
.
format
(
err_info
))
_LOGGER
.
debug
(
self
.
_log
(
'feed_dict: {}'
.
format
(
feed_dict
)))
call_result
=
client_predict_handler
(
_LOGGER
.
debug
(
self
.
_log
(
'fetch: {}'
.
format
(
self
.
_fetch_names
)))
call_result
=
self
.
_client
.
predict
(
feed
=
feed_dict
,
fetch
=
self
.
_fetch_names
)
feed
=
feed_dict
,
fetch
=
self
.
_fetch_names
)
_LOGGER
.
debug
(
self
.
_log
(
"get call_result"
))
_LOGGER
.
debug
(
self
.
_log
(
"get call_result"
))
return
call_result
return
call_result
def
postprocess
(
self
,
fetch_dict
):
def
postprocess
(
self
,
input_dict
,
fetch_dict
):
return
fetch_dict
return
fetch_dict
def
stop
(
self
):
def
stop
(
self
):
...
@@ -174,7 +178,7 @@ class Op(object):
...
@@ -174,7 +178,7 @@ class Op(object):
p
=
multiprocessing
.
Process
(
p
=
multiprocessing
.
Process
(
target
=
self
.
_run
,
target
=
self
.
_run
,
args
=
(
concurrency_idx
,
self
.
_get_input_channel
(),
args
=
(
concurrency_idx
,
self
.
_get_input_channel
(),
self
.
_get_output_channels
(),
client_type
))
self
.
_get_output_channels
(),
client_type
,
False
))
p
.
start
()
p
.
start
()
proces
.
append
(
p
)
proces
.
append
(
p
)
return
proces
return
proces
...
@@ -185,12 +189,12 @@ class Op(object):
...
@@ -185,12 +189,12 @@ class Op(object):
t
=
threading
.
Thread
(
t
=
threading
.
Thread
(
target
=
self
.
_run
,
target
=
self
.
_run
,
args
=
(
concurrency_idx
,
self
.
_get_input_channel
(),
args
=
(
concurrency_idx
,
self
.
_get_input_channel
(),
self
.
_get_output_channels
(),
client_type
))
self
.
_get_output_channels
(),
client_type
,
True
))
t
.
start
()
t
.
start
()
threads
.
append
(
t
)
threads
.
append
(
t
)
return
threads
return
threads
def
load_user_resources
(
self
):
def
init_op
(
self
):
pass
pass
def
_run_preprocess
(
self
,
parsed_data
,
data_id
,
log_func
):
def
_run_preprocess
(
self
,
parsed_data
,
data_id
,
log_func
):
...
@@ -222,13 +226,15 @@ class Op(object):
...
@@ -222,13 +226,15 @@ class Op(object):
data_id
=
data_id
)
data_id
=
data_id
)
return
preped_data
,
error_channeldata
return
preped_data
,
error_channeldata
def
_run_process
(
self
,
preped_data
,
data_id
,
log_func
):
def
_run_process
(
self
,
client_predict_handler
,
preped_data
,
data_id
,
log_func
):
midped_data
,
error_channeldata
=
None
,
None
midped_data
,
error_channeldata
=
None
,
None
if
self
.
with_serving
:
if
self
.
with_serving
:
ecode
=
ChannelDataEcode
.
OK
.
value
ecode
=
ChannelDataEcode
.
OK
.
value
if
self
.
_timeout
<=
0
:
if
self
.
_timeout
<=
0
:
try
:
try
:
midped_data
=
self
.
process
(
preped_data
)
midped_data
=
self
.
process
(
client_predict_handler
,
preped_data
)
except
Exception
as
e
:
except
Exception
as
e
:
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
error_info
=
log_func
(
e
)
error_info
=
log_func
(
e
)
...
@@ -237,7 +243,11 @@ class Op(object):
...
@@ -237,7 +243,11 @@ class Op(object):
for
i
in
range
(
self
.
_retry
):
for
i
in
range
(
self
.
_retry
):
try
:
try
:
midped_data
=
func_timeout
.
func_timeout
(
midped_data
=
func_timeout
.
func_timeout
(
self
.
_timeout
,
self
.
process
,
args
=
(
preped_data
,
))
self
.
_timeout
,
self
.
process
,
args
=
(
client_predict_handler
,
preped_data
,
))
except
func_timeout
.
FunctionTimedOut
as
e
:
except
func_timeout
.
FunctionTimedOut
as
e
:
if
i
+
1
>=
self
.
_retry
:
if
i
+
1
>=
self
.
_retry
:
ecode
=
ChannelDataEcode
.
TIMEOUT
.
value
ecode
=
ChannelDataEcode
.
TIMEOUT
.
value
...
@@ -267,10 +277,10 @@ class Op(object):
...
@@ -267,10 +277,10 @@ class Op(object):
midped_data
=
preped_data
midped_data
=
preped_data
return
midped_data
,
error_channeldata
return
midped_data
,
error_channeldata
def
_run_postprocess
(
self
,
midped_data
,
data_id
,
log_func
):
def
_run_postprocess
(
self
,
input_dict
,
midped_data
,
data_id
,
log_func
):
output_data
,
error_channeldata
=
None
,
None
output_data
,
error_channeldata
=
None
,
None
try
:
try
:
postped_data
=
self
.
postprocess
(
midped_data
)
postped_data
=
self
.
postprocess
(
input_dict
,
midped_data
)
except
Exception
as
e
:
except
Exception
as
e
:
error_info
=
log_func
(
e
)
error_info
=
log_func
(
e
)
_LOGGER
.
error
(
error_info
)
_LOGGER
.
error
(
error_info
)
...
@@ -303,8 +313,8 @@ class Op(object):
...
@@ -303,8 +313,8 @@ class Op(object):
data_id
=
data_id
)
data_id
=
data_id
)
return
output_data
,
error_channeldata
return
output_data
,
error_channeldata
def
_run
(
self
,
concurrency_idx
,
input_channel
,
output_channels
,
def
_run
(
self
,
concurrency_idx
,
input_channel
,
output_channels
,
client_type
,
client_type
):
use_multithread
):
def
get_log_func
(
op_info_prefix
):
def
get_log_func
(
op_info_prefix
):
def
log_func
(
info_str
):
def
log_func
(
info_str
):
return
"{} {}"
.
format
(
op_info_prefix
,
info_str
)
return
"{} {}"
.
format
(
op_info_prefix
,
info_str
)
...
@@ -315,12 +325,30 @@ class Op(object):
...
@@ -315,12 +325,30 @@ class Op(object):
log
=
get_log_func
(
op_info_prefix
)
log
=
get_log_func
(
op_info_prefix
)
tid
=
threading
.
current_thread
().
ident
tid
=
threading
.
current_thread
().
ident
client
=
None
client_predict_handler
=
None
# create client based on client_type
# create client based on client_type
self
.
init_client
(
client_type
,
self
.
_client_config
,
try
:
client
=
self
.
init_client
(
client_type
,
self
.
_client_config
,
self
.
_server_endpoints
,
self
.
_fetch_names
)
self
.
_server_endpoints
,
self
.
_fetch_names
)
if
client
is
not
None
:
client_predict_handler
=
client
.
predict
except
Exception
as
e
:
_LOGGER
.
error
(
log
(
e
))
os
.
_exit
(
-
1
)
# load user resources
# load user resources
self
.
load_user_resources
()
try
:
if
use_multithread
:
with
self
.
_for_init_op_lock
:
if
not
self
.
_succ_init_op
:
self
.
init_op
()
self
.
_succ_init_op
=
True
else
:
self
.
init_op
()
except
Exception
as
e
:
_LOGGER
.
error
(
log
(
e
))
os
.
_exit
(
-
1
)
self
.
_is_run
=
True
self
.
_is_run
=
True
while
self
.
_is_run
:
while
self
.
_is_run
:
...
@@ -349,8 +377,8 @@ class Op(object):
...
@@ -349,8 +377,8 @@ class Op(object):
# process
# process
self
.
_profiler_record
(
"{}-midp#{}_0"
.
format
(
op_info_prefix
,
tid
))
self
.
_profiler_record
(
"{}-midp#{}_0"
.
format
(
op_info_prefix
,
tid
))
midped_data
,
error_channeldata
=
self
.
_run_process
(
preped_data
,
midped_data
,
error_channeldata
=
self
.
_run_process
(
data_id
,
log
)
client_predict_handler
,
preped_data
,
data_id
,
log
)
self
.
_profiler_record
(
"{}-midp#{}_1"
.
format
(
op_info_prefix
,
tid
))
self
.
_profiler_record
(
"{}-midp#{}_1"
.
format
(
op_info_prefix
,
tid
))
if
error_channeldata
is
not
None
:
if
error_channeldata
is
not
None
:
self
.
_push_to_output_channels
(
error_channeldata
,
self
.
_push_to_output_channels
(
error_channeldata
,
...
@@ -359,8 +387,8 @@ class Op(object):
...
@@ -359,8 +387,8 @@ class Op(object):
# postprocess
# postprocess
self
.
_profiler_record
(
"{}-postp#{}_0"
.
format
(
op_info_prefix
,
tid
))
self
.
_profiler_record
(
"{}-postp#{}_0"
.
format
(
op_info_prefix
,
tid
))
output_data
,
error_channeldata
=
self
.
_run_postprocess
(
midped_data
,
output_data
,
error_channeldata
=
self
.
_run_postprocess
(
data_id
,
log
)
parsed_data
,
midped_data
,
data_id
,
log
)
self
.
_profiler_record
(
"{}-postp#{}_1"
.
format
(
op_info_prefix
,
tid
))
self
.
_profiler_record
(
"{}-postp#{}_1"
.
format
(
op_info_prefix
,
tid
))
if
error_channeldata
is
not
None
:
if
error_channeldata
is
not
None
:
self
.
_push_to_output_channels
(
error_channeldata
,
self
.
_push_to_output_channels
(
error_channeldata
,
...
@@ -384,7 +412,11 @@ class RequestOp(Op):
...
@@ -384,7 +412,11 @@ class RequestOp(Op):
super
(
RequestOp
,
self
).
__init__
(
super
(
RequestOp
,
self
).
__init__
(
name
=
"#G"
,
input_ops
=
[],
concurrency
=
concurrency
)
name
=
"#G"
,
input_ops
=
[],
concurrency
=
concurrency
)
# load user resources
# load user resources
self
.
load_user_resources
()
try
:
self
.
init_op
()
except
Exception
as
e
:
_LOGGER
.
error
(
e
)
os
.
_exit
(
-
1
)
def
unpack_request_package
(
self
,
request
):
def
unpack_request_package
(
self
,
request
):
dictdata
=
{}
dictdata
=
{}
...
@@ -405,7 +437,11 @@ class ResponseOp(Op):
...
@@ -405,7 +437,11 @@ class ResponseOp(Op):
super
(
ResponseOp
,
self
).
__init__
(
super
(
ResponseOp
,
self
).
__init__
(
name
=
"#R"
,
input_ops
=
input_ops
,
concurrency
=
concurrency
)
name
=
"#R"
,
input_ops
=
input_ops
,
concurrency
=
concurrency
)
# load user resources
# load user resources
self
.
load_user_resources
()
try
:
self
.
init_op
()
except
Exception
as
e
:
_LOGGER
.
error
(
e
)
os
.
_exit
(
-
1
)
def
pack_response_package
(
self
,
channeldata
):
def
pack_response_package
(
self
,
channeldata
):
resp
=
pipeline_service_pb2
.
Response
()
resp
=
pipeline_service_pb2
.
Response
()
...
@@ -450,17 +486,26 @@ class VirtualOp(Op):
...
@@ -450,17 +486,26 @@ class VirtualOp(Op):
def
add_virtual_pred_op
(
self
,
op
):
def
add_virtual_pred_op
(
self
,
op
):
self
.
_virtual_pred_ops
.
append
(
op
)
self
.
_virtual_pred_ops
.
append
(
op
)
def
_actual_pred_op_names
(
self
,
op
):
if
not
isinstance
(
op
,
VirtualOp
):
return
[
op
.
name
]
names
=
[]
for
x
in
op
.
_virtual_pred_ops
:
names
.
extend
(
self
.
_actual_pred_op_names
(
x
))
return
names
def
add_output_channel
(
self
,
channel
):
def
add_output_channel
(
self
,
channel
):
if
not
isinstance
(
channel
,
(
ThreadChannel
,
ProcessChannel
)):
if
not
isinstance
(
channel
,
(
ThreadChannel
,
ProcessChannel
)):
raise
TypeError
(
raise
TypeError
(
self
.
_log
(
'output channel must be Channel type, not {}'
.
format
(
self
.
_log
(
'output channel must be Channel type, not {}'
.
format
(
type
(
channel
))))
type
(
channel
))))
for
op
in
self
.
_virtual_pred_ops
:
for
op
in
self
.
_virtual_pred_ops
:
channel
.
add_producer
(
op
.
name
)
for
op_name
in
self
.
_actual_pred_op_names
(
op
):
channel
.
add_producer
(
op_name
)
self
.
_outputs
.
append
(
channel
)
self
.
_outputs
.
append
(
channel
)
def
_run
(
self
,
concurrency_idx
,
input_channel
,
output_channels
,
def
_run
(
self
,
concurrency_idx
,
input_channel
,
output_channels
,
client_type
,
client_type
):
use_multithread
):
def
get_log_func
(
op_info_prefix
):
def
get_log_func
(
op_info_prefix
):
def
log_func
(
info_str
):
def
log_func
(
info_str
):
return
"{} {}"
.
format
(
op_info_prefix
,
info_str
)
return
"{} {}"
.
format
(
op_info_prefix
,
info_str
)
...
...
python/pipeline/pipeline_client.py
浏览文件 @
f3eb9d4a
...
@@ -20,7 +20,7 @@ import functools
...
@@ -20,7 +20,7 @@ import functools
from
.proto
import
pipeline_service_pb2
from
.proto
import
pipeline_service_pb2
from
.proto
import
pipeline_service_pb2_grpc
from
.proto
import
pipeline_service_pb2_grpc
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
class
PipelineClient
(
object
):
class
PipelineClient
(
object
):
...
@@ -52,7 +52,7 @@ class PipelineClient(object):
...
@@ -52,7 +52,7 @@ class PipelineClient(object):
return
{
"ecode"
:
resp
.
ecode
,
"error_info"
:
resp
.
error_info
}
return
{
"ecode"
:
resp
.
ecode
,
"error_info"
:
resp
.
error_info
}
fetch_map
=
{
"ecode"
:
resp
.
ecode
}
fetch_map
=
{
"ecode"
:
resp
.
ecode
}
for
idx
,
key
in
enumerate
(
resp
.
key
):
for
idx
,
key
in
enumerate
(
resp
.
key
):
if
key
not
in
fetch
:
if
fetch
is
not
None
and
key
not
in
fetch
:
continue
continue
data
=
resp
.
value
[
idx
]
data
=
resp
.
value
[
idx
]
try
:
try
:
...
@@ -62,16 +62,16 @@ class PipelineClient(object):
...
@@ -62,16 +62,16 @@ class PipelineClient(object):
fetch_map
[
key
]
=
data
fetch_map
[
key
]
=
data
return
fetch_map
return
fetch_map
def
predict
(
self
,
feed_dict
,
fetch
,
asyn
=
False
):
def
predict
(
self
,
feed_dict
,
fetch
=
None
,
asyn
=
False
):
if
not
isinstance
(
feed_dict
,
dict
):
if
not
isinstance
(
feed_dict
,
dict
):
raise
TypeError
(
raise
TypeError
(
"feed must be dict type with format: {name: value}."
)
"feed must be dict type with format: {name: value}."
)
if
not
isinstance
(
fetch
,
list
):
if
fetch
is
not
None
and
not
isinstance
(
fetch
,
list
):
raise
TypeError
(
"fetch must be list type with format: [name]."
)
raise
TypeError
(
"fetch must be list type with format: [name]."
)
req
=
self
.
_pack_request_package
(
feed_dict
)
req
=
self
.
_pack_request_package
(
feed_dict
)
if
not
asyn
:
if
not
asyn
:
resp
=
self
.
_stub
.
inference
(
req
)
resp
=
self
.
_stub
.
inference
(
req
)
return
self
.
_unpack_response_package
(
resp
)
return
self
.
_unpack_response_package
(
resp
,
fetch
)
else
:
else
:
call_future
=
self
.
_stub
.
inference
.
future
(
req
)
call_future
=
self
.
_stub
.
inference
.
future
(
req
)
return
PipelinePredictFuture
(
return
PipelinePredictFuture
(
...
...
python/pipeline/pipeline_server.py
浏览文件 @
f3eb9d4a
...
@@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod
...
@@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod
from
.profiler
import
TimeProfiler
from
.profiler
import
TimeProfiler
from
.util
import
NameGenerator
from
.util
import
NameGenerator
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
_profiler
=
TimeProfiler
()
_profiler
=
TimeProfiler
()
...
@@ -235,6 +235,10 @@ class PipelineServer(object):
...
@@ -235,6 +235,10 @@ class PipelineServer(object):
return
use_ops
,
succ_ops_of_use_op
return
use_ops
,
succ_ops_of_use_op
use_ops
,
out_degree_ops
=
get_use_ops
(
response_op
)
use_ops
,
out_degree_ops
=
get_use_ops
(
response_op
)
_LOGGER
.
info
(
"================= use op =================="
)
for
op
in
use_ops
:
_LOGGER
.
info
(
op
.
name
)
_LOGGER
.
info
(
"==========================================="
)
if
len
(
use_ops
)
<=
1
:
if
len
(
use_ops
)
<=
1
:
raise
Exception
(
raise
Exception
(
"Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
"Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
...
...
python/pipeline/profiler.py
浏览文件 @
f3eb9d4a
...
@@ -24,7 +24,7 @@ else:
...
@@ -24,7 +24,7 @@ else:
raise
Exception
(
"Error Python version"
)
raise
Exception
(
"Error Python version"
)
import
time
import
time
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
class
TimeProfiler
(
object
):
class
TimeProfiler
(
object
):
...
@@ -58,7 +58,7 @@ class TimeProfiler(object):
...
@@ -58,7 +58,7 @@ class TimeProfiler(object):
print_str
+=
"{}_{}:{} "
.
format
(
name
,
tag
,
timestamp
)
print_str
+=
"{}_{}:{} "
.
format
(
name
,
tag
,
timestamp
)
else
:
else
:
tmp
[
name
]
=
(
tag
,
timestamp
)
tmp
[
name
]
=
(
tag
,
timestamp
)
print_str
+=
"
\n
"
print_str
=
"
\n
{}
\n
"
.
format
(
print_str
)
sys
.
stderr
.
write
(
print_str
)
sys
.
stderr
.
write
(
print_str
)
for
name
,
item
in
tmp
.
items
():
for
name
,
item
in
tmp
.
items
():
tag
,
timestamp
=
item
tag
,
timestamp
=
item
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录