reader_py.cc 4.5 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/reader_py.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/platform/place.h"
#include "pybind11/stl.h"

namespace paddle {
namespace pybind {

S
sneaxiy 已提交
27 28
class MultiDeviceFeedReader {
 public:
S
sneaxiy 已提交
29 30 31
  using ResultDictList =
      std::vector<std::unordered_map<std::string, framework::LoDTensor>>;

S
sneaxiy 已提交
32 33 34 35 36
  MultiDeviceFeedReader(
      const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> &queue,
      const std::vector<std::string> &names,
      const std::vector<platform::Place> &dst_places, bool use_double_buffer)
      : queue_(queue),
S
sneaxiy 已提交
37
        names_(names),
S
sneaxiy 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
        pool_(new ::ThreadPool(dst_places.size())) {
    std::shared_ptr<framework::ReaderBase> reader(
        new operators::reader::PyReader(queue));

    readers_.reserve(dst_places.size());
    for (auto &p : dst_places) {
      auto *holder = new framework::ReaderHolder();
      if (use_double_buffer) {
        holder->Reset(
            framework::MakeDecoratedReader<operators::reader::BufferedReader>(
                reader, p, 2));
      } else {
        if (platform::is_gpu_place(p)) {
          PADDLE_THROW(
              "Place cannot be CUDAPlace when use_double_buffer is False");
        }
        holder->Reset(reader);
      }
      readers_.emplace_back(holder);
    }
S
sneaxiy 已提交
58

S
sneaxiy 已提交
59 60 61 62
    futures_.resize(dst_places.size());
    ret_.resize(dst_places.size());
    ReadAsync();
  }
S
sneaxiy 已提交
63

S
sneaxiy 已提交
64 65
  ResultDictList ReadNext() {
    bool success = WaitFutures();
S
sneaxiy 已提交
66

S
sneaxiy 已提交
67 68
    if (!success) {
      return {};
S
sneaxiy 已提交
69 70
    }

S
sneaxiy 已提交
71 72 73 74 75
    ResultDictList result(ret_.size());
    for (size_t i = 0; i < ret_.size(); ++i) {
      for (size_t j = 0; j < names_.size(); ++j) {
        result[i].emplace(names_[j], std::move(ret_[i][j]));
      }
S
sneaxiy 已提交
76
    }
S
sneaxiy 已提交
77 78
    ReadAsync();
    return result;
S
sneaxiy 已提交
79 80
  }

S
sneaxiy 已提交
81 82 83 84 85 86 87 88 89 90
  void Reset() {
    Shutdown();
    Start();
    ReadAsync();
  }

  ~MultiDeviceFeedReader() {
    queue_->Close();
    pool_.reset();
  }
S
sneaxiy 已提交
91 92

 private:
S
sneaxiy 已提交
93 94 95 96 97 98 99
  bool WaitFutures() {
    bool success = true;
    for (auto &f : futures_) {
      success &= f.get();
    }
    return success;
  }
S
sneaxiy 已提交
100

S
sneaxiy 已提交
101 102
  void Shutdown() {
    for (auto &r : readers_) r->Shutdown();
S
sneaxiy 已提交
103
  }
S
sneaxiy 已提交
104 105 106

  void Start() {
    for (auto &r : readers_) r->Start();
S
sneaxiy 已提交
107 108
  }

S
sneaxiy 已提交
109 110 111 112 113 114 115 116 117
  void ReadAsync() {
    for (size_t i = 0; i < readers_.size(); ++i) {
      futures_[i] = pool_->enqueue([this, i] {
        readers_[i]->ReadNext(&ret_[i]);
        return !ret_[i].empty();
      });
    }
  }

S
sneaxiy 已提交
118
  std::shared_ptr<operators::reader::LoDTensorBlockingQueue> queue_;
S
sneaxiy 已提交
119 120 121 122
  std::vector<std::string> names_;
  std::unique_ptr<::ThreadPool> pool_;

  std::vector<std::unique_ptr<framework::ReaderHolder>> readers_;
S
sneaxiy 已提交
123

S
sneaxiy 已提交
124 125 126
  std::vector<std::future<bool>> futures_;
  std::vector<std::vector<framework::LoDTensor>> ret_;
};
S
sneaxiy 已提交
127 128 129 130 131 132 133 134 135 136 137 138

namespace py = pybind11;

void BindReader(py::module *module) {
  auto &m = *module;

  namespace reader = ::paddle::operators::reader;

  py::class_<framework::ReaderHolder>(m, "Reader", "")
      .def("start", &framework::ReaderHolder::Start)
      .def("reset", &framework::ReaderHolder::ResetAll);

S
sneaxiy 已提交
139 140
  py::class_<MultiDeviceFeedReader>(m, "MultiDeviceFeedReader", "")
      .def("read_next", &MultiDeviceFeedReader::ReadNext,
S
sneaxiy 已提交
141
           py::call_guard<py::gil_scoped_release>())
S
sneaxiy 已提交
142
      .def("reset", &MultiDeviceFeedReader::Reset,
S
sneaxiy 已提交
143 144 145
           py::call_guard<py::gil_scoped_release>());

  m.def("create_py_reader",
S
sneaxiy 已提交
146 147
        [](const std::shared_ptr<operators::reader::LoDTensorBlockingQueue>
               &queue,
S
sneaxiy 已提交
148
           const std::vector<std::string> &names,
S
sneaxiy 已提交
149 150
           const std::vector<platform::Place> &dst_places,
           bool use_double_buffer) {
S
sneaxiy 已提交
151
          return new MultiDeviceFeedReader(queue, names, dst_places,
S
sneaxiy 已提交
152
                                           use_double_buffer);
S
sneaxiy 已提交
153 154 155 156 157 158
        },
        py::return_value_policy::take_ownership);
}

}  // namespace pybind
}  // namespace paddle