Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
dbc88bb9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dbc88bb9
编写于
8月 12, 2020
作者:
K
Kaipeng Deng
提交者:
GitHub
8月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add iterable dataset support for multiprocess DataLoader (#25558)
* add IterableDataset support in multiprocess DataLoader. test=develop
上级
54003b87
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
932 addition
and
58 deletion
+932
-58
python/paddle/fluid/dataloader/__init__.py
python/paddle/fluid/dataloader/__init__.py
+5
-1
python/paddle/fluid/dataloader/batch_sampler.py
python/paddle/fluid/dataloader/batch_sampler.py
+31
-6
python/paddle/fluid/dataloader/dataloader_iter.py
python/paddle/fluid/dataloader/dataloader_iter.py
+196
-33
python/paddle/fluid/dataloader/dataset.py
python/paddle/fluid/dataloader/dataset.py
+153
-2
python/paddle/fluid/dataloader/fetcher.py
python/paddle/fluid/dataloader/fetcher.py
+53
-0
python/paddle/fluid/reader.py
python/paddle/fluid/reader.py
+31
-9
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py
...tests/unittests/test_multiprocess_dataloader_exception.py
+51
-6
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py
.../test_multiprocess_dataloader_iterable_dataset_dynamic.py
+124
-0
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_split.py
...ts/test_multiprocess_dataloader_iterable_dataset_split.py
+111
-0
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py
...s/test_multiprocess_dataloader_iterable_dataset_static.py
+171
-0
python/paddle/io/__init__.py
python/paddle/io/__init__.py
+3
-1
未找到文件。
python/paddle/fluid/dataloader/__init__.py
浏览文件 @
dbc88bb9
...
...
@@ -20,5 +20,9 @@ from .dataset import *
from
.
import
batch_sampler
from
.batch_sampler
import
*
from
.
import
dataloader_iter
from
.dataloader_iter
import
*
__all__
=
dataset
.
__all__
\
+
batch_sampler
.
__all__
+
batch_sampler
.
__all__
\
+
dataloader_iter
.
__all__
python/paddle/fluid/dataloader/batch_sampler.py
浏览文件 @
dbc88bb9
...
...
@@ -16,7 +16,7 @@ from __future__ import print_function
from
__future__
import
division
import
numpy
as
np
from
.dataset
import
Dataset
from
.dataset
import
Dataset
,
IterableDataset
__all__
=
[
"BatchSampler"
]
...
...
@@ -106,12 +106,18 @@ class BatchSampler(object):
assert
isinstance
(
indices
,
list
)
or
isinstance
(
indices
,
tuple
),
\
"indices should be a list or tuple, but got {}"
.
format
(
type
(
indices
))
self
.
indices
=
indices
self
.
sampler_iter
=
None
else
:
assert
isinstance
(
dataset
,
Dataset
),
\
"dataset should be an instance of paddle.io.Dataset"
assert
indices
is
None
,
\
"should not set both dataset and indices"
self
.
indices
=
list
(
range
(
len
(
dataset
)))
if
isinstance
(
dataset
,
IterableDataset
):
self
.
sampler_iter
=
iter
(
_InfiniteIterableSampler
(
dataset
,
batch_size
))
else
:
self
.
sampler_iter
=
None
assert
isinstance
(
dataset
,
Dataset
),
\
"dataset should be an instance of paddle.io.Dataset"
assert
indices
is
None
,
\
"should not set both dataset and indices"
self
.
indices
=
list
(
range
(
len
(
dataset
)))
assert
isinstance
(
batch_size
,
int
)
and
batch_size
>
0
,
\
"batch_size should be a positive integer, but got {}"
.
format
(
batch_size
)
...
...
@@ -124,6 +130,9 @@ class BatchSampler(object):
self
.
drop_last
=
drop_last
def
__iter__
(
self
):
if
self
.
sampler_iter
:
yield
next
(
self
.
sampler_iter
)
if
self
.
shuffle
:
np
.
random
.
shuffle
(
self
.
indices
)
_iter
=
iter
(
self
.
indices
)
...
...
@@ -138,6 +147,22 @@ class BatchSampler(object):
yield
batch_indices
def
__len__
(
self
):
if
self
.
sampler_iter
:
raise
RuntimeError
(
"'{}' should not be called for IterableDataset"
.
format
(
'__len__'
))
num_samples
=
len
(
self
.
indices
)
num_samples
+=
int
(
not
self
.
drop_last
)
*
(
self
.
batch_size
-
1
)
return
num_samples
//
self
.
batch_size
class
_InfiniteIterableSampler
(
object
):
def
__init__
(
self
,
dataset
,
batch_size
=
1
):
assert
isinstance
(
dataset
,
IterableDataset
),
"dataset should be an instance of paddle.io.IterableDataset"
self
.
dataset
=
dataset
self
.
batch_size
=
batch_size
def
__iter__
(
self
):
while
True
:
yield
[
None
]
*
self
.
batch_size
python/paddle/fluid/dataloader/dataloader_iter.py
浏览文件 @
dbc88bb9
...
...
@@ -22,6 +22,7 @@ import itertools
import
threading
import
numpy
as
np
import
multiprocessing
from
collections
import
namedtuple
# NOTE: queue has a different name in python2 and python3
if
six
.
PY2
:
...
...
@@ -32,11 +33,17 @@ else:
from
..
import
core
from
..framework
import
in_dygraph_mode
from
..multiprocess_utils
import
CleanupFuncRegistrar
,
_cleanup_mmap
,
_set_SIGCHLD_handler
from
.fetcher
import
_IterableDatasetFetcher
,
_MapDatasetFetcher
__all__
=
[
'get_worker_info'
]
# multi-process worker check indices queue interval, avoid
# hanging in subprocess data loading
MP_INDICES_CHECK_INTERVAL
=
5
_IterableDatasetStopIteration
=
namedtuple
(
'_IterableDatasetStopIteration'
,
[
'worker_id'
])
def
default_collate_fn
(
batch
):
"""
...
...
@@ -75,6 +82,20 @@ def default_collate_fn(batch):
return
[
np
.
stack
(
slot
,
axis
=
0
)
for
slot
in
slots
]
class
_DatasetKind
(
object
):
MAP
=
0
ITER
=
1
@
staticmethod
def
create_fetcher
(
kind
,
dataset
,
collate_fn
,
drop_last
):
if
kind
==
_DatasetKind
.
MAP
:
return
_MapDatasetFetcher
(
dataset
,
collate_fn
,
drop_last
)
elif
kind
==
_DatasetKind
.
ITER
:
return
_IterableDatasetFetcher
(
dataset
,
collate_fn
,
drop_last
)
else
:
raise
NotImplementedError
(
"unknown Dataset kind {}"
.
format
(
kind
))
class
ParentWatchDog
(
object
):
def
__init__
(
self
):
self
.
_parent_pid
=
os
.
getppid
()
...
...
@@ -86,6 +107,92 @@ class ParentWatchDog(object):
return
self
.
_parent_alive
# worker information for each workers, used for splitting data copy
# for IteratorDataset in worker processes.
_worker_info
=
None
def
get_worker_info
():
"""
Get DataLoader worker process information function, this function is
used to split data copy in worker process for IterableDataset
(see :code:`paddle.io.IterableDataset`), worker information contains
following fields:
:attr:`num_workers`: total worker process number, see `paddle.io.DataLoader`
:attr:`id`: the worker processs id, count from 0 to :attr:`num_workers - 1`
:attr:`dataset`: the dataset object in this worker process
Returns:
WorkerInfo: an instance of WorkerInfo which contains fields above.
.. note::
For mode usage and exampls, please see :code:`paddle.io.IterableDataset`
Example:
.. code-block:: python
import math
import numpy as np
import paddle.fluid as fluid
from paddle.io import IterableDataset, DataLoader, get_worker_info
class SplitedIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(
math.ceil((self.end - self.start) / float(
worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
for i in range(iter_start, iter_end):
yield np.array([i])
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
dataset = SplitedIterableDataset(start=2, end=9)
dataloader = DataLoader(
dataset,
places=place,
num_workers=2,
batch_size=1,
drop_last=True)
print(list(dataloader))
# outputs: [2, 5, 3, 6, 4, 7]
"""
return
_worker_info
class
WorkerInfo
(
object
):
__initialized
=
False
def
__init__
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
setattr
(
self
,
k
,
v
)
self
.
__initialized
=
True
def
__setattr__
(
self
,
key
,
val
):
if
self
.
__initialized
:
raise
RuntimeError
(
"Cannot assign attributes to {} objects"
.
format
(
self
.
__class__
.
__name__
))
return
super
(
WorkerInfo
,
self
).
__setattr__
(
key
,
val
)
class
_DataLoaderIterBase
(
object
):
"""
Iterator implement of DataLoader, will load and feed mini-batch
...
...
@@ -108,6 +215,7 @@ class _DataLoaderIterBase(object):
self
.
_use_shared_memory
=
loader
.
use_shared_memory
self
.
_timeout
=
loader
.
timeout
if
loader
.
timeout
>
0
else
MP_INDICES_CHECK_INTERVAL
self
.
_worker_init_fn
=
loader
.
worker_init_fn
self
.
_dataset_kind
=
loader
.
dataset_kind
self
.
_pin_memory
=
loader
.
pin_memory
# LoDTensorBlockingQueue instance for create_py_reader and a thread
...
...
@@ -135,6 +243,9 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
def
__init__
(
self
,
loader
):
super
(
_DataLoaderIterSingleProcess
,
self
).
__init__
(
loader
)
self
.
_dataset_fetcher
=
_DatasetKind
.
create_fetcher
(
self
.
_dataset_kind
,
self
.
_dataset
,
self
.
_collate_fn
,
True
)
# NOTE: len(self._places) batch data compose as an output
# iteration, set blocking_queue can cache 2 iteration datas
# at most here
...
...
@@ -166,9 +277,7 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
try
:
for
indices
in
self
.
_sampler_iter
:
# read data from dataset in mini-batch
batch
=
[
self
.
_dataset
[
i
]
for
i
in
indices
]
if
self
.
_collate_fn
is
not
None
:
batch
=
self
.
_collate_fn
(
batch
)
batch
=
self
.
_dataset_fetcher
.
fetch
(
indices
)
# pack as LoDTensorArray
array
=
core
.
LoDTensorArray
()
...
...
@@ -186,6 +295,8 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self
.
_blocking_queue
.
close
()
self
.
_thread
=
None
except
StopIteration
:
self
.
_blocking_queue
.
close
()
except
Exception
:
self
.
_blocking_queue
.
kill
()
self
.
_thread
=
None
...
...
@@ -233,11 +344,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# data get from _data_queue will be reordered by _rcvd_idx
# for data order keeping, data index not equal _rcvd_idx
# will be cached in _
reorder_dict
# will be cached in _
task_infos
self
.
_send_idx
=
0
self
.
_rcvd_idx
=
0
self
.
_batches_outstanding
=
0
self
.
_
reorder_dict
=
{}
self
.
_
task_infos
=
{}
# indices outstand as _outstanding_capacity at first, and
# blocking_queue capacity is also _outstanding_capacity.
...
...
@@ -248,14 +359,14 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_outstanding_capacity
=
2
*
max
(
self
.
_num_workers
,
len
(
self
.
_places
))
# init workers and indices queues and put 2 indices in each indices queue
self
.
_init_workers
()
self
.
_init_thread
()
self
.
_shutdown
=
False
for
_
in
range
(
self
.
_outstanding_capacity
):
self
.
_try_put_indices
()
self
.
_init_thread
()
self
.
_shutdown
=
False
def
_init_workers
(
self
):
# multiprocess worker and indice queue list initial as empty
self
.
_workers
=
[]
...
...
@@ -276,9 +387,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_indices_queues
.
append
(
indices_queue
)
worker
=
multiprocessing
.
Process
(
target
=
self
.
_worker_loop
,
args
=
(
self
.
_dataset
,
indices_queue
,
self
.
_data_queue
,
self
.
_workers_done_event
,
self
.
_collate_fn
,
self
.
_worker_init_fn
,
i
))
args
=
(
self
.
_dataset
,
self
.
_dataset_kind
,
indices_queue
,
self
.
_data_queue
,
self
.
_workers_done_event
,
self
.
_collate_fn
,
self
.
_worker_init_fn
,
i
,
self
.
_num_workers
))
worker
.
daemon
=
True
worker
.
start
()
self
.
_workers
.
append
(
worker
)
...
...
@@ -353,8 +465,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_blocking_queue
.
kill
()
logging
.
error
(
"DataLoader reader thread raised an exception!"
)
def
_worker_loop
(
self
,
dataset
,
indices_queue
,
out_queue
,
done_event
,
collate_fn
,
init_fn
,
worker_id
):
def
_worker_loop
(
self
,
dataset
,
dataset_kind
,
indices_queue
,
out_queue
,
done_event
,
collate_fn
,
init_fn
,
worker_id
,
num_workers
):
try
:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet
...
...
@@ -365,14 +477,21 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# set signal handler
core
.
_set_process_signal_handler
()
global
_worker_info
_worker_info
=
WorkerInfo
(
id
=
worker_id
,
num_workers
=
num_workers
,
dataset
=
dataset
)
init_exception
=
None
if
init_fn
is
not
None
:
try
:
try
:
if
init_fn
is
not
None
:
init_fn
(
worker_id
)
except
:
init_exception
=
Exception
(
"init_fn failed in worker {}: "
\
"{}"
.
format
(
worker_id
,
sys
.
exc_info
()))
fetcher
=
_DatasetKind
.
create_fetcher
(
dataset_kind
,
dataset
,
collate_fn
,
True
)
except
:
init_exception
=
Exception
(
"init_fn failed in worker {}: "
\
"{}"
.
format
(
worker_id
,
sys
.
exc_info
()))
iterator_drained
=
False
parent_watch_dog
=
ParentWatchDog
()
while
parent_watch_dog
.
is_alive
():
...
...
@@ -383,12 +502,12 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# None as poison piil, so worker event should be set
if
data
is
None
:
assert
done_event
.
is_set
(
),
"get None when worker done_event set"
assert
done_event
.
is_set
(
)
or
iterator_drained
,
\
"get None when worker done_event set"
break
# If worker done event is set but get still get data in
# indices_queue, remaining data should be get and skipped.
if
done_event
.
is_set
():
if
done_event
.
is_set
()
or
iterator_drained
:
continue
idx
,
indices
=
data
...
...
@@ -397,11 +516,15 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
batch
=
init_exception
init_exception
=
None
else
:
batch
=
[
dataset
[
i
]
for
i
in
indices
]
if
self
.
_collate_fn
is
not
None
:
batch
=
self
.
_collate_fn
(
batch
)
batch
=
fetcher
.
fetch
(
indices
)
except
Exception
as
e
:
out_queue
.
put
((
idx
,
e
))
if
isinstance
(
e
,
StopIteration
)
and
dataset_kind
==
_DatasetKind
.
ITER
:
out_queue
.
put
(
_IterableDatasetStopIteration
(
worker_id
))
iterator_drained
=
True
else
:
out_queue
.
put
((
idx
,
e
))
else
:
if
self
.
_use_shared_memory
:
tensor_list
=
core
.
_convert_to_tensor_list
(
batch
)
...
...
@@ -438,7 +561,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# serializable, cannot be create in workers
for
slot
in
batch
:
if
not
isinstance
(
slot
,
core
.
LoDTensor
):
# self._check_input_array(slot)
tmp
=
core
.
LoDTensor
()
tmp
.
set
(
slot
,
core
.
CPUPlace
())
slot
=
tmp
...
...
@@ -453,10 +575,31 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self
.
_rcvd_idx
+=
1
def
_get_data
(
self
):
if
self
.
_rcvd_idx
in
self
.
_reorder_dict
.
keys
():
return
self
.
_reorder_dict
.
pop
(
self
.
_rcvd_idx
)
while
not
self
.
_thread_done_event
.
is_set
():
# For IterableDataset, batch indices is generated infinitely
# for each worker to raise StopIteration, but a StopIteration
# raising process will discard a batch indices which is count
# in _send_idx but will not increase _rcvd_idx, so we check
# whether the worker is still alive here to skip the discarded
# batch indices and increase _rcvd_idx
while
self
.
_rcvd_idx
<
self
.
_send_idx
:
info
=
self
.
_task_infos
[
self
.
_rcvd_idx
]
if
len
(
info
)
==
2
or
self
.
_worker_status
[
info
[
0
]]:
break
del
self
.
_task_infos
[
self
.
_rcvd_idx
]
self
.
_rcvd_idx
+=
1
self
.
_batches_outstanding
-=
1
else
:
# NOTE: _rcvd_idx and _send_idx only record batches among
# workers, if batches among workers drained, there
# may also be data in blocking queue
if
self
.
_batches_outstanding
<
len
(
self
.
_places
):
return
None
continue
if
len
(
self
.
_task_infos
[
self
.
_rcvd_idx
])
==
2
:
return
self
.
_task_infos
.
pop
(
self
.
_rcvd_idx
)[
1
]
try
:
# [ avoid hang ]: main process may blocking at _reader.read_next when
# KeyboardInterrupt, we do following tradeoff:
...
...
@@ -494,23 +637,43 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
"workers' result queue."
.
format
(
e
))
six
.
reraise
(
*
sys
.
exc_info
())
else
:
if
self
.
_dataset_kind
==
_DatasetKind
.
ITER
and
isinstance
(
data
,
_IterableDatasetStopIteration
):
# if a worker get StopIteraion, we shutdown this worker,
# note that this batch indices to trigger StopIteration
# is discard, outstanding batch number should be decrease
# and another indices should be put for other workers
# may still working.
self
.
_shutdown_worker
(
data
.
worker_id
)
self
.
_batches_outstanding
-=
1
self
.
_try_put_indices
()
continue
idx
,
batch
=
data
if
idx
==
self
.
_rcvd_idx
:
del
self
.
_task_infos
[
idx
]
return
batch
else
:
self
.
_
reorder_dict
[
idx
]
=
batch
self
.
_
task_infos
[
idx
]
+=
(
batch
,
)
continue
def
_try_put_indices
(
self
):
assert
self
.
_
send_idx
-
self
.
_rcvd_idx
<=
self
.
_outstanding_capacity
,
\
assert
self
.
_
batches_outstanding
<=
self
.
_outstanding_capacity
,
\
"too many indices have been put to queue"
try
:
indices
=
next
(
self
.
_sampler_iter
)
except
StopIteration
:
return
worker_idx
=
next
(
self
.
_workers_idx_cycle
)
for
i
in
range
(
self
.
_num_workers
):
worker_idx
=
next
(
self
.
_workers_idx_cycle
)
if
self
.
_worker_status
[
worker_idx
]:
break
else
:
return
self
.
_indices_queues
[
worker_idx
].
put
((
self
.
_send_idx
,
indices
))
self
.
_task_infos
[
self
.
_send_idx
]
=
(
worker_idx
,
)
self
.
_batches_outstanding
+=
1
self
.
_send_idx
+=
1
...
...
python/paddle/fluid/dataloader/dataset.py
浏览文件 @
dbc88bb9
...
...
@@ -16,12 +16,12 @@ from __future__ import print_function
import
paddle.dataset.common
__all__
=
[
"Dataset"
]
__all__
=
[
"Dataset"
,
"IterableDataset"
]
class
Dataset
(
object
):
"""
An abstract class to encapsulate
s
methods and behaviors of datasets.
An abstract class to encapsulate methods and behaviors of datasets.
All datasets in map-style(dataset samples can be get by a given key)
should be a subclass of `paddle.io.Dataset`. All subclasses should
...
...
@@ -71,3 +71,154 @@ class Dataset(object):
def
__len__
(
self
):
raise
NotImplementedError
(
"'{}' not implement in class "
\
"{}"
.
format
(
'__len__'
,
self
.
__class__
.
__name__
))
class
IterableDataset
(
Dataset
):
"""
An abstract class to encapsulate methods and behaviors of iterable datasets.
All datasets in iterable-style (can only get sample one by one sequentially, like
a Python iterator) should be a subclass of `paddle.io.IterableDataset`. All subclasses should
implement following methods:
:code:`__iter__`: yield sample sequentially. This method is required by reading dataset sample in :code:`paddle.io.DataLoader`.
.. note::
do not implement :code:`__getitem__` and :code:`__len__` in IterableDataset, should not be called either.
see :code:`paddle.io.DataLoader`.
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset
# define a random dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __iter__(self):
for i in range(self.num_samples):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
yield image, label
dataset = RandomDataset(10)
for img, lbl in dataset:
print(img, lbl)
When :attr:`num_workers > 0`, each worker has a different copy of the dataset object and
will yield whole dataset samples, which means samples in dataset will be repeated in
:attr:`num_workers` times. If it is required for each sample to yield only once, there
are two methods to configure different copy in each worker process to avoid duplicate data
among workers as follows. In both the methods, worker information that can be getted in
a worker process by `paddle.io.get_worker_info` will be needed.
Example 1: splitting data copy in each worker in :code:`__iter__`
.. code-block:: python
import math
import numpy as np
import paddle.fluid as fluid
from paddle.io import IterableDataset, DataLoader, get_worker_info
class SplitedIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(
math.ceil((self.end - self.start) / float(
worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
for i in range(iter_start, iter_end):
yield np.array([i])
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
dataset = SplitedIterableDataset(start=2, end=9)
dataloader = DataLoader(
dataset,
places=place,
num_workers=2,
batch_size=1,
drop_last=True)
print(list(dataloader))
# outputs: [2, 5, 3, 6, 4, 7]
Example 2: splitting data copy in each worker by :code:`worker_init_fn`
.. code-block:: python
import math
import numpy as np
import paddle.fluid as fluid
from paddle.io import IterableDataset, DataLoader, get_worker_info
class RangeIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
for i in range(self.start, self.end):
yield np.array([i])
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
dataset = RangeIterableDataset(start=2, end=9)
def worker_init_fn(worker_id):
worker_info = get_worker_info()
dataset = worker_info.dataset
start = dataset.start
end = dataset.end
num_per_worker = int(
math.ceil((end - start) / float(worker_info.num_workers)))
worker_id = worker_info.id
dataset.start = start + worker_id * num_per_worker
dataset.end = min(dataset.start + num_per_worker, end)
dataloader = DataLoader(
dataset,
places=place,
num_workers=2,
batch_size=1,
drop_last=True,
worker_init_fn=worker_init_fn)
print(list(dataloader))
# outputs: [2, 5, 3, 6, 4, 7]
"""
def
__init__
(
self
):
pass
def
__iter__
(
self
):
raise
NotImplementedError
(
"'{}' not implement in class "
\
"{}"
.
format
(
'__iter__'
,
self
.
__class__
.
__name__
))
def
__getitem__
(
self
,
idx
):
raise
RuntimeError
(
"'{}' should not be called for IterableDataset"
\
"{}"
.
format
(
'__getitem__'
,
self
.
__class__
.
__name__
))
def
__len__
(
self
):
raise
RuntimeError
(
"'{}' should not be called for IterableDataset"
\
"{}"
.
format
(
'__len__'
,
self
.
__class__
.
__name__
))
python/paddle/fluid/dataloader/fetcher.py
0 → 100644
浏览文件 @
dbc88bb9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
_DatasetFetcher
(
object
):
def
__init__
(
self
,
dataset
,
collate_fn
,
drop_last
):
self
.
dataset
=
dataset
self
.
collate_fn
=
collate_fn
self
.
drop_last
=
drop_last
def
fetch
(
self
,
batch_indices
):
raise
NotImplementedError
(
"'fetch' not implement for class {}"
.
format
(
self
.
__class__
.
__name__
))
class
_IterableDatasetFetcher
(
_DatasetFetcher
):
def
__init__
(
self
,
dataset
,
collate_fn
,
drop_last
):
super
(
_IterableDatasetFetcher
,
self
).
__init__
(
dataset
,
collate_fn
,
drop_last
)
self
.
dataset_iter
=
iter
(
dataset
)
def
fetch
(
self
,
batch_indices
):
data
=
[]
for
_
in
batch_indices
:
try
:
data
.
append
(
next
(
self
.
dataset_iter
))
except
StopIteration
:
break
if
len
(
data
)
==
0
or
(
self
.
drop_last
and
len
(
data
)
<
len
(
batch_indices
)):
raise
StopIteration
return
self
.
collate_fn
(
data
)
class
_MapDatasetFetcher
(
_DatasetFetcher
):
def
__init__
(
self
,
dataset
,
collate_fn
,
drop_last
):
super
(
_MapDatasetFetcher
,
self
).
__init__
(
dataset
,
collate_fn
,
drop_last
)
def
fetch
(
self
,
batch_indices
):
data
=
[
self
.
dataset
[
idx
]
for
idx
in
batch_indices
]
return
self
.
collate_fn
(
data
)
python/paddle/fluid/reader.py
浏览文件 @
dbc88bb9
...
...
@@ -22,8 +22,9 @@ from .framework import Program, Variable, program_guard, default_main_program, d
from
.executor
import
global_scope
from
.data_feeder
import
DataFeeder
,
BatchedTensorProvider
from
.multiprocess_utils
import
multiprocess_queue_set
,
CleanupFuncRegistrar
,
_cleanup_mmap
,
_cleanup
,
_set_SIGCHLD_handler
from
.dataloader
import
BatchSampler
,
Dataset
from
.dataloader.dataloader_iter
import
_DataLoaderIterSingleProcess
,
_DataLoaderIterMultiProcess
,
default_collate_fn
from
.dataloader
import
BatchSampler
,
Dataset
,
IterableDataset
from
.dataloader.dataloader_iter
import
_DataLoaderIterSingleProcess
,
_DataLoaderIterMultiProcess
,
_DatasetKind
,
default_collate_fn
from
.dataloader.batch_sampler
import
_InfiniteIterableSampler
from
.layers.io
import
monkey_patch_reader_methods
,
_copy_reader_var_
,
double_buffer
from
.unique_name
import
UniqueNameGenerator
import
logging
...
...
@@ -136,8 +137,9 @@ class DataLoader(object):
Args:
dataset(Dataset): the dataset to load data from, should be an
instance of subclass of :code:`paddle.io.Dataset`.
feed_list (list(Variable)|tuple(Variable)): feed variable list.
instance of subclass of :code:`paddle.io.Dataset` or
:code:`paddle.io.IterableDataset`.
feed_list (list(Tensor)|tuple(Tensor)): feed variable list.
The variables should be created by :code:`fluid.data()`.
:attr:`feed_list` must be set if :attr:`return_list` is
False. Default None.
...
...
@@ -295,6 +297,10 @@ class DataLoader(object):
# -------------------------------------------------------
.. note::
For reading iterable dataset with multiprocess Dataloader,
please see :code:`paddle.io.IterableDataset`
"""
def
__init__
(
self
,
...
...
@@ -348,6 +354,18 @@ class DataLoader(object):
assert
timeout
>=
0
,
"timeout should be a non-negative value"
self
.
timeout
=
timeout
if
isinstance
(
dataset
,
IterableDataset
):
self
.
dataset_kind
=
_DatasetKind
.
ITER
if
shuffle
:
raise
ValueError
(
"IterableDataset not support shuffle, but got shuffle={}"
.
format
(
shuffle
))
if
batch_sampler
is
not
None
:
raise
ValueError
(
"IterableDataset expect unspecified batch_sampler"
)
else
:
self
.
dataset_kind
=
_DatasetKind
.
MAP
if
batch_sampler
is
not
None
:
assert
isinstance
(
batch_sampler
,
BatchSampler
),
\
"batch_sampler should be None or subclass instance "
\
...
...
@@ -360,11 +378,15 @@ class DataLoader(object):
assert
batch_size
is
not
None
and
batch_size
>
0
,
\
"batch_size should be a positive value when "
\
"batch_sampler is not given"
self
.
batch_sampler
=
BatchSampler
(
dataset
=
dataset
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
if
isinstance
(
dataset
,
IterableDataset
):
self
.
batch_sampler
=
_InfiniteIterableSampler
(
dataset
,
batch_size
)
else
:
self
.
batch_sampler
=
BatchSampler
(
dataset
=
dataset
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
self
.
pin_memory
=
False
if
in_dygraph_mode
():
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
dbc88bb9
...
...
@@ -278,6 +278,7 @@ if (APPLE OR WIN32)
list
(
REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_static
)
list
(
REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_dynamic
)
list
(
REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_exception
)
list
(
REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_iterable_dataset
)
endif
()
if
(
NOT WITH_GPU OR WIN32 OR APPLE
)
...
...
@@ -496,4 +497,6 @@ if(NOT WIN32 AND NOT APPLE)
set_tests_properties
(
test_multiprocess_dataloader_static PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
)
set_tests_properties
(
test_multiprocess_dataloader_dynamic PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
)
set_tests_properties
(
test_multiprocess_dataloader_exception PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
)
set_tests_properties
(
test_multiprocess_dataloader_iterable_dataset_static PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
)
set_tests_properties
(
test_multiprocess_dataloader_iterable_dataset_dynamic PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
)
endif
()
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py
浏览文件 @
dbc88bb9
...
...
@@ -24,7 +24,7 @@ import numpy as np
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.io
import
Dataset
,
BatchSampler
,
DataLoader
from
paddle.io
import
Dataset
,
IterableDataset
,
BatchSampler
,
DataLoader
from
paddle.fluid.dygraph.nn
import
Linear
from
paddle.fluid.dygraph.base
import
to_variable
...
...
@@ -108,6 +108,48 @@ class TestDataLoaderAssert(unittest.TestCase):
self
.
assertTrue
(
False
)
class
TestDatasetRuntimeError
(
unittest
.
TestCase
):
def
test_main
(
self
):
dataset
=
Dataset
()
# __getitem__ not implement
try
:
d
=
dataset
[
0
]
self
.
assertTrue
(
False
)
except
NotImplementedError
:
pass
# __len__ not implement
try
:
l
=
len
(
dataset
)
self
.
assertTrue
(
False
)
except
NotImplementedError
:
pass
dataset
=
IterableDataset
()
# __iter__ not implement
try
:
d
=
iter
(
dataset
)
self
.
assertTrue
(
False
)
except
NotImplementedError
:
pass
# __getitem__ runtime error
try
:
d
=
dataset
[
0
]
self
.
assertTrue
(
False
)
except
RuntimeError
:
pass
# __len__ runtime error
try
:
l
=
len
(
dataset
)
self
.
assertTrue
(
False
)
except
RuntimeError
:
pass
# CI Converage cannot record stub in subprocess,
# HACK a _worker_loop in main process call here
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
...
@@ -144,12 +186,15 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
indices_queue
.
put
([
i
,
i
+
10
])
indices_queue
.
put
(
None
)
loader
.
_worker_loop
(
loader
.
_dataset
,
indices_queue
,
loader
.
_data_queue
,
loader
.
_workers_done_event
,
_collate_fn
,
_init_fn
,
0
)
loader
.
_dataset
,
0
,
indices_queue
,
loader
.
_data_queue
,
loader
.
_workers_done_event
,
_collate_fn
,
_init_fn
,
0
,
1
)
self
.
assertTrue
(
False
)
except
AssertionError
:
pass
except
Exception
:
except
Exception
as
e
:
print
(
"Exception"
,
e
)
import
sys
sys
.
stdout
.
flush
()
self
.
assertTrue
(
False
)
def
run_with_worker_done
(
self
,
use_shared_memory
=
True
):
...
...
@@ -184,8 +229,8 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
indices_queue
.
put
(
None
)
loader
.
_workers_done_event
.
set
()
loader
.
_worker_loop
(
loader
.
_dataset
,
indices_queue
,
loader
.
_data_queue
,
loader
.
_workers_done_event
,
_collate_fn
,
_init_fn
,
0
)
loader
.
_dataset
,
0
,
indices_queue
,
loader
.
_data_queue
,
loader
.
_workers_done_event
,
_collate_fn
,
_init_fn
,
0
,
1
)
self
.
assertTrue
(
True
)
except
AssertionError
:
pass
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py
0 → 100644
浏览文件 @
dbc88bb9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
import
os
import
sys
import
six
import
time
import
unittest
import
multiprocessing
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.io
import
Dataset
,
BatchSampler
,
DataLoader
from
paddle.fluid.dygraph.nn
import
Linear
from
paddle.fluid.dygraph.base
import
to_variable
from
test_multiprocess_dataloader_iterable_dataset_static
import
RandomDataset
,
prepare_places
from
test_multiprocess_dataloader_iterable_dataset_static
import
EPOCH_NUM
,
BATCH_SIZE
,
IMAGE_SIZE
,
SAMPLE_NUM
,
CLASS_NUM
class
SimpleFCNet
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
(
SimpleFCNet
,
self
).
__init__
()
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.8
))
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.5
))
self
.
_fcs
=
[]
in_channel
=
IMAGE_SIZE
for
hidden_size
in
[
10
,
20
,
30
]:
self
.
_fcs
.
append
(
Linear
(
in_channel
,
hidden_size
,
act
=
'tanh'
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
))
in_channel
=
hidden_size
self
.
_fcs
.
append
(
Linear
(
in_channel
,
CLASS_NUM
,
act
=
'softmax'
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
))
def
forward
(
self
,
image
):
out
=
image
for
fc
in
self
.
_fcs
:
out
=
fc
(
out
)
return
out
class
TestDygraphDataLoader
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
,
places
):
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
with
fluid
.
dygraph
.
guard
(
places
[
0
]):
fc_net
=
SimpleFCNet
()
optimizer
=
fluid
.
optimizer
.
Adam
(
parameter_list
=
fc_net
.
parameters
())
dataset
=
RandomDataset
(
SAMPLE_NUM
,
CLASS_NUM
)
dataloader
=
DataLoader
(
dataset
,
places
=
places
,
num_workers
=
num_workers
,
batch_size
=
BATCH_SIZE
,
drop_last
=
True
)
step_list
=
[]
loss_list
=
[]
start_t
=
time
.
time
()
for
_
in
six
.
moves
.
range
(
EPOCH_NUM
):
step
=
0
for
image
,
label
in
dataloader
():
out
=
fc_net
(
image
)
loss
=
fluid
.
layers
.
cross_entropy
(
out
,
label
)
avg_loss
=
fluid
.
layers
.
reduce_mean
(
loss
)
avg_loss
.
backward
()
optimizer
.
minimize
(
avg_loss
)
fc_net
.
clear_gradients
()
loss_list
.
append
(
np
.
mean
(
avg_loss
.
numpy
()))
step
+=
1
step_list
.
append
(
step
)
end_t
=
time
.
time
()
ret
=
{
"time"
:
end_t
-
start_t
,
"step"
:
step_list
,
"loss"
:
np
.
array
(
loss_list
)
}
print
(
"time cost"
,
ret
[
'time'
],
'step_list'
,
ret
[
'step'
])
return
ret
def
test_main
(
self
):
# dynamic graph do not run with_data_parallel
for
p
in
prepare_places
(
False
):
results
=
[]
for
num_workers
in
[
0
,
2
]:
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
)
sys
.
stdout
.
flush
()
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
)
results
.
append
(
ret
)
assert
results
[
0
][
'loss'
].
shape
[
0
]
*
2
==
results
[
1
][
'loss'
].
shape
[
0
]
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_split.py
0 → 100644
浏览文件 @
dbc88bb9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
import
math
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.io
import
IterableDataset
,
BatchSampler
,
DataLoader
,
get_worker_info
class
RangeIterableDatasetSplit
(
IterableDataset
):
def
__init__
(
self
,
start
,
end
):
self
.
start
=
start
self
.
end
=
end
def
__iter__
(
self
):
worker_info
=
get_worker_info
()
if
worker_info
is
None
:
iter_start
=
self
.
start
iter_end
=
self
.
end
else
:
per_worker
=
int
(
math
.
ceil
((
self
.
end
-
self
.
start
)
/
float
(
worker_info
.
num_workers
)))
worker_id
=
worker_info
.
id
iter_start
=
self
.
start
+
worker_id
*
per_worker
iter_end
=
min
(
iter_start
+
per_worker
,
self
.
end
)
for
i
in
range
(
iter_start
,
iter_end
):
yield
np
.
array
([
i
])
class
TestDynamicDataLoaderIterSplit
(
unittest
.
TestCase
):
def
test_main
(
self
):
place
=
fluid
.
CPUPlace
()
with
fluid
.
dygraph
.
guard
(
place
):
dataset
=
RangeIterableDatasetSplit
(
0
,
10
)
dataloader
=
DataLoader
(
dataset
,
places
=
place
,
num_workers
=
2
,
batch_size
=
1
,
drop_last
=
True
)
rets
=
[]
for
d
in
dataloader
:
rets
.
append
(
d
[
0
].
numpy
()[
0
][
0
])
assert
tuple
(
sorted
(
rets
))
==
tuple
(
range
(
0
,
10
))
class
RangeIterableDataset
(
IterableDataset
):
def
__init__
(
self
,
start
,
end
):
self
.
start
=
start
self
.
end
=
end
def
__iter__
(
self
):
for
i
in
range
(
self
.
start
,
self
.
end
):
yield
np
.
array
([
i
])
class
TestDynamicDataLoaderIterInitFuncSplit
(
unittest
.
TestCase
):
def
test_main
(
self
):
place
=
fluid
.
CPUPlace
()
with
fluid
.
dygraph
.
guard
(
place
):
dataset
=
RangeIterableDataset
(
0
,
10
)
def
worker_spliter
(
worker_id
):
worker_info
=
get_worker_info
()
dataset
=
worker_info
.
dataset
start
=
dataset
.
start
end
=
dataset
.
end
num_per_worker
=
int
(
math
.
ceil
((
end
-
start
)
/
float
(
worker_info
.
num_workers
)))
worker_id
=
worker_info
.
id
dataset
.
start
=
start
+
worker_id
*
num_per_worker
dataset
.
end
=
min
(
dataset
.
start
+
num_per_worker
,
end
)
dataloader
=
DataLoader
(
dataset
,
places
=
place
,
num_workers
=
1
,
batch_size
=
1
,
drop_last
=
True
,
worker_init_fn
=
worker_spliter
)
rets
=
[]
for
d
in
dataloader
:
rets
.
append
(
d
[
0
].
numpy
()[
0
][
0
])
assert
tuple
(
sorted
(
rets
))
==
tuple
(
range
(
0
,
10
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py
0 → 100644
浏览文件 @
dbc88bb9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
import
os
import
sys
import
six
import
time
import
unittest
import
multiprocessing
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.io
import
IterableDataset
,
BatchSampler
,
DataLoader
,
get_worker_info
EPOCH_NUM
=
2
BATCH_SIZE
=
8
IMAGE_SIZE
=
32
SAMPLE_NUM
=
80
CLASS_NUM
=
10
class
RandomDataset
(
IterableDataset
):
def
__init__
(
self
,
sample_num
,
class_num
):
self
.
sample_num
=
sample_num
self
.
class_num
=
class_num
def
__iter__
(
self
):
for
i
in
range
(
self
.
sample_num
):
np
.
random
.
seed
(
i
)
image
=
np
.
random
.
random
([
IMAGE_SIZE
]).
astype
(
'float32'
)
label
=
np
.
random
.
randint
(
0
,
self
.
class_num
-
1
,
(
1
,
)).
astype
(
'int64'
)
yield
image
,
label
def
simple_fc_net_static
():
startup_prog
=
fluid
.
Program
()
main_prog
=
fluid
.
Program
()
startup_prog
.
random_seed
=
1
main_prog
.
random_seed
=
1
with
fluid
.
unique_name
.
guard
():
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
[
None
,
IMAGE_SIZE
],
dtype
=
'float32'
)
label
=
fluid
.
data
(
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
hidden
=
image
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.8
))
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.5
))
for
hidden_size
in
[
10
,
20
,
30
]:
hidden
=
fluid
.
layers
.
fc
(
hidden
,
size
=
hidden_size
,
act
=
'tanh'
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
predict_label
=
fluid
.
layers
.
fc
(
hidden
,
size
=
CLASS_NUM
,
act
=
'softmax'
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
loss
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
cross_entropy
(
input
=
predict_label
,
label
=
label
))
optimizer
=
fluid
.
optimizer
.
Adam
()
optimizer
.
minimize
(
loss
)
return
startup_prog
,
main_prog
,
image
,
label
,
loss
def
prepare_places
(
with_data_parallel
,
with_cpu
=
False
,
with_gpu
=
True
):
places
=
[]
if
with_cpu
:
places
.
append
([
fluid
.
CPUPlace
()])
if
with_data_parallel
:
places
.
append
([
fluid
.
CPUPlace
()]
*
2
)
if
with_gpu
and
fluid
.
core
.
is_compiled_with_cuda
():
tmp
=
fluid
.
cuda_places
()[:
2
]
assert
len
(
tmp
)
>
0
,
"no gpu detected"
if
with_data_parallel
:
places
.
append
(
tmp
)
places
.
append
([
tmp
[
0
]])
return
places
class
TestStaticDataLoader
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
,
places
):
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
startup_prog
,
main_prog
,
image
,
label
,
loss
=
simple_fc_net_static
()
dataset
=
RandomDataset
(
SAMPLE_NUM
,
CLASS_NUM
)
dataloader
=
DataLoader
(
dataset
,
feed_list
=
[
image
,
label
],
places
=
places
,
num_workers
=
num_workers
,
batch_size
=
BATCH_SIZE
,
drop_last
=
True
)
# assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
exe
=
fluid
.
Executor
(
place
=
places
[
0
])
exe
.
run
(
startup_prog
)
prog
=
fluid
.
CompiledProgram
(
main_prog
)
if
len
(
places
)
>
1
:
prog
=
prog
.
with_data_parallel
(
loss_name
=
loss
.
name
,
places
=
places
)
step_list
=
[]
loss_list
=
[]
start_t
=
time
.
time
()
for
i
in
six
.
moves
.
range
(
EPOCH_NUM
):
step
=
0
for
d
in
dataloader
:
assert
len
(
d
)
==
len
(
places
),
"{} != {}"
.
format
(
len
(
d
),
len
(
places
))
for
i
,
item
in
enumerate
(
d
):
image
=
item
[
'image'
]
label
=
item
[
'label'
]
assert
image
.
shape
()
==
[
BATCH_SIZE
,
IMAGE_SIZE
]
assert
label
.
shape
()
==
[
BATCH_SIZE
,
1
]
assert
image
.
_place
().
_equals
(
places
[
i
])
assert
label
.
_place
().
_equals
(
places
[
i
])
L
,
=
exe
.
run
(
program
=
prog
,
feed
=
d
,
fetch_list
=
[
loss
],
use_program_cache
=
True
)
loss_list
.
append
(
np
.
mean
(
L
))
step
+=
1
step_list
.
append
(
step
)
end_t
=
time
.
time
()
ret
=
{
"time"
:
end_t
-
start_t
,
"step"
:
step_list
,
"loss"
:
np
.
array
(
loss_list
)
}
print
(
"time cost"
,
ret
[
'time'
],
'step_list'
,
ret
[
'step'
])
return
ret
def
test_main
(
self
):
for
p
in
prepare_places
(
True
):
results
=
[]
for
num_workers
in
[
0
,
2
]:
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
)
sys
.
stdout
.
flush
()
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
)
results
.
append
(
ret
)
assert
results
[
0
][
'loss'
].
shape
[
0
]
*
2
==
results
[
1
][
'loss'
].
shape
[
0
]
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/io/__init__.py
浏览文件 @
dbc88bb9
...
...
@@ -15,9 +15,11 @@
# TODO: define all functions about input & output in this directory
__all__
=
[
'Dataset'
,
'IterableDataset'
,
'BatchSampler'
,
# 'Transform',
'DataLoader'
,
'get_worker_info'
,
'load'
,
'save'
,
'load_program_state'
,
...
...
@@ -36,7 +38,7 @@ __all__ = [
]
from
..fluid.io
import
DataLoader
from
..fluid.dataloader
import
Dataset
,
BatchSampler
from
..fluid.dataloader
import
Dataset
,
IterableDataset
,
BatchSampler
,
get_worker_info
from
..fluid.io
import
load
,
save
,
load_program_state
,
set_program_state
,
\
load_inference_model
,
save_inference_model
,
batch
from
..reader
import
shuffle
,
buffered
,
cache
,
chain
,
firstn
,
compose
,
map_readers
,
xmap_readers
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录