Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a32e8bf1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a32e8bf1
编写于
3月 15, 2021
作者:
K
Kaipeng Deng
提交者:
GitHub
3月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
DataLoader supprot dict str (#31481)
* add dict/str/list supprot for DataLoader. test=develop
上级
30a627aa
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
646 addition
and
297 deletion
+646
-297
paddle/fluid/imperative/data_loader.cc
paddle/fluid/imperative/data_loader.cc
+19
-5
paddle/fluid/operators/reader/blocking_queue.h
paddle/fluid/operators/reader/blocking_queue.h
+10
-2
paddle/fluid/pybind/reader_py.cc
paddle/fluid/pybind/reader_py.cc
+6
-4
python/paddle/fluid/dataloader/collate.py
python/paddle/fluid/dataloader/collate.py
+87
-0
python/paddle/fluid/dataloader/dataloader_iter.py
python/paddle/fluid/dataloader/dataloader_iter.py
+58
-284
python/paddle/fluid/dataloader/flat.py
python/paddle/fluid/dataloader/flat.py
+150
-0
python/paddle/fluid/dataloader/worker.py
python/paddle/fluid/dataloader/worker.py
+253
-0
python/paddle/fluid/multiprocess_utils.py
python/paddle/fluid/multiprocess_utils.py
+4
-0
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py
...d/tests/unittests/test_multiprocess_dataloader_dataset.py
+57
-0
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_split.py
...ts/test_multiprocess_dataloader_iterable_dataset_split.py
+2
-2
未找到文件。
paddle/fluid/imperative/data_loader.cc
浏览文件 @
a32e8bf1
...
@@ -71,8 +71,11 @@ void EraseLoadProcessPIDs(int64_t key) {
...
@@ -71,8 +71,11 @@ void EraseLoadProcessPIDs(int64_t key) {
} \
} \
} while (0)
} while (0)
#define REGISTER_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME
)
\
#define REGISTER_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME
, ERROR_MSG)
\
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \
auto _w = \
write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \
(void)_w; \
SIGNAL_HANDLE(SIGNAL); \
SIGNAL_HANDLE(SIGNAL); \
}
}
...
@@ -84,8 +87,18 @@ void EraseLoadProcessPIDs(int64_t key) {
...
@@ -84,8 +87,18 @@ void EraseLoadProcessPIDs(int64_t key) {
SIGNAL_HANDLE(SIGNAL); \
SIGNAL_HANDLE(SIGNAL); \
}
}
REGISTER_SIGNAL_HANDLER
(
SIGSEGV
,
SIGSEGV_handler
);
REGISTER_SIGNAL_HANDLER
(
SIGSEGV
,
SIGSEGV_handler
,
REGISTER_SIGNAL_HANDLER
(
SIGBUS
,
SIGBUS_handler
);
"ERROR: Unexpected segmentation fault encountered in "
"DataLoader workers.
\n
"
);
REGISTER_SIGNAL_HANDLER
(
SIGBUS
,
SIGBUS_handler
,
"ERROR: Unexpected BUS error encountered in DataLoader worker. "
"This might be caused by insufficient shared memory (shm), "
"please check whether use_shared_memory is set and storage space "
"in /dev/shm is enough
\n
"
);
REGISTER_SIGNAL_HANDLER
(
SIGFPE
,
SIGFPE_handler
,
"ERROR: Unexpected floating-point exception "
"encountered in DataLoader worker.
\n
"
)
REGISTER_SPEC_SIGNAL_HANDLER
(
SIGTERM
,
SIGTERM_handler
);
REGISTER_SPEC_SIGNAL_HANDLER
(
SIGTERM
,
SIGTERM_handler
);
static
inline
void
setSignalHandler
(
int
signal
,
static
inline
void
setSignalHandler
(
int
signal
,
...
@@ -105,6 +118,7 @@ static inline void setSignalHandler(int signal,
...
@@ -105,6 +118,7 @@ static inline void setSignalHandler(int signal,
void
SetLoadProcessSignalHandler
()
{
void
SetLoadProcessSignalHandler
()
{
setSignalHandler
(
SIGSEGV
,
&
SIGSEGV_handler
,
nullptr
);
setSignalHandler
(
SIGSEGV
,
&
SIGSEGV_handler
,
nullptr
);
setSignalHandler
(
SIGBUS
,
&
SIGBUS_handler
,
nullptr
);
setSignalHandler
(
SIGBUS
,
&
SIGBUS_handler
,
nullptr
);
setSignalHandler
(
SIGFPE
,
&
SIGFPE_handler
,
nullptr
);
setSignalHandler
(
SIGTERM
,
&
SIGTERM_handler
,
nullptr
);
setSignalHandler
(
SIGTERM
,
&
SIGTERM_handler
,
nullptr
);
}
}
...
...
paddle/fluid/operators/reader/blocking_queue.h
浏览文件 @
a32e8bf1
...
@@ -45,7 +45,11 @@ class BlockingQueue {
...
@@ -45,7 +45,11 @@ class BlockingQueue {
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
send_cv_
.
wait
(
send_cv_
.
wait
(
lock
,
[
&
]
{
return
queue_
.
size
()
<
capacity_
||
closed_
||
killed_
;
});
lock
,
[
&
]
{
return
queue_
.
size
()
<
capacity_
||
closed_
||
killed_
;
});
EnforceNotKilled
();
if
(
killed_
)
{
VLOG
(
3
)
<<
"WARNING:: Sending an element to a killed reader::BlokcingQueue"
;
return
false
;
}
if
(
closed_
)
{
if
(
closed_
)
{
VLOG
(
5
)
VLOG
(
5
)
<<
"WARNING: Sending an element to a closed reader::BlokcingQueue."
;
<<
"WARNING: Sending an element to a closed reader::BlokcingQueue."
;
...
@@ -66,7 +70,11 @@ class BlockingQueue {
...
@@ -66,7 +70,11 @@ class BlockingQueue {
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
send_cv_
.
wait
(
send_cv_
.
wait
(
lock
,
[
&
]
{
return
queue_
.
size
()
<
capacity_
||
closed_
||
killed_
;
});
lock
,
[
&
]
{
return
queue_
.
size
()
<
capacity_
||
closed_
||
killed_
;
});
EnforceNotKilled
();
if
(
killed_
)
{
VLOG
(
3
)
<<
"WARNING:: Sending an element to a killed reader::BlokcingQueue"
;
return
false
;
}
if
(
closed_
)
{
if
(
closed_
)
{
VLOG
(
5
)
VLOG
(
5
)
<<
"WARNING: Sending an element to a closed reader::BlokcingQueue."
;
<<
"WARNING: Sending an element to a closed reader::BlokcingQueue."
;
...
...
paddle/fluid/pybind/reader_py.cc
浏览文件 @
a32e8bf1
...
@@ -223,6 +223,10 @@ class MultiDeviceFeedReader {
...
@@ -223,6 +223,10 @@ class MultiDeviceFeedReader {
ReadAsync
();
ReadAsync
();
}
}
void
Shutdown
()
{
for
(
auto
&
r
:
readers_
)
r
->
Shutdown
();
}
~
MultiDeviceFeedReader
()
{
~
MultiDeviceFeedReader
()
{
queue_
->
Close
();
queue_
->
Close
();
pool_
.
reset
();
pool_
.
reset
();
...
@@ -266,10 +270,6 @@ class MultiDeviceFeedReader {
...
@@ -266,10 +270,6 @@ class MultiDeviceFeedReader {
}
}
}
}
void
Shutdown
()
{
for
(
auto
&
r
:
readers_
)
r
->
Shutdown
();
}
void
Start
()
{
void
Start
()
{
for
(
auto
&
r
:
readers_
)
r
->
Start
();
for
(
auto
&
r
:
readers_
)
r
->
Start
();
}
}
...
@@ -362,6 +362,8 @@ void BindMultiDeviceReader(py::module *module, const char *reader_name) {
...
@@ -362,6 +362,8 @@ void BindMultiDeviceReader(py::module *module, const char *reader_name) {
},
},
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"reset"
,
&
ReaderType
::
Reset
,
.
def
(
"reset"
,
&
ReaderType
::
Reset
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"shutdown"
,
&
ReaderType
::
Shutdown
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
}
}
...
...
python/paddle/fluid/dataloader/collate.py
0 → 100644
浏览文件 @
a32e8bf1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
numbers
import
numpy
as
np
from
..framework
import
in_dygraph_mode
from
..
import
core
,
layers
try
:
from
collections.abc
import
Sequence
,
Mapping
except
:
from
collections
import
Sequence
,
Mapping
def
default_collate_fn
(
batch
):
"""
Default batch collating function for :code:`paddle.io.DataLoader`,
batch should be a list of samples, and each sample should be a list
of fields as follows:
[[filed1, filed2, ...], [filed1, filed2, ...], ...]
This default collate function zipped each filed together and stack
each filed as the batch field as follows:
[batch_filed1, batch_filed2, ...]
Args:
batch(list of list of numpy array|paddle.Tensor): the batch data, each fields
should be a numpy array, each sample should be a list of
fileds, and batch should be a list of sample.
Returns:
a list of numpy array|Paddle.Tensor: collated batch of input batch data,
fields data type as same as fields in each sample.
"""
sample
=
batch
[
0
]
if
isinstance
(
sample
,
np
.
ndarray
):
batch
=
np
.
stack
(
batch
,
axis
=
0
)
return
batch
elif
isinstance
(
sample
,
paddle
.
Tensor
):
return
layers
.
stack
(
batch
,
axis
=
0
)
elif
isinstance
(
sample
,
numbers
.
Number
):
batch
=
np
.
array
(
batch
)
return
batch
elif
isinstance
(
sample
,
(
str
,
bytes
)):
return
batch
elif
isinstance
(
sample
,
Mapping
):
return
{
key
:
default_collate_fn
([
d
[
key
]
for
d
in
batch
])
for
key
in
sample
}
elif
isinstance
(
sample
,
Sequence
):
sample_fields_num
=
len
(
sample
)
if
not
all
(
len
(
sample
)
==
sample_fields_num
for
sample
in
iter
(
batch
)):
raise
RuntimeError
(
"fileds number not same among samples in a batch"
)
return
[
default_collate_fn
(
fields
)
for
fields
in
zip
(
*
batch
)]
raise
TypeError
(
"batch data con only contains: tensor, numpy.ndarray, "
"dict, list, number, but got {}"
.
format
(
type
(
sample
)))
return
outputs
def
default_convert_fn
(
batch
):
if
isinstance
(
batch
,
(
paddle
.
Tensor
,
np
.
ndarray
)):
return
batch
elif
isinstance
(
batch
,
(
str
,
bytes
)):
return
batch
elif
isinstance
(
batch
,
Mapping
):
return
{
key
:
default_convert_fn
(
batch
[
key
])
for
key
in
batch
}
elif
isinstance
(
batch
,
Sequence
):
return
[
default_convert_fn
(
d
)
for
d
in
batch
]
else
:
return
batch
python/paddle/fluid/dataloader/dataloader_iter.py
浏览文件 @
a32e8bf1
...
@@ -35,181 +35,16 @@ else:
...
@@ -35,181 +35,16 @@ else:
import
paddle
import
paddle
from
..
import
core
,
layers
from
..
import
core
,
layers
from
..framework
import
in_dygraph_mode
from
..framework
import
in_dygraph_mode
from
..multiprocess_utils
import
CleanupFuncRegistrar
,
_cleanup_mmap
,
_set_SIGCHLD_handler
from
..multiprocess_utils
import
_set_SIGCHLD_handler
,
MP_STATUS_CHECK_INTERVAL
from
.fetcher
import
_IterableDatasetFetcher
,
_MapDatasetFetcher
from
.fetcher
import
_IterableDatasetFetcher
,
_MapDatasetFetcher
from
.batch_sampler
import
_InfiniteIterableSampler
from
.batch_sampler
import
_InfiniteIterableSampler
from
.collate
import
default_collate_fn
,
default_convert_fn
from
.worker
import
ParentWatchDog
,
get_worker_info
,
_worker_loop
,
\
_DatasetKind
,
_IterableDatasetStopIteration
,
_WorkerException
from
.flat
import
_flatten_batch
,
_restore_batch
__all__
=
[
'get_worker_info'
]
__all__
=
[
'get_worker_info'
]
# multi-process worker check indices queue interval, avoid
# hanging in subprocess data loading
MP_INDICES_CHECK_INTERVAL
=
5
_IterableDatasetStopIteration
=
namedtuple
(
'_IterableDatasetStopIteration'
,
[
'worker_id'
])
def
default_collate_fn
(
batch
):
"""
Default batch collating function for :code:`fluid.io.DataLoader`,
batch should be a list of samples, and each sample should be a list
of fields as follows:
[[filed1, filed2, ...], [filed1, filed2, ...], ...]
This default collate function zipped each filed together and stack
each filed as the batch field as follows:
[batch_filed1, batch_filed2, ...]
Args:
batch(list of list of numpy array): the batch data, each fields
should be a numpy array, each sample should be a list of
fileds, and batch should be a list of sample.
Returns:
a list of numpy array: collated batch
"""
sample
=
batch
[
0
]
# dataset has only 1 field
if
isinstance
(
sample
,
np
.
ndarray
):
return
[
np
.
stack
(
batch
,
axis
=
0
)]
# batch each field
slots
=
[]
for
items
in
batch
:
for
i
,
item
in
enumerate
(
items
):
if
len
(
slots
)
<
len
(
items
):
slots
.
append
([
item
])
else
:
slots
[
i
].
append
(
item
)
outputs
=
[]
for
slot
in
slots
:
if
isinstance
(
slot
[
0
],
(
np
.
ndarray
,
np
.
bool
,
numbers
.
Number
)):
tmp
=
np
.
stack
(
slot
,
axis
=
0
)
outputs
.
append
(
tmp
)
elif
isinstance
(
slot
[
0
],
paddle
.
Tensor
):
tmp
=
layers
.
stack
(
slot
,
axis
=
0
)
outputs
.
append
(
tmp
)
else
:
raise
RuntimeError
(
"Unknown data type {}"
.
format
(
type
(
slot
[
0
])))
return
outputs
class
_DatasetKind
(
object
):
MAP
=
0
ITER
=
1
@
staticmethod
def
create_fetcher
(
kind
,
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
):
if
kind
==
_DatasetKind
.
MAP
:
return
_MapDatasetFetcher
(
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
)
elif
kind
==
_DatasetKind
.
ITER
:
return
_IterableDatasetFetcher
(
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
)
else
:
raise
NotImplementedError
(
"unknown Dataset kind {}"
.
format
(
kind
))
class
ParentWatchDog
(
object
):
def
__init__
(
self
):
self
.
_parent_pid
=
os
.
getppid
()
self
.
_parent_alive
=
True
def
is_alive
(
self
):
if
self
.
_parent_alive
:
self
.
_parent_alive
=
os
.
getppid
()
==
self
.
_parent_pid
return
self
.
_parent_alive
# worker information for each workers, used for splitting data copy
# for IteratorDataset in worker processes.
_worker_info
=
None
def
get_worker_info
():
"""
Get DataLoader worker process information function, this function is
used to split data copy in worker process for IterableDataset
(see :code:`paddle.io.IterableDataset`), worker information contains
following fields:
:attr:`num_workers`: total worker process number, see `paddle.io.DataLoader`
:attr:`id`: the worker processs id, count from 0 to :attr:`num_workers - 1`
:attr:`dataset`: the dataset object in this worker process
Returns:
WorkerInfo: an instance of WorkerInfo which contains fields above.
.. note::
For mode usage and exampls, please see :code:`paddle.io.IterableDataset`
Example:
.. code-block:: python
import math
import paddle
import numpy as np
from paddle.io import IterableDataset, DataLoader, get_worker_info
class SplitedIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(
math.ceil((self.end - self.start) / float(
worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
for i in range(iter_start, iter_end):
yield np.array([i])
place = paddle.CPUPlace()
dataset = SplitedIterableDataset(start=2, end=9)
dataloader = DataLoader(
dataset,
places=place,
num_workers=2,
batch_size=1,
drop_last=True)
for data in dataloader:
print(data)
# outputs: [2, 5, 3, 6, 4, 7]
"""
return
_worker_info
class
WorkerInfo
(
object
):
__initialized
=
False
def
__init__
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
setattr
(
self
,
k
,
v
)
self
.
__initialized
=
True
def
__setattr__
(
self
,
key
,
val
):
if
self
.
__initialized
:
raise
RuntimeError
(
"Cannot assign attributes to {} objects"
.
format
(
self
.
__class__
.
__name__
))
return
super
(
WorkerInfo
,
self
).
__setattr__
(
key
,
val
)
class
_DataLoaderIterBase
(
object
):
class
_DataLoaderIterBase
(
object
):
"""
"""
...
@@ -230,7 +65,7 @@ class _DataLoaderIterBase(object):
...
@@ -230,7 +65,7 @@ class _DataLoaderIterBase(object):
self
.
_num_workers
=
loader
.
num_workers
self
.
_num_workers
=
loader
.
num_workers
self
.
_use_buffer_reader
=
loader
.
use_buffer_reader
self
.
_use_buffer_reader
=
loader
.
use_buffer_reader
self
.
_use_shared_memory
=
loader
.
use_shared_memory
self
.
_use_shared_memory
=
loader
.
use_shared_memory
self
.
_timeout
=
loader
.
timeout
if
loader
.
timeout
>
0
else
MP_
INDICE
S_CHECK_INTERVAL
self
.
_timeout
=
loader
.
timeout
if
loader
.
timeout
>
0
else
MP_
STATU
S_CHECK_INTERVAL
self
.
_worker_init_fn
=
loader
.
worker_init_fn
self
.
_worker_init_fn
=
loader
.
worker_init_fn
self
.
_dataset_kind
=
loader
.
dataset_kind
self
.
_dataset_kind
=
loader
.
dataset_kind
self
.
_pin_memory
=
loader
.
pin_memory
self
.
_pin_memory
=
loader
.
pin_memory
...
@@ -244,7 +79,7 @@ class _DataLoaderIterBase(object):
...
@@ -244,7 +79,7 @@ class _DataLoaderIterBase(object):
else
:
else
:
self
.
_sampler_iter
=
iter
(
self
.
_sampler_iter
=
iter
(
_InfiniteIterableSampler
(
self
.
_dataset
,
1
))
_InfiniteIterableSampler
(
self
.
_dataset
,
1
))
self
.
_collate_fn
=
loader
.
collate_fn
self
.
_collate_fn
=
loader
.
collate_fn
or
default_convert_fn
# LoDTensorBlockingQueue instance for create_py_reader and a thread
# LoDTensorBlockingQueue instance for create_py_reader and a thread
# to put mini-batch data to self._blocking_queue, mini-batch data
# to put mini-batch data to self._blocking_queue, mini-batch data
...
@@ -275,6 +110,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
...
@@ -275,6 +110,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self
.
_dataset_kind
,
self
.
_dataset
,
self
.
_auto_collate_batch
,
self
.
_dataset_kind
,
self
.
_dataset
,
self
.
_auto_collate_batch
,
self
.
_collate_fn
,
True
)
self
.
_collate_fn
,
True
)
# NOTE: _structrue_infos used to record the data structure of
# batch to restore batch structure after reading Tensor
# from blocking_queue in single-process mode. Note that
# only single process is used in single-process mode, we
# can record the data structure sequencely in a list without
# recording the send and recv index
self
.
_structure_infos
=
[]
# NOTE: len(self._places) batch data compose as an output
# NOTE: len(self._places) batch data compose as an output
# iteration, set blocking_queue can cache 2 iteration datas
# iteration, set blocking_queue can cache 2 iteration datas
# at most here
# at most here
...
@@ -316,16 +159,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
...
@@ -316,16 +159,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
# read data from dataset in mini-batch
# read data from dataset in mini-batch
batch
=
self
.
_dataset_fetcher
.
fetch
(
indices
)
batch
=
self
.
_dataset_fetcher
.
fetch
(
indices
)
# flat batch and record structure infos
batch
,
structure
=
_flatten_batch
(
batch
)
self
.
_structure_infos
.
append
(
structure
)
# pack as LoDTensorArray
# pack as LoDTensorArray
array
=
core
.
LoDTensorArray
()
array
=
core
.
LoDTensorArray
()
for
slot
in
batch
:
for
slot
in
batch
:
if
not
isinstance
(
slot
,
core
.
LoDTensor
):
if
not
isinstance
(
slot
,
core
.
LoDTensor
):
# FIXME(dkp): blocking_queue only support
# core.LoDTensorArray as input now, read
# numpy data into a LoDTensorArray here,
# should support paddle.Tensor list later
if
isinstance
(
slot
,
paddle
.
Tensor
):
slot
=
slot
.
numpy
()
tmp
=
core
.
LoDTensor
()
tmp
=
core
.
LoDTensor
()
tmp
.
set
(
slot
,
core
.
CPUPlace
())
tmp
.
set
(
slot
,
core
.
CPUPlace
())
slot
=
tmp
slot
=
tmp
...
@@ -348,20 +189,29 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
...
@@ -348,20 +189,29 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
def
__next__
(
self
):
def
__next__
(
self
):
try
:
try
:
if
in_dygraph_mode
():
if
in_dygraph_mode
():
return
self
.
_reader
.
read_next_var_list
()
data
=
self
.
_reader
.
read_next_var_list
()
data
=
_restore_batch
(
data
,
self
.
_structure_infos
.
pop
(
0
))
else
:
else
:
if
self
.
_return_list
:
if
self
.
_return_list
:
data
=
self
.
_reader
.
read_next_list
()
data
=
[
_restore_batch
(
d
,
s
)
for
d
,
s
in
zip
(
data
,
self
.
_structure_infos
[:
len
(
self
.
_places
)])
]
self
.
_structure_infos
=
self
.
_structure_infos
[
len
(
self
.
_places
):]
# static graph organized data on multi-device with list, if
# static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data
# place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode
# from list for devices to be compatible with dygraph mode
if
len
(
self
.
_places
)
==
1
:
if
len
(
self
.
_places
)
==
1
:
return
self
.
_reader
.
read_next_list
()[
0
]
data
=
data
[
0
]
else
:
return
self
.
_reader
.
read_next_list
()
else
:
else
:
return
self
.
_reader
.
read_next
()
data
=
self
.
_reader
.
read_next
()
return
data
except
StopIteration
:
except
StopIteration
:
self
.
_reader
.
reset
()
self
.
_reader
.
shutdown
()
six
.
reraise
(
*
sys
.
exc_info
())
six
.
reraise
(
*
sys
.
exc_info
())
# python2 compatibility
# python2 compatibility
...
@@ -375,97 +225,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
...
@@ -375,97 +225,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self
.
_blocking_queue
.
close
()
self
.
_blocking_queue
.
close
()
# NOTE(chenweihang): _worker_loop must be top level method to be pickled
def
_worker_loop
(
dataset
,
dataset_kind
,
indices_queue
,
out_queue
,
done_event
,
auto_collate_batch
,
collate_fn
,
init_fn
,
worker_id
,
num_workers
,
use_shared_memory
):
try
:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet
# been put into the inter-process Queue. This part of the object needs
# to be cleaned up when the process ends.
CleanupFuncRegistrar
.
register
(
_cleanup_mmap
)
# set signal handler
core
.
_set_process_signal_handler
()
global
_worker_info
_worker_info
=
WorkerInfo
(
id
=
worker_id
,
num_workers
=
num_workers
,
dataset
=
dataset
)
init_exception
=
None
try
:
if
init_fn
is
not
None
:
init_fn
(
worker_id
)
fetcher
=
_DatasetKind
.
create_fetcher
(
dataset_kind
,
dataset
,
auto_collate_batch
,
collate_fn
,
True
)
except
:
init_exception
=
Exception
(
"init_fn failed in worker {}: "
\
"{}"
.
format
(
worker_id
,
sys
.
exc_info
()))
iterator_drained
=
False
parent_watch_dog
=
ParentWatchDog
()
while
parent_watch_dog
.
is_alive
():
try
:
data
=
indices_queue
.
get
(
MP_INDICES_CHECK_INTERVAL
)
except
queue
.
Empty
:
continue
# None as poison piil, so worker event should be set
if
data
is
None
:
assert
done_event
.
is_set
()
or
iterator_drained
,
\
"get None when worker done_event set"
break
# If worker done event is set but get still get data in
# indices_queue, remaining data should be get and skipped.
if
done_event
.
is_set
()
or
iterator_drained
:
continue
idx
,
indices
=
data
try
:
if
init_exception
is
not
None
:
batch
=
init_exception
init_exception
=
None
else
:
batch
=
fetcher
.
fetch
(
indices
)
except
Exception
as
e
:
if
isinstance
(
e
,
StopIteration
)
and
dataset_kind
==
_DatasetKind
.
ITER
:
out_queue
.
put
(
_IterableDatasetStopIteration
(
worker_id
))
iterator_drained
=
True
else
:
out_queue
.
put
((
idx
,
e
))
else
:
if
use_shared_memory
:
# FIXME(dkp): _convert_to_tensor_list only support np.array
# list now, should support paddle.Tensor list
new_batch
=
[]
for
sample
in
batch
:
new_sample
=
[]
for
s
in
sample
:
if
isinstance
(
s
,
paddle
.
Tensor
):
new_sample
.
append
(
s
.
numpy
())
else
:
new_sample
.
append
(
s
)
new_batch
.
append
(
new_sample
)
batch
=
new_batch
tensor_list
=
core
.
_convert_to_tensor_list
(
batch
)
out_queue
.
put
((
idx
,
tensor_list
))
core
.
_remove_tensor_list_mmap_fds
(
tensor_list
)
else
:
out_queue
.
put
((
idx
,
batch
))
except
KeyboardInterrupt
:
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
pass
except
:
six
.
reraise
(
*
sys
.
exc_info
())
finally
:
if
use_shared_memory
:
_cleanup_mmap
()
class
_DataLoaderIterMultiProcess
(
_DataLoaderIterBase
):
class
_DataLoaderIterMultiProcess
(
_DataLoaderIterBase
):
def
__init__
(
self
,
loader
):
def
__init__
(
self
,
loader
):
super
(
_DataLoaderIterMultiProcess
,
self
).
__init__
(
loader
)
super
(
_DataLoaderIterMultiProcess
,
self
).
__init__
(
loader
)
...
@@ -483,6 +242,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -483,6 +242,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_rcvd_idx
=
0
self
.
_rcvd_idx
=
0
self
.
_batches_outstanding
=
0
self
.
_batches_outstanding
=
0
self
.
_task_infos
=
{}
self
.
_task_infos
=
{}
self
.
_structure_infos
=
[]
# indices outstand as _outstanding_capacity at first, and
# indices outstand as _outstanding_capacity at first, and
# blocking_queue capacity is also _outstanding_capacity.
# blocking_queue capacity is also _outstanding_capacity.
...
@@ -617,8 +377,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -617,8 +377,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if
not
self
.
_thread_done_event
.
is_set
():
if
not
self
.
_thread_done_event
.
is_set
():
if
batch
is
None
:
if
batch
is
None
:
self
.
_exit_thread_expectedly
()
self
.
_exit_thread_expectedly
()
elif
isinstance
(
batch
,
Exception
):
self
.
_exit_thread_unexpectedly
()
else
:
else
:
try
:
try
:
# pack as LoDTensorArray
# pack as LoDTensorArray
...
@@ -654,8 +412,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -654,8 +412,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# batch indices and increase _rcvd_idx
# batch indices and increase _rcvd_idx
if
self
.
_dataset_kind
==
_DatasetKind
.
ITER
:
if
self
.
_dataset_kind
==
_DatasetKind
.
ITER
:
while
self
.
_rcvd_idx
<
self
.
_send_idx
:
while
self
.
_rcvd_idx
<
self
.
_send_idx
:
sys
.
stdout
.
flush
()
info
=
self
.
_task_infos
[
self
.
_rcvd_idx
]
info
=
self
.
_task_infos
[
self
.
_rcvd_idx
]
if
len
(
info
)
==
2
or
self
.
_worker_status
[
info
[
0
]]:
if
len
(
info
)
==
3
or
self
.
_worker_status
[
info
[
0
]]:
break
break
del
self
.
_task_infos
[
self
.
_rcvd_idx
]
del
self
.
_task_infos
[
self
.
_rcvd_idx
]
self
.
_rcvd_idx
+=
1
self
.
_rcvd_idx
+=
1
...
@@ -669,13 +428,15 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -669,13 +428,15 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
continue
continue
if
self
.
_rcvd_idx
in
self
.
_task_infos
and
\
if
self
.
_rcvd_idx
in
self
.
_task_infos
and
\
len
(
self
.
_task_infos
[
self
.
_rcvd_idx
])
==
2
:
len
(
self
.
_task_infos
[
self
.
_rcvd_idx
])
==
3
:
return
self
.
_task_infos
.
pop
(
self
.
_rcvd_idx
)[
1
]
info
=
self
.
_task_infos
.
pop
(
self
.
_rcvd_idx
)
self
.
_structure_infos
.
append
(
info
[
2
])
return
info
[
1
]
try
:
try
:
# [ avoid hang ]: main process may blocking at _reader.read_next when
# [ avoid hang ]: main process may blocking at _reader.read_next when
# KeyboardInterrupt, we do following tradeoff:
# KeyboardInterrupt, we do following tradeoff:
# 1. get data with timeout, MP_
INDICE
S_CHECK_INTERVAL(5s) as timeout
# 1. get data with timeout, MP_
STATU
S_CHECK_INTERVAL(5s) as timeout
# default, if KeyboardInterrupt blocking, failed workers will be
# default, if KeyboardInterrupt blocking, failed workers will be
# checked and raise RuntimeError to quit DataLoader in timeout
# checked and raise RuntimeError to quit DataLoader in timeout
# exception handling.
# exception handling.
...
@@ -721,12 +482,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -721,12 +482,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_try_put_indices
()
self
.
_try_put_indices
()
continue
continue
idx
,
batch
=
data
idx
,
batch
,
structure
=
data
if
isinstance
(
batch
,
_WorkerException
):
self
.
_exit_thread_unexpectedly
()
batch
.
reraise
()
if
idx
==
self
.
_rcvd_idx
:
if
idx
==
self
.
_rcvd_idx
:
del
self
.
_task_infos
[
idx
]
del
self
.
_task_infos
[
idx
]
self
.
_structure_infos
.
append
(
structure
)
return
batch
return
batch
else
:
else
:
self
.
_task_infos
[
idx
]
+=
(
batch
,
)
self
.
_task_infos
[
idx
]
+=
(
batch
,
structure
)
continue
continue
def
_try_put_indices
(
self
):
def
_try_put_indices
(
self
):
...
@@ -777,9 +543,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -777,9 +543,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if
in_dygraph_mode
():
if
in_dygraph_mode
():
data
=
self
.
_reader
.
read_next_var_list
()
data
=
self
.
_reader
.
read_next_var_list
()
data
=
_restore_batch
(
data
,
self
.
_structure_infos
.
pop
(
0
))
else
:
else
:
if
self
.
_return_list
:
if
self
.
_return_list
:
data
=
self
.
_reader
.
read_next_list
()
data
=
self
.
_reader
.
read_next_list
()
data
=
[
_restore_batch
(
d
,
s
)
for
d
,
s
in
zip
(
data
,
self
.
_structure_infos
[:
len
(
self
.
_places
)])
]
self
.
_structure_infos
=
self
.
_structure_infos
[
len
(
self
.
_places
):]
# static graph organized data on multi-device with list, if
# static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data
# place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode
# from list for devices to be compatible with dygraph mode
...
@@ -790,7 +564,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -790,7 +564,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_on_output_batch
()
self
.
_on_output_batch
()
return
data
return
data
except
StopIteration
:
except
StopIteration
:
self
.
_reader
.
reset
()
self
.
_reader
.
shutdown
()
self
.
_try_shutdown_all
()
self
.
_try_shutdown_all
()
six
.
reraise
(
*
sys
.
exc_info
())
six
.
reraise
(
*
sys
.
exc_info
())
...
...
python/paddle/fluid/dataloader/flat.py
0 → 100644
浏览文件 @
a32e8bf1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
numbers
import
numpy
as
np
try
:
from
collections.abc
import
Sequence
,
Mapping
except
:
from
collections
import
Sequence
,
Mapping
FIELD_PREFIX
=
"_paddle_field_"
def
_flatten_batch
(
batch
):
"""
For lod_blocking_queue only receive tensor array, flatten batch
data, extract numpy.array data out as a list of numpy.array to
send to lod_blocking_queue, and save the batch data structure
such as fields in other types (str, int, etc) or key-value map
of dictionaries
"""
def
_flatten
(
batch
,
flat_batch
,
structure
,
field_idx
):
if
isinstance
(
batch
,
Sequence
):
for
field
in
batch
:
if
isinstance
(
field
,
np
.
ndarray
):
structure
.
append
(
'{}{}'
.
format
(
FIELD_PREFIX
,
field_idx
))
flat_batch
.
append
(
field
)
field_idx
+=
1
elif
isinstance
(
field
,
paddle
.
Tensor
):
structure
.
append
(
'{}{}'
.
format
(
FIELD_PREFIX
,
field_idx
))
flat_batch
.
append
(
field
.
numpy
())
field_idx
+=
1
elif
isinstance
(
field
,
(
str
,
bytes
,
numbers
.
Number
)):
structure
.
append
(
field
)
elif
isinstance
(
field
,
Sequence
):
field_struct
,
field_idx
=
_flatten
(
field
,
flat_batch
,
[],
field_idx
)
structure
.
append
(
field_struct
)
elif
isinstance
(
field
,
Mapping
):
field_struct
,
field_idx
=
_flatten
(
field
,
flat_batch
,
{},
field_idx
)
structure
.
append
(
field_struct
)
else
:
structure
.
append
(
field
)
elif
isinstance
(
batch
,
Mapping
):
for
k
,
field
in
batch
.
items
():
if
isinstance
(
field
,
np
.
ndarray
):
structure
[
k
]
=
'{}{}'
.
format
(
FIELD_PREFIX
,
field_idx
)
flat_batch
.
append
(
field
)
field_idx
+=
1
elif
isinstance
(
field
,
paddle
.
Tensor
):
structure
[
k
]
=
'{}{}'
.
format
(
FIELD_PREFIX
,
field_idx
)
flat_batch
.
append
(
field
.
numpy
())
field_idx
+=
1
elif
isinstance
(
field
,
(
str
,
bytes
,
numbers
.
Number
)):
structure
[
k
]
=
field
elif
isinstance
(
field
,
Sequence
):
field_struct
,
field_idx
=
_flatten
(
field
,
flat_batch
,
[],
field_idx
)
structure
[
k
]
=
field_struct
elif
isinstance
(
field
,
Mapping
):
field_struct
,
field_idx
=
_flatten
(
field
,
flat_batch
,
{},
field_idx
)
structure
[
k
]
=
field_struct
else
:
structure
[
k
]
=
field
else
:
raise
TypeError
(
"wrong flat data type: {}"
.
format
(
type
(
batch
)))
return
structure
,
field_idx
# sample only contains single fields
if
not
isinstance
(
batch
,
Sequence
):
flat_batch
=
[]
structure
,
_
=
_flatten
([
batch
],
flat_batch
,
[],
0
)
return
flat_batch
,
structure
[
0
]
flat_batch
=
[]
structure
,
_
=
_flatten
(
batch
,
flat_batch
,
[],
0
)
return
flat_batch
,
structure
def
_restore_batch
(
flat_batch
,
structure
):
"""
After reading list of Tensor data from lod_blocking_queue outputs,
use this function to restore the batch data structrue, replace
:attr:`_paddle_field_x` with data from flat_batch
"""
def
_restore
(
structure
,
field_idx
):
if
isinstance
(
structure
,
Sequence
):
for
i
,
field
in
enumerate
(
structure
):
if
isinstance
(
field
,
str
)
and
field
.
startswith
(
FIELD_PREFIX
):
cur_field_idx
=
int
(
field
.
replace
(
FIELD_PREFIX
,
''
))
field_idx
=
max
(
field_idx
,
cur_field_idx
)
assert
flat_batch
[
cur_field_idx
]
is
not
None
,
\
"flat_batch[{}] parsed repeatly"
structure
[
i
]
=
flat_batch
[
cur_field_idx
]
flat_batch
[
cur_field_idx
]
=
None
elif
isinstance
(
field
,
(
str
,
bytes
,
numbers
.
Number
)):
continue
elif
isinstance
(
field
,
(
Sequence
,
Mapping
)):
field_idx
=
_restore
(
structure
[
i
],
field_idx
)
elif
isinstance
(
structure
,
Mapping
):
for
k
,
field
in
structure
.
items
():
if
isinstance
(
field
,
str
)
and
field
.
startswith
(
FIELD_PREFIX
):
cur_field_idx
=
int
(
field
.
replace
(
FIELD_PREFIX
,
''
))
field_idx
=
max
(
field_idx
,
cur_field_idx
)
assert
flat_batch
[
cur_field_idx
]
is
not
None
,
\
"flat_batch[{}] parsed repeatly"
structure
[
k
]
=
flat_batch
[
cur_field_idx
]
flat_batch
[
cur_field_idx
]
=
None
elif
isinstance
(
field
,
(
str
,
bytes
,
numbers
.
Number
)):
continue
elif
isinstance
(
field
,
(
Sequence
,
Mapping
)):
field_idx
=
_restore
(
structure
[
k
],
field_idx
)
else
:
raise
TypeError
(
"wrong flat data type: {}"
.
format
(
type
(
batch
)))
return
field_idx
assert
isinstance
(
flat_batch
,
Sequence
),
\
"flat_batch is not a list or tuple"
# no np.array in dataset, no output tensor from blocking queue
# simply return structure
if
len
(
flat_batch
)
==
0
:
return
structure
# sample only contains single fields
if
isinstance
(
structure
,
(
str
,
bytes
)):
assert
structure
==
'{}{}'
.
format
(
FIELD_PREFIX
,
0
),
\
"invalid structure: {}"
.
format
(
structure
)
return
flat_batch
[
0
]
field_idx
=
_restore
(
structure
,
0
)
assert
field_idx
+
1
==
len
(
flat_batch
),
"Tensor parse incomplete"
return
structure
python/paddle/fluid/dataloader/worker.py
0 → 100644
浏览文件 @
a32e8bf1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
six
import
sys
import
paddle
import
numpy
as
np
import
traceback
from
collections
import
namedtuple
from
..
import
core
from
.fetcher
import
_IterableDatasetFetcher
,
_MapDatasetFetcher
from
..multiprocess_utils
import
_cleanup_mmap
,
CleanupFuncRegistrar
,
MP_STATUS_CHECK_INTERVAL
from
..framework
import
in_dygraph_mode
from
.flat
import
_flatten_batch
# NOTE: queue has a different name in python2 and python3
if
six
.
PY2
:
import
Queue
as
queue
else
:
import
queue
__all__
=
[
'get_worker_info'
]
class
_IterableDatasetStopIteration
(
object
):
def
__init__
(
self
,
worker_id
):
self
.
worker_id
=
worker_id
class
_DatasetKind
(
object
):
MAP
=
0
ITER
=
1
@
staticmethod
def
create_fetcher
(
kind
,
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
):
if
kind
==
_DatasetKind
.
MAP
:
return
_MapDatasetFetcher
(
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
)
elif
kind
==
_DatasetKind
.
ITER
:
return
_IterableDatasetFetcher
(
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
)
else
:
raise
NotImplementedError
(
"unknown Dataset kind {}"
.
format
(
kind
))
class
ParentWatchDog
(
object
):
def
__init__
(
self
):
self
.
_parent_pid
=
os
.
getppid
()
self
.
_parent_alive
=
True
def
is_alive
(
self
):
if
self
.
_parent_alive
:
self
.
_parent_alive
=
os
.
getppid
()
==
self
.
_parent_pid
return
self
.
_parent_alive
# worker information for each workers, used for splitting data copy
# for IteratorDataset in worker processes.
_worker_info
=
None
def
get_worker_info
():
"""
Get DataLoader worker process information function, this function is
used to split data copy in worker process for IterableDataset
(see :code:`paddle.io.IterableDataset`), worker information contains
following fields:
:attr:`num_workers`: total worker process number, see `paddle.io.DataLoader`
:attr:`id`: the worker processs id, count from 0 to :attr:`num_workers - 1`
:attr:`dataset`: the dataset object in this worker process
Returns:
WorkerInfo: an instance of WorkerInfo which contains fields above.
.. note::
For more usage and examples, please see :code:`paddle.io.IterableDataset`
Example:
.. code-block:: python
import math
import paddle
import numpy as np
from paddle.io import IterableDataset, DataLoader, get_worker_info
class SplitedIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(
math.ceil((self.end - self.start) / float(
worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
for i in range(iter_start, iter_end):
yield np.array([i])
place = paddle.CPUPlace()
dataset = SplitedIterableDataset(start=2, end=9)
dataloader = DataLoader(
dataset,
places=place,
num_workers=2,
batch_size=1,
drop_last=True)
for data in dataloader:
print(data)
# outputs: [2, 5, 3, 6, 4, 7]
"""
return
_worker_info
class
WorkerInfo
(
object
):
__initialized
=
False
def
__init__
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
setattr
(
self
,
k
,
v
)
self
.
__initialized
=
True
def
__setattr__
(
self
,
key
,
val
):
if
self
.
__initialized
:
raise
RuntimeError
(
"Cannot assign attributes to {} objects"
.
format
(
self
.
__class__
.
__name__
))
return
super
(
WorkerInfo
,
self
).
__setattr__
(
key
,
val
)
class
_WorkerException
(
object
):
def
__init__
(
self
,
worker_id
,
exc_info
=
None
):
self
.
worker_id
=
worker_id
exc_info
=
exc_info
or
sys
.
exc_info
()
self
.
exc_type
=
exc_info
[
0
]
self
.
exc_msg
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
def
reraise
(
self
):
msg
=
"DataLoader worker({}) caught {} with message:
\n
{}"
.
format
(
self
.
worker_id
,
self
.
exc_type
.
__name__
,
self
.
exc_msg
)
if
getattr
(
self
.
exc_type
,
"message"
,
None
):
raise
self
.
exc_type
(
message
=
msg
)
raise
self
.
exc_type
(
msg
)
def
_worker_loop
(
dataset
,
dataset_kind
,
indices_queue
,
out_queue
,
done_event
,
auto_collate_batch
,
collate_fn
,
init_fn
,
worker_id
,
num_workers
,
use_shared_memory
):
try
:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet
# been put into the inter-process Queue. This part of the object needs
# to be cleaned up when the process ends.
CleanupFuncRegistrar
.
register
(
_cleanup_mmap
)
# set signal handler
core
.
_set_process_signal_handler
()
global
_worker_info
_worker_info
=
WorkerInfo
(
id
=
worker_id
,
num_workers
=
num_workers
,
dataset
=
dataset
)
init_exception
=
None
try
:
if
init_fn
is
not
None
:
init_fn
(
worker_id
)
fetcher
=
_DatasetKind
.
create_fetcher
(
dataset_kind
,
dataset
,
auto_collate_batch
,
collate_fn
,
True
)
except
:
init_exception
=
_WorkerException
(
worker_id
)
iterator_drained
=
False
parent_watch_dog
=
ParentWatchDog
()
while
parent_watch_dog
.
is_alive
():
try
:
data
=
indices_queue
.
get
(
MP_STATUS_CHECK_INTERVAL
)
except
queue
.
Empty
:
continue
# None as poison piil, so worker event should be set
if
data
is
None
:
assert
done_event
.
is_set
()
or
iterator_drained
,
\
"get None when worker done_event set"
break
# If worker done event is set but get still get data in
# indices_queue, remaining data should be get and skipped.
if
done_event
.
is_set
()
or
iterator_drained
:
continue
idx
,
indices
=
data
try
:
if
init_exception
is
not
None
:
batch
=
init_exception
init_exception
=
None
else
:
# NOTE: GPU tensor operation is not supported in sub-process
# but default device is GPU in paddle-gpu version, which
# may copy CPU tensor to GPU even if users want to use
# CPU tensor operation, so we add CPUPlace guard here
# to make sure tensor will be operated only on CPU
with
paddle
.
fluid
.
dygraph
.
guard
(
place
=
paddle
.
CPUPlace
()):
batch
=
fetcher
.
fetch
(
indices
)
except
Exception
as
e
:
if
isinstance
(
e
,
StopIteration
)
and
dataset_kind
==
_DatasetKind
.
ITER
:
out_queue
.
put
(
_IterableDatasetStopIteration
(
worker_id
))
iterator_drained
=
True
else
:
out_queue
.
put
((
idx
,
_WorkerException
(
worker_id
),
None
))
else
:
if
isinstance
(
batch
,
_WorkerException
):
out_queue
.
put
((
idx
,
batch
,
None
))
batch
,
structure
=
_flatten_batch
(
batch
)
if
use_shared_memory
:
tensor_list
=
core
.
_convert_to_tensor_list
(
batch
)
out_queue
.
put
((
idx
,
tensor_list
,
structure
))
core
.
_remove_tensor_list_mmap_fds
(
tensor_list
)
else
:
out_queue
.
put
((
idx
,
batch
,
structure
))
except
KeyboardInterrupt
:
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
pass
except
:
six
.
reraise
(
*
sys
.
exc_info
())
finally
:
if
use_shared_memory
:
_cleanup_mmap
()
python/paddle/fluid/multiprocess_utils.py
浏览文件 @
a32e8bf1
...
@@ -25,6 +25,10 @@ if six.PY2:
...
@@ -25,6 +25,10 @@ if six.PY2:
else
:
else
:
import
queue
import
queue
# multi-process worker check indices queue interval, avoid
# hanging in subprocess data loading
MP_STATUS_CHECK_INTERVAL
=
5.
# NOTE: [ mmap files clear ] If there is still data in the multiprocess queue when the main process finishes reading,
# NOTE: [ mmap files clear ] If there is still data in the multiprocess queue when the main process finishes reading,
# the data in the queue needs to be popped. Then the LoDTensor read by the main process
# the data in the queue needs to be popped. Then the LoDTensor read by the main process
# from the child process will automatically clear the memory-mapped file.
# from the child process will automatically clear the memory-mapped file.
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py
浏览文件 @
a32e8bf1
...
@@ -273,5 +273,62 @@ class TestNumpyMixTensorDataset(TestTensorDataset):
...
@@ -273,5 +273,62 @@ class TestNumpyMixTensorDataset(TestTensorDataset):
assert
isinstance
(
label
,
paddle
.
Tensor
)
assert
isinstance
(
label
,
paddle
.
Tensor
)
class
ComplextDataset
(
Dataset
):
def
__init__
(
self
,
sample_num
):
self
.
sample_num
=
sample_num
def
__len__
(
self
):
return
self
.
sample_num
def
__getitem__
(
self
,
idx
):
return
(
3.1
,
'abc'
,
paddle
.
to_tensor
(
np
.
random
.
random
([
IMAGE_SIZE
]).
astype
(
'float32'
),
place
=
paddle
.
CPUPlace
()),
[
1
,
np
.
random
.
random
([
2
]).
astype
(
'float32'
)],
{
'a'
:
2.0
,
'b'
:
np
.
random
.
random
([
2
]).
astype
(
'float32'
)
})
class
TestComplextDataset
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
):
paddle
.
static
.
default_startup_program
().
random_seed
=
1
paddle
.
static
.
default_main_program
().
random_seed
=
1
place
=
paddle
.
CPUPlace
()
with
fluid
.
dygraph
.
guard
(
place
):
dataset
=
ComplextDataset
(
16
)
assert
len
(
dataset
)
==
16
dataloader
=
DataLoader
(
dataset
,
places
=
place
,
num_workers
=
num_workers
,
batch_size
=
2
,
drop_last
=
True
)
for
i
,
data
in
enumerate
(
dataloader
()):
assert
len
(
data
)
==
5
# data[0]: collate 3.1
assert
data
[
0
].
shape
==
[
2
]
assert
isinstance
(
data
[
1
],
list
)
# data[1]: collate 'abc'
assert
len
(
data
[
1
])
==
2
assert
isinstance
(
data
[
1
][
0
],
str
)
assert
isinstance
(
data
[
1
][
1
],
str
)
# data[2]: collate tensor
assert
data
[
2
].
shape
==
[
2
,
IMAGE_SIZE
]
# data[3]: collate list
assert
isinstance
(
data
[
3
],
list
)
assert
data
[
3
][
0
].
shape
==
[
2
]
assert
data
[
3
][
1
].
shape
==
[
2
,
2
]
# data[4]: collate dict
assert
isinstance
(
data
[
4
],
dict
)
assert
data
[
4
][
'a'
].
shape
==
[
2
]
assert
data
[
4
][
'b'
].
shape
==
[
2
,
2
]
def
test_main
(
self
):
for
num_workers
in
[
0
,
2
]:
self
.
run_main
(
num_workers
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_split.py
浏览文件 @
a32e8bf1
...
@@ -58,7 +58,7 @@ class TestDynamicDataLoaderIterSplit(unittest.TestCase):
...
@@ -58,7 +58,7 @@ class TestDynamicDataLoaderIterSplit(unittest.TestCase):
rets
=
[]
rets
=
[]
for
d
in
dataloader
:
for
d
in
dataloader
:
rets
.
append
(
d
[
0
]
.
numpy
()[
0
][
0
])
rets
.
append
(
d
.
numpy
()[
0
][
0
])
assert
tuple
(
sorted
(
rets
))
==
tuple
(
range
(
0
,
10
))
assert
tuple
(
sorted
(
rets
))
==
tuple
(
range
(
0
,
10
))
...
@@ -102,7 +102,7 @@ class TestDynamicDataLoaderIterInitFuncSplit(unittest.TestCase):
...
@@ -102,7 +102,7 @@ class TestDynamicDataLoaderIterInitFuncSplit(unittest.TestCase):
rets
=
[]
rets
=
[]
for
d
in
dataloader
:
for
d
in
dataloader
:
rets
.
append
(
d
[
0
]
.
numpy
()[
0
][
0
])
rets
.
append
(
d
.
numpy
()[
0
][
0
])
assert
tuple
(
sorted
(
rets
))
==
tuple
(
range
(
0
,
10
))
assert
tuple
(
sorted
(
rets
))
==
tuple
(
range
(
0
,
10
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录