reader_py.cc 17.6 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/reader_py.h"
16

Z
Zeng Jinle 已提交
17
#include <exception>
S
sneaxiy 已提交
18
#include <memory>
S
sneaxiy 已提交
19
#include <string>
S
sneaxiy 已提交
20 21
#include <unordered_map>
#include <utility>
S
sneaxiy 已提交
22
#include <vector>
23

Z
Zeng Jinle 已提交
24
#include "Python.h"
25 26
#include "boost/optional.hpp"
#include "gflags/gflags.h"
S
sneaxiy 已提交
27
#include "paddle/fluid/framework/reader.h"
28 29
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/tracer.h"
S
sneaxiy 已提交
30
#include "paddle/fluid/operators/reader/buffered_reader.h"
31
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
S
sneaxiy 已提交
32 33
#include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/platform/place.h"
34
#include "paddle/phi/core/ddim.h"
S
sneaxiy 已提交
35 36
#include "pybind11/stl.h"

37
DECLARE_bool(reader_queue_speed_test_mode);
38

39 40 41
// disable auto conversion to list in Python
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray);

S
sneaxiy 已提交
42 43 44
namespace paddle {
namespace pybind {

Z
Zeng Jinle 已提交
45
namespace py = pybind11;
46 47 48 49
namespace reader = operators::reader;

// Check whether the tensor shape matches the VarDesc shape
// Return the different shape if exists
50
static paddle::optional<std::vector<int64_t>> DiffTensorShapeWithVarDesc(
51 52
    const framework::LoDTensor &tensor,
    const framework::VarDesc &var_desc,
53 54 55 56 57 58 59 60
    size_t num_places) {
  auto tensor_shape = tensor.dims();
  auto desc_shape = var_desc.GetShape();

  int64_t rank = tensor_shape.size();

  if (UNLIKELY(rank == 0)) {
    if (desc_shape.size() != 0) {  // Tensor rank = 0 but desc does not match
61
      return phi::vectorize<int64_t>(tensor_shape);
62
    } else {
63
      return paddle::none;
64 65 66
    }
  }

67 68
  PADDLE_ENFORCE_GE(tensor_shape[0],
                    0,
69 70 71 72 73 74 75 76 77 78 79
                    platform::errors::InvalidArgument(
                        "Tensor shape at dim 0 must not be less than 0"));

  if (!tensor.lod().empty()) {
    tensor_shape[0] = -1;  // unknown shape
  } else {
    int64_t split_size = (tensor_shape[0] + num_places - 1) / num_places;
    int64_t remainder = (split_size == 0 ? 0 : tensor_shape[0] % split_size);
    tensor_shape[0] = split_size;
    if (desc_shape[0] >= 0) {  // need check dim 0
      if (tensor_shape[0] != desc_shape[0]) {
80
        return phi::vectorize<int64_t>(tensor_shape);
81 82 83 84
      }

      if (remainder > 0) {
        tensor_shape[0] = remainder;
85
        return phi::vectorize<int64_t>(tensor_shape);
86 87 88 89 90 91
      }
    }
  }

  for (int64_t idx = 1; idx < rank; ++idx) {
    PADDLE_ENFORCE_GE(
92 93
        tensor_shape[idx],
        0,
94 95 96
        platform::errors::InvalidArgument(
            "Tensor shape at dim %d must not be less than 0", idx));
    if (desc_shape[idx] >= 0 && tensor_shape[idx] != desc_shape[idx]) {
97
      return phi::vectorize<int64_t>(tensor_shape);
98 99 100
    }
  }

101
  return paddle::none;
102 103 104 105 106 107 108 109 110 111 112 113 114
}

static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
    const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue, size_t idx) {
  return queue;
}

static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
    const std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>
        &queue,
    size_t idx) {
  return queue->GetQueue(idx);
}
Z
Zeng Jinle 已提交
115

