diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 4372f23fc1dbd85e43b04a9d644977392316c2e9..504e069b65def012da72f1547e6bbc2043d3709f 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -14,7 +14,7 @@ #include // NOLINT -#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/operators/reader/blocking_queue.h" #include "paddle/fluid/operators/reader/reader_op_registry.h" namespace paddle { @@ -23,13 +23,13 @@ namespace reader { // 'Double buffer' means we shall maintain two batches of input data at the same // time. So the kCacheSize shoul be at least 2. -static constexpr size_t kCacheSize = 2; +static constexpr size_t kCacheSize = 3; // There will be two bacthes out of the channel during training: // 1. the one waiting to be sent to the channel // 2. the one just be received from the channel, which is also being used by // subsequent operators. // So the channel size should be kChacheSize - 2 -static constexpr size_t kChannelSize = 0; // kCacheSize - 2 +static constexpr size_t kChannelSize = 1; // kCacheSize - 2 class DoubleBufferReader : public framework::DecoratedReader { public: @@ -58,7 +58,7 @@ class DoubleBufferReader : public framework::DecoratedReader { bool HasNext() const; void StartPrefetcher() { - channel_ = framework::MakeChannel(kChannelSize); + channel_ = new reader::BlockingQueue(kChannelSize); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); } @@ -74,7 +74,7 @@ class DoubleBufferReader : public framework::DecoratedReader { void PrefetchThreadFunc(); std::thread prefetcher_; - framework::Channel* channel_; + reader::BlockingQueue* channel_; platform::Place place_; std::vector> cpu_tensor_cache_; std::vector> gpu_tensor_cache_; @@ -185,10 +185,7 @@ void DoubleBufferReader::PrefetchThreadFunc() { gpu_batch[i].set_lod(cpu_batch[i].lod()); } } - try { - size_t tmp = cached_tensor_id; - channel_->Send(&tmp); - } catch (paddle::platform::EnforceNotMet e) { + if (!channel_->Send(cached_tensor_id)) { VLOG(5) << "WARNING: The double buffer channel has been closed. The " "prefetch thread will terminate."; break;