提交 7bb18433 编写于 作者: F fengjiayi

refine code

上级 53fa7cb9
......@@ -20,7 +20,8 @@ namespace paddle {
namespace operators {
namespace reader {
static constexpr size_t kDoubleBufferSize = 2;
static constexpr size_t kChannelSize = 2;
static constexpr size_t kCacheSize = 4; // kChannelSize + 2
class DoubleBufferReader : public framework::DecoratedReader {
public:
......@@ -34,33 +35,36 @@ class DoubleBufferReader : public framework::DecoratedReader {
explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) {
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
if (platform::is_gpu_place(place_)) {
#ifdef PADDLE_WITH_CUDA
for (size_t i = 0; i < kChannelSize + 2; ++i) {
if (platform::is_gpu_place(place_)) {
ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_)));
#endif
}
}
start_thread();
}
void start_thread() {
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
#endif
StartPrefetcher();
}
bool HasNext() const override;
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
~DoubleBufferReader() {
void StartPrefetcher() {
buffer_ = framework::MakeChannel<Item>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
}
void EndPrefetcher() {
buffer_->Close();
if (prefecther_.joinable()) {
prefetcher_.join();
}
delete buffer_;
buffer_ = nullptr;
}
bool HasNext() const override;
~DoubleBufferReader() { EndPrefetcher(); }
private:
void PrefetchThreadFunc();
......@@ -123,6 +127,15 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}
};
bool DoubleBufferReader::HasNext() const {
if (local_buffer_.payloads_.empty()) {
bool ok = buffer_->Receive(&local_buffer_);
return ok;
} else {
return true;
}
}
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
......@@ -137,40 +150,36 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void DoubleBufferReader::ReInit() {
reader_->ReInit();
buffer_->Close();
prefetcher_.join();
delete buffer_;
start_thread();
EndPrefetcher();
StartPrefetcher();
}
void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
size_t gpu_ctx_offset = 0;
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(4);
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(4);
size_t tensor_cache_id = 0;
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
size_t cached_tensor_id = 0;
while (reader_->HasNext()) {
Item batch;
reader_->ReadNext(&batch.payloads_);
auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
reader_->ReadNext(&cpu_batch);
if (platform::is_gpu_place(place_)) {
tensor_cache_id %= 4;
auto& gpu_batch = gpu_tensor_cache[tensor_cache_id];
auto& cpu_batch = cpu_tensor_cache[tensor_cache_id];
cpu_batch = batch.payloads_;
++tensor_cache_id;
auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++];
gpu_ctx_offset %= this->ctxs_.size();
gpu_batch.resize(batch.payloads_.size());
auto& gpu_batch = gpu_tensor_cache[cached_tensor_id];
auto* gpu_ctx = ctxs_[cached_tensor_id].get();
gpu_batch.resize(cpu_batch.size());
for (size_t i = 0; i < cpu_batch.size(); ++i) {
framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]);
gpu_batch[i].set_lod(batch.payloads_[i].lod());
}
batch.ctx_ = gpu_ctx.get();
batch.payloads_ = gpu_batch;
batch.payload_ = gpu_batch;
batch.ctx_ = gpu_ctx;
} else {
// CPUPlace
batch.payload_ = cpu_batch;
}
++cached_tensor_id;
cached_tensor_id %= kCacheSize;
try {
buffer_->Send(&batch);
......@@ -184,15 +193,6 @@ void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "Prefetch thread terminates.";
}
bool DoubleBufferReader::HasNext() const {
if (local_buffer_.payloads_.empty()) {
bool ok = buffer_->Receive(&local_buffer_);
return ok;
} else {
return true;
}
}
} // namespace reader
} // namespace operators
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册