116
template <typename QueueType>
S
sneaxiy 已提交
117 118
class MultiDeviceFeedReader {
 public:
S
sneaxiy 已提交
119 120
  using ResultDictList =
      std::vector<std::unordered_map<std::string, framework::LoDTensor>>;
121
  using ResultList = std::vector<std::vector<framework::LoDTensor>>;
S
sneaxiy 已提交
122

123 124 125 126
  static constexpr bool kKeepOrder =
      std::is_same<QueueType,
                   reader::OrderedMultiDeviceLoDTensorBlockingQueue>::value;

S
sneaxiy 已提交
127
  MultiDeviceFeedReader(
128
      const std::shared_ptr<QueueType> &queue,
S
sneaxiy 已提交
129
      const std::vector<std::string> &names,
130 131 132
      const std::vector<std::vector<int>> &shapes,
      const std::vector<framework::proto::VarType::Type> &dtypes,
      const std::vector<bool> &need_check_feed,
133 134 135 136
      const std::vector<platform::Place> &dst_places,
      bool use_double_buffer,
      bool drop_last,
      bool pin_memory = false)
S
sneaxiy 已提交
137
      : queue_(queue),
S
sneaxiy 已提交
138
        names_(names),
139
        pool_(new ::ThreadPool(dst_places.size())),
140 141
        drop_last_(drop_last),
        pin_memory_(pin_memory) {
142 143
    std::vector<framework::DDim> dims;
    for (auto &shape : shapes) {
144
      dims.push_back(phi::make_ddim(shape));
145
    }
146 147 148 149 150 151 152 153 154

    auto first_reader = std::make_shared<reader::PyReader>(
        GetQueue(queue, 0), dims, dtypes, need_check_feed);

    auto create_or_get_reader = [&](size_t idx) {
      if (idx == 0 ||
          std::is_same<QueueType, reader::LoDTensorBlockingQueue>::value) {
        return first_reader;
      } else {
155 156
        return std::make_shared<reader::PyReader>(
            GetQueue(queue, idx), dims, dtypes, need_check_feed);
157 158
      }
    };
S
sneaxiy 已提交
159 160

    readers_.reserve(dst_places.size());
161 162
    for (size_t i = 0; i < dst_places.size(); ++i) {
      auto &p = dst_places[i];
S
sneaxiy 已提交
163
      auto *holder = new framework::ReaderHolder();
164
      auto reader = create_or_get_reader(i);
S
sneaxiy 已提交
165
      if (use_double_buffer) {
166
        VLOG(10) << "Creating " << i << "-th BufferedReader";
S
sneaxiy 已提交
167 168
        holder->Reset(
            framework::MakeDecoratedReader<operators::reader::BufferedReader>(
169
                reader, p, 2, pin_memory_));
S
sneaxiy 已提交
170 171
      } else {
        if (platform::is_gpu_place(p)) {
172 173
          PADDLE_THROW(platform::errors::PermissionDenied(
              "Place cannot be CUDAPlace when use_double_buffer is False"));
S
sneaxiy 已提交
174 175 176 177 178
        }
        holder->Reset(reader);
      }
      readers_.emplace_back(holder);
    }
S
sneaxiy 已提交
179

S
sneaxiy 已提交
180 181
    futures_.resize(dst_places.size());
    ret_.resize(dst_places.size());
Z
Zeng Jinle 已提交
182
    exceptions_.assign(dst_places.size(), nullptr);
S
sneaxiy 已提交
183 184
    ReadAsync();
  }
S
sneaxiy 已提交
185

186 187
  bool DropLast() const { return drop_last_; }

S
sneaxiy 已提交
188
  ResultDictList ReadNext() {
Z
Zeng Jinle 已提交
189
    CheckNextStatus();
190 191
    ResultDictList result;
    result.reserve(ret_.size());
S
sneaxiy 已提交
192
    for (size_t i = 0; i < ret_.size(); ++i) {
193 194 195 196 197 198 199
      if (ret_[i].empty()) {
        if (!kKeepOrder) result.emplace_back();
        continue;
      }

      result.emplace_back();
      auto &ret = result.back();
200 201
      PADDLE_ENFORCE_EQ(names_.size(),
                        ret_[i].size(),
202 203 204 205 206 207 208
                        platform::errors::InvalidArgument(
                            "The sample number of reader's input data and the "
                            "input number of feed list are not equal.\n"
                            "Possible reasons are:\n"
                            "  The generator is decorated by `paddle.batch` "
                            "and configured by `set_batch_generator`, but here "
                            "need to used `set_sample_list_generator`."));
S
sneaxiy 已提交
209
      for (size_t j = 0; j < names_.size(); ++j) {
210
        ret.emplace(names_[j], std::move(ret_[i][j]));
S
sneaxiy 已提交
211
      }
S
sneaxiy 已提交
212
    }
S
sneaxiy 已提交
213 214
    ReadAsync();
    return result;
S
sneaxiy 已提交
215 216
  }

217
  ResultList ReadNextList() {
Z
Zeng Jinle 已提交
218
    CheckNextStatus();
219 220 221
    ResultList result;
    result.reserve(ret_.size());
    for (size_t i = 0; i < ret_.size(); ++i) {
222
      if (kKeepOrder && ret_[i].empty()) continue;
223 224 225 226 227 228
      result.emplace_back(std::move(ret_[i]));
    }
    ReadAsync();
    return result;
  }

S
sneaxiy 已提交
229 230 231 232 233 234
  void Reset() {
    Shutdown();
    Start();
    ReadAsync();
  }

235 236 237 238
  void Shutdown() {
    for (auto &r : readers_) r->Shutdown();
  }

S
sneaxiy 已提交
239 240 241 242
  ~MultiDeviceFeedReader() {
    queue_->Close();
    pool_.reset();
  }
S
sneaxiy 已提交
243 244

