reader_py.cc 8.2 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/reader_py.h"
Z
Zeng Jinle 已提交
16
#include <exception>
S
sneaxiy 已提交
17
#include <memory>
S
sneaxiy 已提交
18
#include <string>
S
sneaxiy 已提交
19 20
#include <unordered_map>
#include <utility>
S
sneaxiy 已提交
21
#include <vector>
Z
Zeng Jinle 已提交
22
#include "Python.h"
23
#include "paddle/fluid/framework/ddim.h"
S
sneaxiy 已提交
24
#include "paddle/fluid/framework/reader.h"
25 26
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/tracer.h"
S
sneaxiy 已提交
27 28 29 30 31 32 33 34
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/platform/place.h"
#include "pybind11/stl.h"

namespace paddle {
namespace pybind {

Z
Zeng Jinle 已提交
35 36
namespace py = pybind11;

S
sneaxiy 已提交
37 38
class MultiDeviceFeedReader {
 public:
S
sneaxiy 已提交
39 40
  using ResultDictList =
      std::vector<std::unordered_map<std::string, framework::LoDTensor>>;
41
  using ResultList = std::vector<std::vector<framework::LoDTensor>>;
S
sneaxiy 已提交
42

S
sneaxiy 已提交
43 44 45
  MultiDeviceFeedReader(
      const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> &queue,
      const std::vector<std::string> &names,
46 47 48
      const std::vector<std::vector<int>> &shapes,
      const std::vector<framework::proto::VarType::Type> &dtypes,
      const std::vector<bool> &need_check_feed,
S
sneaxiy 已提交
49 50
      const std::vector<platform::Place> &dst_places, bool use_double_buffer)
      : queue_(queue),
S
sneaxiy 已提交
51
        names_(names),
S
sneaxiy 已提交
52
        pool_(new ::ThreadPool(dst_places.size())) {
53 54 55 56
    std::vector<framework::DDim> dims;
    for (auto &shape : shapes) {
      dims.push_back(framework::make_ddim(shape));
    }
S
sneaxiy 已提交
57
    std::shared_ptr<framework::ReaderBase> reader(
58
        new operators::reader::PyReader(queue, dims, dtypes, need_check_feed));
S
sneaxiy 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75

    readers_.reserve(dst_places.size());
    for (auto &p : dst_places) {
      auto *holder = new framework::ReaderHolder();
      if (use_double_buffer) {
        holder->Reset(
            framework::MakeDecoratedReader<operators::reader::BufferedReader>(
                reader, p, 2));
      } else {
        if (platform::is_gpu_place(p)) {
          PADDLE_THROW(
              "Place cannot be CUDAPlace when use_double_buffer is False");
        }
        holder->Reset(reader);
      }
      readers_.emplace_back(holder);
    }
S
sneaxiy 已提交
76

S
sneaxiy 已提交
77 78
    futures_.resize(dst_places.size());
    ret_.resize(dst_places.size());
Z
Zeng Jinle 已提交
79
    exceptions_.assign(dst_places.size(), nullptr);
S
sneaxiy 已提交
80 81
    ReadAsync();
  }
S
sneaxiy 已提交
82

S
sneaxiy 已提交
83
  ResultDictList ReadNext() {
Z
Zeng Jinle 已提交
84
    CheckNextStatus();
S
sneaxiy 已提交
85 86 87 88 89
    ResultDictList result(ret_.size());
    for (size_t i = 0; i < ret_.size(); ++i) {
      for (size_t j = 0; j < names_.size(); ++j) {
        result[i].emplace(names_[j], std::move(ret_[i][j]));
      }
S
sneaxiy 已提交
90
    }
S
sneaxiy 已提交
91 92
    ReadAsync();
    return result;
S
sneaxiy 已提交
93 94
  }

95
  ResultList ReadNextList() {
Z
Zeng Jinle 已提交
96
    CheckNextStatus();
97 98 99 100 101 102 103 104 105
    ResultList result;
    result.reserve(ret_.size());
    for (size_t i = 0; i < ret_.size(); ++i) {
      result.emplace_back(std::move(ret_[i]));
    }
    ReadAsync();
    return result;
  }

S
sneaxiy 已提交
106 107 108 109 110 111 112 113 114 115
  void Reset() {
    Shutdown();
    Start();
    ReadAsync();
  }

  ~MultiDeviceFeedReader() {
    queue_->Close();
    pool_.reset();
  }
S
sneaxiy 已提交
116 117

 private:
Z
Zeng Jinle 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
  enum Status {
    kSuccess = 0,   // Read next data successfully
    kEOF = 1,       // Reach EOF
    kException = 2  // Exception raises when reading
  };

  Status WaitFutures(std::exception_ptr *excep) {
    bool is_success = true;
    *excep = nullptr;
    for (size_t i = 0; i < futures_.size(); ++i) {
      auto each_status = futures_[i].get();
      if (UNLIKELY(each_status != Status::kSuccess)) {
        is_success = false;
        if (UNLIKELY(each_status == Status::kException)) {
          PADDLE_ENFORCE_NOT_NULL(exceptions_[i]);
          *excep = exceptions_[i];
          exceptions_[i] = nullptr;
        }
      }
    }

    if (UNLIKELY(*excep)) {
      return Status::kException;
    } else {
      return is_success ? Status::kSuccess : Status::kEOF;
S
sneaxiy 已提交
143 144
    }
  }
S
sneaxiy 已提交
145

S
sneaxiy 已提交
146 147
  void Shutdown() {
    for (auto &r : readers_) r->Shutdown();
S
sneaxiy 已提交
148
  }
S
sneaxiy 已提交
149 150 151

