// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/pybind/reader_py.h" #include #include #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/operators/reader/buffered_reader.h" #include "paddle/fluid/operators/reader/compose_reader.h" #include "paddle/fluid/operators/reader/py_reader.h" #include "paddle/fluid/platform/place.h" #include "pybind11/stl.h" namespace paddle { namespace pybind { class FeedReader { using ResultDictList = std::vector>; public: FeedReader(std::unique_ptr reader, const std::vector &names, size_t num_places, bool drop_last = true) : reader_(std::move(reader)), names_(names), num_places_(num_places), drop_last_(drop_last) {} ResultDictList ReadNext() { std::vector tensors; reader_->ReadNext(&tensors); if (tensors.empty()) return ResultDictList(); PADDLE_ENFORCE(tensors.size() % names_.size() == 0, "Tensor size: %d, names size: %d", tensors.size(), names_.size()); size_t read_place_num = tensors.size() / names_.size(); if (drop_last_ && read_place_num != num_places_) { return ResultDictList(); } ResultDictList ret(read_place_num); for (size_t i = 0; i < tensors.size(); ++i) { ret[i / names_.size()].emplace(names_[i % names_.size()], std::move(tensors[i])); } return ret; } void Start() { reader_->Start(); } void Reset() { reader_->ResetAll(); } private: std::unique_ptr reader_; std::vector names_; size_t num_places_; bool drop_last_; }; static std::unique_ptr CreatePyReader( const std::vector< std::shared_ptr> &queues, const std::vector &dst_places) { std::shared_ptr reader; if (queues.size() == 1) { reader.reset(new operators::reader::PyReader(queues[0])); } else { reader.reset(new operators::reader::MultiQueuePyReader(queues)); } std::vector> buffered_reader; buffered_reader.reserve(dst_places.size()); for (auto &p : dst_places) { buffered_reader.emplace_back( framework::MakeDecoratedReader( reader, p, 2)); } reader = framework::MakeDecoratedReader( buffered_reader); auto *holder = new framework::ReaderHolder(); holder->Reset(reader); return std::unique_ptr(holder); } namespace py = pybind11; void BindReader(py::module *module) { auto &m = *module; namespace reader = ::paddle::operators::reader; py::class_(m, "Reader", "") .def("start", &framework::ReaderHolder::Start) .def("reset", &framework::ReaderHolder::ResetAll); py::class_(m, "FeedReader", "") .def("read_next", &FeedReader::ReadNext, py::call_guard()) .def("start", &FeedReader::Start, py::call_guard()) .def("reset", &FeedReader::Reset, py::call_guard()); m.def("create_py_reader", [](const std::vector< std::shared_ptr> queues, const std::vector &names, const std::vector &dst_places, bool drop_last) { return new FeedReader(CreatePyReader(queues, dst_places), names, dst_places.size(), drop_last); }, py::return_value_policy::take_ownership); } } // namespace pybind } // namespace paddle