未验证 提交 35efbe6d 编写于 作者: C Chen Weihang 提交者: GitHub

Speeding up dygraph DataLoader with multiprocessing (#21762)

* add multiprocess for dygraph data loader, test=develop

* polish code & add safe gurad, test=develop

* refactor dygraph dataloader & add signal handler, test=develop

* fix member initializer compile error on ci, test=develop

* fix member initializer compile error one more, test=develop

* remove useless config, test=develop

* skip windows incompatible problem, test=develop

* add unittest for coverage, test=coverage

* add more exception unittest case, test=develop

* deal with signal handler coverage, test=develop

* polish code & add signal handler tests, test=develop

* deal with coverage ci problem, test=develop

* split data loader test & coverage ci fix, test=develop

* remove test_imperative_data_loader_with_exception, test=develop

* remove singal process except test case, test=develop

* add exception tests again & remove sample list test, test=develop

* split normal and exception unittests to diff class, test=develop

* polish doc for use_multiprocess effect in static mode, test=develop
上级 5751509e
......@@ -10,6 +10,7 @@ cc_library(engine SRCS engine.cc DEPS layer gradient_accumulator)
cc_library(imperative_profiler SRCS profiler.cc)
if(NOT WIN32)
cc_library(nccl_context SRCS nccl_context.cc DEPS device_context)
cc_library(data_loader SRCS data_loader.cc DEPS enforce)
endif(NOT WIN32)
add_subdirectory(tests)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef _WIN32
#include "paddle/fluid/imperative/data_loader.h"
#include <string.h>
#include <sys/wait.h>
#include <atomic>
#include <csignal>
#include <map>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace imperative {
static std::map<int64_t, pid_t> load_process_pids;
void SetLoadProcessPID(int64_t key, pid_t pid) {
VLOG(3) << "Dygraph Data Loader: set loader child process PID (" << key
<< ", " << pid << ")";
load_process_pids[key] = pid;
}
void EraseLoadProcessPID(int64_t key) {
auto it = load_process_pids.find(key);
// Note: Can not find key also possible
if (it != load_process_pids.end()) {
VLOG(3) << "Dygraph Data Loader: erase loader child process PID (" << key
<< ")";
load_process_pids.erase(it);
} else {
VLOG(3) << "Dygraph Data Loader: The dygrph loader (id: " << key
<< ") you want erase does not exist.";
}
}
// sigaction doc: http://man7.org/linux/man-pages/man2/sigaction.2.html
// sigemptyset doc: https://linux.die.net/man/3/sigemptyset
// siginfo_t doc: https://www.mkssoftware.com/docs/man5/siginfo_t.5.asp
// waitid doc: https://linux.die.net/man/2/waitid
#define SIGNAL_HANDLE(SIGNAL) \
do { \
struct sigaction sa; \
sa.sa_handler = SIG_DFL; \
sa.sa_flags = 0; \
if (sigemptyset(&sa.sa_mask) != 0 || \
sigaction(SIGNAL, &sa, nullptr) != 0) { \
_exit(EXIT_FAILURE); \
} else { \
raise(SIGNAL); \
} \
} while (0)
#define REGISTER_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \
SIGNAL_HANDLE(SIGNAL); \
}
#define REGISTER_SPEC_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \
if (info->si_pid == getppid()) { \
_exit(EXIT_SUCCESS); \
} \
SIGNAL_HANDLE(SIGNAL); \
}
REGISTER_SIGNAL_HANDLER(SIGSEGV, SIGSEGV_handler);
REGISTER_SIGNAL_HANDLER(SIGBUS, SIGBUS_handler);
REGISTER_SPEC_SIGNAL_HANDLER(SIGTERM, SIGTERM_handler);
static inline void setSignalHandler(int signal,
void (*handler)(int, siginfo_t *, void *),
struct sigaction *old_sa_ptr) {
struct sigaction sa;
sa.sa_sigaction = handler;
sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER;
if (sigemptyset(&sa.sa_mask) != 0 ||
sigaction(signal, &sa, old_sa_ptr) != 0) {
PADDLE_THROW(platform::errors::Fatal(
"An error occurred while setting handler for %s.", strsignal(signal)));
}
}
// Note: maybe need to add other signal handler
void SetLoadProcessSignalHandler() {
setSignalHandler(SIGSEGV, &SIGSEGV_handler, nullptr);
setSignalHandler(SIGBUS, &SIGBUS_handler, nullptr);
setSignalHandler(SIGTERM, &SIGTERM_handler, nullptr);
}
void ThrowErrorIfLoadProcessFailed() {
int error;
pid_t process_pid;
siginfo_t infop;
for (auto &w : load_process_pids) {
process_pid = w.second;
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
// and other handlers can get whatever info they want about the child.
infop.si_pid = 0;
VLOG(3) << "Dygraph Data Loader: monitor loader child process "
<< process_pid;
error = waitid(P_PID, process_pid, &infop, WEXITED | WNOHANG | WNOWAIT);
// ignore errors and case with no waitable child
if (error < 0 || infop.si_pid == 0) continue;
if (infop.si_code == CLD_EXITED &&
infop.si_status != EXIT_SUCCESS) { // exit with error
PADDLE_THROW(platform::errors::Fatal(
"DataLoader process (pid %ld) exited unexpectedly with code %d. "
"Error detailed are lost due to multiprocessing. Rerunning with "
"DataLoader.from_generator(..., use_multiprocess=False) may give "
"better error trace.",
process_pid, infop.si_status));
} else if (infop.si_code == CLD_KILLED ||
infop.si_code == CLD_DUMPED) { // killed by signal
PADDLE_THROW(platform::errors::Fatal(
"DataLoader process (pid %ld) exited is killed by signal: %s.",
process_pid, strsignal(infop.si_status)));
}
}
}
} // namespace imperative
} // namespace paddle
#endif
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _WIN32
#include <unistd.h>
#include <cstdint>
namespace paddle {
namespace imperative {
extern void SetLoadProcessPID(int64_t key, pid_t pid);
extern void EraseLoadProcessPID(int64_t key);
extern void SetLoadProcessSignalHandler();
extern void ThrowErrorIfLoadProcessFailed();
} // namespace imperative
} // namespace paddle
#endif
......@@ -4,7 +4,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp
gloo_wrapper)
if(NOT WIN32)
set(PYBIND_DEPS ${PYBIND_DEPS} nccl_context)
set(PYBIND_DEPS ${PYBIND_DEPS} nccl_context data_loader)
endif(NOT WIN32)
if(WITH_PYTHON)
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/backward_strategy.h"
#include "paddle/fluid/imperative/data_loader.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/imperative/profiler.h"
......@@ -276,6 +277,19 @@ void BindImperative(py::module *m_ptr) {
imperative::SetCurrentTracer(tracer);
});
#ifndef _WIN32
// Dygraph DataLoader signal handler
m.def("_set_process_pid", [](int64_t key, pid_t pid) {
imperative::SetLoadProcessPID(key, pid);
});
m.def("_erase_process_pid",
[](int64_t key) { imperative::EraseLoadProcessPID(key); });
m.def("_set_process_signal_handler",
[]() { imperative::SetLoadProcessSignalHandler(); });
m.def("_throw_error_if_process_failed",
[]() { imperative::ThrowErrorIfLoadProcessFailed(); });
#endif
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
m, "VarBase",
R"DOC()DOC")
......
......@@ -184,6 +184,11 @@ if avx_supported():
from .core_avx import _save_dygraph_dict
from .core_avx import _load_dygraph_dict
from .core_avx import _create_loaded_parameter
if sys.platform != 'win32':
from .core_avx import _set_process_pid
from .core_avx import _erase_process_pid
from .core_avx import _set_process_signal_handler
from .core_avx import _throw_error_if_process_failed
except Exception as e:
if has_avx_core:
raise e
......@@ -220,6 +225,11 @@ if load_noavx:
from .core_noavx import _save_dygraph_dict
from .core_noavx import _load_dygraph_dict
from .core_noavx import _create_loaded_parameter
if sys.platform != 'win32':
from .core_noavx import _set_process_pid
from .core_noavx import _erase_process_pid
from .core_noavx import _set_process_signal_handler
from .core_noavx import _throw_error_if_process_failed
except Exception as e:
if has_noavx_core:
sys.stderr.write(
......
此差异已折叠。
......@@ -188,6 +188,9 @@ list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass)
if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_dataset)
list(REMOVE_ITEM TEST_OPS test_dataset_dataloader)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_process)
list(REMOVE_ITEM TEST_OPS test_imperative_signal_handler)
endif()
# Some ops need to check results when gc is enabled
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import core
import paddle.compat as cpt
def get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return image, label
def sample_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num * batch_size):
image, label = get_random_images_and_labels([784], [1])
yield image, label
return __reader__
def sample_list_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num):
sample_list = []
for _ in range(batch_size):
image, label = get_random_images_and_labels([784], [1])
sample_list.append([image, label])
yield sample_list
return __reader__
def batch_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num):
batch_image, batch_label = get_random_images_and_labels(
[batch_size, 784], [batch_size, 1])
yield batch_image, batch_label
return __reader__
class TestDygraphhDataLoader(unittest.TestCase):
def setUp(self):
self.batch_size = 8
self.batch_num = 4
self.epoch_num = 2
self.capacity = 2
def test_single_process_reader(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, iterable=False, use_multiprocess=False)
loader.set_sample_generator(
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size,
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_sample_genarator(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_sample_generator(
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size,
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_sample_list_generator(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_sample_list_generator(
sample_list_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_batch_genarator(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_batch_generator(
batch_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
class TestDygraphhDataLoaderWithException(unittest.TestCase):
def setUp(self):
self.batch_num = 4
self.capacity = 2
def test_not_capacity(self):
with fluid.dygraph.guard():
with self.assertRaisesRegexp(ValueError,
"Please give value to capacity."):
fluid.io.DataLoader.from_generator()
def test_single_process_with_thread_expection(self):
def error_sample_genarator(batch_num):
def __reader__():
for _ in range(batch_num):
yield [[[1, 2], [1]]]
return __reader__
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, iterable=False, use_multiprocess=False)
loader.set_batch_generator(
error_sample_genarator(self.batch_num), places=fluid.CPUPlace())
exception = None
try:
for _ in loader():
print("test_single_process_with_thread_expection")
except core.EnforceNotMet as ex:
self.assertIn("Blocking queue is killed",
cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
def test_multi_process_with_thread_expection(self):
def error_sample_genarator(batch_num):
def __reader__():
for _ in range(batch_num):
yield [[[1, 2], [1]]]
return __reader__
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_batch_generator(
error_sample_genarator(self.batch_num), places=fluid.CPUPlace())
exception = None
try:
for _ in loader():
print("test_multi_process_with_thread_expection")
except core.EnforceNotMet as ex:
self.assertIn("Blocking queue is killed",
cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle.fluid as fluid
if sys.version_info[0] == 2:
import Queue as queue
else:
import queue
def get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return image, label
def batch_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num):
batch_image, batch_label = get_random_images_and_labels(
[batch_size, 784], [batch_size, 1])
yield batch_image, batch_label
return __reader__
# NOTE: coverage CI can't cover child process code, so need these test.
# Here test child process loop function in main process
class TestDygraphhDataLoaderProcess(unittest.TestCase):
def setUp(self):
self.batch_size = 8
self.batch_num = 4
self.epoch_num = 2
self.capacity = 2
def test_reader_process_loop(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.batch_num + 1, use_multiprocess=True)
loader.set_batch_generator(
batch_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace())
loader._data_queue = queue.Queue(self.batch_num + 1)
loader._reader_process_loop()
for _ in range(self.batch_num):
loader._data_queue.get(timeout=10)
def test_reader_process_loop_simple_none(self):
def none_sample_genarator(batch_num):
def __reader__():
for _ in range(batch_num):
yield None
return __reader__
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.batch_num + 1, use_multiprocess=True)
loader.set_batch_generator(
none_sample_genarator(self.batch_num), places=fluid.CPUPlace())
loader._data_queue = queue.Queue(self.batch_num + 1)
exception = None
try:
loader._reader_process_loop()
except AttributeError as ex:
exception = ex
self.assertIsNotNone(exception)
if __name__ == '__main__':
unittest.main()
......@@ -242,8 +242,6 @@ class TestDygraphResnet(unittest.TestCase):
optimizer = optimizer_setting(
train_parameters, parameter_list=resnet.parameters())
np.random.seed(seed)
import random
random.seed = seed
batch_py_reader = fluid.io.PyReader(capacity=1)
batch_py_reader.decorate_sample_list_generator(
......@@ -330,8 +328,6 @@ class TestDygraphResnet(unittest.TestCase):
optimizer = optimizer_setting(train_parameters)
np.random.seed(seed)
import random
random.seed = seed
train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size)
......
......@@ -316,8 +316,6 @@ class TestImperativeResneXt(unittest.TestCase):
optimizer = optimizer_setting(
train_parameters, parameter_list=se_resnext.parameters())
np.random.seed(seed)
import random
random.seed = seed
batch_py_reader = fluid.io.PyReader(capacity=1)
batch_py_reader.decorate_sample_list_generator(
......@@ -379,8 +377,6 @@ class TestImperativeResneXt(unittest.TestCase):
optimizer = optimizer_setting(train_parameters)
np.random.seed(seed)
import random
random.seed = seed
train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import signal
import unittest
import multiprocessing
import time
import paddle.compat as cpt
from paddle.fluid import core
def set_child_signal_handler(self, child_pid):
core._set_process_pid(id(self), child_pid)
current_handler = signal.getsignal(signal.SIGCHLD)
if not callable(current_handler):
current_handler = None
def __handler__(signum, frame):
core._throw_error_if_process_failed()
if current_handler is not None:
current_handler(signum, frame)
signal.signal(signal.SIGCHLD, __handler__)
class TestDygraphDataLoaderSingalHandler(unittest.TestCase):
def test_child_process_exit_will_error(self):
def __test_process__():
core._set_process_signal_handler()
sys.exit(1)
exception = None
try:
test_process = multiprocessing.Process(target=__test_process__)
test_process.start()
set_child_signal_handler(id(self), test_process.pid)
time.sleep(1)
except core.EnforceNotMet as ex:
self.assertIn("FatalError", cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
def test_child_process_killed_by_sigsegv(self):
def __test_process__():
core._set_process_signal_handler()
os.kill(os.getpid(), signal.SIGSEGV)
exception = None
try:
test_process = multiprocessing.Process(target=__test_process__)
test_process.start()
set_child_signal_handler(id(self), test_process.pid)
time.sleep(1)
except core.EnforceNotMet as ex:
self.assertIn("FatalError", cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
def test_child_process_killed_by_sigterm(self):
def __test_process__():
core._set_process_signal_handler()
time.sleep(10)
test_process = multiprocessing.Process(target=__test_process__)
test_process.daemon = True
test_process.start()
set_child_signal_handler(id(self), test_process.pid)
time.sleep(1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册