open_files_op.cc 7.9 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"

namespace paddle {
namespace operators {
namespace reader {

class MultipleReader : public framework::ReaderBase {
 public:
F
fengjiayi 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
  class ThreadBufferMap {
   public:
    std::vector<framework::LoDTensor>& operator[](
        const std::thread::id& thread_id) {
      std::lock_guard<std::mutex> lock(mutex_);
      return buffer_[thread_id];
    }

    void Clear() { buffer_.clear(); }

   private:
    std::mutex mutex_;
    std::unordered_map<std::thread::id, std::vector<framework::LoDTensor>>
        buffer_;
  };

F
fengjiayi 已提交
40
  MultipleReader(const std::vector<std::string>& file_names,
41 42 43
                 const std::vector<framework::DDim>& dims, size_t thread_num,
                 size_t buffer_size)
      : file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
F
fengjiayi 已提交
44
    prefetchers_.resize(thread_num);
F
fengjiayi 已提交
45 46 47 48 49 50 51
    StartNewScheduler();
  }

  void ReadNext(std::vector<framework::LoDTensor>* out) override;
  bool HasNext() const override;
  void ReInit() override;

F
fengjiayi 已提交
52 53
  ~MultipleReader() { EndScheduler(); }

F
fengjiayi 已提交
54 55
 private:
  void StartNewScheduler();
F
fengjiayi 已提交
56
  void EndScheduler();
F
fengjiayi 已提交
57
  void ScheduleThreadFunc();
F
fengjiayi 已提交
58
  void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
F
fengjiayi 已提交
59 60 61

  std::vector<std::string> file_names_;
  std::vector<framework::DDim> dims_;
F
fengjiayi 已提交
62 63
  std::thread scheduler_;
  std::vector<std::thread> prefetchers_;
64
  size_t buffer_size_;
F
fengjiayi 已提交
65
  framework::Channel<size_t>* waiting_file_idx_;
F
fengjiayi 已提交
66
  framework::Channel<size_t>* available_thread_idx_;
F
fengjiayi 已提交
67
  framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
F
fengjiayi 已提交
68
  mutable ThreadBufferMap thread_buffer_map_;
F
fengjiayi 已提交
69 70 71 72 73 74
};

void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
  if (!HasNext()) {
    PADDLE_THROW("There is no next data!");
  }
F
fengjiayi 已提交
75 76 77
  auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
  *out = thread_local_buffer;
  thread_local_buffer.clear();
F
fengjiayi 已提交
78 79 80
}

bool MultipleReader::HasNext() const {
F
fengjiayi 已提交
81 82 83
  auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
  return thread_local_buffer.empty() ? buffer_->Receive(&thread_local_buffer)
                                     : true;
F
fengjiayi 已提交
84 85 86
}

void MultipleReader::ReInit() {
F
fengjiayi 已提交
87
  EndScheduler();
F
fengjiayi 已提交
88
  thread_buffer_map_.Clear();
F
fengjiayi 已提交
89 90 91 92
  StartNewScheduler();
}

void MultipleReader::StartNewScheduler() {
F
fengjiayi 已提交
93
  size_t thread_num = prefetchers_.size();
F
fengjiayi 已提交
94
  waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
F
fengjiayi 已提交
95
  available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
F
fengjiayi 已提交
96
  buffer_ =
97
      framework::MakeChannel<std::vector<framework::LoDTensor>>(buffer_size_);
F
fengjiayi 已提交
98 99 100 101 102

  for (size_t i = 0; i < file_names_.size(); ++i) {
    waiting_file_idx_->Send(&i);
  }
  waiting_file_idx_->Close();
F
fengjiayi 已提交
103 104
  for (size_t i = 0; i < thread_num; ++i) {
    available_thread_idx_->Send(&i);
F
fengjiayi 已提交
105 106
  }

F
fengjiayi 已提交
107 108 109 110 111 112 113
  scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}

void MultipleReader::EndScheduler() {
  available_thread_idx_->Close();
  buffer_->Close();
  waiting_file_idx_->Close();
F
fengjiayi 已提交
114 115 116
  if (scheduler_.joinable()) {
    scheduler_.join();
  }
F
fengjiayi 已提交
117 118 119
  delete buffer_;
  delete available_thread_idx_;
  delete waiting_file_idx_;
F
fengjiayi 已提交
120 121 122 123 124
}

void MultipleReader::ScheduleThreadFunc() {
  VLOG(5) << "MultipleReader schedule thread starts.";
  size_t completed_thread_num = 0;
F
fengjiayi 已提交
125 126 127 128 129 130
  size_t thread_idx;
  while (available_thread_idx_->Receive(&thread_idx)) {
    std::thread& prefetcher = prefetchers_[thread_idx];
    if (prefetcher.joinable()) {
      prefetcher.join();
    }
F
fengjiayi 已提交
131 132 133 134
    size_t file_idx;
    if (waiting_file_idx_->Receive(&file_idx)) {
      // Still have files to read. Start a new prefetch thread.
      std::string file_name = file_names_[file_idx];
F
fengjiayi 已提交
135 136 137
      prefetcher = std::thread([this, file_name, thread_idx] {
        PrefetchThreadFunc(file_name, thread_idx);
      });
F
fengjiayi 已提交
138 139 140
    } else {
      // No more file to read.
      ++completed_thread_num;
F
fengjiayi 已提交
141
      if (completed_thread_num == prefetchers_.size()) {
F
fengjiayi 已提交
142
        buffer_->Close();
F
fengjiayi 已提交
143 144 145 146
        break;
      }
    }
  }
F
fengjiayi 已提交
147 148 149 150 151 152 153 154
  // If users invoke ReInit() when scheduler is running, it will close the
  // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
  // to release their resource. So a check is needed before scheduler ends.
  for (auto& p : prefetchers_) {
    if (p.joinable()) {
      p.join();
    }
  }
F
fengjiayi 已提交
155 156 157
  VLOG(5) << "MultipleReader schedule thread terminates.";
}

F
fengjiayi 已提交
158 159
void MultipleReader::PrefetchThreadFunc(std::string file_name,
                                        size_t thread_idx) {
F
fengjiayi 已提交
160 161 162 163 164 165
  VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
  std::unique_ptr<framework::ReaderBase> reader =
      CreateReaderByFileName(file_name, dims_);
  while (reader->HasNext()) {
    std::vector<framework::LoDTensor> ins;
    reader->ReadNext(&ins);
166 167 168
    try {
      buffer_->Send(&ins);
    } catch (paddle::platform::EnforceNotMet e) {
F
fengjiayi 已提交
169 170 171 172 173 174
      VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
                 "thread of file '"
              << file_name << "' will terminate.";
      break;
    }
  }
175 176 177 178

  try {
    available_thread_idx_->Send(&thread_idx);
  } catch (paddle::platform::EnforceNotMet e) {
F
fengjiayi 已提交
179 180 181
    VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
               "Fail to send thread_idx.";
  }
F
fengjiayi 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195
  VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
}

class OpenFilesOp : public framework::OperatorBase {
 public:
  using framework::OperatorBase::OperatorBase;

