open_files_op.cc 6.9 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

F
fengjiayi 已提交
15 16
#include <thread>  // NOLINT

17
#include "paddle/fluid/operators/reader/blocking_queue.h"
F
fengjiayi 已提交
18 19 20 21 22 23
#include "paddle/fluid/operators/reader/reader_op_registry.h"

namespace paddle {
namespace operators {
namespace reader {

F
fengjiayi 已提交
24
class MultiFileReader : public framework::ReaderBase {
F
fengjiayi 已提交
25
 public:
26
  MultiFileReader(const std::vector<std::string>& file_names, size_t thread_num,
F
fengjiayi 已提交
27
                  size_t buffer_size)
28
      : buffer_size_(buffer_size) {
F
fengjiayi 已提交
29
    readers_.reserve(file_names.size());
30
    for (const std::string& f_name : file_names) {
31
      readers_.emplace_back(CreateReaderByFileName(f_name));
32
    }
F
fengjiayi 已提交
33
    prefetchers_.resize(thread_num);
F
fengjiayi 已提交
34
    StartNewScheduler();
F
fengjiayi 已提交
35 36
  }

37
  void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
F
fengjiayi 已提交
38

F
fengjiayi 已提交
39
  ~MultiFileReader() { EndScheduler(); }
F
fengjiayi 已提交
40

F
fengjiayi 已提交
41
 private:
F
fengjiayi 已提交
42 43 44 45 46 47
  void ShutdownImpl() override { EndScheduler(); }

  void StartImpl() override { StartNewScheduler(); }

  void StartNewScheduler();
  void EndScheduler();
F
fengjiayi 已提交
48
  void ScheduleThreadFunc();
49
  void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx);
F
fengjiayi 已提交
50

51
  std::vector<std::unique_ptr<framework::ReaderBase>> readers_;
F
fengjiayi 已提交
52 53
  std::thread scheduler_;
  std::vector<std::thread> prefetchers_;
54
  size_t buffer_size_;
55
  reader::BlockingQueue<size_t>* waiting_reader_idx_;
56 57
  reader::BlockingQueue<size_t>* available_thread_idx_;
  reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
F
fengjiayi 已提交
58 59
};

60
void MultiFileReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
F
fengjiayi 已提交
61 62
  if (!buffer_->Receive(out)) {
    out->clear();
F
fengjiayi 已提交
63 64 65
  }
}

F
fengjiayi 已提交
66
void MultiFileReader::StartNewScheduler() {
F
fengjiayi 已提交
67
  size_t thread_num = prefetchers_.size();
68
  waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
69 70 71
  available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
  buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
      buffer_size_);
F
fengjiayi 已提交
72

73 74
  for (size_t i = 0; i < readers_.size(); ++i) {
    waiting_reader_idx_->Send(i);
F
fengjiayi 已提交
75
  }
76
  waiting_reader_idx_->Close();
F
fengjiayi 已提交
77
  for (size_t i = 0; i < thread_num; ++i) {
78
    available_thread_idx_->Send(i);
F
fengjiayi 已提交
79 80
  }

F
fengjiayi 已提交
81 82 83
  scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}

F
fengjiayi 已提交
84
void MultiFileReader::EndScheduler() {
F
fengjiayi 已提交
85 86
  available_thread_idx_->Close();
  buffer_->Close();
87
  waiting_reader_idx_->Close();
F
fengjiayi 已提交
88 89 90
  if (scheduler_.joinable()) {
    scheduler_.join();
  }
F
fengjiayi 已提交
91 92
  delete buffer_;
  delete available_thread_idx_;
93
  delete waiting_reader_idx_;
F
fengjiayi 已提交
94 95
}

