未验证 提交 899827f2 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #9535 from reyoung/feature/fix_double_buffer

Add local cache of double buffer reader
...@@ -20,12 +20,29 @@ namespace paddle { ...@@ -20,12 +20,29 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
static constexpr size_t kDoubleBufferSize = 2; // 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 2;
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 0; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
struct Item { struct Item {
Item() : ctx_(nullptr) {} Item() : ctx_(nullptr) {}
Item(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
}
Item& operator=(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
return *this;
}
std::vector<framework::LoDTensor> payloads_; std::vector<framework::LoDTensor> payloads_;
platform::DeviceContext* ctx_; platform::DeviceContext* ctx_;
...@@ -34,42 +51,44 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -34,42 +51,44 @@ class DoubleBufferReader : public framework::DecoratedReader {
explicit DoubleBufferReader( explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) { : DecoratedReader(reader), place_(target_place) {
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
if (platform::is_gpu_place(place_)) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
for (size_t i = 0; i < kCacheSize; ++i) {
if (platform::is_gpu_place(place_)) {
ctxs_.emplace_back(new platform::CUDADeviceContext( ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_))); boost::get<platform::CUDAPlace>(place_)));
#endif
} }
} }
#endif
start_thread(); StartPrefetcher();
}
void start_thread() {
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
} }
bool HasNext() const override;
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override; void ReInit() override;
~DoubleBufferReader() { ~DoubleBufferReader() { EndPrefetcher(); }
buffer_->Close();
prefetcher_.join(); private:
delete buffer_; void StartPrefetcher() {
channel_ = framework::MakeChannel<Item>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
} }
bool HasNext() const override; void EndPrefetcher() {
channel_->Close();
if (prefetcher_.joinable()) {
prefetcher_.join();
}
delete channel_;
channel_ = nullptr;
}
private:
void PrefetchThreadFunc(); void PrefetchThreadFunc();
std::thread prefetcher_; std::thread prefetcher_;
framework::Channel<Item>* buffer_; framework::Channel<Item>* channel_;
platform::Place place_; platform::Place place_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_; std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
mutable Item local_buffer_;
}; };
class CreateDoubleBufferReaderOp : public framework::OperatorBase { class CreateDoubleBufferReaderOp : public framework::OperatorBase {
...@@ -123,70 +142,70 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -123,70 +142,70 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
} }
}; };
bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
}
return channel_->CanReceive();
}
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) { if (!HasNext()) {
PADDLE_THROW("There is no next data!"); PADDLE_THROW("There is no next data!");
} }
if (local_buffer_.payloads_.empty()) { Item batch;
buffer_->Receive(&local_buffer_); channel_->Receive(&batch);
} *out = batch.payloads_;
*out = local_buffer_.payloads_; if (batch.ctx_) {
local_buffer_.payloads_.clear(); batch.ctx_->Wait();
if (local_buffer_.ctx_) {
local_buffer_.ctx_->Wait();
} }
} }
void DoubleBufferReader::ReInit() { void DoubleBufferReader::ReInit() {
reader_->ReInit(); reader_->ReInit();
buffer_->Close(); EndPrefetcher();
prefetcher_.join(); StartPrefetcher();
delete buffer_;
start_thread();
} }
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
size_t gpu_ctx_offset = 0; std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
size_t cached_tensor_id = 0;
while (reader_->HasNext()) { while (reader_->HasNext()) {
Item batch; Item batch;
reader_->ReadNext(&batch.payloads_); auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
reader_->ReadNext(&cpu_batch);
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
std::vector<framework::LoDTensor> gpu_batch; auto& gpu_batch = gpu_tensor_cache[cached_tensor_id];
auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++]; auto* gpu_ctx = ctxs_[cached_tensor_id].get();
gpu_ctx_offset %= this->ctxs_.size(); gpu_batch.resize(cpu_batch.size());
gpu_batch.resize(batch.payloads_.size()); for (size_t i = 0; i < cpu_batch.size(); ++i) {
for (size_t i = 0; i < batch.payloads_.size(); ++i) { framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]);
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx, gpu_batch[i].set_lod(cpu_batch[i].lod());
&gpu_batch[i]);
gpu_batch[i].set_lod(batch.payloads_[i].lod());
} }
batch.ctx_ = gpu_ctx.get(); batch.payloads_ = gpu_batch;
std::swap(gpu_batch, batch.payloads_); batch.ctx_ = gpu_ctx;
} else {
// CPUPlace
batch.payloads_ = cpu_batch;
} }
++cached_tensor_id;
cached_tensor_id %= kCacheSize;
try { try {
buffer_->Send(&batch); channel_->Send(&batch);
} catch (paddle::platform::EnforceNotMet e) { } catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The " VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread will terminate."; "prefetch thread will terminate.";
break; break;
} }
} }
buffer_->Close(); channel_->Close();
VLOG(5) << "Prefetch thread terminates."; VLOG(5) << "Prefetch thread terminates.";
} }
bool DoubleBufferReader::HasNext() const {
if (local_buffer_.payloads_.empty()) {
bool ok = buffer_->Receive(&local_buffer_);
return ok;
} else {
return true;
}
}
} // namespace reader } // namespace reader
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册