提交 f9974a4a 编写于 作者: Y Yu Yang

Make double_buffer reader async

上级 a8c076e5
...@@ -24,15 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2; ...@@ -24,15 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2;
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
struct Item {
Item() : ctx_(nullptr) {}
std::vector<framework::LoDTensor> payloads_;
platform::DeviceContext* ctx_;
};
explicit DoubleBufferReader( explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) { : DecoratedReader(reader), place_(target_place) {
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
if (platform::is_gpu_place(place_)) {
#ifdef PADDLE_WITH_CUDA
ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_)));
#else
#endif
}
}
start_thread(); start_thread();
} }
void start_thread() { void start_thread() {
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>( buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
kDoubleBufferSize);
std::thread prefetch([this] { PrefetchThreadFunc(); }); std::thread prefetch([this] { PrefetchThreadFunc(); });
prefetch.detach(); prefetch.detach();
} }
...@@ -47,9 +63,10 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -47,9 +63,10 @@ class DoubleBufferReader : public framework::DecoratedReader {
private: private:
void PrefetchThreadFunc(); void PrefetchThreadFunc();
framework::Channel<std::vector<framework::LoDTensor>>* buffer_; framework::Channel<Item>* buffer_;
platform::Place place_; platform::Place place_;
mutable std::vector<framework::LoDTensor> local_buffer_; std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
mutable Item local_buffer_;
}; };
class CreateDoubleBufferReaderOp : public framework::OperatorBase { class CreateDoubleBufferReaderOp : public framework::OperatorBase {
...@@ -104,12 +121,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -104,12 +121,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}; };
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear(); if (local_buffer_.payloads_.empty()) {
if (local_buffer_.empty()) { buffer_->Receive(&local_buffer_);
buffer_->Receive(out); }
} else {
*out = local_buffer_; *out = local_buffer_.payloads_;
local_buffer_.clear(); local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) {
local_buffer_.ctx_->Wait();
} }
} }
...@@ -121,16 +140,22 @@ void DoubleBufferReader::ReInit() { ...@@ -121,16 +140,22 @@ void DoubleBufferReader::ReInit() {
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
size_t gpu_ctx_offset = 0;
while (reader_->HasNext()) { while (reader_->HasNext()) {
std::vector<framework::LoDTensor> batch; Item batch;
reader_->ReadNext(&batch); reader_->ReadNext(&batch.payloads_);
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
std::vector<framework::LoDTensor> gpu_batch; std::vector<framework::LoDTensor> gpu_batch;
gpu_batch.resize(batch.size()); auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++];
for (size_t i = 0; i < batch.size(); ++i) { gpu_ctx_offset %= this->ctxs_.size();
framework::TensorCopy(batch[i], place_, &gpu_batch[i]); gpu_batch.resize(batch.payloads_.size());
gpu_batch[i].set_lod(batch[i].lod()); for (size_t i = 0; i < batch.payloads_.size(); ++i) {
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx,
&gpu_batch[i]);
gpu_batch[i].set_lod(batch.payloads_[i].lod());
} }
batch.ctx_ = gpu_ctx.get();
std::swap(gpu_batch, batch.payloads_);
} }
if (!buffer_->Send(&batch)) { if (!buffer_->Send(&batch)) {
...@@ -143,7 +168,7 @@ void DoubleBufferReader::PrefetchThreadFunc() { ...@@ -143,7 +168,7 @@ void DoubleBufferReader::PrefetchThreadFunc() {
} }
bool DoubleBufferReader::HasNext() const { bool DoubleBufferReader::HasNext() const {
if (local_buffer_.empty()) { if (local_buffer_.payloads_.empty()) {
bool ok = buffer_->Receive(&local_buffer_); bool ok = buffer_->Receive(&local_buffer_);
return ok; return ok;
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册