 private:
  void RunImpl(const framework::Scope& scope,
               const platform::Place& dev_place) const override {
    const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
    const auto& ranks = Attr<std::vector<int>>("ranks");
    PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
    PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
F
fengjiayi 已提交
196
                      static_cast<int>(shape_concat.size()),
F
fengjiayi 已提交
197 198 199 200 201
                      "The accumulate of all ranks should be equal to the "
                      "shape concat's length.");
    const auto& file_names = Attr<std::vector<std::string>>("file_names");
    PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
    const size_t thread_num = Attr<int>("thread_num");
202
    const size_t buffer_size = Attr<int>("buffer_size");
F
fengjiayi 已提交
203 204 205

    auto* out = scope.FindVar(Output("Out"))
                    ->template GetMutable<framework::ReaderHolder>();
206 207 208
    out->Reset(new MultipleReader(file_names,
                                  RestoreShapes(shape_concat, ranks),
                                  thread_num, buffer_size));
F
fengjiayi 已提交
209 210 211
  }
};

212
class OpenFilesOpMaker : public FileReaderMakerBase {
F
fengjiayi 已提交
213 214
 public:
  OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
215 216 217 218
      : FileReaderMakerBase(op_proto, op_checker) {
    AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
    AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
        .GreaterThan(0);
219
    AddAttr<int>("buffer_size", "The size of prefetch buffer.").GreaterThan(0);
220

F
fengjiayi 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
    AddComment(R"DOC(
      OpenFiles Operator

      An OpenFilesOp creates a MultipleReader, which is able to 
      read data multi-threaded from multiple files.
    )DOC");
  }
};

}  // namespace reader
}  // namespace operators
}  // namespace paddle

namespace reader = paddle::operators::reader;

REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
237
                              reader::OpenFilesOpMaker);