提交 35e1e0d5 编写于 作者: F fengjiayi

uses channel to replace the traditional buffer

上级 b3a11fdf
......@@ -12,42 +12,35 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <condition_variable>
#include <mutex>
#include <thread>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
static constexpr size_t kDoubleBufferSize = 3;
static constexpr size_t kDoubleBufferSize = 2;
class DoubleBufferReader : public framework::DecoratedReader {
public:
explicit DoubleBufferReader(ReaderBase* reader)
: DecoratedReader(reader),
buffer_(kDoubleBufferSize),
write_pos_(0),
read_pos_(0) {
std::thread prefetch(
std::bind(&DoubleBufferReader::PrefetchThreadFunc, this));
buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize)) {
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
prefetch.detach();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
bool HasNext() const override;
void ReInit() override;
~DoubleBufferReader() { buffer_->Close(); }
private:
void PrefetchThreadFunc();
std::vector<std::vector<framework::LoDTensor>> buffer_;
size_t write_pos_;
size_t read_pos_;
std::mutex mtx_;
std::condition_variable buffer_not_full_;
std::condition_variable buffer_not_empty_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
};
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
......@@ -80,44 +73,36 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
};
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
std::unique_lock<std::mutex> lck(mtx_);
while (write_pos_ == read_pos_) {
buffer_not_empty_.wait(lck);
}
out->clear();
out->reserve(buffer_[read_pos_].size());
// TODO(fengjiayi): This copy shall be reduced.
for (size_t i = 0; i < buffer_[read_pos_].size(); ++i) {
framework::LoDTensor dst;
TensorCopy(buffer_[read_pos_][i], platform::CPUPlace(), &dst);
dst.set_lod(buffer_[read_pos_][i].lod());
out->push_back(dst);
}
++read_pos_;
if (read_pos_ >= kDoubleBufferSize) {
read_pos_ = 0;
}
buffer_not_full_.notify_all();
buffer_->Receive(out);
}
bool DoubleBufferReader::HasNext() const {
return reader_->HasNext() || !buffer_.empty();
void DoubleBufferReader::ReInit() {
reader_->ReInit();
buffer_->Close();
// The existing prefetch thread will terminate for the buffer_ is closed.
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize);
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
prefetch.detach();
}
void DoubleBufferReader::PrefetchThreadFunc() {
while (reader_->HasNext()) {
std::unique_lock<std::mutex> lck(mtx_);
while (((write_pos_ + 1) % kDoubleBufferSize) == read_pos_) {
buffer_not_full_.wait(lck);
VLOG(5) << "A new prefetch thread starts.";
while (true) {
std::vector<framework::LoDTensor> batch;
reader_->ReadNext(&batch);
if (batch.empty()) {
// EOF
buffer_->Close();
VLOG(5) << "Reached the end of the file. The prefetch thread terminates.";
break;
}
reader_->ReadNext(&buffer_[write_pos_]);
++write_pos_;
if (write_pos_ >= kDoubleBufferSize) {
write_pos_ = 0;
if (!buffer_->Send(&batch)) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread terminates.";
break;
}
buffer_not_empty_.notify_all();
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册