reader_py.cc 17.6 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/reader_py.h"
16

Z
Zeng Jinle 已提交
17
#include <exception>
S
sneaxiy 已提交
18
#include <memory>
S
sneaxiy 已提交
19
#include <string>
S
sneaxiy 已提交
20 21
#include <unordered_map>
#include <utility>
S
sneaxiy 已提交
22
#include <vector>
23

Z
Zeng Jinle 已提交
24
#include "Python.h"
25

26
#include "gflags/gflags.h"
S
sneaxiy 已提交
27
#include "paddle/fluid/framework/reader.h"
28 29
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/tracer.h"
S
sneaxiy 已提交
30
#include "paddle/fluid/operators/reader/buffered_reader.h"
31
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
S
sneaxiy 已提交
32 33
#include "paddle/fluid/operators/reader/py_reader.h"
#include "paddle/fluid/platform/place.h"
34
#include "paddle/phi/core/ddim.h"
35
#include "paddle/phi/core/flags.h"
S
sneaxiy 已提交
36 37
#include "pybind11/stl.h"

38
PHI_DECLARE_bool(reader_queue_speed_test_mode);
39

40 41 42
// disable auto conversion to list in Python
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray);

S
sneaxiy 已提交
43 44 45
namespace paddle {
namespace pybind {

Z
Zeng Jinle 已提交
46
namespace py = pybind11;
47 48 49 50
namespace reader = operators::reader;

// Check whether the tensor shape matches the VarDesc shape
// Return the different shape if exists
51
static paddle::optional<std::vector<int64_t>> DiffTensorShapeWithVarDesc(
52
    const phi::DenseTensor &tensor,
53
    const framework::VarDesc &var_desc,
54 55 56 57 58 59 60
    size_t num_places) {
  auto tensor_shape = tensor.dims();
  auto desc_shape = var_desc.GetShape();

  int64_t rank = tensor_shape.size();

  if (UNLIKELY(rank == 0)) {
61
    if (!desc_shape.empty()) {  // Tensor rank = 0 but desc does not match
62
      return phi::vectorize<int64_t>(tensor_shape);
63
    } else {
64
      return paddle::none;
65 66 67
    }
  }

68 69
  PADDLE_ENFORCE_GE(tensor_shape[0],
                    0,
70 71 72 73 74 75 76 77 78 79 80
                    platform::errors::InvalidArgument(
                        "Tensor shape at dim 0 must not be less than 0"));

  if (!tensor.lod().empty()) {
    tensor_shape[0] = -1;  // unknown shape
  } else {
    int64_t split_size = (tensor_shape[0] + num_places - 1) / num_places;
    int64_t remainder = (split_size == 0 ? 0 : tensor_shape[0] % split_size);
    tensor_shape[0] = split_size;
    if (desc_shape[0] >= 0) {  // need check dim 0
      if (tensor_shape[0] != desc_shape[0]) {
81
        return phi::vectorize<int64_t>(tensor_shape);
82 83 84 85
      }

      if (remainder > 0) {
        tensor_shape[0] = remainder;
86
        return phi::vectorize<int64_t>(tensor_shape);
87 88 89 90 91 92
      }
    }
  }

  for (int64_t idx = 1; idx < rank; ++idx) {
    PADDLE_ENFORCE_GE(
93 94
        tensor_shape[idx],
        0,
95 96 97
        platform::errors::InvalidArgument(
            "Tensor shape at dim %d must not be less than 0", idx));
    if (desc_shape[idx] >= 0 && tensor_shape[idx] != desc_shape[idx]) {
98
      return phi::vectorize<int64_t>(tensor_shape);
99 100 101
    }
  }

102
  return paddle::none;
103 104 105 106 107 108 109 110 111 112 113 114 115
}

static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
    const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue, size_t idx) {
  return queue;
}

static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
    const std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>
        &queue,
    size_t idx) {
  return queue->GetQueue(idx);
}
Z
Zeng Jinle 已提交
116

