Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
f3eb9d4a
S
Serving
项目概览
PaddlePaddle
/
Serving
大约 1 年 前同步成功
通知
186
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f3eb9d4a
编写于
7月 02, 2020
作者:
B
barriery
提交者:
GitHub
7月 02, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #710 from barrierye/pipeline-update
update pipeline
上级
0dc8b905
22b168ca
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
127 addition
and
49 deletion
+127
-49
python/pipeline/channel.py
python/pipeline/channel.py
+37
-8
python/pipeline/operator.py
python/pipeline/operator.py
+78
-33
python/pipeline/pipeline_client.py
python/pipeline/pipeline_client.py
+5
-5
python/pipeline/pipeline_server.py
python/pipeline/pipeline_server.py
+5
-1
python/pipeline/profiler.py
python/pipeline/profiler.py
+2
-2
未找到文件。
python/pipeline/channel.py
浏览文件 @
f3eb9d4a
...
...
@@ -27,7 +27,7 @@ import logging
import
enum
import
copy
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
class
ChannelDataEcode
(
enum
.
Enum
):
...
...
@@ -92,7 +92,16 @@ class ChannelData(object):
def
check_dictdata
(
dictdata
):
ecode
=
ChannelDataEcode
.
OK
.
value
error_info
=
None
if
not
isinstance
(
dictdata
,
dict
):
if
isinstance
(
dictdata
,
list
):
# batch data
for
sample
in
dictdata
:
if
not
isinstance
(
sample
,
dict
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be dict, but get {}."
.
format
(
type
(
sample
))
break
elif
not
isinstance
(
dictdata
,
dict
):
# batch size = 1
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be dict, but get {}."
.
format
(
type
(
dictdata
))
...
...
@@ -102,12 +111,32 @@ class ChannelData(object):
def
check_npdata
(
npdata
):
ecode
=
ChannelDataEcode
.
OK
.
value
error_info
=
None
for
_
,
value
in
npdata
.
items
():
if
not
isinstance
(
value
,
np
.
ndarray
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be np.ndarray, but get {}."
.
format
(
type
(
value
))
break
if
isinstance
(
npdata
,
list
):
# batch data
for
sample
in
npdata
:
if
not
isinstance
(
sample
,
dict
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be dict, but get {}."
.
format
(
type
(
sample
))
break
for
_
,
value
in
sample
.
items
():
if
not
isinstance
(
value
,
np
.
ndarray
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be np.ndarray, but get {}."
.
format
(
type
(
value
))
return
ecode
,
error_info
elif
isinstance
(
npdata
,
dict
):
# batch_size = 1
for
_
,
value
in
npdata
.
items
():
if
not
isinstance
(
value
,
np
.
ndarray
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be np.ndarray, but get {}."
.
format
(
type
(
value
))
break
else
:
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
"the value of data must "
\
"be dict, but get {}."
.
format
(
type
(
npdata
))
return
ecode
,
error_info
def
parse
(
self
):
...
...
python/pipeline/operator.py
浏览文件 @
f3eb9d4a
...
...
@@ -19,13 +19,14 @@ from paddle_serving_client import MultiLangClient, Client
from
concurrent
import
futures
import
logging
import
func_timeout
import
os
from
numpy
import
*
from
.proto
import
pipeline_service_pb2
from
.channel
import
ThreadChannel
,
ProcessChannel
,
ChannelDataEcode
,
ChannelData
,
ChannelDataType
from
.util
import
NameGenerator
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
_op_name_gen
=
NameGenerator
(
"Op"
)
...
...
@@ -59,6 +60,10 @@ class Op(object):
self
.
_outputs
=
[]
self
.
_profiler
=
None
# only for multithread
self
.
_for_init_op_lock
=
threading
.
Lock
()
self
.
_succ_init_op
=
False
def
init_profiler
(
self
,
profiler
):
self
.
_profiler
=
profiler
...
...
@@ -71,18 +76,19 @@ class Op(object):
fetch_names
):
if
self
.
with_serving
==
False
:
_LOGGER
.
debug
(
"{} no client"
.
format
(
self
.
name
))
return
return
None
_LOGGER
.
debug
(
"{} client_config: {}"
.
format
(
self
.
name
,
client_config
))
_LOGGER
.
debug
(
"{} fetch_names: {}"
.
format
(
self
.
name
,
fetch_names
))
if
client_type
==
'brpc'
:
self
.
_
client
=
Client
()
self
.
_
client
.
load_client_config
(
client_config
)
client
=
Client
()
client
.
load_client_config
(
client_config
)
elif
client_type
==
'grpc'
:
self
.
_
client
=
MultiLangClient
()
client
=
MultiLangClient
()
else
:
raise
ValueError
(
"unknow client type: {}"
.
format
(
client_type
))
self
.
_
client
.
connect
(
server_endpoints
)
client
.
connect
(
server_endpoints
)
self
.
_fetch_names
=
fetch_names
return
client
def
_get_input_channel
(
self
):
return
self
.
_input
...
...
@@ -130,19 +136,17 @@ class Op(object):
(
_
,
input_dict
),
=
input_dicts
.
items
()
return
input_dict
def
process
(
self
,
feed_dict
):
def
process
(
self
,
client_predict_handler
,
feed_dict
):
err
,
err_info
=
ChannelData
.
check_npdata
(
feed_dict
)
if
err
!=
0
:
raise
NotImplementedError
(
"{} Please override preprocess func."
.
format
(
err_info
))
_LOGGER
.
debug
(
self
.
_log
(
'feed_dict: {}'
.
format
(
feed_dict
)))
_LOGGER
.
debug
(
self
.
_log
(
'fetch: {}'
.
format
(
self
.
_fetch_names
)))
call_result
=
self
.
_client
.
predict
(
call_result
=
client_predict_handler
(
feed
=
feed_dict
,
fetch
=
self
.
_fetch_names
)
_LOGGER
.
debug
(
self
.
_log
(
"get call_result"
))
return
call_result
def
postprocess
(
self
,
fetch_dict
):
def
postprocess
(
self
,
input_dict
,
fetch_dict
):
return
fetch_dict
def
stop
(
self
):
...
...
@@ -174,7 +178,7 @@ class Op(object):
p
=
multiprocessing
.
Process
(
target
=
self
.
_run
,
args
=
(
concurrency_idx
,
self
.
_get_input_channel
(),
self
.
_get_output_channels
(),
client_type
))
self
.
_get_output_channels
(),
client_type
,
False
))
p
.
start
()
proces
.
append
(
p
)
return
proces
...
...
@@ -185,12 +189,12 @@ class Op(object):
t
=
threading
.
Thread
(
target
=
self
.
_run
,
args
=
(
concurrency_idx
,
self
.
_get_input_channel
(),
self
.
_get_output_channels
(),
client_type
))
self
.
_get_output_channels
(),
client_type
,
True
))
t
.
start
()
threads
.
append
(
t
)
return
threads
def
load_user_resources
(
self
):
def
init_op
(
self
):
pass
def
_run_preprocess
(
self
,
parsed_data
,
data_id
,
log_func
):
...
...
@@ -222,13 +226,15 @@ class Op(object):
data_id
=
data_id
)
return
preped_data
,
error_channeldata
def
_run_process
(
self
,
preped_data
,
data_id
,
log_func
):
def
_run_process
(
self
,
client_predict_handler
,
preped_data
,
data_id
,
log_func
):
midped_data
,
error_channeldata
=
None
,
None
if
self
.
with_serving
:
ecode
=
ChannelDataEcode
.
OK
.
value
if
self
.
_timeout
<=
0
:
try
:
midped_data
=
self
.
process
(
preped_data
)
midped_data
=
self
.
process
(
client_predict_handler
,
preped_data
)
except
Exception
as
e
:
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
error_info
=
log_func
(
e
)
...
...
@@ -237,7 +243,11 @@ class Op(object):
for
i
in
range
(
self
.
_retry
):
try
:
midped_data
=
func_timeout
.
func_timeout
(
self
.
_timeout
,
self
.
process
,
args
=
(
preped_data
,
))
self
.
_timeout
,
self
.
process
,
args
=
(
client_predict_handler
,
preped_data
,
))
except
func_timeout
.
FunctionTimedOut
as
e
:
if
i
+
1
>=
self
.
_retry
:
ecode
=
ChannelDataEcode
.
TIMEOUT
.
value
...
...
@@ -267,10 +277,10 @@ class Op(object):
midped_data
=
preped_data
return
midped_data
,
error_channeldata
def
_run_postprocess
(
self
,
midped_data
,
data_id
,
log_func
):
def
_run_postprocess
(
self
,
input_dict
,
midped_data
,
data_id
,
log_func
):
output_data
,
error_channeldata
=
None
,
None
try
:
postped_data
=
self
.
postprocess
(
midped_data
)
postped_data
=
self
.
postprocess
(
input_dict
,
midped_data
)
except
Exception
as
e
:
error_info
=
log_func
(
e
)
_LOGGER
.
error
(
error_info
)
...
...
@@ -303,8 +313,8 @@ class Op(object):
data_id
=
data_id
)
return
output_data
,
error_channeldata
def
_run
(
self
,
concurrency_idx
,
input_channel
,
output_channels
,
client_type
):
def
_run
(
self
,
concurrency_idx
,
input_channel
,
output_channels
,
client_type
,
use_multithread
):
def
get_log_func
(
op_info_prefix
):
def
log_func
(
info_str
):
return
"{} {}"
.
format
(
op_info_prefix
,
info_str
)
...
...
@@ -315,12 +325,30 @@ class Op(object):
log
=
get_log_func
(
op_info_prefix
)
tid
=
threading
.
current_thread
().
ident
client
=
None
client_predict_handler
=
None
# create client based on client_type
self
.
init_client
(
client_type
,
self
.
_client_config
,
self
.
_server_endpoints
,
self
.
_fetch_names
)
try
:
client
=
self
.
init_client
(
client_type
,
self
.
_client_config
,
self
.
_server_endpoints
,
self
.
_fetch_names
)
if
client
is
not
None
:
client_predict_handler
=
client
.
predict
except
Exception
as
e
:
_LOGGER
.
error
(
log
(
e
))
os
.
_exit
(
-
1
)
# load user resources
self
.
load_user_resources
()
try
:
if
use_multithread
:
with
self
.
_for_init_op_lock
:
if
not
self
.
_succ_init_op
:
self
.
init_op
()
self
.
_succ_init_op
=
True
else
:
self
.
init_op
()
except
Exception
as
e
:
_LOGGER
.
error
(
log
(
e
))
os
.
_exit
(
-
1
)
self
.
_is_run
=
True
while
self
.
_is_run
:
...
...
@@ -349,8 +377,8 @@ class Op(object):
# process
self
.
_profiler_record
(
"{}-midp#{}_0"
.
format
(
op_info_prefix
,
tid
))
midped_data
,
error_channeldata
=
self
.
_run_process
(
preped_data
,
data_id
,
log
)
midped_data
,
error_channeldata
=
self
.
_run_process
(
client_predict_handler
,
preped_data
,
data_id
,
log
)
self
.
_profiler_record
(
"{}-midp#{}_1"
.
format
(
op_info_prefix
,
tid
))
if
error_channeldata
is
not
None
:
self
.
_push_to_output_channels
(
error_channeldata
,
...
...
@@ -359,8 +387,8 @@ class Op(object):
# postprocess
self
.
_profiler_record
(
"{}-postp#{}_0"
.
format
(
op_info_prefix
,
tid
))
output_data
,
error_channeldata
=
self
.
_run_postprocess
(
midped_data
,
data_id
,
log
)
output_data
,
error_channeldata
=
self
.
_run_postprocess
(
parsed_data
,
midped_data
,
data_id
,
log
)
self
.
_profiler_record
(
"{}-postp#{}_1"
.
format
(
op_info_prefix
,
tid
))
if
error_channeldata
is
not
None
:
self
.
_push_to_output_channels
(
error_channeldata
,
...
...
@@ -384,7 +412,11 @@ class RequestOp(Op):
super
(
RequestOp
,
self
).
__init__
(
name
=
"#G"
,
input_ops
=
[],
concurrency
=
concurrency
)
# load user resources
self
.
load_user_resources
()
try
:
self
.
init_op
()
except
Exception
as
e
:
_LOGGER
.
error
(
e
)
os
.
_exit
(
-
1
)
def
unpack_request_package
(
self
,
request
):
dictdata
=
{}
...
...
@@ -405,7 +437,11 @@ class ResponseOp(Op):
super
(
ResponseOp
,
self
).
__init__
(
name
=
"#R"
,
input_ops
=
input_ops
,
concurrency
=
concurrency
)
# load user resources
self
.
load_user_resources
()
try
:
self
.
init_op
()
except
Exception
as
e
:
_LOGGER
.
error
(
e
)
os
.
_exit
(
-
1
)
def
pack_response_package
(
self
,
channeldata
):
resp
=
pipeline_service_pb2
.
Response
()
...
...
@@ -450,17 +486,26 @@ class VirtualOp(Op):
def
add_virtual_pred_op
(
self
,
op
):
self
.
_virtual_pred_ops
.
append
(
op
)
def
_actual_pred_op_names
(
self
,
op
):
if
not
isinstance
(
op
,
VirtualOp
):
return
[
op
.
name
]
names
=
[]
for
x
in
op
.
_virtual_pred_ops
:
names
.
extend
(
self
.
_actual_pred_op_names
(
x
))
return
names
def
add_output_channel
(
self
,
channel
):
if
not
isinstance
(
channel
,
(
ThreadChannel
,
ProcessChannel
)):
raise
TypeError
(
self
.
_log
(
'output channel must be Channel type, not {}'
.
format
(
type
(
channel
))))
for
op
in
self
.
_virtual_pred_ops
:
channel
.
add_producer
(
op
.
name
)
for
op_name
in
self
.
_actual_pred_op_names
(
op
):
channel
.
add_producer
(
op_name
)
self
.
_outputs
.
append
(
channel
)
def
_run
(
self
,
concurrency_idx
,
input_channel
,
output_channels
,
client_type
):
def
_run
(
self
,
concurrency_idx
,
input_channel
,
output_channels
,
client_type
,
use_multithread
):
def
get_log_func
(
op_info_prefix
):
def
log_func
(
info_str
):
return
"{} {}"
.
format
(
op_info_prefix
,
info_str
)
...
...
python/pipeline/pipeline_client.py
浏览文件 @
f3eb9d4a
...
...
@@ -20,7 +20,7 @@ import functools
from
.proto
import
pipeline_service_pb2
from
.proto
import
pipeline_service_pb2_grpc
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
class
PipelineClient
(
object
):
...
...
@@ -52,7 +52,7 @@ class PipelineClient(object):
return
{
"ecode"
:
resp
.
ecode
,
"error_info"
:
resp
.
error_info
}
fetch_map
=
{
"ecode"
:
resp
.
ecode
}
for
idx
,
key
in
enumerate
(
resp
.
key
):
if
key
not
in
fetch
:
if
fetch
is
not
None
and
key
not
in
fetch
:
continue
data
=
resp
.
value
[
idx
]
try
:
...
...
@@ -62,16 +62,16 @@ class PipelineClient(object):
fetch_map
[
key
]
=
data
return
fetch_map
def
predict
(
self
,
feed_dict
,
fetch
,
asyn
=
False
):
def
predict
(
self
,
feed_dict
,
fetch
=
None
,
asyn
=
False
):
if
not
isinstance
(
feed_dict
,
dict
):
raise
TypeError
(
"feed must be dict type with format: {name: value}."
)
if
not
isinstance
(
fetch
,
list
):
if
fetch
is
not
None
and
not
isinstance
(
fetch
,
list
):
raise
TypeError
(
"fetch must be list type with format: [name]."
)
req
=
self
.
_pack_request_package
(
feed_dict
)
if
not
asyn
:
resp
=
self
.
_stub
.
inference
(
req
)
return
self
.
_unpack_response_package
(
resp
)
return
self
.
_unpack_response_package
(
resp
,
fetch
)
else
:
call_future
=
self
.
_stub
.
inference
.
future
(
req
)
return
PipelinePredictFuture
(
...
...
python/pipeline/pipeline_server.py
浏览文件 @
f3eb9d4a
...
...
@@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod
from
.profiler
import
TimeProfiler
from
.util
import
NameGenerator
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
_profiler
=
TimeProfiler
()
...
...
@@ -235,6 +235,10 @@ class PipelineServer(object):
return
use_ops
,
succ_ops_of_use_op
use_ops
,
out_degree_ops
=
get_use_ops
(
response_op
)
_LOGGER
.
info
(
"================= use op =================="
)
for
op
in
use_ops
:
_LOGGER
.
info
(
op
.
name
)
_LOGGER
.
info
(
"==========================================="
)
if
len
(
use_ops
)
<=
1
:
raise
Exception
(
"Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
...
...
python/pipeline/profiler.py
浏览文件 @
f3eb9d4a
...
...
@@ -24,7 +24,7 @@ else:
raise
Exception
(
"Error Python version"
)
import
time
_LOGGER
=
logging
.
getLogger
(
__name__
)
_LOGGER
=
logging
.
getLogger
()
class
TimeProfiler
(
object
):
...
...
@@ -58,7 +58,7 @@ class TimeProfiler(object):
print_str
+=
"{}_{}:{} "
.
format
(
name
,
tag
,
timestamp
)
else
:
tmp
[
name
]
=
(
tag
,
timestamp
)
print_str
+=
"
\n
"
print_str
=
"
\n
{}
\n
"
.
format
(
print_str
)
sys
.
stderr
.
write
(
print_str
)
for
name
,
item
in
tmp
.
items
():
tag
,
timestamp
=
item
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录