未验证 提交 35efbe6d 编写于 作者: C Chen Weihang 提交者: GitHub

Speeding up dygraph DataLoader with multiprocessing (#21762)

* add multiprocess for dygraph data loader, test=develop

* polish code & add safe gurad, test=develop

* refactor dygraph dataloader & add signal handler, test=develop

* fix member initializer compile error on ci, test=develop

* fix member initializer compile error one more, test=develop

* remove useless config, test=develop

* skip windows incompatible problem, test=develop

* add unittest for coverage, test=coverage

* add more exception unittest case, test=develop

* deal with signal handler coverage, test=develop

* polish code & add signal handler tests, test=develop

* deal with coverage ci problem, test=develop

* split data loader test & coverage ci fix, test=develop

* remove test_imperative_data_loader_with_exception, test=develop

* remove singal process except test case, test=develop

* add exception tests again & remove sample list test, test=develop

* split normal and exception unittests to diff class, test=develop

* polish doc for use_multiprocess effect in static mode, test=develop
上级 5751509e
develop 2.0.1-rocm-post Ligoml-patch-1 OliverLPH-patch-1 OliverLPH-patch-2 PaddlePM-patch-1 PaddlePM-patch-2 ZHUI-patch-1 add_default_att add_model_benchmark_ci add_some_yaml_config addfile all_new_design_exec ascendrc ascendrelease cherry_undefined_var compile_windows delete_2.0.1-rocm-post delete_add_default_att delete_all_new_design_exec delete_ascendrc delete_compile_windows delete_delete_addfile delete_disable_iterable_dataset_unittest delete_fix_dataloader_memory_leak delete_fix_imperative_dygraph_error delete_fix_retry_ci delete_fix_undefined_var delete_improve_sccache delete_paralleltest delete_prv-disable-more-cache delete_revert-31068-fix_conv3d_windows delete_revert-31562-mean delete_revert-33630-bug-fix delete_revert-34159-add_npu_bce_logical_dev delete_revert-34910-spinlocks_for_allocator delete_revert-35069-revert-34910-spinlocks_for_allocator delete_revert-36057-dev/read_flags_in_ut dingjiaweiww-patch-1 disable_iterable_dataset_unittest dy2static enable_eager_model_test final_state_gen_python_c final_state_intermediate fix-numpy-issue fix_concat_slice fix_dataloader_memory_leak fix_imperative_dygraph_error fix_npu_ci fix_op_flops fix_retry_ci fix_rnn_docs fix_tensor_type fix_undefined_var fixiscan fixiscan1 fixiscan2 fixiscan3 github/fork/123malin/netifaces github/fork/123malin/tdm_abacus github/fork/AshburnLee/dev_unique github/fork/ForFishes/fix_memory_matmul github/fork/ForFishes/rm_fluid github/fork/LielinJiang/move-2.0-api github/fork/LielinJiang/visual-dl-cb github/fork/LiuChiachi/add-transformer-generate-square-subsequent-mask-api github/fork/LiuChiachi/fix-example-code-for-hapi-Model github/fork/LiuChiachi/remove-input-requirment-in-dygraph-Model github/fork/MrChengmo/fix_ps_profiler github/fork/MrChengmo/update_ps_heter github/fork/PWhiddy/patch-1 github/fork/Shixiaowei02/dev/save_load_upgrade github/fork/TCChenlong/fix_hapi github/fork/TCChenlong/fix_inden github/fork/Thunderbrook/xpu_slice github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_2 github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_3 github/fork/XieYunshen/timeout_20S_ut github/fork/ZeyuChen/remove-nltk github/fork/arlesniak/arlesniak/selective__mkldnn_flags github/fork/baiyfbupt/code_doc_mig github/fork/chalsliu/set_timeout github/fork/chen-zhiyu/develop github/fork/chenwhql/ci/try_to_find_test_buffer_shared_memory_reuse_pass_error github/fork/chenwhql/dygraph/remove_scale_loss_and_apply_collective_grads github/fork/chenwhql/saveload/add_get_inference_program github/fork/chenwhql/saveload/remove_save_load_config github/fork/cryoco/pass-compatibility-trt github/fork/danleifeng/isempty_api2.0 github/fork/frankwhzhang/api_transfer github/fork/hbwx24/error_msg/cuda_kernel_error_msg github/fork/heavengate/cherry_yolo_box github/fork/heavengate/update_yolo_box github/fork/iclementine/rnn_fix github/fork/iducn/testestse github/fork/jczaja/prv-25537-fix github/fork/jeff41404/release/1.8 github/fork/jiweibo/api_2.0 github/fork/jiweibo/fix_lite_resnet50_test github/fork/juncaipeng/fix_doc_1 github/fork/lfchener/sample_code github/fork/littletomatodonkey/fix_reg_doc github/fork/liym27/dy2stat_update_assign_to_rc20 github/fork/luotao1/profiler_ut github/fork/mapingshuo/add_wait github/fork/mapingshuo/doc_2.0 github/fork/mapingshuo/zero-0.5 github/fork/miraiwk/dev github/fork/pangyoki/add-Categorical-class-branch github/fork/pangyoki/add-multinomial-op-branch github/fork/pangyoki/fix-test_distritbution-CI github/fork/qjing666/doublegrad github/fork/qjing666/fix_hdfs_download github/fork/sandyhouse/add_gather_etc github/fork/sandyhouse/add_send_recv_alltoall_etc github/fork/sandyhouse/pipeline_exe_run github/fork/seiriosPlus/feature/large_scale_kv_save_delta github/fork/seiriosPlus/fix/paddle_errors_fix github/fork/seiriosPlus/fix/paddle_op_errors github/fork/shangzhizhou/fix_test_activation_op_random_bug github/fork/smallv0221/yxp0924 github/fork/smallv0221/yxp0925 github/fork/swtkiwi/del-matplotlib github/fork/tianshuo78520a/kunlun_test github/fork/tianshuo78520a/update_dockerfile github/fork/wanghaoshuang/bert_fuse github/fork/wanghaoshuang/label_smooth github/fork/wanghuancoder/develop_CUDASynchronize github/fork/wanghuancoder/develop_Layer_doc github/fork/wanghuancoder/develop_ParameterList_doc github/fork/wanghuancoder/develop_Sequential_doc github/fork/wanghuancoder/develop_bilinear_tensor_product github/fork/wanghuancoder/develop_coverage_build_sh github/fork/wanghuancoder/develop_in_dynamic_mode_doc github/fork/wanghuancoder/develop_unique_name_doc github/fork/wangxicoding/fleet_meta_combine github/fork/wawltor/error_message_fix_5 github/fork/willthefrog/remove_l2_norm github/fork/windstamp/momentum_op github/fork/windstamp/mv_op_5 github/fork/windstamp/normal_api github/fork/wojtuss/wojtuss/fusion_gru_quantization github/fork/wojtuss/wojtuss/quantization-with-shift github/fork/wzzju/fix_err_info github/fork/wzzju/pure_fp16 github/fork/xiemoyuan/op_error_message github/fork/xiemoyuan/optimize_error_message github/fork/yaoxuefeng6/fix_doc github/fork/yaoxuefeng6/mod_dataset_v2 github/fork/yongqiangma/lod github/fork/ysh329/fix-clip-by-norm-error github/fork/ysh329/fix-error-clip-by-value github/fork/yukavio/error_info github/fork/zhangting2020/conv_filter_grad github/fork/zhangting2020/is_compile_with_cuda github/fork/zhangting2020/place_doc github/fork/zhangting2020/program github/fork/zhhsplendid/fix_any github/fork/zhhsplendid/refine_api2 github/fork/zhhsplendid/refine_api2_test github/fork/zhhsplendid/refine_api_test_ptb_lm github/fork/zhhsplendid/refine_api_test_resnet github/fork/zhhsplendid/refine_api_test_simnet github/fork/zhiqiu/dev/refine_initializer github/fork/zhiqiu/dev/remove_inplace_argument github/fork/zlsh80826/nvinfer_plugin_var_len_cuda11 improve_sccache incubate/infrt inplace_addto make_flag_adding_easier move_embedding_to_phi move_histogram_to_pten move_sgd_to_phi move_slice_to_pten move_temporal_shift_to_phi move_yolo_box_to_phi npu_fix_alloc numel paralleltest preln_ernie prv-disable-more-cache prv-md-even-more prv-onednn-2.5 pten_tensor_refactor release/1.8 release/2.0 release/2.0-alpha release/2.0-beta release/2.0-rc release/2.0-rc1 release/2.1 release/2.2 release/2.3 release/2.3-fc-ernie-fix release/2.4 revert-24981-add_device_attr_for_regulization revert-26856-strategy_example2 revert-27520-disable_pr revert-31068-fix_conv3d_windows revert-31562-mean revert-32290-develop-hardlabel revert-33037-forci revert-33475-fix_cifar_label_dimension revert-33630-bug-fix revert-34159-add_npu_bce_logical_dev revert-34406-add_copy_from_tensor revert-34910-spinlocks_for_allocator revert-35069-revert-34910-spinlocks_for_allocator revert-36057-dev/read_flags_in_ut revert-36201-refine_fast_threaded_ssa_graph_executor revert-36985-add_license revert-37318-refactor_dygraph_to_eager revert-37926-eager_coreops_500 revert-37956-revert-37727-pylayer_support_tuple revert-38100-mingdong revert-38301-allocation_rearrange_pr revert-38703-numpy_bf16_package_reupload revert-38732-remove_useless_header_in_elementwise_mul_grad revert-38959-Reduce_Grad revert-39143-adjust_empty revert-39227-move_trace_op_to_pten revert-39268-dev/remove_concat_fluid_kernel revert-40170-support_partial_grad revert-41056-revert-40727-move_some_activaion_to_phi revert-41065-revert-40993-mv_ele_floordiv_pow revert-41068-revert-40790-phi_new revert-41944-smaller_inference_api_test revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment rocm_dev_0217 support_weight_transpose test_benchmark_ci test_feature_precision_test_c test_model_benchmark test_model_benchmark_ci zhiqiu-patch-1 v2.4.0-rc0 v2.3.2 v2.3.1 v2.3.0 v2.3.0-rc0 v2.2.2 v2.2.1 v2.2.0 v2.2.0-rc0 v2.2.0-bak0 v2.1.3 v2.1.2 v2.1.1 v2.1.0 v2.1.0-rc0 v2.0.2 v2.0.1 v2.0.0 v2.0.0-rc1 v2.0.0-rc0 v2.0.0-beta0 v2.0.0-alpha0 v1.8.5 v1.8.4 v1.8.3 v1.8.2 v1.8.1 v1.8.0
无相关合并请求
......@@ -10,6 +10,7 @@ cc_library(engine SRCS engine.cc DEPS layer gradient_accumulator)
cc_library(imperative_profiler SRCS profiler.cc)
if(NOT WIN32)
cc_library(nccl_context SRCS nccl_context.cc DEPS device_context)
cc_library(data_loader SRCS data_loader.cc DEPS enforce)
endif(NOT WIN32)
add_subdirectory(tests)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef _WIN32
#include "paddle/fluid/imperative/data_loader.h"
#include <string.h>
#include <sys/wait.h>
#include <atomic>
#include <csignal>
#include <map>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace imperative {
static std::map<int64_t, pid_t> load_process_pids;
void SetLoadProcessPID(int64_t key, pid_t pid) {
VLOG(3) << "Dygraph Data Loader: set loader child process PID (" << key
<< ", " << pid << ")";
load_process_pids[key] = pid;
}
void EraseLoadProcessPID(int64_t key) {
auto it = load_process_pids.find(key);
// Note: Can not find key also possible
if (it != load_process_pids.end()) {
VLOG(3) << "Dygraph Data Loader: erase loader child process PID (" << key
<< ")";
load_process_pids.erase(it);
} else {
VLOG(3) << "Dygraph Data Loader: The dygrph loader (id: " << key
<< ") you want erase does not exist.";
}
}
// sigaction doc: http://man7.org/linux/man-pages/man2/sigaction.2.html
// sigemptyset doc: https://linux.die.net/man/3/sigemptyset
// siginfo_t doc: https://www.mkssoftware.com/docs/man5/siginfo_t.5.asp
// waitid doc: https://linux.die.net/man/2/waitid
#define SIGNAL_HANDLE(SIGNAL) \
do { \
struct sigaction sa; \
sa.sa_handler = SIG_DFL; \
sa.sa_flags = 0; \
if (sigemptyset(&sa.sa_mask) != 0 || \
sigaction(SIGNAL, &sa, nullptr) != 0) { \
_exit(EXIT_FAILURE); \
} else { \
raise(SIGNAL); \
} \
} while (0)
#define REGISTER_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \
SIGNAL_HANDLE(SIGNAL); \
}
#define REGISTER_SPEC_SIGNAL_HANDLER(SIGNAL, HANDLER_NAME) \
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) { \
if (info->si_pid == getppid()) { \
_exit(EXIT_SUCCESS); \
} \
SIGNAL_HANDLE(SIGNAL); \
}
REGISTER_SIGNAL_HANDLER(SIGSEGV, SIGSEGV_handler);
REGISTER_SIGNAL_HANDLER(SIGBUS, SIGBUS_handler);
REGISTER_SPEC_SIGNAL_HANDLER(SIGTERM, SIGTERM_handler);
static inline void setSignalHandler(int signal,
void (*handler)(int, siginfo_t *, void *),
struct sigaction *old_sa_ptr) {
struct sigaction sa;
sa.sa_sigaction = handler;
sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER;
if (sigemptyset(&sa.sa_mask) != 0 ||
sigaction(signal, &sa, old_sa_ptr) != 0) {
PADDLE_THROW(platform::errors::Fatal(
"An error occurred while setting handler for %s.", strsignal(signal)));
}
}
// Note: maybe need to add other signal handler
void SetLoadProcessSignalHandler() {
setSignalHandler(SIGSEGV, &SIGSEGV_handler, nullptr);
setSignalHandler(SIGBUS, &SIGBUS_handler, nullptr);
setSignalHandler(SIGTERM, &SIGTERM_handler, nullptr);
}
void ThrowErrorIfLoadProcessFailed() {
int error;
pid_t process_pid;
siginfo_t infop;
for (auto &w : load_process_pids) {
process_pid = w.second;
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
// and other handlers can get whatever info they want about the child.
infop.si_pid = 0;
VLOG(3) << "Dygraph Data Loader: monitor loader child process "
<< process_pid;
error = waitid(P_PID, process_pid, &infop, WEXITED | WNOHANG | WNOWAIT);
// ignore errors and case with no waitable child
if (error < 0 || infop.si_pid == 0) continue;
if (infop.si_code == CLD_EXITED &&
infop.si_status != EXIT_SUCCESS) { // exit with error
PADDLE_THROW(platform::errors::Fatal(
"DataLoader process (pid %ld) exited unexpectedly with code %d. "
"Error detailed are lost due to multiprocessing. Rerunning with "
"DataLoader.from_generator(..., use_multiprocess=False) may give "
"better error trace.",
process_pid, infop.si_status));
} else if (infop.si_code == CLD_KILLED ||
infop.si_code == CLD_DUMPED) { // killed by signal
PADDLE_THROW(platform::errors::Fatal(
"DataLoader process (pid %ld) exited is killed by signal: %s.",
process_pid, strsignal(infop.si_status)));
}
}
}
} // namespace imperative
} // namespace paddle
#endif
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _WIN32
#include <unistd.h>
#include <cstdint>
namespace paddle {
namespace imperative {
extern void SetLoadProcessPID(int64_t key, pid_t pid);
extern void EraseLoadProcessPID(int64_t key);
extern void SetLoadProcessSignalHandler();
extern void ThrowErrorIfLoadProcessFailed();
} // namespace imperative
} // namespace paddle
#endif
......@@ -4,7 +4,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp
gloo_wrapper)
if(NOT WIN32)
set(PYBIND_DEPS ${PYBIND_DEPS} nccl_context)
set(PYBIND_DEPS ${PYBIND_DEPS} nccl_context data_loader)
endif(NOT WIN32)
if(WITH_PYTHON)
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/backward_strategy.h"
#include "paddle/fluid/imperative/data_loader.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/imperative/profiler.h"
......@@ -276,6 +277,19 @@ void BindImperative(py::module *m_ptr) {
imperative::SetCurrentTracer(tracer);
});
#ifndef _WIN32
// Dygraph DataLoader signal handler
m.def("_set_process_pid", [](int64_t key, pid_t pid) {
imperative::SetLoadProcessPID(key, pid);
});
m.def("_erase_process_pid",
[](int64_t key) { imperative::EraseLoadProcessPID(key); });
m.def("_set_process_signal_handler",
[]() { imperative::SetLoadProcessSignalHandler(); });
m.def("_throw_error_if_process_failed",
[]() { imperative::ThrowErrorIfLoadProcessFailed(); });
#endif
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
m, "VarBase",
R"DOC()DOC")
......
......@@ -184,6 +184,11 @@ if avx_supported():
from .core_avx import _save_dygraph_dict
from .core_avx import _load_dygraph_dict
from .core_avx import _create_loaded_parameter
if sys.platform != 'win32':
from .core_avx import _set_process_pid
from .core_avx import _erase_process_pid
from .core_avx import _set_process_signal_handler
from .core_avx import _throw_error_if_process_failed
except Exception as e:
if has_avx_core:
raise e
......@@ -220,6 +225,11 @@ if load_noavx:
from .core_noavx import _save_dygraph_dict
from .core_noavx import _load_dygraph_dict
from .core_noavx import _create_loaded_parameter
if sys.platform != 'win32':
from .core_noavx import _set_process_pid
from .core_noavx import _erase_process_pid
from .core_noavx import _set_process_signal_handler
from .core_noavx import _throw_error_if_process_failed
except Exception as e:
if has_noavx_core:
sys.stderr.write(
......
此差异已折叠。
......@@ -188,6 +188,9 @@ list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass)
if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_dataset)
list(REMOVE_ITEM TEST_OPS test_dataset_dataloader)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_process)
list(REMOVE_ITEM TEST_OPS test_imperative_signal_handler)
endif()
# Some ops need to check results when gc is enabled
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import core
import paddle.compat as cpt
def get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return image, label
def sample_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num * batch_size):
image, label = get_random_images_and_labels([784], [1])
yield image, label
return __reader__
def sample_list_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num):
sample_list = []
for _ in range(batch_size):
image, label = get_random_images_and_labels([784], [1])
sample_list.append([image, label])
yield sample_list
return __reader__
def batch_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num):
batch_image, batch_label = get_random_images_and_labels(
[batch_size, 784], [batch_size, 1])
yield batch_image, batch_label
return __reader__
class TestDygraphhDataLoader(unittest.TestCase):
def setUp(self):
self.batch_size = 8
self.batch_num = 4
self.epoch_num = 2
self.capacity = 2
def test_single_process_reader(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, iterable=False, use_multiprocess=False)
loader.set_sample_generator(
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size,
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_sample_genarator(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_sample_generator(
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size,
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_sample_list_generator(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_sample_list_generator(
sample_list_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_batch_genarator(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_batch_generator(
batch_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
class TestDygraphhDataLoaderWithException(unittest.TestCase):
def setUp(self):
self.batch_num = 4
self.capacity = 2
def test_not_capacity(self):
with fluid.dygraph.guard():
with self.assertRaisesRegexp(ValueError,
"Please give value to capacity."):
fluid.io.DataLoader.from_generator()
def test_single_process_with_thread_expection(self):
def error_sample_genarator(batch_num):
def __reader__():
for _ in range(batch_num):
yield [[[1, 2], [1]]]
return __reader__
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, iterable=False, use_multiprocess=False)
loader.set_batch_generator(
error_sample_genarator(self.batch_num), places=fluid.CPUPlace())
exception = None
try:
for _ in loader():
print("test_single_process_with_thread_expection")
except core.EnforceNotMet as ex:
self.assertIn("Blocking queue is killed",
cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
def test_multi_process_with_thread_expection(self):
def error_sample_genarator(batch_num):
def __reader__():
for _ in range(batch_num):
yield [[[1, 2], [1]]]
return __reader__
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_batch_generator(
error_sample_genarator(self.batch_num), places=fluid.CPUPlace())
exception = None
try:
for _ in loader():
print("test_multi_process_with_thread_expection")
except core.EnforceNotMet as ex:
self.assertIn("Blocking queue is killed",
cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle.fluid as fluid
if sys.version_info[0] == 2:
import Queue as queue
else:
import queue
def get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return image, label
def batch_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num):
batch_image, batch_label = get_random_images_and_labels(
[batch_size, 784], [batch_size, 1])
yield batch_image, batch_label
return __reader__
# NOTE: coverage CI can't cover child process code, so need these test.
# Here test child process loop function in main process
class TestDygraphhDataLoaderProcess(unittest.TestCase):
def setUp(self):
self.batch_size = 8
self.batch_num = 4
self.epoch_num = 2
self.capacity = 2
def test_reader_process_loop(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.batch_num + 1, use_multiprocess=True)
loader.set_batch_generator(
batch_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace())
loader._data_queue = queue.Queue(self.batch_num + 1)
loader._reader_process_loop()
for _ in range(self.batch_num):
loader._data_queue.get(timeout=10)
def test_reader_process_loop_simple_none(self):
def none_sample_genarator(batch_num):
def __reader__():
for _ in range(batch_num):
yield None
return __reader__
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.batch_num + 1, use_multiprocess=True)
loader.set_batch_generator(
none_sample_genarator(self.batch_num), places=fluid.CPUPlace())
loader._data_queue = queue.Queue(self.batch_num + 1)
exception = None
try:
loader._reader_process_loop()
except AttributeError as ex:
exception = ex
self.assertIsNotNone(exception)
if __name__ == '__main__':
unittest.main()
......@@ -242,8 +242,6 @@ class TestDygraphResnet(unittest.TestCase):
optimizer = optimizer_setting(
train_parameters, parameter_list=resnet.parameters())
np.random.seed(seed)
import random
random.seed = seed
batch_py_reader = fluid.io.PyReader(capacity=1)
batch_py_reader.decorate_sample_list_generator(
......@@ -330,8 +328,6 @@ class TestDygraphResnet(unittest.TestCase):
optimizer = optimizer_setting(train_parameters)
np.random.seed(seed)
import random
random.seed = seed
train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size)
......
......@@ -316,8 +316,6 @@ class TestImperativeResneXt(unittest.TestCase):
optimizer = optimizer_setting(
train_parameters, parameter_list=se_resnext.parameters())
np.random.seed(seed)
import random
random.seed = seed
batch_py_reader = fluid.io.PyReader(capacity=1)
batch_py_reader.decorate_sample_list_generator(
......@@ -379,8 +377,6 @@ class TestImperativeResneXt(unittest.TestCase):
optimizer = optimizer_setting(train_parameters)
np.random.seed(seed)
import random
random.seed = seed
train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False),
batch_size=batch_size,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import signal
import unittest
import multiprocessing
import time
import paddle.compat as cpt
from paddle.fluid import core
def set_child_signal_handler(self, child_pid):
core._set_process_pid(id(self), child_pid)
current_handler = signal.getsignal(signal.SIGCHLD)
if not callable(current_handler):
current_handler = None
def __handler__(signum, frame):
core._throw_error_if_process_failed()
if current_handler is not None:
current_handler(signum, frame)
signal.signal(signal.SIGCHLD, __handler__)
class TestDygraphDataLoaderSingalHandler(unittest.TestCase):
def test_child_process_exit_will_error(self):
def __test_process__():
core._set_process_signal_handler()
sys.exit(1)
exception = None
try:
test_process = multiprocessing.Process(target=__test_process__)
test_process.start()
set_child_signal_handler(id(self), test_process.pid)
time.sleep(1)
except core.EnforceNotMet as ex:
self.assertIn("FatalError", cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
def test_child_process_killed_by_sigsegv(self):
def __test_process__():
core._set_process_signal_handler()
os.kill(os.getpid(), signal.SIGSEGV)
exception = None
try:
test_process = multiprocessing.Process(target=__test_process__)
test_process.start()
set_child_signal_handler(id(self), test_process.pid)
time.sleep(1)
except core.EnforceNotMet as ex:
self.assertIn("FatalError", cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
def test_child_process_killed_by_sigterm(self):
def __test_process__():
core._set_process_signal_handler()
time.sleep(10)
test_process = multiprocessing.Process(target=__test_process__)
test_process.daemon = True
test_process.start()
set_child_signal_handler(id(self), test_process.pid)
time.sleep(1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部