117
template <typename QueueType>
S
sneaxiy 已提交
118 119
class MultiDeviceFeedReader {
 public:
S
sneaxiy 已提交
120
  using ResultDictList =
121
      std::vector<std::unordered_map<std::string, phi::DenseTensor>>;
122
  using ResultList = std::vector<paddle::framework::LoDTensorArray>;
S
sneaxiy 已提交
123

124 125 126 127
  static constexpr bool kKeepOrder =
      std::is_same<QueueType,
                   reader::OrderedMultiDeviceLoDTensorBlockingQueue>::value;

S
sneaxiy 已提交
128
  MultiDeviceFeedReader(
129
      const std::shared_ptr<QueueType> &queue,
S
sneaxiy 已提交
130
      const std::vector<std::string> &names,
131 132 133
      const std::vector<std::vector<int>> &shapes,
      const std::vector<framework::proto::VarType::Type> &dtypes,
      const std::vector<bool> &need_check_feed,
134 135 136 137
      const std::vector<platform::Place> &dst_places,
      bool use_double_buffer,
      bool drop_last,
      bool pin_memory = false)
S
sneaxiy 已提交
138
      : queue_(queue),
S
sneaxiy 已提交
139
        names_(names),
140
        pool_(new ::ThreadPool(dst_places.size())),
141 142
        drop_last_(drop_last),
        pin_memory_(pin_memory) {
143 144
    std::vector<framework::DDim> dims;
    for (auto &shape : shapes) {
145
      dims.push_back(phi::make_ddim(shape));
146
    }
147 148 149 150 151 152 153 154 155

    auto first_reader = std::make_shared<reader::PyReader>(
        GetQueue(queue, 0), dims, dtypes, need_check_feed);

    auto create_or_get_reader = [&](size_t idx) {
      if (idx == 0 ||
          std::is_same<QueueType, reader::LoDTensorBlockingQueue>::value) {
        return first_reader;
      } else {
156 157
        return std::make_shared<reader::PyReader>(
            GetQueue(queue, idx), dims, dtypes, need_check_feed);
158 159
      }
    };
S
sneaxiy 已提交
160 161

    readers_.reserve(dst_places.size());
162 163
    for (size_t i = 0; i < dst_places.size(); ++i) {
      auto &p = dst_places[i];
S
sneaxiy 已提交
164
      auto *holder = new framework::ReaderHolder();
165
      auto reader = create_or_get_reader(i);
S
sneaxiy 已提交
166
      if (use_double_buffer) {
167
        VLOG(10) << "Creating " << i << "-th BufferedReader";
S
sneaxiy 已提交
168 169
        holder->Reset(
            framework::MakeDecoratedReader<operators::reader::BufferedReader>(
170
                reader, p, 2, pin_memory_));
S
sneaxiy 已提交
171 172
      } else {
        if (platform::is_gpu_place(p)) {
173 174
          PADDLE_THROW(platform::errors::PermissionDenied(
              "Place cannot be CUDAPlace when use_double_buffer is False"));
S
sneaxiy 已提交
175 176 177 178 179
        }
        holder->Reset(reader);
      }
      readers_.emplace_back(holder);
    }
S
sneaxiy 已提交
180

S
sneaxiy 已提交
181 182
    futures_.resize(dst_places.size());
    ret_.resize(dst_places.size());
Z
Zeng Jinle 已提交
183
    exceptions_.assign(dst_places.size(), nullptr);
S
sneaxiy 已提交
184 185
    ReadAsync();
  }
S
sneaxiy 已提交
186

187 188
  bool DropLast() const { return drop_last_; }

S
sneaxiy 已提交
189
  ResultDictList ReadNext() {
Z
Zeng Jinle 已提交
190
    CheckNextStatus();
191 192
    ResultDictList result;
    result.reserve(ret_.size());
193 194
    for (auto &item : ret_) {
      if (item.empty()) {
195 196 197 198 199 200
        if (!kKeepOrder) result.emplace_back();
        continue;
      }

      result.emplace_back();
      auto &ret = result.back();
201
      PADDLE_ENFORCE_EQ(names_.size(),
202
                        item.size(),
203 204 205 206 207 208 209
                        platform::errors::InvalidArgument(
                            "The sample number of reader's input data and the "
                            "input number of feed list are not equal.\n"
                            "Possible reasons are:\n"
                            "  The generator is decorated by `paddle.batch` "
                            "and configured by `set_batch_generator`, but here "
                            "need to used `set_sample_list_generator`."));
S
sneaxiy 已提交
210
      for (size_t j = 0; j < names_.size(); ++j) {
211
        ret.emplace(names_[j], std::move(item[j]));
S
sneaxiy 已提交
212
      }
S
sneaxiy 已提交
213
    }
S
sneaxiy 已提交
214 215
    ReadAsync();
    return result;
S
sneaxiy 已提交
216 217
  }

218
  ResultList ReadNextList() {
Z
Zeng Jinle 已提交
219
    CheckNextStatus();
220 221
    ResultList result;
    result.reserve(ret_.size());
222 223 224
    for (auto &item : ret_) {
      if (kKeepOrder && item.empty()) continue;
      result.emplace_back(std::move(item));
225 226 227 228 229
    }
    ReadAsync();
    return result;
  }

S
sneaxiy 已提交
230 231 232 233 234 235
  void Reset() {
    Shutdown();
    Start();
    ReadAsync();
  }

236 237 238 239
  void Shutdown() {
    for (auto &r : readers_) r->Shutdown();
  }

S
sneaxiy 已提交
240 241 242 243
  ~MultiDeviceFeedReader() {
    queue_->Close();
    pool_.reset();
  }
S
sneaxiy 已提交
244 245

