未验证 提交 3d015f1c 编写于 作者: L Leo Chen 提交者: GitHub

Set expected place in child thread for dataloader to avoid costing cuda memory...

Set expected place in child thread for dataloader to avoid costing cuda memory on other card (#30338)

* set expected place in child thread for dataloader

* set device id when set tensor from numpy

* revert tensor_py change

* add compile guard

* fix ci

* fix bug
上级 2c1bba02
...@@ -161,7 +161,7 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self, ...@@ -161,7 +161,7 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
} }
VLOG(5) << "Init Tensor as: / name: " << name VLOG(5) << "Init Tensor as: / name: " << name
<< " / persistable: " << persistable << " / zero_copy: " << zero_copy << " / persistable: " << persistable << " / zero_copy: " << zero_copy
<< " / stop_gradient: " << stop_gradient; << " / stop_gradient: " << stop_gradient << " / at " << place;
new (self) imperative::VarBase(name); new (self) imperative::VarBase(name);
self->SetPersistable(persistable); self->SetPersistable(persistable);
auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>(); auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
...@@ -175,8 +175,8 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self, ...@@ -175,8 +175,8 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self, static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
const py::array &array) { const py::array &array) {
VLOG(4) << "Init VarBase from numpy: ";
auto place = imperative::GetCurrentTracer()->ExpectedPlace(); auto place = imperative::GetCurrentTracer()->ExpectedPlace();
VLOG(4) << "Init VarBase from numpy at " << place;
InitTensorForVarBase(self, array, place); InitTensorForVarBase(self, array, place);
} }
...@@ -1206,15 +1206,44 @@ void BindImperative(py::module *m_ptr) { ...@@ -1206,15 +1206,44 @@ void BindImperative(py::module *m_ptr) {
if (py::isinstance<platform::CUDAPlace>(obj)) { if (py::isinstance<platform::CUDAPlace>(obj)) {
auto p = obj.cast<platform::CUDAPlace *>(); auto p = obj.cast<platform::CUDAPlace *>();
self.SetExpectedPlace(*p); self.SetExpectedPlace(*p);
// NOTE(zhiqiu): When switching cuda place, we need to set the
// cuda device id.
// Otherwise, some cuda API may be launched at other cuda place,
// which may cost hundreds of MB of GPU memory due to the cuda
// lib.
#ifdef PADDLE_WITH_CUDA
platform::SetDeviceId(p->device);
#endif
VLOG(4) << "Tracer(" << &self << ")"
<< " set expected place " << *p;
} else if (py::isinstance<platform::XPUPlace>(obj)) { } else if (py::isinstance<platform::XPUPlace>(obj)) {
auto p = obj.cast<platform::XPUPlace *>(); auto p = obj.cast<platform::XPUPlace *>();
self.SetExpectedPlace(*p); self.SetExpectedPlace(*p);
VLOG(4) << "Tracer(" << &self << ")"
<< " set expected place " << *p;
} else if (py::isinstance<platform::CPUPlace>(obj)) { } else if (py::isinstance<platform::CPUPlace>(obj)) {
auto p = obj.cast<platform::CPUPlace *>(); auto p = obj.cast<platform::CPUPlace *>();
self.SetExpectedPlace(*p); self.SetExpectedPlace(*p);
VLOG(4) << "Tracer(" << &self << ")"
<< " set expected place " << *p;
} else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) { } else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
auto p = obj.cast<platform::CUDAPinnedPlace *>(); auto p = obj.cast<platform::CUDAPinnedPlace *>();
self.SetExpectedPlace(*p); self.SetExpectedPlace(*p);
VLOG(4) << "Tracer(" << &self << ")"
<< " set expected place " << *p;
} else if (py::isinstance<platform::Place>(obj)) {
auto p = obj.cast<platform::Place *>();
self.SetExpectedPlace(*p);
if (platform::is_gpu_place(*p)) {
// NOTE(zhiqu): same as obj is CUDAPlace.
#ifdef PADDLE_WITH_CUDA
platform::SetDeviceId(
BOOST_GET_CONST(platform::CUDAPlace, *p).device);
#endif
}
VLOG(4) << "Tracer(" << &self << ")"
<< " set expected place " << *p;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Incompatible Place Type: supports XPUPlace, CUDAPlace, " "Incompatible Place Type: supports XPUPlace, CUDAPlace, "
......
...@@ -288,12 +288,14 @@ void SetTensorFromPyArrayT( ...@@ -288,12 +288,14 @@ void SetTensorFromPyArrayT(
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (paddle::platform::is_gpu_place(place)) {
// TODO(zhiqiu): set SetDeviceId before calling cuda APIs.
auto dst = self->mutable_data<T>(place); auto dst = self->mutable_data<T>(place);
if (paddle::platform::is_cuda_pinned_place(place)) {
std::memcpy(dst, array.data(), array.nbytes());
} else if (paddle::platform::is_gpu_place(place)) {
paddle::platform::GpuMemcpySync(dst, array.data(), array.nbytes(), paddle::platform::GpuMemcpySync(dst, array.data(), array.nbytes(),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} else if (paddle::platform::is_cuda_pinned_place(place)) {
auto dst = self->mutable_data<T>(place);
std::memcpy(dst, array.data(), array.nbytes());
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Incompatible place type: Tensor.set() supports " "Incompatible place type: Tensor.set() supports "
......
...@@ -24,6 +24,7 @@ import threading ...@@ -24,6 +24,7 @@ import threading
import numpy as np import numpy as np
import multiprocessing import multiprocessing
from collections import namedtuple from collections import namedtuple
from paddle.fluid.framework import _set_expected_place, _current_expected_place
# NOTE: queue has a different name in python2 and python3 # NOTE: queue has a different name in python2 and python3
if six.PY2: if six.PY2:
...@@ -297,12 +298,20 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -297,12 +298,20 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self._need_check_feed, self._places, self._use_buffer_reader, True, self._need_check_feed, self._places, self._use_buffer_reader, True,
self._pin_memory) self._pin_memory)
self._thread = threading.Thread(target=self._thread_loop) self._thread = threading.Thread(
target=self._thread_loop, args=(_current_expected_place(), ))
self._thread.daemon = True self._thread.daemon = True
self._thread.start() self._thread.start()
def _thread_loop(self): def _thread_loop(self, legacy_expected_place):
try: try:
#NOTE(zhiqiu): Set the expected place for new thread as the same as father thread,
# and it will call platform::SetDeviceId() in c++ internally.
# If we do not set cudaDeviceId in new thread, the default cudaDeviceId will be 0,
# Which may cost hundreds of MB of GPU memory on CUDAPlace(0) if calling some cuda
# APIs in this thread.
_set_expected_place(legacy_expected_place)
for indices in self._sampler_iter: for indices in self._sampler_iter:
# read data from dataset in mini-batch # read data from dataset in mini-batch
batch = self._dataset_fetcher.fetch(indices) batch = self._dataset_fetcher.fetch(indices)
...@@ -563,7 +572,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -563,7 +572,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._pin_memory) self._pin_memory)
self._thread_done_event = threading.Event() self._thread_done_event = threading.Event()
self._thread = threading.Thread(target=self._thread_loop) self._thread = threading.Thread(
target=self._thread_loop, args=(_current_expected_place(), ))
self._thread.daemon = True self._thread.daemon = True
self._thread.start() self._thread.start()
...@@ -603,7 +613,14 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -603,7 +613,14 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._blocking_queue.kill() self._blocking_queue.kill()
logging.error("DataLoader reader thread raised an exception!") logging.error("DataLoader reader thread raised an exception!")
def _thread_loop(self): def _thread_loop(self, legacy_expected_place):
#NOTE(zhiqiu): Set the expected place for new thread as the same as father thread,
# and it will call platform::SetDeviceId() in c++ internally.
# If we do not set cudaDeviceId in new thread, the default cudaDeviceId will be 0,
# Which may cost hundreds of MB of GPU memory on CUDAPlace(0) if calling some cuda
# APIs in this thread.
_set_expected_place(legacy_expected_place)
while not self._thread_done_event.is_set(): while not self._thread_done_event.is_set():
batch = self._get_data() batch = self._get_data()
if not self._thread_done_event.is_set(): if not self._thread_done_event.is_set():
......
...@@ -379,7 +379,6 @@ def guard(place=None): ...@@ -379,7 +379,6 @@ def guard(place=None):
expected_place = _get_paddle_place(place) expected_place = _get_paddle_place(place)
else: else:
expected_place = framework._current_expected_place() expected_place = framework._current_expected_place()
tracer._expected_place = expected_place
with framework.program_guard(train, startup): with framework.program_guard(train, startup):
with framework.unique_name.guard(): with framework.unique_name.guard():
......
...@@ -5664,15 +5664,15 @@ def _get_var(name, program=None): ...@@ -5664,15 +5664,15 @@ def _get_var(name, program=None):
@signature_safe_contextmanager @signature_safe_contextmanager
def _dygraph_guard(tracer): def _dygraph_guard(tracer):
global _dygraph_tracer_ global _dygraph_tracer_
tmp_trace = _dygraph_tracer_ tmp_tracer = _dygraph_tracer_
_dygraph_tracer_ = tracer _dygraph_tracer_ = tracer
core._switch_tracer(tracer) core._switch_tracer(tracer)
try: try:
yield yield
finally: finally:
core._switch_tracer(tmp_trace) core._switch_tracer(tmp_tracer)
_dygraph_tracer_ = tmp_trace _dygraph_tracer_ = tmp_tracer
@signature_safe_contextmanager @signature_safe_contextmanager
...@@ -5681,10 +5681,13 @@ def _dygraph_place_guard(place): ...@@ -5681,10 +5681,13 @@ def _dygraph_place_guard(place):
tmp_place = _global_expected_place_ tmp_place = _global_expected_place_
_global_expected_place_ = place _global_expected_place_ = place
_set_dygraph_tracer_expected_place(place)
try: try:
yield yield
finally: finally:
_global_expected_place_ = tmp_place _global_expected_place_ = tmp_place
_set_dygraph_tracer_expected_place(tmp_place)
def load_op_library(lib_filename): def load_op_library(lib_filename):
......
...@@ -32,7 +32,7 @@ from ..unique_name import generate as unique_name ...@@ -32,7 +32,7 @@ from ..unique_name import generate as unique_name
import logging import logging
from ..data_feeder import check_dtype, check_type from ..data_feeder import check_dtype, check_type
from paddle.fluid.framework import static_only from paddle.fluid.framework import static_only
from ..framework import _get_paddle_place from ..framework import _get_paddle_place, _current_expected_place, _set_expected_place
__all__ = [ __all__ = [
'data', 'read_file', 'double_buffer', 'py_reader', 'data', 'read_file', 'double_buffer', 'py_reader',
...@@ -475,8 +475,11 @@ def _py_reader(capacity, ...@@ -475,8 +475,11 @@ def _py_reader(capacity,
reader.exited = False reader.exited = False
def start_provide_thread(func): def start_provide_thread(func):
def __provider_thread__(): def __provider_thread__(legacy_expected_place):
try: try:
# See _DataLoaderIterSingleProcess._thread_loop() for why set expected place here.
_set_expected_place(legacy_expected_place)
for tensors in func(): for tensors in func():
array = core.LoDTensorArray() array = core.LoDTensorArray()
for item in tensors: for item in tensors:
...@@ -498,7 +501,8 @@ def _py_reader(capacity, ...@@ -498,7 +501,8 @@ def _py_reader(capacity,
logging.warn('Your decorated reader has raised an exception!') logging.warn('Your decorated reader has raised an exception!')
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
reader.thread = threading.Thread(target=__provider_thread__) reader.thread = threading.Thread(
target=__provider_thread__, args=(_current_expected_place(), ))
reader.thread.daemon = True reader.thread.daemon = True
reader.thread.start() reader.thread.start()
......
...@@ -28,6 +28,7 @@ from .dataloader.batch_sampler import _InfiniteIterableSampler ...@@ -28,6 +28,7 @@ from .dataloader.batch_sampler import _InfiniteIterableSampler
from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer
from .unique_name import UniqueNameGenerator from .unique_name import UniqueNameGenerator
from .framework import _get_paddle_place, _get_paddle_place_list from .framework import _get_paddle_place, _get_paddle_place_list
from paddle.fluid.framework import _set_expected_place, _current_expected_place
import logging import logging
import warnings import warnings
...@@ -928,12 +929,14 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -928,12 +929,14 @@ class DygraphGeneratorLoader(DataLoaderBase):
# Set reader_thread # Set reader_thread
self._thread_done_event = threading.Event() self._thread_done_event = threading.Event()
self._thread = threading.Thread( self._thread = threading.Thread(
target=self._reader_thread_loop_for_multiprocess) target=self._reader_thread_loop_for_multiprocess,
args=(_current_expected_place(), ))
self._thread.daemon = True self._thread.daemon = True
self._thread.start() self._thread.start()
else: else:
self._thread = threading.Thread( self._thread = threading.Thread(
target=self._reader_thread_loop_for_singleprocess) target=self._reader_thread_loop_for_singleprocess,
args=(_current_expected_place(), ))
self._thread.daemon = True self._thread.daemon = True
self._thread.start() self._thread.start()
...@@ -968,7 +971,10 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -968,7 +971,10 @@ class DygraphGeneratorLoader(DataLoaderBase):
self._blocking_queue.kill() self._blocking_queue.kill()
logging.error("DataLoader reader thread raised an exception!") logging.error("DataLoader reader thread raised an exception!")
def _reader_thread_loop_for_multiprocess(self): def _reader_thread_loop_for_multiprocess(self, legacy_expected_place):
# See _DataLoaderIterSingleProcess._thread_loop() for why set expected place here.
_set_expected_place(legacy_expected_place)
while not self._thread_done_event.is_set(): while not self._thread_done_event.is_set():
try: try:
# NOTE: [ avoid hanging ] Even with carefully designed data dependencies # NOTE: [ avoid hanging ] Even with carefully designed data dependencies
...@@ -1007,8 +1013,11 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -1007,8 +1013,11 @@ class DygraphGeneratorLoader(DataLoaderBase):
else: else:
self._exit_thread_expectedly() self._exit_thread_expectedly()
def _reader_thread_loop_for_singleprocess(self): def _reader_thread_loop_for_singleprocess(self, legacy_expected_place):
try: try:
# See _DataLoaderIterSingleProcess._thread_loop() for why set expected place here.
_set_expected_place(legacy_expected_place)
for sample in self._batch_reader(): for sample in self._batch_reader():
array = core.LoDTensorArray() array = core.LoDTensorArray()
for item in sample: for item in sample:
...@@ -1248,8 +1257,11 @@ class GeneratorLoader(DataLoaderBase): ...@@ -1248,8 +1257,11 @@ class GeneratorLoader(DataLoaderBase):
self._reset() self._reset()
def _start(self): def _start(self):
def __thread_main__(): def __thread_main__(legacy_expected_place):
try: try:
# See _DataLoaderIterSingleProcess._thread_loop() for why set expected place here.
_set_expected_place(legacy_expected_place)
while not self._queue.wait_for_inited(1): while not self._queue.wait_for_inited(1):
if self._exited: if self._exited:
return return
...@@ -1276,7 +1288,8 @@ class GeneratorLoader(DataLoaderBase): ...@@ -1276,7 +1288,8 @@ class GeneratorLoader(DataLoaderBase):
logging.warn('Your reader has raised an exception!') logging.warn('Your reader has raised an exception!')
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
self._thread = threading.Thread(target=__thread_main__) self._thread = threading.Thread(
target=__thread_main__, args=(_current_expected_place(), ))
self._thread.daemon = True self._thread.daemon = True
self._thread.start() self._thread.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册