open_files_op.cc 7.3 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <cmath>
F
fengjiayi 已提交
16
#include <thread>  // NOLINT
17 18
#include "ThreadPool.h"
#include "paddle/fluid/framework/blocking_queue.h"
19
#include "paddle/fluid/operators/reader/blocking_queue.h"
F
fengjiayi 已提交
20 21 22 23 24 25
#include "paddle/fluid/operators/reader/reader_op_registry.h"

namespace paddle {
namespace operators {
namespace reader {

26
class IReaderContainer {
F
fengjiayi 已提交
27
 public:
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
  virtual ~IReaderContainer() {}
  virtual void AppendReader(
      std::unique_ptr<framework::ReaderBase>&& readers) = 0;
  virtual void Stop() = 0;
  virtual void Start() = 0;
  virtual void ReadNext(std::vector<framework::LoDTensor>* out) = 0;
};

class OrderedReaderContainer : public IReaderContainer {
 public:
  void AppendReader(std::unique_ptr<framework::ReaderBase>&& reader) override {
    pending_.emplace(std::move(reader));
  }

  void Stop() override {
    while (!pending_.empty()) {
      MoveFrontPendingToDone();
45
    }
F
fengjiayi 已提交
46 47
  }

48
  void Start() override { std::swap(done_, pending_); }
F
fengjiayi 已提交
49

50 51 52 53 54 55 56 57 58 59 60
  void ReadNext(std::vector<framework::LoDTensor>* out) override {
    if (!pending_.empty()) {
      pending_.front()->ReadNext(out);
      if (out->empty()) {
        MoveFrontPendingToDone();
        ReadNext(out);
      }
    } else {
      out->clear();
    }
  }
F
fengjiayi 已提交
61

F
fengjiayi 已提交
62
 private:
63 64 65 66 67 68 69 70 71
  void MoveFrontPendingToDone() {
    pending_.front()->Shutdown();
    pending_.front()->Start();
    done_.emplace(move(pending_.front()));
    pending_.pop();
  }

  std::queue<std::unique_ptr<framework::ReaderBase>> pending_;
  std::queue<std::unique_ptr<framework::ReaderBase>> done_;
F
fengjiayi 已提交
72 73
};

74 75
class PreemptiveReaderContainer : public IReaderContainer {
  using ReaderList = std::list<std::unique_ptr<framework::ReaderBase>>;
F
fengjiayi 已提交
76

77 78 79 80
  struct FutureItem {
    std::vector<framework::LoDTensor> data_;
    ReaderList::iterator reader_it_;
  };
F
fengjiayi 已提交
81

82
  using FutureList = std::list<std::future<FutureItem>>;
F
fengjiayi 已提交
83

84 85
 public:
  explicit PreemptiveReaderContainer(size_t thread_num) : pool_(thread_num) {}
F
fengjiayi 已提交
86

87 88 89 90 91 92 93 94 95 96 97 98
  void Stop() override {
    if (!pending_.empty()) {
      for (auto& reader : pending_) {
        reader->Shutdown();
      }
      for (auto& fu : futures_) {
        fu.wait();
      }
      futures_.clear();
      for (auto& reader : pending_) {
        reader->Start();
        done_.emplace_back(std::move(reader));
F
fengjiayi 已提交
99
      }
100 101 102 103
      pending_.clear();
      bool timeout;
      complete_queue_.PopAll(1000, &timeout);
      PADDLE_ENFORCE(!timeout);
F
fengjiayi 已提交
104 105
    }
  }
106 107 108 109

  void Start() override {
    for (auto& reader : done_) {
      AppendReader(std::move(reader));
F
fengjiayi 已提交
110
    }
111
    done_.clear();
F
fengjiayi 已提交
112
  }
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129

  void ReadNext(std::vector<framework::LoDTensor>* out) override {
    if (!pending_.empty()) {
      auto future_it = complete_queue_.Pop();
      FutureItem item = future_it->get();
      if (item.data_.empty()) {  // reader done.
        done_.emplace_back(std::move(*item.reader_it_));
        pending_.erase(item.reader_it_);
        futures_.erase(future_it);
        ReadNext(out);
      } else {
        *out = item.data_;
        // continue read async
        AsyncRead(item.reader_it_, &future_it);
      }
    } else {
      out->clear();
F
fengjiayi 已提交
130
    }
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
  }

