From 7bb18433fd34a43ac46b0b134284b8d516c6ece0 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sat, 31 Mar 2018 01:08:32 +0800 Subject: [PATCH] refine code --- .../reader/create_double_buffer_reader_op.cc | 88 +++++++++---------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index f4b10cb0326..1b7df87b355 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -20,7 +20,8 @@ namespace paddle { namespace operators { namespace reader { -static constexpr size_t kDoubleBufferSize = 2; +static constexpr size_t kChannelSize = 2; +static constexpr size_t kCacheSize = 4; // kChannelSize + 2 class DoubleBufferReader : public framework::DecoratedReader { public: @@ -34,33 +35,36 @@ class DoubleBufferReader : public framework::DecoratedReader { explicit DoubleBufferReader( ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) : DecoratedReader(reader), place_(target_place) { - for (size_t i = 0; i < kDoubleBufferSize; ++i) { - if (platform::is_gpu_place(place_)) { #ifdef PADDLE_WITH_CUDA + for (size_t i = 0; i < kChannelSize + 2; ++i) { + if (platform::is_gpu_place(place_)) { ctxs_.emplace_back(new platform::CUDADeviceContext( boost::get(place_))); -#endif } } - - start_thread(); - } - - void start_thread() { - buffer_ = framework::MakeChannel(kDoubleBufferSize); - prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); +#endif + StartPrefetcher(); } + bool HasNext() const override; void ReadNext(std::vector* out) override; void ReInit() override; - ~DoubleBufferReader() { + void StartPrefetcher() { + buffer_ = framework::MakeChannel(kChannelSize); + prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); + } + + void EndPrefetcher() { buffer_->Close(); - prefetcher_.join(); + if (prefecther_.joinable()) { + prefetcher_.join(); + } delete buffer_; + buffer_ = nullptr; } - bool HasNext() const override; + ~DoubleBufferReader() { EndPrefetcher(); } private: void PrefetchThreadFunc(); @@ -123,6 +127,15 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { } }; +bool DoubleBufferReader::HasNext() const { + if (local_buffer_.payloads_.empty()) { + bool ok = buffer_->Receive(&local_buffer_); + return ok; + } else { + return true; + } +} + void DoubleBufferReader::ReadNext(std::vector* out) { if (!HasNext()) { PADDLE_THROW("There is no next data!"); @@ -137,40 +150,36 @@ void DoubleBufferReader::ReadNext(std::vector* out) { void DoubleBufferReader::ReInit() { reader_->ReInit(); - buffer_->Close(); - prefetcher_.join(); - delete buffer_; - start_thread(); + EndPrefetcher(); + StartPrefetcher(); } void DoubleBufferReader::PrefetchThreadFunc() { VLOG(5) << "A new prefetch thread starts."; - size_t gpu_ctx_offset = 0; - std::vector> cpu_tensor_cache(4); - std::vector> gpu_tensor_cache(4); - size_t tensor_cache_id = 0; + std::vector> cpu_tensor_cache(kCacheSize); + std::vector> gpu_tensor_cache(kCacheSize); + size_t cached_tensor_id = 0; while (reader_->HasNext()) { Item batch; - reader_->ReadNext(&batch.payloads_); + auto& cpu_batch = cpu_tensor_cache[cached_tensor_id]; + reader_->ReadNext(&cpu_batch); if (platform::is_gpu_place(place_)) { - tensor_cache_id %= 4; - auto& gpu_batch = gpu_tensor_cache[tensor_cache_id]; - auto& cpu_batch = cpu_tensor_cache[tensor_cache_id]; - cpu_batch = batch.payloads_; - ++tensor_cache_id; - - auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++]; - gpu_ctx_offset %= this->ctxs_.size(); - - gpu_batch.resize(batch.payloads_.size()); + auto& gpu_batch = gpu_tensor_cache[cached_tensor_id]; + auto* gpu_ctx = ctxs_[cached_tensor_id].get(); + gpu_batch.resize(cpu_batch.size()); for (size_t i = 0; i < cpu_batch.size(); ++i) { framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]); gpu_batch[i].set_lod(batch.payloads_[i].lod()); } - batch.ctx_ = gpu_ctx.get(); - batch.payloads_ = gpu_batch; + batch.payload_ = gpu_batch; + batch.ctx_ = gpu_ctx; + } else { + // CPUPlace + batch.payload_ = cpu_batch; } + ++cached_tensor_id; + cached_tensor_id %= kCacheSize; try { buffer_->Send(&batch); @@ -184,15 +193,6 @@ void DoubleBufferReader::PrefetchThreadFunc() { VLOG(5) << "Prefetch thread terminates."; } -bool DoubleBufferReader::HasNext() const { - if (local_buffer_.payloads_.empty()) { - bool ok = buffer_->Receive(&local_buffer_); - return ok; - } else { - return true; - } -} - } // namespace reader } // namespace operators } // namespace paddle -- GitLab