 private:
Z
Zeng Jinle 已提交
245 246 247 248 249 250 251 252
  enum Status {
    kSuccess = 0,   // Read next data successfully
    kEOF = 1,       // Reach EOF
    kException = 2  // Exception raises when reading
  };

  Status WaitFutures(std::exception_ptr *excep) {
    *excep = nullptr;
253
    size_t success_num = 0;
Z
Zeng Jinle 已提交
254 255 256 257
    for (size_t i = 0; i < futures_.size(); ++i) {
      auto each_status = futures_[i].get();
      if (UNLIKELY(each_status != Status::kSuccess)) {
        if (UNLIKELY(each_status == Status::kException)) {
258 259 260 261 262
          PADDLE_ENFORCE_NOT_NULL(
              exceptions_[i],
              platform::errors::NotFound("exceptions_[%d] is NULL, but the "
                                         "result status is Status::kException",
                                         i));
Z
Zeng Jinle 已提交
263 264 265
          *excep = exceptions_[i];
          exceptions_[i] = nullptr;
        }
266 267
      } else {
        ++success_num;
Z
Zeng Jinle 已提交
268 269 270 271 272
      }
    }

    if (UNLIKELY(*excep)) {
      return Status::kException;
273 274 275 276
    }

    if (drop_last_) {
      return success_num == futures_.size() ? Status::kSuccess : Status::kEOF;
Z
Zeng Jinle 已提交
277
    } else {
278
      return success_num > 0 ? Status::kSuccess : Status::kEOF;
S
sneaxiy 已提交
279 280
    }
  }
S
sneaxiy 已提交
281

S
sneaxiy 已提交
282 283
  void Start() {
    for (auto &r : readers_) r->Start();
S
sneaxiy 已提交
284 285
  }

S
sneaxiy 已提交
286 287 288
  void ReadAsync() {
    for (size_t i = 0; i < readers_.size(); ++i) {
      futures_[i] = pool_->enqueue([this, i] {
Z
Zeng Jinle 已提交
289 290 291 292 293 294 295
        try {
          readers_[i]->ReadNext(&ret_[i]);
          return ret_[i].empty() ? Status::kEOF : Status::kSuccess;
        } catch (...) {
          exceptions_[i] = std::current_exception();
          return Status::kException;
        }
S
sneaxiy 已提交
296 297 298 299
      });
    }
  }

Z
Zeng Jinle 已提交
300 301 302 303 304
  void CheckNextStatus() {
    std::exception_ptr excep;
    Status status = WaitFutures(&excep);

    if (UNLIKELY(excep)) {
305 306
      PADDLE_ENFORCE_EQ(status,
                        Status::kException,
307 308 309
                        platform::errors::NotFound(
                            "The exception raised is not NULL, but "
                            "the result status is not Status::kException"));
Z
Zeng Jinle 已提交
310 311 312 313 314 315 316 317 318
      std::rethrow_exception(excep);
    }

    if (UNLIKELY(status == Status::kEOF)) {
      VLOG(2) << "Raise StopIteration Exception in Python";
      py::gil_scoped_acquire guard;
      throw py::stop_iteration();
    }

319 320
    PADDLE_ENFORCE_EQ(status,
                      Status::kSuccess,
321 322 323
                      platform::errors::NotFound(
                          "The function executed sucessfully, but "
                          "the result status is not Status::kSuccess"));
Z
Zeng Jinle 已提交
324 325
  }

326
  std::shared_ptr<QueueType> queue_;
S
sneaxiy 已提交
327 328 329 330
  std::vector<std::string> names_;
  std::unique_ptr<::ThreadPool> pool_;

