提交 8b397d16 编写于 作者: Y Yu Yang

Make recordio file reader thread-safe by default

上级 8c9cd369
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
template <bool ThreadSafe>
class RecordIOFileReader : public framework::FileReader { class RecordIOFileReader : public framework::FileReader {
public: public:
RecordIOFileReader(const std::string& filename, RecordIOFileReader(const std::string& filename,
...@@ -26,11 +27,19 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -26,11 +27,19 @@ class RecordIOFileReader : public framework::FileReader {
scanner_(filename), scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get( dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) { platform::CPUPlace())) {
if (ThreadSafe) {
mutex_.reset(new std::mutex());
}
LOG(INFO) << "Creating file reader" << filename; LOG(INFO) << "Creating file reader" << filename;
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_); if (ThreadSafe) {
std::lock_guard<std::mutex> guard(*mutex_);
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
} else {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}
} }
bool HasNext() const override { return scanner_.HasNext(); } bool HasNext() const override { return scanner_.HasNext(); }
...@@ -38,6 +47,7 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -38,6 +47,7 @@ class RecordIOFileReader : public framework::FileReader {
void ReInit() override { scanner_.Reset(); } void ReInit() override { scanner_.Reset(); }
private: private:
std::unique_ptr<std::mutex> mutex_;
recordio::Scanner scanner_; recordio::Scanner scanner_;
const platform::DeviceContext& dev_ctx_; const platform::DeviceContext& dev_ctx_;
}; };
...@@ -61,7 +71,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { ...@@ -61,7 +71,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader(filename, shapes)); out->Reset(new RecordIOFileReader<true>(filename, shapes));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册