Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
4049eb8a
S
Serving
项目概览
PaddlePaddle
/
Serving
接近 2 年 前同步成功
通知
186
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4049eb8a
编写于
7月 30, 2020
作者:
B
barriery
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix codestyle
上级
de58d900
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
159 addition
and
161 deletion
+159
-161
python/pipeline/channel.py
python/pipeline/channel.py
+52
-58
python/pipeline/dag.py
python/pipeline/dag.py
+31
-26
python/pipeline/operator.py
python/pipeline/operator.py
+71
-73
python/pipeline/pipeline_client.py
python/pipeline/pipeline_client.py
+1
-1
python/pipeline/pipeline_server.py
python/pipeline/pipeline_server.py
+4
-3
未找到文件。
python/pipeline/channel.py
浏览文件 @
4049eb8a
...
@@ -258,7 +258,8 @@ class ProcessChannel(object):
...
@@ -258,7 +258,8 @@ class ProcessChannel(object):
def
push
(
self
,
channeldata
,
op_name
=
None
):
def
push
(
self
,
channeldata
,
op_name
=
None
):
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
"{} try to push data[{}]"
.
format
(
op_name
,
channeldata
.
id
)))
self
.
_log
(
"{} try to push data[{}]"
.
format
(
op_name
,
channeldata
.
id
)))
if
len
(
self
.
_producers
)
==
0
:
if
len
(
self
.
_producers
)
==
0
:
raise
Exception
(
raise
Exception
(
self
.
_log
(
self
.
_log
(
...
@@ -275,8 +276,9 @@ class ProcessChannel(object):
...
@@ -275,8 +276,9 @@ class ProcessChannel(object):
if
self
.
_stop
.
value
==
1
:
if
self
.
_stop
.
value
==
1
:
raise
ChannelStopError
()
raise
ChannelStopError
()
self
.
_cv
.
notify_all
()
self
.
_cv
.
notify_all
()
_LOGGER
.
debug
(
self
.
_log
(
"{} notify all"
.
format
(
op_name
)))
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
"{} push data succ!"
.
format
(
op_name
)))
self
.
_log
(
"{} succ push data[{}] into internal queue."
.
format
(
op_name
,
channeldata
.
id
)))
return
True
return
True
elif
op_name
is
None
:
elif
op_name
is
None
:
raise
Exception
(
raise
Exception
(
...
@@ -287,7 +289,6 @@ class ProcessChannel(object):
...
@@ -287,7 +289,6 @@ class ProcessChannel(object):
data_id
=
channeldata
.
id
data_id
=
channeldata
.
id
put_data
=
None
put_data
=
None
with
self
.
_cv
:
with
self
.
_cv
:
_LOGGER
.
debug
(
self
.
_log
(
"{} get lock"
.
format
(
op_name
)))
if
data_id
not
in
self
.
_input_buf
:
if
data_id
not
in
self
.
_input_buf
:
self
.
_input_buf
[
data_id
]
=
{
self
.
_input_buf
[
data_id
]
=
{
name
:
None
name
:
None
...
@@ -309,14 +310,11 @@ class ProcessChannel(object):
...
@@ -309,14 +310,11 @@ class ProcessChannel(object):
if
put_data
is
None
:
if
put_data
is
None
:
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
"{}
push data succ, but not push to queue."
.
self
.
_log
(
"{}
succ push data[{}] into input_buffer."
.
format
(
format
(
op_name
)))
op_name
,
data_id
)))
else
:
else
:
while
self
.
_stop
.
value
==
0
:
while
self
.
_stop
.
value
==
0
:
try
:
try
:
_LOGGER
.
debug
(
self
.
_log
(
"{} push data succ: {}"
.
format
(
op_name
,
put_data
.
__str__
())))
self
.
_que
.
put
(
put_data
,
timeout
=
0
)
self
.
_que
.
put
(
put_data
,
timeout
=
0
)
break
break
except
Queue
.
Empty
:
except
Queue
.
Empty
:
...
@@ -325,11 +323,15 @@ class ProcessChannel(object):
...
@@ -325,11 +323,15 @@ class ProcessChannel(object):
raise
ChannelStopError
()
raise
ChannelStopError
()
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
"multi | {} push data succ!"
.
format
(
op_name
)))
self
.
_log
(
"{} succ push data[{}] into internal queue."
.
format
(
op_name
,
data_id
)))
self
.
_cv
.
notify_all
()
self
.
_cv
.
notify_all
()
return
True
return
True
def
front
(
self
,
op_name
=
None
,
timeout
=
None
):
def
front
(
self
,
op_name
=
None
,
timeout
=
None
):
_LOGGER
.
debug
(
self
.
_log
(
"{} try to get data[?]; timeout={}"
.
format
(
op_name
,
timeout
)))
endtime
=
None
endtime
=
None
if
timeout
is
not
None
:
if
timeout
is
not
None
:
if
timeout
<=
0
:
if
timeout
<=
0
:
...
@@ -337,7 +339,6 @@ class ProcessChannel(object):
...
@@ -337,7 +339,6 @@ class ProcessChannel(object):
else
:
else
:
endtime
=
_time
()
+
timeout
endtime
=
_time
()
+
timeout
_LOGGER
.
debug
(
self
.
_log
(
"{} try to get data..."
.
format
(
op_name
)))
if
len
(
self
.
_consumer_cursors
)
==
0
:
if
len
(
self
.
_consumer_cursors
)
==
0
:
raise
Exception
(
raise
Exception
(
self
.
_log
(
self
.
_log
(
...
@@ -348,21 +349,24 @@ class ProcessChannel(object):
...
@@ -348,21 +349,24 @@ class ProcessChannel(object):
with
self
.
_cv
:
with
self
.
_cv
:
while
self
.
_stop
.
value
==
0
and
resp
is
None
:
while
self
.
_stop
.
value
==
0
and
resp
is
None
:
try
:
try
:
_LOGGER
.
debug
(
self
.
_log
(
"{} try to get(with channel empty: {})"
.
format
(
op_name
,
self
.
_que
.
empty
())))
resp
=
self
.
_que
.
get
(
timeout
=
0
)
resp
=
self
.
_que
.
get
(
timeout
=
0
)
break
break
except
Queue
.
Empty
:
except
Queue
.
Empty
:
if
timeout
is
not
None
:
if
timeout
is
not
None
:
remaining
=
endtime
-
_time
()
remaining
=
endtime
-
_time
()
if
remaining
<=
0.0
:
if
remaining
<=
0.0
:
_LOGGER
.
debug
(
self
.
_log
(
"{} get data[?] timeout"
.
format
(
op_name
)))
raise
ChannelTimeoutError
()
raise
ChannelTimeoutError
()
self
.
_cv
.
wait
(
remaining
)
self
.
_cv
.
wait
(
remaining
)
else
:
else
:
self
.
_cv
.
wait
()
self
.
_cv
.
wait
()
if
self
.
_stop
.
value
==
1
:
if
self
.
_stop
.
value
==
1
:
raise
ChannelStopError
()
raise
ChannelStopError
()
_LOGGER
.
debug
(
self
.
_log
(
"{} succ get data[{}]"
.
format
(
op_name
,
resp
.
values
()[
0
].
id
)))
return
resp
return
resp
elif
op_name
is
None
:
elif
op_name
is
None
:
raise
Exception
(
raise
Exception
(
...
@@ -384,22 +388,20 @@ class ProcessChannel(object):
...
@@ -384,22 +388,20 @@ class ProcessChannel(object):
# it is necessary to obtain a data from queue and add it to output_buf.
# it is necessary to obtain a data from queue and add it to output_buf.
while
self
.
_stop
.
value
==
0
and
self
.
_consumer_cursors
[
while
self
.
_stop
.
value
==
0
and
self
.
_consumer_cursors
[
op_name
]
-
self
.
_base_cursor
.
value
>=
len
(
self
.
_output_buf
):
op_name
]
-
self
.
_base_cursor
.
value
>=
len
(
self
.
_output_buf
):
_LOGGER
.
debug
(
self
.
_log
(
"({}) B self._consumer_cursors: {}, self._base_cursor: {}, len(self._output_buf): {}"
.
format
(
op_name
,
self
.
_consumer_cursors
,
self
.
_base_cursor
.
value
,
len
(
self
.
_output_buf
))))
try
:
try
:
_LOGGER
.
debug
(
self
.
_log
(
"{} try to get(with channel size: {})"
.
format
(
op_name
,
self
.
_que
.
qsize
())))
channeldata
=
self
.
_que
.
get
(
timeout
=
0
)
channeldata
=
self
.
_que
.
get
(
timeout
=
0
)
self
.
_output_buf
.
append
(
channeldata
)
self
.
_output_buf
.
append
(
channeldata
)
_LOGGER
.
debug
(
self
.
_log
(
"pop ready item[{}] into output_buffer"
.
format
(
channeldata
.
values
()[
0
].
id
)))
break
break
except
Queue
.
Empty
:
except
Queue
.
Empty
:
if
timeout
is
not
None
:
if
timeout
is
not
None
:
remaining
=
endtime
-
_time
()
remaining
=
endtime
-
_time
()
if
remaining
<=
0.0
:
if
remaining
<=
0.0
:
_LOGGER
.
debug
(
self
.
_log
(
"{} get data[?] timeout"
.
format
(
op_name
)))
raise
ChannelTimeoutError
()
raise
ChannelTimeoutError
()
self
.
_cv
.
wait
(
remaining
)
self
.
_cv
.
wait
(
remaining
)
else
:
else
:
...
@@ -411,7 +413,6 @@ class ProcessChannel(object):
...
@@ -411,7 +413,6 @@ class ProcessChannel(object):
base_cursor
=
self
.
_base_cursor
.
value
base_cursor
=
self
.
_base_cursor
.
value
data_idx
=
consumer_cursor
-
base_cursor
data_idx
=
consumer_cursor
-
base_cursor
resp
=
self
.
_output_buf
[
data_idx
]
resp
=
self
.
_output_buf
[
data_idx
]
_LOGGER
.
debug
(
self
.
_log
(
"{} get data: {}"
.
format
(
op_name
,
resp
)))
self
.
_cursor_count
[
consumer_cursor
]
-=
1
self
.
_cursor_count
[
consumer_cursor
]
-=
1
if
consumer_cursor
==
base_cursor
and
self
.
_cursor_count
[
if
consumer_cursor
==
base_cursor
and
self
.
_cursor_count
[
...
@@ -423,6 +424,7 @@ class ProcessChannel(object):
...
@@ -423,6 +424,7 @@ class ProcessChannel(object):
self
.
_base_cursor
.
value
+=
1
self
.
_base_cursor
.
value
+=
1
# to avoid cursor overflow
# to avoid cursor overflow
if
self
.
_base_cursor
.
value
>=
self
.
_reset_max_cursor
:
if
self
.
_base_cursor
.
value
>=
self
.
_reset_max_cursor
:
_LOGGER
.
info
(
self
.
_log
(
"reset cursor in Channel"
))
self
.
_base_cursor
.
value
-=
self
.
_reset_max_cursor
self
.
_base_cursor
.
value
-=
self
.
_reset_max_cursor
for
name
in
self
.
_consumer_cursors
.
keys
():
for
name
in
self
.
_consumer_cursors
.
keys
():
self
.
_consumer_cursors
[
name
]
-=
self
.
_reset_max_cursor
self
.
_consumer_cursors
[
name
]
-=
self
.
_reset_max_cursor
...
@@ -440,16 +442,12 @@ class ProcessChannel(object):
...
@@ -440,16 +442,12 @@ class ProcessChannel(object):
self
.
_cursor_count
[
new_consumer_cursor
]
=
0
self
.
_cursor_count
[
new_consumer_cursor
]
=
0
self
.
_cursor_count
[
new_consumer_cursor
]
+=
1
self
.
_cursor_count
[
new_consumer_cursor
]
+=
1
_LOGGER
.
debug
(
self
.
_log
(
"({}) A self._consumer_cursors: {}, self._base_cursor: {}, len(self._output_buf): {}"
.
format
(
op_name
,
self
.
_consumer_cursors
,
self
.
_base_cursor
.
value
,
len
(
self
.
_output_buf
))))
_LOGGER
.
debug
(
self
.
_log
(
"{} notify all"
.
format
(
op_name
)))
self
.
_cv
.
notify_all
()
self
.
_cv
.
notify_all
()
_LOGGER
.
debug
(
self
.
_log
(
"multi | {} get data succ!"
.
format
(
op_name
)))
_LOGGER
.
debug
(
return
resp
# reference, read only
self
.
_log
(
"{} succ get data[{}] from output_buffer"
.
format
(
op_name
,
resp
.
values
()[
0
].
id
)))
return
resp
def
stop
(
self
):
def
stop
(
self
):
_LOGGER
.
debug
(
self
.
_log
(
"stop."
))
_LOGGER
.
debug
(
self
.
_log
(
"stop."
))
...
@@ -538,7 +536,8 @@ class ThreadChannel(Queue.Queue):
...
@@ -538,7 +536,8 @@ class ThreadChannel(Queue.Queue):
def
push
(
self
,
channeldata
,
op_name
=
None
):
def
push
(
self
,
channeldata
,
op_name
=
None
):
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
"{} try to push data[{}]"
.
format
(
op_name
,
channeldata
.
id
)))
self
.
_log
(
"{} try to push data[{}]"
.
format
(
op_name
,
channeldata
.
id
)))
if
len
(
self
.
_producers
)
==
0
:
if
len
(
self
.
_producers
)
==
0
:
raise
Exception
(
raise
Exception
(
self
.
_log
(
self
.
_log
(
...
@@ -556,9 +555,8 @@ class ThreadChannel(Queue.Queue):
...
@@ -556,9 +555,8 @@ class ThreadChannel(Queue.Queue):
raise
ChannelStopError
()
raise
ChannelStopError
()
self
.
_cv
.
notify_all
()
self
.
_cv
.
notify_all
()
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
self
.
_log
(
"{} succ push data[{}] into internal queue."
.
format
(
"{} succ push data[{}] into internal queue."
.
format
(
op_name
,
channeldata
.
id
)))
op_name
,
channeldata
.
id
)))
return
True
return
True
elif
op_name
is
None
:
elif
op_name
is
None
:
raise
Exception
(
raise
Exception
(
...
@@ -585,9 +583,8 @@ class ThreadChannel(Queue.Queue):
...
@@ -585,9 +583,8 @@ class ThreadChannel(Queue.Queue):
if
put_data
is
None
:
if
put_data
is
None
:
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
self
.
_log
(
"{} succ push data[{}] into input_buffer."
.
format
(
"{} succ push data[{}] into input_buffer."
.
format
(
op_name
,
data_id
)))
op_name
,
data_id
)))
else
:
else
:
while
self
.
_stop
is
False
:
while
self
.
_stop
is
False
:
try
:
try
:
...
@@ -599,17 +596,15 @@ class ThreadChannel(Queue.Queue):
...
@@ -599,17 +596,15 @@ class ThreadChannel(Queue.Queue):
raise
ChannelStopError
()
raise
ChannelStopError
()
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
self
.
_log
(
"{} succ push data[{}] into internal queue."
.
"{} succ push data[{}] into internal queue."
.
format
(
format
(
op_name
,
data_id
)))
op_name
,
data_id
)))
self
.
_cv
.
notify_all
()
self
.
_cv
.
notify_all
()
return
True
return
True
def
front
(
self
,
op_name
=
None
,
timeout
=
None
):
def
front
(
self
,
op_name
=
None
,
timeout
=
None
):
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
self
.
_log
(
"{} try to get data[?]; timeout={}"
.
format
(
op_name
,
"{} try to get data[?]; timeout={}"
.
format
(
timeout
)))
op_name
,
timeout
)))
endtime
=
None
endtime
=
None
if
timeout
is
not
None
:
if
timeout
is
not
None
:
if
timeout
<=
0
:
if
timeout
<=
0
:
...
@@ -634,8 +629,8 @@ class ThreadChannel(Queue.Queue):
...
@@ -634,8 +629,8 @@ class ThreadChannel(Queue.Queue):
remaining
=
endtime
-
_time
()
remaining
=
endtime
-
_time
()
if
remaining
<=
0.0
:
if
remaining
<=
0.0
:
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
self
.
_log
(
"{} get data[?] timeout"
.
format
(
"{} get data[?] timeout"
.
format
(
op_name
)))
op_name
)))
raise
ChannelTimeoutError
()
raise
ChannelTimeoutError
()
self
.
_cv
.
wait
(
remaining
)
self
.
_cv
.
wait
(
remaining
)
else
:
else
:
...
@@ -643,8 +638,8 @@ class ThreadChannel(Queue.Queue):
...
@@ -643,8 +638,8 @@ class ThreadChannel(Queue.Queue):
if
self
.
_stop
:
if
self
.
_stop
:
raise
ChannelStopError
()
raise
ChannelStopError
()
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
"{} succ get data[{}]"
.
format
(
self
.
_log
(
"{} succ get data[{}]"
.
format
(
op_name
,
op_name
,
resp
.
values
()[
0
].
id
)))
resp
.
values
()[
0
].
id
)))
return
resp
return
resp
elif
op_name
is
None
:
elif
op_name
is
None
:
raise
Exception
(
raise
Exception
(
...
@@ -670,17 +665,16 @@ class ThreadChannel(Queue.Queue):
...
@@ -670,17 +665,16 @@ class ThreadChannel(Queue.Queue):
channeldata
=
self
.
get
(
timeout
=
0
)
channeldata
=
self
.
get
(
timeout
=
0
)
self
.
_output_buf
.
append
(
channeldata
)
self
.
_output_buf
.
append
(
channeldata
)
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
self
.
_log
(
"pop ready item[{}] into output_buffer"
.
"pop ready item[{}] into output_buffer"
.
format
(
format
(
channeldata
.
values
()[
0
].
id
)))
channeldata
.
values
()[
0
].
id
)))
break
break
except
Queue
.
Empty
:
except
Queue
.
Empty
:
if
timeout
is
not
None
:
if
timeout
is
not
None
:
remaining
=
endtime
-
_time
()
remaining
=
endtime
-
_time
()
if
remaining
<=
0.0
:
if
remaining
<=
0.0
:
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
self
.
_log
(
"{} get data[?] timeout"
.
format
(
"{} get data[?] timeout"
.
format
(
op_name
)))
op_name
)))
raise
ChannelTimeoutError
()
raise
ChannelTimeoutError
()
self
.
_cv
.
wait
(
remaining
)
self
.
_cv
.
wait
(
remaining
)
else
:
else
:
...
@@ -704,8 +698,7 @@ class ThreadChannel(Queue.Queue):
...
@@ -704,8 +698,7 @@ class ThreadChannel(Queue.Queue):
self
.
_base_cursor
+=
1
self
.
_base_cursor
+=
1
# to avoid cursor overflow
# to avoid cursor overflow
if
self
.
_base_cursor
>=
self
.
_reset_max_cursor
:
if
self
.
_base_cursor
>=
self
.
_reset_max_cursor
:
_LOGGER
.
info
(
_LOGGER
.
info
(
self
.
_log
(
"reset cursor in Channel"
))
self
.
_log
(
"reset cursor in Channel"
))
self
.
_base_cursor
-=
self
.
_reset_max_cursor
self
.
_base_cursor
-=
self
.
_reset_max_cursor
for
name
in
self
.
_consumer_cursors
:
for
name
in
self
.
_consumer_cursors
:
self
.
_consumer_cursors
[
name
]
-=
self
.
_reset_max_cursor
self
.
_consumer_cursors
[
name
]
-=
self
.
_reset_max_cursor
...
@@ -725,9 +718,8 @@ class ThreadChannel(Queue.Queue):
...
@@ -725,9 +718,8 @@ class ThreadChannel(Queue.Queue):
self
.
_cv
.
notify_all
()
self
.
_cv
.
notify_all
()
_LOGGER
.
debug
(
_LOGGER
.
debug
(
self
.
_log
(
self
.
_log
(
"{} succ get data[{}] from output_buffer"
.
format
(
"{} succ get data[{}] from output_buffer"
.
format
(
op_name
,
resp
.
values
()[
0
].
id
)))
op_name
,
resp
.
values
()[
0
].
id
)))
return
resp
return
resp
def
stop
(
self
):
def
stop
(
self
):
...
@@ -736,10 +728,12 @@ class ThreadChannel(Queue.Queue):
...
@@ -736,10 +728,12 @@ class ThreadChannel(Queue.Queue):
with
self
.
_cv
:
with
self
.
_cv
:
self
.
_cv
.
notify_all
()
self
.
_cv
.
notify_all
()
class
ChannelTimeoutError
(
RuntimeError
):
class
ChannelTimeoutError
(
RuntimeError
):
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
class
ChannelStopError
(
RuntimeError
):
class
ChannelStopError
(
RuntimeError
):
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
python/pipeline/dag.py
浏览文件 @
4049eb8a
...
@@ -47,7 +47,7 @@ class DAGExecutor(object):
...
@@ -47,7 +47,7 @@ class DAGExecutor(object):
for
key
,
val
in
default_conf
.
items
():
for
key
,
val
in
default_conf
.
items
():
if
dag_config
.
get
(
key
)
is
None
:
if
dag_config
.
get
(
key
)
is
None
:
_LOGGER
.
warning
(
"[CONF] {} not set, use default: {}"
_LOGGER
.
warning
(
"[CONF] {} not set, use default: {}"
.
format
(
key
,
val
))
.
format
(
key
,
val
))
dag_config
[
key
]
=
val
dag_config
[
key
]
=
val
self
.
_retry
=
dag_config
[
"retry"
]
self
.
_retry
=
dag_config
[
"retry"
]
...
@@ -60,7 +60,7 @@ class DAGExecutor(object):
...
@@ -60,7 +60,7 @@ class DAGExecutor(object):
if
show_info
:
if
show_info
:
_LOGGER
.
info
(
"=============== DAGExecutor ==============="
)
_LOGGER
.
info
(
"=============== DAGExecutor ==============="
)
for
key
in
default_conf
.
keys
():
for
key
in
default_conf
.
keys
():
_LOGGER
.
info
(
"{}: {}"
.
format
(
key
,
dag_config
[
key
]))
_LOGGER
.
info
(
"{}: {}"
.
format
(
key
,
dag_config
[
key
]))
_LOGGER
.
info
(
"-------------------------------------------"
)
_LOGGER
.
info
(
"-------------------------------------------"
)
self
.
name
=
"@G"
self
.
name
=
"@G"
...
@@ -110,17 +110,15 @@ class DAGExecutor(object):
...
@@ -110,17 +110,15 @@ class DAGExecutor(object):
def
_set_in_channel
(
self
,
in_channel
):
def
_set_in_channel
(
self
,
in_channel
):
if
not
isinstance
(
in_channel
,
(
ThreadChannel
,
ProcessChannel
)):
if
not
isinstance
(
in_channel
,
(
ThreadChannel
,
ProcessChannel
)):
raise
TypeError
(
raise
TypeError
(
"in_channel must be Channel type, but get {}"
.
"in_channel must be Channel type, but get {}"
.
format
(
format
(
type
(
in_channel
)))
type
(
in_channel
)))
in_channel
.
add_producer
(
self
.
name
)
in_channel
.
add_producer
(
self
.
name
)
self
.
_in_channel
=
in_channel
self
.
_in_channel
=
in_channel
def
_set_out_channel
(
self
,
out_channel
):
def
_set_out_channel
(
self
,
out_channel
):
if
not
isinstance
(
out_channel
,
(
ThreadChannel
,
ProcessChannel
)):
if
not
isinstance
(
out_channel
,
(
ThreadChannel
,
ProcessChannel
)):
raise
TypeError
(
raise
TypeError
(
"iout_channel must be Channel type, but get {}"
.
"iout_channel must be Channel type, but get {}"
.
format
(
format
(
type
(
out_channel
)))
type
(
out_channel
)))
out_channel
.
add_consumer
(
self
.
name
)
out_channel
.
add_consumer
(
self
.
name
)
self
.
_out_channel
=
out_channel
self
.
_out_channel
=
out_channel
...
@@ -143,12 +141,14 @@ class DAGExecutor(object):
...
@@ -143,12 +141,14 @@ class DAGExecutor(object):
break
break
if
len
(
channeldata_dict
)
!=
1
:
if
len
(
channeldata_dict
)
!=
1
:
_LOGGER
.
error
(
"[DAG Executor] out_channel cannot have multiple input ops"
)
_LOGGER
.
error
(
"[DAG Executor] out_channel cannot have multiple input ops"
)
os
.
_exit
(
-
1
)
os
.
_exit
(
-
1
)
(
_
,
channeldata
),
=
channeldata_dict
.
items
()
(
_
,
channeldata
),
=
channeldata_dict
.
items
()
if
not
isinstance
(
channeldata
,
ChannelData
):
if
not
isinstance
(
channeldata
,
ChannelData
):
_LOGGER
.
error
(
'[DAG Executor] data must be ChannelData type, but get {}'
_LOGGER
.
error
(
.
format
(
type
(
channeldata
)))
'[DAG Executor] data must be ChannelData type, but get {}'
.
format
(
type
(
channeldata
)))
os
.
_exit
(
-
1
)
os
.
_exit
(
-
1
)
data_id
=
channeldata
.
id
data_id
=
channeldata
.
id
...
@@ -178,7 +178,7 @@ class DAGExecutor(object):
...
@@ -178,7 +178,7 @@ class DAGExecutor(object):
dictdata
=
self
.
_unpack_rpc_func
(
rpc_request
)
dictdata
=
self
.
_unpack_rpc_func
(
rpc_request
)
except
Exception
as
e
:
except
Exception
as
e
:
_LOGGER
.
error
(
"parse RPC package to data[{}] Error: {}"
_LOGGER
.
error
(
"parse RPC package to data[{}] Error: {}"
.
format
(
data_id
,
e
))
.
format
(
data_id
,
e
))
return
ChannelData
(
return
ChannelData
(
ecode
=
ChannelDataEcode
.
RPC_PACKAGE_ERROR
.
value
,
ecode
=
ChannelDataEcode
.
RPC_PACKAGE_ERROR
.
value
,
error_info
=
"rpc package error: {}"
.
format
(
e
),
error_info
=
"rpc package error: {}"
.
format
(
e
),
...
@@ -192,7 +192,8 @@ class DAGExecutor(object):
...
@@ -192,7 +192,8 @@ class DAGExecutor(object):
profile_value
=
rpc_request
.
value
[
idx
]
profile_value
=
rpc_request
.
value
[
idx
]
break
break
client_need_profile
=
(
profile_value
==
self
.
_client_profile_value
)
client_need_profile
=
(
profile_value
==
self
.
_client_profile_value
)
_LOGGER
.
debug
(
"request[{}] need profile: {}"
.
format
(
data_id
,
client_need_profile
))
_LOGGER
.
debug
(
"request[{}] need profile: {}"
.
format
(
data_id
,
client_need_profile
))
return
ChannelData
(
return
ChannelData
(
datatype
=
ChannelDataType
.
DICT
.
value
,
datatype
=
ChannelDataType
.
DICT
.
value
,
dictdata
=
dictdata
,
dictdata
=
dictdata
,
...
@@ -208,7 +209,8 @@ class DAGExecutor(object):
...
@@ -208,7 +209,8 @@ class DAGExecutor(object):
else
:
else
:
self
.
_profiler
.
record
(
"call_{}#DAG_0"
.
format
(
data_id
))
self
.
_profiler
.
record
(
"call_{}#DAG_0"
.
format
(
data_id
))
_LOGGER
.
debug
(
"try parse RPC package to channeldata[{}]"
.
format
(
data_id
))
_LOGGER
.
debug
(
"try parse RPC package to channeldata[{}]"
.
format
(
data_id
))
self
.
_profiler
.
record
(
"prepack_{}#{}_0"
.
format
(
data_id
,
self
.
name
))
self
.
_profiler
.
record
(
"prepack_{}#{}_0"
.
format
(
data_id
,
self
.
name
))
req_channeldata
=
self
.
_pack_channeldata
(
rpc_request
,
data_id
)
req_channeldata
=
self
.
_pack_channeldata
(
rpc_request
,
data_id
)
self
.
_profiler
.
record
(
"prepack_{}#{}_1"
.
format
(
data_id
,
self
.
name
))
self
.
_profiler
.
record
(
"prepack_{}#{}_1"
.
format
(
data_id
,
self
.
name
))
...
@@ -226,23 +228,26 @@ class DAGExecutor(object):
...
@@ -226,23 +228,26 @@ class DAGExecutor(object):
error_info
=
"dag closed."
,
error_info
=
"dag closed."
,
data_id
=
data_id
))
data_id
=
data_id
))
_LOGGER
.
debug
(
"wait for Graph engine for data[{}]..."
.
format
(
data_id
))
_LOGGER
.
debug
(
"wait for Graph engine for data[{}]..."
.
format
(
data_id
))
resp_channeldata
=
self
.
_get_channeldata_from_fetch_buffer
(
data_id
)
resp_channeldata
=
self
.
_get_channeldata_from_fetch_buffer
(
data_id
)
if
resp_channeldata
.
ecode
==
ChannelDataEcode
.
OK
.
value
:
if
resp_channeldata
.
ecode
==
ChannelDataEcode
.
OK
.
value
:
_LOGGER
.
debug
(
"Graph engine predict data[{}] succ"
.
format
(
data_id
))
_LOGGER
.
debug
(
"Graph engine predict data[{}] succ"
.
format
(
data_id
))
break
break
else
:
else
:
_LOGGER
.
warn
(
"Graph engine predict data[{}] failed: {}"
_LOGGER
.
warn
(
"Graph engine predict data[{}] failed: {}"
.
format
(
data_id
,
resp_channeldata
.
error_info
))
.
format
(
data_id
,
resp_channeldata
.
error_info
))
if
resp_channeldata
.
ecode
!=
ChannelDataEcode
.
TIMEOUT
.
value
:
if
resp_channeldata
.
ecode
!=
ChannelDataEcode
.
TIMEOUT
.
value
:
break
break
if
i
+
1
<
self
.
_retry
:
if
i
+
1
<
self
.
_retry
:
_LOGGER
.
warn
(
"retry({}/{}) data[{}]"
.
format
(
_LOGGER
.
warn
(
"retry({}/{}) data[{}]"
.
format
(
i
+
1
,
self
.
_retry
,
i
+
1
,
self
.
_retry
,
data_id
))
data_id
))
_LOGGER
.
debug
(
"unpack channeldata[{}] into RPC resp package"
.
format
(
data_id
))
_LOGGER
.
debug
(
"unpack channeldata[{}] into RPC resp package"
.
format
(
data_id
))
self
.
_profiler
.
record
(
"postpack_{}#{}_0"
.
format
(
data_id
,
self
.
name
))
self
.
_profiler
.
record
(
"postpack_{}#{}_0"
.
format
(
data_id
,
self
.
name
))
rpc_resp
=
self
.
_pack_for_rpc_resp
(
resp_channeldata
)
rpc_resp
=
self
.
_pack_for_rpc_resp
(
resp_channeldata
)
self
.
_profiler
.
record
(
"postpack_{}#{}_1"
.
format
(
data_id
,
self
.
name
))
self
.
_profiler
.
record
(
"postpack_{}#{}_1"
.
format
(
data_id
,
self
.
name
))
...
@@ -380,8 +385,8 @@ class DAG(object):
...
@@ -380,8 +385,8 @@ class DAG(object):
)
)
if
self
.
_build_dag_each_worker
:
if
self
.
_build_dag_each_worker
:
_LOGGER
.
info
(
"Because `build_dag_each_worker` mode is used, "
_LOGGER
.
info
(
"Because `build_dag_each_worker` mode is used, "
"Auto-batching is set to the default config: "
"Auto-batching is set to the default config: "
"batch_size=1, auto_batching_timeout=None"
)
"batch_size=1, auto_batching_timeout=None"
)
for
op
in
used_ops
:
for
op
in
used_ops
:
op
.
use_default_auto_batching_config
()
op
.
use_default_auto_batching_config
()
...
@@ -439,7 +444,7 @@ class DAG(object):
...
@@ -439,7 +444,7 @@ class DAG(object):
channel
=
self
.
_gen_channel
(
channel_name_gen
)
channel
=
self
.
_gen_channel
(
channel_name_gen
)
channels
.
append
(
channel
)
channels
.
append
(
channel
)
_LOGGER
.
debug
(
"[DAG] Channel({}) => Op({})"
_LOGGER
.
debug
(
"[DAG] Channel({}) => Op({})"
.
format
(
channel
.
name
,
op
.
name
))
.
format
(
channel
.
name
,
op
.
name
))
op
.
add_input_channel
(
channel
)
op
.
add_input_channel
(
channel
)
pred_ops
=
pred_op_of_next_view_op
[
op
.
name
]
pred_ops
=
pred_op_of_next_view_op
[
op
.
name
]
if
v_idx
==
0
:
if
v_idx
==
0
:
...
@@ -448,7 +453,7 @@ class DAG(object):
...
@@ -448,7 +453,7 @@ class DAG(object):
# if pred_op is virtual op, it will use ancestors as producers to channel
# if pred_op is virtual op, it will use ancestors as producers to channel
for
pred_op
in
pred_ops
:
for
pred_op
in
pred_ops
:
_LOGGER
.
debug
(
"[DAG] Op({}) => Channel({})"
_LOGGER
.
debug
(
"[DAG] Op({}) => Channel({})"
.
format
(
pred_op
.
name
,
channel
.
name
))
.
format
(
pred_op
.
name
,
channel
.
name
))
pred_op
.
add_output_channel
(
channel
)
pred_op
.
add_output_channel
(
channel
)
processed_op
.
add
(
op
.
name
)
processed_op
.
add
(
op
.
name
)
# find same input op to combine channel
# find same input op to combine channel
...
@@ -465,7 +470,7 @@ class DAG(object):
...
@@ -465,7 +470,7 @@ class DAG(object):
break
break
if
same_flag
:
if
same_flag
:
_LOGGER
.
debug
(
"[DAG] Channel({}) => Op({})"
_LOGGER
.
debug
(
"[DAG] Channel({}) => Op({})"
.
format
(
channel
.
name
,
other_op
.
name
))
.
format
(
channel
.
name
,
other_op
.
name
))
other_op
.
add_input_channel
(
channel
)
other_op
.
add_input_channel
(
channel
)
processed_op
.
add
(
other_op
.
name
)
processed_op
.
add
(
other_op
.
name
)
output_channel
=
self
.
_gen_channel
(
channel_name_gen
)
output_channel
=
self
.
_gen_channel
(
channel_name_gen
)
...
@@ -484,7 +489,7 @@ class DAG(object):
...
@@ -484,7 +489,7 @@ class DAG(object):
for
c
in
channels
:
for
c
in
channels
:
_LOGGER
.
debug
(
"Channel({}):
\n
-producers: {}
\n
-consumers: {}"
_LOGGER
.
debug
(
"Channel({}):
\n
-producers: {}
\n
-consumers: {}"
.
format
(
c
.
name
,
c
.
get_producers
(),
c
.
get_consumers
()))
.
format
(
c
.
name
,
c
.
get_producers
(),
c
.
get_consumers
()))
return
(
actual_ops
,
channels
,
input_channel
,
output_channel
,
pack_func
,
return
(
actual_ops
,
channels
,
input_channel
,
output_channel
,
pack_func
,
unpack_func
)
unpack_func
)
...
...
python/pipeline/operator.py
浏览文件 @
4049eb8a
...
@@ -81,14 +81,12 @@ class Op(object):
...
@@ -81,14 +81,12 @@ class Op(object):
def
use_default_auto_batching_config
(
self
):
def
use_default_auto_batching_config
(
self
):
if
self
.
_batch_size
!=
1
:
if
self
.
_batch_size
!=
1
:
_LOGGER
.
warn
(
_LOGGER
.
warn
(
"Op({}) reset batch_size=1 (original: {})"
"Op({}) reset batch_size=1 (original: {})"
.
format
(
self
.
name
,
self
.
_batch_size
))
.
format
(
self
.
name
,
self
.
_batch_size
))
self
.
_batch_size
=
1
self
.
_batch_size
=
1
if
self
.
_auto_batching_timeout
!=
None
:
if
self
.
_auto_batching_timeout
!=
None
:
_LOGGER
.
warn
(
_LOGGER
.
warn
(
"Op({}) reset auto_batching_timeout=1 (original: {})"
"Op({}) reset auto_batching_timeout=1 (original: {})"
.
format
(
self
.
name
,
self
.
_auto_batching_timeout
))
.
format
(
self
.
name
,
self
.
_auto_batching_timeout
))
self
.
_auto_batching_timeout
=
None
self
.
_auto_batching_timeout
=
None
def
use_profiler
(
self
,
use_profile
):
def
use_profiler
(
self
,
use_profile
):
...
@@ -104,10 +102,12 @@ class Op(object):
...
@@ -104,10 +102,12 @@ class Op(object):
if
self
.
with_serving
==
False
:
if
self
.
with_serving
==
False
:
_LOGGER
.
info
(
"Op({}) no client"
.
format
(
self
.
name
))
_LOGGER
.
info
(
"Op({}) no client"
.
format
(
self
.
name
))
return
None
return
None
_LOGGER
.
info
(
"Op({}) service endpoints: {}"
.
format
(
self
.
name
,
server_endpoints
))
_LOGGER
.
info
(
"Op({}) service endpoints: {}"
.
format
(
self
.
name
,
server_endpoints
))
_LOGGER
.
debug
(
"Op({}) fetch_names: {}"
.
format
(
self
.
name
,
fetch_names
))
_LOGGER
.
debug
(
"Op({}) fetch_names: {}"
.
format
(
self
.
name
,
fetch_names
))
if
client_type
==
'brpc'
:
if
client_type
==
'brpc'
:
_LOGGER
.
debug
(
"Op({}) client_config: {}"
.
format
(
self
.
name
,
client_config
))
_LOGGER
.
debug
(
"Op({}) client_config: {}"
.
format
(
self
.
name
,
client_config
))
client
=
Client
()
client
=
Client
()
client
.
load_client_config
(
client_config
)
client
.
load_client_config
(
client_config
)
elif
client_type
==
'grpc'
:
elif
client_type
==
'grpc'
:
...
@@ -259,27 +259,24 @@ class Op(object):
...
@@ -259,27 +259,24 @@ class Op(object):
preped_data
=
self
.
preprocess
(
parsed_data
)
preped_data
=
self
.
preprocess
(
parsed_data
)
except
NotImplementedError
as
e
:
except
NotImplementedError
as
e
:
# preprocess function not implemented
# preprocess function not implemented
error_info
=
log_func
(
error_info
=
log_func
(
"preprocess data[{}] failed: {}"
.
format
(
"preprocess data[{}] failed: {}"
.
format
(
data_id
,
e
))
data_id
,
e
))
error_channeldata
=
ChannelData
(
error_channeldata
=
ChannelData
(
ecode
=
ChannelDataEcode
.
NOT_IMPLEMENTED
.
value
,
ecode
=
ChannelDataEcode
.
NOT_IMPLEMENTED
.
value
,
error_info
=
error_info
,
error_info
=
error_info
,
data_id
=
data_id
)
data_id
=
data_id
)
except
TypeError
as
e
:
except
TypeError
as
e
:
# Error type in channeldata.datatype
# Error type in channeldata.datatype
error_info
=
log_func
(
error_info
=
log_func
(
"preprocess data[{}] failed: {}"
.
format
(
"preprocess data[{}] failed: {}"
.
format
(
data_id
,
e
))
data_id
,
e
))
_LOGGER
.
error
(
error_info
)
_LOGGER
.
error
(
error_info
)
error_channeldata
=
ChannelData
(
error_channeldata
=
ChannelData
(
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
,
ecode
=
ChannelDataEcode
.
TYPE_ERROR
.
value
,
error_info
=
error_info
,
error_info
=
error_info
,
data_id
=
data_id
)
data_id
=
data_id
)
except
Exception
as
e
:
except
Exception
as
e
:
error_info
=
log_func
(
error_info
=
log_func
(
"preprocess data[{}] failed: {}"
.
format
(
"preprocess data[{}] failed: {}"
.
format
(
data_id
,
e
))
data_id
,
e
))
_LOGGER
.
error
(
error_info
)
_LOGGER
.
error
(
error_info
)
error_channeldata
=
ChannelData
(
error_channeldata
=
ChannelData
(
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
,
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
,
...
@@ -321,10 +318,11 @@ class Op(object):
...
@@ -321,10 +318,11 @@ class Op(object):
else
:
else
:
_LOGGER
.
warn
(
_LOGGER
.
warn
(
log_func
(
"timeout, retry({}/{})"
log_func
(
"timeout, retry({}/{})"
.
format
(
i
+
1
,
self
.
_retry
)))
.
format
(
i
+
1
,
self
.
_retry
)))
except
Exception
as
e
:
except
Exception
as
e
:
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
error_info
=
log_func
(
"process batch failed: {}"
.
format
(
e
))
error_info
=
log_func
(
"process batch failed: {}"
.
format
(
e
))
_LOGGER
.
error
(
error_info
)
_LOGGER
.
error
(
error_info
)
break
break
else
:
else
:
...
@@ -332,24 +330,23 @@ class Op(object):
...
@@ -332,24 +330,23 @@ class Op(object):
if
ecode
!=
ChannelDataEcode
.
OK
.
value
:
if
ecode
!=
ChannelDataEcode
.
OK
.
value
:
for
data_id
in
data_ids
:
for
data_id
in
data_ids
:
err_channeldata_dict
[
data_id
]
=
ChannelData
(
err_channeldata_dict
[
data_id
]
=
ChannelData
(
ecode
=
ecode
,
ecode
=
ecode
,
error_info
=
error_info
,
data_id
=
data_id
)
error_info
=
error_info
,
data_id
=
data_id
)
elif
midped_batch
is
None
:
elif
midped_batch
is
None
:
# op client return None
# op client return None
error_info
=
log_func
(
error_info
=
log_func
(
"predict failed. pls check the server side."
)
"predict failed. pls check the server side."
)
_LOGGER
.
error
(
error_info
)
_LOGGER
.
error
(
error_info
)
for
data_id
in
data_ids
:
for
data_id
in
data_ids
:
err_channeldata_dict
[
data_id
]
=
ChannelData
(
err_channeldata_dict
[
data_id
]
=
ChannelData
(
ecode
=
ChannelDataEcode
.
CLIENT_ERROR
.
value
,
ecode
=
ChannelDataEcode
.
CLIENT_ERROR
.
value
,
error_info
=
error_info
,
error_info
=
error_info
,
data_id
=
data_id
)
data_id
=
data_id
)
else
:
else
:
# transform np format to dict format
# transform np format to dict format
for
idx
,
data_id
in
enumerate
(
data_ids
):
for
idx
,
data_id
in
enumerate
(
data_ids
):
midped_data_dict
[
data_id
]
=
{
midped_data_dict
[
data_id
]
=
{
k
:
v
[
idx
]
for
k
,
v
in
midped_batch
.
items
()
k
:
v
[
idx
]
for
k
,
v
in
midped_batch
.
items
()
}
}
else
:
else
:
midped_data_dict
=
preped_data_dict
midped_data_dict
=
preped_data_dict
...
@@ -363,11 +360,11 @@ class Op(object):
...
@@ -363,11 +360,11 @@ class Op(object):
for
data_id
,
midped_data
in
midped_data_dict
.
items
():
for
data_id
,
midped_data
in
midped_data_dict
.
items
():
postped_data
,
err_channeldata
=
None
,
None
postped_data
,
err_channeldata
=
None
,
None
try
:
try
:
postped_data
=
self
.
postprocess
(
postped_data
=
self
.
postprocess
(
parsed_data_dict
[
data_id
],
parsed_data_dict
[
data_id
],
midped_data
)
midped_data
)
except
Exception
as
e
:
except
Exception
as
e
:
error_info
=
log_func
(
"postprocess data[{}] failed: {}"
error_info
=
log_func
(
"postprocess data[{}] failed: {}"
.
format
(
data_id
,
e
))
.
format
(
data_id
,
e
))
_LOGGER
.
error
(
error_info
)
_LOGGER
.
error
(
error_info
)
err_channeldata
=
ChannelData
(
err_channeldata
=
ChannelData
(
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
,
ecode
=
ChannelDataEcode
.
UNKNOW
.
value
,
...
@@ -403,15 +400,14 @@ class Op(object):
...
@@ -403,15 +400,14 @@ class Op(object):
postped_data_dict
[
data_id
]
=
output_data
postped_data_dict
[
data_id
]
=
output_data
_LOGGER
.
debug
(
log_func
(
"succ run postprocess"
))
_LOGGER
.
debug
(
log_func
(
"succ run postprocess"
))
return
postped_data_dict
,
err_channeldata_dict
return
postped_data_dict
,
err_channeldata_dict
def
_auto_batching_generator
(
self
,
input_channel
,
op_name
,
def
_auto_batching_generator
(
self
,
input_channel
,
op_name
,
batch_size
,
batch_size
,
timeout
,
log_func
):
timeout
,
log_func
):
while
True
:
while
True
:
batch
=
[]
batch
=
[]
_LOGGER
.
debug
(
_LOGGER
.
debug
(
log_func
(
log_func
(
"Auto-batching expect size: {}; timeout: {}"
.
format
(
"Auto-batching expect size: {}; timeout: {}"
.
format
(
batch_size
,
timeout
)))
batch_size
,
timeout
)))
while
len
(
batch
)
==
0
:
while
len
(
batch
)
==
0
:
endtime
=
None
endtime
=
None
if
timeout
is
not
None
:
if
timeout
is
not
None
:
...
@@ -424,14 +420,16 @@ class Op(object):
...
@@ -424,14 +420,16 @@ class Op(object):
if
remaining
<=
0.0
:
if
remaining
<=
0.0
:
_LOGGER
.
debug
(
log_func
(
"Auto-batching timeout"
))
_LOGGER
.
debug
(
log_func
(
"Auto-batching timeout"
))
break
break
channeldata_dict
=
input_channel
.
front
(
op_name
,
timeout
)
channeldata_dict
=
input_channel
.
front
(
op_name
,
timeout
)
else
:
else
:
channeldata_dict
=
input_channel
.
front
(
op_name
)
channeldata_dict
=
input_channel
.
front
(
op_name
)
batch
.
append
(
channeldata_dict
)
batch
.
append
(
channeldata_dict
)
except
ChannelTimeoutError
:
except
ChannelTimeoutError
:
_LOGGER
.
debug
(
log_func
(
"Auto-batching timeout"
))
_LOGGER
.
debug
(
log_func
(
"Auto-batching timeout"
))
break
break
_LOGGER
.
debug
(
log_func
(
"Auto-batching actual size: {}"
.
format
(
len
(
batch
))))
_LOGGER
.
debug
(
log_func
(
"Auto-batching actual size: {}"
.
format
(
len
(
batch
))))
yield
batch
yield
batch
def
_parse_channeldata_batch
(
self
,
batch
,
output_channels
):
def
_parse_channeldata_batch
(
self
,
batch
,
output_channels
):
...
@@ -449,16 +447,17 @@ class Op(object):
...
@@ -449,16 +447,17 @@ class Op(object):
else
:
else
:
# error data in predecessor Op
# error data in predecessor Op
# (error_channeldata with profile info)
# (error_channeldata with profile info)
self
.
_push_to_output_channels
(
self
.
_push_to_output_channels
(
error_channeldata
,
error_channeldata
,
output_channels
)
output_channels
)
return
parsed_data_dict
,
need_profile_dict
,
profile_dict
return
parsed_data_dict
,
need_profile_dict
,
profile_dict
def
_run
(
self
,
concurrency_idx
,
input_channel
,
output_channels
,
def
_run
(
self
,
concurrency_idx
,
input_channel
,
output_channels
,
client_type
,
client_type
,
is_thread_op
):
is_thread_op
):
def
get_log_func
(
op_info_prefix
):
def
get_log_func
(
op_info_prefix
):
def
log_func
(
info_str
):
def
log_func
(
info_str
):
return
"{} {}"
.
format
(
op_info_prefix
,
info_str
)
return
"{} {}"
.
format
(
op_info_prefix
,
info_str
)
return
log_func
return
log_func
op_info_prefix
=
"[{}|{}]"
.
format
(
self
.
name
,
concurrency_idx
)
op_info_prefix
=
"[{}|{}]"
.
format
(
self
.
name
,
concurrency_idx
)
...
@@ -474,12 +473,12 @@ class Op(object):
...
@@ -474,12 +473,12 @@ class Op(object):
_LOGGER
.
info
(
log
(
"succ init"
))
_LOGGER
.
info
(
log
(
"succ init"
))
batch_generator
=
self
.
_auto_batching_generator
(
batch_generator
=
self
.
_auto_batching_generator
(
input_channel
=
input_channel
,
input_channel
=
input_channel
,
op_name
=
self
.
name
,
op_name
=
self
.
name
,
batch_size
=
self
.
_batch_size
,
batch_size
=
self
.
_batch_size
,
timeout
=
self
.
_auto_batching_timeout
,
timeout
=
self
.
_auto_batching_timeout
,
log_func
=
log
)
log_func
=
log
)
while
True
:
while
True
:
try
:
try
:
channeldata_dict_batch
=
next
(
batch_generator
)
channeldata_dict_batch
=
next
(
batch_generator
)
...
@@ -528,10 +527,10 @@ class Op(object):
...
@@ -528,10 +527,10 @@ class Op(object):
try
:
try
:
for
data_id
,
err_channeldata
in
err_channeldata_dict
.
items
():
for
data_id
,
err_channeldata
in
err_channeldata_dict
.
items
():
self
.
_push_to_output_channels
(
self
.
_push_to_output_channels
(
err_channeldata
,
err_channeldata
,
output_channels
,
output_channels
,
client_need_profile
=
need_profile_dict
[
data_id
],
client_need_profile
=
need_profile_dict
[
data_id
],
profile_set
=
profile_dict
[
data_id
])
profile_set
=
profile_dict
[
data_id
])
except
ChannelStopError
:
except
ChannelStopError
:
_LOGGER
.
debug
(
log
(
"channel stop."
))
_LOGGER
.
debug
(
log
(
"channel stop."
))
self
.
_finalize
(
is_thread_op
)
self
.
_finalize
(
is_thread_op
)
...
@@ -548,10 +547,10 @@ class Op(object):
...
@@ -548,10 +547,10 @@ class Op(object):
try
:
try
:
for
data_id
,
err_channeldata
in
err_channeldata_dict
.
items
():
for
data_id
,
err_channeldata
in
err_channeldata_dict
.
items
():
self
.
_push_to_output_channels
(
self
.
_push_to_output_channels
(
error_channeldata
,
error_channeldata
,
output_channels
,
output_channels
,
client_need_profile
=
need_profile_dict
[
data_id
],
client_need_profile
=
need_profile_dict
[
data_id
],
profile_set
=
profile_dict
[
data_id
])
profile_set
=
profile_dict
[
data_id
])
except
ChannelStopError
:
except
ChannelStopError
:
_LOGGER
.
debug
(
log
(
"channel stop."
))
_LOGGER
.
debug
(
log
(
"channel stop."
))
self
.
_finalize
(
is_thread_op
)
self
.
_finalize
(
is_thread_op
)
...
@@ -563,10 +562,10 @@ class Op(object):
...
@@ -563,10 +562,10 @@ class Op(object):
try
:
try
:
for
data_id
,
postped_data
in
postped_data_dict
.
items
():
for
data_id
,
postped_data
in
postped_data_dict
.
items
():
self
.
_push_to_output_channels
(
self
.
_push_to_output_channels
(
postped_data
,
postped_data
,
output_channels
,
output_channels
,
client_need_profile
=
need_profile_dict
[
data_id
],
client_need_profile
=
need_profile_dict
[
data_id
],
profile_set
=
profile_dict
[
data_id
])
profile_set
=
profile_dict
[
data_id
])
except
ChannelStopError
:
except
ChannelStopError
:
_LOGGER
.
debug
(
log
(
"channel stop."
))
_LOGGER
.
debug
(
log
(
"channel stop."
))
self
.
_finalize
(
is_thread_op
)
self
.
_finalize
(
is_thread_op
)
...
@@ -583,8 +582,8 @@ class Op(object):
...
@@ -583,8 +582,8 @@ class Op(object):
self
.
_profiler
.
enable
(
True
)
self
.
_profiler
.
enable
(
True
)
# init client
# init client
self
.
client
=
self
.
init_client
(
self
.
client
=
self
.
init_client
(
client_type
,
self
.
_client_config
,
client_type
,
self
.
_client_config
,
self
.
_server_endpoints
,
self
.
_fetch_names
)
self
.
_server_endpoints
,
self
.
_fetch_names
)
# user defined
# user defined
self
.
init_op
()
self
.
init_op
()
self
.
_succ_init_op
=
True
self
.
_succ_init_op
=
True
...
@@ -595,13 +594,12 @@ class Op(object):
...
@@ -595,13 +594,12 @@ class Op(object):
self
.
_profiler
=
TimeProfiler
()
self
.
_profiler
=
TimeProfiler
()
self
.
_profiler
.
enable
(
True
)
self
.
_profiler
.
enable
(
True
)
# init client
# init client
self
.
client
=
self
.
init_client
(
self
.
client
=
self
.
init_client
(
client_type
,
self
.
_client_config
,
client_type
,
self
.
_client_config
,
self
.
_server_endpoints
,
self
.
_server_endpoints
,
self
.
_fetch_names
)
self
.
_fetch_names
)
# user defined
# user defined
self
.
init_op
()
self
.
init_op
()
def
_finalize
(
self
,
is_thread_op
):
def
_finalize
(
self
,
is_thread_op
):
if
is_thread_op
:
if
is_thread_op
:
with
self
.
_for_close_op_lock
:
with
self
.
_for_close_op_lock
:
...
@@ -625,7 +623,7 @@ class RequestOp(Op):
...
@@ -625,7 +623,7 @@ class RequestOp(Op):
try
:
try
:
self
.
init_op
()
self
.
init_op
()
except
Exception
as
e
:
except
Exception
as
e
:
_LOGGER
.
error
(
e
)
_LOGGER
.
error
(
"Op(Request) init op failed: {}"
.
format
(
e
)
)
os
.
_exit
(
-
1
)
os
.
_exit
(
-
1
)
def
unpack_request_package
(
self
,
request
):
def
unpack_request_package
(
self
,
request
):
...
@@ -649,7 +647,7 @@ class ResponseOp(Op):
...
@@ -649,7 +647,7 @@ class ResponseOp(Op):
try
:
try
:
self
.
init_op
()
self
.
init_op
()
except
Exception
as
e
:
except
Exception
as
e
:
_LOGGER
.
error
(
e
)
_LOGGER
.
error
(
"Op(ResponseOp) init op failed: {}"
.
format
(
e
)
)
os
.
_exit
(
-
1
)
os
.
_exit
(
-
1
)
def
pack_response_package
(
self
,
channeldata
):
def
pack_response_package
(
self
,
channeldata
):
...
@@ -730,7 +728,7 @@ class VirtualOp(Op):
...
@@ -730,7 +728,7 @@ class VirtualOp(Op):
try
:
try
:
channeldata_dict
=
input_channel
.
front
(
self
.
name
)
channeldata_dict
=
input_channel
.
front
(
self
.
name
)
except
ChannelStopError
:
except
ChannelStopError
:
_LOGGER
.
debug
(
log
(
"stop."
))
_LOGGER
.
debug
(
log
(
"
Channel
stop."
))
break
break
try
:
try
:
...
@@ -738,5 +736,5 @@ class VirtualOp(Op):
...
@@ -738,5 +736,5 @@ class VirtualOp(Op):
self
.
_push_to_output_channels
(
self
.
_push_to_output_channels
(
data
,
channels
=
output_channels
,
name
=
name
)
data
,
channels
=
output_channels
,
name
=
name
)
except
ChannelStopError
:
except
ChannelStopError
:
_LOGGER
.
debug
(
log
(
"stop."
))
_LOGGER
.
debug
(
log
(
"
Channel
stop."
))
break
break
python/pipeline/pipeline_client.py
浏览文件 @
4049eb8a
...
@@ -61,7 +61,7 @@ class PipelineClient(object):
...
@@ -61,7 +61,7 @@ class PipelineClient(object):
def
_unpack_response_package
(
self
,
resp
,
fetch
):
def
_unpack_response_package
(
self
,
resp
,
fetch
):
if
resp
.
ecode
!=
0
:
if
resp
.
ecode
!=
0
:
return
{
return
{
"ecode"
:
resp
.
ecode
,
"ecode"
:
resp
.
ecode
,
"ecode_desc"
:
ChannelDataEcode
(
resp
.
ecode
),
"ecode_desc"
:
ChannelDataEcode
(
resp
.
ecode
),
"error_info"
:
resp
.
error_info
,
"error_info"
:
resp
.
error_info
,
}
}
...
...
python/pipeline/pipeline_server.py
浏览文件 @
4049eb8a
...
@@ -90,7 +90,7 @@ class PipelineServer(object):
...
@@ -90,7 +90,7 @@ class PipelineServer(object):
for
key
,
val
in
default_config
.
items
():
for
key
,
val
in
default_config
.
items
():
if
yml_config
.
get
(
key
)
is
None
:
if
yml_config
.
get
(
key
)
is
None
:
_LOGGER
.
warning
(
"[CONF] {} not set, use default: {}"
_LOGGER
.
warning
(
"[CONF] {} not set, use default: {}"
.
format
(
key
,
val
))
.
format
(
key
,
val
))
yml_config
[
key
]
=
val
yml_config
[
key
]
=
val
self
.
_port
=
yml_config
[
"port"
]
self
.
_port
=
yml_config
[
"port"
]
...
@@ -98,12 +98,13 @@ class PipelineServer(object):
...
@@ -98,12 +98,13 @@ class PipelineServer(object):
raise
SystemExit
(
"Prot {} is already used"
.
format
(
self
.
_port
))
raise
SystemExit
(
"Prot {} is already used"
.
format
(
self
.
_port
))
self
.
_worker_num
=
yml_config
[
"worker_num"
]
self
.
_worker_num
=
yml_config
[
"worker_num"
]
self
.
_build_dag_each_worker
=
yml_config
[
"build_dag_each_worker"
]
self
.
_build_dag_each_worker
=
yml_config
[
"build_dag_each_worker"
]
_LOGGER
.
info
(
"============= PIPELINE SERVER ============="
)
_LOGGER
.
info
(
"============= PIPELINE SERVER ============="
)
for
key
in
default_config
.
keys
():
for
key
in
default_config
.
keys
():
_LOGGER
.
info
(
"{}: {}"
.
format
(
key
,
yml_config
[
key
]))
_LOGGER
.
info
(
"{}: {}"
.
format
(
key
,
yml_config
[
key
]))
if
self
.
_build_dag_each_worker
is
True
:
if
self
.
_build_dag_each_worker
is
True
:
_LOGGER
.
info
(
"(Make sure that install grpcio whl with --no-binary flag)"
)
_LOGGER
.
info
(
"(Make sure that install grpcio whl with --no-binary flag)"
)
_LOGGER
.
info
(
"-------------------------------------------"
)
_LOGGER
.
info
(
"-------------------------------------------"
)
self
.
_dag_config
=
yml_config
.
get
(
"dag"
,
{})
self
.
_dag_config
=
yml_config
.
get
(
"dag"
,
{})
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录