From 3fa0ef3d7102615848f8793c1151c6ec069cd296 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 13 Apr 2018 06:40:50 +0000 Subject: [PATCH] Refine double_buffer code --- .../reader/create_double_buffer_reader_op.cc | 62 +++++++------------ 1 file changed, 22 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 33a50b5cebc..0b7c1d6af71 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -33,28 +33,14 @@ static constexpr size_t kChannelSize = 0; // kCacheSize - 2 class DoubleBufferReader : public framework::DecoratedReader { public: - struct Item { - Item() : ctx_(nullptr) {} - Item(Item&& b) { - payloads_ = std::move(b.payloads_); - ctx_ = std::move(b.ctx_); - } - Item& operator=(Item&& b) { - payloads_ = std::move(b.payloads_); - ctx_ = std::move(b.ctx_); - return *this; - } - - std::vector payloads_; - platform::DeviceContext* ctx_; - }; - explicit DoubleBufferReader( ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) : DecoratedReader(reader), place_(target_place) { + cpu_tensor_cache_.resize(kCacheSize); + gpu_tensor_cache_.resize(kCacheSize); #ifdef PADDLE_WITH_CUDA - for (size_t i = 0; i < kCacheSize; ++i) { - if (platform::is_gpu_place(place_)) { + if (platform::is_gpu_place(place_)) { + for (size_t i = 0; i < kCacheSize; ++i) { ctxs_.emplace_back(new platform::CUDADeviceContext( boost::get(place_))); } @@ -72,7 +58,7 @@ class DoubleBufferReader : public framework::DecoratedReader { bool HasNext() const; void StartPrefetcher() { - channel_ = framework::MakeChannel(kChannelSize); + channel_ = framework::MakeChannel(kChannelSize); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); } @@ -88,8 +74,10 @@ class DoubleBufferReader : public framework::DecoratedReader { void PrefetchThreadFunc(); std::thread prefetcher_; - framework::Channel* channel_; + framework::Channel* channel_; platform::Place place_; + std::vector> cpu_tensor_cache_; + std::vector> gpu_tensor_cache_; std::vector> ctxs_; }; @@ -153,11 +141,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { void DoubleBufferReader::ReadNext(std::vector* out) { out->clear(); if (HasNext()) { - Item batch; - channel_->Receive(&batch); - *out = batch.payloads_; - if (batch.ctx_) { - batch.ctx_->Wait(); + size_t cached_tensor_id; + channel_->Receive(&cached_tensor_id); + if (platform::is_gpu_place(place_)) { + *out = gpu_tensor_cache_[cached_tensor_id]; + ctxs_[cached_tensor_id]->Wait(); + } else { + // CPU place + *out = cpu_tensor_cache_[cached_tensor_id]; } } } @@ -176,42 +167,33 @@ bool DoubleBufferReader::HasNext() const { void DoubleBufferReader::PrefetchThreadFunc() { VLOG(5) << "A new prefetch thread starts."; - std::vector> cpu_tensor_cache(kCacheSize); - std::vector> gpu_tensor_cache(kCacheSize); size_t cached_tensor_id = 0; - while (true) { - Item batch; - auto& cpu_batch = cpu_tensor_cache[cached_tensor_id]; + auto& cpu_batch = cpu_tensor_cache_[cached_tensor_id]; reader_->ReadNext(&cpu_batch); if (cpu_batch.empty()) { // The underlying reader have no next data. break; } if (platform::is_gpu_place(place_)) { - auto& gpu_batch = gpu_tensor_cache[cached_tensor_id]; + auto& gpu_batch = gpu_tensor_cache_[cached_tensor_id]; auto* gpu_ctx = ctxs_[cached_tensor_id].get(); gpu_batch.resize(cpu_batch.size()); for (size_t i = 0; i < cpu_batch.size(); ++i) { framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]); gpu_batch[i].set_lod(cpu_batch[i].lod()); } - batch.payloads_ = gpu_batch; - batch.ctx_ = gpu_ctx; - } else { - // CPUPlace - batch.payloads_ = cpu_batch; } - ++cached_tensor_id; - cached_tensor_id %= kCacheSize; - try { - channel_->Send(&batch); + size_t tmp = cached_tensor_id; + channel_->Send(&tmp); } catch (paddle::platform::EnforceNotMet e) { VLOG(5) << "WARNING: The double buffer channel has been closed. The " "prefetch thread will terminate."; break; } + ++cached_tensor_id; + cached_tensor_id %= kCacheSize; } channel_->Close(); VLOG(5) << "Prefetch thread terminates."; -- GitLab