未验证 提交 3c8daa9b 编写于 作者: C Chen Weihang 提交者: GitHub

Add pin memory control for BufferedReader (#26026)

* add pin memory control

* fix buffered reader init problem

* fix unittest error

* add unittest for coverage
上级 ad4a0466
......@@ -36,15 +36,30 @@ BufferedReader::~BufferedReader() {
BufferedReader::BufferedReader(
const std::shared_ptr<framework::ReaderBase> &reader,
const platform::Place &place, size_t buffer_size)
const platform::Place &place, size_t buffer_size, bool pin_memory)
: framework::DecoratedReader(reader),
thread_pool_(1),
place_(place),
buffer_size_(buffer_size) {
buffer_size_(buffer_size),
pin_memory_(pin_memory) {
VLOG(1) << "BufferedReader";
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_) && !pin_memory) {
int dev_idx = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
compute_stream_ =
((platform::CUDADeviceContext *)(platform::DeviceContextPool::Instance()
.Get(place_)))
->stream();
events_.resize(buffer_size);
for (auto &event : events_) {
event = platform::CudaEventResourcePool::Instance().New(dev_idx);
}
stream_ = platform::CudaStreamResourcePool::Instance().New(dev_idx);
}
#endif
is_same_place_ = false;
cpu_buffer_.resize(buffer_size);
cuda_pinned_buffer_.resize(buffer_size);
cuda_buffer_.resize(buffer_size);
ReadTillBufferFullAsync();
}
......@@ -65,47 +80,103 @@ void BufferedReader::ReadAsync(size_t i) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
// NOTE: [Copy processing of different input devices]
// We may accept input tensor in three different devices:
// - CPUPlace
// - CUDAPinnedPlace
// - CUDAPlace
// CUDA Stream Synchronizing is slow, in order to avoid Synchronizing
// in BufferedReader thread, we do data copy as follows:
// - If src Tensor on CPU memory, we copy it to CUDAPinned memory
// - IF src Tensor on CUDAPinned memory, we use it directly
// - IF src Tensor on CUDA memory, we use it directly
platform::CUDAPinnedPlace cuda_pinned_place;
TensorVec &cuda_pinned = cuda_pinned_buffer_[i];
if (cuda_pinned.empty()) {
cuda_pinned.resize(cpu.size());
TensorVec &cuda = cuda_buffer_[i];
if (cuda.empty()) {
cuda.resize(cpu.size());
} else {
PADDLE_ENFORCE_EQ(
cuda_pinned.size(), cpu.size(),
cuda.size(), cpu.size(),
platform::errors::InvalidArgument(
"Input tensor number on GPU and CPU devices are not matched."));
}
if (pin_memory_) {
// NOTE: [Copy processing of different input devices]
// We may accept input tensor in three different devices:
// - CPUPlace
// - CUDAPinnedPlace
// - CUDAPlace
// CUDA Stream Synchronizing is slow, in order to avoid Synchronizing
// in BufferedReader thread, we do data copy as follows:
// - If src Tensor on CPU memory, we copy it to CUDAPinned memory
// - IF src Tensor on CUDAPinned memory, we use it directly
// - IF src Tensor on CUDA memory, we use it directly
platform::CUDAPinnedPlace cuda_pinned_place;
std::vector<void *> cuda_pinned_ptrs;
cuda_pinned_ptrs.reserve(cpu.size());
platform::RecordEvent record_event("BufferedReader:MemoryCopy");
for (size_t i = 0; i < cpu.size(); ++i) {
if (platform::is_cpu_place(cpu[i].place())) {
cuda[i].Resize(cpu[i].dims());
cuda[i].set_layout(cpu[i].layout());
cuda_pinned_ptrs.emplace_back(
cuda[i].mutable_data(cuda_pinned_place, cpu[i].type()));
auto size =
cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type());
memory::Copy(cuda_pinned_place, cuda_pinned_ptrs[i],
BOOST_GET_CONST(platform::CPUPlace, cpu[i].place()),
cpu[i].data<void>(), size);
cuda[i].set_lod(cpu[i].lod());
} else {
// we set same place flag & use cpu[i] directly
is_same_place_ = true;
}
}
} else {
// NOTE(liangdun): using async copy instead of TensorCopySync
// TensorCopySync would block other stream, because TensorCopySync
// issues the copying command to the default stream, it will make two
// commands from different streams cannot run concurrently.
std::vector<void *> gpu_ptrs;
gpu_ptrs.reserve(cpu.size());
for (size_t i = 0; i < cpu.size(); ++i) {
cuda[i].Resize(cpu[i].dims());
cuda[i].set_layout(cpu[i].layout());
gpu_ptrs.emplace_back(cuda[i].mutable_data(place_, cpu[i].type()));
}
std::vector<void *> cuda_pinned_ptrs;
cuda_pinned_ptrs.reserve(cpu.size());
platform::RecordEvent record_event("BufferedReader:MemoryCopy");
for (size_t i = 0; i < cpu.size(); ++i) {
if (platform::is_cpu_place(cpu[i].place())) {
cuda_pinned[i].Resize(cpu[i].dims());
cuda_pinned[i].set_layout(cpu[i].layout());
cuda_pinned_ptrs.emplace_back(
cuda_pinned[i].mutable_data(cuda_pinned_place, cpu[i].type()));
// NOTE(zjl): cudaStreamWaitEvent() must be called after all
// cuda[i].mutable_data() is called, since some ops release
// cuda memory immediately without waiting cuda kernel ends
platform::SetDeviceId(
BOOST_GET_CONST(platform::CUDAPlace, place_).device);
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventRecord(events_[i].get(), compute_stream_));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamWaitEvent(stream_.get(), events_[i].get(), 0));
platform::RecordEvent record_event("BufferedReader:MemoryCopy");
for (size_t i = 0; i < cpu.size(); ++i) {
auto cpu_place = cpu[i].place();
auto cpu_ptr = cpu[i].data<void>();
auto gpu_ptr = gpu_ptrs[i];
auto size =
cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type());
memory::Copy(cuda_pinned_place, cuda_pinned_ptrs[i],
BOOST_GET_CONST(platform::CPUPlace, cpu[i].place()),
cpu[i].data<void>(), size);
cuda_pinned[i].set_lod(cpu[i].lod());
} else {
// we set same place flag & use cpu[i] directly
is_same_place_ = true;
if (platform::is_cuda_pinned_place(cpu_place)) {
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place_), gpu_ptr,
BOOST_GET_CONST(platform::CUDAPinnedPlace, cpu_place),
cpu_ptr, size, stream_.get());
} else if ((platform::is_gpu_place(cpu_place))) {
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place_), gpu_ptr,
BOOST_GET_CONST(platform::CUDAPlace, cpu_place),
cpu_ptr, size, stream_.get());
} else {
platform::CUDAPinnedPlace cuda_pinned_place;
framework::LoDTensor cuda_pinned_tensor;
cuda_pinned_tensor.Resize(cpu[i].dims());
auto cuda_pinned_ptr = cuda_pinned_tensor.mutable_data(
cuda_pinned_place, cpu[i].type());
memory::Copy(cuda_pinned_place, cuda_pinned_ptr,
BOOST_GET_CONST(platform::CPUPlace, cpu_place),
cpu_ptr, size);
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place_), gpu_ptr,
cuda_pinned_place, cuda_pinned_ptr, size,
stream_.get());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_.get()));
}
cuda[i].set_lod(cpu[i].lod());
}
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_.get()));
}
}
#endif
......@@ -141,7 +212,7 @@ void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) {
}
*out = std::move((platform::is_gpu_place(place_) && !is_same_place_)
? cuda_pinned_buffer_[i]
? cuda_buffer_[i]
: cpu_buffer_[i]);
// Do not push current position into ReadAsync. Push the previous position
......
......@@ -35,7 +35,8 @@ class BufferedReader : public framework::DecoratedReader {
public:
BufferedReader(const std::shared_ptr<framework::ReaderBase>& reader,
const platform::Place& place, size_t buffer_size);
const platform::Place& place, size_t buffer_size,
bool pin_memory = false);
~BufferedReader() override;
......@@ -53,6 +54,7 @@ class BufferedReader : public framework::DecoratedReader {
ThreadPool thread_pool_;
platform::Place place_;
const size_t buffer_size_;
bool pin_memory_;
std::queue<std::future<size_t>> position_;
......@@ -63,8 +65,13 @@ class BufferedReader : public framework::DecoratedReader {
// buffers and prevent alloc every time.
bool is_same_place_;
std::vector<TensorVec> cpu_buffer_;
std::vector<TensorVec> cuda_pinned_buffer_;
std::vector<TensorVec> cuda_buffer_;
size_t prev_pos_{-1UL};
#ifdef PADDLE_WITH_CUDA
cudaStream_t compute_stream_;
std::shared_ptr<platform::CudaStreamObject> stream_;
std::vector<std::shared_ptr<platform::CudaEventObject>> events_;
#endif
};
} // namespace reader
......
......@@ -125,11 +125,12 @@ class MultiDeviceFeedReader {
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, bool use_double_buffer,
bool drop_last)
bool drop_last, bool pin_memory = false)
: queue_(queue),
names_(names),
pool_(new ::ThreadPool(dst_places.size())),
drop_last_(drop_last) {
drop_last_(drop_last),
pin_memory_(pin_memory) {
std::vector<framework::DDim> dims;
for (auto &shape : shapes) {
dims.push_back(framework::make_ddim(shape));
......@@ -157,7 +158,7 @@ class MultiDeviceFeedReader {
VLOG(10) << "Creating " << i << "-th BufferedReader";
holder->Reset(
framework::MakeDecoratedReader<operators::reader::BufferedReader>(
reader, p, 2));
reader, p, 2, pin_memory_));
} else {
if (platform::is_gpu_place(p)) {
PADDLE_THROW(platform::errors::PermissionDenied(
......@@ -322,6 +323,7 @@ class MultiDeviceFeedReader {
std::vector<std::vector<framework::LoDTensor>> ret_;
bool drop_last_;
bool pin_memory_;
};
template <typename QueueType>
......@@ -445,10 +447,10 @@ void BindReader(py::module *module) {
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places,
bool use_double_buffer, bool drop_last) {
bool use_double_buffer, bool drop_last, bool pin_memory) {
return new MultiDeviceFeedReader<reader::LoDTensorBlockingQueue>(
queue, names, shapes, dtypes, need_check_feed, dst_places,
use_double_buffer, drop_last);
use_double_buffer, drop_last, pin_memory);
},
py::return_value_policy::take_ownership);
......@@ -461,12 +463,12 @@ void BindReader(py::module *module) {
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, bool use_double_buffer,
bool drop_last) {
bool drop_last, bool pin_memory) {
queue->SetDeviceCount(dst_places.size());
return new MultiDeviceFeedReader<
reader::OrderedMultiDeviceLoDTensorBlockingQueue>(
queue, names, shapes, dtypes, need_check_feed, dst_places,
use_double_buffer, drop_last);
use_double_buffer, drop_last, pin_memory);
},
py::return_value_policy::take_ownership);
}
......
......@@ -108,6 +108,7 @@ class _DataLoaderIterBase(object):
self._use_shared_memory = loader.use_shared_memory
self._timeout = loader.timeout if loader.timeout > 0 else MP_INDICES_CHECK_INTERVAL
self._worker_init_fn = loader.worker_init_fn
self._pin_memory = loader.pin_memory
# LoDTensorBlockingQueue instance for create_py_reader and a thread
# to put mini-batch data to self._blocking_queue, mini-batch data
......@@ -154,7 +155,8 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
len(self._places) > 1)
self._reader = core.create_py_reader(
self._blocking_queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_buffer_reader, True)
self._need_check_feed, self._places, self._use_buffer_reader, True,
self._pin_memory)
self._thread = threading.Thread(target=self._thread_loop)
self._thread.daemon = True
......@@ -307,7 +309,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
core.Variable(), self._outstanding_capacity, len(self._places) > 1)
self._reader = core.create_py_reader(
self._blocking_queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_buffer_reader, True)
self._need_check_feed, self._places, self._use_buffer_reader, True,
self._pin_memory)
self._thread_done_event = threading.Event()
self._thread = threading.Thread(target=self._thread_loop)
......
......@@ -48,6 +48,7 @@ __all__ = ['PyReader', 'DataLoader', 'default_collate_fn']
data_loader_unique_name_generator = UniqueNameGenerator()
KEEP_DATA_LOADER_ORDER = True
USE_PINNED_MEMORY = None
def keep_data_loader_order(*args):
......@@ -59,6 +60,15 @@ def keep_data_loader_order(*args):
KEEP_DATA_LOADER_ORDER = args[0]
def use_pinned_memory(*args):
global USE_PINNED_MEMORY
if len(args) == 0:
return USE_PINNED_MEMORY
else:
assert len(args) == 1 and isinstance(args[0], bool)
USE_PINNED_MEMORY = args[0]
def _convert_places(places):
if not isinstance(places, (list, tuple)):
places = [places]
......@@ -356,6 +366,11 @@ class DataLoader(object):
shuffle=shuffle,
drop_last=drop_last)
self.pin_memory = False
if in_dygraph_mode():
self.pin_memory = True if use_pinned_memory(
) is None else use_pinned_memory()
def __len__(self):
return len(self.batch_sampler)
......@@ -714,6 +729,8 @@ class DygraphGeneratorLoader(DataLoaderBase):
# mode, this thread is used to get next batch data from self._batch_reader, then
# push it into self._blocking_queue
self._thread = None
self._pin_memory = True if use_pinned_memory(
) is None else use_pinned_memory()
@property
def queue(self):
......@@ -759,7 +776,8 @@ class DygraphGeneratorLoader(DataLoaderBase):
self._reader = None
self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer, True)
self._need_check_feed, self._places, self._use_double_buffer, True,
self._pin_memory)
def _start(self):
if self._use_multiprocess:
......@@ -999,7 +1017,7 @@ class GeneratorLoader(DataLoaderBase):
self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer,
self._drop_last)
self._drop_last, True)
def _init_non_iterable(self):
lod_levels = []
......
......@@ -17,6 +17,7 @@ import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.reader import use_pinned_memory
def get_random_images_and_labels(image_shape, label_shape):
......@@ -77,6 +78,18 @@ class TestDygraphDataLoader(unittest.TestCase):
batch_size=self.batch_size)
self.iter_loader_data(loader)
def test_set_pin_memory(self):
with fluid.dygraph.guard():
use_pinned_memory(False)
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, iterable=False, use_multiprocess=False)
loader.set_sample_generator(
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size,
places=fluid.CPUPlace())
self.iter_loader_data(loader)
use_pinned_memory(True)
if __name__ == '__main__':
unittest.main()
......@@ -137,14 +137,8 @@ class TestStaticDataLoader(unittest.TestCase):
label = item['label']
assert image.shape() == [BATCH_SIZE, IMAGE_SIZE]
assert label.shape() == [BATCH_SIZE, 1]
if places[i]._equals(fluid.CPUPlace()):
assert image._place()._equals(fluid.CPUPlace())
assert label._place()._equals(fluid.CPUPlace())
else:
assert image._place()._equals(fluid.CUDAPinnedPlace(
))
assert label._place()._equals(fluid.CUDAPinnedPlace(
))
assert image._place()._equals(places[i])
assert label._place()._equals(places[i])
L, = exe.run(program=prog,
feed=d,
fetch_list=[loss],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册