Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
29caa6d2
S
Serving
项目概览
PaddlePaddle
/
Serving
大约 1 年 前同步成功
通知
186
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
29caa6d2
编写于
7月 31, 2020
作者:
B
barriery
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug in dag-executor and update log
上级
94ea6590
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
325 addition
and
174 deletion
+325
-174
python/pipeline/channel.py
python/pipeline/channel.py
+38
-21
python/pipeline/dag.py
python/pipeline/dag.py
+70
-75
python/pipeline/operator.py
python/pipeline/operator.py
+77
-49
python/pipeline/pipeline_server.py
python/pipeline/pipeline_server.py
+138
-29
python/pipeline/profiler.py
python/pipeline/profiler.py
+2
-0
未找到文件。
python/pipeline/channel.py
浏览文件 @
29caa6d2
...
@@ -26,6 +26,7 @@ else:
...
@@ -26,6 +26,7 @@ else:
import
numpy
as
np
import
numpy
as
np
import
logging
import
logging
import
enum
import
enum
import
os
import
copy
import
copy
_LOGGER
=
logging
.
getLogger
()
_LOGGER
=
logging
.
getLogger
()
...
@@ -69,7 +70,8 @@ class ChannelData(object):
...
@@ -69,7 +70,8 @@ class ChannelData(object):
'''
'''
if
ecode
is
not
None
:
if
ecode
is
not
None
:
if
data_id
is
None
or
error_info
is
None
:
if
data_id
is
None
or
error_info
is
None
:
raise
ValueError
(
"data_id and error_info cannot be None"
)
_LOGGER
.
critical
(
"data_id and error_info cannot be None"
)
os
.
_exit
(
-
1
)
datatype
=
ChannelDataType
.
ERROR
.
value
datatype
=
ChannelDataType
.
ERROR
.
value
else
:
else
:
if
datatype
==
ChannelDataType
.
CHANNEL_NPDATA
.
value
:
if
datatype
==
ChannelDataType
.
CHANNEL_NPDATA
.
value
:
...
@@ -83,7 +85,8 @@ class ChannelData(object):
...
@@ -83,7 +85,8 @@ class ChannelData(object):
datatype
=
ChannelDataType
.
ERROR
.
value
datatype
=
ChannelDataType
.
ERROR
.
value
_LOGGER
.
error
(
error_info
)
_LOGGER
.
error
(
error_info
)
else
:
else
:
raise
ValueError
(
"datatype not match"
)
_LOGGER
.
critical
(
"datatype not match"
)
os
.
_exit
(
-
1
)
self
.
datatype
=
datatype
self
.
datatype
=
datatype
self
.
npdata
=
npdata
self
.
npdata
=
npdata
self
.
dictdata
=
dictdata
self
.
dictdata
=
dictdata
...
@@ -168,7 +171,9 @@ class ChannelData(object):
...
@@ -168,7 +171,9 @@ class ChannelData(object):
# return dict
# return dict
feed
=
self
.
dictdata
feed
=
self
.
dictdata
else
:
else
:
raise
TypeError
(
"Error type({}) in datatype."
.
format
(
self
.
datatype
))
_LOGGER
.
critical
(
"Error type({}) in datatype."
.
format
(
self
.
datatype
))
os
.
_exit
(
-
1
)
return
feed
return
feed
def
__str__
(
self
):
def
__str__
(
self
):
...
@@ -241,30 +246,35 @@ class ProcessChannel(object):
...
@@ -241,30 +246,35 @@ class ProcessChannel(object):
def
add_producer
(
self
,
op_name
):
def
add_producer
(
self
,
op_name
):
""" not thread safe, and can only be called during initialization. """
""" not thread safe, and can only be called during initialization. """
if
op_name
in
self
.
_producers
:
if
op_name
in
self
.
_producers
:
raise
ValueError
(
_LOGGER
.
critical
(
self
.
_log
(
"producer({}) is already in channel"
.
format
(
op_name
)))
self
.
_log
(
"producer({}) is already in channel"
.
format
(
op_name
)))
os
.
_exit
(
-
1
)
self
.
_producers
.
append
(
op_name
)
self
.
_producers
.
append
(
op_name
)
_LOGGER
.
debug
(
self
.
_log
(
"add a producer: {}"
.
format
(
op_name
)))
def
add_consumer
(
self
,
op_name
):
def
add_consumer
(
self
,
op_name
):
""" not thread safe, and can only be called during initialization. """
""" not thread safe, and can only be called during initialization. """
if
op_name
in
self
.
_consumer_cursors
:
if
op_name
in
self
.
_consumer_cursors
:
raise
ValueError
(
_LOGGER
.
critical
(
self
.
_log
(
"consumer({}) is already in channel"
.
format
(
op_name
)))
self
.
_log
(
"consumer({}) is already in channel"
.
format
(
op_name
)))
os
.
_exit
(
-
1
)
self
.
_consumer_cursors
[
op_name
]
=
0
self
.
_consumer_cursors
[
op_name
]
=
0
if
self
.
_cursor_count
.
get
(
0
)
is
None
:
if
self
.
_cursor_count
.
get
(
0
)
is
None
:
self
.
_cursor_count
[
0
]
=
0
self
.
_cursor_count
[
0
]
=
0
self
.
_cursor_count
[
0
]
+=
1
self
.
_cursor_count
[
0
]
+=
1
_LOGGER
.
debug
(
self
.
_log
(
"add a consumer: {}"
.
format
(
op_name
)))
def
push
(
self
,
channeldata
,
op_name
=
None
):
def
push
(
self
,
channeldata
,
op_name
=
None
):
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
"{} try to push data[{}]"
.
format
(
op_name
,
self
.
_log
(
"{} try to push data[{}]"
.
format
(
op_name
,
channeldata
.
id
)))
channeldata
.
id
)))
if
len
(
self
.
_producers
)
==
0
:
if
len
(
self
.
_producers
)
==
0
:
raise
Exception
(
_LOGGER
.
critical
(
self
.
_log
(
self
.
_log
(
"expected number of producers to be greater than 0, but the it is 0."
"expected number of producers to be greater than 0, but the it is 0."
))
))
os
.
_exit
(
-
1
)
elif
len
(
self
.
_producers
)
==
1
:
elif
len
(
self
.
_producers
)
==
1
:
with
self
.
_cv
:
with
self
.
_cv
:
while
self
.
_stop
.
value
==
0
:
while
self
.
_stop
.
value
==
0
:
...
@@ -281,9 +291,10 @@ class ProcessChannel(object):
...
@@ -281,9 +291,10 @@ class ProcessChannel(object):
op_name
,
channeldata
.
id
)))
op_name
,
channeldata
.
id
)))
return
True
return
True
elif
op_name
is
None
:
elif
op_name
is
None
:
raise
Exception
(
_LOGGER
.
critical
(
self
.
_log
(
self
.
_log
(
"There are multiple producers, so op_name cannot be None."
))
"There are multiple producers, so op_name cannot be None."
))
os
.
_exit
(
-
1
)
producer_num
=
len
(
self
.
_producers
)
producer_num
=
len
(
self
.
_producers
)
data_id
=
channeldata
.
id
data_id
=
channeldata
.
id
...
@@ -340,10 +351,11 @@ class ProcessChannel(object):
...
@@ -340,10 +351,11 @@ class ProcessChannel(object):
endtime
=
_time
()
+
timeout
endtime
=
_time
()
+
timeout
if
len
(
self
.
_consumer_cursors
)
==
0
:
if
len
(
self
.
_consumer_cursors
)
==
0
:
raise
Exception
(
_LOGGER
.
critical
(
self
.
_log
(
self
.
_log
(
"expected number of consumers to be greater than 0, but the it is 0."
"expected number of consumers to be greater than 0, but the it is 0."
))
))
os
.
_exit
(
-
1
)
elif
len
(
self
.
_consumer_cursors
)
==
1
:
elif
len
(
self
.
_consumer_cursors
)
==
1
:
resp
=
None
resp
=
None
with
self
.
_cv
:
with
self
.
_cv
:
...
@@ -369,9 +381,10 @@ class ProcessChannel(object):
...
@@ -369,9 +381,10 @@ class ProcessChannel(object):
resp
.
values
()[
0
].
id
)))
resp
.
values
()[
0
].
id
)))
return
resp
return
resp
elif
op_name
is
None
:
elif
op_name
is
None
:
raise
Exception
(
_LOGGER
.
critical
(
self
.
_log
(
self
.
_log
(
"There are multiple consumers, so op_name cannot be None."
))
"There are multiple consumers, so op_name cannot be None."
))
os
.
_exit
(
-
1
)
# In output_buf, different Ops (according to op_name) have different
# In output_buf, different Ops (according to op_name) have different
# cursors. In addition, there is a base_cursor. Their difference is
# cursors. In addition, there is a base_cursor. Their difference is
...
@@ -450,7 +463,7 @@ class ProcessChannel(object):
...
@@ -450,7 +463,7 @@ class ProcessChannel(object):
return
resp
return
resp
def
stop
(
self
):
def
stop
(
self
):
_LOGGER
.
debug
(
self
.
_log
(
"stop."
))
_LOGGER
.
info
(
self
.
_log
(
"stop."
))
self
.
_stop
.
value
=
1
self
.
_stop
.
value
=
1
with
self
.
_cv
:
with
self
.
_cv
:
self
.
_cv
.
notify_all
()
self
.
_cv
.
notify_all
()
...
@@ -512,37 +525,38 @@ class ThreadChannel(Queue.Queue):
...
@@ -512,37 +525,38 @@ class ThreadChannel(Queue.Queue):
def
_log
(
self
,
info_str
):
def
_log
(
self
,
info_str
):
return
"[{}] {}"
.
format
(
self
.
name
,
info_str
)
return
"[{}] {}"
.
format
(
self
.
name
,
info_str
)
def
debug
(
self
):
return
self
.
_log
(
"p: {}, c: {}"
.
format
(
self
.
get_producers
(),
self
.
get_consumers
()))
def
add_producer
(
self
,
op_name
):
def
add_producer
(
self
,
op_name
):
""" not thread safe, and can only be called during initialization. """
""" not thread safe, and can only be called during initialization. """
if
op_name
in
self
.
_producers
:
if
op_name
in
self
.
_producers
:
raise
ValueError
(
_LOGGER
.
critical
(
self
.
_log
(
"producer({}) is already in channel"
.
format
(
op_name
)))
self
.
_log
(
"producer({}) is already in channel"
.
format
(
op_name
)))
os
.
_exit
(
-
1
)
self
.
_producers
.
append
(
op_name
)
self
.
_producers
.
append
(
op_name
)
_LOGGER
.
debug
(
self
.
_log
(
"add a producer: {}"
.
format
(
op_name
)))
def
add_consumer
(
self
,
op_name
):
def
add_consumer
(
self
,
op_name
):
""" not thread safe, and can only be called during initialization. """
""" not thread safe, and can only be called during initialization. """
if
op_name
in
self
.
_consumer_cursors
:
if
op_name
in
self
.
_consumer_cursors
:
raise
ValueError
(
_LOGGER
.
critical
(
self
.
_log
(
"consumer({}) is already in channel"
.
format
(
op_name
)))
self
.
_log
(
"consumer({}) is already in channel"
.
format
(
op_name
)))
os
.
_exit
(
-
1
)
self
.
_consumer_cursors
[
op_name
]
=
0
self
.
_consumer_cursors
[
op_name
]
=
0
if
self
.
_cursor_count
.
get
(
0
)
is
None
:
if
self
.
_cursor_count
.
get
(
0
)
is
None
:
self
.
_cursor_count
[
0
]
=
0
self
.
_cursor_count
[
0
]
=
0
self
.
_cursor_count
[
0
]
+=
1
self
.
_cursor_count
[
0
]
+=
1
_LOGGER
.
debug
(
self
.
_log
(
"add a consumer: {}"
.
format
(
op_name
)))
def
push
(
self
,
channeldata
,
op_name
=
None
):
def
push
(
self
,
channeldata
,
op_name
=
None
):
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
"{} try to push data[{}]"
.
format
(
op_name
,
self
.
_log
(
"{} try to push data[{}]"
.
format
(
op_name
,
channeldata
.
id
)))
channeldata
.
id
)))
if
len
(
self
.
_producers
)
==
0
:
if
len
(
self
.
_producers
)
==
0
:
raise
Exception
(
_LOGGER
.
critical
(
self
.
_log
(
self
.
_log
(
"expected number of producers to be greater than 0, but the it is 0."
"expected number of producers to be greater than 0, but the it is 0."
))
))
os
.
_exit
(
-
1
)
elif
len
(
self
.
_producers
)
==
1
:
elif
len
(
self
.
_producers
)
==
1
:
with
self
.
_cv
:
with
self
.
_cv
:
while
self
.
_stop
is
False
:
while
self
.
_stop
is
False
:
...
@@ -559,9 +573,10 @@ class ThreadChannel(Queue.Queue):
...
@@ -559,9 +573,10 @@ class ThreadChannel(Queue.Queue):
op_name
,
channeldata
.
id
)))
op_name
,
channeldata
.
id
)))
return
True
return
True
elif
op_name
is
None
:
elif
op_name
is
None
:
raise
Exception
(
_LOGGER
.
critical
(
self
.
_log
(
self
.
_log
(
"There are multiple producers, so op_name cannot be None."
))
"There are multiple producers, so op_name cannot be None."
))
os
.
_exit
(
-
1
)
producer_num
=
len
(
self
.
_producers
)
producer_num
=
len
(
self
.
_producers
)
data_id
=
channeldata
.
id
data_id
=
channeldata
.
id
...
@@ -613,10 +628,11 @@ class ThreadChannel(Queue.Queue):
...
@@ -613,10 +628,11 @@ class ThreadChannel(Queue.Queue):
endtime
=
_time
()
+
timeout
endtime
=
_time
()
+
timeout
if
len
(
self
.
_consumer_cursors
)
==
0
:
if
len
(
self
.
_consumer_cursors
)
==
0
:
raise
Exception
(
_LOGGER
.
critical
(
self
.
_log
(
self
.
_log
(
"expected number of consumers to be greater than 0, but the it is 0."
"expected number of consumers to be greater than 0, but the it is 0."
))
))
os
.
_exit
(
-
1
)
elif
len
(
self
.
_consumer_cursors
)
==
1
:
elif
len
(
self
.
_consumer_cursors
)
==
1
:
resp
=
None
resp
=
None
with
self
.
_cv
:
with
self
.
_cv
:
...
@@ -642,9 +658,10 @@ class ThreadChannel(Queue.Queue):
...
@@ -642,9 +658,10 @@ class ThreadChannel(Queue.Queue):
resp
.
values
()[
0
].
id
)))
resp
.
values
()[
0
].
id
)))
return
resp
return
resp
elif
op_name
is
None
:
elif
op_name
is
None
:
raise
Exception
(
_LOGGER
.
critical
(
self
.
_log
(
self
.
_log
(
"There are multiple consumers, so op_name cannot be None."
))
"There are multiple consumers, so op_name cannot be None."
))
os
.
_exit
(
-
1
)
# In output_buf, different Ops (according to op_name) have different
# In output_buf, different Ops (according to op_name) have different
# cursors. In addition, there is a base_cursor. Their difference is
# cursors. In addition, there is a base_cursor. Their difference is
...
@@ -723,7 +740,7 @@ class ThreadChannel(Queue.Queue):
...
@@ -723,7 +740,7 @@ class ThreadChannel(Queue.Queue):
return
resp
return
resp
def
stop
(
self
):
def
stop
(
self
):
_LOGGER
.
debug
(
self
.
_log
(
"stop."
))
_LOGGER
.
info
(
self
.
_log
(
"stop."
))
self
.
_stop
=
True
self
.
_stop
=
True
with
self
.
_cv
:
with
self
.
_cv
:
self
.
_cv
.
notify_all
()
self
.
_cv
.
notify_all
()
...
...
python/pipeline/dag.py
浏览文件 @
29caa6d2
...
@@ -35,33 +35,13 @@ _LOGGER = logging.getLogger()
...
@@ -35,33 +35,13 @@ _LOGGER = logging.getLogger()
class
DAGExecutor
(
object
):
class
DAGExecutor
(
object
):
def
__init__
(
self
,
response_op
,
dag_config
,
show_info
):
def
__init__
(
self
,
response_op
,
dag_conf
):
default_conf
=
{
self
.
_retry
=
dag_conf
[
"retry"
]
"retry"
:
1
,
client_type
=
dag_conf
[
"client_type"
]
"client_type"
:
"brpc"
,
self
.
_server_use_profile
=
dag_conf
[
"use_profile"
]
"use_profile"
:
False
,
channel_size
=
dag_conf
[
"channel_size"
]
"channel_size"
:
0
,
self
.
_is_thread_op
=
dag_conf
[
"is_thread_op"
]
"is_thread_op"
:
True
build_dag_each_worker
=
dag_conf
[
"build_dag_each_worker"
]
}
for
key
,
val
in
default_conf
.
items
():
if
dag_config
.
get
(
key
)
is
None
:
_LOGGER
.
warning
(
"[CONF] {} not set, use default: {}"
.
format
(
key
,
val
))
dag_config
[
key
]
=
val
self
.
_retry
=
dag_config
[
"retry"
]
client_type
=
dag_config
[
"client_type"
]
self
.
_server_use_profile
=
dag_config
[
"use_profile"
]
channel_size
=
dag_config
[
"channel_size"
]
self
.
_is_thread_op
=
dag_config
[
"is_thread_op"
]
build_dag_each_worker
=
dag_config
[
"build_dag_each_worker"
]
if
show_info
:
_LOGGER
.
info
(
"=============== DAGExecutor ==============="
)
for
key
in
default_conf
.
keys
():
_LOGGER
.
info
(
"{}: {}"
.
format
(
key
,
dag_config
[
key
]))
_LOGGER
.
info
(
"-------------------------------------------"
)
self
.
name
=
"@G"
self
.
name
=
"@G"
self
.
_profiler
=
TimeProfiler
()
self
.
_profiler
=
TimeProfiler
()
...
@@ -69,7 +49,7 @@ class DAGExecutor(object):
...
@@ -69,7 +49,7 @@ class DAGExecutor(object):
self
.
_dag
=
DAG
(
self
.
name
,
response_op
,
self
.
_server_use_profile
,
self
.
_dag
=
DAG
(
self
.
name
,
response_op
,
self
.
_server_use_profile
,
self
.
_is_thread_op
,
client_type
,
channel_size
,
self
.
_is_thread_op
,
client_type
,
channel_size
,
show_info
,
build_dag_each_worker
)
build_dag_each_worker
)
(
in_channel
,
out_channel
,
pack_rpc_func
,
(
in_channel
,
out_channel
,
pack_rpc_func
,
unpack_rpc_func
)
=
self
.
_dag
.
build
()
unpack_rpc_func
)
=
self
.
_dag
.
build
()
self
.
_dag
.
start
()
self
.
_dag
.
start
()
...
@@ -84,7 +64,7 @@ class DAGExecutor(object):
...
@@ -84,7 +64,7 @@ class DAGExecutor(object):
self
.
_reset_max_id
=
1000000000000000000
self
.
_reset_max_id
=
1000000000000000000
self
.
_cv_pool
=
{}
self
.
_cv_pool
=
{}
self
.
_cv_for_cv_pool
=
threading
.
Condition
()
self
.
_cv_for_cv_pool
=
threading
.
Condition
()
self
.
_fetch_buffer
=
None
self
.
_fetch_buffer
=
{}
self
.
_recive_func
=
None
self
.
_recive_func
=
None
self
.
_client_profile_key
=
"pipeline.profile"
self
.
_client_profile_key
=
"pipeline.profile"
...
@@ -111,19 +91,22 @@ class DAGExecutor(object):
...
@@ -111,19 +91,22 @@ class DAGExecutor(object):
cond_v
=
threading
.
Condition
()
cond_v
=
threading
.
Condition
()
with
self
.
_cv_for_cv_pool
:
with
self
.
_cv_for_cv_pool
:
self
.
_cv_pool
[
data_id
]
=
cond_v
self
.
_cv_pool
[
data_id
]
=
cond_v
self
.
_fetch_buffer
[
data_id
]
=
None
return
data_id
,
cond_v
return
data_id
,
cond_v
def
_set_in_channel
(
self
,
in_channel
):
def
_set_in_channel
(
self
,
in_channel
):
if
not
isinstance
(
in_channel
,
(
ThreadChannel
,
ProcessChannel
)):
if
not
isinstance
(
in_channel
,
(
ThreadChannel
,
ProcessChannel
)):
raise
TypeError
(
"in_channel must be Channel type, but get {}"
.
_LOGGER
.
critical
(
"[DAG Executor] in_channel must be Channel"
format
(
type
(
in_channel
)))
" type, but get {}"
.
format
(
type
(
in_channel
)))
os
.
_exit
(
-
1
)
in_channel
.
add_producer
(
self
.
name
)
in_channel
.
add_producer
(
self
.
name
)
self
.
_in_channel
=
in_channel
self
.
_in_channel
=
in_channel
def
_set_out_channel
(
self
,
out_channel
):
def
_set_out_channel
(
self
,
out_channel
):
if
not
isinstance
(
out_channel
,
(
ThreadChannel
,
ProcessChannel
)):
if
not
isinstance
(
out_channel
,
(
ThreadChannel
,
ProcessChannel
)):
raise
TypeError
(
"iout_channel must be Channel type, but get {}"
.
_LOGGER
.
critical
(
"[DAG Executor]iout_channel must be Channel"
format
(
type
(
out_channel
)))
" type, but get {}"
.
format
(
type
(
out_channel
)))
os
.
_exit
(
-
1
)
out_channel
.
add_consumer
(
self
.
name
)
out_channel
.
add_consumer
(
self
.
name
)
self
.
_out_channel
=
out_channel
self
.
_out_channel
=
out_channel
...
@@ -133,7 +116,7 @@ class DAGExecutor(object):
...
@@ -133,7 +116,7 @@ class DAGExecutor(object):
try
:
try
:
channeldata_dict
=
self
.
_out_channel
.
front
(
self
.
name
)
channeldata_dict
=
self
.
_out_channel
.
front
(
self
.
name
)
except
ChannelStopError
:
except
ChannelStopError
:
_LOGGER
.
debug
(
"[DAG Executor] channel stop."
)
_LOGGER
.
info
(
"[DAG Executor] channel stop."
)
with
self
.
_cv_for_cv_pool
:
with
self
.
_cv_for_cv_pool
:
for
data_id
,
cv
in
self
.
_cv_pool
.
items
():
for
data_id
,
cv
in
self
.
_cv_pool
.
items
():
closed_errror_data
=
ChannelData
(
closed_errror_data
=
ChannelData
(
...
@@ -141,17 +124,17 @@ class DAGExecutor(object):
...
@@ -141,17 +124,17 @@ class DAGExecutor(object):
error_info
=
"dag closed."
,
error_info
=
"dag closed."
,
data_id
=
data_id
)
data_id
=
data_id
)
with
cv
:
with
cv
:
self
.
_fetch_buffer
=
closed_errror_data
self
.
_fetch_buffer
[
data_id
]
=
closed_errror_data
cv
.
notify_all
()
cv
.
notify_all
()
break
break
if
len
(
channeldata_dict
)
!=
1
:
if
len
(
channeldata_dict
)
!=
1
:
_LOGGER
.
error
(
_LOGGER
.
critical
(
"[DAG Executor] out_channel cannot have multiple input ops"
)
"[DAG Executor] out_channel cannot have multiple input ops"
)
os
.
_exit
(
-
1
)
os
.
_exit
(
-
1
)
(
_
,
channeldata
),
=
channeldata_dict
.
items
()
(
_
,
channeldata
),
=
channeldata_dict
.
items
()
if
not
isinstance
(
channeldata
,
ChannelData
):
if
not
isinstance
(
channeldata
,
ChannelData
):
_LOGGER
.
error
(
_LOGGER
.
critical
(
'[DAG Executor] data must be ChannelData type, but get {}'
'[DAG Executor] data must be ChannelData type, but get {}'
.
format
(
type
(
channeldata
)))
.
format
(
type
(
channeldata
)))
os
.
_exit
(
-
1
)
os
.
_exit
(
-
1
)
...
@@ -159,20 +142,30 @@ class DAGExecutor(object):
...
@@ -159,20 +142,30 @@ class DAGExecutor(object):
data_id
=
channeldata
.
id
data_id
=
channeldata
.
id
_LOGGER
.
debug
(
"recive thread fetch data[{}]"
.
format
(
data_id
))
_LOGGER
.
debug
(
"recive thread fetch data[{}]"
.
format
(
data_id
))
with
self
.
_cv_for_cv_pool
:
with
self
.
_cv_for_cv_pool
:
cv
=
self
.
_cv_pool
[
data_id
]
c
ond_
v
=
self
.
_cv_pool
[
data_id
]
with
cv
:
with
c
ond_
v
:
self
.
_fetch_buffer
=
channeldata
self
.
_fetch_buffer
[
data_id
]
=
channeldata
cv
.
notify_all
()
c
ond_
v
.
notify_all
()
def
_get_channeldata_from_fetch_buffer
(
self
,
data_id
,
cond_v
):
def
_get_channeldata_from_fetch_buffer
(
self
,
data_id
,
cond_v
):
resp
=
None
ready_data
=
None
with
cond_v
:
with
cond_v
:
cond_v
.
wait
()
with
self
.
_cv_for_cv_pool
:
with
self
.
_cv_for_cv_pool
:
if
self
.
_fetch_buffer
[
data_id
]
is
not
None
:
resp
=
copy
.
deepcopy
(
self
.
_fetch_buffer
)
# The requested data is already ready
_LOGGER
.
debug
(
"resp thread get resp data[{}]"
.
format
(
data_id
))
ready_data
=
self
.
_fetch_buffer
[
data_id
]
self
.
_cv_pool
.
pop
(
data_id
)
self
.
_cv_pool
.
pop
(
data_id
)
return
resp
self
.
_fetch_buffer
.
pop
(
data_id
)
if
ready_data
is
None
:
# Wait for data ready
cond_v
.
wait
()
with
self
.
_cv_for_cv_pool
:
ready_data
=
self
.
_fetch_buffer
[
data_id
]
self
.
_cv_pool
.
pop
(
data_id
)
self
.
_fetch_buffer
.
pop
(
data_id
)
_LOGGER
.
debug
(
"resp thread get resp data[{}]"
.
format
(
data_id
))
return
ready_data
def
_pack_channeldata
(
self
,
rpc_request
,
data_id
):
def
_pack_channeldata
(
self
,
rpc_request
,
data_id
):
dictdata
=
None
dictdata
=
None
...
@@ -204,14 +197,14 @@ class DAGExecutor(object):
...
@@ -204,14 +197,14 @@ class DAGExecutor(object):
def
call
(
self
,
rpc_request
):
def
call
(
self
,
rpc_request
):
data_id
,
cond_v
=
self
.
_get_next_data_id
()
data_id
,
cond_v
=
self
.
_get_next_data_id
()
_LOGGER
.
debug
(
"generate id: {}"
.
format
(
data_id
))
_LOGGER
.
debug
(
"generate
Request
id: {}"
.
format
(
data_id
))
if
not
self
.
_is_thread_op
:
if
not
self
.
_is_thread_op
:
self
.
_profiler
.
record
(
"call_{}#DAG-{}_0"
.
format
(
data_id
,
data_id
))
self
.
_profiler
.
record
(
"call_{}#DAG-{}_0"
.
format
(
data_id
,
data_id
))
else
:
else
:
self
.
_profiler
.
record
(
"call_{}#DAG_0"
.
format
(
data_id
))
self
.
_profiler
.
record
(
"call_{}#DAG_0"
.
format
(
data_id
))
_LOGGER
.
debug
(
"try parse RPC
package
to channeldata[{}]"
.
format
(
_LOGGER
.
debug
(
"try parse RPC
request
to channeldata[{}]"
.
format
(
data_id
))
data_id
))
self
.
_profiler
.
record
(
"prepack_{}#{}_0"
.
format
(
data_id
,
self
.
name
))
self
.
_profiler
.
record
(
"prepack_{}#{}_0"
.
format
(
data_id
,
self
.
name
))
req_channeldata
=
self
.
_pack_channeldata
(
rpc_request
,
data_id
)
req_channeldata
=
self
.
_pack_channeldata
(
rpc_request
,
data_id
)
...
@@ -232,26 +225,24 @@ class DAGExecutor(object):
...
@@ -232,26 +225,24 @@ class DAGExecutor(object):
error_info
=
"dag closed."
,
error_info
=
"dag closed."
,
data_id
=
data_id
))
data_id
=
data_id
))
_LOGGER
.
debug
(
"wait for Graph engine for data[{}]..."
.
format
(
_LOGGER
.
debug
(
"wait Graph engine for data[{}]..."
.
format
(
data_id
))
data_id
))
resp_channeldata
=
self
.
_get_channeldata_from_fetch_buffer
(
data_id
,
resp_channeldata
=
self
.
_get_channeldata_from_fetch_buffer
(
data_id
,
cond_v
)
cond_v
)
if
resp_channeldata
.
ecode
==
ChannelDataEcode
.
OK
.
value
:
if
resp_channeldata
.
ecode
==
ChannelDataEcode
.
OK
.
value
:
_LOGGER
.
debug
(
"Graph engine predict data[{}] succ"
.
format
(
_LOGGER
.
debug
(
"request[{}] succ predict"
.
format
(
data_id
))
data_id
))
break
break
else
:
else
:
_LOGGER
.
warn
(
"Graph engine predict data[{}]
failed: {}"
_LOGGER
.
warn
ing
(
"request[{}] predict
failed: {}"
.
format
(
data_id
,
resp_channeldata
.
error_info
))
.
format
(
data_id
,
resp_channeldata
.
error_info
))
if
resp_channeldata
.
ecode
!=
ChannelDataEcode
.
TIMEOUT
.
value
:
if
resp_channeldata
.
ecode
!=
ChannelDataEcode
.
TIMEOUT
.
value
:
break
break
if
i
+
1
<
self
.
_retry
:
if
i
+
1
<
self
.
_retry
:
_LOGGER
.
warn
(
"retry({}/{}) data[{}]"
.
format
(
i
+
1
,
self
.
_retry
,
_LOGGER
.
warn
ing
(
"retry({}/{}) data[{}]"
.
format
(
data_id
))
i
+
1
,
self
.
_retry
,
data_id
))
_LOGGER
.
debug
(
"unpack channeldata[{}] into RPC resp
packag
e"
.
format
(
_LOGGER
.
debug
(
"unpack channeldata[{}] into RPC resp
ons
e"
.
format
(
data_id
))
data_id
))
self
.
_profiler
.
record
(
"postpack_{}#{}_0"
.
format
(
data_id
,
self
.
name
))
self
.
_profiler
.
record
(
"postpack_{}#{}_0"
.
format
(
data_id
,
self
.
name
))
rpc_resp
=
self
.
_pack_for_rpc_resp
(
resp_channeldata
)
rpc_resp
=
self
.
_pack_for_rpc_resp
(
resp_channeldata
)
...
@@ -282,14 +273,13 @@ class DAGExecutor(object):
...
@@ -282,14 +273,13 @@ class DAGExecutor(object):
class
DAG
(
object
):
class
DAG
(
object
):
def
__init__
(
self
,
request_name
,
response_op
,
use_profile
,
is_thread_op
,
def
__init__
(
self
,
request_name
,
response_op
,
use_profile
,
is_thread_op
,
client_type
,
channel_size
,
show_info
,
build_dag_each_worker
):
client_type
,
channel_size
,
build_dag_each_worker
):
self
.
_request_name
=
request_name
self
.
_request_name
=
request_name
self
.
_response_op
=
response_op
self
.
_response_op
=
response_op
self
.
_use_profile
=
use_profile
self
.
_use_profile
=
use_profile
self
.
_is_thread_op
=
is_thread_op
self
.
_is_thread_op
=
is_thread_op
self
.
_channel_size
=
channel_size
self
.
_channel_size
=
channel_size
self
.
_client_type
=
client_type
self
.
_client_type
=
client_type
self
.
_show_info
=
show_info
self
.
_build_dag_each_worker
=
build_dag_each_worker
self
.
_build_dag_each_worker
=
build_dag_each_worker
if
not
self
.
_is_thread_op
:
if
not
self
.
_is_thread_op
:
self
.
_manager
=
multiprocessing
.
Manager
()
self
.
_manager
=
multiprocessing
.
Manager
()
...
@@ -313,8 +303,9 @@ class DAG(object):
...
@@ -313,8 +303,9 @@ class DAG(object):
used_ops
.
add
(
pred_op
)
used_ops
.
add
(
pred_op
)
# check the name of op is globally unique
# check the name of op is globally unique
if
pred_op
.
name
in
unique_names
:
if
pred_op
.
name
in
unique_names
:
raise
Exception
(
"the name of Op must be unique: {}"
.
_LOGGER
.
critical
(
"the name of Op must be unique: {}"
.
format
(
pred_op
.
name
))
format
(
pred_op
.
name
))
os
.
_exit
(
-
1
)
unique_names
.
add
(
pred_op
.
name
)
unique_names
.
add
(
pred_op
.
name
)
return
used_ops
,
succ_ops_of_use_op
return
used_ops
,
succ_ops_of_use_op
...
@@ -346,7 +337,8 @@ class DAG(object):
...
@@ -346,7 +337,8 @@ class DAG(object):
if
len
(
op
.
get_input_ops
())
==
0
:
if
len
(
op
.
get_input_ops
())
==
0
:
zero_indegree_num
+=
1
zero_indegree_num
+=
1
if
zero_indegree_num
!=
1
:
if
zero_indegree_num
!=
1
:
raise
Exception
(
"DAG contains multiple input Ops"
)
_LOGGER
.
critical
(
"DAG contains multiple RequestOps"
)
os
.
_exit
(
-
1
)
last_op
=
response_op
.
get_input_ops
()[
0
]
last_op
=
response_op
.
get_input_ops
()[
0
]
ques
[
que_idx
].
put
(
last_op
)
ques
[
que_idx
].
put
(
last_op
)
...
@@ -370,24 +362,27 @@ class DAG(object):
...
@@ -370,24 +362,27 @@ class DAG(object):
break
break
que_idx
=
(
que_idx
+
1
)
%
2
que_idx
=
(
que_idx
+
1
)
%
2
if
sorted_op_num
<
len
(
used_ops
):
if
sorted_op_num
<
len
(
used_ops
):
raise
Exception
(
"not legal DAG"
)
_LOGGER
.
critical
(
"not legal DAG"
)
os
.
_exit
(
-
1
)
return
dag_views
,
last_op
return
dag_views
,
last_op
def
_build_dag
(
self
,
response_op
):
def
_build_dag
(
self
,
response_op
):
if
response_op
is
None
:
if
response_op
is
None
:
raise
Exception
(
"response_op has not been set."
)
_LOGGER
.
critical
(
"ResponseOp has not been set."
)
os
.
_exit
(
-
1
)
used_ops
,
out_degree_ops
=
self
.
get_use_ops
(
response_op
)
used_ops
,
out_degree_ops
=
self
.
get_use_ops
(
response_op
)
if
self
.
_show_info
:
if
not
self
.
_build_dag_each_worker
:
_LOGGER
.
info
(
"================= USED OP ================="
)
_LOGGER
.
info
(
"================= USED OP ================="
)
for
op
in
used_ops
:
for
op
in
used_ops
:
if
op
.
name
!=
self
.
_request_name
:
if
op
.
name
!=
self
.
_request_name
:
_LOGGER
.
info
(
op
.
name
)
_LOGGER
.
info
(
op
.
name
)
_LOGGER
.
info
(
"-------------------------------------------"
)
_LOGGER
.
info
(
"-------------------------------------------"
)
if
len
(
used_ops
)
<=
1
:
if
len
(
used_ops
)
<=
1
:
raise
Exception
(
_LOGGER
.
critical
(
"Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
"Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
)
)
os
.
_exit
(
-
1
)
if
self
.
_build_dag_each_worker
:
if
self
.
_build_dag_each_worker
:
_LOGGER
.
info
(
"Because `build_dag_each_worker` mode is used, "
_LOGGER
.
info
(
"Because `build_dag_each_worker` mode is used, "
"Auto-batching is set to the default config: "
"Auto-batching is set to the default config: "
...
@@ -398,15 +393,15 @@ class DAG(object):
...
@@ -398,15 +393,15 @@ class DAG(object):
dag_views
,
last_op
=
self
.
_topo_sort
(
used_ops
,
response_op
,
dag_views
,
last_op
=
self
.
_topo_sort
(
used_ops
,
response_op
,
out_degree_ops
)
out_degree_ops
)
dag_views
=
list
(
reversed
(
dag_views
))
dag_views
=
list
(
reversed
(
dag_views
))
if
self
.
_show_info
:
if
not
self
.
_build_dag_each_worker
:
_LOGGER
.
info
(
"================== DAG ===================="
)
_LOGGER
.
debug
(
"================== DAG ===================="
)
for
idx
,
view
in
enumerate
(
dag_views
):
for
idx
,
view
in
enumerate
(
dag_views
):
_LOGGER
.
info
(
"(VIEW {})"
.
format
(
idx
))
_LOGGER
.
debug
(
"(VIEW {})"
.
format
(
idx
))
for
op
in
view
:
for
op
in
view
:
_LOGGER
.
info
(
" [{}]"
.
format
(
op
.
name
))
_LOGGER
.
debug
(
" [{}]"
.
format
(
op
.
name
))
for
out_op
in
out_degree_ops
[
op
.
name
]:
for
out_op
in
out_degree_ops
[
op
.
name
]:
_LOGGER
.
info
(
" - {}"
.
format
(
out_op
.
name
))
_LOGGER
.
debug
(
" - {}"
.
format
(
out_op
.
name
))
_LOGGER
.
info
(
"-------------------------------------------"
)
_LOGGER
.
debug
(
"-------------------------------------------"
)
# create channels and virtual ops
# create channels and virtual ops
virtual_op_name_gen
=
NameGenerator
(
"vir"
)
virtual_op_name_gen
=
NameGenerator
(
"vir"
)
...
@@ -493,7 +488,7 @@ class DAG(object):
...
@@ -493,7 +488,7 @@ class DAG(object):
actual_ops
.
append
(
op
)
actual_ops
.
append
(
op
)
for
c
in
channels
:
for
c
in
channels
:
_LOGGER
.
debug
(
"Channel({}):
\n
-producers: {}
\n
-consumers: {}"
_LOGGER
.
debug
(
"Channel({}):
\n
\t
-producers: {}
\n\t
-consumers: {}"
.
format
(
c
.
name
,
c
.
get_producers
(),
c
.
get_consumers
()))
.
format
(
c
.
name
,
c
.
get_producers
(),
c
.
get_consumers
()))
return
(
actual_ops
,
channels
,
input_channel
,
output_channel
,
pack_func
,
return
(
actual_ops
,
channels
,
input_channel
,
output_channel
,
pack_func
,
...
...
python/pipeline/operator.py
浏览文件 @
29caa6d2
...
@@ -60,7 +60,10 @@ class Op(object):
...
@@ -60,7 +60,10 @@ class Op(object):
self
.
_client_config
=
client_config
self
.
_client_config
=
client_config
self
.
_fetch_names
=
fetch_list
self
.
_fetch_names
=
fetch_list
self
.
_timeout
=
timeout
if
timeout
>
0
:
self
.
_timeout
=
timeout
/
1000.0
else
:
self
.
_timeout
=
-
1
self
.
_retry
=
max
(
1
,
retry
)
self
.
_retry
=
max
(
1
,
retry
)
self
.
_input
=
None
self
.
_input
=
None
self
.
_outputs
=
[]
self
.
_outputs
=
[]
...
@@ -69,13 +72,32 @@ class Op(object):
...
@@ -69,13 +72,32 @@ class Op(object):
self
.
_auto_batching_timeout
=
auto_batching_timeout
self
.
_auto_batching_timeout
=
auto_batching_timeout
if
self
.
_auto_batching_timeout
is
not
None
:
if
self
.
_auto_batching_timeout
is
not
None
:
if
self
.
_auto_batching_timeout
<=
0
or
self
.
_batch_size
==
1
:
if
self
.
_auto_batching_timeout
<=
0
or
self
.
_batch_size
==
1
:
_LOGGER
.
warning
(
"Because auto_batching_timeout <= 0 or batch_size == 1,"
" set auto_batching_timeout to None."
)
self
.
_auto_batching_timeout
=
None
self
.
_auto_batching_timeout
=
None
else
:
else
:
self
.
_auto_batching_timeout
=
self
.
_auto_batching_timeout
/
1000.0
self
.
_auto_batching_timeout
=
self
.
_auto_batching_timeout
/
1000.0
if
not
isinstance
(
self
,
RequestOp
)
and
not
isinstance
(
self
,
ResponseOp
):
_LOGGER
.
info
(
self
.
_log
(
"
\n\t
input_ops: {},"
"
\n\t
server_endpoints: {}"
"
\n\t
fetch_list: {}"
"
\n\t
client_config: {}"
"
\n\t
concurrency: {},"
"
\n\t
timeout(s): {},"
"
\n\t
retry: {},"
"
\n\t
batch_size: {},"
"
\n\t
auto_batching_timeout(s): {}"
.
format
(
", "
.
join
([
op
.
name
for
op
in
input_ops
]),
self
.
_server_endpoints
,
self
.
_fetch_names
,
self
.
_client_config
,
self
.
concurrency
,
self
.
_timeout
,
self
.
_retry
,
self
.
_batch_size
,
self
.
_auto_batching_timeout
)))
self
.
_server_use_profile
=
False
self
.
_server_use_profile
=
False
# only for
multithread
# only for
thread op
self
.
_for_init_op_lock
=
threading
.
Lock
()
self
.
_for_init_op_lock
=
threading
.
Lock
()
self
.
_for_close_op_lock
=
threading
.
Lock
()
self
.
_for_close_op_lock
=
threading
.
Lock
()
self
.
_succ_init_op
=
False
self
.
_succ_init_op
=
False
...
@@ -83,11 +105,11 @@ class Op(object):
...
@@ -83,11 +105,11 @@ class Op(object):
def
use_default_auto_batching_config
(
self
):
def
use_default_auto_batching_config
(
self
):
if
self
.
_batch_size
!=
1
:
if
self
.
_batch_size
!=
1
:
_LOGGER
.
warn
(
"Op({}) reset batch_size=1 (original: {})"
_LOGGER
.
warn
ing
(
"Op({}) reset batch_size=1 (original: {})"
.
format
(
self
.
name
,
self
.
_batch_size
))
.
format
(
self
.
name
,
self
.
_batch_size
))
self
.
_batch_size
=
1
self
.
_batch_size
=
1
if
self
.
_auto_batching_timeout
!=
None
:
if
self
.
_auto_batching_timeout
!=
None
:
_LOGGER
.
warn
(
_LOGGER
.
warn
ing
(
"Op({}) reset auto_batching_timeout=None (original: {})"
"Op({}) reset auto_batching_timeout=None (original: {})"
.
format
(
self
.
name
,
self
.
_auto_batching_timeout
))
.
format
(
self
.
name
,
self
.
_auto_batching_timeout
))
self
.
_auto_batching_timeout
=
None
self
.
_auto_batching_timeout
=
None
...
@@ -100,12 +122,7 @@ class Op(object):
...
@@ -100,12 +122,7 @@ class Op(object):
if
self
.
with_serving
==
False
:
if
self
.
with_serving
==
False
:
_LOGGER
.
info
(
"Op({}) no client"
.
format
(
self
.
name
))
_LOGGER
.
info
(
"Op({}) no client"
.
format
(
self
.
name
))
return
None
return
None
_LOGGER
.
info
(
"Op({}) service endpoints: {}"
.
format
(
self
.
name
,
server_endpoints
))
_LOGGER
.
debug
(
"Op({}) fetch_names: {}"
.
format
(
self
.
name
,
fetch_names
))
if
client_type
==
'brpc'
:
if
client_type
==
'brpc'
:
_LOGGER
.
debug
(
"Op({}) client_config: {}"
.
format
(
self
.
name
,
client_config
))
client
=
Client
()
client
=
Client
()
client
.
load_client_config
(
client_config
)
client
.
load_client_config
(
client_config
)
elif
client_type
==
'grpc'
:
elif
client_type
==
'grpc'
:
...
@@ -125,16 +142,18 @@ class Op(object):
...
@@ -125,16 +142,18 @@ class Op(object):
self
.
_input_ops
=
[]
self
.
_input_ops
=
[]
for
op
in
ops
:
for
op
in
ops
:
if
not
isinstance
(
op
,
Op
):
if
not
isinstance
(
op
,
Op
):
raise
TypeError
(
_LOGGER
.
critical
(
self
.
_log
(
'input op must be Op type, not {}'
.
format
(
self
.
_log
(
"input op must be Op type, not {}"
type
(
op
))))
.
format
(
type
(
op
))))
os
.
_exit
(
-
1
)
self
.
_input_ops
.
append
(
op
)
self
.
_input_ops
.
append
(
op
)
def
add_input_channel
(
self
,
channel
):
def
add_input_channel
(
self
,
channel
):
if
not
isinstance
(
channel
,
(
ThreadChannel
,
ProcessChannel
)):
if
not
isinstance
(
channel
,
(
ThreadChannel
,
ProcessChannel
)):
raise
TypeError
(
_LOGGER
.
critical
(
self
.
_log
(
'input channel must be Channel type, not {}'
.
format
(
self
.
_log
(
"input channel must be Channel type, not {}"
type
(
channel
))))
.
format
(
type
(
channel
))))
os
.
_exit
(
-
1
)
channel
.
add_consumer
(
self
.
name
)
channel
.
add_consumer
(
self
.
name
)
self
.
_input
=
channel
self
.
_input
=
channel
...
@@ -146,9 +165,10 @@ class Op(object):
...
@@ -146,9 +165,10 @@ class Op(object):
def
add_output_channel
(
self
,
channel
):
def
add_output_channel
(
self
,
channel
):
if
not
isinstance
(
channel
,
(
ThreadChannel
,
ProcessChannel
)):
if
not
isinstance
(
channel
,
(
ThreadChannel
,
ProcessChannel
)):
raise
TypeError
(
_LOGGER
.
critical
(
self
.
_log
(
'output channel must be Channel type, not {}'
.
format
(
self
.
_log
(
"output channel must be Channel type, not {}"
type
(
channel
))))
.
format
(
type
(
channel
))))
os
.
_exit
(
-
1
)
channel
.
add_producer
(
self
.
name
)
channel
.
add_producer
(
self
.
name
)
self
.
_outputs
.
append
(
channel
)
self
.
_outputs
.
append
(
channel
)
...
@@ -161,9 +181,11 @@ class Op(object):
...
@@ -161,9 +181,11 @@ class Op(object):
def
preprocess
(
self
,
input_dicts
):
def
preprocess
(
self
,
input_dicts
):
# multiple previous Op
# multiple previous Op
if
len
(
input_dicts
)
!=
1
:
if
len
(
input_dicts
)
!=
1
:
raise
NotImplementedError
(
_LOGGER
.
critical
(
'this Op has multiple previous inputs. Please override this func.'
self
.
_log
(
)
"this Op has multiple previous inputs. Please override this func."
))
os
.
_exit
(
-
1
)
(
_
,
input_dict
),
=
input_dicts
.
items
()
(
_
,
input_dict
),
=
input_dicts
.
items
()
return
input_dict
return
input_dict
...
@@ -171,8 +193,10 @@ class Op(object):
...
@@ -171,8 +193,10 @@ class Op(object):
def
process
(
self
,
feed_batch
):
def
process
(
self
,
feed_batch
):
err
,
err_info
=
ChannelData
.
check_batch_npdata
(
feed_batch
)
err
,
err_info
=
ChannelData
.
check_batch_npdata
(
feed_batch
)
if
err
!=
0
:
if
err
!=
0
:
raise
NotImplementedError
(
_LOGGER
.
critical
(
"{} Please override preprocess func."
.
format
(
err_info
))
self
.
_log
(
"{}, Please override preprocess func."
.
format
(
err_info
)))
os
.
_exit
(
-
1
)
call_result
=
self
.
client
.
predict
(
call_result
=
self
.
client
.
predict
(
feed
=
feed_batch
,
fetch
=
self
.
_fetch_names
)
feed
=
feed_batch
,
fetch
=
self
.
_fetch_names
)
if
isinstance
(
self
.
client
,
MultiLangClient
):
if
isinstance
(
self
.
client
,
MultiLangClient
):
...
@@ -258,26 +282,18 @@ class Op(object):
...
@@ -258,26 +282,18 @@ class Op(object):
preped_data
,
error_channeldata
=
None
,
None
preped_data
,
error_channeldata
=
None
,
None
try
:
try
:
preped_data
=
self
.
preprocess
(
parsed_data
)
preped_data
=
self
.
preprocess
(
parsed_data
)
except
NotImplementedError
as
e
:
# preprocess function not implemented
error_info
=
log_func
(
"preprocess data[{}] failed: {}"
.
format
(
data_id
,
e
))
error_channeldata
=
ChannelData
(
ecode
=
ChannelDataEcode
.
NOT_IMPLEMENTED
.
value
,
error_info
=
error_info
,
data_id
=
data_id
)
except
TypeError
as
e
:
except
TypeError
as
e
:
# Error type in channeldata.datatype
# Error type in channeldata.datatype
error_info
=
log_func
(
"preprocess data[{}] failed: {}"
.
format
(
error_info
=
log_func
(
"preprocess data[{}] failed: {}"
data_id
,
e
))
.
format
(
data_id
,
e
))
_LOGGER
.
error
(
error_info
)
_LOGGER
.
error
(
error_info
)
error_channeldata
=
ChannelData
(
error_channeldata
=
ChannelData
(
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
,
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
,
error_info
=
error_info
,
error_info
=
error_info
,
data_id
=
data_id
)
data_id
=
data_id
)
except
Exception
as
e
:
except
Exception
as
e
:
error_info
=
log_func
(
"preprocess data[{}] failed: {}"
.
format
(
error_info
=
log_func
(
"preprocess data[{}] failed: {}"
data_id
,
e
))
.
format
(
data_id
,
e
))
_LOGGER
.
error
(
error_info
)
_LOGGER
.
error
(
error_info
)
error_channeldata
=
ChannelData
(
error_channeldata
=
ChannelData
(
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
,
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
,
...
@@ -317,7 +333,7 @@ class Op(object):
...
@@ -317,7 +333,7 @@ class Op(object):
error_info
=
log_func
(
e
)
error_info
=
log_func
(
e
)
_LOGGER
.
error
(
error_info
)
_LOGGER
.
error
(
error_info
)
else
:
else
:
_LOGGER
.
warn
(
_LOGGER
.
warn
ing
(
log_func
(
"PaddleService timeout, retry({}/{})"
log_func
(
"PaddleService timeout, retry({}/{})"
.
format
(
i
+
1
,
self
.
_retry
)))
.
format
(
i
+
1
,
self
.
_retry
)))
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -376,7 +392,8 @@ class Op(object):
...
@@ -376,7 +392,8 @@ class Op(object):
continue
continue
else
:
else
:
if
not
isinstance
(
postped_data
,
dict
):
if
not
isinstance
(
postped_data
,
dict
):
error_info
=
log_func
(
"output of postprocess funticon must be "
\
error_info
=
log_func
(
"output of postprocess funticon must be "
"dict type, but get {}"
.
format
(
type
(
postped_data
)))
"dict type, but get {}"
.
format
(
type
(
postped_data
)))
_LOGGER
.
error
(
error_info
)
_LOGGER
.
error
(
error_info
)
err_channeldata
=
ChannelData
(
err_channeldata
=
ChannelData
(
...
@@ -471,7 +488,7 @@ class Op(object):
...
@@ -471,7 +488,7 @@ class Op(object):
profiler
=
self
.
_initialize
(
is_thread_op
,
client_type
,
profiler
=
self
.
_initialize
(
is_thread_op
,
client_type
,
concurrency_idx
)
concurrency_idx
)
except
Exception
as
e
:
except
Exception
as
e
:
_LOGGER
.
error
(
log
(
"init op failed: {}"
.
format
(
e
)))
_LOGGER
.
critical
(
log
(
"init op failed: {}"
.
format
(
e
)))
os
.
_exit
(
-
1
)
os
.
_exit
(
-
1
)
_LOGGER
.
info
(
log
(
"succ init"
))
_LOGGER
.
info
(
log
(
"succ init"
))
...
@@ -629,7 +646,7 @@ class RequestOp(Op):
...
@@ -629,7 +646,7 @@ class RequestOp(Op):
try
:
try
:
self
.
init_op
()
self
.
init_op
()
except
Exception
as
e
:
except
Exception
as
e
:
_LOGGER
.
error
(
"Op(Request) init op failed: {}"
.
format
(
e
))
_LOGGER
.
critical
(
"Op(Request) init op failed: {}"
.
format
(
e
))
os
.
_exit
(
-
1
)
os
.
_exit
(
-
1
)
def
unpack_request_package
(
self
,
request
):
def
unpack_request_package
(
self
,
request
):
...
@@ -653,7 +670,7 @@ class ResponseOp(Op):
...
@@ -653,7 +670,7 @@ class ResponseOp(Op):
try
:
try
:
self
.
init_op
()
self
.
init_op
()
except
Exception
as
e
:
except
Exception
as
e
:
_LOGGER
.
error
(
"Op(ResponseOp) init op failed: {}"
.
format
(
e
))
_LOGGER
.
critical
(
"Op(ResponseOp) init op failed: {}"
.
format
(
e
))
os
.
_exit
(
-
1
)
os
.
_exit
(
-
1
)
def
pack_response_package
(
self
,
channeldata
):
def
pack_response_package
(
self
,
channeldata
):
...
@@ -710,9 +727,10 @@ class VirtualOp(Op):
...
@@ -710,9 +727,10 @@ class VirtualOp(Op):
def
add_output_channel
(
self
,
channel
):
def
add_output_channel
(
self
,
channel
):
if
not
isinstance
(
channel
,
(
ThreadChannel
,
ProcessChannel
)):
if
not
isinstance
(
channel
,
(
ThreadChannel
,
ProcessChannel
)):
raise
TypeError
(
_LOGGER
.
critical
(
self
.
_log
(
'output channel must be Channel type, not {}'
.
format
(
self
.
_log
(
"output channel must be Channel type, not {}"
type
(
channel
))))
.
format
(
type
(
channel
))))
os
.
_exit
(
-
1
)
for
op
in
self
.
_virtual_pred_ops
:
for
op
in
self
.
_virtual_pred_ops
:
for
op_name
in
self
.
_actual_pred_op_names
(
op
):
for
op_name
in
self
.
_actual_pred_op_names
(
op
):
channel
.
add_producer
(
op_name
)
channel
.
add_producer
(
op_name
)
...
@@ -730,17 +748,27 @@ class VirtualOp(Op):
...
@@ -730,17 +748,27 @@ class VirtualOp(Op):
log
=
get_log_func
(
op_info_prefix
)
log
=
get_log_func
(
op_info_prefix
)
tid
=
threading
.
current_thread
().
ident
tid
=
threading
.
current_thread
().
ident
batch_generator
=
self
.
_auto_batching_generator
(
input_channel
=
input_channel
,
op_name
=
self
.
name
,
batch_size
=
1
,
timeout
=
None
,
log_func
=
log
)
while
True
:
while
True
:
try
:
try
:
channeldata_dict
=
input_channel
.
front
(
self
.
name
)
channeldata_dict
_batch
=
next
(
batch_generator
)
except
ChannelStopError
:
except
ChannelStopError
:
_LOGGER
.
debug
(
log
(
"Channel stop."
))
_LOGGER
.
debug
(
log
(
"channel stop."
))
self
.
_finalize
(
is_thread_op
)
break
break
try
:
try
:
for
name
,
data
in
channeldata_dict
.
items
():
for
channeldata_dict
in
channeldata_dict_batch
:
self
.
_push_to_output_channels
(
for
name
,
data
in
channeldata_dict
.
items
():
data
,
channels
=
output_channels
,
name
=
name
)
self
.
_push_to_output_channels
(
data
,
channels
=
output_channels
,
name
=
name
)
except
ChannelStopError
:
except
ChannelStopError
:
_LOGGER
.
debug
(
log
(
"Channel stop."
))
_LOGGER
.
debug
(
log
(
"Channel stop."
))
self
.
_finalize
(
is_thread_op
)
break
break
python/pipeline/pipeline_server.py
浏览文件 @
29caa6d2
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
from
concurrent
import
futures
from
concurrent
import
futures
import
grpc
import
grpc
import
logging
import
logging
import
json
import
socket
import
socket
import
contextlib
import
contextlib
from
contextlib
import
closing
from
contextlib
import
closing
...
@@ -29,11 +30,10 @@ _LOGGER = logging.getLogger()
...
@@ -29,11 +30,10 @@ _LOGGER = logging.getLogger()
class
PipelineServicer
(
pipeline_service_pb2_grpc
.
PipelineServiceServicer
):
class
PipelineServicer
(
pipeline_service_pb2_grpc
.
PipelineServiceServicer
):
def
__init__
(
self
,
response_op
,
dag_conf
ig
,
show_info
):
def
__init__
(
self
,
response_op
,
dag_conf
):
super
(
PipelineServicer
,
self
).
__init__
()
super
(
PipelineServicer
,
self
).
__init__
()
# init dag executor
# init dag executor
self
.
_dag_executor
=
DAGExecutor
(
self
.
_dag_executor
=
DAGExecutor
(
response_op
,
dag_conf
)
response_op
,
dag_config
,
show_info
=
show_info
)
self
.
_dag_executor
.
start
()
self
.
_dag_executor
.
start
()
_LOGGER
.
info
(
"[PipelineServicer] succ init"
)
_LOGGER
.
info
(
"[PipelineServicer] succ init"
)
...
@@ -79,36 +79,25 @@ class PipelineServer(object):
...
@@ -79,36 +79,25 @@ class PipelineServer(object):
return
result
!=
0
return
result
!=
0
def
prepare_server
(
self
,
yml_file
):
def
prepare_server
(
self
,
yml_file
):
with
open
(
yml_file
)
as
f
:
conf
=
ServerYamlConfChecker
.
load_server_yaml_conf
(
yml_file
)
yml_config
=
yaml
.
load
(
f
.
read
())
default_config
=
{
"port"
:
9292
,
"worker_num"
:
1
,
"build_dag_each_worker"
:
False
,
}
for
key
,
val
in
default_config
.
items
():
self
.
_port
=
conf
[
"port"
]
if
yml_config
.
get
(
key
)
is
None
:
_LOGGER
.
warning
(
"[CONF] {} not set, use default: {}"
.
format
(
key
,
val
))
yml_config
[
key
]
=
val
self
.
_port
=
yml_config
[
"port"
]
if
not
self
.
_port_is_available
(
self
.
_port
):
if
not
self
.
_port_is_available
(
self
.
_port
):
raise
SystemExit
(
"Prot {} is already used"
.
format
(
self
.
_port
))
raise
SystemExit
(
"Prot {} is already used"
.
format
(
self
.
_port
))
self
.
_worker_num
=
yml_config
[
"worker_num"
]
self
.
_worker_num
=
conf
[
"worker_num"
]
self
.
_build_dag_each_worker
=
yml_config
[
"build_dag_each_worker"
]
self
.
_build_dag_each_worker
=
conf
[
"build_dag_each_worker"
]
_LOGGER
.
info
(
"============= PIPELINE SERVER ============="
)
_LOGGER
.
info
(
"============= PIPELINE SERVER ============="
)
for
key
in
default_config
.
keys
():
_LOGGER
.
info
(
"
\n
{}"
.
format
(
_LOGGER
.
info
(
"{}: {}"
.
format
(
key
,
yml_config
[
key
]))
json
.
dumps
(
conf
,
indent
=
4
,
separators
=
(
','
,
':'
))))
if
self
.
_build_dag_each_worker
is
True
:
if
self
.
_build_dag_each_worker
is
True
:
_LOGGER
.
info
(
_LOGGER
.
info
(
"(Make sure that install grpcio whl with --no-binary flag)"
)
"(Make sure that install grpcio whl with --no-binary flag)"
)
_LOGGER
.
info
(
"-------------------------------------------"
)
_LOGGER
.
info
(
"-------------------------------------------"
)
self
.
_dag_conf
ig
=
yml_config
.
get
(
"dag"
,
{})
self
.
_dag_conf
=
conf
[
"dag"
]
self
.
_dag_conf
ig
[
"build_dag_each_worker"
]
=
self
.
_build_dag_each_worker
self
.
_dag_conf
[
"build_dag_each_worker"
]
=
self
.
_build_dag_each_worker
def
run_server
(
self
):
def
run_server
(
self
):
if
self
.
_build_dag_each_worker
:
if
self
.
_build_dag_each_worker
:
...
@@ -119,8 +108,7 @@ class PipelineServer(object):
...
@@ -119,8 +108,7 @@ class PipelineServer(object):
show_info
=
(
i
==
0
)
show_info
=
(
i
==
0
)
worker
=
multiprocessing
.
Process
(
worker
=
multiprocessing
.
Process
(
target
=
self
.
_run_server_func
,
target
=
self
.
_run_server_func
,
args
=
(
bind_address
,
self
.
_response_op
,
args
=
(
bind_address
,
self
.
_response_op
,
self
.
_dag_conf
))
self
.
_dag_config
))
worker
.
start
()
worker
.
start
()
workers
.
append
(
worker
)
workers
.
append
(
worker
)
for
worker
in
workers
:
for
worker
in
workers
:
...
@@ -129,19 +117,140 @@ class PipelineServer(object):
...
@@ -129,19 +117,140 @@ class PipelineServer(object):
server
=
grpc
.
server
(
server
=
grpc
.
server
(
futures
.
ThreadPoolExecutor
(
max_workers
=
self
.
_worker_num
))
futures
.
ThreadPoolExecutor
(
max_workers
=
self
.
_worker_num
))
pipeline_service_pb2_grpc
.
add_PipelineServiceServicer_to_server
(
pipeline_service_pb2_grpc
.
add_PipelineServiceServicer_to_server
(
PipelineServicer
(
self
.
_response_op
,
self
.
_dag_config
,
True
),
PipelineServicer
(
self
.
_response_op
,
self
.
_dag_conf
),
server
)
server
)
server
.
add_insecure_port
(
'[::]:{}'
.
format
(
self
.
_port
))
server
.
add_insecure_port
(
'[::]:{}'
.
format
(
self
.
_port
))
server
.
start
()
server
.
start
()
server
.
wait_for_termination
()
server
.
wait_for_termination
()
def
_run_server_func
(
self
,
bind_address
,
response_op
,
dag_conf
ig
):
def
_run_server_func
(
self
,
bind_address
,
response_op
,
dag_conf
):
options
=
((
'grpc.so_reuseport'
,
1
),
)
options
=
((
'grpc.so_reuseport'
,
1
),
)
server
=
grpc
.
server
(
server
=
grpc
.
server
(
futures
.
ThreadPoolExecutor
(
futures
.
ThreadPoolExecutor
(
max_workers
=
1
,
),
options
=
options
)
max_workers
=
1
,
),
options
=
options
)
pipeline_service_pb2_grpc
.
add_PipelineServiceServicer_to_server
(
pipeline_service_pb2_grpc
.
add_PipelineServiceServicer_to_server
(
PipelineServicer
(
response_op
,
dag_conf
ig
,
False
),
server
)
PipelineServicer
(
response_op
,
dag_conf
),
server
)
server
.
add_insecure_port
(
bind_address
)
server
.
add_insecure_port
(
bind_address
)
server
.
start
()
server
.
start
()
server
.
wait_for_termination
()
server
.
wait_for_termination
()
class
ServerYamlConfChecker
(
object
):
def
__init__
(
self
):
pass
@
staticmethod
def
load_server_yaml_conf
(
yml_file
):
with
open
(
yml_file
)
as
f
:
conf
=
yaml
.
load
(
f
.
read
())
ServerYamlConfChecker
.
check_server_conf
(
conf
)
ServerYamlConfChecker
.
check_dag_conf
(
conf
[
"dag"
])
return
conf
@
staticmethod
def
check_server_conf
(
conf
):
default_conf
=
{
"port"
:
9292
,
"worker_num"
:
1
,
"build_dag_each_worker"
:
False
,
"dag"
:
{},
}
ServerYamlConfChecker
.
fill_with_default_conf
(
conf
,
default_conf
)
conf_type
=
{
"port"
:
int
,
"worker_num"
:
int
,
"build_dag_each_worker"
:
bool
,
}
ServerYamlConfChecker
.
check_conf_type
(
conf
,
conf_type
)
conf_qualification
=
{
"port"
:
[(
">="
,
1024
),
(
"<="
,
65535
)],
"worker_num"
:
(
">="
,
1
),
}
ServerYamlConfChecker
.
check_conf_qualification
(
conf
,
conf_qualification
)
@
staticmethod
def
check_dag_conf
(
conf
):
default_conf
=
{
"retry"
:
1
,
"client_type"
:
"brpc"
,
"use_profile"
:
False
,
"channel_size"
:
0
,
"is_thread_op"
:
True
}
ServerYamlConfChecker
.
fill_with_default_conf
(
conf
,
default_conf
)
conf_type
=
{
"retry"
:
int
,
"client_type"
:
str
,
"use_profile"
:
bool
,
"channel_size"
:
int
,
"is_thread_op"
:
bool
,
}
ServerYamlConfChecker
.
check_conf_type
(
conf
,
conf_type
)
conf_qualification
=
{
"retry"
:
(
">="
,
1
),
"client_type"
:
(
"in"
,
[
"brpc"
,
"grpc"
]),
"channel_size"
:
(
">="
,
0
),
}
ServerYamlConfChecker
.
check_conf_qualification
(
conf
,
conf_qualification
)
@
staticmethod
def
fill_with_default_conf
(
conf
,
default_conf
):
for
key
,
val
in
default_conf
.
items
():
if
conf
.
get
(
key
)
is
None
:
_LOGGER
.
warning
(
"[CONF] {} not set, use default: {}"
.
format
(
key
,
val
))
conf
[
key
]
=
val
@
staticmethod
def
check_conf_type
(
conf
,
conf_type
):
for
key
,
val
in
conf_type
.
items
():
if
not
isinstance
(
conf
[
key
],
val
):
raise
SystemExit
(
"[CONF] {} must be {} type, but get {}."
.
format
(
key
,
val
,
type
(
conf
[
key
])))
@
staticmethod
def
check_conf_qualification
(
conf
,
conf_qualification
):
for
key
,
qualification
in
conf_qualification
.
items
():
if
not
isinstance
(
qualification
,
list
):
qualification
=
[
qualification
]
if
not
ServerYamlConfChecker
.
qualification_check
(
conf
[
key
],
qualification
):
raise
SystemExit
(
"[CONF] {} must be {}, but get {}."
.
format
(
key
,
", "
.
join
([
"{} {}"
.
format
(
q
[
0
],
q
[
1
])
for
q
in
qualification
]),
conf
[
key
]))
@
staticmethod
def
qualification_check
(
value
,
qualifications
):
if
not
isinstance
(
qualifications
,
list
):
qualifications
=
[
qualifications
]
ok
=
True
for
q
in
qualifications
:
operator
,
limit
=
q
if
operator
==
"<"
:
ok
=
value
<
limit
elif
operator
==
"=="
:
ok
=
value
==
limit
elif
operator
==
">"
:
ok
=
value
>
limit
elif
operator
==
"<="
:
ok
=
value
<=
limit
elif
operator
==
">="
:
ok
=
value
>=
limit
elif
operator
==
"in"
:
ok
=
value
in
limit
else
:
raise
SystemExit
(
"unknow operator: {}"
.
format
(
operator
))
if
ok
==
False
:
break
return
ok
python/pipeline/profiler.py
浏览文件 @
29caa6d2
...
@@ -29,6 +29,8 @@ _LOGGER = logging.getLogger()
...
@@ -29,6 +29,8 @@ _LOGGER = logging.getLogger()
class
UnsafeTimeProfiler
(
object
):
class
UnsafeTimeProfiler
(
object
):
""" thread unsafe profiler """
def
__init__
(
self
):
def
__init__
(
self
):
self
.
pid
=
os
.
getpid
()
self
.
pid
=
os
.
getpid
()
self
.
print_head
=
'PROFILE
\t
pid:{}
\t
'
.
format
(
self
.
pid
)
self
.
print_head
=
'PROFILE
\t
pid:{}
\t
'
.
format
(
self
.
pid
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录