From bba740710d0d6f2088d808e692886b04f52260ac Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 25 Mar 2020 06:49:32 -0500 Subject: [PATCH] add cuda resource pool for BufferedReader, test=develop (#23152) --- .../fluid/operators/reader/buffered_reader.cc | 44 +++---- .../fluid/operators/reader/buffered_reader.h | 5 +- .../reader/create_double_buffer_reader_op.cc | 1 + paddle/fluid/platform/CMakeLists.txt | 5 + paddle/fluid/platform/cuda_resource_pool.cc | 114 ++++++++++++++++++ paddle/fluid/platform/cuda_resource_pool.h | 64 ++++++++++ paddle/fluid/platform/resource_pool.h | 100 +++++++++++++++ python/paddle/fluid/reader.py | 2 + 8 files changed, 312 insertions(+), 23 deletions(-) create mode 100644 paddle/fluid/platform/cuda_resource_pool.cc create mode 100644 paddle/fluid/platform/cuda_resource_pool.h create mode 100644 paddle/fluid/platform/resource_pool.h diff --git a/paddle/fluid/operators/reader/buffered_reader.cc b/paddle/fluid/operators/reader/buffered_reader.cc index 894d98ca992..b237df130ab 100644 --- a/paddle/fluid/operators/reader/buffered_reader.cc +++ b/paddle/fluid/operators/reader/buffered_reader.cc @@ -17,8 +17,8 @@ #include #include #include "paddle/fluid/framework/data_type.h" - #include "paddle/fluid/platform/profiler.h" + namespace paddle { namespace operators { namespace reader { @@ -32,15 +32,6 @@ BufferedReader::~BufferedReader() { } position_.pop(); } -#ifdef PADDLE_WITH_CUDA - if (platform::is_gpu_place(place_)) { - platform::SetDeviceId(boost::get(place_).device); - PADDLE_ENFORCE(cudaStreamDestroy(stream_)); - for (auto &event : events_) { - PADDLE_ENFORCE(cudaEventDestroy(event)); - } - } -#endif } BufferedReader::BufferedReader( @@ -53,16 +44,16 @@ BufferedReader::BufferedReader( VLOG(1) << "BufferedReader"; #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { - platform::SetDeviceId(boost::get(place_).device); + int dev_idx = boost::get(place_).device; compute_stream_ = ((platform::CUDADeviceContext *)(platform::DeviceContextPool::Instance() .Get(place_))) ->stream(); events_.resize(buffer_size); for (auto &event : events_) { - PADDLE_ENFORCE(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + event = platform::CudaEventResourcePool::Instance().New(dev_idx); } - PADDLE_ENFORCE(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); + stream_ = platform::CudaStreamResourcePool::Instance().New(dev_idx); } #endif cpu_buffer_.resize(buffer_size); @@ -112,8 +103,14 @@ void BufferedReader::ReadAsync(size_t i) { // gpu[i].mutable_data() is called, since some ops release // gpu memory immediately without waiting gpu kernel ends platform::SetDeviceId(boost::get(place_).device); - PADDLE_ENFORCE(cudaEventRecord(events_[i], compute_stream_)); - PADDLE_ENFORCE(cudaStreamWaitEvent(stream_, events_[i], 0)); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventRecord(events_[i].get(), compute_stream_), + platform::errors::Fatal( + "cudaEventRecord raises unexpected exception")); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamWaitEvent(stream_.get(), events_[i].get(), 0), + platform::errors::Fatal( + "cudaStreamWaitEvent raises unexpected exception")); platform::RecordEvent record_event("BufferedReader:MemoryCopy"); for (size_t i = 0; i < cpu.size(); ++i) { @@ -125,11 +122,11 @@ void BufferedReader::ReadAsync(size_t i) { if (platform::is_cuda_pinned_place(cpu_place)) { memory::Copy(boost::get(place_), gpu_ptr, boost::get(cpu_place), - cpu_ptr, size, stream_); + cpu_ptr, size, stream_.get()); } else if ((platform::is_gpu_place(cpu_place))) { memory::Copy(boost::get(place_), gpu_ptr, boost::get(cpu_place), cpu_ptr, - size, stream_); + size, stream_.get()); } else { platform::CUDAPinnedPlace cuda_pinned_place; framework::LoDTensor cuda_pinned_tensor; @@ -140,13 +137,18 @@ void BufferedReader::ReadAsync(size_t i) { boost::get(cpu_place), cpu_ptr, size); memory::Copy(boost::get(place_), gpu_ptr, - cuda_pinned_place, cuda_pinned_ptr, size, stream_); - PADDLE_ENFORCE(cudaStreamSynchronize(stream_), - "cuda stream sync error."); + cuda_pinned_place, cuda_pinned_ptr, size, stream_.get()); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamSynchronize(stream_.get()), + platform::errors::Fatal( + "cudaStreamSynchronize raises unexpected exception")); } gpu[i].set_lod(cpu[i].lod()); } - PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamSynchronize(stream_.get()), + platform::errors::Fatal( + "cudaStreamSynchronize raises unexpected exception")); } #endif return i; diff --git a/paddle/fluid/operators/reader/buffered_reader.h b/paddle/fluid/operators/reader/buffered_reader.h index 5f8b2d47c22..89ecea95835 100644 --- a/paddle/fluid/operators/reader/buffered_reader.h +++ b/paddle/fluid/operators/reader/buffered_reader.h @@ -21,6 +21,7 @@ #include "ThreadPool.h" #include "paddle/fluid/framework/reader.h" #ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_resource_pool.h" #include "paddle/fluid/platform/gpu_info.h" #endif @@ -64,9 +65,9 @@ class BufferedReader : public framework::DecoratedReader { std::vector gpu_buffer_; size_t prev_pos_{-1UL}; #ifdef PADDLE_WITH_CUDA - cudaStream_t stream_; cudaStream_t compute_stream_; - std::vector events_; + std::shared_ptr stream_; + std::vector> events_; #endif }; diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 49983b3fa4f..e39919947c2 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -57,6 +57,7 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { VLOG(10) << "Create new double buffer reader on " << place; + out->Clear(); out->Reset(framework::MakeDecoratedReader(underlying_reader, place, 2)); } diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 810b9e86b0c..b954d1b7658 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -80,6 +80,11 @@ cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce) +if(WITH_GPU) + cc_library(cuda_resource_pool SRCS cuda_resource_pool.cc DEPS gpu_info) + target_link_libraries(device_context cuda_resource_pool) +endif() + if(WIN32) if(WITH_GPU AND NOT WITH_DSO) get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES) diff --git a/paddle/fluid/platform/cuda_resource_pool.cc b/paddle/fluid/platform/cuda_resource_pool.cc new file mode 100644 index 00000000000..1828f0760a7 --- /dev/null +++ b/paddle/fluid/platform/cuda_resource_pool.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_resource_pool.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace platform { + +CudaStreamResourcePool::CudaStreamResourcePool() { + int dev_cnt = platform::GetCUDADeviceCount(); + pool_.reserve(dev_cnt); + for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { + auto creator = [dev_idx] { + platform::SetDeviceId(dev_idx); + cudaStream_t stream; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking), + platform::errors::Fatal( + "cudaStreamCreateWithFlags raises unexpected exception")); + return stream; + }; + + auto deleter = [dev_idx](cudaStream_t stream) { + platform::SetDeviceId(dev_idx); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamDestroy(stream), + platform::errors::Fatal( + "cudaStreamDestroy raises unexpected exception")); + }; + + pool_.emplace_back( + ResourcePool::Create(creator, deleter)); + } +} + +CudaStreamResourcePool& CudaStreamResourcePool::Instance() { + static CudaStreamResourcePool pool; + return pool; +} + +std::shared_ptr CudaStreamResourcePool::New(int dev_idx) { + PADDLE_ENFORCE_GE( + dev_idx, 0, + platform::errors::InvalidArgument( + "dev_idx should be not less than 0, but got %d", dev_idx)); + PADDLE_ENFORCE_LT( + dev_idx, pool_.size(), + platform::errors::OutOfRange( + "dev_idx should be less than device count %d, but got %d", + pool_.size(), dev_idx)); + return pool_[dev_idx]->New(); +} + +CudaEventResourcePool::CudaEventResourcePool() { + int dev_cnt = platform::GetCUDADeviceCount(); + pool_.reserve(dev_cnt); + for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { + auto creator = [dev_idx] { + platform::SetDeviceId(dev_idx); + cudaEvent_t event; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventCreateWithFlags(&event, cudaEventDisableTiming), + platform::errors::Fatal( + "cudaEventCreateWithFlags raises unexpected exception")); + return event; + }; + + auto deleter = [dev_idx](cudaEvent_t event) { + platform::SetDeviceId(dev_idx); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventDestroy(event), + platform::errors::Fatal( + "cudaEventDestroy raises unexpected exception")); + }; + + pool_.emplace_back(ResourcePool::Create(creator, deleter)); + } +} + +CudaEventResourcePool& CudaEventResourcePool::Instance() { + static CudaEventResourcePool pool; + return pool; +} + +std::shared_ptr CudaEventResourcePool::New(int dev_idx) { + PADDLE_ENFORCE_GE( + dev_idx, 0, + platform::errors::InvalidArgument( + "dev_idx should be not less than 0, but got %d", dev_idx)); + PADDLE_ENFORCE_LT( + dev_idx, pool_.size(), + platform::errors::OutOfRange( + "dev_idx should be less than device count %d, but got %d", + pool_.size(), dev_idx)); + return pool_[dev_idx]->New(); +} + +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/cuda_resource_pool.h b/paddle/fluid/platform/cuda_resource_pool.h new file mode 100644 index 00000000000..22b53445d84 --- /dev/null +++ b/paddle/fluid/platform/cuda_resource_pool.h @@ -0,0 +1,64 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_CUDA +#include +#include +#include +#include +#include +#include "paddle/fluid/platform/resource_pool.h" + +namespace paddle { +namespace platform { + +using CudaStreamObject = std::remove_pointer::type; +using CudaEventObject = std::remove_pointer::type; + +class CudaStreamResourcePool { + public: + std::shared_ptr New(int dev_idx); + + static CudaStreamResourcePool &Instance(); + + private: + CudaStreamResourcePool(); + + DISABLE_COPY_AND_ASSIGN(CudaStreamResourcePool); + + private: + std::vector>> pool_; +}; + +class CudaEventResourcePool { + public: + std::shared_ptr New(int dev_idx); + + static CudaEventResourcePool &Instance(); + + private: + CudaEventResourcePool(); + + DISABLE_COPY_AND_ASSIGN(CudaEventResourcePool); + + private: + std::vector>> pool_; +}; + +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/resource_pool.h b/paddle/fluid/platform/resource_pool.h new file mode 100644 index 00000000000..6ac68c2da07 --- /dev/null +++ b/paddle/fluid/platform/resource_pool.h @@ -0,0 +1,100 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace platform { + +template +class ResourcePool : public std::enable_shared_from_this> { + private: + struct ResourceDeleter { + public: + explicit ResourceDeleter(ResourcePool *pool) + : instance_(pool->shared_from_this()) {} + + void operator()(T *ptr) const { instance_->Restore(ptr); } + + private: + std::shared_ptr> instance_; + }; + + public: + static std::shared_ptr> Create( + const std::function &creator, + const std::function &deleter) { + return std::shared_ptr>( + new ResourcePool(creator, deleter)); + } + + ~ResourcePool() { + for (auto *ptr : instances_) { + deleter_(ptr); + } + } + + std::shared_ptr New() { + std::lock_guard guard(mtx_); + T *obj = nullptr; + if (instances_.empty()) { + obj = creator_(); + PADDLE_ENFORCE_NOT_NULL(obj, + platform::errors::PermissionDenied( + "The creator should not return nullptr")); + VLOG(10) << "Create new instance " << TypePtrName(); + } else { + obj = instances_.back(); + instances_.pop_back(); + VLOG(10) << "Pop new instance " << TypePtrName() + << " from pool, size=" << instances_.size(); + } + return std::shared_ptr(obj, ResourceDeleter(this)); + } + + private: + static std::string TypePtrName() { + return platform::demangle(typeid(T *).name()); // NOLINT + } + + private: + ResourcePool(const std::function &creator, + const std::function &deleter) + : creator_(creator), deleter_(deleter) {} + + void Restore(T *ptr) { + std::lock_guard guard(mtx_); + instances_.emplace_back(ptr); + VLOG(10) << "Restore " << TypePtrName() + << " into pool, size=" << instances_.size(); + } + + private: + std::vector instances_; + std::function creator_; + std::function deleter_; + + std::mutex mtx_; +}; + +} // namespace platform +} // namespace paddle diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index d11628557c5..fe5e9a2e045 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -580,6 +580,7 @@ class DygraphGeneratorLoader(DataLoaderBase): self._need_check_feed = [] self._blocking_queue = core.init_lod_tensor_blocking_queue( core.Variable(), self._capacity, False) + self._reader = None self._reader = core.create_py_reader( self.queue, self._var_names, self._shapes, self._dtypes, self._need_check_feed, self._places, self._use_double_buffer, True) @@ -832,6 +833,7 @@ class GeneratorLoader(DataLoaderBase): ] self._queue = core.init_lod_tensor_blocking_queue( core.Variable(), self._capacity, self._keep_order) + self._reader = None self._reader = core.create_py_reader( self.queue, self._var_names, self._shapes, self._dtypes, self._need_check_feed, self._places, self._use_double_buffer, -- GitLab