未验证 提交 dc34effd 编写于 作者: Y yuyang18

Extract buffered reader

上级 ab1a6cb8
...@@ -15,12 +15,13 @@ function(reader_library TARGET_NAME) ...@@ -15,12 +15,13 @@ function(reader_library TARGET_NAME)
PARENT_SCOPE) PARENT_SCOPE)
endfunction() endfunction()
cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool)
reader_library(open_files_op SRCS open_files_op.cc) reader_library(open_files_op SRCS open_files_op.cc)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc DEPS buffered_reader)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc) reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc)
reader_library(create_py_reader_op SRCS create_py_reader_op.cc) reader_library(create_py_reader_op SRCS create_py_reader_op.cc)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include <vector>
namespace paddle {
namespace operators {
namespace reader {
BufferedReader::~BufferedReader() {
reader_->Shutdown();
buffer_.clear();
}
BufferedReader::BufferedReader(
const std::shared_ptr<framework::ReaderBase> &reader,
const platform::Place &place, size_t buffer_size)
: framework::DecoratedReader(reader),
thread_pool_(1),
place_(place),
buffer_size_(buffer_size) {
AppendFutureToBatchSize();
}
void BufferedReader::AppendFutureToBatchSize() {
while (buffer_.size() < buffer_size_) {
AppendFuture();
}
}
void BufferedReader::AppendFuture() {
buffer_.emplace_back(thread_pool_.enqueue([this] {
TensorVec cpu_buffer;
reader_->ReadNext(&cpu_buffer);
if (platform::is_gpu_place(place_)) {
TensorVec gpu_buffer;
for (size_t i = 0; i < cpu_buffer.size(); ++i) {
gpu_buffer.emplace_back();
framework::TensorCopySync(cpu_buffer[i], place_, &gpu_buffer.back());
}
cpu_buffer = gpu_buffer;
}
return cpu_buffer;
}));
}
void BufferedReader::ShutdownImpl() {
reader_->Shutdown();
buffer_.clear();
}
void BufferedReader::StartImpl() {
reader_->Start();
AppendFutureToBatchSize();
}
void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) {
PADDLE_ENFORCE_EQ(buffer_.size(), buffer_size_);
*out = buffer_.front().get();
buffer_.pop_front();
AppendFuture();
}
} // namespace reader
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/framework/reader.h"
namespace paddle {
namespace operators {
namespace reader {
class BufferedReader : public framework::DecoratedReader {
using TensorVec = std::vector<framework::LoDTensor>;
using VecFuture = std::future<TensorVec>;
public:
BufferedReader(const std::shared_ptr<framework::ReaderBase>& reader,
const platform::Place& place, size_t buffer_size);
~BufferedReader() override;
private:
void AppendFutureToBatchSize();
void AppendFuture();
protected:
void ShutdownImpl() override;
void StartImpl() override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private:
ThreadPool thread_pool_;
platform::Place place_;
const size_t buffer_size_;
std::list<VecFuture> buffer_;
};
} // namespace reader
} // namespace operators
} // namespace paddle
...@@ -12,83 +12,12 @@ ...@@ -12,83 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <thread> // NOLINT #include "paddle/fluid/operators/reader/buffered_reader.h"
#include "ThreadPool.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
class BufferedReader : public framework::DecoratedReader {
using TensorVec = std::vector<framework::LoDTensor>;
using VecFuture = std::future<TensorVec>;
public:
BufferedReader(const std::shared_ptr<framework::ReaderBase>& reader,
const platform::Place& place, size_t buffer_size)
: framework::DecoratedReader(reader),
thread_pool_(1),
place_(place),
buffer_size_(buffer_size) {
AppendFutureToBatchSize();
}
~BufferedReader() override {
reader_->Shutdown();
buffer_.clear();
}
private:
void AppendFutureToBatchSize() {
while (buffer_.size() < buffer_size_) {
AppendFuture();
}
}
void AppendFuture() {
buffer_.emplace_back(thread_pool_.enqueue([this] {
TensorVec cpu_buffer;
reader_->ReadNext(&cpu_buffer);
if (platform::is_gpu_place(place_)) {
TensorVec gpu_buffer;
for (size_t i = 0; i < cpu_buffer.size(); ++i) {
gpu_buffer.emplace_back();
framework::TensorCopySync(cpu_buffer[i], place_, &gpu_buffer.back());
}
cpu_buffer = gpu_buffer;
}
return cpu_buffer;
}));
}
protected:
void ShutdownImpl() override {
reader_->Shutdown();
buffer_.clear();
}
void StartImpl() override {
reader_->Start();
AppendFutureToBatchSize();
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
std::cerr << "Read" << std::endl;
PADDLE_ENFORCE_EQ(buffer_.size(), buffer_size_);
*out = buffer_.front().get();
buffer_.pop_front();
AppendFuture();
}
private:
ThreadPool thread_pool_;
platform::Place place_;
const size_t buffer_size_;
std::list<VecFuture> buffer_;
};
class CreateDoubleBufferReaderOp : public framework::OperatorBase { class CreateDoubleBufferReaderOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册