提交 3fa0ef3d 编写于 作者: F fengjiayi

Refine double_buffer code

上级 5a4d9328
...@@ -33,28 +33,14 @@ static constexpr size_t kChannelSize = 0; // kCacheSize - 2 ...@@ -33,28 +33,14 @@ static constexpr size_t kChannelSize = 0; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
struct Item {
Item() : ctx_(nullptr) {}
Item(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
}
Item& operator=(Item&& b) {
payloads_ = std::move(b.payloads_);
ctx_ = std::move(b.ctx_);
return *this;
}
std::vector<framework::LoDTensor> payloads_;
platform::DeviceContext* ctx_;
};
explicit DoubleBufferReader( explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) { : DecoratedReader(reader), place_(target_place) {
cpu_tensor_cache_.resize(kCacheSize);
gpu_tensor_cache_.resize(kCacheSize);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
for (size_t i = 0; i < kCacheSize; ++i) {
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
for (size_t i = 0; i < kCacheSize; ++i) {
ctxs_.emplace_back(new platform::CUDADeviceContext( ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_))); boost::get<platform::CUDAPlace>(place_)));
} }
...@@ -72,7 +58,7 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -72,7 +58,7 @@ class DoubleBufferReader : public framework::DecoratedReader {
bool HasNext() const; bool HasNext() const;
void StartPrefetcher() { void StartPrefetcher() {
channel_ = framework::MakeChannel<Item>(kChannelSize); channel_ = framework::MakeChannel<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
} }
...@@ -88,8 +74,10 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -88,8 +74,10 @@ class DoubleBufferReader : public framework::DecoratedReader {
void PrefetchThreadFunc(); void PrefetchThreadFunc();
std::thread prefetcher_; std::thread prefetcher_;
framework::Channel<Item>* channel_; framework::Channel<size_t>* channel_;
platform::Place place_; platform::Place place_;
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache_;
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_; std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
}; };
...@@ -153,11 +141,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -153,11 +141,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear(); out->clear();
if (HasNext()) { if (HasNext()) {
Item batch; size_t cached_tensor_id;
channel_->Receive(&batch); channel_->Receive(&cached_tensor_id);
*out = batch.payloads_; if (platform::is_gpu_place(place_)) {
if (batch.ctx_) { *out = gpu_tensor_cache_[cached_tensor_id];
batch.ctx_->Wait(); ctxs_[cached_tensor_id]->Wait();
} else {
// CPU place
*out = cpu_tensor_cache_[cached_tensor_id];
} }
} }
} }
...@@ -176,42 +167,33 @@ bool DoubleBufferReader::HasNext() const { ...@@ -176,42 +167,33 @@ bool DoubleBufferReader::HasNext() const {
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
size_t cached_tensor_id = 0; size_t cached_tensor_id = 0;
while (true) { while (true) {
Item batch; auto& cpu_batch = cpu_tensor_cache_[cached_tensor_id];
auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
reader_->ReadNext(&cpu_batch); reader_->ReadNext(&cpu_batch);
if (cpu_batch.empty()) { if (cpu_batch.empty()) {
// The underlying reader have no next data. // The underlying reader have no next data.
break; break;
} }
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
auto& gpu_batch = gpu_tensor_cache[cached_tensor_id]; auto& gpu_batch = gpu_tensor_cache_[cached_tensor_id];
auto* gpu_ctx = ctxs_[cached_tensor_id].get(); auto* gpu_ctx = ctxs_[cached_tensor_id].get();
gpu_batch.resize(cpu_batch.size()); gpu_batch.resize(cpu_batch.size());
for (size_t i = 0; i < cpu_batch.size(); ++i) { for (size_t i = 0; i < cpu_batch.size(); ++i) {
framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]); framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]);
gpu_batch[i].set_lod(cpu_batch[i].lod()); gpu_batch[i].set_lod(cpu_batch[i].lod());
} }
batch.payloads_ = gpu_batch;
batch.ctx_ = gpu_ctx;
} else {
// CPUPlace
batch.payloads_ = cpu_batch;
} }
++cached_tensor_id;
cached_tensor_id %= kCacheSize;
try { try {
channel_->Send(&batch); size_t tmp = cached_tensor_id;
channel_->Send(&tmp);
} catch (paddle::platform::EnforceNotMet e) { } catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The " VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread will terminate."; "prefetch thread will terminate.";
break; break;
} }
++cached_tensor_id;
cached_tensor_id %= kCacheSize;
} }
channel_->Close(); channel_->Close();
VLOG(5) << "Prefetch thread terminates."; VLOG(5) << "Prefetch thread terminates.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册