提交 8b397d16 编写于 作者: Y Yu Yang

Make recordio file reader thread-safe by default

上级 8c9cd369
......@@ -18,6 +18,7 @@
namespace paddle {
namespace operators {
namespace reader {
template <bool ThreadSafe>
class RecordIOFileReader : public framework::FileReader {
public:
RecordIOFileReader(const std::string& filename,
......@@ -26,18 +27,27 @@ class RecordIOFileReader : public framework::FileReader {
scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {
if (ThreadSafe) {
mutex_.reset(new std::mutex());
}
LOG(INFO) << "Creating file reader" << filename;
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (ThreadSafe) {
std::lock_guard<std::mutex> guard(*mutex_);
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
} else {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}
}
bool HasNext() const override { return scanner_.HasNext(); }
void ReInit() override { scanner_.Reset(); }
private:
std::unique_ptr<std::mutex> mutex_;
recordio::Scanner scanner_;
const platform::DeviceContext& dev_ctx_;
};
......@@ -61,7 +71,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader(filename, shapes));
out->Reset(new RecordIOFileReader<true>(filename, shapes));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册