diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index 0126ff7271b9ae6ce4c88831780dd3486e9dea8b..986e1b7a21a8e268baa84f56df5ee12af150fae1 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -18,6 +18,7 @@ namespace paddle { namespace operators { namespace reader { +template class RecordIOFileReader : public framework::FileReader { public: RecordIOFileReader(const std::string& filename, @@ -26,11 +27,19 @@ class RecordIOFileReader : public framework::FileReader { scanner_(filename), dev_ctx_(*platform::DeviceContextPool::Instance().Get( platform::CPUPlace())) { + if (ThreadSafe) { + mutex_.reset(new std::mutex()); + } LOG(INFO) << "Creating file reader" << filename; } void ReadNext(std::vector* out) override { - *out = framework::ReadFromRecordIO(scanner_, dev_ctx_); + if (ThreadSafe) { + std::lock_guard guard(*mutex_); + *out = framework::ReadFromRecordIO(scanner_, dev_ctx_); + } else { + *out = framework::ReadFromRecordIO(scanner_, dev_ctx_); + } } bool HasNext() const override { return scanner_.HasNext(); } @@ -38,6 +47,7 @@ class RecordIOFileReader : public framework::FileReader { void ReInit() override { scanner_.Reset(); } private: + std::unique_ptr mutex_; recordio::Scanner scanner_; const platform::DeviceContext& dev_ctx_; }; @@ -61,7 +71,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new RecordIOFileReader(filename, shapes)); + out->Reset(new RecordIOFileReader(filename, shapes)); } };