提交 f9974a4a 编写于 作者: Y Yu Yang

Make double_buffer reader async

上级 a8c076e5
......@@ -24,15 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2;
class DoubleBufferReader : public framework::DecoratedReader {
public:
struct Item {
Item() : ctx_(nullptr) {}
std::vector<framework::LoDTensor> payloads_;
platform::DeviceContext* ctx_;
};
explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) {
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
if (platform::is_gpu_place(place_)) {
#ifdef PADDLE_WITH_CUDA
ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_)));
#else
#endif
}
}
start_thread();
}
void start_thread() {
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize);
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
std::thread prefetch([this] { PrefetchThreadFunc(); });
prefetch.detach();
}
......@@ -47,9 +63,10 @@ class DoubleBufferReader : public framework::DecoratedReader {
private:
void PrefetchThreadFunc();
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
framework::Channel<Item>* buffer_;
platform::Place place_;
mutable std::vector<framework::LoDTensor> local_buffer_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
mutable Item local_buffer_;
};
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
......@@ -104,12 +121,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
};
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear();
if (local_buffer_.empty()) {
buffer_->Receive(out);
} else {
*out = local_buffer_;
local_buffer_.clear();
if (local_buffer_.payloads_.empty()) {
buffer_->Receive(&local_buffer_);
}
*out = local_buffer_.payloads_;
local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) {
local_buffer_.ctx_->Wait();
}
}
......@@ -121,16 +140,22 @@ void DoubleBufferReader::ReInit() {
void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
size_t gpu_ctx_offset = 0;
while (reader_->HasNext()) {
std::vector<framework::LoDTensor> batch;
reader_->ReadNext(&batch);
Item batch;
reader_->ReadNext(&batch.payloads_);
if (platform::is_gpu_place(place_)) {
std::vector<framework::LoDTensor> gpu_batch;
gpu_batch.resize(batch.size());
for (size_t i = 0; i < batch.size(); ++i) {
framework::TensorCopy(batch[i], place_, &gpu_batch[i]);
gpu_batch[i].set_lod(batch[i].lod());
auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++];
gpu_ctx_offset %= this->ctxs_.size();
gpu_batch.resize(batch.payloads_.size());
for (size_t i = 0; i < batch.payloads_.size(); ++i) {
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx,
&gpu_batch[i]);
gpu_batch[i].set_lod(batch.payloads_[i].lod());
}
batch.ctx_ = gpu_ctx.get();
std::swap(gpu_batch, batch.payloads_);
}
if (!buffer_->Send(&batch)) {
......@@ -143,7 +168,7 @@ void DoubleBufferReader::PrefetchThreadFunc() {
}
bool DoubleBufferReader::HasNext() const {
if (local_buffer_.empty()) {
if (local_buffer_.payloads_.empty()) {
bool ok = buffer_->Receive(&local_buffer_);
return ok;
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册