  std::vector<std::unique_ptr<framework::ReaderHolder>> readers_;
S
sneaxiy 已提交
331

Z
Zeng Jinle 已提交
332 333 334
  std::vector<std::future<Status>> futures_;
  std::vector<std::exception_ptr> exceptions_;

S
sneaxiy 已提交
335
  std::vector<std::vector<framework::LoDTensor>> ret_;
336
  bool drop_last_;
337
  bool pin_memory_;
S
sneaxiy 已提交
338
};
S
sneaxiy 已提交
339

340 341
template <typename QueueType>
void BindMultiDeviceReader(py::module *module, const char *reader_name) {
S
sneaxiy 已提交
342 343
  auto &m = *module;

344 345
  using ReaderType = MultiDeviceFeedReader<QueueType>;
  py::class_<ReaderType>(m, reader_name, "")
346 347
      .def("read_next",
           &ReaderType::ReadNext,
S
sneaxiy 已提交
348
           py::call_guard<py::gil_scoped_release>())
349 350
      .def("read_next_list",
           &ReaderType::ReadNextList,
351
           py::call_guard<py::gil_scoped_release>())
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
      .def(
          "read_next_var_list",
          [](ReaderType &self) {
            auto result_list = self.ReadNextList();
            auto &tensor_list = result_list[0];
            std::vector<std::shared_ptr<imperative::VarBase>> var_list;
            var_list.reserve(tensor_list.size());
            auto func = [](framework::LoDTensor &lod_tensor) {
              std::string act_name =
                  imperative::GetCurrentTracer()->GenerateUniqueName(
                      "generated_var");
              auto new_var = std::make_shared<imperative::VarBase>(act_name);
              new_var->SetPersistable(false);
              new_var->SetType(framework::proto::VarType::LOD_TENSOR);
              new_var->SetDataType(
                  framework::TransToProtoVarType(lod_tensor.dtype()));
              auto *tensor =
                  new_var->MutableVar()->GetMutable<framework::LoDTensor>();
              *tensor = std::move(lod_tensor);
              return new_var;
            };
            for (auto &tensor : tensor_list) {
              var_list.emplace_back(func(tensor));
            }
            return var_list;
          },
          py::call_guard<py::gil_scoped_release>())
379 380 381 382
      .def(
          "reset", &ReaderType::Reset, py::call_guard<py::gil_scoped_release>())
      .def("shutdown",
           &ReaderType::Shutdown,
383 384 385 386 387 388
           py::call_guard<py::gil_scoped_release>());
}