 private:
Z
Zeng Jinle 已提交
246 247 248 249 250 251 252 253
  enum Status {
    kSuccess = 0,   // Read next data successfully
    kEOF = 1,       // Reach EOF
    kException = 2  // Exception raises when reading
  };

  Status WaitFutures(std::exception_ptr *excep) {
    *excep = nullptr;
254
    size_t success_num = 0;
Z
Zeng Jinle 已提交
255 256 257 258
    for (size_t i = 0; i < futures_.size(); ++i) {
      auto each_status = futures_[i].get();
      if (UNLIKELY(each_status != Status::kSuccess)) {
        if (UNLIKELY(each_status == Status::kException)) {
259 260 261 262 263
          PADDLE_ENFORCE_NOT_NULL(
              exceptions_[i],
              platform::errors::NotFound("exceptions_[%d] is NULL, but the "
                                         "result status is Status::kException",
                                         i));
Z
Zeng Jinle 已提交
264 265 266
          *excep = exceptions_[i];
          exceptions_[i] = nullptr;
        }
267 268
      } else {
        ++success_num;
Z
Zeng Jinle 已提交
269 270 271 272 273
      }
    }

    if (UNLIKELY(*excep)) {
      return Status::kException;
274 275 276 277
    }

    if (drop_last_) {
      return success_num == futures_.size() ? Status::kSuccess : Status::kEOF;
Z
Zeng Jinle 已提交
278
    } else {
279
      return success_num > 0 ? Status::kSuccess : Status::kEOF;
S
sneaxiy 已提交
280 281
    }
  }
S
sneaxiy 已提交
282

S
sneaxiy 已提交
283 284
  void Start() {
    for (auto &r : readers_) r->Start();
S
sneaxiy 已提交
285 286
  }

S
sneaxiy 已提交
287 288 289
  void ReadAsync() {
    for (size_t i = 0; i < readers_.size(); ++i) {
      futures_[i] = pool_->enqueue([this, i] {
Z
Zeng Jinle 已提交
290 291 292 293 294 295 296
        try {
          readers_[i]->ReadNext(&ret_[i]);
          return ret_[i].empty() ? Status::kEOF : Status::kSuccess;
        } catch (...) {
          exceptions_[i] = std::current_exception();
          return Status::kException;
        }
S
sneaxiy 已提交
297 298 299 300
      });
    }
  }

Z
Zeng Jinle 已提交
301 302 303 304 305
  void CheckNextStatus() {
    std::exception_ptr excep;
    Status status = WaitFutures(&excep);

    if (UNLIKELY(excep)) {
306 307
      PADDLE_ENFORCE_EQ(status,
                        Status::kException,
308 309 310
                        platform::errors::NotFound(
                            "The exception raised is not NULL, but "
                            "the result status is not Status::kException"));
Z
Zeng Jinle 已提交
311 312 313 314 315 316 317 318 319
      std::rethrow_exception(excep);
    }

    if (UNLIKELY(status == Status::kEOF)) {
      VLOG(2) << "Raise StopIteration Exception in Python";
      py::gil_scoped_acquire guard;
      throw py::stop_iteration();
    }

320 321
    PADDLE_ENFORCE_EQ(status,
                      Status::kSuccess,
322
                      platform::errors::NotFound(
C
co63oc 已提交
323
                          "The function executed successfully, but "
324
                          "the result status is not Status::kSuccess"));
Z
Zeng Jinle 已提交
325 326
  }

327
  std::shared_ptr<QueueType> queue_;
S
sneaxiy 已提交
328 329 330 331
  std::vector<std::string> names_;
  std::unique_ptr<::ThreadPool> pool_;

