reader_py.cc 18.2 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/reader_py.h"
16

Z
Zeng Jinle 已提交
17
#include <exception>
S
sneaxiy 已提交
18
#include <memory>
S
sneaxiy 已提交
19
#include <string>
S
sneaxiy 已提交
20 21
#include <unordered_map>
#include <utility>
S
sneaxiy 已提交
22
#include <vector>
23

Z
Zeng Jinle 已提交
24
#include "Python.h"
25

26
#include "gflags/gflags.h"
S
sneaxiy 已提交
27
#include "paddle/fluid/framework/reader.h"
28 29
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/tracer.h"
S
sneaxiy 已提交
30
#include "paddle/fluid/operators/reader/buffered_reader.h"
31
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
S
sneaxiy 已提交
32 33
#include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/platform/place.h"
34
#include "paddle/phi/core/ddim.h"
35
#include "paddle/phi/core/flags.h"
S
sneaxiy 已提交
36 37
#include "pybind11/stl.h"

38
PHI_DECLARE_bool(reader_queue_speed_test_mode);
39

40 41 42
// disable auto conversion to list in Python
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray);

S
sneaxiy 已提交
43 44 45
namespace paddle {
namespace pybind {

Z
Zeng Jinle 已提交
46
namespace py = pybind11;
47 48
namespace reader = operators::reader;

49
static paddle::optional<std::vector<int64_t>> DiffTensorShape(
50
    const phi::DenseTensor &tensor,
51
    const std::vector<int64_t> &target_shape,
52 53 54 55 56 57
    size_t num_places) {
  auto tensor_shape = tensor.dims();

  int64_t rank = tensor_shape.size();

  if (UNLIKELY(rank == 0)) {
58
    if (!target_shape.empty()) {  // Tensor rank = 0 but desc does not match
59
      return phi::vectorize<int64_t>(tensor_shape);
60
    } else {
61
      return paddle::none;
62 63 64
    }
  }

65 66
  PADDLE_ENFORCE_GE(tensor_shape[0],
                    0,
67 68 69 70 71 72 73 74 75
                    platform::errors::InvalidArgument(
                        "Tensor shape at dim 0 must not be less than 0"));

  if (!tensor.lod().empty()) {
    tensor_shape[0] = -1;  // unknown shape
  } else {
    int64_t split_size = (tensor_shape[0] + num_places - 1) / num_places;
    int64_t remainder = (split_size == 0 ? 0 : tensor_shape[0] % split_size);
    tensor_shape[0] = split_size;
76 77
    if (target_shape[0] >= 0) {  // need check dim 0
      if (tensor_shape[0] != target_shape[0]) {
78
        return phi::vectorize<int64_t>(tensor_shape);
79 80 81 82
      }

      if (remainder > 0) {
        tensor_shape[0] = remainder;
83
        return phi::vectorize<int64_t>(tensor_shape);
84 85 86 87 88 89
      }
    }
  }

  for (int64_t idx = 1; idx < rank; ++idx) {
    PADDLE_ENFORCE_GE(
90 91
        tensor_shape[idx],
        0,
92 93
        platform::errors::InvalidArgument(
            "Tensor shape at dim %d must not be less than 0", idx));
94
    if (target_shape[idx] >= 0 && tensor_shape[idx] != target_shape[idx]) {
95
      return phi::vectorize<int64_t>(tensor_shape);
96 97 98
    }
  }

99
  return paddle::none;
100 101
}

102 103 104 105 106 107 108 109 110 111
// Check whether the tensor shape matches the VarDesc shape
// Return the different shape if exists
static paddle::optional<std::vector<int64_t>> DiffTensorShapeWithVarDesc(
    const phi::DenseTensor &tensor,
    const framework::VarDesc &var_desc,
    size_t num_places) {
  auto desc_shape = var_desc.GetShape();
  return DiffTensorShape(tensor, desc_shape, num_places);
}

112 113 114 115 116 117 118 119 120 121 122
static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
    const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue, size_t idx) {
  return queue;
}

static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
    const std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>
        &queue,
    size_t idx) {
  return queue->GetQueue(idx);
}
Z
Zeng Jinle 已提交
123

