未验证 提交 12116675 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #9894 from JiayiFeng/refine_double_buffer_code

Refine double_buffer code
......@@ -33,28 +33,14 @@ static constexpr size_t kChannelSize = 0; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader {
public:
struct Item {
Item() : ctx_(nullptr) {}
Item(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
}
Item& operator=(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
return *this;
}
std::vector<framework::LoDTensor> payloads_;
platform::DeviceContext* ctx_;
};
explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) {
cpu_tensor_cache_.resize(kCacheSize);
gpu_tensor_cache_.resize(kCacheSize);
#ifdef PADDLE_WITH_CUDA
for (size_t i = 0; i < kCacheSize; ++i) {
if (platform::is_gpu_place(place_)) {
if (platform::is_gpu_place(place_)) {
for (size_t i = 0; i < kCacheSize; ++i) {
ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_)));
}
......@@ -72,7 +58,7 @@ class DoubleBufferReader : public framework::DecoratedReader {
bool HasNext() const;
void StartPrefetcher() {
channel_ = framework::MakeChannel<Item>(kChannelSize);
channel_ = framework::MakeChannel<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
}
......@@ -88,8 +74,10 @@ class DoubleBufferReader : public framework::DecoratedReader {
void PrefetchThreadFunc();
std::thread prefetcher_;
framework::Channel<Item>* channel_;
framework::Channel<size_t>* channel_;
platform::Place place_;
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache_;
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
};
......@@ -153,11 +141,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear();
if (HasNext()) {
Item batch;
channel_->Receive(&batch);
*out = batch.payloads_;
if (batch.ctx_) {
batch.ctx_->Wait();
size_t cached_tensor_id;
channel_->Receive(&cached_tensor_id);
if (platform::is_gpu_place(place_)) {
*out = gpu_tensor_cache_[cached_tensor_id];
ctxs_[cached_tensor_id]->Wait();
} else {
// CPU place
*out = cpu_tensor_cache_[cached_tensor_id];
}
}
}
......@@ -176,42 +167,33 @@ bool DoubleBufferReader::HasNext() const {
void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
size_t cached_tensor_id = 0;
while (true) {
Item batch;
auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
auto& cpu_batch = cpu_tensor_cache_[cached_tensor_id];
reader_->ReadNext(&cpu_batch);
if (cpu_batch.empty()) {
// The underlying reader have no next data.
break;
}
if (platform::is_gpu_place(place_)) {
auto& gpu_batch = gpu_tensor_cache[cached_tensor_id];
auto& gpu_batch = gpu_tensor_cache_[cached_tensor_id];
auto* gpu_ctx = ctxs_[cached_tensor_id].get();
gpu_batch.resize(cpu_batch.size());
for (size_t i = 0; i < cpu_batch.size(); ++i) {
framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]);
gpu_batch[i].set_lod(cpu_batch[i].lod());
}
batch.payloads_ = gpu_batch;
batch.ctx_ = gpu_ctx;
} else {
// CPUPlace
batch.payloads_ = cpu_batch;
}
++cached_tensor_id;
cached_tensor_id %= kCacheSize;
try {
channel_->Send(&batch);
size_t tmp = cached_tensor_id;
channel_->Send(&tmp);
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread will terminate.";
break;
}
++cached_tensor_id;
cached_tensor_id %= kCacheSize;
}
channel_->Close();
VLOG(5) << "Prefetch thread terminates.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册