Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a32e8bf1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a32e8bf1
编写于
3月 15, 2021
作者:
K
Kaipeng Deng
提交者:
GitHub
3月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
DataLoader supprot dict str (#31481)
* add dict/str/list supprot for DataLoader. test=develop
上级
30a627aa
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
646 addition
and
297 deletion
+646
-297
paddle/fluid/imperative/data_loader.cc
paddle/fluid/imperative/data_loader.cc
+19
-5
paddle/fluid/operators/reader/blocking_queue.h
paddle/fluid/operators/reader/blocking_queue.h
+10
-2
paddle/fluid/pybind/reader_py.cc
paddle/fluid/pybind/reader_py.cc
+6
-4
python/paddle/fluid/dataloader/collate.py
python/paddle/fluid/dataloader/collate.py
+87
-0
python/paddle/fluid/dataloader/dataloader_iter.py
python/paddle/fluid/dataloader/dataloader_iter.py
+58
-284
python/paddle/fluid/dataloader/flat.py
python/paddle/fluid/dataloader/flat.py
+150
-0
python/paddle/fluid/dataloader/worker.py
python/paddle/fluid/dataloader/worker.py
+253
-0
python/paddle/fluid/multiprocess_utils.py
python/paddle/fluid/multiprocess_utils.py
+4
-0
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py
...d/tests/unittests/test_multiprocess_dataloader_dataset.py
+57
-0
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_split.py
...ts/test_multiprocess_dataloader_iterable_dataset_split.py
+2
-2
未找到文件。
paddle/fluid/imperative/data_loader.cc
浏览文件 @
a32e8bf1
...
@@ -71,9 +71,12 @@ void EraseLoadProcessPIDs(int64_t key) {
...
@@ -71,9 +71,12 @@ void EraseLoadProcessPIDs(int64_t key) {
} \
} \
} while (0)
} while (0)
#define REGISTER_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \
#define REGISTER_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \
SIGNAL_HANDLE(SIGNAL); \
auto _w = \
write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \
(void)_w; \
SIGNAL_HANDLE(SIGNAL); \
}
}
#define REGISTER_SPEC_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \
#define REGISTER_SPEC_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \
...
@@ -84,8 +87,18 @@ void EraseLoadProcessPIDs(int64_t key) {
...
@@ -84,8 +87,18 @@ void EraseLoadProcessPIDs(int64_t key) {
SIGNAL_HANDLE(SIGNAL); \
SIGNAL_HANDLE(SIGNAL); \
}
}
REGISTER_SIGNAL_HANDLER
(
SIGSEGV
,
SIGSEGV_handler
);
REGISTER_SIGNAL_HANDLER
(
SIGSEGV
,
SIGSEGV_handler
,
REGISTER_SIGNAL_HANDLER
(
SIGBUS
,
SIGBUS_handler
);
"ERROR: Unexpected segmentation fault encountered in "
"DataLoader workers.
\n
"
);
REGISTER_SIGNAL_HANDLER
(
SIGBUS
,
SIGBUS_handler
,
"ERROR: Unexpected BUS error encountered in DataLoader worker. "
"This might be caused by insufficient shared memory (shm), "
"please check whether use_shared_memory is set and storage space "
"in /dev/shm is enough
\n
"
);
REGISTER_SIGNAL_HANDLER
(
SIGFPE
,
SIGFPE_handler
,
"ERROR: Unexpected floating-point exception "
"encountered in DataLoader worker.
\n
"
)
REGISTER_SPEC_SIGNAL_HANDLER
(
SIGTERM
,
SIGTERM_handler
);
REGISTER_SPEC_SIGNAL_HANDLER
(
SIGTERM
,
SIGTERM_handler
);
static
inline
void
setSignalHandler
(
int
signal
,
static
inline
void
setSignalHandler
(
int
signal
,
...
@@ -105,6 +118,7 @@ static inline void setSignalHandler(int signal,
...
@@ -105,6 +118,7 @@ static inline void setSignalHandler(int signal,
void
SetLoadProcessSignalHandler
()
{
void
SetLoadProcessSignalHandler
()
{
setSignalHandler
(
SIGSEGV
,
&
SIGSEGV_handler
,
nullptr
);
setSignalHandler
(
SIGSEGV
,
&
SIGSEGV_handler
,
nullptr
);
setSignalHandler
(
SIGBUS
,
&
SIGBUS_handler
,
nullptr
);
setSignalHandler
(
SIGBUS
,
&
SIGBUS_handler
,
nullptr
);
setSignalHandler
(
SIGFPE
,
&
SIGFPE_handler
,
nullptr
);
setSignalHandler
(
SIGTERM
,
&
SIGTERM_handler
,
nullptr
);
setSignalHandler
(
SIGTERM
,
&
SIGTERM_handler
,
nullptr
);
}
}
...
...
paddle/fluid/operators/reader/blocking_queue.h
浏览文件 @
a32e8bf1
...
@@ -45,7 +45,11 @@ class BlockingQueue {
...
@@ -45,7 +45,11 @@ class BlockingQueue {
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
send_cv_
.
wait
(
send_cv_
.
wait
(
lock
,
[
&
]
{
return
queue_
.
size
()
<
capacity_
||
closed_
||
killed_
;
});
lock
,
[
&
]
{
return
queue_
.
size
()
<
capacity_
||
closed_
||
killed_
;
});
EnforceNotKilled
();
if
(
killed_
)
{
VLOG
(
3
)
<<
"WARNING:: Sending an element to a killed reader::BlokcingQueue"
;
return
false
;
}
if
(
closed_
)
{
if
(
closed_
)
{
VLOG
(
5
)
VLOG
(
5
)
<<
"WARNING: Sending an element to a closed reader::BlokcingQueue."
;
<<
"WARNING: Sending an element to a closed reader::BlokcingQueue."
;
...
@@ -66,7 +70,11 @@ class BlockingQueue {
...
@@ -66,7 +70,11 @@ class BlockingQueue {
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
send_cv_
.
wait
(
send_cv_
.
wait
(
lock
,
[
&
]
{
return
queue_
.
size
()
<
capacity_
||
closed_
||
killed_
;
});
lock
,
[
&
]
{
return
queue_
.
size
()
<
capacity_
||
closed_
||
killed_
;
});
EnforceNotKilled
();
if
(
killed_
)
{
VLOG
(
3
)
<<
"WARNING:: Sending an element to a killed reader::BlokcingQueue"
;
return
false
;
}
if
(
closed_
)
{
if
(
closed_
)
{
VLOG
(
5
)
VLOG
(
5
)
<<
"WARNING: Sending an element to a closed reader::BlokcingQueue."
;
<<
"WARNING: Sending an element to a closed reader::BlokcingQueue."
;
...
...
paddle/fluid/pybind/reader_py.cc
浏览文件 @
a32e8bf1
...
@@ -223,6 +223,10 @@ class MultiDeviceFeedReader {
...
@@ -223,6 +223,10 @@ class MultiDeviceFeedReader {
ReadAsync
();
ReadAsync
();
}
}
void
Shutdown
()
{
for
(
auto
&
r
:
readers_
)
r
->
Shutdown
();
}
~
MultiDeviceFeedReader
()
{
~
MultiDeviceFeedReader
()
{
queue_
->
Close
();
queue_
->
Close
();
pool_
.
reset
();
pool_
.
reset
();
...
@@ -266,10 +270,6 @@ class MultiDeviceFeedReader {
...
@@ -266,10 +270,6 @@ class MultiDeviceFeedReader {
}
}
}
}
void
Shutdown
()
{
for
(
auto
&
r
:
readers_
)
r
->
Shutdown
();
}
void
Start
()
{
void
Start
()
{
for
(
auto
&
r
:
readers_
)
r
->
Start
();
for
(
auto
&
r
:
readers_
)
r
->
Start
();
}
}
...
@@ -362,6 +362,8 @@ void BindMultiDeviceReader(py::module *module, const char *reader_name) {
...
@@ -362,6 +362,8 @@ void BindMultiDeviceReader(py::module *module, const char *reader_name) {
},
},
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"reset"
,
&
ReaderType
::
Reset
,
.
def
(
"reset"
,
&
ReaderType
::
Reset
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"shutdown"
,
&
ReaderType
::
Shutdown
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
}
}
...
...
python/paddle/fluid/dataloader/collate.py
0 → 100644
浏览文件 @
a32e8bf1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
numbers
import
numpy
as
np
from
..framework
import
in_dygraph_mode
from
..
import
core
,
layers
try
:
from
collections.abc
import
Sequence
,
Mapping
except
:
from
collections
import
Sequence
,
Mapping
def
default_collate_fn
(
batch
):
"""
Default batch collating function for :code:`paddle.io.DataLoader`,
batch should be a list of samples, and each sample should be a list
of fields as follows:
[[filed1, filed2, ...], [filed1, filed2, ...], ...]
This default collate function zipped each filed together and stack
each filed as the batch field as follows:
[batch_filed1, batch_filed2, ...]
Args:
batch(list of list of numpy array|paddle.Tensor): the batch data, each fields
should be a numpy array, each sample should be a list of
fileds, and batch should be a list of sample.
Returns:
a list of numpy array|Paddle.Tensor: collated batch of input batch data,
fields data type as same as fields in each sample.
"""
sample
=
batch
[
0
]
if
isinstance
(
sample
,
np
.
ndarray
):
batch
=
np
.
stack
(
batch
,
axis
=
0
)
return
batch
elif
isinstance
(
sample
,
paddle
.
Tensor
):
return
layers
.
stack
(
batch
,
axis
=
0
)
elif
isinstance
(
sample
,
numbers
.
Number
):
batch
=
np
.
array
(
batch
)
return
batch
elif
isinstance
(
sample
,
(
str
,
bytes
)):
return
batch
elif
isinstance
(
sample
,
Mapping
):
return
{
key
:
default_collate_fn
([
d
[
key
]
for
d
in
batch
])
for
key
in
sample
}
elif
isinstance
(
sample
,
Sequence
):
sample_fields_num
=
len
(
sample
)
if
not
all
(
len
(
sample
)
==
sample_fields_num
for
sample
in
iter
(
batch
)):
raise
RuntimeError
(
"fileds number not same among samples in a batch"
)
return
[
default_collate_fn
(
fields
)
for
fields
in
zip
(
*
batch
)]
raise
TypeError
(
"batch data con only contains: tensor, numpy.ndarray, "
"dict, list, number, but got {}"
.
format
(
type
(
sample
)))
return
outputs
def
default_convert_fn
(
batch
):
if
isinstance
(
batch
,
(
paddle
.
Tensor
,
np
.
ndarray
)):
return
batch
elif
isinstance
(
batch
,
(
str
,
bytes
)):
return
batch
elif
isinstance
(
batch
,
Mapping
):
return
{
key
:
default_convert_fn
(
batch
[
key
])
for
key
in
batch
}
elif
isinstance
(
batch
,
Sequence
):
return
[
default_convert_fn
(
d
)
for
d
in
batch
]
else
:
return
batch
python/paddle/fluid/dataloader/dataloader_iter.py
浏览文件 @
a32e8bf1
...
@@ -35,181 +35,16 @@ else:
...
@@ -35,181 +35,16 @@ else:
import
paddle
import
paddle
from
..
import
core
,
layers
from
..
import
core
,
layers
from
..framework
import
in_dygraph_mode
from
..framework
import
in_dygraph_mode
from
..multiprocess_utils
import
CleanupFuncRegistrar
,
_cleanup_mmap
,
_set_SIGCHLD_handler
from
..multiprocess_utils
import
_set_SIGCHLD_handler
,
MP_STATUS_CHECK_INTERVAL
from
.fetcher
import
_IterableDatasetFetcher
,
_MapDatasetFetcher
from
.fetcher
import
_IterableDatasetFetcher
,
_MapDatasetFetcher
from
.batch_sampler
import
_InfiniteIterableSampler
from
.batch_sampler
import
_InfiniteIterableSampler
from
.collate
import
default_collate_fn
,
default_convert_fn
from
.worker
import
ParentWatchDog
,
get_worker_info
,
_worker_loop
,
\
_DatasetKind
,
_IterableDatasetStopIteration
,
_WorkerException
from
.flat
import
_flatten_batch
,
_restore_batch
__all__
=
[
'get_worker_info'
]
__all__
=
[
'get_worker_info'
]
# multi-process worker check indices queue interval, avoid
# hanging in subprocess data loading
MP_INDICES_CHECK_INTERVAL
=
5
_IterableDatasetStopIteration
=
namedtuple
(
'_IterableDatasetStopIteration'
,
[
'worker_id'
])
def
default_collate_fn
(
batch
):
"""
Default batch collating function for :code:`fluid.io.DataLoader`,
batch should be a list of samples, and each sample should be a list
of fields as follows:
[[filed1, filed2, ...], [filed1, filed2, ...], ...]
This default collate function zipped each filed together and stack
each filed as the batch field as follows:
[batch_filed1, batch_filed2, ...]
Args:
batch(list of list of numpy array): the batch data, each fields
should be a numpy array, each sample should be a list of
fileds, and batch should be a list of sample.
Returns:
a list of numpy array: collated batch
"""
sample
=
batch
[
0
]
# dataset has only 1 field
if
isinstance
(
sample
,
np
.
ndarray
):
return
[
np
.
stack
(
batch
,
axis
=
0
)]
# batch each field
slots
=
[]
for
items
in
batch
:
for
i
,
item
in
enumerate
(
items
):
if
len
(
slots
)
<
len
(
items
):
slots
.
append
([
item
])
else
:
slots
[
i
].
append
(
item
)
outputs
=
[]
for
slot
in
slots
:
if
isinstance
(
slot
[
0
],
(
np
.
ndarray
,
np
.
bool
,
numbers
.
Number
)):
tmp
=
np
.
stack
(
slot
,
axis
=
0
)
outputs
.
append
(
tmp
)
elif
isinstance
(
slot
[
0
],
paddle
.
Tensor
):
tmp
=
layers
.
stack
(
slot
,
axis
=
0
)
outputs
.
append
(
tmp
)
else
:
raise
RuntimeError
(
"Unknown data type {}"
.
format
(
type
(
slot
[
0
])))
return
outputs
class
_DatasetKind
(
object
):
MAP
=
0
ITER
=
1
@
staticmethod
def
create_fetcher
(
kind
,
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
):
if
kind
==
_DatasetKind
.
MAP
:
return
_MapDatasetFetcher
(
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
)
elif
kind
==
_DatasetKind
.
ITER
:
return
_IterableDatasetFetcher
(
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
)
else
:
raise
NotImplementedError
(
"unknown Dataset kind {}"
.
format
(
kind
))
class
ParentWatchDog
(
object
):
def
__init__
(
self
):
self
.
_parent_pid
=
os
.
getppid
()
self
.
_parent_alive
=
True
def
is_alive
(
self
):
if
self
.
_parent_alive
:
self
.
_parent_alive
=
os
.
getppid
()
==
self
.
_parent_pid
return
self
.
_parent_alive
# worker information for each workers, used for splitting data copy
# for IteratorDataset in worker processes.
_worker_info
=
None
def
get_worker_info
():
"""
Get DataLoader worker process information function, this function is
used to split data copy in worker process for IterableDataset
(see :code:`paddle.io.IterableDataset`), worker information contains
following fields:
:attr:`num_workers`: total worker process number, see `paddle.io.DataLoader`
:attr:`id`: the worker processs id, count from 0 to :attr:`num_workers - 1`
:attr:`dataset`: the dataset object in this worker process
Returns:
WorkerInfo: an instance of WorkerInfo which contains fields above.
.. note::
For mode usage and exampls, please see :code:`paddle.io.IterableDataset`
Example:
.. code-block:: python
import math
import paddle
import numpy as np
from paddle.io import IterableDataset, DataLoader, get_worker_info
class SplitedIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(
math.ceil((self.end - self.start) / float(
worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
for i in range(iter_start, iter_end):
yield np.array([i])
place = paddle.CPUPlace()
dataset = SplitedIterableDataset(start=2, end=9)
dataloader = DataLoader(
dataset,
places=place,
num_workers=2,
batch_size=1,
drop_last=True)
for data in dataloader:
print(data)
# outputs: [2, 5, 3, 6, 4, 7]
"""
return
_worker_info
class
WorkerInfo
(
object
):
__initialized
=
False
def
__init__
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
setattr
(
self
,
k
,
v
)
self
.
__initialized
=
True
def
__setattr__
(
self
,
key
,
val
):
if
self
.
__initialized
:
raise
RuntimeError
(
"Cannot assign attributes to {} objects"
.
format
(
self
.
__class__
.
__name__
))
return
super
(
WorkerInfo
,
self
).
__setattr__
(
key
,
val
)
class
_DataLoaderIterBase
(
object
):
class
_DataLoaderIterBase
(
object
):
"""
"""
...
@@ -230,7 +65,7 @@ class _DataLoaderIterBase(object):
...
@@ -230,7 +65,7 @@ class _DataLoaderIterBase(object):
self
.
_num_workers
=
loader
.
num_workers
self
.
_num_workers
=
loader
.
num_workers
self
.
_use_buffer_reader
=
loader
.
use_buffer_reader
self
.
_use_buffer_reader
=
loader
.
use_buffer_reader
self
.
_use_shared_memory
=
loader
.
use_shared_memory
self
.
_use_shared_memory
=
loader
.
use_shared_memory
self
.
_timeout
=
loader
.
timeout
if
loader
.
timeout
>
0
else
MP_
INDICE
S_CHECK_INTERVAL
self
.
_timeout
=
loader
.
timeout
if
loader
.
timeout
>
0
else
MP_
STATU
S_CHECK_INTERVAL
self
.
_worker_init_fn
=
loader
.
worker_init_fn
self
.
_worker_init_fn
=
loader
.
worker_init_fn
self
.
_dataset_kind
=
loader
.
dataset_kind
self
.
_dataset_kind
=
loader
.
dataset_kind
self
.
_pin_memory
=
loader
.
pin_memory
self
.
_pin_memory
=
loader
.
pin_memory
...
@@ -244,7 +79,7 @@ class _DataLoaderIterBase(object):
...
@@ -244,7 +79,7 @@ class _DataLoaderIterBase(object):
else
:
else
:
self
.
_sampler_iter
=
iter
(
self
.
_sampler_iter
=
iter
(
_InfiniteIterableSampler
(
self
.
_dataset
,
1
))
_InfiniteIterableSampler
(
self
.
_dataset
,
1
))
self
.
_collate_fn
=
loader
.
collate_fn
self
.
_collate_fn
=
loader
.
collate_fn
or
default_convert_fn
# LoDTensorBlockingQueue instance for create_py_reader and a thread
# LoDTensorBlockingQueue instance for create_py_reader and a thread
# to put mini-batch data to self._blocking_queue, mini-batch data
# to put mini-batch data to self._blocking_queue, mini-batch data
...
@@ -275,6 +110,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
...
@@ -275,6 +110,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self
.
_dataset_kind
,
self
.
_dataset
,
self
.
_auto_collate_batch
,
self
.
_dataset_kind
,
self
.
_dataset
,
self
.
_auto_collate_batch
,
self
.
_collate_fn
,
True
)
self
.
_collate_fn
,
True
)
# NOTE: _structrue_infos used to record the data structure of
# batch to restore batch structure after reading Tensor
# from blocking_queue in single-process mode. Note that
# only single process is used in single-process mode, we
# can record the data structure sequencely in a list without
# recording the send and recv index
self
.
_structure_infos
=
[]
# NOTE: len(self._places) batch data compose as an output
# NOTE: len(self._places) batch data compose as an output
# iteration, set blocking_queue can cache 2 iteration datas
# iteration, set blocking_queue can cache 2 iteration datas
# at most here
# at most here
...
@@ -316,16 +159,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
...
@@ -316,16 +159,14 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
# read data from dataset in mini-batch
# read data from dataset in mini-batch
batch
=
self
.
_dataset_fetcher
.
fetch
(
indices
)
batch
=
self
.
_dataset_fetcher
.
fetch
(
indices
)
# flat batch and record structure infos
batch
,
structure
=
_flatten_batch
(
batch
)
self
.
_structure_infos
.
append
(
structure
)
# pack as LoDTensorArray
# pack as LoDTensorArray
array
=
core
.
LoDTensorArray
()
array
=
core
.
LoDTensorArray
()
for
slot
in
batch
:
for
slot
in
batch
:
if
not
isinstance
(
slot
,
core
.
LoDTensor
):
if
not
isinstance
(
slot
,
core
.
LoDTensor
):
# FIXME(dkp): blocking_queue only support
# core.LoDTensorArray as input now, read
# numpy data into a LoDTensorArray here,
# should support paddle.Tensor list later
if
isinstance
(
slot
,
paddle
.
Tensor
):
slot
=
slot
.
numpy
()
tmp
=
core
.
LoDTensor
()
tmp
=
core
.
LoDTensor
()
tmp
.
set
(
slot
,
core
.
CPUPlace
())
tmp
.
set
(
slot
,
core
.
CPUPlace
())
slot
=
tmp
slot
=
tmp
...
@@ -348,20 +189,29 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
...
@@ -348,20 +189,29 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
def
__next__
(
self
):
def
__next__
(
self
):
try
:
try
:
if
in_dygraph_mode
():
if
in_dygraph_mode
():
return
self
.
_reader
.
read_next_var_list
()
data
=
self
.
_reader
.
read_next_var_list
()
data
=
_restore_batch
(
data
,
self
.
_structure_infos
.
pop
(
0
))
else
:
else
:
if
self
.
_return_list
:
if
self
.
_return_list
:
data
=
self
.
_reader
.
read_next_list
()
data
=
[
_restore_batch
(
d
,
s
)
for
d
,
s
in
zip
(
data
,
self
.
_structure_infos
[:
len
(
self
.
_places
)])
]
self
.
_structure_infos
=
self
.
_structure_infos
[
len
(
self
.
_places
):]
# static graph organized data on multi-device with list, if
# static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data
# place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode
# from list for devices to be compatible with dygraph mode
if
len
(
self
.
_places
)
==
1
:
if
len
(
self
.
_places
)
==
1
:
return
self
.
_reader
.
read_next_list
()[
0
]
data
=
data
[
0
]
else
:
return
self
.
_reader
.
read_next_list
()
else
:
else
:
return
self
.
_reader
.
read_next
()
data
=
self
.
_reader
.
read_next
()
return
data
except
StopIteration
:
except
StopIteration
:
self
.
_reader
.
reset
()
self
.
_reader
.
shutdown
()
six
.
reraise
(
*
sys
.
exc_info
())
six
.
reraise
(
*
sys
.
exc_info
())
# python2 compatibility
# python2 compatibility
...
@@ -375,97 +225,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
...
@@ -375,97 +225,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self
.
_blocking_queue
.
close
()
self
.
_blocking_queue
.
close
()
# NOTE(chenweihang): _worker_loop must be top level method to be pickled
def
_worker_loop
(
dataset
,
dataset_kind
,
indices_queue
,
out_queue
,
done_event
,
auto_collate_batch
,
collate_fn
,
init_fn
,
worker_id
,
num_workers
,
use_shared_memory
):
try
:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet
# been put into the inter-process Queue. This part of the object needs
# to be cleaned up when the process ends.
CleanupFuncRegistrar
.
register
(
_cleanup_mmap
)
# set signal handler
core
.
_set_process_signal_handler
()
global
_worker_info
_worker_info
=
WorkerInfo
(
id
=
worker_id
,
num_workers
=
num_workers
,
dataset
=
dataset
)
init_exception
=
None
try
:
if
init_fn
is
not
None
:
init_fn
(
worker_id
)
fetcher
=
_DatasetKind
.
create_fetcher
(
dataset_kind
,
dataset
,
auto_collate_batch
,
collate_fn
,
True
)
except
:
init_exception
=
Exception
(
"init_fn failed in worker {}: "
\
"{}"
.
format
(
worker_id
,
sys
.
exc_info
()))
iterator_drained
=
False
parent_watch_dog
=
ParentWatchDog
()
while
parent_watch_dog
.
is_alive
():
try
:
data
=
indices_queue
.
get
(
MP_INDICES_CHECK_INTERVAL
)
except
queue
.
Empty
:
continue
# None as poison piil, so worker event should be set
if
data
is
None
:
assert
done_event
.
is_set
()
or
iterator_drained
,
\
"get None when worker done_event set"
break
# If worker done event is set but get still get data in
# indices_queue, remaining data should be get and skipped.
if
done_event
.
is_set
()
or
iterator_drained
:
continue
idx
,
indices
=
data
try
:
if
init_exception
is
not
None
:
batch
=
init_exception
init_exception
=
None
else
:
batch
=
fetcher
.
fetch
(
indices
)
except
Exception
as
e
:
if
isinstance
(
e
,
StopIteration
)
and
dataset_kind
==
_DatasetKind
.
ITER
:
out_queue
.
put
(
_IterableDatasetStopIteration
(
worker_id
))
iterator_drained
=
True
else
:
out_queue
.
put
((
idx
,
e
))
else
:
if
use_shared_memory
:
# FIXME(dkp): _convert_to_tensor_list only support np.array
# list now, should support paddle.Tensor list
new_batch
=
[]
for
sample
in
batch
:
new_sample
=
[]
for
s
in
sample
:
if
isinstance
(
s
,
paddle
.
Tensor
):
new_sample
.
append
(
s
.
numpy
())
else
:
new_sample
.
append
(
s
)
new_batch
.
append
(
new_sample
)
batch
=
new_batch
tensor_list
=
core
.
_convert_to_tensor_list
(
batch
)
out_queue
.
put
((
idx
,
tensor_list
))
core
.
_remove_tensor_list_mmap_fds
(
tensor_list
)
else
:
out_queue
.
put
((
idx
,
batch
))
except
KeyboardInterrupt
:
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
pass
except
:
six
.
reraise
(
*
sys
.
exc_info
())
finally
:
if
use_shared_memory
:
_cleanup_mmap
()
class
_DataLoaderIterMultiProcess
(
_DataLoaderIterBase
):
class
_DataLoaderIterMultiProcess
(
_DataLoaderIterBase
):
def
__init__
(
self
,
loader
):
def
__init__
(
self
,
loader
):
super
(
_DataLoaderIterMultiProcess
,
self
).
__init__
(
loader
)
super
(
_DataLoaderIterMultiProcess
,
self
).
__init__
(
loader
)
...
@@ -483,6 +242,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -483,6 +242,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_rcvd_idx
=
0
self
.
_rcvd_idx
=
0
self
.
_batches_outstanding
=
0
self
.
_batches_outstanding
=
0
self
.
_task_infos
=
{}
self
.
_task_infos
=
{}
self
.
_structure_infos
=
[]
# indices outstand as _outstanding_capacity at first, and
# indices outstand as _outstanding_capacity at first, and
# blocking_queue capacity is also _outstanding_capacity.
# blocking_queue capacity is also _outstanding_capacity.
...
@@ -617,8 +377,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -617,8 +377,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if
not
self
.
_thread_done_event
.
is_set
():
if
not
self
.
_thread_done_event
.
is_set
():
if
batch
is
None
:
if
batch
is
None
:
self
.
_exit_thread_expectedly
()
self
.
_exit_thread_expectedly
()
elif
isinstance
(
batch
,
Exception
):
self
.
_exit_thread_unexpectedly
()
else
:
else
:
try
:
try
:
# pack as LoDTensorArray
# pack as LoDTensorArray
...
@@ -654,8 +412,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -654,8 +412,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# batch indices and increase _rcvd_idx
# batch indices and increase _rcvd_idx
if
self
.
_dataset_kind
==
_DatasetKind
.
ITER
:
if
self
.
_dataset_kind
==
_DatasetKind
.
ITER
:
while
self
.
_rcvd_idx
<
self
.
_send_idx
:
while
self
.
_rcvd_idx
<
self
.
_send_idx
:
sys
.
stdout
.
flush
()
info
=
self
.
_task_infos
[
self
.
_rcvd_idx
]
info
=
self
.
_task_infos
[
self
.
_rcvd_idx
]
if
len
(
info
)
==
2
or
self
.
_worker_status
[
info
[
0
]]:
if
len
(
info
)
==
3
or
self
.
_worker_status
[
info
[
0
]]:
break
break
del
self
.
_task_infos
[
self
.
_rcvd_idx
]
del
self
.
_task_infos
[
self
.
_rcvd_idx
]
self
.
_rcvd_idx
+=
1
self
.
_rcvd_idx
+=
1
...
@@ -669,13 +428,15 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -669,13 +428,15 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
continue
continue
if
self
.
_rcvd_idx
in
self
.
_task_infos
and
\
if
self
.
_rcvd_idx
in
self
.
_task_infos
and
\
len
(
self
.
_task_infos
[
self
.
_rcvd_idx
])
==
2
:
len
(
self
.
_task_infos
[
self
.
_rcvd_idx
])
==
3
:
return
self
.
_task_infos
.
pop
(
self
.
_rcvd_idx
)[
1
]
info
=
self
.
_task_infos
.
pop
(
self
.
_rcvd_idx
)
self
.
_structure_infos
.
append
(
info
[
2
])
return
info
[
1
]
try
:
try
:
# [ avoid hang ]: main process may blocking at _reader.read_next when
# [ avoid hang ]: main process may blocking at _reader.read_next when
# KeyboardInterrupt, we do following tradeoff:
# KeyboardInterrupt, we do following tradeoff:
# 1. get data with timeout, MP_
INDICE
S_CHECK_INTERVAL(5s) as timeout
# 1. get data with timeout, MP_
STATU
S_CHECK_INTERVAL(5s) as timeout
# default, if KeyboardInterrupt blocking, failed workers will be
# default, if KeyboardInterrupt blocking, failed workers will be
# checked and raise RuntimeError to quit DataLoader in timeout
# checked and raise RuntimeError to quit DataLoader in timeout
# exception handling.
# exception handling.
...
@@ -721,12 +482,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -721,12 +482,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_try_put_indices
()
self
.
_try_put_indices
()
continue
continue
idx
,
batch
=
data
idx
,
batch
,
structure
=
data
if
isinstance
(
batch
,
_WorkerException
):
self
.
_exit_thread_unexpectedly
()
batch
.
reraise
()
if
idx
==
self
.
_rcvd_idx
:
if
idx
==
self
.
_rcvd_idx
:
del
self
.
_task_infos
[
idx
]
del
self
.
_task_infos
[
idx
]
self
.
_structure_infos
.
append
(
structure
)
return
batch
return
batch
else
:
else
:
self
.
_task_infos
[
idx
]
+=
(
batch
,
)
self
.
_task_infos
[
idx
]
+=
(
batch
,
structure
)
continue
continue
def
_try_put_indices
(
self
):
def
_try_put_indices
(
self
):
...
@@ -777,9 +543,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -777,9 +543,17 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
if
in_dygraph_mode
():
if
in_dygraph_mode
():
data
=
self
.
_reader
.
read_next_var_list
()
data
=
self
.
_reader
.
read_next_var_list
()
data
=
_restore_batch
(
data
,
self
.
_structure_infos
.
pop
(
0
))
else
:
else
:
if
self
.
_return_list
:
if
self
.
_return_list
:
data
=
self
.
_reader
.
read_next_list
()
data
=
self
.
_reader
.
read_next_list
()
data
=
[
_restore_batch
(
d
,
s
)
for
d
,
s
in
zip
(
data
,
self
.
_structure_infos
[:
len
(
self
.
_places
)])
]
self
.
_structure_infos
=
self
.
_structure_infos
[
len
(
self
.
_places
):]
# static graph organized data on multi-device with list, if
# static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data
# place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode
# from list for devices to be compatible with dygraph mode
...
@@ -790,7 +564,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...
@@ -790,7 +564,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_on_output_batch
()
self
.
_on_output_batch
()
return
data
return
data
except
StopIteration
:
except
StopIteration
:
self
.
_reader
.
reset
()
self
.
_reader
.
shutdown
()
self
.
_try_shutdown_all
()
self
.
_try_shutdown_all
()
six
.
reraise
(
*
sys
.
exc_info
())
six
.
reraise
(
*
sys
.
exc_info
())
...
...
python/paddle/fluid/dataloader/flat.py
0 → 100644
浏览文件 @
a32e8bf1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
numbers
import
numpy
as
np
try
:
from
collections.abc
import
Sequence
,
Mapping
except
:
from
collections
import
Sequence
,
Mapping
FIELD_PREFIX
=
"_paddle_field_"
def
_flatten_batch
(
batch
):
"""
For lod_blocking_queue only receive tensor array, flatten batch
data, extract numpy.array data out as a list of numpy.array to
send to lod_blocking_queue, and save the batch data structure
such as fields in other types (str, int, etc) or key-value map
of dictionaries
"""
def
_flatten
(
batch
,
flat_batch
,
structure
,
field_idx
):
if
isinstance
(
batch
,
Sequence
):
for
field
in
batch
:
if
isinstance
(
field
,
np
.
ndarray
):
structure
.
append
(
'{}{}'
.
format
(
FIELD_PREFIX
,
field_idx
))
flat_batch
.
append
(
field
)
field_idx
+=
1
elif
isinstance
(
field
,
paddle
.
Tensor
):
structure
.
append
(
'{}{}'
.
format
(
FIELD_PREFIX
,
field_idx
))
flat_batch
.
append
(
field
.
numpy
())
field_idx
+=
1
elif
isinstance
(
field
,
(
str
,
bytes
,
numbers
.
Number
)):
structure
.
append
(
field
)
elif
isinstance
(
field
,
Sequence
):
field_struct
,
field_idx
=
_flatten
(
field
,
flat_batch
,
[],
field_idx
)
structure
.
append
(
field_struct
)
elif
isinstance
(
field
,
Mapping
):
field_struct
,
field_idx
=
_flatten
(
field
,
flat_batch
,
{},
field_idx
)
structure
.
append
(
field_struct
)
else
:
structure
.
append
(
field
)
elif
isinstance
(
batch
,
Mapping
):
for
k
,
field
in
batch
.
items
():
if
isinstance
(
field
,
np
.
ndarray
):
structure
[
k
]
=
'{}{}'
.
format
(
FIELD_PREFIX
,
field_idx
)
flat_batch
.
append
(
field
)
field_idx
+=
1
elif
isinstance
(
field
,
paddle
.
Tensor
):
structure
[
k
]
=
'{}{}'
.
format
(
FIELD_PREFIX
,
field_idx
)
flat_batch
.
append
(
field
.
numpy
())
field_idx
+=
1
elif
isinstance
(
field
,
(
str
,
bytes
,
numbers
.
Number
)):
structure
[
k
]
=
field
elif
isinstance
(
field
,
Sequence
):
field_struct
,
field_idx
=
_flatten
(
field
,
flat_batch
,
[],
field_idx
)
structure
[
k
]
=
field_struct
elif
isinstance
(
field
,
Mapping
):
field_struct
,
field_idx
=
_flatten
(
field
,
flat_batch
,
{},
field_idx
)
structure
[
k
]
=
field_struct
else
:
structure
[
k
]
=
field
else
:
raise
TypeError
(
"wrong flat data type: {}"
.
format
(
type
(
batch
)))
return
structure
,
field_idx
# sample only contains single fields
if
not
isinstance
(
batch
,
Sequence
):
flat_batch
=
[]
structure
,
_
=
_flatten
([
batch
],
flat_batch
,
[],
0
)
return
flat_batch
,
structure
[
0
]
flat_batch
=
[]
structure
,
_
=
_flatten
(
batch
,
flat_batch
,
[],
0
)
return
flat_batch
,
structure
def
_restore_batch
(
flat_batch
,
structure
):
"""
After reading list of Tensor data from lod_blocking_queue outputs,
use this function to restore the batch data structrue, replace
:attr:`_paddle_field_x` with data from flat_batch
"""
def
_restore
(
structure
,
field_idx
):
if
isinstance
(
structure
,
Sequence
):
for
i
,
field
in
enumerate
(
structure
):
if
isinstance
(
field
,
str
)
and
field
.
startswith
(
FIELD_PREFIX
):
cur_field_idx
=
int
(
field
.
replace
(
FIELD_PREFIX
,
''
))
field_idx
=
max
(
field_idx
,
cur_field_idx
)
assert
flat_batch
[
cur_field_idx
]
is
not
None
,
\
"flat_batch[{}] parsed repeatly"
structure
[
i
]
=
flat_batch
[
cur_field_idx
]
flat_batch
[
cur_field_idx
]
=
None
elif
isinstance
(
field
,
(
str
,
bytes
,
numbers
.
Number
)):
continue
elif
isinstance
(
field
,
(
Sequence
,
Mapping
)):
field_idx
=
_restore
(
structure
[
i
],
field_idx
)
elif
isinstance
(
structure
,
Mapping
):
for
k
,
field
in
structure
.
items
():
if
isinstance
(
field
,
str
)
and
field
.
startswith
(
FIELD_PREFIX
):
cur_field_idx
=
int
(
field
.
replace
(
FIELD_PREFIX
,
''
))
field_idx
=
max
(
field_idx
,
cur_field_idx
)
assert
flat_batch
[
cur_field_idx
]
is
not
None
,
\
"flat_batch[{}] parsed repeatly"
structure
[
k
]
=
flat_batch
[
cur_field_idx
]
flat_batch
[
cur_field_idx
]
=
None
elif
isinstance
(
field
,
(
str
,
bytes
,
numbers
.
Number
)):
continue
elif
isinstance
(
field
,
(
Sequence
,
Mapping
)):
field_idx
=
_restore
(
structure
[
k
],
field_idx
)
else
:
raise
TypeError
(
"wrong flat data type: {}"
.
format
(
type
(
batch
)))
return
field_idx
assert
isinstance
(
flat_batch
,
Sequence
),
\
"flat_batch is not a list or tuple"
# no np.array in dataset, no output tensor from blocking queue
# simply return structure
if
len
(
flat_batch
)
==
0
:
return
structure
# sample only contains single fields
if
isinstance
(
structure
,
(
str
,
bytes
)):
assert
structure
==
'{}{}'
.
format
(
FIELD_PREFIX
,
0
),
\
"invalid structure: {}"
.
format
(
structure
)
return
flat_batch
[
0
]
field_idx
=
_restore
(
structure
,
0
)
assert
field_idx
+
1
==
len
(
flat_batch
),
"Tensor parse incomplete"
return
structure
python/paddle/fluid/dataloader/worker.py
0 → 100644
浏览文件 @
a32e8bf1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
six
import
sys
import
paddle
import
numpy
as
np
import
traceback
from
collections
import
namedtuple
from
..
import
core
from
.fetcher
import
_IterableDatasetFetcher
,
_MapDatasetFetcher
from
..multiprocess_utils
import
_cleanup_mmap
,
CleanupFuncRegistrar
,
MP_STATUS_CHECK_INTERVAL
from
..framework
import
in_dygraph_mode
from
.flat
import
_flatten_batch
# NOTE: queue has a different name in python2 and python3
if
six
.
PY2
:
import
Queue
as
queue
else
:
import
queue
__all__
=
[
'get_worker_info'
]
class
_IterableDatasetStopIteration
(
object
):
def
__init__
(
self
,
worker_id
):
self
.
worker_id
=
worker_id
class
_DatasetKind
(
object
):
MAP
=
0
ITER
=
1
@
staticmethod
def
create_fetcher
(
kind
,
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
):
if
kind
==
_DatasetKind
.
MAP
:
return
_MapDatasetFetcher
(
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
)
elif
kind
==
_DatasetKind
.
ITER
:
return
_IterableDatasetFetcher
(
dataset
,
auto_collate_batch
,
collate_fn
,
drop_last
)
else
:
raise
NotImplementedError
(
"unknown Dataset kind {}"
.
format
(
kind
))
class
ParentWatchDog
(
object
):
def
__init__
(
self
):
self
.
_parent_pid
=
os
.
getppid
()
self
.
_parent_alive
=
True
def
is_alive
(
self
):
if
self
.
_parent_alive
:
self
.
_parent_alive
=
os
.
getppid
()
==
self
.
_parent_pid
return
self
.
_parent_alive
# worker information for each workers, used for splitting data copy
# for IteratorDataset in worker processes.
_worker_info
=
None
def
get_worker_info
():
"""
Get DataLoader worker process information function, this function is
used to split data copy in worker process for IterableDataset
(see :code:`paddle.io.IterableDataset`), worker information contains
following fields:
:attr:`num_workers`: total worker process number, see `paddle.io.DataLoader`
:attr:`id`: the worker processs id, count from 0 to :attr:`num_workers - 1`
:attr:`dataset`: the dataset object in this worker process
Returns:
WorkerInfo: an instance of WorkerInfo which contains fields above.
.. note::
For more usage and examples, please see :code:`paddle.io.IterableDataset`
Example:
.. code-block:: python
import math
import paddle
import numpy as np
from paddle.io import IterableDataset, DataLoader, get_worker_info
class SplitedIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(
math.ceil((self.end - self.start) / float(
worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
for i in range(iter_start, iter_end):
yield np.array([i])
place = paddle.CPUPlace()
dataset = SplitedIterableDataset(start=2, end=9)
dataloader = DataLoader(
dataset,
places=place,
num_workers=2,
batch_size=1,
drop_last=True)
for data in dataloader:
print(data)
# outputs: [2, 5, 3, 6, 4, 7]
"""
return
_worker_info
class
WorkerInfo
(
object
):
__initialized
=
False
def
__init__
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
setattr
(
self
,
k
,
v
)
self
.
__initialized
=
True
def
__setattr__
(
self
,
key
,
val
):
if
self
.
__initialized
:
raise
RuntimeError
(
"Cannot assign attributes to {} objects"
.
format
(
self
.
__class__
.
__name__
))
return
super
(
WorkerInfo
,
self
).
__setattr__
(
key
,
val
)
class
_WorkerException
(
object
):
def
__init__
(
self
,
worker_id
,
exc_info
=
None
):
self
.
worker_id
=
worker_id
exc_info
=
exc_info
or
sys
.
exc_info
()
self
.
exc_type
=
exc_info
[
0
]
self
.
exc_msg
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
def
reraise
(
self
):
msg
=
"DataLoader worker({}) caught {} with message:
\n
{}"
.
format
(
self
.
worker_id
,
self
.
exc_type
.
__name__
,
self
.
exc_msg
)
if
getattr
(
self
.
exc_type
,
"message"
,
None
):
raise
self
.
exc_type
(
message
=
msg
)
raise
self
.
exc_type
(
msg
)
def
_worker_loop
(
dataset
,
dataset_kind
,
indices_queue
,
out_queue
,
done_event
,
auto_collate_batch
,
collate_fn
,
init_fn
,
worker_id
,
num_workers
,
use_shared_memory
):
try
:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet
# been put into the inter-process Queue. This part of the object needs
# to be cleaned up when the process ends.
CleanupFuncRegistrar
.
register
(
_cleanup_mmap
)
# set signal handler
core
.
_set_process_signal_handler
()
global
_worker_info
_worker_info
=
WorkerInfo
(
id
=
worker_id
,
num_workers
=
num_workers
,
dataset
=
dataset
)
init_exception
=
None
try
:
if
init_fn
is
not
None
:
init_fn
(
worker_id
)
fetcher
=
_DatasetKind
.
create_fetcher
(
dataset_kind
,
dataset
,
auto_collate_batch
,
collate_fn
,
True
)
except
:
init_exception
=
_WorkerException
(
worker_id
)
iterator_drained
=
False
parent_watch_dog
=
ParentWatchDog
()
while
parent_watch_dog
.
is_alive
():
try
:
data
=
indices_queue
.
get
(
MP_STATUS_CHECK_INTERVAL
)
except
queue
.
Empty
:
continue
# None as poison piil, so worker event should be set
if
data
is
None
:
assert
done_event
.
is_set
()
or
iterator_drained
,
\
"get None when worker done_event set"
break
# If worker done event is set but get still get data in
# indices_queue, remaining data should be get and skipped.
if
done_event
.
is_set
()
or
iterator_drained
:
continue
idx
,
indices
=
data
try
:
if
init_exception
is
not
None
:
batch
=
init_exception
init_exception
=
None
else
:
# NOTE: GPU tensor operation is not supported in sub-process
# but default device is GPU in paddle-gpu version, which
# may copy CPU tensor to GPU even if users want to use
# CPU tensor operation, so we add CPUPlace guard here
# to make sure tensor will be operated only on CPU
with
paddle
.
fluid
.
dygraph
.
guard
(
place
=
paddle
.
CPUPlace
()):
batch
=
fetcher
.
fetch
(
indices
)
except
Exception
as
e
:
if
isinstance
(
e
,
StopIteration
)
and
dataset_kind
==
_DatasetKind
.
ITER
:
out_queue
.
put
(
_IterableDatasetStopIteration
(
worker_id
))
iterator_drained
=
True
else
:
out_queue
.
put
((
idx
,
_WorkerException
(
worker_id
),
None
))
else
:
if
isinstance
(
batch
,
_WorkerException
):
out_queue
.
put
((
idx
,
batch
,
None
))
batch
,
structure
=
_flatten_batch
(
batch
)
if
use_shared_memory
:
tensor_list
=
core
.
_convert_to_tensor_list
(
batch
)
out_queue
.
put
((
idx
,
tensor_list
,
structure
))
core
.
_remove_tensor_list_mmap_fds
(
tensor_list
)
else
:
out_queue
.
put
((
idx
,
batch
,
structure
))
except
KeyboardInterrupt
:
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
pass
except
:
six
.
reraise
(
*
sys
.
exc_info
())
finally
:
if
use_shared_memory
:
_cleanup_mmap
()
python/paddle/fluid/multiprocess_utils.py
浏览文件 @
a32e8bf1
...
@@ -25,6 +25,10 @@ if six.PY2:
...
@@ -25,6 +25,10 @@ if six.PY2:
else
:
else
:
import
queue
import
queue
# multi-process worker check indices queue interval, avoid
# hanging in subprocess data loading
MP_STATUS_CHECK_INTERVAL
=
5.
# NOTE: [ mmap files clear ] If there is still data in the multiprocess queue when the main process finishes reading,
# NOTE: [ mmap files clear ] If there is still data in the multiprocess queue when the main process finishes reading,
# the data in the queue needs to be popped. Then the LoDTensor read by the main process
# the data in the queue needs to be popped. Then the LoDTensor read by the main process
# from the child process will automatically clear the memory-mapped file.
# from the child process will automatically clear the memory-mapped file.
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py
浏览文件 @
a32e8bf1
...
@@ -273,5 +273,62 @@ class TestNumpyMixTensorDataset(TestTensorDataset):
...
@@ -273,5 +273,62 @@ class TestNumpyMixTensorDataset(TestTensorDataset):
assert
isinstance
(
label
,
paddle
.
Tensor
)
assert
isinstance
(
label
,
paddle
.
Tensor
)
class
ComplextDataset
(
Dataset
):
def
__init__
(
self
,
sample_num
):
self
.
sample_num
=
sample_num
def
__len__
(
self
):
return
self
.
sample_num
def
__getitem__
(
self
,
idx
):
return
(
3.1
,
'abc'
,
paddle
.
to_tensor
(
np
.
random
.
random
([
IMAGE_SIZE
]).
astype
(
'float32'
),
place
=
paddle
.
CPUPlace
()),
[
1
,
np
.
random
.
random
([
2
]).
astype
(
'float32'
)],
{
'a'
:
2.0
,
'b'
:
np
.
random
.
random
([
2
]).
astype
(
'float32'
)
})
class
TestComplextDataset
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
):
paddle
.
static
.
default_startup_program
().
random_seed
=
1
paddle
.
static
.
default_main_program
().
random_seed
=
1
place
=
paddle
.
CPUPlace
()
with
fluid
.
dygraph
.
guard
(
place
):
dataset
=
ComplextDataset
(
16
)
assert
len
(
dataset
)
==
16
dataloader
=
DataLoader
(
dataset
,
places
=
place
,
num_workers
=
num_workers
,
batch_size
=
2
,
drop_last
=
True
)
for
i
,
data
in
enumerate
(
dataloader
()):
assert
len
(
data
)
==
5
# data[0]: collate 3.1
assert
data
[
0
].
shape
==
[
2
]
assert
isinstance
(
data
[
1
],
list
)
# data[1]: collate 'abc'
assert
len
(
data
[
1
])
==
2
assert
isinstance
(
data
[
1
][
0
],
str
)
assert
isinstance
(
data
[
1
][
1
],
str
)
# data[2]: collate tensor
assert
data
[
2
].
shape
==
[
2
,
IMAGE_SIZE
]
# data[3]: collate list
assert
isinstance
(
data
[
3
],
list
)
assert
data
[
3
][
0
].
shape
==
[
2
]
assert
data
[
3
][
1
].
shape
==
[
2
,
2
]
# data[4]: collate dict
assert
isinstance
(
data
[
4
],
dict
)
assert
data
[
4
][
'a'
].
shape
==
[
2
]
assert
data
[
4
][
'b'
].
shape
==
[
2
,
2
]
def
test_main
(
self
):
for
num_workers
in
[
0
,
2
]:
self
.
run_main
(
num_workers
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_split.py
浏览文件 @
a32e8bf1
...
@@ -58,7 +58,7 @@ class TestDynamicDataLoaderIterSplit(unittest.TestCase):
...
@@ -58,7 +58,7 @@ class TestDynamicDataLoaderIterSplit(unittest.TestCase):
rets
=
[]
rets
=
[]
for
d
in
dataloader
:
for
d
in
dataloader
:
rets
.
append
(
d
[
0
]
.
numpy
()[
0
][
0
])
rets
.
append
(
d
.
numpy
()[
0
][
0
])
assert
tuple
(
sorted
(
rets
))
==
tuple
(
range
(
0
,
10
))
assert
tuple
(
sorted
(
rets
))
==
tuple
(
range
(
0
,
10
))
...
@@ -102,7 +102,7 @@ class TestDynamicDataLoaderIterInitFuncSplit(unittest.TestCase):
...
@@ -102,7 +102,7 @@ class TestDynamicDataLoaderIterInitFuncSplit(unittest.TestCase):
rets
=
[]
rets
=
[]
for
d
in
dataloader
:
for
d
in
dataloader
:
rets
.
append
(
d
[
0
]
.
numpy
()[
0
][
0
])
rets
.
append
(
d
.
numpy
()[
0
][
0
])
assert
tuple
(
sorted
(
rets
))
==
tuple
(
range
(
0
,
10
))
assert
tuple
(
sorted
(
rets
))
==
tuple
(
range
(
0
,
10
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录