124
template <typename QueueType>
S
sneaxiy 已提交
125 126
class MultiDeviceFeedReader {
 public:
S
sneaxiy 已提交
127
  using ResultDictList =
128
      std::vector<std::unordered_map<std::string, phi::DenseTensor>>;
129
  using ResultList = std::vector<paddle::framework::LoDTensorArray>;
S
sneaxiy 已提交
130

131 132 133 134
  static constexpr bool kKeepOrder =
      std::is_same<QueueType,
                   reader::OrderedMultiDeviceLoDTensorBlockingQueue>::value;

S
sneaxiy 已提交
135
  MultiDeviceFeedReader(
136
      const std::shared_ptr<QueueType> &queue,
S
sneaxiy 已提交
137
      const std::vector<std::string> &names,
138 139 140
      const std::vector<std::vector<int>> &shapes,
      const std::vector<framework::proto::VarType::Type> &dtypes,
      const std::vector<bool> &need_check_feed,
141 142 143 144
      const std::vector<platform::Place> &dst_places,
      bool use_double_buffer,
      bool drop_last,
      bool pin_memory = false)
S
sneaxiy 已提交
145
      : queue_(queue),
S
sneaxiy 已提交
146
        names_(names),
147
        pool_(new ::ThreadPool(dst_places.size())),
148 149
        drop_last_(drop_last),
        pin_memory_(pin_memory) {
150 151
    std::vector<framework::DDim> dims;
    for (auto &shape : shapes) {
152
      dims.push_back(phi::make_ddim(shape));
153
    }
154 155 156 157 158 159 160 161 162

    auto first_reader = std::make_shared<reader::PyReader>(
        GetQueue(queue, 0), dims, dtypes, need_check_feed);

    auto create_or_get_reader = [&](size_t idx) {
      if (idx == 0 ||
          std::is_same<QueueType, reader::LoDTensorBlockingQueue>::value) {
        return first_reader;
      } else {
163 164
        return std::make_shared<reader::PyReader>(
            GetQueue(queue, idx), dims, dtypes, need_check_feed);
165 166
      }
    };
S
sneaxiy 已提交
167 168

    readers_.reserve(dst_places.size());
169 170
    for (size_t i = 0; i < dst_places.size(); ++i) {
      auto &p = dst_places[i];
S
sneaxiy 已提交
171
      auto *holder = new framework::ReaderHolder();
172
      auto reader = create_or_get_reader(i);
S
sneaxiy 已提交
173
      if (use_double_buffer) {
174
        VLOG(10) << "Creating " << i << "-th BufferedReader";
S
sneaxiy 已提交
175 176
        holder->Reset(
            framework::MakeDecoratedReader<operators::reader::BufferedReader>(
177
                reader, p, 2, pin_memory_));
S
sneaxiy 已提交
178 179
      } else {
        if (platform::is_gpu_place(p)) {
180 181
          PADDLE_THROW(platform::errors::PermissionDenied(
              "Place cannot be CUDAPlace when use_double_buffer is False"));
S
sneaxiy 已提交
182 183 184 185 186
        }
        holder->Reset(reader);
      }
      readers_.emplace_back(holder);
    }
S
sneaxiy 已提交
187

S
sneaxiy 已提交
188 189
    futures_.resize(dst_places.size());
    ret_.resize(dst_places.size());
Z
Zeng Jinle 已提交
190
    exceptions_.assign(dst_places.size(), nullptr);
S
sneaxiy 已提交
191 192
    ReadAsync();
  }
S
sneaxiy 已提交
193

194 195
  bool DropLast() const { return drop_last_; }

S
sneaxiy 已提交
196
  ResultDictList ReadNext() {
Z
Zeng Jinle 已提交
197
    CheckNextStatus();
198 199
    ResultDictList result;
    result.reserve(ret_.size());
200 201
    for (auto &item : ret_) {
      if (item.empty()) {
202 203 204 205 206 207
        if (!kKeepOrder) result.emplace_back();
        continue;
      }

      result.emplace_back();
      auto &ret = result.back();
208
      PADDLE_ENFORCE_EQ(names_.size(),
209
                        item.size(),
210 211 212 213 214 215 216
                        platform::errors::InvalidArgument(
                            "The sample number of reader's input data and the "
                            "input number of feed list are not equal.\n"
                            "Possible reasons are:\n"
                            "  The generator is decorated by `paddle.batch` "
                            "and configured by `set_batch_generator`, but here "
                            "need to used `set_sample_list_generator`."));
S
sneaxiy 已提交
217
      for (size_t j = 0; j < names_.size(); ++j) {
218
        ret.emplace(names_[j], std::move(item[j]));
S
sneaxiy 已提交
219
      }
S
sneaxiy 已提交
220
    }
S
sneaxiy 已提交
221 222
    ReadAsync();
    return result;
S
sneaxiy 已提交
223 224
  }

225
  ResultList ReadNextList() {
Z
Zeng Jinle 已提交
226
    CheckNextStatus();
227 228
    ResultList result;
    result.reserve(ret_.size());
229 230 231
    for (auto &item : ret_) {
      if (kKeepOrder && item.empty()) continue;
      result.emplace_back(std::move(item));
232 233 234 235 236
    }
    ReadAsync();
    return result;
  }

S
sneaxiy 已提交
237 238 239 240 241 242
  void Reset() {
    Shutdown();
    Start();
    ReadAsync();
  }

243 244 245 246
  void Shutdown() {
    for (auto &r : readers_) r->Shutdown();
  }

S
sneaxiy 已提交
247 248 249 250
  ~MultiDeviceFeedReader() {
    queue_->Close();
    pool_.reset();
  }
S
sneaxiy 已提交
251 252

