Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
0f18e403
S
Serving
项目概览
PaddlePaddle
/
Serving
大约 1 年 前同步成功
通知
186
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0f18e403
编写于
6月 11, 2020
作者:
B
barrierye
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[WIP] remove channel def in user side
上级
26eda7a0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
203 addition
and
94 deletion
+203
-94
python/examples/fit_a_line/test_py_server.py
python/examples/fit_a_line/test_py_server.py
+5
-23
python/paddle_serving_server/pyserver.py
python/paddle_serving_server/pyserver.py
+198
-71
未找到文件。
python/examples/fit_a_line/test_py_server.py
浏览文件 @
0f18e403
...
@@ -25,8 +25,6 @@ logging.basicConfig(
...
@@ -25,8 +25,6 @@ logging.basicConfig(
#level=logging.DEBUG)
#level=logging.DEBUG)
level
=
logging
.
INFO
)
level
=
logging
.
INFO
)
# channel data: {name(str): data(narray)}
class
CombineOp
(
Op
):
class
CombineOp
(
Op
):
def
preprocess
(
self
,
input_data
):
def
preprocess
(
self
,
input_data
):
...
@@ -39,13 +37,9 @@ class CombineOp(Op):
...
@@ -39,13 +37,9 @@ class CombineOp(Op):
return
data
return
data
read_channel
=
Channel
(
name
=
"read_channel"
)
read_op
=
Op
(
name
=
"read"
,
input
=
None
)
combine_channel
=
Channel
(
name
=
"combine_channel"
)
out_channel
=
Channel
(
name
=
"out_channel"
)
uci1_op
=
Op
(
name
=
"uci1"
,
uci1_op
=
Op
(
name
=
"uci1"
,
input
=
read_channel
,
inputs
=
[
read_op
],
outputs
=
[
combine_channel
],
server_model
=
"./uci_housing_model"
,
server_model
=
"./uci_housing_model"
,
server_port
=
"9393"
,
server_port
=
"9393"
,
device
=
"cpu"
,
device
=
"cpu"
,
...
@@ -55,10 +49,8 @@ uci1_op = Op(name="uci1",
...
@@ -55,10 +49,8 @@ uci1_op = Op(name="uci1",
concurrency
=
1
,
concurrency
=
1
,
timeout
=
0.1
,
timeout
=
0.1
,
retry
=
2
)
retry
=
2
)
uci2_op
=
Op
(
name
=
"uci2"
,
uci2_op
=
Op
(
name
=
"uci2"
,
input
=
read_channel
,
inputs
=
[
read_op
],
outputs
=
[
combine_channel
],
server_model
=
"./uci_housing_model"
,
server_model
=
"./uci_housing_model"
,
server_port
=
"9292"
,
server_port
=
"9292"
,
device
=
"cpu"
,
device
=
"cpu"
,
...
@@ -68,24 +60,14 @@ uci2_op = Op(name="uci2",
...
@@ -68,24 +60,14 @@ uci2_op = Op(name="uci2",
concurrency
=
1
,
concurrency
=
1
,
timeout
=-
1
,
timeout
=-
1
,
retry
=
1
)
retry
=
1
)
combine_op
=
CombineOp
(
combine_op
=
CombineOp
(
name
=
"combine"
,
name
=
"combine"
,
input
=
combine_channel
,
inputs
=
[
uci1_op
,
uci2_op
],
outputs
=
[
out_channel
],
concurrency
=
1
,
concurrency
=
1
,
timeout
=-
1
,
timeout
=-
1
,
retry
=
1
)
retry
=
1
)
logging
.
info
(
read_channel
.
debug
())
logging
.
info
(
combine_channel
.
debug
())
logging
.
info
(
out_channel
.
debug
())
pyserver
=
PyServer
(
profile
=
False
,
retry
=
1
)
pyserver
=
PyServer
(
profile
=
False
,
retry
=
1
)
pyserver
.
add_channel
(
read_channel
)
pyserver
.
add_ops
([
read_op
,
uci1_op
,
uci2_op
,
combine_op
])
pyserver
.
add_channel
(
combine_channel
)
pyserver
.
add_channel
(
out_channel
)
pyserver
.
add_op
(
uci1_op
)
pyserver
.
add_op
(
uci2_op
)
pyserver
.
add_op
(
combine_op
)
pyserver
.
prepare_server
(
port
=
8080
,
worker_num
=
2
)
pyserver
.
prepare_server
(
port
=
8080
,
worker_num
=
2
)
pyserver
.
run_server
()
pyserver
.
run_server
()
python/paddle_serving_server/pyserver.py
浏览文件 @
0f18e403
...
@@ -31,6 +31,7 @@ import random
...
@@ -31,6 +31,7 @@ import random
import
time
import
time
import
func_timeout
import
func_timeout
import
enum
import
enum
import
collections
class
_TimeProfiler
(
object
):
class
_TimeProfiler
(
object
):
...
@@ -140,7 +141,7 @@ class Channel(Queue.Queue):
...
@@ -140,7 +141,7 @@ class Channel(Queue.Queue):
Queue
.
Queue
.
__init__
(
self
,
maxsize
=
maxsize
)
Queue
.
Queue
.
__init__
(
self
,
maxsize
=
maxsize
)
self
.
_maxsize
=
maxsize
self
.
_maxsize
=
maxsize
self
.
_timeout
=
timeout
self
.
_timeout
=
timeout
self
.
_
name
=
name
self
.
name
=
name
self
.
_stop
=
False
self
.
_stop
=
False
self
.
_cv
=
threading
.
Condition
()
self
.
_cv
=
threading
.
Condition
()
...
@@ -161,7 +162,7 @@ class Channel(Queue.Queue):
...
@@ -161,7 +162,7 @@ class Channel(Queue.Queue):
return
self
.
_consumers
.
keys
()
return
self
.
_consumers
.
keys
()
def
_log
(
self
,
info_str
):
def
_log
(
self
,
info_str
):
return
"[{}] {}"
.
format
(
self
.
_
name
,
info_str
)
return
"[{}] {}"
.
format
(
self
.
name
,
info_str
)
def
debug
(
self
):
def
debug
(
self
):
return
self
.
_log
(
"p: {}, c: {}"
.
format
(
self
.
get_producers
(),
return
self
.
_log
(
"p: {}, c: {}"
.
format
(
self
.
get_producers
(),
...
@@ -313,8 +314,7 @@ class Channel(Queue.Queue):
...
@@ -313,8 +314,7 @@ class Channel(Queue.Queue):
class
Op
(
object
):
class
Op
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
name
,
name
,
input
,
inputs
,
outputs
,
server_model
=
None
,
server_model
=
None
,
server_port
=
None
,
server_port
=
None
,
device
=
None
,
device
=
None
,
...
@@ -325,23 +325,24 @@ class Op(object):
...
@@ -325,23 +325,24 @@ class Op(object):
timeout
=-
1
,
timeout
=-
1
,
retry
=
2
):
retry
=
2
):
self
.
_run
=
False
self
.
_run
=
False
# TODO: globally unique check
self
.
name
=
name
# to identify the type of OP, it must be globally unique
self
.
_name
=
name
# to identify the type of OP, it must be globally unique
self
.
_concurrency
=
concurrency
# amount of concurrency
self
.
_concurrency
=
concurrency
# amount of concurrency
self
.
set_input
(
input
)
self
.
set_input_ops
(
inputs
)
self
.
set_outputs
(
outputs
)
self
.
set_client
(
client_config
,
server_name
,
fetch_names
)
self
.
_client
=
None
if
client_config
is
not
None
and
\
server_name
is
not
None
and
\
fetch_names
is
not
None
:
self
.
set_client
(
client_config
,
server_name
,
fetch_names
)
self
.
_server_model
=
server_model
self
.
_server_model
=
server_model
self
.
_server_port
=
server_port
self
.
_server_port
=
server_port
self
.
_device
=
device
self
.
_device
=
device
self
.
_timeout
=
timeout
self
.
_timeout
=
timeout
self
.
_retry
=
retry
self
.
_retry
=
retry
self
.
_input
=
None
self
.
_outputs
=
[]
def
set_client
(
self
,
client_config
,
server_name
,
fetch_names
):
def
set_client
(
self
,
client_config
,
server_name
,
fetch_names
):
self
.
_client
=
None
if
client_config
is
None
or
\
server_name
is
None
or
\
fetch_names
is
None
:
return
self
.
_client
=
Client
()
self
.
_client
=
Client
()
self
.
_client
.
load_client_config
(
client_config
)
self
.
_client
.
load_client_config
(
client_config
)
self
.
_client
.
connect
([
server_name
])
self
.
_client
.
connect
([
server_name
])
...
@@ -350,28 +351,41 @@ class Op(object):
...
@@ -350,28 +351,41 @@ class Op(object):
def
with_serving
(
self
):
def
with_serving
(
self
):
return
self
.
_client
is
not
None
return
self
.
_client
is
not
None
def
get_input
(
self
):
def
get_input
_channel
(
self
):
return
self
.
_input
return
self
.
_input
def
set_input
(
self
,
channel
):
def
get_input_ops
(
self
):
return
self
.
_input_ops
def
set_input_ops
(
self
,
ops
):
if
not
isinstance
(
ops
,
list
):
ops
=
[]
if
ops
is
None
else
[
ops
]
self
.
_input_ops
=
[]
for
op
in
ops
:
if
not
isinstance
(
op
,
Op
):
raise
TypeError
(
self
.
_log
(
'input op must be Op type, not {}'
.
format
(
type
(
op
))))
self
.
_input_ops
.
append
(
op
)
def
add_input_channel
(
self
,
channel
):
if
not
isinstance
(
channel
,
Channel
):
if
not
isinstance
(
channel
,
Channel
):
raise
TypeError
(
raise
TypeError
(
self
.
_log
(
'input channel must be Channel type, not {}'
.
format
(
self
.
_log
(
'input channel must be Channel type, not {}'
.
format
(
type
(
channel
))))
type
(
channel
))))
channel
.
add_consumer
(
self
.
_
name
)
channel
.
add_consumer
(
self
.
name
)
self
.
_input
=
channel
self
.
_input
=
channel
def
get_outputs
(
self
):
def
get_output
_channel
s
(
self
):
return
self
.
_outputs
return
self
.
_outputs
def
set_outputs
(
self
,
channels
):
def
add_output_channel
(
self
,
channel
):
if
not
isinstance
(
channel
s
,
list
):
if
not
isinstance
(
channel
,
Channel
):
raise
TypeError
(
raise
TypeError
(
self
.
_log
(
'output channels must be list type, not {}'
.
format
(
self
.
_log
(
'output channel must be Channel type, not {}'
.
format
(
type
(
channels
))))
type
(
channel
))))
for
channel
in
channels
:
channel
.
add_producer
(
self
.
name
)
channel
.
add_producer
(
self
.
_name
)
self
.
_outputs
.
append
(
channel
)
self
.
_outputs
=
channels
def
preprocess
(
self
,
channeldata
):
def
preprocess
(
self
,
channeldata
):
if
isinstance
(
channeldata
,
dict
):
if
isinstance
(
channeldata
,
dict
):
...
@@ -430,26 +444,26 @@ class Op(object):
...
@@ -430,26 +444,26 @@ class Op(object):
def
start
(
self
,
concurrency_idx
):
def
start
(
self
,
concurrency_idx
):
self
.
_run
=
True
self
.
_run
=
True
while
self
.
_run
:
while
self
.
_run
:
_profiler
.
record
(
"{}{}-get_0"
.
format
(
self
.
_
name
,
concurrency_idx
))
_profiler
.
record
(
"{}{}-get_0"
.
format
(
self
.
name
,
concurrency_idx
))
input_data
=
self
.
_input
.
front
(
self
.
_
name
)
input_data
=
self
.
_input
.
front
(
self
.
name
)
_profiler
.
record
(
"{}{}-get_1"
.
format
(
self
.
_
name
,
concurrency_idx
))
_profiler
.
record
(
"{}{}-get_1"
.
format
(
self
.
name
,
concurrency_idx
))
logging
.
debug
(
self
.
_log
(
"input_data: {}"
.
format
(
input_data
)))
logging
.
debug
(
self
.
_log
(
"input_data: {}"
.
format
(
input_data
)))
data_id
,
error_data
=
self
.
_parse_channeldata
(
input_data
)
data_id
,
error_data
=
self
.
_parse_channeldata
(
input_data
)
output_data
=
None
output_data
=
None
if
error_data
is
None
:
if
error_data
is
None
:
_profiler
.
record
(
"{}{}-prep_0"
.
format
(
self
.
_
name
,
_profiler
.
record
(
"{}{}-prep_0"
.
format
(
self
.
name
,
concurrency_idx
))
concurrency_idx
))
data
=
self
.
preprocess
(
input_data
)
data
=
self
.
preprocess
(
input_data
)
_profiler
.
record
(
"{}{}-prep_1"
.
format
(
self
.
_
name
,
_profiler
.
record
(
"{}{}-prep_1"
.
format
(
self
.
name
,
concurrency_idx
))
concurrency_idx
))
call_future
=
None
call_future
=
None
error_info
=
None
error_info
=
None
if
self
.
with_serving
():
if
self
.
with_serving
():
for
i
in
range
(
self
.
_retry
):
for
i
in
range
(
self
.
_retry
):
_profiler
.
record
(
"{}{}-midp_0"
.
format
(
self
.
_
name
,
_profiler
.
record
(
"{}{}-midp_0"
.
format
(
self
.
name
,
concurrency_idx
))
concurrency_idx
))
if
self
.
_timeout
>
0
:
if
self
.
_timeout
>
0
:
try
:
try
:
...
@@ -460,21 +474,21 @@ class Op(object):
...
@@ -460,21 +474,21 @@ class Op(object):
except
func_timeout
.
FunctionTimedOut
:
except
func_timeout
.
FunctionTimedOut
:
logging
.
error
(
"error: timeout"
)
logging
.
error
(
"error: timeout"
)
error_info
=
"{}({}): timeout"
.
format
(
error_info
=
"{}({}): timeout"
.
format
(
self
.
_
name
,
concurrency_idx
)
self
.
name
,
concurrency_idx
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
"error: {}"
.
format
(
e
))
logging
.
error
(
"error: {}"
.
format
(
e
))
error_info
=
"{}({}): {}"
.
format
(
error_info
=
"{}({}): {}"
.
format
(
self
.
_
name
,
concurrency_idx
,
e
)
self
.
name
,
concurrency_idx
,
e
)
else
:
else
:
call_future
=
self
.
midprocess
(
data
)
call_future
=
self
.
midprocess
(
data
)
_profiler
.
record
(
"{}{}-midp_1"
.
format
(
self
.
_
name
,
_profiler
.
record
(
"{}{}-midp_1"
.
format
(
self
.
name
,
concurrency_idx
))
concurrency_idx
))
if
i
+
1
<
self
.
_retry
:
if
i
+
1
<
self
.
_retry
:
error_info
=
None
error_info
=
None
logging
.
warn
(
logging
.
warn
(
self
.
_log
(
"warn: timeout, retry({})"
.
format
(
i
+
self
.
_log
(
"warn: timeout, retry({})"
.
format
(
i
+
1
)))
1
)))
_profiler
.
record
(
"{}{}-postp_0"
.
format
(
self
.
_
name
,
_profiler
.
record
(
"{}{}-postp_0"
.
format
(
self
.
name
,
concurrency_idx
))
concurrency_idx
))
if
error_info
is
not
None
:
if
error_info
is
not
None
:
error_data
=
self
.
errorprocess
(
error_info
,
data_id
)
error_data
=
self
.
errorprocess
(
error_info
,
data_id
)
...
@@ -504,18 +518,18 @@ class Op(object):
...
@@ -504,18 +518,18 @@ class Op(object):
pbdata
.
ecode
=
0
pbdata
.
ecode
=
0
pbdata
.
id
=
data_id
pbdata
.
id
=
data_id
output_data
=
ChannelData
(
pbdata
=
pbdata
)
output_data
=
ChannelData
(
pbdata
=
pbdata
)
_profiler
.
record
(
"{}{}-postp_1"
.
format
(
self
.
_
name
,
_profiler
.
record
(
"{}{}-postp_1"
.
format
(
self
.
name
,
concurrency_idx
))
concurrency_idx
))
else
:
else
:
output_data
=
ChannelData
(
pbdata
=
error_data
)
output_data
=
ChannelData
(
pbdata
=
error_data
)
_profiler
.
record
(
"{}{}-push_0"
.
format
(
self
.
_
name
,
concurrency_idx
))
_profiler
.
record
(
"{}{}-push_0"
.
format
(
self
.
name
,
concurrency_idx
))
for
channel
in
self
.
_outputs
:
for
channel
in
self
.
_outputs
:
channel
.
push
(
output_data
,
self
.
_
name
)
channel
.
push
(
output_data
,
self
.
name
)
_profiler
.
record
(
"{}{}-push_1"
.
format
(
self
.
_
name
,
concurrency_idx
))
_profiler
.
record
(
"{}{}-push_1"
.
format
(
self
.
name
,
concurrency_idx
))
def
_log
(
self
,
info_str
):
def
_log
(
self
,
info_str
):
return
"[{}] {}"
.
format
(
self
.
_
name
,
info_str
)
return
"[{}] {}"
.
format
(
self
.
name
,
info_str
)
def
get_concurrency
(
self
):
def
get_concurrency
(
self
):
return
self
.
_concurrency
return
self
.
_concurrency
...
@@ -525,7 +539,7 @@ class GeneralPythonService(
...
@@ -525,7 +539,7 @@ class GeneralPythonService(
general_python_service_pb2_grpc
.
GeneralPythonService
):
general_python_service_pb2_grpc
.
GeneralPythonService
):
def
__init__
(
self
,
in_channel
,
out_channel
,
retry
=
2
):
def
__init__
(
self
,
in_channel
,
out_channel
,
retry
=
2
):
super
(
GeneralPythonService
,
self
).
__init__
()
super
(
GeneralPythonService
,
self
).
__init__
()
self
.
_
name
=
"#G"
self
.
name
=
"#G"
self
.
set_in_channel
(
in_channel
)
self
.
set_in_channel
(
in_channel
)
self
.
set_out_channel
(
out_channel
)
self
.
set_out_channel
(
out_channel
)
logging
.
debug
(
self
.
_log
(
in_channel
.
debug
()))
logging
.
debug
(
self
.
_log
(
in_channel
.
debug
()))
...
@@ -543,14 +557,14 @@ class GeneralPythonService(
...
@@ -543,14 +557,14 @@ class GeneralPythonService(
self
.
_recive_func
.
start
()
self
.
_recive_func
.
start
()
def
_log
(
self
,
info_str
):
def
_log
(
self
,
info_str
):
return
"[{}] {}"
.
format
(
self
.
_
name
,
info_str
)
return
"[{}] {}"
.
format
(
self
.
name
,
info_str
)
def
set_in_channel
(
self
,
in_channel
):
def
set_in_channel
(
self
,
in_channel
):
if
not
isinstance
(
in_channel
,
Channel
):
if
not
isinstance
(
in_channel
,
Channel
):
raise
TypeError
(
raise
TypeError
(
self
.
_log
(
'in_channel must be Channel type, but get {}'
.
format
(
self
.
_log
(
'in_channel must be Channel type, but get {}'
.
format
(
type
(
in_channel
))))
type
(
in_channel
))))
in_channel
.
add_producer
(
self
.
_
name
)
in_channel
.
add_producer
(
self
.
name
)
self
.
_in_channel
=
in_channel
self
.
_in_channel
=
in_channel
def
set_out_channel
(
self
,
out_channel
):
def
set_out_channel
(
self
,
out_channel
):
...
@@ -558,12 +572,12 @@ class GeneralPythonService(
...
@@ -558,12 +572,12 @@ class GeneralPythonService(
raise
TypeError
(
raise
TypeError
(
self
.
_log
(
'out_channel must be Channel type, but get {}'
.
format
(
self
.
_log
(
'out_channel must be Channel type, but get {}'
.
format
(
type
(
out_channel
))))
type
(
out_channel
))))
out_channel
.
add_consumer
(
self
.
_
name
)
out_channel
.
add_consumer
(
self
.
name
)
self
.
_out_channel
=
out_channel
self
.
_out_channel
=
out_channel
def
_recive_out_channel_func
(
self
):
def
_recive_out_channel_func
(
self
):
while
True
:
while
True
:
channeldata
=
self
.
_out_channel
.
front
(
self
.
_
name
)
channeldata
=
self
.
_out_channel
.
front
(
self
.
name
)
if
not
isinstance
(
channeldata
,
ChannelData
):
if
not
isinstance
(
channeldata
,
ChannelData
):
raise
TypeError
(
raise
TypeError
(
self
.
_log
(
'data must be ChannelData type, but get {}'
.
self
.
_log
(
'data must be ChannelData type, but get {}'
.
...
@@ -644,38 +658,43 @@ class GeneralPythonService(
...
@@ -644,38 +658,43 @@ class GeneralPythonService(
return
resp
return
resp
def
inference
(
self
,
request
,
context
):
def
inference
(
self
,
request
,
context
):
_profiler
.
record
(
"{}-prepack_0"
.
format
(
self
.
_
name
))
_profiler
.
record
(
"{}-prepack_0"
.
format
(
self
.
name
))
data
,
data_id
=
self
.
_pack_data_for_infer
(
request
)
data
,
data_id
=
self
.
_pack_data_for_infer
(
request
)
_profiler
.
record
(
"{}-prepack_1"
.
format
(
self
.
_
name
))
_profiler
.
record
(
"{}-prepack_1"
.
format
(
self
.
name
))
resp_channeldata
=
None
resp_channeldata
=
None
for
i
in
range
(
self
.
_retry
):
for
i
in
range
(
self
.
_retry
):
logging
.
debug
(
self
.
_log
(
'push data'
))
logging
.
debug
(
self
.
_log
(
'push data'
))
_profiler
.
record
(
"{}-push_0"
.
format
(
self
.
_
name
))
_profiler
.
record
(
"{}-push_0"
.
format
(
self
.
name
))
self
.
_in_channel
.
push
(
data
,
self
.
_
name
)
self
.
_in_channel
.
push
(
data
,
self
.
name
)
_profiler
.
record
(
"{}-push_1"
.
format
(
self
.
_
name
))
_profiler
.
record
(
"{}-push_1"
.
format
(
self
.
name
))
logging
.
debug
(
self
.
_log
(
'wait for infer'
))
logging
.
debug
(
self
.
_log
(
'wait for infer'
))
_profiler
.
record
(
"{}-fetch_0"
.
format
(
self
.
_
name
))
_profiler
.
record
(
"{}-fetch_0"
.
format
(
self
.
name
))
resp_channeldata
=
self
.
_get_data_in_globel_resp_dict
(
data_id
)
resp_channeldata
=
self
.
_get_data_in_globel_resp_dict
(
data_id
)
_profiler
.
record
(
"{}-fetch_1"
.
format
(
self
.
_
name
))
_profiler
.
record
(
"{}-fetch_1"
.
format
(
self
.
name
))
if
resp_channeldata
.
pbdata
.
ecode
==
0
:
if
resp_channeldata
.
pbdata
.
ecode
==
0
:
break
break
logging
.
warn
(
"retry({}): {}"
.
format
(
logging
.
warn
(
"retry({}): {}"
.
format
(
i
+
1
,
resp_channeldata
.
pbdata
.
error_info
))
i
+
1
,
resp_channeldata
.
pbdata
.
error_info
))
_profiler
.
record
(
"{}-postpack_0"
.
format
(
self
.
_
name
))
_profiler
.
record
(
"{}-postpack_0"
.
format
(
self
.
name
))
resp
=
self
.
_pack_data_for_resp
(
resp_channeldata
)
resp
=
self
.
_pack_data_for_resp
(
resp_channeldata
)
_profiler
.
record
(
"{}-postpack_1"
.
format
(
self
.
_
name
))
_profiler
.
record
(
"{}-postpack_1"
.
format
(
self
.
name
))
_profiler
.
print_profile
()
_profiler
.
print_profile
()
return
resp
return
resp
class
VirtualOp
(
Op
):
pass
class
PyServer
(
object
):
class
PyServer
(
object
):
def
__init__
(
self
,
retry
=
2
,
profile
=
False
):
def
__init__
(
self
,
retry
=
2
,
profile
=
False
):
self
.
_channels
=
[]
self
.
_channels
=
[]
self
.
_ops
=
[]
self
.
_user_ops
=
[]
self
.
_total_ops
=
[]
self
.
_op_threads
=
[]
self
.
_op_threads
=
[]
self
.
_port
=
None
self
.
_port
=
None
self
.
_worker_num
=
None
self
.
_worker_num
=
None
...
@@ -688,40 +707,147 @@ class PyServer(object):
...
@@ -688,40 +707,147 @@ class PyServer(object):
self
.
_channels
.
append
(
channel
)
self
.
_channels
.
append
(
channel
)
def
add_op
(
self
,
op
):
def
add_op
(
self
,
op
):
self
.
_ops
.
append
(
op
)
self
.
_user_ops
.
append
(
op
)
def
add_ops
(
self
,
ops
):
self
.
_user_ops
.
expand
(
ops
)
def
gen_desc
(
self
):
def
gen_desc
(
self
):
logging
.
info
(
'here will generate desc for PAAS'
)
logging
.
info
(
'here will generate desc for PAAS'
)
pass
pass
def
_topo_sort
(
self
):
indeg_num
=
{}
outdegs
=
{}
que_idx
=
0
# scroll queue
ques
=
[
Queue
.
SimpleQueue
()
for
_
in
range
(
2
)]
for
idx
,
op
in
enumerate
(
self
.
_user_ops
):
# check the name of op is globally unique
if
op
.
name
in
indeg_num
:
raise
Exception
(
"the name of Op must be unique"
)
indeg_num
[
op
.
name
]
=
len
(
op
.
get_input_ops
())
if
indeg_num
[
op
.
name
]
==
0
:
ques
[
que_idx
].
put
(
op
)
for
pred_op
in
op
.
get_input_ops
():
if
op
.
name
in
outdegs
:
outdegs
[
op
.
name
].
append
(
op
)
else
:
outdegs
[
op
.
name
]
=
[
op
]
# get dag_views
dag_views
=
[]
sorted_op_num
=
0
while
True
:
que
=
ques
[
que_idx
]
next_que
=
ques
[(
que_idx
+
1
)
%
2
]
dag_view
=
[]
while
que
.
qsize
()
!=
0
:
op
=
que
.
get
()
dag_view
.
append
(
op
)
op_name
=
op
.
name
sorted_op_num
+=
1
for
succ_op
in
outdegs
[
op_name
]:
indeg_num
[
op_name
]
-=
1
if
indeg_num
[
succ_op
.
name
]
==
0
:
next_que
.
put
(
succ_op
)
dag_views
.
append
(
dag_view
)
if
next_que
.
qsize
()
==
0
:
break
que_idx
=
(
que_idx
+
1
)
%
2
if
sorted_op_num
<
len
(
self
.
_user_ops
):
raise
Exception
(
"not legal DAG"
)
if
len
(
dag_views
[
0
])
!=
1
:
raise
Exception
(
"DAG contains multiple input Ops"
)
if
len
(
dag_views
[
-
1
])
!=
1
:
raise
Exception
(
"DAG contains multiple output Ops"
)
# create channels and virtual ops
virtual_op_idx
=
0
channel_idx
=
0
virtual_ops
=
[]
channels
=
[]
input_channel
=
None
for
v_idx
,
view
in
enumerate
(
dag_views
):
if
v_idx
+
1
>=
len
(
dag_views
):
break
next_view
=
dag_views
[
v_idx
+
1
]
actual_next_view
=
[]
pred_op_of_next_view_op
=
{}
for
op
in
view
:
# create virtual op
for
succ_op
in
outdegs
[
op
.
name
]:
if
succ_op
in
next_view
:
actual_next_view
.
append
(
succ_op
)
if
succ_op
.
name
not
in
pred_op_of_next_view_op
:
pred_op_of_next_view_op
[
succ_op
.
name
]
=
[]
pred_op_of_next_view_op
[
succ_op
.
name
].
append
(
op
)
else
:
vop
=
VirtualOp
(
name
=
"vir{}"
.
format
(
virtual_op_idx
))
virtual_op_idx
+=
1
virtual_ops
.
append
(
virtual_op
)
outdegs
[
vop
.
name
]
=
[
succ_op
]
actual_next_view
.
append
(
vop
)
# TODO: combine vop
pred_op_of_next_view_op
[
vop
.
name
]
=
[
op
]
# create channel
processed_op
=
set
()
for
o_idx
,
op
in
enumerate
(
actual_next_view
):
op_name
=
op
.
name
if
op_name
in
processed_op
:
continue
channel
=
Channel
(
name
=
"chl{}"
.
format
(
channel_idx
))
channel_idx
+=
1
channels
.
append
(
channel
)
op
.
add_input_channel
(
channel
)
pred_ops
=
pred_op_of_next_view_op
[
op_name
]
if
v_idx
==
0
:
input_channel
=
channel
else
:
for
pred_op
in
pred_ops
:
pred_op
.
add_output_channel
(
channel
)
processed_op
.
add
(
op_name
)
# combine channel
for
other_op
in
actual_next_view
[
o_idx
:]:
if
other_op
.
name
in
processed_op
:
continue
other_pred_ops
=
pred_op_of_next_view_op
[
other_op
.
name
]
if
len
(
other_pred_ops
)
!=
len
(
pred_ops
):
continue
same_flag
=
True
for
pred_op
in
pred_ops
:
if
pred_op
not
in
other_pred_ops
:
same_flag
=
False
break
if
same_flag
:
other_op
.
add_input_channel
(
channel
)
processed_op
.
add
(
other_op
.
name
)
output_channel
=
Channel
(
name
=
"Ochl"
)
channels
.
append
(
output_channel
)
last_op
=
dag_views
[
-
1
][
0
]
last_op
.
add_output_channel
(
output_channel
)
self
.
_ops
=
self
.
_user_ops
+
virtual_ops
self
.
_channels
=
channels
return
input_channel
,
output_channel
def
prepare_server
(
self
,
port
,
worker_num
):
def
prepare_server
(
self
,
port
,
worker_num
):
self
.
_port
=
port
self
.
_port
=
port
self
.
_worker_num
=
worker_num
self
.
_worker_num
=
worker_num
inputs
=
set
()
outputs
=
set
()
input_channel
,
output_channel
=
self
.
_topo_sort
()
for
op
in
self
.
_ops
:
self
.
_in_channel
=
input_channel
inputs
|=
set
([
op
.
get_input
()])
self
.
out_channel
=
output_channel
outputs
|=
set
(
op
.
get_outputs
())
if
op
.
with_serving
():
self
.
prepare_serving
(
op
)
in_channel
=
inputs
-
outputs
out_channel
=
outputs
-
inputs
if
len
(
in_channel
)
!=
1
or
len
(
out_channel
)
!=
1
:
raise
Exception
(
"in_channel(out_channel) more than 1 or no in_channel(out_channel)"
)
self
.
_in_channel
=
in_channel
.
pop
()
self
.
_out_channel
=
out_channel
.
pop
()
self
.
gen_desc
()
self
.
gen_desc
()
def
_op_start_wrapper
(
self
,
op
,
concurrency_idx
):
def
_op_start_wrapper
(
self
,
op
,
concurrency_idx
):
return
op
.
start
(
concurrency_idx
)
return
op
.
start
(
concurrency_idx
)
def
_run_ops
(
self
):
def
_run_ops
(
self
):
#TODO
for
op
in
self
.
_ops
:
for
op
in
self
.
_ops
:
op_concurrency
=
op
.
get_concurrency
()
op_concurrency
=
op
.
get_concurrency
()
logging
.
debug
(
"run op: {}, op_concurrency: {}"
.
format
(
logging
.
debug
(
"run op: {}, op_concurrency: {}"
.
format
(
op
.
_
name
,
op_concurrency
))
op
.
name
,
op_concurrency
))
for
c
in
range
(
op_concurrency
):
for
c
in
range
(
op_concurrency
):
# th = multiprocessing.Process(
# th = multiprocessing.Process(
th
=
threading
.
Thread
(
th
=
threading
.
Thread
(
...
@@ -730,6 +856,7 @@ class PyServer(object):
...
@@ -730,6 +856,7 @@ class PyServer(object):
self
.
_op_threads
.
append
(
th
)
self
.
_op_threads
.
append
(
th
)
def
_stop_ops
(
self
):
def
_stop_ops
(
self
):
# TODO
for
op
in
self
.
_ops
:
for
op
in
self
.
_ops
:
op
.
stop
()
op
.
stop
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录