Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
80cf3c3c
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
80cf3c3c
编写于
4月 21, 2020
作者:
K
Kaipeng Deng
提交者:
GitHub
4月 21, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine DataLoader support multi-processing (#23107)
* add DataLoader, Dataset, BatchSampler
上级
76d78c63
变更
21
展开全部
显示空白变更内容
内联
并排
Showing
21 changed file
with
1958 addition
and
190 deletion
+1958
-190
paddle/fluid/imperative/data_loader.cc
paddle/fluid/imperative/data_loader.cc
+75
-51
paddle/fluid/imperative/data_loader.h
paddle/fluid/imperative/data_loader.h
+3
-2
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+16
-4
python/paddle/distributed/utils.py
python/paddle/distributed/utils.py
+3
-1
python/paddle/fluid/core.py
python/paddle/fluid/core.py
+4
-4
python/paddle/fluid/dataloader/__init__.py
python/paddle/fluid/dataloader/__init__.py
+24
-0
python/paddle/fluid/dataloader/batch_sampler.py
python/paddle/fluid/dataloader/batch_sampler.py
+143
-0
python/paddle/fluid/dataloader/dataloader_iter.py
python/paddle/fluid/dataloader/dataloader_iter.py
+528
-0
python/paddle/fluid/dataloader/dataset.py
python/paddle/fluid/dataloader/dataset.py
+73
-0
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+2
-0
python/paddle/fluid/multiprocess_utils.py
python/paddle/fluid/multiprocess_utils.py
+139
-0
python/paddle/fluid/reader.py
python/paddle/fluid/reader.py
+268
-108
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+4
-0
python/paddle/fluid/tests/unittests/test_batch_sampler.py
python/paddle/fluid/tests/unittests/test_batch_sampler.py
+120
-0
python/paddle/fluid/tests/unittests/test_dataloader_dataset.py
...n/paddle/fluid/tests/unittests/test_dataloader_dataset.py
+41
-0
python/paddle/fluid/tests/unittests/test_imperative_data_loader_fds_clear.py
.../tests/unittests/test_imperative_data_loader_fds_clear.py
+29
-0
python/paddle/fluid/tests/unittests/test_imperative_signal_handler.py
...e/fluid/tests/unittests/test_imperative_signal_handler.py
+1
-1
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_base.py
...luid/tests/unittests/test_multiprocess_dataloader_base.py
+260
-0
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py
...tests/unittests/test_multiprocess_dataloader_exception.py
+199
-0
python/paddle/io/__init__.py
python/paddle/io/__init__.py
+24
-19
python/setup.py.in
python/setup.py.in
+2
-0
未找到文件。
paddle/fluid/imperative/data_loader.cc
浏览文件 @
80cf3c3c
...
...
@@ -22,21 +22,23 @@
#include <atomic>
#include <csignal>
#include <map>
#include <set>
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
imperative
{
static
std
::
map
<
int64_t
,
pid_t
>
load_process_pids
;
static
std
::
map
<
int64_t
,
std
::
set
<
pid_t
>
>
load_process_pids
;
void
SetLoadProcessPID
(
int64_t
key
,
pid_t
pid
)
{
VLOG
(
3
)
<<
"D
ygraph Data
Loader: set loader child process PID ("
<<
key
<<
",
"
<<
pid
<<
")"
;
load_process_pids
[
key
]
=
pid
;
void
SetLoadProcessPID
s
(
int64_t
key
,
std
::
set
<
pid_t
>
pids
)
{
VLOG
(
3
)
<<
"D
ata
Loader: set loader child process PID ("
<<
key
<<
",
pid number: "
<<
pids
.
size
()
<<
")"
;
load_process_pids
[
key
]
=
pid
s
;
}
void
EraseLoadProcessPID
(
int64_t
key
)
{
void
EraseLoadProcessPID
s
(
int64_t
key
)
{
auto
it
=
load_process_pids
.
find
(
key
);
// Note: Can not find key also possible
if
(
it
!=
load_process_pids
.
end
())
{
...
...
@@ -54,8 +56,12 @@ void EraseLoadProcessPID(int64_t key) {
// siginfo_t doc: https://www.mkssoftware.com/docs/man5/siginfo_t.5.asp
// waitid doc: https://linux.die.net/man/2/waitid
// clear mmap fds on signal handler, make sure mmap clear will be called
// on signal handling and no need to register mmap clear up handler on
// python side. If shared memory is not used Clear() will do nothing.
#define SIGNAL_HANDLE(SIGNAL) \
do { \
memory::allocation::MemoryMapFdSet::Instance().Clear(); \
struct sigaction sa; \
sa.sa_handler = SIG_DFL; \
sa.sa_flags = 0; \
...
...
@@ -106,30 +112,40 @@ void SetLoadProcessSignalHandler() {
void
ThrowErrorIfLoadProcessFailed
()
{
int
error
;
std
::
set
<
pid_t
>
*
pids_set
;
pid_t
process_pid
;
siginfo_t
infop
;
for
(
auto
&
w
:
load_process_pids
)
{
process_pid
=
w
.
second
;
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
// and other handlers can get whatever info they want about the child.
for
(
auto
&
p
:
load_process_pids
)
{
pids_set
=
&
(
p
.
second
);
for
(
auto
pid_it
=
pids_set
->
begin
();
pid_it
!=
pids_set
->
end
();
++
pid_it
)
{
process_pid
=
*
pid_it
;
// Use waitid rather than waitpid so that we can set NOWAIT, and that
// Python and other handlers can get whatever info they want about the
// child.
infop
.
si_pid
=
0
;
VLOG
(
3
)
<<
"Dygraph Data Loader: monitor loader child process "
<<
process_pid
;
VLOG
(
3
)
<<
"DataLoader: monitor loader child process "
<<
process_pid
;
error
=
waitid
(
P_PID
,
process_pid
,
&
infop
,
WEXITED
|
WNOHANG
|
WNOWAIT
);
// ignore errors and case with no waitable child
if
(
error
<
0
||
infop
.
si_pid
==
0
)
continue
;
if
(
infop
.
si_code
==
CLD_EXITED
&&
infop
.
si_status
!=
EXIT_SUCCESS
)
{
// exit with error
pids_set
->
clear
();
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"DataLoader process (pid %ld) exited unexpectedly with code %d. "
"Error detailed are lost due to multiprocessing. Rerunning with "
"Error detailed are lost due to multiprocessing. Rerunning with:
\n
"
" 1. If run DataLoader by DataLoader.from_generator(...), run "
"with "
"DataLoader.from_generator(..., use_multiprocess=False) may give "
"better error trace."
,
"better error trace.
\n
"
" 2. If run DataLoader by DataLoader(dataset, ...), run with "
"DataLoader(dataset, ..., num_workers=0) may give better error "
"trace"
,
process_pid
,
infop
.
si_status
));
}
else
if
(
infop
.
si_code
==
CLD_KILLED
||
infop
.
si_code
==
CLD_DUMPED
)
{
// killed by signal
if
(
infop
.
si_status
==
SIGBUS
)
{
pids_set
->
clear
();
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"DataLoader process (pid %ld) exited is killed by signal: %s.
\n
"
" It may be caused by insufficient shared storage space. This "
...
...
@@ -138,7 +154,14 @@ void ThrowErrorIfLoadProcessFailed() {
"space of `/dev/shm`. Shared storage space needs to be greater "
"than (DataLoader Num * DataLoader queue capacity * 1 batch data "
"size).
\n
You can solve this problem by increasing the shared "
"storage space or reducing the queue capacity appropriately."
,
"storage space or reducing the queue capacity appropriately.
\n
"
,
" 1. If run DataLoader by DataLoader.from_generator(...), queue "
"capacity is set by from_generator(..., capacity=xx, ...).
\n
"
" 2. If run DataLoader by DataLoader(dataset, ...), queue "
"capacity is set as 2 times of the max value of num_workers and "
"len(places).
\n
"
" 3. If run by DataLoader(dataset, ..., use_shared_memory=True),"
" set use_shared_memory=False for not using shared memory."
,
process_pid
,
strsignal
(
infop
.
si_status
)));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
...
...
@@ -147,6 +170,7 @@ void ThrowErrorIfLoadProcessFailed() {
}
}
}
}
}
}
// namespace imperative
...
...
paddle/fluid/imperative/data_loader.h
浏览文件 @
80cf3c3c
...
...
@@ -18,12 +18,13 @@
#include <unistd.h>
#include <cstdint>
#include <set>
namespace
paddle
{
namespace
imperative
{
extern
void
SetLoadProcessPID
(
int64_t
key
,
pid_t
pid
);
extern
void
EraseLoadProcessPID
(
int64_t
key
);
extern
void
SetLoadProcessPID
s
(
int64_t
key
,
std
::
set
<
pid_t
>
pids
);
extern
void
EraseLoadProcessPID
s
(
int64_t
key
);
extern
void
SetLoadProcessSignalHandler
();
extern
void
ThrowErrorIfLoadProcessFailed
();
...
...
paddle/fluid/pybind/imperative.cc
浏览文件 @
80cf3c3c
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
...
...
@@ -290,11 +291,22 @@ void BindImperative(py::module *m_ptr) {
#ifndef _WIN32
// Dygraph DataLoader signal handler
m
.
def
(
"_set_process_pid"
,
[](
int64_t
key
,
pid_t
pid
)
{
imperative
::
SetLoadProcessPID
(
key
,
pid
);
m
.
def
(
"_set_process_pids"
,
[](
int64_t
key
,
py
::
object
&
obj
)
{
PADDLE_ENFORCE_EQ
(
py
::
isinstance
<
py
::
tuple
>
(
obj
)
||
py
::
isinstance
<
py
::
list
>
(
obj
),
true
,
platform
::
errors
::
InvalidArgument
(
"The subprocess ids set in DataLoader is illegal."
"Expected data type is tuple or list, but received %s"
,
obj
.
get_type
()));
py
::
list
pids
=
py
::
cast
<
py
::
list
>
(
obj
);
std
::
set
<
pid_t
>
pids_set
=
{};
for
(
size_t
i
=
0
;
i
<
pids
.
size
();
i
++
)
{
pids_set
.
insert
(
pids
[
i
].
cast
<
pid_t
>
());
}
imperative
::
SetLoadProcessPIDs
(
key
,
pids_set
);
});
m
.
def
(
"_erase_process_pid"
,
[](
int64_t
key
)
{
imperative
::
EraseLoadProcessPID
(
key
);
});
m
.
def
(
"_erase_process_pid
s
"
,
[](
int64_t
key
)
{
imperative
::
EraseLoadProcessPID
s
(
key
);
});
m
.
def
(
"_set_process_signal_handler"
,
[]()
{
imperative
::
SetLoadProcessSignalHandler
();
});
m
.
def
(
"_throw_error_if_process_failed"
,
...
...
python/paddle/distributed/utils.py
浏览文件 @
80cf3c3c
...
...
@@ -252,7 +252,9 @@ def get_cluster(node_ips, node_ip, paddle_ports, selected_gpus):
def
terminate_local_procs
(
procs
):
for
p
in
procs
:
if
p
.
proc
.
poll
()
is
None
:
p
.
proc
.
terminate
()
# subprocess need to release resource(e.g. shared memory)
# use join to wait subprocess releasing
p
.
proc
.
join
(
timeout
=
1
)
p
.
log_fn
.
close
()
logger
.
debug
(
"terminate process id:{}"
.
format
(
p
.
proc
.
pid
))
...
...
python/paddle/fluid/core.py
浏览文件 @
80cf3c3c
...
...
@@ -185,8 +185,8 @@ if avx_supported():
from
.core_avx
import
_load_dygraph_dict
from
.core_avx
import
_create_loaded_parameter
if
sys
.
platform
!=
'win32'
:
from
.core_avx
import
_set_process_pid
from
.core_avx
import
_erase_process_pid
from
.core_avx
import
_set_process_pid
s
from
.core_avx
import
_erase_process_pid
s
from
.core_avx
import
_set_process_signal_handler
from
.core_avx
import
_throw_error_if_process_failed
from
.core_avx
import
_convert_to_tensor_list
...
...
@@ -229,8 +229,8 @@ if load_noavx:
from
.core_noavx
import
_load_dygraph_dict
from
.core_noavx
import
_create_loaded_parameter
if
sys
.
platform
!=
'win32'
:
from
.core_noavx
import
_set_process_pid
from
.core_noavx
import
_erase_process_pid
from
.core_noavx
import
_set_process_pid
s
from
.core_noavx
import
_erase_process_pid
s
from
.core_noavx
import
_set_process_signal_handler
from
.core_noavx
import
_throw_error_if_process_failed
from
.core_noavx
import
_convert_to_tensor_list
...
...
python/paddle/fluid/dataloader/__init__.py
0 → 100644
浏览文件 @
80cf3c3c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
.
import
dataset
from
.dataset
import
*
from
.
import
batch_sampler
from
.batch_sampler
import
*
__all__
=
dataset
.
__all__
\
+
batch_sampler
.
__all__
python/paddle/fluid/dataloader/batch_sampler.py
0 → 100644
浏览文件 @
80cf3c3c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
__future__
import
division
import
numpy
as
np
from
.dataset
import
Dataset
__all__
=
[
"BatchSampler"
]
class
BatchSampler
(
object
):
"""
A base implement of batch sampler used by `paddle.io.DataLoader`
which yield mini-batch indices(a list/tuple with length as
mini-batch size and holds sample indices) iterably.
Batch sampler used by :code:`paddle.io.DataLoader` should be a subclass
of :code:`paddle.io.BatchSampler`, BatchSampler subclasses should
implement following methods:
:code:`__iter__`: return mini-batch indices iterably.
:code:`__len__`: get mini-batch number in an epoch.
Args:
dataset(Dataset): this could be a :code:`paddle.io.Dataset`
implement or other python object which implemented
:code:`__len__` for BatchSampler to get indices as the
range of :attr:`dataset` length. Default None.
indices (list|tuple): a substitution parameter for
:attr:`dataset` either :attr:`dataset` or
:attr:`indices` should be set, give the whole
indices to sampler from directly. Default None.
shuffle(bool): whether to shuffle indices order before genrating
batch indices. Default False.
batch_size(int): sample indice number in a mini-batch indices.
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
Returns:
BatchSampler: an iterable object for indices iterating
Examples:
.. code-block:: python
from paddle.io import BatchSampler, Dataset
# init with indices
bs = BatchSampler(indices=list(range(100)),
shuffle=True,
batch_size=8,
drop_last=True)
for batch_indices in bs:
print(batch_indices)
# init with dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
bs = BatchSampler(dataset=RandomDataset(100),
shuffle=False,
batch_size=16,
drop_last=False)
for batch_indices in bs:
print(batch_indices)
see `paddle.io.DataLoader`
"""
def
__init__
(
self
,
dataset
=
None
,
indices
=
None
,
shuffle
=
False
,
batch_size
=
1
,
drop_last
=
False
):
if
dataset
is
None
:
assert
indices
is
not
None
,
\
"either dataset or indices should be set"
assert
isinstance
(
indices
,
list
)
or
isinstance
(
indices
,
tuple
),
\
"indices should be a list or tuple, but got {}"
.
format
(
type
(
indices
))
self
.
indices
=
indices
else
:
assert
isinstance
(
dataset
,
Dataset
),
\
"dataset should be an instance of paddle.io.Dataset"
assert
indices
is
None
,
\
"should not set both dataset and indices"
self
.
indices
=
list
(
range
(
len
(
dataset
)))
assert
isinstance
(
batch_size
,
int
)
and
batch_size
>
0
,
\
"batch_size should be a positive integer, but got {}"
.
format
(
batch_size
)
self
.
batch_size
=
batch_size
assert
isinstance
(
shuffle
,
bool
),
\
"shuffle should be a boolean value, but got {}"
.
format
(
type
(
shuffle
))
self
.
shuffle
=
shuffle
assert
isinstance
(
drop_last
,
bool
),
\
"drop_last should be a boolean value, but got {}"
.
format
(
type
(
drop_last
))
self
.
drop_last
=
drop_last
def
__iter__
(
self
):
if
self
.
shuffle
:
np
.
random
.
shuffle
(
self
.
indices
)
_iter
=
iter
(
self
.
indices
)
batch_indices
=
[]
for
idx
in
_iter
:
batch_indices
.
append
(
idx
)
if
len
(
batch_indices
)
==
self
.
batch_size
:
yield
batch_indices
batch_indices
=
[]
if
not
self
.
drop_last
and
len
(
batch_indices
)
>
0
:
yield
batch_indices
def
__len__
(
self
):
num_samples
=
len
(
self
.
indices
)
num_samples
+=
int
(
not
self
.
drop_last
)
*
(
self
.
batch_size
-
1
)
return
num_samples
//
self
.
batch_size
python/paddle/fluid/dataloader/dataloader_iter.py
0 → 100644
浏览文件 @
80cf3c3c
此差异已折叠。
点击以展开。
python/paddle/fluid/dataloader/dataset.py
0 → 100644
浏览文件 @
80cf3c3c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle.dataset.common
__all__
=
[
"Dataset"
]
class
Dataset
(
object
):
"""
An abstract class to encapsulates methods and behaviors of datasets.
All datasets in map-style(dataset samples can be get by a given key)
should be a subclass of `paddle.io.Dataset`. All subclasses should
implement following methods:
:code:`__getitem__`: get sample from dataset with a given index. This
method is required by reading dataset sample in :code:`paddle.io.DataLoader`.
:code:`__len__`: return dataset sample number. This method is required
by some implements of :code:`paddle.io.BatchSampler`
see :code:`paddle.io.DataLoader`.
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset
# define a random dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(10)
for i in range(len(dataset)):
print(dataset[i])
"""
def
__init__
(
self
):
pass
def
__getitem__
(
self
,
idx
):
raise
NotImplementedError
(
"'{}' not implement in class "
\
"{}"
.
format
(
'__getitem__'
,
self
.
__class__
.
__name__
))
def
__len__
(
self
):
raise
NotImplementedError
(
"'{}' not implement in class "
\
"{}"
.
format
(
'__len__'
,
self
.
__class__
.
__name__
))
python/paddle/fluid/io.py
浏览文件 @
80cf3c3c
...
...
@@ -37,6 +37,8 @@ from paddle.fluid.compiler import CompiledProgram
from
paddle.fluid.log_helper
import
get_logger
from
.
import
reader
from
.reader
import
*
from
.
import
dataloader
from
.dataloader
import
*
from
.
import
core
from
..
import
compat
as
cpt
...
...
python/paddle/fluid/multiprocess_utils.py
0 → 100644
浏览文件 @
80cf3c3c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
six
import
sys
import
signal
import
atexit
from
.
import
core
# NOTE: queue has a different name in python2 and python3
if
six
.
PY2
:
import
Queue
as
queue
else
:
import
queue
# NOTE: [ mmap files clear ] If there is still data in the multiprocess queue when the main process finishes reading,
# the data in the queue needs to be popped. Then the LoDTensor read by the main process
# from the child process will automatically clear the memory-mapped file.
multiprocess_queue_set
=
set
()
def
_clear_multiprocess_queue_set
():
global
multiprocess_queue_set
for
data_queue
in
multiprocess_queue_set
:
while
True
:
try
:
data_queue
.
get_nowait
()
except
queue
.
Empty
:
break
# NOTE: main process clear function at exit
def
_cleanup
():
# NOTE: inter-process Queue shared memory objects clear function
_clear_multiprocess_queue_set
()
# NOTE: main process memory map files clear funciton
core
.
_cleanup_mmap_fds
()
# NOTE: for child process clear function at exit
def
_cleanup_mmap
():
# clear memory map files in child process
core
.
_cleanup_mmap_fds
()
# NOTE used for register a function to be executed at interpreter exit.
class
CleanupFuncRegistrar
():
# Record the cleanup functions that have been executed
_executed_func_set
=
set
()
# Record the cleanup functions that have been registered
_registered_func_set
=
set
()
@
classmethod
def
register
(
cls
,
function
,
signals
=
[]):
def
_func_exectuor
():
if
function
not
in
cls
.
_executed_func_set
:
try
:
function
()
finally
:
cls
.
_executed_func_set
.
add
(
function
)
def
_func_register
(
function
):
if
not
callable
(
function
):
raise
TypeError
(
"%s is not callable object."
%
(
function
))
# check function object whether hash-able
set
([
function
])
if
function
not
in
cls
.
_registered_func_set
:
atexit
.
register
(
_func_exectuor
)
cls
.
_registered_func_set
.
add
(
function
)
def
_signal_handler
(
signum
=
None
,
frame
=
None
):
_func_exectuor
()
if
signum
is
not
None
:
if
signum
==
signal
.
SIGINT
:
raise
KeyboardInterrupt
sys
.
exit
(
signum
)
def
_signal_register
(
signals
):
signals
=
set
(
signals
)
for
sig
in
signals
:
orig_handler
=
signal
.
signal
(
sig
,
_signal_handler
)
if
orig_handler
not
in
(
signal
.
SIG_DFL
,
signal
.
SIG_IGN
):
if
(
sig
==
signal
.
SIGINT
and
orig_handler
is
signal
.
default_int_handler
):
continue
if
orig_handler
not
in
cls
.
_registered_func_set
:
atexit
.
register
(
orig_handler
)
cls
.
_registered_func_set
.
add
(
orig_handler
)
# deal with signals
_signal_register
(
signals
)
# deal with function
_func_register
(
function
)
# NOTE: [ mmap files clear ] When the main process exits unexpectedly, the remaining
# shared memory objects in the inter-process Queue and the main process (mostly in the
# BlockingQueue) may not be completely released, resulting in the corresponding
# memory-mapped file remaining on the disk (/dev/shm), so register this function
# to clean up shared memory objects in these two queues before the python interpreter exits.
# NOTE: Currently multi-process DataLoader only supports Linux platform
if
not
(
sys
.
platform
==
'darwin'
or
sys
.
platform
==
'win32'
):
CleanupFuncRegistrar
.
register
(
_cleanup
)
# ------------ SIGCHLD handler setting --------------
_SIGCHLD_handler_set
=
False
def
_set_SIGCHLD_handler
():
global
_SIGCHLD_handler_set
if
_SIGCHLD_handler_set
:
return
current_handler
=
signal
.
getsignal
(
signal
.
SIGCHLD
)
if
not
callable
(
current_handler
):
current_handler
=
None
def
__handler__
(
signum
,
frame
):
# NOTE: Here the signum is SIGCHLD, when the child process exits,
# this handler will be called whenever the child process exits
# normally or abnormally.
core
.
_throw_error_if_process_failed
()
if
current_handler
is
not
None
:
current_handler
(
signum
,
frame
)
signal
.
signal
(
signal
.
SIGCHLD
,
__handler__
)
_SIGCHLD_handler_set
=
True
python/paddle/fluid/reader.py
浏览文件 @
80cf3c3c
...
...
@@ -21,29 +21,28 @@ import paddle
from
.framework
import
Program
,
Variable
,
program_guard
,
default_main_program
,
default_startup_program
,
in_dygraph_mode
,
cpu_places
from
.executor
import
global_scope
from
.data_feeder
import
DataFeeder
,
BatchedTensorProvider
from
.multiprocess_utils
import
multiprocess_queue_set
,
CleanupFuncRegistrar
,
_cleanup_mmap
,
_cleanup
,
_set_SIGCHLD_handler
from
.dataloader
import
BatchSampler
,
Dataset
from
.dataloader.dataloader_iter
import
_DataLoaderIterSingleProcess
,
_DataLoaderIterMultiProcess
from
.layers.io
import
monkey_patch_reader_methods
,
_copy_reader_var_
,
double_buffer
from
.unique_name
import
UniqueNameGenerator
import
logging
from
.dataset
import
DatasetBase
,
InMemoryDataset
### Dygraph DataLoader configs ###
import
atexit
import
os
import
multiprocessing
import
signal
# NOTE: queue has a different name in python2 and python3
if
s
ys
.
version_info
[
0
]
==
2
:
if
s
ix
.
PY
2
:
import
Queue
as
queue
else
:
import
queue
# NOTE: [ avoid hanging & failed quickly ] These value is used in getting data from another process
QUEUE_GET_TIMEOUT
=
60
# NOTE: [ mmap files clear ] If there is still data in the multiprocess queue when the main process finishes reading,
# the data in the queue needs to be popped. Then the LoDTensor read by the main process
# from the child process will automatically clear the memory-mapped file.
multiprocess_queue_set
=
set
()
__all__
=
[
'PyReader'
,
'DataLoader'
]
data_loader_unique_name_generator
=
UniqueNameGenerator
()
...
...
@@ -75,84 +74,6 @@ def _convert_places(places):
return
ret
def
_clear_multiprocess_queue_set
():
global
multiprocess_queue_set
for
data_queue
in
multiprocess_queue_set
:
while
True
:
try
:
data_queue
.
get_nowait
()
except
queue
.
Empty
:
break
# NOTE: main process clear function at exit
def
_cleanup
():
# NOTE: inter-process Queue shared memory objects clear function
_clear_multiprocess_queue_set
()
# NOTE: main process memory map files clear funciton
core
.
_cleanup_mmap_fds
()
# NOTE used for register a function to be executed at interpreter exit.
class
CleanupFuncRegistrar
():
# Record the cleanup functions that have been executed
_executed_func_set
=
set
()
# Record the cleanup functions that have been registered
_registered_func_set
=
set
()
@
classmethod
def
register
(
cls
,
function
,
signals
=
[
signal
.
SIGTERM
]):
def
_func_exectuor
():
if
function
not
in
cls
.
_executed_func_set
:
try
:
function
()
finally
:
cls
.
_executed_func_set
.
add
(
function
)
def
_func_register
(
function
):
if
not
callable
(
function
):
raise
TypeError
(
"%s is not callable object."
%
(
function
))
# check function object whether hash-able
set
([
function
])
if
function
not
in
cls
.
_registered_func_set
:
atexit
.
register
(
_func_exectuor
)
cls
.
_registered_func_set
.
add
(
function
)
def
_signal_handler
(
signum
=
None
,
frame
=
None
):
_func_exectuor
()
if
signum
is
not
None
:
if
signum
==
signal
.
SIGINT
:
raise
KeyboardInterrupt
sys
.
exit
(
signum
)
def
_signal_register
(
signals
):
signals
=
set
(
signals
)
for
sig
in
signals
:
orig_handler
=
signal
.
signal
(
sig
,
_signal_handler
)
if
orig_handler
not
in
(
signal
.
SIG_DFL
,
signal
.
SIG_IGN
):
if
(
sig
==
signal
.
SIGINT
and
orig_handler
is
signal
.
default_int_handler
):
continue
if
orig_handler
not
in
cls
.
_registered_func_set
:
atexit
.
register
(
orig_handler
)
cls
.
_registered_func_set
.
add
(
orig_handler
)
# deal with signals
_signal_register
(
signals
)
# deal with function
_func_register
(
function
)
# NOTE: [ mmap files clear ] When the main process exits unexpectedly, the remaining
# shared memory objects in the inter-process Queue and the main process (mostly in the
# BlockingQueue) may not be completely released, resulting in the corresponding
# memory-mapped file remaining on the disk (/dev/shm), so register this function
# to clean up shared memory objects in these two queues before the python interpreter exits.
# NOTE: Currently multi-process DataLoader only supports Linux platform
if
not
(
sys
.
platform
==
'darwin'
or
sys
.
platform
==
'win32'
):
CleanupFuncRegistrar
.
register
(
_cleanup
)
class
DataLoaderBase
(
object
):
def
__init__
(
self
):
self
.
_places
=
None
...
...
@@ -177,6 +98,264 @@ class DataLoaderBase(object):
class
DataLoader
(
object
):
"""
DataLoader prodives an iterator which iterates given dataset
once by the batch_sampler.
DataLoader supports single-process and multi-prcess data loading,
multi-process workers will be used to load data asynchronously if
:attr:`num_workers` is set as a positive number.
DataLoader only supports map-style dataset(can get a sample from
dataset with a given index) currently, for a map-style dataset,
please see :code:`paddle.io.Dataset`.
batch_sampler please see :code:`paddle.io.BatchSampler`
Args:
dataset(Dataset): the dataset to load data from, should be an
instance of subclass of :code:`paddle.io.Dataset`.
feed_list (list(Variable)|tuple(Variable)): feed variable list.
The variables should be created by :code:`fluid.data()`.
:attr:`feed_list` must be set if :attr:`return_list` is
False. Default None.
places(list(Place)|tuple(Place)): a list of Place, to put data
onto, :attr:`places` must be set in both static graph and
dynamic graph mode, in dynamic graph mode, place number must
be 1. Default None.
return_list (bool): whether the return value on each device is
presented as a list. If :attr:`return_list=False`, the return
value on each device would be a dict of str -> LoDTensor, where
the key of the dict is the name of each fed variables. If
:attr:`return_list=True`, the return value on each device would
be a list(LoDTensor). :attr:`return_list` can only be True
in dynamic graph mode. Default False.
batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler`
to generate batch indices to draw samples from :attr:`dataset`
and combine a batch. Default None.
batch_size(int): sample number in a mini-batch, a substitution
parameter for :attr:`batch_sampler`, if :attr:`batch_sampler`
is not set, a default `paddle.io.BatchSampler` will be used
and initialize by :attr:`batch_size`, :attr:`shuffle` and
:attr:`drop_last`. Default 1.
shuffle(bool): whther to shuffle indices order before genrate
batch indices, a substitution parameter for :attr:`batch_sampler`
see :attr:`batch_size`. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size, a substitution parameter
for :attr:`batch_sampler`, see :attr:`batch_size`. Default False
collate_fn(callable): function to generate mini-batch data by merging
the sample list, None for only stack each fields of sample in axis
0(same as :attr::`np.stack(..., axis=0)`). Default None
num_workers(int): the number of subprocess to load data, 0 for no
subprocess used and loading data in main process. Default 0
use_buffer_reader (bool): whether to use bufferred reader.
If use_buffer_reader=True, the DataLoader would prefetch next
batch data asynchronously, so it would speed up data feeding
and occupies a little more CPU or GPU memory, i.e., the memory
of one batch input data. Default True.
use_shared_memory (bool): whether to use shared memory to speed up
putting data into inter-process queue, set :attr:`use_shared_memory`
as True only when the shared memory space on your machine(e.g.
space of '/dev/shm' on Linux operating sysytem) is large enough.
Shared memory will only be enabled in multi-process mode(num_workers
> 0). Default True.
timeout(int): the timeout value for getting data form output queue
of subprocesses. Default 0.
worker_init_fn(callable): init function which will be called with
worker id on each subproces starting if not set as None. Default
None.
Returns:
DataLoader: an iterable object for data iterating
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
from paddle.io import Dataset, BatchSampler, DataLoader
BATCH_NUM = 20
BATCH_SIZE = 16
EPOCH_NUM = 4
IMAGE_SIZE = 784
CLASS_NUM = 10
USE_GPU = False # whether use GPU to run model
# define a random dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
# get places
places = fluid.cuda_places() if USE_GPU else fluid.cpu_places()
# -------------------- static graph ---------------------
def simple_net(image, label):
fc_tmp = fluid.layers.fc(image, size=CLASS_NUM, act='softmax')
cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label)
loss = fluid.layers.reduce_mean(cross_entropy)
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
return loss
image = fluid.data(name='image', shape=[None, IMAGE_SIZE], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
loss = simple_net(image, label)
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)
dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
loader = DataLoader(dataset,
feed_list=[image, label],
places=places,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=2)
for e in range(EPOCH_NUM):
for i, data in enumerate(loader()):
l = exe.run(prog, feed=data, fetch_list=[loss], return_numpy=True)
print("Epoch {} batch {}: loss = {}".format(e, i, l[0][0]))
# -------------------------------------------------------
# --------------------- dygraph mode --------------------
class SimpleNet(fluid.dygraph.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = fluid.dygraph.nn.Linear(IMAGE_SIZE, CLASS_NUM, act='softmax')
def forward(self, image, label=None):
return self.fc(image)
with fluid.dygraph.guard(places[0]):
simple_net = SimpleNet()
opt = fluid.optimizer.SGD(learning_rate=1e-3,
parameter_list=simple_net.parameters())
loader = DataLoader(dataset,
places=places[0],
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=2)
for e in range(EPOCH_NUM):
for i, (image, label) in enumerate(loader()):
out = simple_net(image)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.reduce_mean(loss)
avg_loss.backward()
opt.minimize(avg_loss)
simple_net.clear_gradients()
print("Epoch {} batch {}: loss = {}".format(e, i, np.mean(loss.numpy())))
# -------------------------------------------------------
"""
def
__init__
(
self
,
dataset
,
feed_list
=
None
,
places
=
None
,
return_list
=
False
,
batch_sampler
=
None
,
batch_size
=
1
,
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
None
,
num_workers
=
0
,
use_buffer_reader
=
True
,
use_shared_memory
=
True
,
timeout
=
0
,
worker_init_fn
=
None
):
self
.
return_list
=
return_list
self
.
collate_fn
=
collate_fn
self
.
use_buffer_reader
=
use_buffer_reader
self
.
worker_init_fn
=
worker_init_fn
assert
isinstance
(
dataset
,
Dataset
),
\
"dataset should be subclass instance of paddle.io.Dataset"
self
.
dataset
=
dataset
if
not
return_list
and
not
in_dygraph_mode
():
assert
feed_list
is
not
None
,
\
"feed_list should be set when return_list=False"
self
.
feed_list
=
feed_list
assert
places
is
not
None
,
"places cannot be None"
self
.
places
=
_convert_places
(
places
)
if
in_dygraph_mode
():
assert
len
(
self
.
places
)
==
1
,
\
"Number of places must be 1 in dygraph mode"
assert
num_workers
>=
0
,
"num_workers should be a non-negative value"
if
num_workers
>
0
and
(
sys
.
platform
==
'darwin'
or
sys
.
platform
==
'win32'
):
logging
.
warning
(
"multi-process mode not support MacOs and Windows currently."
\
" use signle-process with num_workers = 0 instead"
)
num_workers
=
0
self
.
num_workers
=
num_workers
self
.
use_shared_memory
=
use_shared_memory
if
use_shared_memory
and
num_workers
==
0
:
self
.
use_shared_memory
=
False
assert
timeout
>=
0
,
"timeout should be a non-negative value"
self
.
timeout
=
timeout
if
batch_sampler
is
not
None
:
assert
isinstance
(
batch_sampler
,
BatchSampler
),
\
"batch_sampler should be None or subclass instance "
\
"of paddle.io.BatchSampler"
assert
batch_size
==
1
and
not
shuffle
and
not
drop_last
,
\
"batch_size/shuffle/drop_last should not be set when "
\
"batch_sampler is given"
self
.
batch_sampler
=
batch_sampler
else
:
assert
batch_size
is
not
None
and
batch_size
>
0
,
\
"batch_size should be a positive value when "
\
"batch_sampler is not given"
self
.
batch_sampler
=
BatchSampler
(
dataset
=
dataset
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
def
__len__
(
self
):
return
len
(
self
.
batch_sampler
)
def
__iter__
(
self
):
if
self
.
num_workers
==
0
:
return
_DataLoaderIterSingleProcess
(
self
)
else
:
return
_DataLoaderIterMultiProcess
(
self
)
def
__call__
(
self
):
return
self
.
__iter__
()
@
staticmethod
def
from_generator
(
feed_list
=
None
,
capacity
=
None
,
...
...
@@ -553,22 +732,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
if
process
is
not
None
:
process
.
join
()
# erase process id
core
.
_erase_process_pid
(
id
(
self
))
def
_set_child_signal_handler
(
self
):
core
.
_set_process_pid
(
id
(
self
),
self
.
_process
.
pid
)
current_handler
=
signal
.
getsignal
(
signal
.
SIGCHLD
)
if
not
callable
(
current_handler
):
current_handler
=
None
def
__handler__
(
signum
,
frame
):
# NOTE: Here the signum is SIGDHLD, when the child process exits, this handler
# will be called whenever the child process exits normally or abnormally.
core
.
_throw_error_if_process_failed
()
if
current_handler
is
not
None
:
current_handler
(
signum
,
frame
)
signal
.
signal
(
signal
.
SIGCHLD
,
__handler__
)
core
.
_erase_process_pids
(
id
(
self
))
def
_init_iterable
(
self
):
self
.
_wait_thread_ends
()
...
...
@@ -605,7 +769,8 @@ class DygraphGeneratorLoader(DataLoaderBase):
# with SIGSEGV and SIGBUS of child process; 2. if the main process end before child
# process, it shuts the all its daemonic children down with a SIGTERM (instead of
# joining them without a timeout), so here nedd to deal with SIGTERM.
self
.
_set_child_signal_handler
()
core
.
_set_process_pids
(
id
(
self
),
[
self
.
_process
.
pid
])
_set_SIGCHLD_handler
()
# Set reader_thread
self
.
_thread_done_event
=
threading
.
Event
()
...
...
@@ -666,16 +831,11 @@ class DygraphGeneratorLoader(DataLoaderBase):
# set signal handler
core
.
_set_process_signal_handler
()
# child process clear function at exit
def
_cleanup
():
# clear memory map files in child process
core
.
_cleanup_mmap_fds
()
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet
# been put into the inter-process Queue. This part of the object needs
# to be cleaned up when the process ends.
CleanupFuncRegistrar
.
register
(
_cleanup
)
CleanupFuncRegistrar
.
register
(
_cleanup
_mmap
)
for
batch
in
self
.
_batch_reader
():
tensor_list
=
core
.
_convert_to_tensor_list
(
batch
)
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
80cf3c3c
...
...
@@ -211,6 +211,8 @@ if (APPLE OR WIN32)
list
(
REMOVE_ITEM TEST_OPS test_imperative_data_loader_fds_clear
)
list
(
REMOVE_ITEM TEST_OPS test_imperative_data_loader_exit_func
)
list
(
REMOVE_ITEM TEST_OPS test_imperative_signal_handler
)
list
(
REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_base
)
list
(
REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_exception
)
endif
()
if
(
NOT WITH_GPU OR WIN32 OR APPLE
)
...
...
@@ -381,4 +383,6 @@ if(NOT WIN32 AND NOT APPLE)
set_tests_properties
(
test_imperative_data_loader_base PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
RUN_SERIAL TRUE
)
set_tests_properties
(
test_imperative_data_loader_fds_clear PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
RUN_SERIAL TRUE
)
# set_tests_properties(test_imperative_data_loader_exception PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE)
set_tests_properties
(
test_multiprocess_dataloader_base PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
RUN_SERIAL TRUE
)
set_tests_properties
(
test_multiprocess_dataloader_exception PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
RUN_SERIAL TRUE
)
endif
()
python/paddle/fluid/tests/unittests/test_batch_sampler.py
0 → 100644
浏览文件 @
80cf3c3c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
import
unittest
import
paddle.fluid
as
fluid
from
paddle.io
import
BatchSampler
,
Dataset
class
RandomDataset
(
Dataset
):
def
__init__
(
self
,
sample_num
,
class_num
):
self
.
sample_num
=
sample_num
self
.
class_num
=
class_num
def
__getitem__
(
self
,
idx
):
np
.
random
.
seed
(
idx
)
image
=
np
.
random
.
random
([
IMAGE_SIZE
]).
astype
(
'float32'
)
label
=
np
.
random
.
randint
(
0
,
CLASS_NUM
-
1
,
(
1
,
)).
astype
(
'int64'
)
return
image
,
label
def
__len__
(
self
):
return
self
.
sample_num
class
TestBatchSampler
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
num_samples
=
1000
self
.
num_classes
=
10
self
.
batch_size
=
32
self
.
shuffle
=
False
self
.
drop_last
=
False
def
init_batch_sampler
(
self
):
dataset
=
RandomDataset
(
self
.
num_samples
,
self
.
num_classes
)
bs
=
BatchSampler
(
dataset
=
dataset
,
batch_size
=
self
.
batch_size
,
shuffle
=
self
.
shuffle
,
drop_last
=
self
.
drop_last
)
return
bs
def
test_main
(
self
):
bs
=
self
.
init_batch_sampler
()
# length check
bs_len
=
(
self
.
num_samples
+
int
(
not
self
.
drop_last
)
\
*
(
self
.
batch_size
-
1
))
//
self
.
batch_size
self
.
assertTrue
(
bs_len
==
len
(
bs
))
# output indices check
if
not
self
.
shuffle
:
index
=
0
for
indices
in
bs
:
for
idx
in
indices
:
self
.
assertTrue
(
index
==
idx
)
index
+=
1
class
TestBatchSamplerDropLast
(
TestBatchSampler
):
def
setUp
(
self
):
self
.
num_samples
=
1000
self
.
num_classes
=
10
self
.
batch_size
=
32
self
.
shuffle
=
False
self
.
drop_last
=
True
class
TestBatchSamplerShuffle
(
TestBatchSampler
):
def
setUp
(
self
):
self
.
num_samples
=
1000
self
.
num_classes
=
10
self
.
batch_size
=
32
self
.
shuffle
=
True
self
.
drop_last
=
True
class
TestBatchSamplerWithIndices
(
TestBatchSampler
):
def
init_batch_sampler
(
self
):
bs
=
BatchSampler
(
indices
=
list
(
range
(
self
.
num_samples
)),
batch_size
=
self
.
batch_size
,
drop_last
=
self
.
drop_last
)
return
bs
class
TestBatchSamplerWithIndicesAndDataSource
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
num_samples
=
1000
self
.
num_classes
=
10
self
.
batch_size
=
32
self
.
shuffle
=
False
self
.
drop_last
=
True
def
test_main
(
self
):
try
:
dataset
=
RandomDataset
(
self
.
num_samples
,
self
.
num_classes
)
bs
=
BatchSampler
(
dataset
=
dataset
,
indices
=
list
(
range
(
self
.
num_samples
)),
batch_size
=
self
.
batch_size
,
drop_last
=
self
.
drop_last
)
self
.
assertTrue
(
False
)
except
AssertionError
:
pass
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_dataloader_dataset.py
0 → 100644
浏览文件 @
80cf3c3c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.io
import
*
class
TestDatasetAbstract
(
unittest
.
TestCase
):
def
test_main
(
self
):
dataset
=
Dataset
()
try
:
d
=
dataset
[
0
]
self
.
assertTrue
(
False
)
except
NotImplementedError
:
pass
try
:
l
=
len
(
dataset
)
self
.
assertTrue
(
False
)
except
NotImplementedError
:
pass
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_imperative_data_loader_fds_clear.py
浏览文件 @
80cf3c3c
...
...
@@ -17,6 +17,7 @@ import unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
from
paddle.io
import
Dataset
,
DataLoader
def
get_random_images_and_labels
(
image_shape
,
label_shape
):
...
...
@@ -35,6 +36,20 @@ def batch_generator_creator(batch_size, batch_num):
return
__reader__
class
RandomDataset
(
Dataset
):
def
__init__
(
self
,
sample_num
):
self
.
sample_num
=
sample_num
def
__getitem__
(
self
,
idx
):
np
.
random
.
seed
(
idx
)
image
=
np
.
random
.
random
([
784
]).
astype
(
'float32'
)
label
=
np
.
random
.
randint
(
0
,
9
,
(
1
,
)).
astype
(
'int64'
)
return
image
,
label
def
__len__
(
self
):
return
self
.
sample_num
class
TestDygraphDataLoaderMmapFdsClear
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_size
=
8
...
...
@@ -74,5 +89,19 @@ class TestDygraphDataLoaderMmapFdsClear(unittest.TestCase):
self
.
run_one_epoch_with_break
(
loader
)
class
TestMultiProcessDataLoaderMmapFdsClear
(
TestDygraphDataLoaderMmapFdsClear
):
def
prepare_data_loader
(
self
):
place
=
fluid
.
CPUPlace
()
with
fluid
.
dygraph
.
guard
(
place
):
dataset
=
RandomDataset
(
self
.
batch_size
*
self
.
batch_num
)
loader
=
DataLoader
(
dataset
,
places
=
place
,
batch_size
=
self
.
batch_size
,
drop_last
=
True
,
num_workers
=
2
)
return
loader
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_imperative_signal_handler.py
浏览文件 @
80cf3c3c
...
...
@@ -24,7 +24,7 @@ from paddle.fluid import core
def
set_child_signal_handler
(
self
,
child_pid
):
core
.
_set_process_pid
(
id
(
self
),
child_pid
)
core
.
_set_process_pid
s
(
id
(
self
),
tuple
([
child_pid
])
)
current_handler
=
signal
.
getsignal
(
signal
.
SIGCHLD
)
if
not
callable
(
current_handler
):
current_handler
=
None
...
...
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_base.py
0 → 100644
浏览文件 @
80cf3c3c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
import
os
import
sys
import
six
import
time
import
unittest
import
multiprocessing
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.io
import
Dataset
,
BatchSampler
,
DataLoader
from
paddle.fluid.dygraph.nn
import
Linear
from
paddle.fluid.dygraph.base
import
to_variable
EPOCH_NUM
=
5
BATCH_SIZE
=
16
IMAGE_SIZE
=
784
SAMPLE_NUM
=
400
CLASS_NUM
=
10
class
RandomDataset
(
Dataset
):
def
__init__
(
self
,
sample_num
,
class_num
):
self
.
sample_num
=
sample_num
self
.
class_num
=
class_num
def
__getitem__
(
self
,
idx
):
np
.
random
.
seed
(
idx
)
image
=
np
.
random
.
random
([
IMAGE_SIZE
]).
astype
(
'float32'
)
label
=
np
.
random
.
randint
(
0
,
self
.
class_num
-
1
,
(
1
,
)).
astype
(
'int64'
)
return
image
,
label
def
__len__
(
self
):
return
self
.
sample_num
def
simple_fc_net_static
():
startup_prog
=
fluid
.
Program
()
main_prog
=
fluid
.
Program
()
startup_prog
.
random_seed
=
1
main_prog
.
random_seed
=
1
with
fluid
.
unique_name
.
guard
():
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
[
None
,
IMAGE_SIZE
],
dtype
=
'float32'
)
label
=
fluid
.
data
(
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
hidden
=
image
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.8
))
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.5
))
for
hidden_size
in
[
10
,
20
,
30
]:
hidden
=
fluid
.
layers
.
fc
(
hidden
,
size
=
hidden_size
,
act
=
'tanh'
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
predict_label
=
fluid
.
layers
.
fc
(
hidden
,
size
=
CLASS_NUM
,
act
=
'softmax'
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
loss
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
cross_entropy
(
input
=
predict_label
,
label
=
label
))
optimizer
=
fluid
.
optimizer
.
Adam
()
optimizer
.
minimize
(
loss
)
return
startup_prog
,
main_prog
,
image
,
label
,
loss
class
SimpleFCNet
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
(
SimpleFCNet
,
self
).
__init__
()
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.8
))
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.5
))
self
.
_fcs
=
[]
in_channel
=
IMAGE_SIZE
for
hidden_size
in
[
10
,
20
,
30
]:
self
.
_fcs
.
append
(
Linear
(
in_channel
,
hidden_size
,
act
=
'tanh'
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
))
in_channel
=
hidden_size
self
.
_fcs
.
append
(
Linear
(
in_channel
,
CLASS_NUM
,
act
=
'softmax'
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
))
def
forward
(
self
,
image
):
out
=
image
for
fc
in
self
.
_fcs
:
out
=
fc
(
out
)
return
out
class
TestStaticDataLoader
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
,
places
,
with_data_parallel
):
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
startup_prog
,
main_prog
,
image
,
label
,
loss
=
simple_fc_net_static
()
dataset
=
RandomDataset
(
SAMPLE_NUM
,
CLASS_NUM
)
dataloader
=
DataLoader
(
dataset
,
feed_list
=
[
image
,
label
],
places
=
places
,
num_workers
=
num_workers
,
batch_size
=
BATCH_SIZE
,
drop_last
=
True
)
assert
len
(
dataloader
)
==
int
(
SAMPLE_NUM
/
BATCH_SIZE
)
exe
=
fluid
.
Executor
(
place
=
places
[
0
])
exe
.
run
(
startup_prog
)
prog
=
fluid
.
CompiledProgram
(
main_prog
)
if
with_data_parallel
:
prog
=
prog
.
with_data_parallel
(
loss_name
=
loss
.
name
,
places
=
places
)
step_list
=
[]
loss_list
=
[]
start_t
=
time
.
time
()
for
_
in
six
.
moves
.
range
(
EPOCH_NUM
):
step
=
0
for
d
in
dataloader
:
assert
len
(
d
)
==
len
(
places
),
"{} != {}"
.
format
(
len
(
d
),
len
(
places
))
for
i
,
item
in
enumerate
(
d
):
image
=
item
[
'image'
]
label
=
item
[
'label'
]
assert
image
.
shape
()
==
[
BATCH_SIZE
,
IMAGE_SIZE
]
assert
label
.
shape
()
==
[
BATCH_SIZE
,
1
]
assert
image
.
_place
().
_equals
(
places
[
i
])
assert
label
.
_place
().
_equals
(
places
[
i
])
L
,
=
exe
.
run
(
program
=
prog
,
feed
=
d
,
fetch_list
=
[
loss
],
use_program_cache
=
True
)
loss_list
.
append
(
np
.
mean
(
L
))
step
+=
1
step_list
.
append
(
step
)
end_t
=
time
.
time
()
ret
=
{
"time"
:
end_t
-
start_t
,
"step"
:
step_list
,
"loss"
:
np
.
array
(
loss_list
)
}
print
(
"time cost"
,
ret
[
'time'
],
'step_list'
,
ret
[
'step'
])
return
ret
def
prepare_places
(
self
,
with_data_parallel
,
with_cpu
=
True
,
with_gpu
=
True
):
places
=
[]
# FIXME: PR_CI_Py35 may hang on Multi-CPUs with multiprocess, but it
# works fine locally, this should be fixed. OTOH, multiprocessing
# is not recommended when running on CPU generally
if
with_cpu
and
not
sys
.
version
.
startswith
(
'3.5'
):
places
.
append
([
fluid
.
CPUPlace
()])
if
with_data_parallel
:
places
.
append
([
fluid
.
CPUPlace
()]
*
2
)
if
with_gpu
and
fluid
.
core
.
is_compiled_with_cuda
():
tmp
=
fluid
.
cuda_places
()[:
2
]
assert
len
(
tmp
)
>
0
,
"no gpu detected"
if
with_data_parallel
:
places
.
append
(
tmp
)
places
.
append
([
tmp
[
0
]])
return
places
def
test_main
(
self
):
for
with_data_parallel
in
[
False
]
if
self
.
__class__
.
__name__
\
==
"TestDygraphDataLoader"
else
[
True
,
False
]:
for
p
in
self
.
prepare_places
(
with_data_parallel
):
results
=
[]
for
num_workers
in
[
0
,
2
]:
print
(
self
.
__class__
.
__name__
,
p
,
num_workers
)
ret
=
self
.
run_main
(
num_workers
=
num_workers
,
places
=
p
,
with_data_parallel
=
with_data_parallel
)
results
.
append
(
ret
)
diff
=
np
.
max
(
np
.
abs
(
results
[
0
][
'loss'
]
-
results
[
1
][
'loss'
])
/
np
.
abs
(
results
[
0
][
'loss'
]))
self
.
assertLess
(
diff
,
1e-2
)
class
TestDygraphDataLoader
(
TestStaticDataLoader
):
def
run_main
(
self
,
num_workers
,
places
,
with_data_parallel
):
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_main_program
().
random_seed
=
1
with
fluid
.
dygraph
.
guard
(
places
[
0
]):
fc_net
=
SimpleFCNet
()
optimizer
=
fluid
.
optimizer
.
Adam
(
parameter_list
=
fc_net
.
parameters
())
dataset
=
RandomDataset
(
SAMPLE_NUM
,
CLASS_NUM
)
dataloader
=
DataLoader
(
dataset
,
places
=
places
,
num_workers
=
num_workers
,
batch_size
=
BATCH_SIZE
,
drop_last
=
True
)
assert
len
(
dataloader
)
==
int
(
SAMPLE_NUM
/
BATCH_SIZE
)
step_list
=
[]
loss_list
=
[]
start_t
=
time
.
time
()
for
_
in
six
.
moves
.
range
(
EPOCH_NUM
):
step
=
0
for
image
,
label
in
dataloader
():
out
=
fc_net
(
image
)
loss
=
fluid
.
layers
.
cross_entropy
(
out
,
label
)
avg_loss
=
fluid
.
layers
.
reduce_mean
(
loss
)
avg_loss
.
backward
()
optimizer
.
minimize
(
avg_loss
)
fc_net
.
clear_gradients
()
loss_list
.
append
(
np
.
mean
(
avg_loss
.
numpy
()))
step
+=
1
step_list
.
append
(
step
)
end_t
=
time
.
time
()
ret
=
{
"time"
:
end_t
-
start_t
,
"step"
:
step_list
,
"loss"
:
np
.
array
(
loss_list
)
}
print
(
"time cost"
,
ret
[
'time'
],
'step_list'
,
ret
[
'step'
])
return
ret
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py
0 → 100644
浏览文件 @
80cf3c3c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
import
os
import
sys
import
six
import
time
import
unittest
import
multiprocessing
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.io
import
Dataset
,
BatchSampler
,
DataLoader
from
paddle.fluid.dygraph.nn
import
Linear
from
paddle.fluid.dygraph.base
import
to_variable
class
RandomDataset
(
Dataset
):
def
__init__
(
self
,
sample_num
):
self
.
sample_num
=
sample_num
def
__getitem__
(
self
,
idx
):
np
.
random
.
seed
(
idx
)
image
=
np
.
random
.
random
([
784
]).
astype
(
'float32'
)
label
=
np
.
random
.
randint
(
0
,
9
,
(
1
,
)).
astype
(
'int64'
)
return
image
,
label
def
__len__
(
self
):
return
self
.
sample_num
class
TestDataLoaderAssert
(
unittest
.
TestCase
):
def
test_main
(
self
):
place
=
fluid
.
cpu_places
()[
0
]
with
fluid
.
dygraph
.
guard
(
place
):
dataset
=
RandomDataset
(
100
)
batch_sampler
=
BatchSampler
(
dataset
=
dataset
,
batch_size
=
4
)
# dataset is not instance of Dataset
try
:
loader
=
DataLoader
(
dataset
=
batch_sampler
,
places
=
place
)
self
.
assertTrue
(
False
)
except
AssertionError
:
pass
# places is None
try
:
loader
=
DataLoader
(
dataset
=
dataset
,
places
=
None
)
self
.
assertTrue
(
False
)
except
AssertionError
:
pass
# num_workers < 0
try
:
loader
=
DataLoader
(
dataset
=
dataset
,
places
=
place
,
num_workers
=-
1
)
self
.
assertTrue
(
False
)
except
AssertionError
:
pass
# timeout < 0
try
:
loader
=
DataLoader
(
dataset
=
dataset
,
places
=
place
,
timeout
=-
1
)
self
.
assertTrue
(
False
)
except
AssertionError
:
pass
# batch_sampler is not instance of BatchSampler
try
:
loader
=
DataLoader
(
dataset
=
dataset
,
places
=
place
,
batch_sampler
=
dataset
)
self
.
assertTrue
(
False
)
except
AssertionError
:
pass
# set batch_sampler and shuffle/batch_size/drop_last
try
:
loader
=
DataLoader
(
dataset
=
dataset
,
places
=
place
,
batch_sampler
=
batch_sampler
,
shuffle
=
True
,
drop_last
=
True
)
self
.
assertTrue
(
False
)
except
AssertionError
:
pass
# set batch_sampler correctly
try
:
loader
=
DataLoader
(
dataset
=
dataset
,
places
=
place
,
batch_sampler
=
batch_sampler
)
self
.
assertTrue
(
True
)
except
AssertionError
:
self
.
assertTrue
(
False
)
# CI Converage cannot record stub in subprocess,
# HACK a _worker_loop in main process call here
class
TestDataLoaderWorkerLoop
(
unittest
.
TestCase
):
def
run_without_worker_done
(
self
,
use_shared_memory
=
True
):
try
:
place
=
fluid
.
cpu_places
()[
0
]
with
fluid
.
dygraph
.
guard
(
place
):
dataset
=
RandomDataset
(
800
)
# test init_fn
def
_init_fn
(
worker_id
):
pass
# test collate_fn
def
_collate_fn
(
sample_list
):
return
[
np
.
stack
(
s
,
axis
=
0
)
for
s
in
list
(
zip
(
*
sample_list
))
]
loader
=
DataLoader
(
dataset
,
num_workers
=
1
,
places
=
place
,
use_shared_memory
=
use_shared_memory
)
assert
loader
.
num_workers
>
0
,
\
"go to AssertionError and pass in Mac and Windows"
loader
=
iter
(
loader
)
print
(
"loader length"
,
len
(
loader
))
indices_queue
=
multiprocessing
.
Queue
()
for
i
in
range
(
10
):
indices_queue
.
put
([
i
,
i
+
10
])
indices_queue
.
put
(
None
)
loader
.
_worker_loop
(
loader
.
_dataset
,
indices_queue
,
loader
.
_data_queue
,
loader
.
_workers_done_event
,
_collate_fn
,
_init_fn
,
0
)
self
.
assertTrue
(
False
)
except
AssertionError
:
pass
except
Exception
:
self
.
assertTrue
(
False
)
def
run_with_worker_done
(
self
,
use_shared_memory
=
True
):
try
:
place
=
fluid
.
cpu_places
()[
0
]
with
fluid
.
dygraph
.
guard
(
place
):
dataset
=
RandomDataset
(
800
)
# test init_fn
def
_init_fn
(
worker_id
):
pass
# test collate_fn
def
_collate_fn
(
sample_list
):
return
[
np
.
stack
(
s
,
axis
=
0
)
for
s
in
list
(
zip
(
*
sample_list
))
]
loader
=
DataLoader
(
dataset
,
num_workers
=
1
,
places
=
place
,
use_shared_memory
=
use_shared_memory
)
assert
loader
.
num_workers
>
0
,
\
"go to AssertionError and pass in Mac and Windows"
loader
=
iter
(
loader
)
print
(
"loader length"
,
len
(
loader
))
indices_queue
=
multiprocessing
.
Queue
()
for
i
in
range
(
10
):
indices_queue
.
put
([
i
,
i
+
10
])
indices_queue
.
put
(
None
)
loader
.
_workers_done_event
.
set
()
loader
.
_worker_loop
(
loader
.
_dataset
,
indices_queue
,
loader
.
_data_queue
,
loader
.
_workers_done_event
,
_collate_fn
,
_init_fn
,
0
)
self
.
assertTrue
(
True
)
except
AssertionError
:
pass
except
Exception
:
self
.
assertTrue
(
False
)
def
test_main
(
self
):
for
use_shared_memory
in
[
True
,
False
]:
self
.
run_without_worker_done
(
use_shared_memory
)
self
.
run_with_worker_done
(
use_shared_memory
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/io/__init__.py
浏览文件 @
80cf3c3c
...
...
@@ -13,22 +13,27 @@
# limitations under the License.
# TODO: define all functions about input & output in this directory
# __all__ = ['Dataset',
# 'Sampler',
# 'Transform',
# 'DataLoader',
# 'load',
# 'save',
# 'load_program_state',
# 'set_program_state',
# 'load_inference_model',
# 'save_inference_model',
# 'batch',
# 'shuffle',
# 'buffered',
# 'cache',
# 'chain',
# 'firstn',
# 'compose',
# 'map_readers',
# 'xmap_readers']
__all__
=
[
'Dataset'
,
'BatchSampler'
,
# 'Transform',
'DataLoader'
,
# 'load',
# 'save',
# 'load_program_state',
# 'set_program_state',
# 'load_inference_model',
# 'save_inference_model',
# 'batch',
# 'shuffle',
# 'buffered',
# 'cache',
# 'chain',
# 'firstn',
# 'compose',
# 'map_readers',
# 'xmap_readers'
]
from
..fluid.io
import
DataLoader
from
..fluid.dataloader
import
Dataset
,
BatchSampler
python/setup.py.in
浏览文件 @
80cf3c3c
...
...
@@ -149,6 +149,7 @@ packages=['paddle',
'paddle.fluid.proto.profiler',
'paddle.fluid.distributed',
'paddle.fluid.layers',
'paddle.fluid.dataloader',
'paddle.fluid.contrib',
'paddle.fluid.contrib.decoder',
'paddle.fluid.contrib.quantize',
...
...
@@ -176,6 +177,7 @@ packages=['paddle',
'paddle.fluid.incubate.fleet.parameter_server.pslib',
'paddle.fluid.incubate.fleet.collective',
'paddle.fluid.incubate.fleet.utils',
'paddle.io',
'paddle.nn',
'paddle.nn.functional',
'paddle.nn.layer',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录