 private:
Z
Zeng Jinle 已提交
253 254 255 256 257 258 259 260
  enum Status {
    kSuccess = 0,   // Read next data successfully
    kEOF = 1,       // Reach EOF
    kException = 2  // Exception raises when reading
  };

  Status WaitFutures(std::exception_ptr *excep) {
    *excep = nullptr;
261
    size_t success_num = 0;
Z
Zeng Jinle 已提交
262 263 264 265
    for (size_t i = 0; i < futures_.size(); ++i) {
      auto each_status = futures_[i].get();
      if (UNLIKELY(each_status != Status::kSuccess)) {
        if (UNLIKELY(each_status == Status::kException)) {
266 267 268 269 270
          PADDLE_ENFORCE_NOT_NULL(
              exceptions_[i],
              platform::errors::NotFound("exceptions_[%d] is NULL, but the "
                                         "result status is Status::kException",
                                         i));
Z
Zeng Jinle 已提交
271 272 273
          *excep = exceptions_[i];
          exceptions_[i] = nullptr;
        }
274 275
      } else {
        ++success_num;
Z
Zeng Jinle 已提交
276 277 278 279 280
      }
    }

    if (UNLIKELY(*excep)) {
      return Status::kException;
281 282 283 284
    }

    if (drop_last_) {
      return success_num == futures_.size() ? Status::kSuccess : Status::kEOF;
Z
Zeng Jinle 已提交
285
    } else {
286
      return success_num > 0 ? Status::kSuccess : Status::kEOF;
S
sneaxiy 已提交
287 288
    }
  }
S
sneaxiy 已提交
289

S
sneaxiy 已提交
290 291
  void Start() {
    for (auto &r : readers_) r->Start();
S
sneaxiy 已提交
292 293
  }

S
sneaxiy 已提交
294 295 296
  void ReadAsync() {
    for (size_t i = 0; i < readers_.size(); ++i) {
      futures_[i] = pool_->enqueue([this, i] {
Z
Zeng Jinle 已提交
297 298 299 300 301 302 303
        try {
          readers_[i]->ReadNext(&ret_[i]);
          return ret_[i].empty() ? Status::kEOF : Status::kSuccess;
        } catch (...) {
          exceptions_[i] = std::current_exception();
          return Status::kException;
        }
S
sneaxiy 已提交
304 305 306 307
      });
    }
  }

Z
Zeng Jinle 已提交
308 309 310 311 312
  void CheckNextStatus() {
    std::exception_ptr excep;
    Status status = WaitFutures(&excep);

    if (UNLIKELY(excep)) {
313 314
      PADDLE_ENFORCE_EQ(status,
                        Status::kException,
315 316 317
                        platform::errors::NotFound(
                            "The exception raised is not NULL, but "
                            "the result status is not Status::kException"));
Z
Zeng Jinle 已提交
318 319 320 321 322 323 324 325 326
      std::rethrow_exception(excep);
    }

    if (UNLIKELY(status == Status::kEOF)) {
      VLOG(2) << "Raise StopIteration Exception in Python";
      py::gil_scoped_acquire guard;
      throw py::stop_iteration();
    }

327 328
    PADDLE_ENFORCE_EQ(status,
                      Status::kSuccess,
329
                      platform::errors::NotFound(
C
co63oc 已提交
330
                          "The function executed successfully, but "
331
                          "the result status is not Status::kSuccess"));
Z
Zeng Jinle 已提交
332 333
  }

334
  std::shared_ptr<QueueType> queue_;
S
sneaxiy 已提交
335 336 337 338
  std::vector<std::string> names_;
  std::unique_ptr<::ThreadPool> pool_;