  std::vector<std::unique_ptr<framework::ReaderHolder>> readers_;
S
sneaxiy 已提交
332

Z
Zeng Jinle 已提交
333 334 335
  std::vector<std::future<Status>> futures_;
  std::vector<std::exception_ptr> exceptions_;

336
  std::vector<paddle::framework::LoDTensorArray> ret_;
337
  bool drop_last_;
338
  bool pin_memory_;
S
sneaxiy 已提交
339
};
S
sneaxiy 已提交
340

341 342
template <typename QueueType>
void BindMultiDeviceReader(py::module *module, const char *reader_name) {
S
sneaxiy 已提交
343 344
  auto &m = *module;

345 346
  using ReaderType = MultiDeviceFeedReader<QueueType>;
  py::class_<ReaderType>(m, reader_name, "")
347 348
      .def("read_next",
           &ReaderType::ReadNext,
S
sneaxiy 已提交
349
           py::call_guard<py::gil_scoped_release>())
350 351
      .def("read_next_list",
           &ReaderType::ReadNextList,
352
           py::call_guard<py::gil_scoped_release>())
353 354 355 356 357 358 359
      .def(
          "read_next_var_list",
          [](ReaderType &self) {
            auto result_list = self.ReadNextList();
            auto &tensor_list = result_list[0];
            std::vector<std::shared_ptr<imperative::VarBase>> var_list;
            var_list.reserve(tensor_list.size());
360
            auto func = [](phi::DenseTensor &lod_tensor) {
361 362 363 364 365 366 367 368 369
              std::string act_name =
                  imperative::GetCurrentTracer()->GenerateUniqueName(
                      "generated_var");
              auto new_var = std::make_shared<imperative::VarBase>(act_name);
              new_var->SetPersistable(false);
              new_var->SetType(framework::proto::VarType::LOD_TENSOR);
              new_var->SetDataType(
                  framework::TransToProtoVarType(lod_tensor.dtype()));
              auto *tensor =
370
                  new_var->MutableVar()->GetMutable<phi::DenseTensor>();
371 372 373 374 375 376 377 378 379
              *tensor = std::move(lod_tensor);
              return new_var;
            };
            for (auto &tensor : tensor_list) {
              var_list.emplace_back(func(tensor));
            }
            return var_list;
          },
          py::call_guard<py::gil_scoped_release>())
380 381 382 383
      .def(
          "reset", &ReaderType::Reset, py::call_guard<py::gil_scoped_release>())
      .def("shutdown",
           &ReaderType::Shutdown,
384 385 386 387 388 389
           py::call_guard<py::gil_scoped_release>());
}

