未验证 提交 9fb5a3e5 编写于 作者: L Leo Chen 提交者: GitHub

[cherry-pick] Set expected place in child thread for dataloader #30383

* set expected place in child thread for dataloader

* set device id when set tensor from numpy

* revert tensor_py change

* add compile guard

* fix ci

* fix bug
上级 3fbc3cf4
......@@ -161,7 +161,7 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
}
VLOG(5) << "Init Tensor as: / name: " << name
<< " / persistable: " << persistable << " / zero_copy: " << zero_copy
<< " / stop_gradient: " << stop_gradient;
<< " / stop_gradient: " << stop_gradient << " / at " << place;
new (self) imperative::VarBase(name);
self->SetPersistable(persistable);
auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
......@@ -175,8 +175,8 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
const py::array &array) {
VLOG(4) << "Init VarBase from numpy: ";
auto place = imperative::GetCurrentTracer()->ExpectedPlace();
VLOG(4) << "Init VarBase from numpy at " << place;
InitTensorForVarBase(self, array, place);
}
......@@ -1206,15 +1206,44 @@ void BindImperative(py::module *m_ptr) {
if (py::isinstance<platform::CUDAPlace>(obj)) {
auto p = obj.cast<platform::CUDAPlace *>();
self.SetExpectedPlace(*p);
// NOTE(zhiqiu): When switching cuda place, we need to set the
// cuda device id.
// Otherwise, some cuda API may be launched at other cuda place,
// which may cost hundreds of MB of GPU memory due to the cuda
// lib.
#ifdef PADDLE_WITH_CUDA
platform::SetDeviceId(p->device);
#endif
VLOG(4) << "Tracer(" << &self << ")"
<< " set expected place " << *p;
} else if (py::isinstance<platform::XPUPlace>(obj)) {
auto p = obj.cast<platform::XPUPlace *>();
self.SetExpectedPlace(*p);
VLOG(4) << "Tracer(" << &self << ")"
<< " set expected place " << *p;
} else if (py::isinstance<platform::CPUPlace>(obj)) {
auto p = obj.cast<platform::CPUPlace *>();
self.SetExpectedPlace(*p);
VLOG(4) << "Tracer(" << &self << ")"
<< " set expected place " << *p;
} else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
auto p = obj.cast<platform::CUDAPinnedPlace *>();
self.SetExpectedPlace(*p);
VLOG(4) << "Tracer(" << &self << ")"
<< " set expected place " << *p;
} else if (py::isinstance<platform::Place>(obj)) {
auto p = obj.cast<platform::Place *>();
self.SetExpectedPlace(*p);
if (platform::is_gpu_place(*p)) {
// NOTE(zhiqu): same as obj is CUDAPlace.
#ifdef PADDLE_WITH_CUDA
platform::SetDeviceId(
BOOST_GET_CONST(platform::CUDAPlace, *p).device);
#endif
}
VLOG(4) << "Tracer(" << &self << ")"
<< " set expected place " << *p;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Incompatible Place Type: supports XPUPlace, CUDAPlace, "
......
......@@ -288,12 +288,14 @@ void SetTensorFromPyArrayT(
#endif
} else {
#ifdef PADDLE_WITH_CUDA
if (paddle::platform::is_gpu_place(place)) {
// TODO(zhiqiu): set SetDeviceId before calling cuda APIs.
auto dst = self->mutable_data<T>(place);
if (paddle::platform::is_cuda_pinned_place(place)) {
std::memcpy(dst, array.data(), array.nbytes());
} else if (paddle::platform::is_gpu_place(place)) {
paddle::platform::GpuMemcpySync(dst, array.data(), array.nbytes(),
cudaMemcpyHostToDevice);
} else if (paddle::platform::is_cuda_pinned_place(place)) {
auto dst = self->mutable_data<T>(place);
std::memcpy(dst, array.data(), array.nbytes());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Incompatible place type: Tensor.set() supports "
......
......@@ -24,6 +24,7 @@ import threading
import numpy as np
import multiprocessing
from collections import namedtuple
from paddle.fluid.framework import _set_expected_place, _current_expected_place
# NOTE: queue has a different name in python2 and python3
if six.PY2:
......@@ -297,12 +298,20 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self._need_check_feed, self._places, self._use_buffer_reader, True,
self._pin_memory)
self._thread = threading.Thread(target=self._thread_loop)
self._thread = threading.Thread(
target=self._thread_loop, args=(_current_expected_place(), ))
self._thread.daemon = True
self._thread.start()
def _thread_loop(self):
def _thread_loop(self, legacy_expected_place):
try:
#NOTE(zhiqiu): Set the expected place for new thread as the same as father thread,
# and it will call platform::SetDeviceId() in c++ internally.
# If we do not set cudaDeviceId in new thread, the default cudaDeviceId will be 0,
# Which may cost hundreds of MB of GPU memory on CUDAPlace(0) if calling some cuda
# APIs in this thread.
_set_expected_place(legacy_expected_place)
for indices in self._sampler_iter:
# read data from dataset in mini-batch
batch = self._dataset_fetcher.fetch(indices)
......@@ -563,7 +572,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._pin_memory)
self._thread_done_event = threading.Event()
self._thread = threading.Thread(target=self._thread_loop)
self._thread = threading.Thread(
target=self._thread_loop, args=(_current_expected_place(), ))
self._thread.daemon = True
self._thread.start()
......@@ -603,7 +613,14 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._blocking_queue.kill()
logging.error("DataLoader reader thread raised an exception!")
def _thread_loop(self):
def _thread_loop(self, legacy_expected_place):
#NOTE(zhiqiu): Set the expected place for new thread as the same as father thread,
# and it will call platform::SetDeviceId() in c++ internally.
# If we do not set cudaDeviceId in new thread, the default cudaDeviceId will be 0,
# Which may cost hundreds of MB of GPU memory on CUDAPlace(0) if calling some cuda
# APIs in this thread.
_set_expected_place(legacy_expected_place)
while not self._thread_done_event.is_set():
batch = self._get_data()
if not self._thread_done_event.is_set():
......
......@@ -379,7 +379,6 @@ def guard(place=None):
expected_place = _get_paddle_place(place)
else:
expected_place = framework._current_expected_place()
tracer._expected_place = expected_place
with framework.program_guard(train, startup):
with framework.unique_name.guard():
......
......@@ -5660,15 +5660,15 @@ def _get_var(name, program=None):
@signature_safe_contextmanager
def _dygraph_guard(tracer):
global _dygraph_tracer_
tmp_trace = _dygraph_tracer_
tmp_tracer = _dygraph_tracer_
_dygraph_tracer_ = tracer
core._switch_tracer(tracer)
try:
yield
finally:
core._switch_tracer(tmp_trace)
_dygraph_tracer_ = tmp_trace
core._switch_tracer(tmp_tracer)
_dygraph_tracer_ = tmp_tracer
@signature_safe_contextmanager
......@@ -5677,10 +5677,13 @@ def _dygraph_place_guard(place):
tmp_place = _global_expected_place_
_global_expected_place_ = place
_set_dygraph_tracer_expected_place(place)
try:
yield
finally:
_global_expected_place_ = tmp_place
_set_dygraph_tracer_expected_place(tmp_place)
def load_op_library(lib_filename):
......
......@@ -32,7 +32,7 @@ from ..unique_name import generate as unique_name
import logging
from ..data_feeder import check_dtype, check_type
from paddle.fluid.framework import static_only
from ..framework import _get_paddle_place
from ..framework import _get_paddle_place, _current_expected_place, _set_expected_place
__all__ = [
'data', 'read_file', 'double_buffer', 'py_reader',
......@@ -475,8 +475,11 @@ def _py_reader(capacity,
reader.exited = False
def start_provide_thread(func):
def __provider_thread__():
def __provider_thread__(legacy_expected_place):
try:
# See _DataLoaderIterSingleProcess._thread_loop() for why set expected place here.
_set_expected_place(legacy_expected_place)
for tensors in func():
array = core.LoDTensorArray()
for item in tensors:
......@@ -498,7 +501,8 @@ def _py_reader(capacity,
logging.warn('Your decorated reader has raised an exception!')
six.reraise(*sys.exc_info())
reader.thread = threading.Thread(target=__provider_thread__)
reader.thread = threading.Thread(
target=__provider_thread__, args=(_current_expected_place(), ))
reader.thread.daemon = True
reader.thread.start()
......
......@@ -28,6 +28,7 @@ from .dataloader.batch_sampler import _InfiniteIterableSampler
from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer
from .unique_name import UniqueNameGenerator
from .framework import _get_paddle_place, _get_paddle_place_list
from paddle.fluid.framework import _set_expected_place, _current_expected_place
import logging
import warnings
......@@ -928,12 +929,14 @@ class DygraphGeneratorLoader(DataLoaderBase):
# Set reader_thread
self._thread_done_event = threading.Event()
self._thread = threading.Thread(
target=self._reader_thread_loop_for_multiprocess)
target=self._reader_thread_loop_for_multiprocess,
args=(_current_expected_place(), ))
self._thread.daemon = True
self._thread.start()
else:
self._thread = threading.Thread(
target=self._reader_thread_loop_for_singleprocess)
target=self._reader_thread_loop_for_singleprocess,
args=(_current_expected_place(), ))
self._thread.daemon = True
self._thread.start()
......@@ -968,7 +971,10 @@ class DygraphGeneratorLoader(DataLoaderBase):
self._blocking_queue.kill()
logging.error("DataLoader reader thread raised an exception!")
def _reader_thread_loop_for_multiprocess(self):
def _reader_thread_loop_for_multiprocess(self, legacy_expected_place):
# See _DataLoaderIterSingleProcess._thread_loop() for why set expected place here.
_set_expected_place(legacy_expected_place)
while not self._thread_done_event.is_set():
try:
# NOTE: [ avoid hanging ] Even with carefully designed data dependencies
......@@ -1007,8 +1013,11 @@ class DygraphGeneratorLoader(DataLoaderBase):
else:
self._exit_thread_expectedly()
def _reader_thread_loop_for_singleprocess(self):
def _reader_thread_loop_for_singleprocess(self, legacy_expected_place):
try:
# See _DataLoaderIterSingleProcess._thread_loop() for why set expected place here.
_set_expected_place(legacy_expected_place)
for sample in self._batch_reader():
array = core.LoDTensorArray()
for item in sample:
......@@ -1248,8 +1257,11 @@ class GeneratorLoader(DataLoaderBase):
self._reset()
def _start(self):
def __thread_main__():
def __thread_main__(legacy_expected_place):
try:
# See _DataLoaderIterSingleProcess._thread_loop() for why set expected place here.
_set_expected_place(legacy_expected_place)
while not self._queue.wait_for_inited(1):
if self._exited:
return
......@@ -1276,7 +1288,8 @@ class GeneratorLoader(DataLoaderBase):
logging.warn('Your reader has raised an exception!')
six.reraise(*sys.exc_info())
self._thread = threading.Thread(target=__thread_main__)
self._thread = threading.Thread(
target=__thread_main__, args=(_current_expected_place(), ))
self._thread.daemon = True
self._thread.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册