  std::vector<std::unique_ptr<framework::ReaderHolder>> readers_;
S
sneaxiy 已提交
339

Z
Zeng Jinle 已提交
340 341 342
  std::vector<std::future<Status>> futures_;
  std::vector<std::exception_ptr> exceptions_;

343
  std::vector<paddle::framework::LoDTensorArray> ret_;
344
  bool drop_last_;
345
  bool pin_memory_;
S
sneaxiy 已提交
346
};
S
sneaxiy 已提交
347

348 349
template <typename QueueType>
void BindMultiDeviceReader(py::module *module, const char *reader_name) {
S
sneaxiy 已提交
350 351
  auto &m = *module;

352 353
  using ReaderType = MultiDeviceFeedReader<QueueType>;
  py::class_<ReaderType>(m, reader_name, "")
354 355
      .def("read_next",
           &ReaderType::ReadNext,
S
sneaxiy 已提交
356
           py::call_guard<py::gil_scoped_release>())
357 358
      .def("read_next_list",
           &ReaderType::ReadNextList,
359
           py::call_guard<py::gil_scoped_release>())
360 361 362 363 364 365 366
      .def(
          "read_next_var_list",
          [](ReaderType &self) {
            auto result_list = self.ReadNextList();
            auto &tensor_list = result_list[0];
            std::vector<std::shared_ptr<imperative::VarBase>> var_list;
            var_list.reserve(tensor_list.size());
367
            auto func = [](phi::DenseTensor &lod_tensor) {
368 369 370 371 372 373 374 375 376
              std::string act_name =
                  imperative::GetCurrentTracer()->GenerateUniqueName(
                      "generated_var");
              auto new_var = std::make_shared<imperative::VarBase>(act_name);
              new_var->SetPersistable(false);
              new_var->SetType(framework::proto::VarType::LOD_TENSOR);
              new_var->SetDataType(
                  framework::TransToProtoVarType(lod_tensor.dtype()));
              auto *tensor =
377
                  new_var->MutableVar()->GetMutable<phi::DenseTensor>();
378 379 380 381 382 383 384 385 386
              *tensor = std::move(lod_tensor);
              return new_var;
            };
            for (auto &tensor : tensor_list) {
              var_list.emplace_back(func(tensor));
            }
            return var_list;
          },
          py::call_guard<py::gil_scoped_release>())
387 388 389 390
      .def(
          "reset", &ReaderType::Reset, py::call_guard<py::gil_scoped_release>())
      .def("shutdown",
           &ReaderType::Shutdown,
391 392 393 394 395 396
           py::call_guard<py::gil_scoped_release>());
}

