提交 2ea4a5d9 编写于 作者: Y Yu Yang

Polish double buffer reader

上级 46ae4075
...@@ -24,11 +24,16 @@ static constexpr size_t kDoubleBufferSize = 2; ...@@ -24,11 +24,16 @@ static constexpr size_t kDoubleBufferSize = 2;
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
explicit DoubleBufferReader(ReaderBase* reader) explicit DoubleBufferReader(
: DecoratedReader(reader), ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>( : DecoratedReader(reader), place_(target_place) {
kDoubleBufferSize)) { start_thread();
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this); }
void start_thread() {
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize);
std::thread prefetch([this] { PrefetchThreadFunc(); });
prefetch.detach(); prefetch.detach();
} }
...@@ -43,6 +48,8 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -43,6 +48,8 @@ class DoubleBufferReader : public framework::DecoratedReader {
void PrefetchThreadFunc(); void PrefetchThreadFunc();
framework::Channel<std::vector<framework::LoDTensor>>* buffer_; framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
platform::Place place_;
mutable std::vector<framework::LoDTensor> local_buffer_;
}; };
class CreateDoubleBufferReaderOp : public framework::OperatorBase { class CreateDoubleBufferReaderOp : public framework::OperatorBase {
...@@ -56,7 +63,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -56,7 +63,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new DoubleBufferReader(underlying_reader.Get()));
auto place_str = Attr<std::string>("place");
platform::Place place;
if (place_str == "CPU") {
place = platform::CPUPlace();
} else {
std::istringstream sin(place_str);
sin.seekg(std::string("CUDA:").size(), std::ios::beg);
size_t num;
sin >> num;
place = platform::CUDAPlace(static_cast<int>(num));
}
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place));
} }
}; };
...@@ -71,44 +91,65 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -71,44 +91,65 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
It launches another thread to execute the 'underlying reader' asynchronously, It launches another thread to execute the 'underlying reader' asynchronously,
which prevents reading process from blocking subsequent training. which prevents reading process from blocking subsequent training.
)DOC"); )DOC");
std::unordered_set<std::string> enum_range;
constexpr size_t kMaxCUDADevs = 128;
for (size_t i = 0; i < kMaxCUDADevs; ++i) {
enum_range.insert(string::Sprintf("CUDA:%d", i));
}
enum_range.insert("CPU");
AddAttr<std::string>("place", "The double buffer place, default is CPU")
.SetDefault("CPU")
.InEnum({enum_range});
} }
}; };
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear(); out->clear();
buffer_->Receive(out); if (local_buffer_.empty()) {
buffer_->Receive(out);
} else {
*out = local_buffer_;
local_buffer_.clear();
}
} }
void DoubleBufferReader::ReInit() { void DoubleBufferReader::ReInit() {
reader_->ReInit(); reader_->ReInit();
buffer_->Close(); buffer_->Close();
// The existing prefetch thread will terminate for the buffer_ is closed. start_thread();
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize);
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
prefetch.detach();
} }
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
while (true) { while (reader_->HasNext()) {
std::vector<framework::LoDTensor> batch; std::vector<framework::LoDTensor> batch;
reader_->ReadNext(&batch); reader_->ReadNext(&batch);
if (batch.empty()) { if (platform::is_gpu_place(place_)) {
// EOF std::vector<framework::LoDTensor> gpu_batch;
buffer_->Close(); gpu_batch.resize(batch.size());
VLOG(5) << "Reached the end of the file. The prefetch thread terminates."; for (size_t i = 0; i < batch.size(); ++i) {
break; framework::TensorCopy(batch[i], place_, &gpu_batch[i]);
gpu_batch[i].set_lod(batch[i].lod());
}
} }
if (!buffer_->Send(&batch)) { if (!buffer_->Send(&batch)) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The " VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread terminates."; "prefetch thread terminates.";
break; break;
} }
} }
buffer_->Close();
} }
bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); } bool DoubleBufferReader::HasNext() const {
if (local_buffer_.empty()) {
bool ok = buffer_->Receive(&local_buffer_);
return ok;
} else {
return true;
}
}
} // namespace reader } // namespace reader
} // namespace operators } // namespace operators
......
...@@ -21,7 +21,7 @@ from ..executor import global_scope ...@@ -21,7 +21,7 @@ from ..executor import global_scope
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'read_file', 'create_shuffle_reader' 'read_file', 'create_shuffle_reader', 'create_double_buffer_reader'
] ]
...@@ -306,6 +306,14 @@ def create_shuffle_reader(reader, buffer_size): ...@@ -306,6 +306,14 @@ def create_shuffle_reader(reader, buffer_size):
{'buffer_size': int(buffer_size)}) {'buffer_size': int(buffer_size)})
def create_double_buffer_reader(reader, place=None):
attrs = dict()
if place is not None:
attrs['place'] = str(place).upper()
return __create_decorated_reader__('create_double_buffer_reader', reader,
attrs)
def read_file(file_obj): def read_file(file_obj):
helper = LayerHelper('read_file') helper = LayerHelper('read_file')
out = [ out = [
......
...@@ -13,9 +13,10 @@ ...@@ -13,9 +13,10 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.v2.dataset.mnist as mnist
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
class TestRecordIO(unittest.TestCase): class TestRecordIO(unittest.TestCase):
...@@ -53,7 +54,12 @@ class TestRecordIO(unittest.TestCase): ...@@ -53,7 +54,12 @@ class TestRecordIO(unittest.TestCase):
fluid.optimizer.Adam(learning_rate=1e-3).minimize(avg_loss) fluid.optimizer.Adam(learning_rate=1e-3).minimize(avg_loss)
exe = fluid.Executor(fluid.CPUPlace()) if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
avg_loss_np = [] avg_loss_np = []
...@@ -69,3 +75,7 @@ class TestRecordIO(unittest.TestCase): ...@@ -69,3 +75,7 @@ class TestRecordIO(unittest.TestCase):
def test_shuffle_reader(self): def test_shuffle_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.create_shuffle_reader(reader, buffer_size=200)) self.test_main(decorator_callback=lambda reader: fluid.layers.create_shuffle_reader(reader, buffer_size=200))
def test_double_buffer_reader(self):
self.test_main(decorator_callback=lambda reader: fluid.layers.create_double_buffer_reader(reader,
place='cuda:0' if fluid.core.is_compiled_with_cuda() else 'cpu'))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册