From 72b78154b257a746f09c5fccc3a1d495787fce8f Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Tue, 17 Jul 2018 13:57:31 +0800 Subject: [PATCH] Polish reader speed --- paddle/fluid/framework/lod_tensor.cc | 27 ++++++++++--------- paddle/fluid/framework/lod_tensor.h | 5 ++-- paddle/fluid/framework/lod_tensor_test.cc | 5 ++-- .../reader/create_recordio_file_reader_op.cc | 11 +++++--- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index cba0064f38..919029c38f 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -312,19 +312,22 @@ void WriteToRecordIO(recordio::Writer *writer, writer->Write(buffer.str()); } -std::vector ReadFromRecordIO( - recordio::Scanner *scanner, const platform::DeviceContext &dev_ctx) { - std::vector result; - if (scanner->HasNext()) { - std::istringstream sin(scanner->Next()); - uint32_t sz; - sin.read(reinterpret_cast(&sz), sizeof(uint32_t)); - result.resize(sz); - for (uint32_t i = 0; i < sz; ++i) { - DeserializeFromStream(sin, &result[i], dev_ctx); - } +bool ReadFromRecordIO(recordio::Scanner *scanner, + const platform::DeviceContext &dev_ctx, + std::vector *result_ptr) { + if (!scanner->HasNext()) { + return false; } - return result; + std::istringstream sin(scanner->Next()); + uint32_t sz; + sin.read(reinterpret_cast(&sz), sizeof(uint32_t)); + auto &result = *result_ptr; + result.resize(sz); + for (uint32_t i = 0; i < sz; ++i) { + DeserializeFromStream(sin, &result[i], dev_ctx); + } + + return true; } std::vector LoDTensor::SplitLoDTensor( diff --git a/paddle/fluid/framework/lod_tensor.h b/paddle/fluid/framework/lod_tensor.h index 4a2729373b..e9b473d547 100644 --- a/paddle/fluid/framework/lod_tensor.h +++ b/paddle/fluid/framework/lod_tensor.h @@ -223,8 +223,9 @@ extern void WriteToRecordIO(recordio::Writer* writer, const std::vector& tensor, const platform::DeviceContext& dev_ctx); -extern std::vector ReadFromRecordIO( - recordio::Scanner* scanner, const platform::DeviceContext& dev_ctx); +extern bool ReadFromRecordIO(recordio::Scanner* scanner, + const platform::DeviceContext& dev_ctx, + std::vector* result_ptr); /* * Convert between length-based LoD and offset-based LoD. diff --git a/paddle/fluid/framework/lod_tensor_test.cc b/paddle/fluid/framework/lod_tensor_test.cc index 38d3cd96d6..cd50aaa260 100644 --- a/paddle/fluid/framework/lod_tensor_test.cc +++ b/paddle/fluid/framework/lod_tensor_test.cc @@ -301,11 +301,12 @@ static void TestRecordIO() { { std::unique_ptr stream_ptr(stream); recordio::Scanner scanner(std::move(stream_ptr)); - auto tensors = ReadFromRecordIO(&scanner, ctx); + std::vector tensors; + ASSERT_TRUE(ReadFromRecordIO(&scanner, ctx, &tensors)); ASSERT_EQ(tensors.size(), static_cast(2)); assert_tensor_ok(tensors[0]); assert_tensor_ok(tensors[1]); - tensors = ReadFromRecordIO(&scanner, ctx); + ASSERT_TRUE(ReadFromRecordIO(&scanner, ctx, &tensors)); ASSERT_EQ(tensors.size(), static_cast(2)); assert_tensor_ok(tensors[0]); assert_tensor_ok(tensors[1]); diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index b32f09b225..a08a9dbd0d 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -33,11 +33,14 @@ class RecordIOFileReader : public framework::FileReader { protected: void ReadNextImpl(std::vector* out) override { + std::unique_ptr> guard; if (ThreadSafe) { - std::lock_guard guard(*mutex_); - *out = framework::ReadFromRecordIO(&scanner_, dev_ctx_); - } else { - *out = framework::ReadFromRecordIO(&scanner_, dev_ctx_); + guard.reset(new std::lock_guard(*mutex_)); + } + + bool ok = framework::ReadFromRecordIO(&scanner_, dev_ctx_, out); + if (!ok) { + out->clear(); } } -- GitLab