void BindReader(py::module *module) {
  auto &m = *module;

389 390 391 392 393 394 395 396 397 398 399
  m.def("diff_tensor_shape",
        [](const framework::LoDTensor &tensor,
           const framework::VarDesc &var_desc,
           size_t num_places) -> py::object {
          auto diff = DiffTensorShapeWithVarDesc(tensor, var_desc, num_places);
          if (diff) {
            return py::cast(std::move(diff.get()));
          } else {
            return py::cast(nullptr);
          }
        });
400 401 402

  m.def(
      "init_lod_tensor_blocking_queue",
403 404
      [](framework::Variable &var,
         size_t capacity,
405 406 407 408 409 410 411 412 413 414 415 416 417 418
         bool is_ordered) -> py::object {
        VLOG(1) << "init_lod_tensor_blocking_queue";
        if (is_ordered) {
          auto *holder = var.GetMutable<
              reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder>();
          holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
          return py::cast(holder->GetQueue());
        } else {
          auto *holder = var.GetMutable<reader::LoDTensorBlockingQueueHolder>();
          holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
          return py::cast(holder->GetQueue());
        }
      },
      py::return_value_policy::copy);
419 420 421 422 423 424 425 426

  py::class_<framework::ReaderHolder>(m, "Reader", "")
      .def("start", &framework::ReaderHolder::Start)
      .def("reset", &framework::ReaderHolder::ResetAll);

  py::class_<reader::LoDTensorBlockingQueue,
             std::shared_ptr<reader::LoDTensorBlockingQueue>>(
      m, "LoDTensorBlockingQueue", "")
427 428 429 430 431 432 433
      .def(
          "push",
          [](reader::LoDTensorBlockingQueue &self,
             const std::vector<framework::LoDTensor> &lod_tensor_vec) {
            return self.Push(lod_tensor_vec);
          },
          py::call_guard<py::gil_scoped_release>())
434 435 436 437
      .def("size", &reader::LoDTensorBlockingQueue::Size)
      .def("capacity", &reader::LoDTensorBlockingQueue::Cap)
      .def("close", &reader::LoDTensorBlockingQueue::Close)
      .def("kill", &reader::LoDTensorBlockingQueue::Kill)
438 439
      .def("wait_for_inited",
           &reader::LoDTensorBlockingQueue::WaitForInited,
S
sneaxiy 已提交
440 441
           py::call_guard<py::gil_scoped_release>());

442 443 444
  py::class_<reader::OrderedMultiDeviceLoDTensorBlockingQueue,
             std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>>(
      m, "OrderedMultiDeviceLoDTensorBlockingQueue", "")
445 446 447 448 449 450 451
      .def(
          "push",
          [](reader::OrderedMultiDeviceLoDTensorBlockingQueue &self,
             const std::vector<framework::LoDTensor> &lod_tensor_vec) {
            return self.Push(lod_tensor_vec);
          },
          py::call_guard<py::gil_scoped_release>())
452 453 454 455 456 457 458 459 460 461 462 463 464 465
      .def("size", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Size)
      .def("capacity", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Cap)
      .def("close", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Close)
      .def("kill", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Kill)
      .def("wait_for_inited",
           &reader::OrderedMultiDeviceLoDTensorBlockingQueue::WaitForInited,
           py::call_guard<py::gil_scoped_release>())
      .def("reset", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Reset);

  BindMultiDeviceReader<reader::LoDTensorBlockingQueue>(
      module, "MultiDeviceFeedReader");
  BindMultiDeviceReader<reader::OrderedMultiDeviceLoDTensorBlockingQueue>(
      module, "OrderedMultiDeviceFeedReader");

466 467 468 469 470 471 472
  m.def(
      "create_py_reader",
      [](const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue,
         const std::vector<std::string> &names,
         const std::vector<std::vector<int>> &shapes,
         const std::vector<framework::proto::VarType::Type> &dtypes,
         const std::vector<bool> &need_check_feed,
473 474 475 476
         const std::vector<platform::Place> &dst_places,
         bool use_double_buffer,
         bool drop_last,
         bool pin_memory) {
477
        return new MultiDeviceFeedReader<reader::LoDTensorBlockingQueue>(
478 479 480 481 482 483 484 485 486
            queue,
            names,
            shapes,
            dtypes,
            need_check_feed,
            dst_places,
            use_double_buffer,
            drop_last,
            pin_memory);
487 488
      },
      py::return_value_policy::take_ownership);
489 490 491 492 493 494 495 496 497

  m.def(
      "create_py_reader",
      [](const std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>
             &queue,
         const std::vector<std::string> &names,
         const std::vector<std::vector<int>> &shapes,
         const std::vector<framework::proto::VarType::Type> &dtypes,
         const std::vector<bool> &need_check_feed,
498 499 500 501
         const std::vector<platform::Place> &dst_places,
         bool use_double_buffer,
         bool drop_last,
         bool pin_memory) {
502 503
        queue->SetDeviceCount(dst_places.size());
        return new MultiDeviceFeedReader<
504 505 506 507 508 509 510 511 512
            reader::OrderedMultiDeviceLoDTensorBlockingQueue>(queue,
                                                              names,
                                                              shapes,
                                                              dtypes,
                                                              need_check_feed,
                                                              dst_places,
                                                              use_double_buffer,
                                                              drop_last,
                                                              pin_memory);
513 514
      },
      py::return_value_policy::take_ownership);
S
sneaxiy 已提交
515 516 517 518
}

}  // namespace pybind
}  // namespace paddle