Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5d96b6e0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5d96b6e0
编写于
2月 18, 2020
作者:
C
Chen Weihang
提交者:
GitHub
2月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Queue.get delay for multiprocess data loader (#22604) (#22640)
上级
750c6f42
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
69 addition
and
18 deletion
+69
-18
python/paddle/fluid/reader.py
python/paddle/fluid/reader.py
+31
-17
python/paddle/fluid/tests/unittests/test_imperative_data_loader_exception.py
.../tests/unittests/test_imperative_data_loader_exception.py
+38
-1
未找到文件。
python/paddle/fluid/reader.py
浏览文件 @
5d96b6e0
...
@@ -34,8 +34,9 @@ if sys.version_info[0] == 2:
...
@@ -34,8 +34,9 @@ if sys.version_info[0] == 2:
import
Queue
as
queue
import
Queue
as
queue
else
:
else
:
import
queue
import
queue
# NOTE: [ avoid hanging ] This value is used in getting data from another process
# NOTE: [ avoid hanging ] These value is used in getting data from another process
MP_CHECK_TIMEOUT
=
10
QUEUE_GET_TIMEOUT
=
5
MAX_GET_FAILED_TIME
=
12
__all__
=
[
'PyReader'
,
'DataLoader'
]
__all__
=
[
'PyReader'
,
'DataLoader'
]
...
@@ -485,6 +486,17 @@ class DygraphGeneratorLoader(DataLoaderBase):
...
@@ -485,6 +486,17 @@ class DygraphGeneratorLoader(DataLoaderBase):
signal
.
signal
(
signal
.
SIGCHLD
,
__handler__
)
signal
.
signal
(
signal
.
SIGCHLD
,
__handler__
)
def
_exit_thread_expectedly
(
self
):
self
.
_thread_done_event
.
set
()
self
.
_blocking_queue
.
close
()
self
.
_data_queue
.
close
()
def
_exit_thread_unexpectedly
(
self
):
self
.
_thread_done_event
.
set
()
self
.
_blocking_queue
.
kill
()
self
.
_data_queue
.
close
()
logging
.
error
(
"DataLoader reader thread raised an exception!"
)
def
_reader_process_loop
(
self
):
def
_reader_process_loop
(
self
):
try
:
try
:
# set signal handler
# set signal handler
...
@@ -506,6 +518,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
...
@@ -506,6 +518,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
six
.
reraise
(
*
sys
.
exc_info
())
six
.
reraise
(
*
sys
.
exc_info
())
def
_reader_thread_loop_with_process
(
self
):
def
_reader_thread_loop_with_process
(
self
):
get_sample_try_time
=
0
while
not
self
.
_thread_done_event
.
is_set
():
while
not
self
.
_thread_done_event
.
is_set
():
try
:
try
:
# NOTE: [ avoid hanging ] Even with carefully designed data dependencies
# NOTE: [ avoid hanging ] Even with carefully designed data dependencies
...
@@ -513,10 +526,21 @@ class DygraphGeneratorLoader(DataLoaderBase):
...
@@ -513,10 +526,21 @@ class DygraphGeneratorLoader(DataLoaderBase):
# still happen when data in queue is corrupted (e.g., due to
# still happen when data in queue is corrupted (e.g., due to
# Queue.cancel_join_thread or unexpected exit). So we set a timeout whenever
# Queue.cancel_join_thread or unexpected exit). So we set a timeout whenever
# we try to get data from `data_queue`
# we try to get data from `data_queue`
sample
=
self
.
_data_queue
.
get
(
timeout
=
MP_CHECK_TIMEOUT
)
sample
=
self
.
_data_queue
.
get
(
timeout
=
QUEUE_GET_TIMEOUT
)
get_sample_try_time
=
0
except
queue
.
Empty
:
except
queue
.
Empty
:
self
.
_thread_done_event
.
set
()
get_sample_try_time
+=
1
logging
.
error
(
"The reader has not read data for a long time."
)
if
get_sample_try_time
>
MAX_GET_FAILED_TIME
:
self
.
_exit_thread_unexpectedly
()
raise
RuntimeError
(
"DataLoader reader thread has not read data for a long time (60s)."
)
else
:
# NOTE: [ avoid failed quickly ] Sometimes if the reader child process has a heavy burden,
# the child process has no enough time to put the data in the queue when the main process
# start trying to get data from queue. At this time, failure to read data should not be
# counted as a fatal error, there should be a certain number of attempts.
continue
if
not
self
.
_thread_done_event
.
is_set
():
if
not
self
.
_thread_done_event
.
is_set
():
if
sample
is
not
None
:
if
sample
is
not
None
:
...
@@ -532,20 +556,10 @@ class DygraphGeneratorLoader(DataLoaderBase):
...
@@ -532,20 +556,10 @@ class DygraphGeneratorLoader(DataLoaderBase):
if
not
self
.
_blocking_queue
.
push
(
array
):
if
not
self
.
_blocking_queue
.
push
(
array
):
self
.
_blocking_queue
.
close
()
self
.
_blocking_queue
.
close
()
except
:
except
:
self
.
_thread_done_event
.
set
()
self
.
_exit_thread_unexpectedly
()
self
.
_blocking_queue
.
kill
()
self
.
_data_queue
.
close
()
logging
.
warning
(
"DygraphDataLoader reader thread raised an exception."
)
six
.
reraise
(
*
sys
.
exc_info
())
six
.
reraise
(
*
sys
.
exc_info
())
else
:
else
:
self
.
_thread_done_event
.
set
()
self
.
_exit_thread_expectedly
()
self
.
_blocking_queue
.
close
()
self
.
_data_queue
.
close
()
else
:
self
.
_blocking_queue
.
kill
()
self
.
_data_queue
.
close
()
def
_reader_thread_loop
(
self
):
def
_reader_thread_loop
(
self
):
try
:
try
:
...
...
python/paddle/fluid/tests/unittests/test_imperative_data_loader_exception.py
浏览文件 @
5d96b6e0
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
sys
import
sys
import
time
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
...
@@ -20,10 +21,18 @@ from paddle.fluid import core
...
@@ -20,10 +21,18 @@ from paddle.fluid import core
import
paddle.compat
as
cpt
import
paddle.compat
as
cpt
def
get_random_images_and_labels
(
image_shape
,
label_shape
):
image
=
np
.
random
.
random
(
size
=
image_shape
).
astype
(
'float32'
)
label
=
np
.
random
.
random
(
size
=
label_shape
).
astype
(
'int64'
)
return
image
,
label
class
TestDygraphhDataLoaderWithException
(
unittest
.
TestCase
):
class
TestDygraphhDataLoaderWithException
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
batch_size
=
8
self
.
batch_num
=
4
self
.
batch_num
=
4
self
.
capacity
=
2
self
.
epoch_num
=
1
self
.
capacity
=
5
def
test_not_capacity
(
self
):
def
test_not_capacity
(
self
):
with
fluid
.
dygraph
.
guard
():
with
fluid
.
dygraph
.
guard
():
...
@@ -77,6 +86,34 @@ class TestDygraphhDataLoaderWithException(unittest.TestCase):
...
@@ -77,6 +86,34 @@ class TestDygraphhDataLoaderWithException(unittest.TestCase):
exception
=
ex
exception
=
ex
self
.
assertIsNotNone
(
exception
)
self
.
assertIsNotNone
(
exception
)
def
test_multi_process_with_get_timeout
(
self
):
def
slow_batch_generator_creator
(
batch_size
,
batch_num
):
def
__reader__
():
for
_
in
range
(
batch_num
):
time
.
sleep
(
80
)
batch_image
,
batch_label
=
get_random_images_and_labels
(
[
batch_size
,
784
],
[
batch_size
,
1
])
yield
batch_image
,
batch_label
return
__reader__
with
fluid
.
dygraph
.
guard
():
loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
self
.
capacity
,
use_multiprocess
=
True
)
loader
.
set_batch_generator
(
slow_batch_generator_creator
(
self
.
batch_size
,
self
.
batch_num
),
places
=
fluid
.
CPUPlace
())
exception
=
None
try
:
for
_
in
range
(
self
.
epoch_num
):
for
image
,
_
in
loader
():
fluid
.
layers
.
relu
(
image
)
except
core
.
EnforceNotMet
as
ex
:
self
.
assertIn
(
"Blocking queue is killed"
,
cpt
.
get_exception_message
(
ex
))
exception
=
ex
self
.
assertIsNotNone
(
exception
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录