未验证 提交 fa7bb0c7 编写于 作者: C Chen Weihang 提交者: GitHub

Speeding up dygraph DataLoader with multiprocessing (#21762) (#22354)

* add multiprocess for dygraph data loader, test=develop

* polish code & add safe gurad, test=develop

* refactor dygraph dataloader & add signal handler, test=develop

* fix member initializer compile error on ci, test=develop

* fix member initializer compile error one more, test=develop

* remove useless config, test=develop

* skip windows incompatible problem, test=develop

* add unittest for coverage, test=coverage

* add more exception unittest case, test=develop

* deal with signal handler coverage, test=develop

* polish code & add signal handler tests, test=develop

* deal with coverage ci problem, test=develop

* split data loader test & coverage ci fix, test=develop

* remove test_imperative_data_loader_with_exception, test=develop

* remove singal process except test case, test=develop

* add exception tests again & remove sample list test, test=develop

* split normal and exception unittests to diff class, test=develop

* polish doc for use_multiprocess effect in static mode, test=develop
上级 ecc52688
......@@ -10,6 +10,7 @@ cc_library(engine SRCS engine.cc DEPS layer gradient_accumulator)
cc_library(imperative_profiler SRCS profiler.cc)
if(NOT WIN32)
cc_library(nccl_context SRCS nccl_context.cc DEPS device_context)
cc_library(data_loader SRCS data_loader.cc DEPS enforce)
endif(NOT WIN32)
add_subdirectory(tests)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef _WIN32
#include "paddle/fluid/imperative/data_loader.h"
#include <string.h>
#include <sys/wait.h>
#include <atomic>
#include <csignal>
#include <map>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace imperative {
static std::map<int64_t, pid_t> load_process_pids;
void SetLoadProcessPID(int64_t key, pid_t pid) {
VLOG(3) << "Dygraph Data Loader: set loader child process PID (" << key
<< ", " << pid << ")";
load_process_pids[key] = pid;
}
void EraseLoadProcessPID(int64_t key) {
auto it = load_process_pids.find(key);
// Note: Can not find key also possible
if (it != load_process_pids.end()) {
VLOG(3) << "Dygraph Data Loader: erase loader child process PID (" << key
<< ")";
load_process_pids.erase(it);
} else {
VLOG(3) << "Dygraph Data Loader: The dygrph loader (id: " << key
<< ") you want erase does not exist.";
}
}
// sigaction doc: http://man7.org/linux/man-pages/man2/sigaction.2.html
// sigemptyset doc: https://linux.die.net/man/3/sigemptyset
// siginfo_t doc: https://www.mkssoftware.com/docs/man5/siginfo_t.5.asp
// waitid doc: https://linux.die.net/man/2/waitid
#define SIGNAL_HANDLE(SIGNAL) \
do { \
struct sigaction sa; \
sa.sa_handler = SIG_DFL; \
sa.sa_flags = 0; \
if (sigemptyset(&sa.sa_mask) != 0 || \
sigaction(SIGNAL, &sa, nullptr) != 0) { \
_exit(EXIT_FAILURE); \
} else { \
raise(SIGNAL); \
} \
} while (0)
#define REGISTER_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \
SIGNAL_HANDLE(SIGNAL); \
}
#define REGISTER_SPEC_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \
if (info->si_pid == getppid()) { \
_exit(EXIT_SUCCESS); \
} \
SIGNAL_HANDLE(SIGNAL); \
}
REGISTER_SIGNAL_HANDLER(SIGSEGV, SIGSEGV_handler);
REGISTER_SIGNAL_HANDLER(SIGBUS, SIGBUS_handler);
REGISTER_SPEC_SIGNAL_HANDLER(SIGTERM, SIGTERM_handler);
static inline void setSignalHandler(int signal,
void (*handler)(int, siginfo_t *, void *),
struct sigaction *old_sa_ptr) {
struct sigaction sa;
sa.sa_sigaction = handler;
sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER;
if (sigemptyset(&sa.sa_mask) != 0 ||
sigaction(signal, &sa, old_sa_ptr) != 0) {
PADDLE_THROW(platform::errors::Fatal(
"An error occurred while setting handler for %s.", strsignal(signal)));
}
}
// Note: maybe need to add other signal handler
void SetLoadProcessSignalHandler() {
setSignalHandler(SIGSEGV, &SIGSEGV_handler, nullptr);
setSignalHandler(SIGBUS, &SIGBUS_handler, nullptr);
setSignalHandler(SIGTERM, &SIGTERM_handler, nullptr);
}
void ThrowErrorIfLoadProcessFailed() {
int error;
pid_t process_pid;
siginfo_t infop;
for (auto &w : load_process_pids) {
process_pid = w.second;
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
// and other handlers can get whatever info they want about the child.
infop.si_pid = 0;
VLOG(3) << "Dygraph Data Loader: monitor loader child process "
<< process_pid;
error = waitid(P_PID, process_pid, &infop, WEXITED | WNOHANG | WNOWAIT);
// ignore errors and case with no waitable child
if (error < 0 || infop.si_pid == 0) continue;
if (infop.si_code == CLD_EXITED &&
infop.si_status != EXIT_SUCCESS) { // exit with error
PADDLE_THROW(platform::errors::Fatal(
"DataLoader process (pid %ld) exited unexpectedly with code %d. "
"Error detailed are lost due to multiprocessing. Rerunning with "
"DataLoader.from_generator(..., use_multiprocess=False) may give "
"better error trace.",
process_pid, infop.si_status));
} else if (infop.si_code == CLD_KILLED ||
infop.si_code == CLD_DUMPED) { // killed by signal
PADDLE_THROW(platform::errors::Fatal(
"DataLoader process (pid %ld) exited is killed by signal: %s.",
process_pid, strsignal(infop.si_status)));
}
}
}
} // namespace imperative
} // namespace paddle
#endif
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _WIN32
#include <unistd.h>
#include <cstdint>
namespace paddle {
namespace imperative {
extern void SetLoadProcessPID(int64_t key, pid_t pid);
extern void EraseLoadProcessPID(int64_t key);
extern void SetLoadProcessSignalHandler();
extern void ThrowErrorIfLoadProcessFailed();
} // namespace imperative
} // namespace paddle
#endif
......@@ -3,7 +3,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context)
if(NOT WIN32)
set(PYBIND_DEPS ${PYBIND_DEPS} nccl_context)
set(PYBIND_DEPS ${PYBIND_DEPS} nccl_context data_loader)
endif(NOT WIN32)
if(WITH_PYTHON)
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/backward_strategy.h"
#include "paddle/fluid/imperative/data_loader.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/imperative/profiler.h"
......@@ -276,6 +277,19 @@ void BindImperative(py::module *m_ptr) {
imperative::SetCurrentTracer(tracer);
});
#ifndef _WIN32
// Dygraph DataLoader signal handler
m.def("_set_process_pid", [](int64_t key, pid_t pid) {
imperative::SetLoadProcessPID(key, pid);
});
m.def("_erase_process_pid",
[](int64_t key) { imperative::EraseLoadProcessPID(key); });
m.def("_set_process_signal_handler",
[]() { imperative::SetLoadProcessSignalHandler(); });
m.def("_throw_error_if_process_failed",
[]() { imperative::ThrowErrorIfLoadProcessFailed(); });
#endif
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
m, "VarBase",
R"DOC()DOC")
......
......@@ -184,6 +184,11 @@ if avx_supported():
from .core_avx import _save_dygraph_dict
from .core_avx import _load_dygraph_dict
from .core_avx import _create_loaded_parameter
if sys.platform != 'win32':
from .core_avx import _set_process_pid
from .core_avx import _erase_process_pid
from .core_avx import _set_process_signal_handler
from .core_avx import _throw_error_if_process_failed
except Exception as e:
if has_avx_core:
raise e
......@@ -220,6 +225,11 @@ if load_noavx:
from .core_noavx import _save_dygraph_dict
from .core_noavx import _load_dygraph_dict
from .core_noavx import _create_loaded_parameter
if sys.platform != 'win32':
from .core_noavx import _set_process_pid
from .core_noavx import _erase_process_pid
from .core_noavx import _set_process_signal_handler
from .core_noavx import _throw_error_if_process_failed
except Exception as e:
if has_noavx_core:
sys.stderr.write(
......
......@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import core, dygraph
from . import core
import sys
import six
import warnings
import numpy as np
import threading
import paddle
......@@ -27,6 +26,17 @@ from .unique_name import UniqueNameGenerator
import logging
from .dataset import DatasetBase, InMemoryDataset
### Dygraph DataLoader configs ###
import multiprocessing
import signal
# NOTE: queue has a different name in python2 and python3
if sys.version_info[0] == 2:
import Queue as queue
else:
import queue
# NOTE: [ avoid hanging ] This value is used in getting data from another process
MP_CHECK_TIMEOUT = 10
__all__ = ['PyReader', 'DataLoader']
data_loader_unique_name_generator = UniqueNameGenerator()
......@@ -76,7 +86,8 @@ class DataLoader(object):
capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False):
return_list=False,
use_multiprocess=False):
"""
Create a DataLoader object for loading data from Python generator.
Data would be prefetched using Python thread and be pushed
......@@ -117,6 +128,11 @@ class DataLoader(object):
return value on each device would be a list(LoDTensor). It is
recommended to use return_list=False in static graph mode and
use return_list=True in dygraph mode.
use_multiprocess (bool): whether to use multi-process to speed up
the data loading process in dygraph. Note: this parameter only
can be used in the dygraph mode. In the static graph mode,
whether this parameter is set or not has no effect.
The Default value is False.
Returns:
loader (DataLoader): the created DataLoader object.
......@@ -254,8 +270,13 @@ class DataLoader(object):
assert label.shape == [BATCH_SIZE, 1]
assert relu.shape == [BATCH_SIZE, 784]
"""
return GeneratorLoader(feed_list, capacity, use_double_buffer, iterable,
return_list)
if in_dygraph_mode():
return DygraphGeneratorLoader(feed_list, capacity,
use_double_buffer, iterable,
return_list, use_multiprocess)
else:
return GeneratorLoader(feed_list, capacity, use_double_buffer,
iterable, return_list)
@staticmethod
def from_dataset(dataset, places, drop_last=True):
......@@ -295,32 +316,312 @@ class DataLoader(object):
return DatasetLoader(dataset, places, drop_last)
class GeneratorLoader(DataLoaderBase):
class DygraphGeneratorLoader(DataLoaderBase):
"""
The GeneratorLoader of dygraph
The multiprocess dygraph GeneratorLoader's most functions are different from
static graph GeneratorLoader, Separate implementation to keep code readable.
"""
def __init__(self,
feed_list=None,
capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False):
self._tensor_reader = None
return_list=True,
use_multiprocess=False):
self._batch_reader = None
self._places = None
self._thread = None
self._feed_list = feed_list
if not capacity:
raise ValueError("Please give value to capacity.")
# force to use iterable mode under dygraph mode
if in_dygraph_mode():
self._capacity = capacity
self._use_double_buffer = use_double_buffer
if not iterable:
warnings.warn(
logging.warning(
"Please NOTE: dygraph can support iterable mode only. Change to iterable mode."
)
self._iterable = True
if not return_list:
warnings.warn(
logging.warning(
"Please NOTE: dygraph can support return as list only. Change to return as list."
)
self._return_list = True
# NOTE: the multiprocessing in different platform is incompatible, we will solve it later
self._use_multiprocess = use_multiprocess
if self._use_multiprocess and (sys.platform == 'darwin' or
sys.platform == 'win32'):
logging.warning(
"NOTE: The multiprocess mode does not currently support MacOs and Windows."
)
self._use_multiprocess = False
if self._use_multiprocess:
# NOTE: the multiprocessing.Queue used to save loading data in self._process
self._data_queue = None
# NOTE: this process is used to load data asynchronously from self._batch_reader
self._process = None
# NOTE: the C++ LoDTensorBlockingQueue instance
self._blocking_queue = None
# NOTE: 1. In multiprocess mode, this thread is used to get next batch data from
# self._data_queue, then push it into self._blocking_queue; 2. In singleprocess
# mode, this thread is used to get next batch data from self._batch_reader, then
# push it into self._blocking_queue
self._thread = None
@property
def queue(self):
return self._blocking_queue
@property
def iterable(self):
return self._iterable
def _wait_thread_ends(self):
thread = self._thread
if thread is not None:
self._blocking_queue.close()
thread.join()
def _wait_process_ends(self):
process = self._process
if process is not None:
self._data_queue.cancel_join_thread()
self._data_queue.close()
process.join()
# erase process id
core._erase_process_pid(id(self))
def _init_iterable(self):
self._wait_thread_ends()
if self._use_multiprocess:
self._wait_process_ends()
self._var_names = []
self._shapes = []
self._dtypes = []
self._need_check_feed = []
self._blocking_queue = core.init_lod_tensor_blocking_queue(
core.Variable(), self._capacity)
self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer)
def _start(self):
if self._use_multiprocess:
# Set data_queue and process
self._data_queue = multiprocessing.Queue(self._capacity)
self._process = multiprocessing.Process(
target=self._reader_process_loop)
self._process.daemon = True
self._process.start()
# Set child process signal handler
# NOTE: [ avoiding hang ] 1. if the child process dies due to bus error/segfault
# or just hang, the main process will hang waiting for data, so here need to deal
# with SIGSEGV and SIGBUS of child process; 2. if the main process end before child
# process, it shuts the all its daemonic children down with a SIGTERM (instead of
# joining them without a timeout), so here nedd to deal with SIGTERM.
self._set_child_signal_handler()
# Set reader_thread
self._thread_done_event = threading.Event()
self._thread = threading.Thread(
target=self._reader_thread_loop_with_process)
self._thread.daemon = True
self._thread.start()
else:
self._thread = threading.Thread(target=self._reader_thread_loop)
self._thread.daemon = True
self._thread.start()
def _reset(self):
self._reader.reset()
self._wait_thread_ends()
if self._use_multiprocess:
self._wait_process_ends()
def __iter__(self):
assert self.iterable, "DataLoader is not iterable"
assert self._batch_reader is not None, \
"Data source of DataLoader has not set yet"
self._init_iterable()
self._start()
return self
def __next__(self):
try:
return self._reader.read_next_var_list()
except StopIteration:
self._reset()
six.reraise(*sys.exc_info())
@classmethod
def _check_input_array(cls, item):
arr = np.array(item)
if arr.dtype == np.object:
raise TypeError(
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
"this means the input data contains nested lists with different lengths. "
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
" to locate the data causes this issue.\n\t* Please consider using "
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor.")
def _set_child_signal_handler(self):
core._set_process_pid(id(self), self._process.pid)
current_handler = signal.getsignal(signal.SIGCHLD)
if not callable(current_handler):
current_handler = None
def __handler__(signum, frame):
core._throw_error_if_process_failed()
if current_handler is not None:
current_handler(signum, frame)
signal.signal(signal.SIGCHLD, __handler__)
def _reader_process_loop(self):
try:
# set signal handler
core._set_process_signal_handler()
for sample in self._batch_reader():
if sample is None:
raise ValueError(
"Sample in reader is None. Please check whether your dataset is valid."
)
self._data_queue.put(sample)
self._data_queue.put(None)
except KeyboardInterrupt:
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
pass
except:
self._data_queue.cancel_join_thread()
self._data_queue.close()
six.reraise(*sys.exc_info())
def _reader_thread_loop_with_process(self):
while not self._thread_done_event.is_set():
try:
# NOTE: [ avoid hanging ] Even with carefully designed data dependencies
# (i.e., a put() always corresponding to a get()), hanging on get() can
# still happen when data in queue is corrupted (e.g., due to
# Queue.cancel_join_thread or unexpected exit). So we set a timeout whenever
# we try to get data from `data_queue`
sample = self._data_queue.get(timeout=MP_CHECK_TIMEOUT)
except queue.Empty:
self._thread_done_event.set()
logging.error("The reader has not read data for a long time.")
if not self._thread_done_event.is_set():
if sample is not None:
try:
array = core.LoDTensorArray()
for item in sample:
if not isinstance(item, core.LoDTensor):
self._check_input_array(item)
tmp = core.LoDTensor()
tmp.set(item, core.CPUPlace())
item = tmp
array.append(item)
if not self._blocking_queue.push(array):
self._blocking_queue.close()
except:
self._thread_done_event.set()
self._blocking_queue.kill()
self._data_queue.close()
logging.warning(
"DygraphDataLoader reader thread raised an exception."
)
six.reraise(*sys.exc_info())
else:
self._thread_done_event.set()
self._blocking_queue.close()
self._data_queue.close()
else:
self._blocking_queue.kill()
self._data_queue.close()
def _reader_thread_loop(self):
try:
for sample in self._batch_reader():
array = core.LoDTensorArray()
for item in sample:
if not isinstance(item, core.LoDTensor):
self._check_input_array(item)
tmp = core.LoDTensor()
tmp.set(item, core.CPUPlace())
item = tmp
array.append(item)
if not self._blocking_queue.push(array):
break
self._blocking_queue.close()
self._thread = None
except Exception:
self._blocking_queue.kill()
self._thread = None
logging.warning(
"DygraphDataLoader reader thread raised an exception.")
six.reraise(*sys.exc_info())
def set_sample_generator(self,
reader,
batch_size,
drop_last=True,
places=None):
assert batch_size > 0, "batch_size must be larger than 0"
self.set_sample_list_generator(
paddle.batch(
reader, batch_size=batch_size, drop_last=drop_last),
places=places)
return self
def set_sample_list_generator(self, reader, places=None):
def __batch_reader_impl__():
for batch in reader():
slots = []
for items in batch:
for i, item in enumerate(items):
if len(slots) < len(items):
slots.append([item])
else:
slots[i].append(item)
yield slots
self.set_batch_generator(__batch_reader_impl__, places)
return self
def set_batch_generator(self, reader, places=None):
self._batch_reader = reader
assert places is not None, "Places cannot be None when DataLoader is iterable"
self._places = _convert_places(places)
assert len(self._places) == 1, \
"Number of places must be 1 in dygraph mode"
return self
class GeneratorLoader(DataLoaderBase):
def __init__(self,
feed_list=None,
capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False):
self._tensor_reader = None
self._places = None
self._thread = None
self._queue = None
self._feed_list = feed_list
if not capacity:
raise ValueError("Please give value to capacity.")
self._iterable = iterable
self._return_list = return_list
if not self._feed_list:
......@@ -340,12 +641,6 @@ class GeneratorLoader(DataLoaderBase):
def _init_iterable(self):
self._wait_thread_ends()
if in_dygraph_mode():
self._var_names = []
self._shapes = []
self._dtypes = []
self._need_check_feed = []
else:
self._var_names = [v.name for v in self._feed_list]
self._shapes = [v.shape for v in self._feed_list]
self._dtypes = [v.dtype for v in self._feed_list]
......@@ -442,9 +737,6 @@ class GeneratorLoader(DataLoaderBase):
def __next__(self):
try:
if in_dygraph_mode():
return self._reader.read_next_var_list()
else:
if self._return_list:
return self._reader.read_next_list()
else:
......@@ -455,12 +747,10 @@ class GeneratorLoader(DataLoaderBase):
six.reraise(*sys.exc_info())
def start(self):
if not in_dygraph_mode():
assert not self._iterable, "start() cannot be called when DataLoader is iterable"
self._start()
def reset(self):
if not in_dygraph_mode():
assert not self._iterable, "reset() cannot be called when DataLoader is iterable"
self._reset()
......@@ -518,12 +808,6 @@ class GeneratorLoader(DataLoaderBase):
drop_last=True,
places=None):
assert batch_size > 0, "batch_size must be larger than 0"
if in_dygraph_mode():
self.set_sample_list_generator(
paddle.batch(
reader, batch_size=batch_size, drop_last=drop_last),
places=places)
else:
has_lod = False
for f in self._feed_list:
if f.lod_level != 0:
......@@ -546,24 +830,10 @@ class GeneratorLoader(DataLoaderBase):
return self
def set_sample_list_generator(self, reader, places=None):
if in_dygraph_mode():
def __tensor_reader_impl__():
for batch in reader():
slots = []
for items in batch:
for i, item in enumerate(items):
if len(slots) < len(items):
slots.append([item])
else:
slots[i].append(item)
yield slots
else:
with program_guard(Program(), Program()):
feeder = DataFeeder(
feed_list=self._feed_list, place=core.CPUPlace())
paddle_reader = feeder.decorate_reader(
reader, multi_devices=False)
paddle_reader = feeder.decorate_reader(reader, multi_devices=False)
def __tensor_reader_impl__():
for slots in paddle_reader():
......@@ -577,9 +847,6 @@ class GeneratorLoader(DataLoaderBase):
if self._iterable:
assert places is not None, "Places cannot be None when DataLoader is iterable"
self._places = _convert_places(places)
if in_dygraph_mode():
assert len(self._places) == 1, \
"Number of places must be 1 in dygraph mode"
else:
if places is not None:
logging.info(
......
......@@ -187,6 +187,9 @@ list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass)
if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_dataset)
list(REMOVE_ITEM TEST_OPS test_dataset_dataloader)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_process)
list(REMOVE_ITEM TEST_OPS test_imperative_signal_handler)
endif()
# Some ops need to check results when gc is enabled
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import core
import paddle.compat as cpt
def get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return image, label
def sample_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num * batch_size):
image, label = get_random_images_and_labels([784], [1])
yield image, label
return __reader__
def sample_list_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num):
sample_list = []
for _ in range(batch_size):
image, label = get_random_images_and_labels([784], [1])
sample_list.append([image, label])
yield sample_list
return __reader__
def batch_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num):
batch_image, batch_label = get_random_images_and_labels(
[batch_size, 784], [batch_size, 1])
yield batch_image, batch_label
return __reader__
class TestDygraphhDataLoader(unittest.TestCase):
def setUp(self):
self.batch_size = 8
self.batch_num = 4
self.epoch_num = 2
self.capacity = 2
def test_single_process_reader(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, iterable=False, use_multiprocess=False)
loader.set_sample_generator(
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size,
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_sample_genarator(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_sample_generator(
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size,
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_sample_list_generator(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_sample_list_generator(
sample_list_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_batch_genarator(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_batch_generator(
batch_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
class TestDygraphhDataLoaderWithException(unittest.TestCase):
def setUp(self):
self.batch_num = 4
self.capacity = 2
def test_not_capacity(self):
with fluid.dygraph.guard():
with self.assertRaisesRegexp(ValueError,
"Please give value to capacity."):
fluid.io.DataLoader.from_generator()
def test_single_process_with_thread_expection(self):
def error_sample_genarator(batch_num):
def __reader__():
for _ in range(batch_num):
yield [[[1, 2], [1]]]
return __reader__
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, iterable=False, use_multiprocess=False)
loader.set_batch_generator(
error_sample_genarator(self.batch_num), places=fluid.CPUPlace())
exception = None
try:
for _ in loader():
print("test_single_process_with_thread_expection")
except core.EnforceNotMet as ex:
self.assertIn("Blocking queue is killed",
cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
def test_multi_process_with_thread_expection(self):
def error_sample_genarator(batch_num):
def __reader__():
for _ in range(batch_num):
yield [[[1, 2], [1]]]
return __reader__
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_batch_generator(
error_sample_genarator(self.batch_num), places=fluid.CPUPlace())
exception = None
try:
for _ in loader():
print("test_multi_process_with_thread_expection")
except core.EnforceNotMet as ex:
self.assertIn("Blocking queue is killed",
cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle.fluid as fluid
if sys.version_info[0] == 2:
import Queue as queue
else:
import queue
def get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return image, label
def batch_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num):
batch_image, batch_label = get_random_images_and_labels(
[batch_size, 784], [batch_size, 1])
yield batch_image, batch_label
return __reader__
# NOTE: coverage CI can't cover child process code, so need these test.
# Here test child process loop function in main process
class TestDygraphhDataLoaderProcess(unittest.TestCase):
def setUp(self):
self.batch_size = 8
self.batch_num = 4
self.epoch_num = 2
self.capacity = 2
def test_reader_process_loop(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.batch_num + 1, use_multiprocess=True)
loader.set_batch_generator(
batch_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace())
loader._data_queue = queue.Queue(self.batch_num + 1)
loader._reader_process_loop()
for _ in range(self.batch_num):
loader._data_queue.get(timeout=10)
def test_reader_process_loop_simple_none(self):
def none_sample_genarator(batch_num):
def __reader__():
for _ in range(batch_num):
yield None
return __reader__
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.batch_num + 1, use_multiprocess=True)
loader.set_batch_generator(
none_sample_genarator(self.batch_num), places=fluid.CPUPlace())
loader._data_queue = queue.Queue(self.batch_num + 1)
exception = None
try:
loader._reader_process_loop()
except AttributeError as ex:
exception = ex
self.assertIsNotNone(exception)
if __name__ == '__main__':
unittest.main()
......@@ -242,8 +242,6 @@ class TestDygraphResnet(unittest.TestCase):
optimizer = optimizer_setting(
train_parameters, parameter_list=resnet.parameters())
np.random.seed(seed)
import random
random.seed = seed
batch_py_reader = fluid.io.PyReader(capacity=1)
batch_py_reader.decorate_sample_list_generator(
......@@ -330,8 +328,6 @@ class TestDygraphResnet(unittest.TestCase):
optimizer = optimizer_setting(train_parameters)
np.random.seed(seed)
import random
random.seed = seed
train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size)
......
......@@ -316,8 +316,6 @@ class TestImperativeResneXt(unittest.TestCase):
optimizer = optimizer_setting(
train_parameters, parameter_list=se_resnext.parameters())
np.random.seed(seed)
import random
random.seed = seed
batch_py_reader = fluid.io.PyReader(capacity=1)
batch_py_reader.decorate_sample_list_generator(
......@@ -379,8 +377,6 @@ class TestImperativeResneXt(unittest.TestCase):
optimizer = optimizer_setting(train_parameters)
np.random.seed(seed)
import random
random.seed = seed
train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import signal
import unittest
import multiprocessing
import time
import paddle.compat as cpt
from paddle.fluid import core
def set_child_signal_handler(self, child_pid):
core._set_process_pid(id(self), child_pid)
current_handler = signal.getsignal(signal.SIGCHLD)
if not callable(current_handler):
current_handler = None
def __handler__(signum, frame):
core._throw_error_if_process_failed()
if current_handler is not None:
current_handler(signum, frame)
signal.signal(signal.SIGCHLD, __handler__)
class TestDygraphDataLoaderSingalHandler(unittest.TestCase):
def test_child_process_exit_will_error(self):
def __test_process__():
core._set_process_signal_handler()
sys.exit(1)
exception = None
try:
test_process = multiprocessing.Process(target=__test_process__)
test_process.start()
set_child_signal_handler(id(self), test_process.pid)
time.sleep(1)
except core.EnforceNotMet as ex:
self.assertIn("FatalError", cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
def test_child_process_killed_by_sigsegv(self):
def __test_process__():
core._set_process_signal_handler()
os.kill(os.getpid(), signal.SIGSEGV)
exception = None
try:
test_process = multiprocessing.Process(target=__test_process__)
test_process.start()
set_child_signal_handler(id(self), test_process.pid)
time.sleep(1)
except core.EnforceNotMet as ex:
self.assertIn("FatalError", cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
def test_child_process_killed_by_sigterm(self):
def __test_process__():
core._set_process_signal_handler()
time.sleep(10)
test_process = multiprocessing.Process(target=__test_process__)
test_process.daemon = True
test_process.start()
set_child_signal_handler(id(self), test_process.pid)
time.sleep(1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册