 private:
  void AppendReader(std::unique_ptr<framework::ReaderBase>&& readers) override {
    pending_.emplace_back();
    auto reader_it = pending_.end();
    --reader_it;

    futures_.emplace_back();
    auto future_it = futures_.end();
    --future_it;

    AsyncRead(reader_it, &future_it);
  }

  void AsyncRead(const ReaderList::iterator& reader_it,
                 FutureList::iterator* future_it_ptr) {
    auto& future_it = *future_it_ptr;
    *future_it = pool_.enqueue([reader_it, future_it, this] {
      FutureItem item;
      item.reader_it_ = reader_it;
      (*reader_it)->ReadNext(&item.data_);
      if (item.data_.empty()) {
        (*reader_it)->Shutdown();
        (*reader_it)->Start();
      }
      complete_queue_.Push(future_it);
      return item;
    });
  }

  FutureList futures_;
  ThreadPool pool_;
  framework::BlockingQueue<FutureList::iterator> complete_queue_;
  std::list<std::unique_ptr<framework::ReaderBase>> pending_;
  std::list<std::unique_ptr<framework::ReaderBase>> done_;
};

class MultiFileReader : public framework::ReaderBase {
 public:
  MultiFileReader(const std::vector<std::string>& file_names,
                  std::unique_ptr<IReaderContainer>&& container)
      : container_(std::move(container)) {
    for (auto& fn : file_names) {
      container_->AppendReader(CreateReaderByFileName(fn));
F
fengjiayi 已提交
176 177
    }
  }
178

179 180 181 182 183
  ~MultiFileReader() { container_->Stop(); }

 protected:
  void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
    container_->ReadNext(out);
F
fengjiayi 已提交
184
  }
185 186 187 188 189 190
  void ShutdownImpl() override { container_->Stop(); }
  void StartImpl() override { container_->Start(); }

 private:
  std::unique_ptr<IReaderContainer> container_;
};
F
fengjiayi 已提交
191 192 193 194 195 196 197 198 199 200 201 202

class OpenFilesOp : public framework::OperatorBase {
 public:
  using framework::OperatorBase::OperatorBase;

 private:
  void RunImpl(const framework::Scope& scope,
               const platform::Place& dev_place) const override {
    const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
    const auto& ranks = Attr<std::vector<int>>("ranks");
    PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
    PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
F
fengjiayi 已提交
203
                      static_cast<int>(shape_concat.size()),
F
fengjiayi 已提交
204 205 206 207
                      "The accumulate of all ranks should be equal to the "
                      "shape concat's length.");
    const auto& file_names = Attr<std::vector<std::string>>("file_names");
    PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
208
    bool is_test = Attr<bool>("is_test");
F
fengjiayi 已提交
209 210 211

    auto* out = scope.FindVar(Output("Out"))
                    ->template GetMutable<framework::ReaderHolder>();
212 213 214 215 216 217 218 219 220 221
    std::unique_ptr<IReaderContainer> container;

    if (is_test) {
      container.reset(new OrderedReaderContainer());
    } else {
      container.reset(new PreemptiveReaderContainer(
          std::min(file_names.size(),
                   static_cast<size_t>(std::thread::hardware_concurrency()))));
    }

222
    out->Reset(
223
        std::make_shared<MultiFileReader>(file_names, std::move(container)));
F
fengjiayi 已提交
224 225 226
  }
};

227
class OpenFilesOpMaker : public FileReaderMakerBase {
Y
Yu Yang 已提交
228 229
 protected:
  void Apply() override {
230
    AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
231
    AddAttr<bool>("is_test", "Used for testing data.").SetDefault(false);
232

F
fengjiayi 已提交
233 234 235
    AddComment(R"DOC(
      OpenFiles Operator

Y
Yu Yang 已提交
236
      An OpenFilesOp creates a MultiFileReader, which is able to
F
fengjiayi 已提交
237 238 239 240 241 242 243 244 245 246 247 248
      read data multi-threaded from multiple files.
    )DOC");
  }
};

}  // namespace reader
}  // namespace operators
}  // namespace paddle

namespace reader = paddle::operators::reader;

REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
249
                              reader::OpenFilesOpMaker);