Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
0712a90a
S
Serving
项目概览
PaddlePaddle
/
Serving
大约 2 年 前同步成功
通知
187
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0712a90a
编写于
6月 12, 2020
作者:
B
barrierye
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix
上级
b4bcf477
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
193 addition
and
87 deletion
+193
-87
python/examples/imdb/test_py_server.py
python/examples/imdb/test_py_server.py
+0
-3
python/paddle_serving_server/pyserver.py
python/paddle_serving_server/pyserver.py
+193
-84
未找到文件。
python/examples/imdb/test_py_server.py
浏览文件 @
0712a90a
...
@@ -27,8 +27,6 @@ logging.basicConfig(
...
@@ -27,8 +27,6 @@ logging.basicConfig(
class
CombineOp
(
Op
):
class
CombineOp
(
Op
):
pass
'''
def
preprocess
(
self
,
input_data
):
def
preprocess
(
self
,
input_data
):
combined_prediction
=
0
combined_prediction
=
0
for
op_name
,
channeldata
in
input_data
.
items
():
for
op_name
,
channeldata
in
input_data
.
items
():
...
@@ -37,7 +35,6 @@ class CombineOp(Op):
...
@@ -37,7 +35,6 @@ class CombineOp(Op):
combined_prediction
+=
data
[
"prediction"
]
combined_prediction
+=
data
[
"prediction"
]
data
=
{
"combined_prediction"
:
combined_prediction
/
2
}
data
=
{
"combined_prediction"
:
combined_prediction
/
2
}
return
data
return
data
'''
read_op
=
Op
(
name
=
"read"
,
inputs
=
None
)
read_op
=
Op
(
name
=
"read"
,
inputs
=
None
)
...
...
python/paddle_serving_server/pyserver.py
浏览文件 @
0712a90a
...
@@ -98,10 +98,10 @@ class ChannelData(object):
...
@@ -98,10 +98,10 @@ class ChannelData(object):
'''
'''
There are several ways to use it:
There are several ways to use it:
-
ChannelData(future, pbdata[, callback_func])
1.
ChannelData(future, pbdata[, callback_func])
-
ChannelData(future, data_id[, callback_func])
2.
ChannelData(future, data_id[, callback_func])
-
ChannelData(pbdata)
3.
ChannelData(pbdata)
-
ChannelData(ecode, error_info, data_id)
4.
ChannelData(ecode, error_info, data_id)
'''
'''
if
ecode
is
not
None
:
if
ecode
is
not
None
:
if
data_id
is
None
or
error_info
is
None
:
if
data_id
is
None
or
error_info
is
None
:
...
@@ -138,9 +138,8 @@ class ChannelData(object):
...
@@ -138,9 +138,8 @@ class ChannelData(object):
if
self
.
callback_func
is
not
None
:
if
self
.
callback_func
is
not
None
:
feed
=
self
.
callback_func
(
feed
)
feed
=
self
.
callback_func
(
feed
)
else
:
else
:
raise
TypeError
(
raise
TypeError
(
"Error type({}) in pbdata.type."
.
format
(
self
.
_log
(
"Error type({}) in pbdata.type."
.
format
(
self
.
pbdata
.
type
))
self
.
pbdata
.
type
)))
return
feed
return
feed
...
@@ -334,6 +333,7 @@ class Channel(Queue.Queue):
...
@@ -334,6 +333,7 @@ class Channel(Queue.Queue):
#TODO
#TODO
self
.
close
()
self
.
close
()
self
.
_stop
=
True
self
.
_stop
=
True
self
.
_cv
.
notify_all
()
class
Op
(
object
):
class
Op
(
object
):
...
@@ -358,7 +358,7 @@ class Op(object):
...
@@ -358,7 +358,7 @@ class Op(object):
self
.
_server_port
=
server_port
self
.
_server_port
=
server_port
self
.
_device
=
device
self
.
_device
=
device
self
.
_timeout
=
timeout
self
.
_timeout
=
timeout
self
.
_retry
=
retry
self
.
_retry
=
max
(
1
,
retry
)
self
.
_input
=
None
self
.
_input
=
None
self
.
_outputs
=
[]
self
.
_outputs
=
[]
...
@@ -443,48 +443,51 @@ class Op(object):
...
@@ -443,48 +443,51 @@ class Op(object):
self
.
_run
=
False
self
.
_run
=
False
def
_parse_channeldata
(
self
,
channeldata
):
def
_parse_channeldata
(
self
,
channeldata
):
data_id
,
error_data
=
None
,
None
data_id
,
error_
pb
data
=
None
,
None
if
isinstance
(
channeldata
,
dict
):
if
isinstance
(
channeldata
,
dict
):
parsed_data
=
{}
parsed_data
=
{}
key
=
channeldata
.
keys
()[
0
]
key
=
channeldata
.
keys
()[
0
]
data_id
=
channeldata
[
key
].
pbdata
.
id
data_id
=
channeldata
[
key
].
pbdata
.
id
for
_
,
data
in
channeldata
.
items
():
for
_
,
data
in
channeldata
.
items
():
if
data
.
pbdata
.
ecode
!=
0
:
if
data
.
pbdata
.
ecode
!=
ChannelDataEcode
.
OK
.
value
:
error_data
=
data
.
pbdata
error_
pb
data
=
data
.
pbdata
break
break
else
:
else
:
data_id
=
channeldata
.
pbdata
.
id
data_id
=
channeldata
.
pbdata
.
id
if
channeldata
.
pbdata
.
ecode
!=
0
:
if
channeldata
.
pbdata
.
ecode
!=
ChannelDataEcode
.
OK
.
value
:
error_data
=
channeldata
.
pbdata
error_
pb
data
=
channeldata
.
pbdata
return
data_id
,
error_data
return
data_id
,
error_
pb
data
def
_push_to_output_channels
(
self
,
data
):
def
_push_to_output_channels
(
self
,
data
,
name
=
None
):
if
name
is
None
:
name
=
self
.
name
for
channel
in
self
.
_outputs
:
for
channel
in
self
.
_outputs
:
channel
.
push
(
data
,
self
.
name
)
channel
.
push
(
data
,
name
)
def
start
(
self
,
concurrency_idx
):
def
start
(
self
,
concurrency_idx
):
op_info_prefix
=
"[{}{}]"
.
format
(
self
.
name
,
concurrency_idx
)
op_info_prefix
=
"[{}
|
{}]"
.
format
(
self
.
name
,
concurrency_idx
)
log
=
self
.
_get_log_func
(
op_info_prefix
)
log
=
self
.
_get_log_func
(
op_info_prefix
)
self
.
_run
=
True
self
.
_run
=
True
while
self
.
_run
:
while
self
.
_run
:
_profiler
.
record
(
"{}-get_0"
.
format
(
op_info_prefix
))
_profiler
.
record
(
"{}-get_0"
.
format
(
op_info_prefix
))
input_
data
=
self
.
_input
.
front
(
self
.
name
)
channel
data
=
self
.
_input
.
front
(
self
.
name
)
_profiler
.
record
(
"{}-get_1"
.
format
(
op_info_prefix
))
_profiler
.
record
(
"{}-get_1"
.
format
(
op_info_prefix
))
logging
.
debug
(
log
(
"input_data: {}"
.
format
(
input_
data
)))
logging
.
debug
(
log
(
"input_data: {}"
.
format
(
channel
data
)))
data_id
,
error_
data
=
self
.
_parse_channeldata
(
input_
data
)
data_id
,
error_
pbdata
=
self
.
_parse_channeldata
(
channel
data
)
#
predecessor Op error
#
error data in predecessor Op
if
error_data
is
not
None
:
if
error_
pb
data
is
not
None
:
self
.
_push_to_output_channels
(
ChannelData
(
pbdata
=
error_data
))
self
.
_push_to_output_channels
(
ChannelData
(
pbdata
=
error_
pb
data
))
continue
continue
# prepr
ocess function not implemented
# prepr
ecess
try
:
try
:
_profiler
.
record
(
"{}-prep_0"
.
format
(
op_info_prefix
))
_profiler
.
record
(
"{}-prep_0"
.
format
(
op_info_prefix
))
data
=
self
.
preprocess
(
input_
data
)
preped_data
=
self
.
preprocess
(
channel
data
)
_profiler
.
record
(
"{}-prep_1"
.
format
(
op_info_prefix
))
_profiler
.
record
(
"{}-prep_1"
.
format
(
op_info_prefix
))
except
NotImplementedError
as
e
:
except
NotImplementedError
as
e
:
# preprocess function not implemented
error_info
=
log
(
e
)
error_info
=
log
(
e
)
logging
.
error
(
error_info
)
logging
.
error
(
error_info
)
self
.
_push_to_output_channels
(
self
.
_push_to_output_channels
(
...
@@ -493,18 +496,35 @@ class Op(object):
...
@@ -493,18 +496,35 @@ class Op(object):
error_info
=
error_info
,
error_info
=
error_info
,
data_id
=
data_id
))
data_id
=
data_id
))
continue
continue
except
TypeError
as
e
:
# Error type in channeldata.pbdata.type
error_info
=
log
(
e
)
logging
.
error
(
error_info
)
self
.
_push_to_output_channels
(
ChannelData
(
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
,
error_info
=
error_info
,
data_id
=
data_id
))
continue
except
Exception
as
e
:
error_info
=
log
(
e
)
logging
.
error
(
error_info
)
self
.
_push_to_output_channels
(
ChannelData
(
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
,
error_info
=
error_info
,
data_id
=
data_id
))
continue
# midprocess
# midprocess
call_future
=
None
call_future
=
None
ecode
=
0
error_info
=
None
if
self
.
with_serving
():
if
self
.
with_serving
():
ecode
=
ChannelDataEcode
.
OK
.
value
_profiler
.
record
(
"{}-midp_0"
.
format
(
op_info_prefix
))
_profiler
.
record
(
"{}-midp_0"
.
format
(
op_info_prefix
))
if
self
.
_timeout
<=
0
:
if
self
.
_timeout
<=
0
:
try
:
try
:
call_future
=
self
.
midprocess
(
data
)
call_future
=
self
.
midprocess
(
preped_
data
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
self
.
_log
(
e
))
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
error_info
=
log
(
e
)
error_info
=
log
(
e
)
logging
.
error
(
error_info
)
logging
.
error
(
error_info
)
...
@@ -512,15 +532,17 @@ class Op(object):
...
@@ -512,15 +532,17 @@ class Op(object):
for
i
in
range
(
self
.
_retry
):
for
i
in
range
(
self
.
_retry
):
try
:
try
:
call_future
=
func_timeout
.
func_timeout
(
call_future
=
func_timeout
.
func_timeout
(
self
.
_timeout
,
self
.
midprocess
,
args
=
(
data
,
))
self
.
_timeout
,
except
func_timeout
.
FunctionTimedOut
:
self
.
midprocess
,
args
=
(
preped_data
,
))
except
func_timeout
.
FunctionTimedOut
as
e
:
if
i
+
1
>=
self
.
_retry
:
if
i
+
1
>=
self
.
_retry
:
ecode
=
ChannelDataEcode
.
TIMEOUT
.
value
ecode
=
ChannelDataEcode
.
TIMEOUT
.
value
error_info
=
"{} timeout"
.
format
(
op_info_prefix
)
error_info
=
log
(
e
)
logging
.
error
(
error_info
)
else
:
else
:
logging
.
warn
(
logging
.
warn
(
log
(
"warn: timeout, retry({})"
.
format
(
i
+
log
(
"timeout, retry({})"
.
format
(
i
+
1
)))
1
)))
except
Exception
as
e
:
except
Exception
as
e
:
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
error_info
=
log
(
e
)
error_info
=
log
(
e
)
...
@@ -528,7 +550,7 @@ class Op(object):
...
@@ -528,7 +550,7 @@ class Op(object):
break
break
else
:
else
:
break
break
if
ecode
!=
0
:
if
ecode
!=
ChannelDataEcode
.
OK
.
value
:
self
.
_push_to_output_channels
(
self
.
_push_to_output_channels
(
ChannelData
(
ChannelData
(
ecode
=
ecode
,
error_info
=
error_info
,
ecode
=
ecode
,
error_info
=
error_info
,
...
@@ -539,32 +561,63 @@ class Op(object):
...
@@ -539,32 +561,63 @@ class Op(object):
# postprocess
# postprocess
output_data
=
None
output_data
=
None
_profiler
.
record
(
"{}-postp_0"
.
format
(
op_info_prefix
))
_profiler
.
record
(
"{}-postp_0"
.
format
(
op_info_prefix
))
if
self
.
with_serving
():
# use call_future
if
self
.
with_serving
():
# use call_future
output_data
=
ChannelData
(
output_data
=
ChannelData
(
future
=
call_future
,
future
=
call_future
,
data_id
=
data_id
,
data_id
=
data_id
,
callback_func
=
self
.
postprocess
)
callback_func
=
self
.
postprocess
)
else
:
else
:
post_data
=
self
.
postprocess
(
data
)
try
:
if
not
isinstance
(
post_data
,
dict
):
postped_data
=
self
.
postprocess
(
preped_data
)
except
Exception
as
e
:
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
error_info
=
log
(
e
)
logging
.
error
(
error_info
)
self
.
_push_to_output_channels
(
ChannelData
(
ecode
=
ecode
,
error_info
=
error_info
,
data_id
=
data_id
))
continue
if
not
isinstance
(
postped_data
,
dict
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
log
(
"output of postprocess funticon must be "
\
error_info
=
log
(
"output of postprocess funticon must be "
\
"dict type, but get {}"
.
format
(
type
(
post_data
)))
"dict type, but get {}"
.
format
(
type
(
post
ped
_data
)))
logging
.
error
(
error_info
)
logging
.
error
(
error_info
)
self
.
_push_to_output_channels
(
self
.
_push_to_output_channels
(
ChannelData
(
ChannelData
(
ecode
=
ecode
,
error_info
=
error_info
,
ecode
=
ecode
,
error_info
=
error_info
,
data_id
=
data_id
))
data_id
=
data_id
))
continue
continue
ecode
=
ChannelDataEcode
.
OK
.
value
error_info
=
None
pbdata
=
channel_pb2
.
ChannelData
()
pbdata
=
channel_pb2
.
ChannelData
()
for
name
,
value
in
post_data
.
items
():
for
name
,
value
in
postped_data
.
items
():
if
not
isinstance
(
name
,
(
str
,
unicode
)):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
log
(
"the key of postped_data must "
\
"be str, but get {}"
.
format
(
type
(
name
)))
break
if
not
isinstance
(
value
,
np
.
ndarray
):
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
error_info
=
log
(
"the value of postped_data must "
\
"be np.ndarray, but get {}"
.
format
(
type
(
value
)))
break
inst
=
channel_pb2
.
Inst
()
inst
=
channel_pb2
.
Inst
()
inst
.
data
=
value
.
tobytes
()
inst
.
data
=
value
.
tobytes
()
inst
.
name
=
name
inst
.
name
=
name
inst
.
shape
=
np
.
array
(
value
.
shape
,
dtype
=
"int32"
).
tobytes
()
inst
.
shape
=
np
.
array
(
value
.
shape
,
dtype
=
"int32"
).
tobytes
()
inst
.
type
=
str
(
value
.
dtype
)
inst
.
type
=
str
(
value
.
dtype
)
pbdata
.
insts
.
append
(
inst
)
pbdata
.
insts
.
append
(
inst
)
pbdata
.
ecode
=
0
if
ecode
!=
ChannelDataEcode
.
OK
.
value
:
logging
.
error
(
error_info
)
self
.
_push_to_output_channels
(
ChannelData
(
ecode
=
ecode
,
error_info
=
error_info
,
data_id
=
data_id
))
continue
pbdata
.
ecode
=
ecode
pbdata
.
id
=
data_id
pbdata
.
id
=
data_id
output_data
=
ChannelData
(
pbdata
=
pbdata
)
output_data
=
ChannelData
(
pbdata
=
pbdata
)
_profiler
.
record
(
"{}-postp_1"
.
format
(
op_info_prefix
))
_profiler
.
record
(
"{}-postp_1"
.
format
(
op_info_prefix
))
...
@@ -587,6 +640,45 @@ class Op(object):
...
@@ -587,6 +640,45 @@ class Op(object):
return
self
.
_concurrency
return
self
.
_concurrency
class
VirtualOp
(
Op
):
''' For connecting two channels. '''
def
__init__
(
self
,
name
,
concurrency
=
1
):
super
(
VirtualOp
,
self
).
__init__
(
name
=
name
,
inputs
=
None
,
concurrency
=
concurrency
)
self
.
_virtual_pred_ops
=
[]
def
add_virtual_pred_op
(
self
,
op
):
self
.
_virtual_pred_ops
.
append
(
op
)
def
add_output_channel
(
self
,
channel
):
if
not
isinstance
(
channel
,
Channel
):
raise
TypeError
(
self
.
_log
(
'output channel must be Channel type, not {}'
.
format
(
type
(
channel
))))
for
op
in
self
.
_virtual_pred_ops
:
channel
.
add_producer
(
op
.
name
)
self
.
_outputs
.
append
(
channel
)
def
start
(
self
,
concurrency_idx
):
op_info_prefix
=
"[{}|{}]"
.
format
(
self
.
name
,
concurrency_idx
)
log
=
self
.
_get_log_func
(
op_info_prefix
)
self
.
_run
=
True
while
self
.
_run
:
_profiler
.
record
(
"{}-get_0"
.
format
(
op_info_prefix
))
channeldata
=
self
.
_input
.
front
(
self
.
name
)
_profiler
.
record
(
"{}-get_1"
.
format
(
op_info_prefix
))
_profiler
.
record
(
"{}-push_0"
.
format
(
op_info_prefix
))
if
isinstance
(
channeldata
,
dict
):
for
name
,
data
in
channeldata
.
items
():
self
.
_push_to_output_channels
(
data
,
name
=
name
)
else
:
self
.
_push_to_output_channels
(
channeldata
,
self
.
_virtual_pred_ops
[
0
].
name
)
_profiler
.
record
(
"{}-push_1"
.
format
(
op_info_prefix
))
class
GeneralPythonService
(
class
GeneralPythonService
(
general_python_service_pb2_grpc
.
GeneralPythonService
):
general_python_service_pb2_grpc
.
GeneralPythonService
):
def
__init__
(
self
,
in_channel
,
out_channel
,
retry
=
2
):
def
__init__
(
self
,
in_channel
,
out_channel
,
retry
=
2
):
...
@@ -668,35 +760,27 @@ class GeneralPythonService(
...
@@ -668,35 +760,27 @@ class GeneralPythonService(
inst
.
name
=
name
inst
.
name
=
name
inst
.
type
=
request
.
type
[
idx
]
inst
.
type
=
request
.
type
[
idx
]
pbdata
.
insts
.
append
(
inst
)
pbdata
.
insts
.
append
(
inst
)
pbdata
.
ecode
=
0
#TODO: parse request error
pbdata
.
ecode
=
ChannelDataEcode
.
OK
.
value
#TODO: parse request error
return
ChannelData
(
pbdata
=
pbdata
),
data_id
return
ChannelData
(
pbdata
=
pbdata
),
data_id
def
_pack_data_for_resp
(
self
,
channeldata
):
def
_pack_data_for_resp
(
self
,
channeldata
):
logging
.
debug
(
self
.
_log
(
'get channeldata'
))
logging
.
debug
(
self
.
_log
(
'get channeldata'
))
logging
.
debug
(
self
.
_log
(
'gen resp'
))
resp
=
pyservice_pb2
.
Response
()
resp
=
pyservice_pb2
.
Response
()
resp
.
ecode
=
channeldata
.
pbdata
.
ecode
resp
.
ecode
=
channeldata
.
pbdata
.
ecode
if
resp
.
ecode
==
0
:
if
resp
.
ecode
==
ChannelDataEcode
.
OK
.
value
:
if
channeldata
.
pbdata
.
type
==
ChannelDataType
.
CHANNEL_PBDATA
.
value
:
if
channeldata
.
pbdata
.
type
==
ChannelDataType
.
CHANNEL_PBDATA
.
value
:
for
inst
in
channeldata
.
pbdata
.
insts
:
for
inst
in
channeldata
.
pbdata
.
insts
:
logging
.
debug
(
self
.
_log
(
'append data'
))
resp
.
fetch_insts
.
append
(
inst
.
data
)
resp
.
fetch_insts
.
append
(
inst
.
data
)
logging
.
debug
(
self
.
_log
(
'append name'
))
resp
.
fetch_var_names
.
append
(
inst
.
name
)
resp
.
fetch_var_names
.
append
(
inst
.
name
)
logging
.
debug
(
self
.
_log
(
'append shape'
))
resp
.
shape
.
append
(
inst
.
shape
)
resp
.
shape
.
append
(
inst
.
shape
)
logging
.
debug
(
self
.
_log
(
'append type'
))
resp
.
type
.
append
(
inst
.
type
)
resp
.
type
.
append
(
inst
.
type
)
elif
channeldata
.
pbdata
.
type
==
ChannelDataType
.
CHANNEL_FUTURE
.
value
:
elif
channeldata
.
pbdata
.
type
==
ChannelDataType
.
CHANNEL_FUTURE
.
value
:
feed
=
channeldata
.
futures
.
result
()
feed
=
channeldata
.
futures
.
result
()
if
channeldata
.
callback_func
is
not
None
:
if
channeldata
.
callback_func
is
not
None
:
feed
=
channeldata
.
callback_func
(
feed
)
feed
=
channeldata
.
callback_func
(
feed
)
for
name
,
var
in
feed
:
for
name
,
var
in
feed
:
logging
.
debug
(
self
.
_log
(
'append data'
))
resp
.
fetch_insts
.
append
(
var
.
tobytes
())
resp
.
fetch_insts
.
append
(
var
.
tobytes
())
logging
.
debug
(
self
.
_log
(
'append name'
))
resp
.
fetch_var_names
.
append
(
name
)
resp
.
fetch_var_names
.
append
(
name
)
logging
.
debug
(
self
.
_log
(
'append shape'
))
resp
.
shape
.
append
(
resp
.
shape
.
append
(
np
.
array
(
np
.
array
(
var
.
shape
,
dtype
=
"int32"
).
tobytes
())
var
.
shape
,
dtype
=
"int32"
).
tobytes
())
...
@@ -726,7 +810,7 @@ class GeneralPythonService(
...
@@ -726,7 +810,7 @@ class GeneralPythonService(
resp_channeldata
=
self
.
_get_data_in_globel_resp_dict
(
data_id
)
resp_channeldata
=
self
.
_get_data_in_globel_resp_dict
(
data_id
)
_profiler
.
record
(
"{}-fetch_1"
.
format
(
self
.
name
))
_profiler
.
record
(
"{}-fetch_1"
.
format
(
self
.
name
))
if
resp_channeldata
.
pbdata
.
ecode
==
0
:
if
resp_channeldata
.
pbdata
.
ecode
==
ChannelDataEcode
.
OK
.
value
:
break
break
if
i
+
1
<
self
.
_retry
:
if
i
+
1
<
self
.
_retry
:
logging
.
warn
(
"retry({}): {}"
.
format
(
logging
.
warn
(
"retry({}): {}"
.
format
(
...
@@ -743,7 +827,7 @@ class PyServer(object):
...
@@ -743,7 +827,7 @@ class PyServer(object):
def
__init__
(
self
,
retry
=
2
,
profile
=
False
):
def
__init__
(
self
,
retry
=
2
,
profile
=
False
):
self
.
_channels
=
[]
self
.
_channels
=
[]
self
.
_user_ops
=
[]
self
.
_user_ops
=
[]
self
.
_
tot
al_ops
=
[]
self
.
_
actu
al_ops
=
[]
self
.
_op_threads
=
[]
self
.
_op_threads
=
[]
self
.
_port
=
None
self
.
_port
=
None
self
.
_worker_num
=
None
self
.
_worker_num
=
None
...
@@ -767,9 +851,13 @@ class PyServer(object):
...
@@ -767,9 +851,13 @@ class PyServer(object):
def
_topo_sort
(
self
):
def
_topo_sort
(
self
):
indeg_num
=
{}
indeg_num
=
{}
outdegs
=
{
op
.
name
:
[]
for
op
in
self
.
_user_ops
}
que_idx
=
0
# scroll queue
que_idx
=
0
# scroll queue
ques
=
[
Queue
.
Queue
()
for
_
in
range
(
2
)]
ques
=
[
Queue
.
Queue
()
for
_
in
range
(
2
)]
for
op
in
self
.
_user_ops
:
if
len
(
op
.
get_input_ops
())
==
0
:
op
.
name
=
"#G"
# update read_op.name
break
outdegs
=
{
op
.
name
:
[]
for
op
in
self
.
_user_ops
}
for
idx
,
op
in
enumerate
(
self
.
_user_ops
):
for
idx
,
op
in
enumerate
(
self
.
_user_ops
):
# check the name of op is globally unique
# check the name of op is globally unique
if
op
.
name
in
indeg_num
:
if
op
.
name
in
indeg_num
:
...
@@ -780,7 +868,7 @@ class PyServer(object):
...
@@ -780,7 +868,7 @@ class PyServer(object):
for
pred_op
in
op
.
get_input_ops
():
for
pred_op
in
op
.
get_input_ops
():
outdegs
[
pred_op
.
name
].
append
(
op
)
outdegs
[
pred_op
.
name
].
append
(
op
)
# get dag_views
#
topo sort to
get dag_views
dag_views
=
[]
dag_views
=
[]
sorted_op_num
=
0
sorted_op_num
=
0
while
True
:
while
True
:
...
@@ -790,9 +878,8 @@ class PyServer(object):
...
@@ -790,9 +878,8 @@ class PyServer(object):
while
que
.
qsize
()
!=
0
:
while
que
.
qsize
()
!=
0
:
op
=
que
.
get
()
op
=
que
.
get
()
dag_view
.
append
(
op
)
dag_view
.
append
(
op
)
op_name
=
op
.
name
sorted_op_num
+=
1
sorted_op_num
+=
1
for
succ_op
in
outdegs
[
op
_
name
]:
for
succ_op
in
outdegs
[
op
.
name
]:
indeg_num
[
succ_op
.
name
]
-=
1
indeg_num
[
succ_op
.
name
]
-=
1
if
indeg_num
[
succ_op
.
name
]
==
0
:
if
indeg_num
[
succ_op
.
name
]
==
0
:
next_que
.
put
(
succ_op
)
next_que
.
put
(
succ_op
)
...
@@ -808,52 +895,69 @@ class PyServer(object):
...
@@ -808,52 +895,69 @@ class PyServer(object):
raise
Exception
(
"DAG contains multiple output Ops"
)
raise
Exception
(
"DAG contains multiple output Ops"
)
# create channels and virtual ops
# create channels and virtual ops
virtual_op_idx
=
0
def
name_generator
(
prefix
):
channel_idx
=
0
def
number_generator
():
idx
=
0
while
True
:
yield
"{}{}"
.
format
(
prefix
,
idx
)
idx
+=
1
return
number_generator
()
virtual_op_name_gen
=
name_generator
(
"vir"
)
channel_name_gen
=
name_generator
(
"chl"
)
virtual_ops
=
[]
virtual_ops
=
[]
channels
=
[]
channels
=
[]
input_channel
=
None
input_channel
=
None
actual_view
=
None
for
v_idx
,
view
in
enumerate
(
dag_views
):
for
v_idx
,
view
in
enumerate
(
dag_views
):
if
v_idx
+
1
>=
len
(
dag_views
):
if
v_idx
+
1
>=
len
(
dag_views
):
break
break
next_view
=
dag_views
[
v_idx
+
1
]
next_view
=
dag_views
[
v_idx
+
1
]
if
actual_view
is
None
:
actual_view
=
view
actual_next_view
=
[]
actual_next_view
=
[]
pred_op_of_next_view_op
=
{}
pred_op_of_next_view_op
=
{}
for
op
in
view
:
for
op
in
actual_
view
:
# create virtual op
#
find actual succ op in next view and
create virtual op
for
succ_op
in
outdegs
[
op
.
name
]:
for
succ_op
in
outdegs
[
op
.
name
]:
if
succ_op
in
next_view
:
if
succ_op
in
next_view
:
actual_next_view
.
append
(
succ_op
)
if
succ_op
not
in
actual_next_view
:
actual_next_view
.
append
(
succ_op
)
if
succ_op
.
name
not
in
pred_op_of_next_view_op
:
if
succ_op
.
name
not
in
pred_op_of_next_view_op
:
pred_op_of_next_view_op
[
succ_op
.
name
]
=
[]
pred_op_of_next_view_op
[
succ_op
.
name
]
=
[]
pred_op_of_next_view_op
[
succ_op
.
name
].
append
(
op
)
pred_op_of_next_view_op
[
succ_op
.
name
].
append
(
op
)
else
:
else
:
vop
=
Op
(
name
=
"vir{}"
.
format
(
virtual_op_idx
),
inputs
=
[])
# create virtual op
virtual_op_idx
+=
1
virtual_op
=
None
virtual_op
=
VirtualOp
(
name
=
virtual_op_name_gen
.
next
())
virtual_ops
.
append
(
virtual_op
)
virtual_ops
.
append
(
virtual_op
)
outdegs
[
vop
.
name
]
=
[
succ_op
]
outdegs
[
virtual_op
.
name
]
=
[
succ_op
]
actual_next_view
.
append
(
vop
)
actual_next_view
.
append
(
virtual_op
)
# TODO: combine vop
pred_op_of_next_view_op
[
virtual_op
.
name
]
=
[
op
]
pred_op_of_next_view_op
[
vop
.
name
]
=
[
op
]
virtual_op
.
add_virtual_pred_op
(
op
)
actual_view
=
actual_next_view
# create channel
# create channel
processed_op
=
set
()
processed_op
=
set
()
for
o_idx
,
op
in
enumerate
(
actual_next_view
):
for
o_idx
,
op
in
enumerate
(
actual_next_view
):
op_name
=
op
.
name
if
op
.
name
in
processed_op
:
if
op_name
in
processed_op
:
continue
continue
channel
=
Channel
(
name
=
"chl{}"
.
format
(
channel_idx
))
channel
=
Channel
(
name
=
channel_name_gen
.
next
())
channel_idx
+=
1
channels
.
append
(
channel
)
channels
.
append
(
channel
)
logging
.
debug
(
"{} => {}"
.
format
(
channel
.
name
,
op
.
name
))
op
.
add_input_channel
(
channel
)
op
.
add_input_channel
(
channel
)
pred_ops
=
pred_op_of_next_view_op
[
op
_
name
]
pred_ops
=
pred_op_of_next_view_op
[
op
.
name
]
if
v_idx
==
0
:
if
v_idx
==
0
:
input_channel
=
channel
input_channel
=
channel
else
:
else
:
# if pred_op is virtual op, it will use ancestors as producers to channel
for
pred_op
in
pred_ops
:
for
pred_op
in
pred_ops
:
logging
.
debug
(
"{} => {}"
.
format
(
pred_op
.
name
,
channel
.
name
))
pred_op
.
add_output_channel
(
channel
)
pred_op
.
add_output_channel
(
channel
)
processed_op
.
add
(
op
_
name
)
processed_op
.
add
(
op
.
name
)
# combine channel
#
find same input op to
combine channel
for
other_op
in
actual_next_view
[
o_idx
:]:
for
other_op
in
actual_next_view
[
o_idx
+
1
:]:
if
other_op
.
name
in
processed_op
:
if
other_op
.
name
in
processed_op
:
continue
continue
other_pred_ops
=
pred_op_of_next_view_op
[
other_op
.
name
]
other_pred_ops
=
pred_op_of_next_view_op
[
other_op
.
name
]
...
@@ -865,19 +969,24 @@ class PyServer(object):
...
@@ -865,19 +969,24 @@ class PyServer(object):
same_flag
=
False
same_flag
=
False
break
break
if
same_flag
:
if
same_flag
:
logging
.
debug
(
"{} => {}"
.
format
(
channel
.
name
,
other_op
.
name
))
other_op
.
add_input_channel
(
channel
)
other_op
.
add_input_channel
(
channel
)
processed_op
.
add
(
other_op
.
name
)
processed_op
.
add
(
other_op
.
name
)
output_channel
=
Channel
(
name
=
"Ochl"
)
output_channel
=
Channel
(
name
=
channel_name_gen
.
next
()
)
channels
.
append
(
output_channel
)
channels
.
append
(
output_channel
)
last_op
=
dag_views
[
-
1
][
0
]
last_op
=
dag_views
[
-
1
][
0
]
last_op
.
add_output_channel
(
output_channel
)
last_op
.
add_output_channel
(
output_channel
)
self
.
_ops
=
virtual_ops
self
.
_
actual_
ops
=
virtual_ops
for
op
in
self
.
_user_ops
:
for
op
in
self
.
_user_ops
:
if
len
(
op
.
get_input_ops
())
==
0
:
if
len
(
op
.
get_input_ops
())
==
0
:
# pass read op
continue
continue
self
.
_ops
.
append
(
op
)
self
.
_
actual_
ops
.
append
(
op
)
self
.
_channels
=
channels
self
.
_channels
=
channels
for
c
in
channels
:
logging
.
debug
(
c
.
debug
())
return
input_channel
,
output_channel
return
input_channel
,
output_channel
def
prepare_server
(
self
,
port
,
worker_num
):
def
prepare_server
(
self
,
port
,
worker_num
):
...
@@ -887,7 +996,7 @@ class PyServer(object):
...
@@ -887,7 +996,7 @@ class PyServer(object):
input_channel
,
output_channel
=
self
.
_topo_sort
()
input_channel
,
output_channel
=
self
.
_topo_sort
()
self
.
_in_channel
=
input_channel
self
.
_in_channel
=
input_channel
self
.
_out_channel
=
output_channel
self
.
_out_channel
=
output_channel
for
op
in
self
.
_ops
:
for
op
in
self
.
_
actual_
ops
:
if
op
.
with_serving
():
if
op
.
with_serving
():
self
.
prepare_serving
(
op
)
self
.
prepare_serving
(
op
)
self
.
gen_desc
()
self
.
gen_desc
()
...
@@ -896,7 +1005,7 @@ class PyServer(object):
...
@@ -896,7 +1005,7 @@ class PyServer(object):
return
op
.
start
(
concurrency_idx
)
return
op
.
start
(
concurrency_idx
)
def
_run_ops
(
self
):
def
_run_ops
(
self
):
for
op
in
self
.
_ops
:
for
op
in
self
.
_
actual_
ops
:
op_concurrency
=
op
.
get_concurrency
()
op_concurrency
=
op
.
get_concurrency
()
logging
.
debug
(
"run op: {}, op_concurrency: {}"
.
format
(
logging
.
debug
(
"run op: {}, op_concurrency: {}"
.
format
(
op
.
name
,
op_concurrency
))
op
.
name
,
op_concurrency
))
...
@@ -907,7 +1016,7 @@ class PyServer(object):
...
@@ -907,7 +1016,7 @@ class PyServer(object):
self
.
_op_threads
.
append
(
th
)
self
.
_op_threads
.
append
(
th
)
def
_stop_ops
(
self
):
def
_stop_ops
(
self
):
for
op
in
self
.
_ops
:
for
op
in
self
.
_
actual_
ops
:
op
.
stop
()
op
.
stop
()
def
run_server
(
self
):
def
run_server
(
self
):
...
@@ -921,6 +1030,8 @@ class PyServer(object):
...
@@ -921,6 +1030,8 @@ class PyServer(object):
server
.
start
()
server
.
start
()
server
.
wait_for_termination
()
server
.
wait_for_termination
()
self
.
_stop_ops
()
# TODO
self
.
_stop_ops
()
# TODO
for
th
in
self
.
_op_threads
:
th
.
join
()
def
prepare_serving
(
self
,
op
):
def
prepare_serving
(
self
,
op
):
model_path
=
op
.
_server_model
model_path
=
op
.
_server_model
...
@@ -935,5 +1046,3 @@ class PyServer(object):
...
@@ -935,5 +1046,3 @@ class PyServer(object):
" --model {} --thread 4 --port {} --use_multilang &>/dev/null &"
.
format
(
model_path
,
port
)
" --model {} --thread 4 --port {} --use_multilang &>/dev/null &"
.
format
(
model_path
,
port
)
# run a server (not in PyServing)
# run a server (not in PyServing)
logging
.
info
(
"run a server (not in PyServing): {}"
.
format
(
cmd
))
logging
.
info
(
"run a server (not in PyServing): {}"
.
format
(
cmd
))
return
# os.system(cmd)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录