未验证 提交 88fa9c2e 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #11267 from JiayiFeng/fix_reader_bug

Fix a multi-thread bug in readers
......@@ -35,14 +35,15 @@ class ReaderBase {
class DecoratedReader : public ReaderBase {
public:
explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) {
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
void ReInit() override { reader_->ReInit(); }
protected:
ReaderBase* reader_;
std::shared_ptr<ReaderBase> reader_;
};
class FileReader : public ReaderBase {
......@@ -64,7 +65,7 @@ class ReaderHolder {
public:
void Reset(ReaderBase* reader) { reader_.reset(reader); }
ReaderBase* Get() const { return reader_.get(); }
std::shared_ptr<ReaderBase> Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_);
......@@ -76,7 +77,7 @@ class ReaderHolder {
}
private:
std::unique_ptr<ReaderBase> reader_;
std::shared_ptr<ReaderBase> reader_;
};
} // namespace framework
......
......@@ -20,7 +20,7 @@ namespace reader {
class BatchReader : public framework::DecoratedReader {
public:
BatchReader(ReaderBase* reader, int batch_size)
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size)
: DecoratedReader(reader), batch_size_(batch_size) {
buffer_.reserve(batch_size_);
}
......
......@@ -22,7 +22,8 @@ namespace reader {
class CustomReader : public framework::DecoratedReader {
public:
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block,
CustomReader(const std::shared_ptr<ReaderBase>& reader,
const framework::BlockDesc& sub_block,
const std::vector<std::string>& source_var_names,
const std::vector<std::string>& sink_var_names)
: DecoratedReader(reader),
......
......@@ -34,7 +34,8 @@ static constexpr size_t kChannelSize = 1; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader {
public:
explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
const std::shared_ptr<ReaderBase>& reader,
platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) {
cpu_tensor_cache_.resize(kCacheSize);
gpu_tensor_cache_.resize(kCacheSize);
......
......@@ -21,7 +21,7 @@ namespace reader {
class MultiPassReader : public framework::DecoratedReader {
public:
MultiPassReader(ReaderBase* reader, int pass_num)
MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
......
......@@ -23,7 +23,8 @@ namespace reader {
class ShuffleReader : public framework::DecoratedReader {
public:
ShuffleReader(ReaderBase* reader, size_t buffer_size, size_t seed = 0)
ShuffleReader(const std::shared_ptr<ReaderBase>& reader, size_t buffer_size,
size_t seed = 0)
: DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) {
VLOG(10) << "Create shuffle reader of " << reader_;
if (seed_ == 0) {
......
......@@ -21,7 +21,8 @@ namespace reader {
class ThreadedReader : public framework::DecoratedReader {
public:
explicit ThreadedReader(ReaderBase* reader) : DecoratedReader(reader) {}
explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader)
: DecoratedReader(reader) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册