提交 4fb7b967 编写于 作者: F fengjiayi

Add basic double buffer reader

上级 77200a70
......@@ -16,13 +16,10 @@
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/threadpool.h"
namespace paddle {
namespace framework {
static constexpr size_t kDoubleBufferSize = 3;
class ReaderBase {
public:
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
......
......@@ -2,4 +2,5 @@ cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_regist
op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry)
op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry)
op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry)
set(READER_LIBRARY create_random_data_generator_op create_shuffle_reader_op create_batch_reader_op PARENT_SCOPE)
op_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc DEPS reader_op_registry threadpool)
set(READER_LIBRARY create_random_data_generator_op create_shuffle_reader_op create_batch_reader_op create_double_buffer_reader_op PARENT_SCOPE)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
static constexpr size_t kDoubleBufferSize = 3;
class DoubleBufferReader : public framework::DecoratedReader {
public:
explicit DoubleBufferReader(ReaderBase* reader)
: DecoratedReader(reader),
buffer_(kDoubleBufferSize),
write_pos_(0),
read_pos_(0) {
std::thread prefetch(
std::bind(&DoubleBufferReader::PrefetchThreadFunc, this));
prefetch.detach();
// framework::Async(
// std::bind(&DoubleBufferReader::PrefetchThreadFunc, this));
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
bool HasNext() const override;
private:
void PrefetchThreadFunc();
std::vector<std::vector<framework::LoDTensor>> buffer_;
size_t write_pos_;
size_t read_pos_;
std::mutex mtx_;
std::condition_variable buffer_not_full_;
std::condition_variable buffer_not_empty_;
};
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new DoubleBufferReader(underlying_reader.Get()));
}
};
class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateDoubleBufferReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddComment(R"DOC(
CreateDoubleBufferReader Operator
A double buffer reader takes another reader as its 'underlying reader'.
It launches another thread to execute the 'underlying reader' asynchronously,
which prevents reading process from blocking subsequent training.
)DOC");
}
};
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
std::unique_lock<std::mutex> lck(mtx_);
while (write_pos_ == read_pos_) {
buffer_not_empty_.wait(lck);
}
out->clear();
out->reserve(buffer_[read_pos_].size());
// TODO(fengjiayi): This copy shall be reduced.
for (size_t i = 0; i < buffer_[read_pos_].size(); ++i) {
framework::LoDTensor dst;
TensorCopy(buffer_[read_pos_][i], platform::CPUPlace(), &dst);
dst.set_lod(buffer_[read_pos_][i].lod());
out->push_back(dst);
}
++read_pos_;
if (read_pos_ >= kDoubleBufferSize) {
read_pos_ = 0;
}
buffer_not_full_.notify_all();
}
bool DoubleBufferReader::HasNext() const {
return reader_->HasNext() || !buffer_.empty();
}
void DoubleBufferReader::PrefetchThreadFunc() {
while (reader_->HasNext()) {
std::unique_lock<std::mutex> lck(mtx_);
while (((write_pos_ + 1) % kDoubleBufferSize) == read_pos_) {
buffer_not_full_.wait(lck);
}
reader_->ReadNext(&buffer_[write_pos_]);
++write_pos_;
if (write_pos_ >= kDoubleBufferSize) {
write_pos_ = 0;
}
buffer_not_empty_.notify_all();
}
}
} // namespace reader
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_double_buffer_reader,
ops::CreateDoubleBufferReaderOp,
ops::CreateDoubleBufferReaderOpMaker);
......@@ -15,16 +15,30 @@
import paddle.v2 as paddle
import paddle.fluid as fluid
import numpy as np
import sys
prog = fluid.framework.Program()
block = prog.current_block()
startup_prog = fluid.framework.Program()
startup_block = startup_prog.current_block()
random_reader = block.create_var(
random_reader = startup_block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator")
random_reader.desc.set_dtypes(
[fluid.core.VarDesc.VarType.FP32, fluid.core.VarDesc.VarType.FP32])
random_reader.persistable = True
shuffle_reader = startup_block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="ShuffleReader")
shuffle_reader.persistable = True
batch_reader = startup_block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="BatchReader")
batch_reader.persistable = True
double_buffer = startup_block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="DoubleBuffer")
double_buffer.persistable = True
main_prog = startup_prog.clone()
main_block = main_prog.current_block()
create_random_data_generator_op = block.append_op(
create_random_data_generator_op = startup_block.append_op(
type="create_random_data_generator",
outputs={"Out": random_reader},
attrs={
......@@ -34,37 +48,45 @@ create_random_data_generator_op = block.append_op(
"max": 1.0,
'lod_levels': [0, 0]
})
shuffle_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="ShuffleReader")
create_shuffle_reader_op = block.append_op(
create_shuffle_reader_op = startup_block.append_op(
type="create_shuffle_reader",
inputs={"UnderlyingReader": random_reader},
outputs={"Out": shuffle_reader},
attrs={"buffer_size": 7})
batch_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="BatchReader")
create_batch_reader_op = block.append_op(
create_batch_reader_op = startup_block.append_op(
type="create_batch_reader",
inputs={"UnderlyingReader": shuffle_reader},
outputs={"Out": batch_reader},
attrs={"batch_size": 10})
out1 = block.create_var(type=fluid.core.VarDesc.VarType.LOD_TENSOR, name="Out1")
out2 = block.create_var(type=fluid.core.VarDesc.VarType.LOD_TENSOR, name="Out2")
create_double_buffer_reader_op = startup_block.append_op(
type="create_double_buffer_reader",
inputs={"UnderlyingReader": batch_reader},
outputs={"Out": double_buffer})
out1 = main_block.create_var(
type=fluid.core.VarDesc.VarType.LOD_TENSOR, name="Out1")
out2 = main_block.create_var(
type=fluid.core.VarDesc.VarType.LOD_TENSOR, name="Out2")
read_op = block.append_op(
type="read", inputs={"Reader": batch_reader},
main_block.var("DoubleBuffer").desc.set_shapes(double_buffer.desc.shapes())
main_block.var("DoubleBuffer").desc.set_dtypes(double_buffer.desc.dtypes())
main_block.var("DoubleBuffer").desc.set_lod_levels(
double_buffer.desc.lod_levels())
read_op = main_block.append_op(
type="read",
inputs={"Reader": double_buffer},
outputs={"Out": [out1, out2]})
place = fluid.CPUPlace()
exe = fluid.Executor(place)
[res1, res2] = exe.run(prog, fetch_list=[out1, out2])
exe.run(startup_prog)
if not (res1.shape == (10, 2) and res2.shape == (10, 1)):
for i in range(1, 100):
[res1, res2] = exe.run(main_prog, fetch_list=[out1, out2])
if not (res1.shape == (10, 2) and res2.shape == (10, 1)):
exit(1)
exit(0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册