From 8b397d16024f1d5a985e0cbc6c88c6560d7e7661 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 16 Mar 2018 14:48:17 +0800 Subject: [PATCH] Make recordio file reader thread-safe by default --- .../reader/create_recordio_file_reader_op.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index 0126ff7271b..986e1b7a21a 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { namespace reader { +template class RecordIOFileReader : public framework::FileReader { public: RecordIOFileReader(const std::string& filename, @@ -26,11 +27,19 @@ class RecordIOFileReader : public framework::FileReader { scanner_(filename), dev_ctx_(*platform::DeviceContextPool::Instance().Get( platform::CPUPlace())) { + if (ThreadSafe) { + mutex_.reset(new std::mutex()); + } LOG(INFO) << "Creating file reader" << filename; } void ReadNext(std::vector* out) override { - *out = framework::ReadFromRecordIO(scanner_, dev_ctx_); + if (ThreadSafe) { + std::lock_guard guard(*mutex_); + *out = framework::ReadFromRecordIO(scanner_, dev_ctx_); + } else { + *out = framework::ReadFromRecordIO(scanner_, dev_ctx_); + } } bool HasNext() const override { return scanner_.HasNext(); } @@ -38,6 +47,7 @@ class RecordIOFileReader : public framework::FileReader { void ReInit() override { scanner_.Reset(); } private: + std::unique_ptr mutex_; recordio::Scanner scanner_; const platform::DeviceContext& dev_ctx_; }; @@ -61,7 +71,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new RecordIOFileReader(filename, shapes)); + out->Reset(new RecordIOFileReader(filename, shapes)); } }; -- GitLab