提交 e2ca4240 编写于 作者: F fengjiayi

Replace Channel in DoubleBufferReader with BlockingQueue

上级 a7866113
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <thread> // NOLINT #include <thread> // NOLINT
#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
...@@ -23,13 +23,13 @@ namespace reader { ...@@ -23,13 +23,13 @@ namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same // 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2. // time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 2; static constexpr size_t kCacheSize = 3;
// There will be two bacthes out of the channel during training: // There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel // 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by // 2. the one just be received from the channel, which is also being used by
// subsequent operators. // subsequent operators.
// So the channel size should be kChacheSize - 2 // So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 0; // kCacheSize - 2 static constexpr size_t kChannelSize = 1; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
...@@ -58,7 +58,7 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -58,7 +58,7 @@ class DoubleBufferReader : public framework::DecoratedReader {
bool HasNext() const; bool HasNext() const;
void StartPrefetcher() { void StartPrefetcher() {
channel_ = framework::MakeChannel<size_t>(kChannelSize); channel_ = new reader::BlockingQueue<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
} }
...@@ -74,7 +74,7 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -74,7 +74,7 @@ class DoubleBufferReader : public framework::DecoratedReader {
void PrefetchThreadFunc(); void PrefetchThreadFunc();
std::thread prefetcher_; std::thread prefetcher_;
framework::Channel<size_t>* channel_; reader::BlockingQueue<size_t>* channel_;
platform::Place place_; platform::Place place_;
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache_; std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache_;
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache_; std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache_;
...@@ -185,10 +185,7 @@ void DoubleBufferReader::PrefetchThreadFunc() { ...@@ -185,10 +185,7 @@ void DoubleBufferReader::PrefetchThreadFunc() {
gpu_batch[i].set_lod(cpu_batch[i].lod()); gpu_batch[i].set_lod(cpu_batch[i].lod());
} }
} }
try { if (!channel_->Send(cached_tensor_id)) {
size_t tmp = cached_tensor_id;
channel_->Send(&tmp);
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The " VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread will terminate."; "prefetch thread will terminate.";
break; break;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册