未验证 提交 72b78154 编写于 作者: Y yuyang18

Polish reader speed

上级 e576345f
...@@ -312,19 +312,22 @@ void WriteToRecordIO(recordio::Writer *writer, ...@@ -312,19 +312,22 @@ void WriteToRecordIO(recordio::Writer *writer,
writer->Write(buffer.str()); writer->Write(buffer.str());
} }
std::vector<LoDTensor> ReadFromRecordIO( bool ReadFromRecordIO(recordio::Scanner *scanner,
recordio::Scanner *scanner, const platform::DeviceContext &dev_ctx) { const platform::DeviceContext &dev_ctx,
std::vector<LoDTensor> result; std::vector<LoDTensor> *result_ptr) {
if (scanner->HasNext()) { if (!scanner->HasNext()) {
return false;
}
std::istringstream sin(scanner->Next()); std::istringstream sin(scanner->Next());
uint32_t sz; uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t)); sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
auto &result = *result_ptr;
result.resize(sz); result.resize(sz);
for (uint32_t i = 0; i < sz; ++i) { for (uint32_t i = 0; i < sz; ++i) {
DeserializeFromStream(sin, &result[i], dev_ctx); DeserializeFromStream(sin, &result[i], dev_ctx);
} }
}
return result; return true;
} }
std::vector<LoDTensor> LoDTensor::SplitLoDTensor( std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
......
...@@ -223,8 +223,9 @@ extern void WriteToRecordIO(recordio::Writer* writer, ...@@ -223,8 +223,9 @@ extern void WriteToRecordIO(recordio::Writer* writer,
const std::vector<LoDTensor>& tensor, const std::vector<LoDTensor>& tensor,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
extern std::vector<LoDTensor> ReadFromRecordIO( extern bool ReadFromRecordIO(recordio::Scanner* scanner,
recordio::Scanner* scanner, const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx,
std::vector<LoDTensor>* result_ptr);
/* /*
* Convert between length-based LoD and offset-based LoD. * Convert between length-based LoD and offset-based LoD.
......
...@@ -301,11 +301,12 @@ static void TestRecordIO() { ...@@ -301,11 +301,12 @@ static void TestRecordIO() {
{ {
std::unique_ptr<std::istream> stream_ptr(stream); std::unique_ptr<std::istream> stream_ptr(stream);
recordio::Scanner scanner(std::move(stream_ptr)); recordio::Scanner scanner(std::move(stream_ptr));
auto tensors = ReadFromRecordIO(&scanner, ctx); std::vector<framework::LoDTensor> tensors;
ASSERT_TRUE(ReadFromRecordIO(&scanner, ctx, &tensors));
ASSERT_EQ(tensors.size(), static_cast<size_t>(2)); ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
assert_tensor_ok(tensors[0]); assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]); assert_tensor_ok(tensors[1]);
tensors = ReadFromRecordIO(&scanner, ctx); ASSERT_TRUE(ReadFromRecordIO(&scanner, ctx, &tensors));
ASSERT_EQ(tensors.size(), static_cast<size_t>(2)); ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
assert_tensor_ok(tensors[0]); assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]); assert_tensor_ok(tensors[1]);
......
...@@ -33,11 +33,14 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -33,11 +33,14 @@ class RecordIOFileReader : public framework::FileReader {
protected: protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
std::unique_ptr<std::lock_guard<std::mutex>> guard;
if (ThreadSafe) { if (ThreadSafe) {
std::lock_guard<std::mutex> guard(*mutex_); guard.reset(new std::lock_guard<std::mutex>(*mutex_));
*out = framework::ReadFromRecordIO(&scanner_, dev_ctx_); }
} else {
*out = framework::ReadFromRecordIO(&scanner_, dev_ctx_); bool ok = framework::ReadFromRecordIO(&scanner_, dev_ctx_, out);
if (!ok) {
out->clear();
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册