  void Start() {
    for (auto &r : readers_) r->Start();
S
sneaxiy 已提交
152 153
  }

S
sneaxiy 已提交
154 155 156
  void ReadAsync() {
    for (size_t i = 0; i < readers_.size(); ++i) {
      futures_[i] = pool_->enqueue([this, i] {
Z
Zeng Jinle 已提交
157 158 159 160 161 162 163
        try {
          readers_[i]->ReadNext(&ret_[i]);
          return ret_[i].empty() ? Status::kEOF : Status::kSuccess;
        } catch (...) {
          exceptions_[i] = std::current_exception();
          return Status::kException;
        }
S
sneaxiy 已提交
164 165 166 167
      });
    }
  }

Z
Zeng Jinle 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
  void CheckNextStatus() {
    std::exception_ptr excep;
    Status status = WaitFutures(&excep);

    if (UNLIKELY(excep)) {
      PADDLE_ENFORCE_EQ(status, Status::kException);
      std::rethrow_exception(excep);
    }

    if (UNLIKELY(status == Status::kEOF)) {
      VLOG(2) << "Raise StopIteration Exception in Python";
      py::gil_scoped_acquire guard;
      throw py::stop_iteration();
    }

    PADDLE_ENFORCE_EQ(status, Status::kSuccess);
  }

S
sneaxiy 已提交
186
  std::shared_ptr<operators::reader::LoDTensorBlockingQueue> queue_;
S
sneaxiy 已提交
187 188 189 190
  std::vector<std::string> names_;
  std::unique_ptr<::ThreadPool> pool_;

  std::vector<std::unique_ptr<framework::ReaderHolder>> readers_;
S
sneaxiy 已提交
191

Z
Zeng Jinle 已提交
192 193 194
  std::vector<std::future<Status>> futures_;
  std::vector<std::exception_ptr> exceptions_;

S
sneaxiy 已提交
195 196
  std::vector<std::vector<framework::LoDTensor>> ret_;
};
S
sneaxiy 已提交
197 198 199 200 201 202 203 204 205 206

void BindReader(py::module *module) {
  auto &m = *module;

  namespace reader = ::paddle::operators::reader;

  py::class_<framework::ReaderHolder>(m, "Reader", "")
      .def("start", &framework::ReaderHolder::Start)
      .def("reset", &framework::ReaderHolder::ResetAll);

S
sneaxiy 已提交
207 208
  py::class_<MultiDeviceFeedReader>(m, "MultiDeviceFeedReader", "")
      .def("read_next", &MultiDeviceFeedReader::ReadNext,
S
sneaxiy 已提交
209
           py::call_guard<py::gil_scoped_release>())
210 211
      .def("read_next_list", &MultiDeviceFeedReader::ReadNextList,
           py::call_guard<py::gil_scoped_release>())
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
      .def("read_next_var_list",
           [](MultiDeviceFeedReader &self) {
             auto result_list = self.ReadNextList();
             auto &tensor_list = result_list[0];
             std::vector<std::shared_ptr<imperative::VarBase>> var_list;
             var_list.reserve(tensor_list.size());
             auto func = [](framework::LoDTensor &lod_tensor) {
               std::string act_name =
                   imperative::GetCurrentTracer()->GenerateUniqueName(
                       "generated_var");
               auto new_var = std::make_shared<imperative::VarBase>(act_name);
               new_var->SetPersistable(false);
               new_var->SetType(framework::proto::VarType::LOD_TENSOR);
               new_var->SetDataType(lod_tensor.type());
               auto *tensor =
                   new_var->MutableVar()->GetMutable<framework::LoDTensor>();
               *tensor = std::move(lod_tensor);
               return new_var;
             };
             for (auto &tensor : tensor_list) {
               var_list.emplace_back(func(tensor));
             }
             return var_list;
           },
           py::call_guard<py::gil_scoped_release>())
S
sneaxiy 已提交
237
      .def("reset", &MultiDeviceFeedReader::Reset,
S
sneaxiy 已提交
238 239 240
           py::call_guard<py::gil_scoped_release>());

  m.def("create_py_reader",
S
sneaxiy 已提交
241 242
        [](const std::shared_ptr<operators::reader::LoDTensorBlockingQueue>
               &queue,
S
sneaxiy 已提交
243
           const std::vector<std::string> &names,
244 245 246
           const std::vector<std::vector<int>> &shapes,
           const std::vector<framework::proto::VarType::Type> &dtypes,
           const std::vector<bool> &need_check_feed,
S
sneaxiy 已提交
247 248
           const std::vector<platform::Place> &dst_places,
           bool use_double_buffer) {
249 250
          return new MultiDeviceFeedReader(queue, names, shapes, dtypes,
                                           need_check_feed, dst_places,
S
sneaxiy 已提交
251
                                           use_double_buffer);
S
sneaxiy 已提交
252 253 254 255 256 257
        },
        py::return_value_policy::take_ownership);
}

}  // namespace pybind
}  // namespace paddle