F
fengjiayi 已提交
96 97
void MultiFileReader::ScheduleThreadFunc() {
  VLOG(5) << "MultiFileReader schedule thread starts.";
F
fengjiayi 已提交
98
  size_t completed_thread_num = 0;
F
fengjiayi 已提交
99 100 101 102 103 104
  size_t thread_idx;
  while (available_thread_idx_->Receive(&thread_idx)) {
    std::thread& prefetcher = prefetchers_[thread_idx];
    if (prefetcher.joinable()) {
      prefetcher.join();
    }
105 106
    size_t reader_idx;
    if (waiting_reader_idx_->Receive(&reader_idx)) {
F
fengjiayi 已提交
107
      // Still have files to read. Start a new prefetch thread.
108 109
      prefetcher = std::thread([this, reader_idx, thread_idx] {
        PrefetchThreadFunc(reader_idx, thread_idx);
F
fengjiayi 已提交
110
      });
F
fengjiayi 已提交
111 112 113
    } else {
      // No more file to read.
      ++completed_thread_num;
F
fengjiayi 已提交
114
      if (completed_thread_num == prefetchers_.size()) {
F
fengjiayi 已提交
115
        buffer_->Close();
F
fengjiayi 已提交
116 117 118 119
        break;
      }
    }
  }
120
  // If users invoke Shutdown() when scheduler is running, it will close the
F
fengjiayi 已提交
121 122 123 124 125 126 127
  // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
  // to release their resource. So a check is needed before scheduler ends.
  for (auto& p : prefetchers_) {
    if (p.joinable()) {
      p.join();
    }
  }
F
fengjiayi 已提交
128
  VLOG(5) << "MultiFileReader schedule thread terminates.";
F
fengjiayi 已提交
129 130
}

131 132 133
void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
  VLOG(5) << "The prefetch thread of file idx '" << reader_idx << "' starts.";
  std::unique_ptr<framework::ReaderBase>& reader = readers_[reader_idx];
F
fengjiayi 已提交
134
  while (true) {
F
fengjiayi 已提交
135 136
    std::vector<framework::LoDTensor> ins;
    reader->ReadNext(&ins);
F
fengjiayi 已提交
137
    if (ins.empty()) {
138 139
      reader->Shutdown();
      reader->Start();
F
fengjiayi 已提交
140 141
      break;
    }
142
    try {
143
      buffer_->Send(std::move(ins));
144
    } catch (paddle::platform::EnforceNotMet e) {
F
fengjiayi 已提交
145
      VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
146 147
                 "thread of file idx '"
              << reader_idx << "' will terminate.";
F
fengjiayi 已提交
148 149 150
      break;
    }
  }
151

152
  if (!available_thread_idx_->Send(thread_idx)) {
F
fengjiayi 已提交
153 154 155
    VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
               "Fail to send thread_idx.";
  }
156 157
  VLOG(5) << "The prefetch thread of file idx '" << reader_idx
          << "' terminates.";
F
fengjiayi 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170
}

class OpenFilesOp : public framework::OperatorBase {
 public:
  using framework::OperatorBase::OperatorBase;

 private:
  void RunImpl(const framework::Scope& scope,
               const platform::Place& dev_place) const override {
    const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
    const auto& ranks = Attr<std::vector<int>>("ranks");
    PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
    PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
F
fengjiayi 已提交
171
                      static_cast<int>(shape_concat.size()),
F
fengjiayi 已提交
172 173 174 175 176
                      "The accumulate of all ranks should be equal to the "
                      "shape concat's length.");
    const auto& file_names = Attr<std::vector<std::string>>("file_names");
    PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
    const size_t thread_num = Attr<int>("thread_num");
177
    const size_t buffer_size = Attr<int>("buffer_size");
F
fengjiayi 已提交
178 179 180

    auto* out = scope.FindVar(Output("Out"))
                    ->template GetMutable<framework::ReaderHolder>();
181 182
    out->Reset(
        std::make_shared<MultiFileReader>(file_names, thread_num, buffer_size));
F
fengjiayi 已提交
183 184 185
  }
};

186
class OpenFilesOpMaker : public FileReaderMakerBase {
Y
Yu Yang 已提交
187 188
 protected:
  void Apply() override {
189 190 191
    AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
    AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
        .GreaterThan(0);
192
    AddAttr<int>("buffer_size", "The size of prefetch buffer.").GreaterThan(0);
193

F
fengjiayi 已提交
194 195 196
    AddComment(R"DOC(
      OpenFiles Operator

Y
Yu Yang 已提交
197
      An OpenFilesOp creates a MultiFileReader, which is able to
F
fengjiayi 已提交
198 199 200 201 202 203 204 205 206 207 208 209
      read data multi-threaded from multiple files.
    )DOC");
  }
};

}  // namespace reader
}  // namespace operators
}  // namespace paddle

namespace reader = paddle::operators::reader;

REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
210
                              reader::OpenFilesOpMaker);