未验证 提交 826c607e 编写于 作者: K Kaipeng Deng 提交者: GitHub

Fix test multiprocess dataloader static (#25287)

* fix test_multiprocess_dataloader_static random fail. test=develop
上级 4474fc10
...@@ -128,8 +128,10 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -128,8 +128,10 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self._need_check_feed = [ self._need_check_feed = [
v.desc.need_check_feed() for v in self._feed_list v.desc.need_check_feed() for v in self._feed_list
] ]
# if only 1 place, do not need to keep order
self._blocking_queue = core.init_lod_tensor_blocking_queue( self._blocking_queue = core.init_lod_tensor_blocking_queue(
core.Variable(), self._blocking_queue_capacity, True) core.Variable(), self._blocking_queue_capacity,
len(self._places) > 1)
self._reader = core.create_py_reader( self._reader = core.create_py_reader(
self._blocking_queue, self._var_names, self._shapes, self._dtypes, self._blocking_queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_buffer_reader, True) self._need_check_feed, self._places, self._use_buffer_reader, True)
...@@ -280,8 +282,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -280,8 +282,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._need_check_feed = [ self._need_check_feed = [
v.desc.need_check_feed() for v in self._feed_list v.desc.need_check_feed() for v in self._feed_list
] ]
# if only 1 place, do not need to keep order
self._blocking_queue = core.init_lod_tensor_blocking_queue( self._blocking_queue = core.init_lod_tensor_blocking_queue(
core.Variable(), self._outstanding_capacity, True) core.Variable(), self._outstanding_capacity, len(self._places) > 1)
self._reader = core.create_py_reader( self._reader = core.create_py_reader(
self._blocking_queue, self._var_names, self._shapes, self._dtypes, self._blocking_queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_buffer_reader, True) self._need_check_feed, self._places, self._use_buffer_reader, True)
...@@ -442,6 +445,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -442,6 +445,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# get data again # get data again
data = self._data_queue.get(timeout=self._timeout) data = self._data_queue.get(timeout=self._timeout)
except Exception as e: except Exception as e:
# check if thread done event set when waiting data
if self._thread_done_event.is_set():
continue
# check failed workers
failed_workers = [] failed_workers = []
for i, w in enumerate(self._workers): for i, w in enumerate(self._workers):
if self._worker_status[i] and not w.is_alive(): if self._worker_status[i] and not w.is_alive():
......
...@@ -28,12 +28,7 @@ from paddle.fluid.dygraph.nn import Linear ...@@ -28,12 +28,7 @@ from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from test_multiprocess_dataloader_static import RandomDataset, prepare_places from test_multiprocess_dataloader_static import RandomDataset, prepare_places
from test_multiprocess_dataloader_static import EPOCH_NUM, BATCH_SIZE, IMAGE_SIZE, SAMPLE_NUM, CLASS_NUM
EPOCH_NUM = 5
BATCH_SIZE = 16
IMAGE_SIZE = 784
SAMPLE_NUM = 400
CLASS_NUM = 10
class SimpleFCNet(fluid.dygraph.Layer): class SimpleFCNet(fluid.dygraph.Layer):
......
...@@ -25,10 +25,10 @@ import numpy as np ...@@ -25,10 +25,10 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.io import Dataset, BatchSampler, DataLoader from paddle.io import Dataset, BatchSampler, DataLoader
EPOCH_NUM = 5 EPOCH_NUM = 3
BATCH_SIZE = 16 BATCH_SIZE = 8
IMAGE_SIZE = 784 IMAGE_SIZE = 32
SAMPLE_NUM = 400 SAMPLE_NUM = 100
CLASS_NUM = 10 CLASS_NUM = 10
...@@ -157,10 +157,6 @@ class TestStaticDataLoader(unittest.TestCase): ...@@ -157,10 +157,6 @@ class TestStaticDataLoader(unittest.TestCase):
return ret return ret
def test_main(self): def test_main(self):
# FIXME(dkp): disable for random fail in Py35 cloud,
# should be fixed ASAP
if sys.version[:3] == '3.5':
return
for p in prepare_places(True): for p in prepare_places(True):
results = [] results = []
for num_workers in [0, 2]: for num_workers in [0, 2]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册