void BindReader(py::module *module) {
  auto &m = *module;

390
  m.def("diff_tensor_shape",
391
        [](const phi::DenseTensor &tensor,
392 393 394 395 396 397 398 399 400
           const framework::VarDesc &var_desc,
           size_t num_places) -> py::object {
          auto diff = DiffTensorShapeWithVarDesc(tensor, var_desc, num_places);
          if (diff) {
            return py::cast(std::move(diff.get()));
          } else {
            return py::cast(nullptr);
          }
        });
401 402 403

  m.def(
      "init_lod_tensor_blocking_queue",
404 405
      [](framework::Variable &var,
         size_t capacity,
406 407 408 409 410 411 412 413 414 415 416 417 418 419
         bool is_ordered) -> py::object {
        VLOG(1) << "init_lod_tensor_blocking_queue";
        if (is_ordered) {
          auto *holder = var.GetMutable<
              reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder>();
          holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
          return py::cast(holder->GetQueue());
        } else {
          auto *holder = var.GetMutable<reader::LoDTensorBlockingQueueHolder>();
          holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
          return py::cast(holder->GetQueue());
        }
      },
      py::return_value_policy::copy);
420 421 422 423 424 425 426 427

  py::class_<framework::ReaderHolder>(m, "Reader", "")
      .def("start", &framework::ReaderHolder::Start)
      .def("reset", &framework::ReaderHolder::ResetAll);

  py::class_<reader::LoDTensorBlockingQueue,
             std::shared_ptr<reader::LoDTensorBlockingQueue>>(
      m, "LoDTensorBlockingQueue", "")
428 429 430
      .def(
          "push",
          [](reader::LoDTensorBlockingQueue &self,
431
             const paddle::framework::LoDTensorArray &lod_tensor_vec) {
432 433 434
            return self.Push(lod_tensor_vec);
          },
          py::call_guard<py::gil_scoped_release>())
435 436 437 438
      .def("size", &reader::LoDTensorBlockingQueue::Size)
      .def("capacity", &reader::LoDTensorBlockingQueue::Cap)
      .def("close", &reader::LoDTensorBlockingQueue::Close)
      .def("kill", &reader::LoDTensorBlockingQueue::Kill)
439 440
      .def("wait_for_inited",
           &reader::LoDTensorBlockingQueue::WaitForInited,
S
sneaxiy 已提交
441 442
           py::call_guard<py::gil_scoped_release>());

443 444 445
  py::class_<reader::OrderedMultiDeviceLoDTensorBlockingQueue,
             std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>>(
      m, "OrderedMultiDeviceLoDTensorBlockingQueue", "")
446 447 448
      .def(
          "push",
          [](reader::OrderedMultiDeviceLoDTensorBlockingQueue &self,
449
             const paddle::framework::LoDTensorArray &lod_tensor_vec) {
450 451 452
            return self.Push(lod_tensor_vec);
          },
          py::call_guard<py::gil_scoped_release>())
453 454 455 456 457 458 459 460 461 462 463 464 465 466
      .def("size", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Size)
      .def("capacity", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Cap)
      .def("close", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Close)
      .def("kill", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Kill)
      .def("wait_for_inited",
           &reader::OrderedMultiDeviceLoDTensorBlockingQueue::WaitForInited,
           py::call_guard<py::gil_scoped_release>())
      .def("reset", &reader::OrderedMultiDeviceLoDTensorBlockingQueue::Reset);

  BindMultiDeviceReader<reader::LoDTensorBlockingQueue>(
      module, "MultiDeviceFeedReader");
  BindMultiDeviceReader<reader::OrderedMultiDeviceLoDTensorBlockingQueue>(
      module, "OrderedMultiDeviceFeedReader");

467 468 469 470 471 472 473
  m.def(
      "create_py_reader",
      [](const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue,
         const std::vector<std::string> &names,
         const std::vector<std::vector<int>> &shapes,
         const std::vector<framework::proto::VarType::Type> &dtypes,
         const std::vector<bool> &need_check_feed,
474 475 476 477
         const std::vector<platform::Place> &dst_places,
         bool use_double_buffer,
         bool drop_last,
         bool pin_memory) {
478
        return new MultiDeviceFeedReader<reader::LoDTensorBlockingQueue>(
479 480 481 482 483 484 485 486 487
            queue,
            names,
            shapes,
            dtypes,
            need_check_feed,
            dst_places,
            use_double_buffer,
            drop_last,
            pin_memory);
488 489
      },
      py::return_value_policy::take_ownership);
490 491 492 493 494 495 496 497 498

  m.def(
      "create_py_reader",
      [](const std::shared_ptr<reader::OrderedMultiDeviceLoDTensorBlockingQueue>
             &queue,
         const std::vector<std::string> &names,
         const std::vector<std::vector<int>> &shapes,
         const std::vector<framework::proto::VarType::Type> &dtypes,
         const std::vector<bool> &need_check_feed,
499 500 501 502
         const std::vector<platform::Place> &dst_places,
         bool use_double_buffer,
         bool drop_last,
         bool pin_memory) {
503 504
        queue->SetDeviceCount(dst_places.size());
        return new MultiDeviceFeedReader<
505 506 507 508 509 510 511 512 513
            reader::OrderedMultiDeviceLoDTensorBlockingQueue>(queue,
                                                              names,
                                                              shapes,
                                                              dtypes,
                                                              need_check_feed,
                                                              dst_places,
                                                              use_double_buffer,
                                                              drop_last,
                                                              pin_memory);
514 515
      },
      py::return_value_policy::take_ownership);
S
sneaxiy 已提交
516 517 518 519
}

}  // namespace pybind
}  // namespace paddle