void BindReader(py::module *module) {
  auto &m = *module;

397
  m.def("diff_tensor_shape",
398
        [](const phi::DenseTensor &tensor,
399 400 401 402 403 404 405 406 407
           const framework::VarDesc &var_desc,
           size_t num_places) -> py::object {
          auto diff = DiffTensorShapeWithVarDesc(tensor, var_desc, num_places);
          if (diff) {
            return py::cast(std::move(diff.get()));
          } else {
            return py::cast(nullptr);
          }
        });
408

409 410 411 412 413 414 415 416 417 418 419 420
  m.def("diff_tensor_shape",
        [](const phi::DenseTensor &tensor,
           const std::vector<int64_t> &target_shape,
           size_t num_places) -> py::object {
          auto diff = DiffTensorShape(tensor, target_shape, num_places);
          if (diff) {
            return py::cast(std::move(diff.get()));
          } else {
            return py::cast(nullptr);
          }
        });

421 422
  m.def(
      "init_lod_tensor_blocking_queue",
423 424
      [](framework::Variable &var,
         size_t capacity,
425 426 427 428 429 430 431 432 433 434 435 436 437 438
         bool is_ordered) -> py::object {
        VLOG(1) << "init_lod_tensor_blocking_queue";
        if (is_ordered) {
          auto *holder = var.GetMutable<
              reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder>();
          holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
          return py::cast(holder->GetQueue());
        } else {
          auto *holder = var.GetMutable<reader::LoDTensorBlockingQueueHolder>();
          holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
          return py::cast(holder->GetQueue());
        }
      },
      py::return_value_policy::copy);
439 440 441 442 443 444 445 446

  py::class_<framework::ReaderHolder>(m, "Reader", "")
      .def("start", &framework::ReaderHolder::Start)
      .def("reset", &framework::ReaderHolder::ResetAll);

  py::class_<reader::LoDTensorBlockingQueue,
             std::shared_ptr<reader::LoDTensorBlockingQueue>>(
      m, "LoDTensorBlockingQueue", "")
447 448 449
      .def(
          "push",
          [](reader::LoDTensorBlockingQueue &self,
450
             const paddle::framework::LoDTensorArray &lod_tensor_vec) {
451 452 453
            return self.Push(lod_tensor_vec);
          },
          py::call_guard<py::gil_scoped_release>())
454 455 456 457
      .def("size", &reader::LoDTensorBlockingQueue::Size)
      .def("capacity", &reader::LoDTensorBlockingQueue::Cap)
      .def("close", &reader::LoDTensorBlockingQueue::Close)
      .def("kill", &reader::LoDTensorBlockingQueue::Kill)
458 459
      .def("wait_for_inited",
           &reader::LoDTensorBlockingQueue::WaitForInited,
S
sneaxiy 已提交
460 461
           py::call_guard<py::gil_scoped_release>());

462 463 464
  py::class_<reader::OrderedMultiDeviceLoDTensorBlockingQueue,
             std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>>(
      m, "OrderedMultiDeviceLoDTensorBlockingQueue", "")
465 466 467
      .def(
          "push",
          [](reader::OrderedMultiDeviceLoDTensorBlockingQueue &self,
468
             const paddle::framework::LoDTensorArray &lod_tensor_vec) {
469 470 471
            return self.Push(lod_tensor_vec);
          },
          py::call_guard<py::gil_scoped_release>())
472 473 474 475 476 477 478 479 480 481 482 483 484 485
      .def("size", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Size)
      .def("capacity", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Cap)
      .def("close", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Close)
      .def("kill", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Kill)
      .def("wait_for_inited",
           &reader::OrderedMultiDeviceLoDTensorBlockingQueue::WaitForInited,
           py::call_guard<py::gil_scoped_release>())
      .def("reset", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Reset);

  BindMultiDeviceReader<reader::LoDTensorBlockingQueue>(
      module, "MultiDeviceFeedReader");
  BindMultiDeviceReader<reader::OrderedMultiDeviceLoDTensorBlockingQueue>(
      module, "OrderedMultiDeviceFeedReader");

486 487 488 489 490 491 492
  m.def(
      "create_py_reader",
      [](const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue,
         const std::vector<std::string> &names,
         const std::vector<std::vector<int>> &shapes,
         const std::vector<framework::proto::VarType::Type> &dtypes,
         const std::vector<bool> &need_check_feed,
493 494 495 496
         const std::vector<platform::Place> &dst_places,
         bool use_double_buffer,
         bool drop_last,
         bool pin_memory) {
497
        return new MultiDeviceFeedReader<reader::LoDTensorBlockingQueue>(
498 499 500 501 502 503 504 505 506
            queue,
            names,
            shapes,
            dtypes,
            need_check_feed,
            dst_places,
            use_double_buffer,
            drop_last,
            pin_memory);
507 508
      },
      py::return_value_policy::take_ownership);
509 510 511 512 513 514 515 516 517

  m.def(
      "create_py_reader",
      [](const std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>
             &queue,
         const std::vector<std::string> &names,
         const std::vector<std::vector<int>> &shapes,
         const std::vector<framework::proto::VarType::Type> &dtypes,
         const std::vector<bool> &need_check_feed,
518 519 520 521
         const std::vector<platform::Place> &dst_places,
         bool use_double_buffer,
         bool drop_last,
         bool pin_memory) {
522 523
        queue->SetDeviceCount(dst_places.size());
        return new MultiDeviceFeedReader<
524 525 526 527 528 529 530 531 532
            reader::OrderedMultiDeviceLoDTensorBlockingQueue>(queue,
                                                              names,
                                                              shapes,
                                                              dtypes,
                                                              need_check_feed,
                                                              dst_places,
                                                              use_double_buffer,
                                                              drop_last,
                                                              pin_memory);
533 534
      },
      py::return_value_policy::take_ownership);
S
sneaxiy 已提交
535 536 537 538
}

}  // namespace pybind
}  // namespace paddle