diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index ca947fff4358d28dcda8895e1c31b80e82099422..706f6fd592f88ceb9728ff75df955bbdd75c4c32 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -24,15 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2; class DoubleBufferReader : public framework::DecoratedReader { public: + struct Item { + Item() : ctx_(nullptr) {} + + std::vector payloads_; + platform::DeviceContext* ctx_; + }; + explicit DoubleBufferReader( ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) : DecoratedReader(reader), place_(target_place) { + for (size_t i = 0; i < kDoubleBufferSize; ++i) { + if (platform::is_gpu_place(place_)) { +#ifdef PADDLE_WITH_CUDA + ctxs_.emplace_back(new platform::CUDADeviceContext( + boost::get(place_))); +#else +#endif + } + } + start_thread(); } void start_thread() { - buffer_ = framework::MakeChannel>( - kDoubleBufferSize); + buffer_ = framework::MakeChannel(kDoubleBufferSize); std::thread prefetch([this] { PrefetchThreadFunc(); }); prefetch.detach(); } @@ -47,9 +63,10 @@ class DoubleBufferReader : public framework::DecoratedReader { private: void PrefetchThreadFunc(); - framework::Channel>* buffer_; + framework::Channel* buffer_; platform::Place place_; - mutable std::vector local_buffer_; + std::vector> ctxs_; + mutable Item local_buffer_; }; class CreateDoubleBufferReaderOp : public framework::OperatorBase { @@ -104,12 +121,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { }; void DoubleBufferReader::ReadNext(std::vector* out) { - out->clear(); - if (local_buffer_.empty()) { - buffer_->Receive(out); - } else { - *out = local_buffer_; - local_buffer_.clear(); + if (local_buffer_.payloads_.empty()) { + buffer_->Receive(&local_buffer_); + } + + *out = local_buffer_.payloads_; + local_buffer_.payloads_.clear(); + if (local_buffer_.ctx_) { + local_buffer_.ctx_->Wait(); } } @@ -121,16 +140,22 @@ void DoubleBufferReader::ReInit() { void DoubleBufferReader::PrefetchThreadFunc() { VLOG(5) << "A new prefetch thread starts."; + size_t gpu_ctx_offset = 0; while (reader_->HasNext()) { - std::vector batch; - reader_->ReadNext(&batch); + Item batch; + reader_->ReadNext(&batch.payloads_); if (platform::is_gpu_place(place_)) { std::vector gpu_batch; - gpu_batch.resize(batch.size()); - for (size_t i = 0; i < batch.size(); ++i) { - framework::TensorCopy(batch[i], place_, &gpu_batch[i]); - gpu_batch[i].set_lod(batch[i].lod()); + auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++]; + gpu_ctx_offset %= this->ctxs_.size(); + gpu_batch.resize(batch.payloads_.size()); + for (size_t i = 0; i < batch.payloads_.size(); ++i) { + framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx, + &gpu_batch[i]); + gpu_batch[i].set_lod(batch.payloads_[i].lod()); } + batch.ctx_ = gpu_ctx.get(); + std::swap(gpu_batch, batch.payloads_); } if (!buffer_->Send(&batch)) { @@ -143,7 +168,7 @@ void DoubleBufferReader::PrefetchThreadFunc() { } bool DoubleBufferReader::HasNext() const { - if (local_buffer_.empty()) { + if (local_buffer_.payloads_.empty()) { bool ok = buffer_->Receive(&local_buffer_); return ok; } else {