From 7160cb0f322aa5a0f3478bc2957ea704a907c30b Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Tue, 19 Feb 2019 11:15:23 +0000 Subject: [PATCH] decoupled reader test=develop --- paddle/fluid/framework/reader.h | 53 +++++- paddle/fluid/operators/reader/CMakeLists.txt | 7 +- .../fluid/operators/reader/blocking_queue.h | 7 +- .../fluid/operators/reader/buffered_reader.cc | 32 ++-- .../fluid/operators/reader/buffered_reader.h | 6 +- .../fluid/operators/reader/compose_reader.cc | 39 +++++ .../fluid/operators/reader/compose_reader.h | 34 ++++ .../operators/reader/create_py_reader_op.cc | 26 +-- paddle/fluid/operators/reader/py_reader.cc | 78 +++++++++ paddle/fluid/operators/reader/py_reader.h | 62 +++++++ paddle/fluid/pybind/CMakeLists.txt | 2 +- paddle/fluid/pybind/pybind.cc | 54 +++++- paddle/fluid/pybind/reader_py.cc | 132 +++++++++++++++ paddle/fluid/pybind/reader_py.h | 25 +++ python/paddle/fluid/compiler.py | 43 +++-- python/paddle/fluid/executor.py | 4 + python/paddle/fluid/framework.py | 36 ++++ python/paddle/fluid/io.py | 4 +- python/paddle/fluid/reader.py | 141 ++++++++++++++++ .../unittests/test_decoupled_py_reader.py | 157 ++++++++++++++++++ 20 files changed, 869 insertions(+), 73 deletions(-) create mode 100644 paddle/fluid/operators/reader/compose_reader.cc create mode 100644 paddle/fluid/operators/reader/compose_reader.h create mode 100644 paddle/fluid/operators/reader/py_reader.cc create mode 100644 paddle/fluid/operators/reader/py_reader.h create mode 100644 paddle/fluid/pybind/reader_py.cc create mode 100644 paddle/fluid/pybind/reader_py.h create mode 100644 python/paddle/fluid/reader.py create mode 100644 python/paddle/fluid/tests/unittests/test_decoupled_py_reader.py diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 82562bf883..61120dcf12 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -54,6 +54,7 @@ class ReaderBase { private: friend class DecoratedReader; + friend class MultiDecoratedReader; // These methods can be only invoked inside DecoratedReader to record the // decorating chain. void InsertDecoratedReader( @@ -62,15 +63,20 @@ class ReaderBase { std::vector> decorated_readers_; }; -class DecoratedReader : public ReaderBase, +class DecoratedReaderBase : public ReaderBase { + public: + virtual void RegisterDecorateChain() = 0; +}; + +class DecoratedReader : public DecoratedReaderBase, public std::enable_shared_from_this { public: explicit DecoratedReader(const std::shared_ptr& reader) - : ReaderBase(), reader_(reader) { + : DecoratedReaderBase(), reader_(reader) { PADDLE_ENFORCE_NOT_NULL(reader_); } - void RegisterDecorateChain() { + void RegisterDecorateChain() final { reader_->InsertDecoratedReader(shared_from_this()); } @@ -84,6 +90,41 @@ class DecoratedReader : public ReaderBase, std::shared_ptr reader_; }; +class MultiDecoratedReader + : public DecoratedReaderBase, + public std::enable_shared_from_this { + public: + explicit MultiDecoratedReader( + const std::vector>& readers) + : readers_(readers) { + PADDLE_ENFORCE(!readers_.empty()); + for (auto& r : readers_) { + PADDLE_ENFORCE_NOT_NULL(r); + } + } + + void RegisterDecorateChain() final { + for (auto& r : readers_) { + r->InsertDecoratedReader(shared_from_this()); + } + } + + protected: + void ShutdownImpl() override { + for (auto& r : readers_) { + r->Shutdown(); + } + } + + void StartImpl() override { + for (auto& r : readers_) { + r->Start(); + } + } + + std::vector> readers_; +}; + // FileReader is just a conceptual class. class FileReader : public ReaderBase {}; @@ -132,8 +173,10 @@ class ReaderHolder { }; template -inline std::shared_ptr MakeDecoratedReader(ARGS&&... args) { - std::shared_ptr reader(new T(std::forward(args)...)); +inline std::shared_ptr MakeDecoratedReader( + ARGS&&... args) { + std::shared_ptr reader( + new T(std::forward(args)...)); reader->RegisterDecorateChain(); return reader; } diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 7c284312df..2701e10b30 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -17,7 +17,10 @@ function(reader_library TARGET_NAME) PARENT_SCOPE) endfunction() +cc_library(py_reader SRCS py_reader.cc DEPS reader) +cc_library(compose_reader SRCS compose_reader.cc DEPS reader) cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool) + reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) @@ -26,7 +29,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc DEPS buffered_reader) reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc) -reader_library(create_py_reader_op SRCS create_py_reader_op.cc) +reader_library(create_py_reader_op SRCS create_py_reader_op.cc DEPS py_reader) if (NOT WIN32 AND NOT ON_INFER) cc_library(ctr_reader SRCS ctr_reader.cc DEPS gzstream reader zlib) @@ -38,7 +41,7 @@ cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc) # Export local libraries to parent # set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) -op_library(read_op) +op_library(read_op DEPS py_reader compose_reader buffered_reader) foreach(src ${LOCAL_READER_LIBS}) set(OP_LIBRARY ${src} ${OP_LIBRARY} CACHE INTERNAL "op libs") diff --git a/paddle/fluid/operators/reader/blocking_queue.h b/paddle/fluid/operators/reader/blocking_queue.h index 51b980acb5..b76f482c57 100644 --- a/paddle/fluid/operators/reader/blocking_queue.h +++ b/paddle/fluid/operators/reader/blocking_queue.h @@ -34,7 +34,7 @@ class BlockingQueue { explicit BlockingQueue(size_t capacity, bool speed_test_mode = false) : capacity_(capacity), speed_test_mode_(speed_test_mode), closed_(false) { PADDLE_ENFORCE_GT( - capacity_, 0, + capacity_, static_cast(0), "The capacity of a reader::BlockingQueue must be greater than 0."); } @@ -114,6 +114,11 @@ class BlockingQueue { return queue_.size(); } + void Clear() { + std::lock_guard lock(mutex_); + queue_.clear(); + } + private: size_t capacity_; bool speed_test_mode_; diff --git a/paddle/fluid/operators/reader/buffered_reader.cc b/paddle/fluid/operators/reader/buffered_reader.cc index defc29b91f..b8c98ff5e7 100644 --- a/paddle/fluid/operators/reader/buffered_reader.cc +++ b/paddle/fluid/operators/reader/buffered_reader.cc @@ -28,8 +28,10 @@ BufferedReader::~BufferedReader() { #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { platform::SetDeviceId(boost::get(place_).device); - PADDLE_ENFORCE(cudaStreamDestroy(stream)); - for (auto &event : events) PADDLE_ENFORCE(cudaEventDestroy(event)); + PADDLE_ENFORCE(cudaStreamDestroy(stream_)); + for (auto &event : events_) { + PADDLE_ENFORCE(cudaEventDestroy(event)); + } } #endif } @@ -44,14 +46,15 @@ BufferedReader::BufferedReader( #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { platform::SetDeviceId(boost::get(place_).device); - compute_stream = + compute_stream_ = ((platform::CUDADeviceContext *)(platform::DeviceContextPool::Instance() .Get(place_))) ->stream(); - events.resize(buffer_size); - for (auto &event : events) + events_.resize(buffer_size); + for (auto &event : events_) { PADDLE_ENFORCE(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); - PADDLE_ENFORCE(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + } + PADDLE_ENFORCE(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); } #endif cpu_buffer_.resize(buffer_size); @@ -70,7 +73,7 @@ void BufferedReader::ReadAsync(size_t i) { #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { platform::SetDeviceId(boost::get(place_).device); - PADDLE_ENFORCE(cudaEventRecord(events[i], compute_stream)); + PADDLE_ENFORCE(cudaEventRecord(events_[i], compute_stream_)); } #endif position_.emplace(thread_pool_.enqueue([this, i]() -> size_t { @@ -86,7 +89,7 @@ void BufferedReader::ReadAsync(size_t i) { // TensorCopySync would block other stream if (platform::is_gpu_place(place_)) { platform::SetDeviceId(boost::get(place_).device); - PADDLE_ENFORCE(cudaStreamWaitEvent(stream, events[i], 0)); + PADDLE_ENFORCE(cudaStreamWaitEvent(stream_, events_[i], 0)); TensorVec &gpu = gpu_buffer_[i]; gpu.resize(cpu.size()); for (size_t i = 0; i < cpu.size(); ++i) { @@ -97,23 +100,24 @@ void BufferedReader::ReadAsync(size_t i) { auto gpu_ptr = gpu[i].mutable_data(place_, cpu[i].type()); auto size = cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type()); - if (platform::is_cuda_pinned_place(cpu_place)) + if (platform::is_cuda_pinned_place(cpu_place)) { memory::Copy(boost::get(place_), gpu_ptr, boost::get(cpu_place), - cpu_ptr, size, stream); - else if ((platform::is_gpu_place(cpu_place))) + cpu_ptr, size, stream_); + } else if ((platform::is_gpu_place(cpu_place))) { memory::Copy(boost::get(place_), gpu_ptr, boost::get(cpu_place), cpu_ptr, - size, stream); - else + size, stream_); + } else { // if cpu place is not pinned, async copy is slower than sync copy, // so we use sync copy instead. memory::Copy(boost::get(place_), gpu_ptr, boost::get(cpu_place), cpu_ptr, size, 0); + } gpu[i].set_lod(cpu[i].lod()); } - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); } #endif return i; diff --git a/paddle/fluid/operators/reader/buffered_reader.h b/paddle/fluid/operators/reader/buffered_reader.h index 87680da01a..6b21de0949 100644 --- a/paddle/fluid/operators/reader/buffered_reader.h +++ b/paddle/fluid/operators/reader/buffered_reader.h @@ -63,9 +63,9 @@ class BufferedReader : public framework::DecoratedReader { std::vector gpu_buffer_; size_t prev_pos_{-1UL}; #ifdef PADDLE_WITH_CUDA - cudaStream_t stream; - cudaStream_t compute_stream; - std::vector events; + cudaStream_t stream_; + cudaStream_t compute_stream_; + std::vector events_; #endif }; diff --git a/paddle/fluid/operators/reader/compose_reader.cc b/paddle/fluid/operators/reader/compose_reader.cc new file mode 100644 index 0000000000..4b88b9331c --- /dev/null +++ b/paddle/fluid/operators/reader/compose_reader.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/compose_reader.h" + +namespace paddle { +namespace operators { +namespace reader { + +ComposeReader::ComposeReader( + const std::vector> &readers) + : framework::MultiDecoratedReader(readers) {} + +void ComposeReader::ReadNext(std::vector *out) { + out->clear(); + std::vector each_ret; + for (auto &r : readers_) { + r->ReadNext(&each_ret); + out->reserve(out->size() + each_ret.size()); + for (auto &data : each_ret) { + out->emplace_back(std::move(data)); + } + } +} + +} // namespace reader +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/compose_reader.h b/paddle/fluid/operators/reader/compose_reader.h new file mode 100644 index 0000000000..c9e2a2d72f --- /dev/null +++ b/paddle/fluid/operators/reader/compose_reader.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/reader.h" + +namespace paddle { +namespace operators { +namespace reader { + +class ComposeReader : public framework::MultiDecoratedReader { + public: + explicit ComposeReader( + const std::vector> &readers); + + void ReadNext(std::vector *out) override; +}; + +} // namespace reader +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/create_py_reader_op.cc b/paddle/fluid/operators/reader/create_py_reader_op.cc index 901a92ab5b..4a6581bbbd 100644 --- a/paddle/fluid/operators/reader/create_py_reader_op.cc +++ b/paddle/fluid/operators/reader/create_py_reader_op.cc @@ -12,37 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" +#include "paddle/fluid/operators/reader/py_reader.h" #include "paddle/fluid/operators/reader/reader_op_registry.h" namespace paddle { namespace operators { namespace reader { -class PyReader : public framework::FileReader { - public: - explicit PyReader(const std::shared_ptr& queue) - : framework::FileReader() { - PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); - queue_ = queue; - } - - void ReadNext(std::vector* out) override { - bool success; - *out = queue_->Pop(&success); - if (!success) out->clear(); - } - - ~PyReader() { queue_->Close(); } - - void Shutdown() override { queue_->Close(); } - - void Start() override { queue_->ReOpen(); } - - private: - std::shared_ptr queue_; -}; - class CreatePyReaderOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; diff --git a/paddle/fluid/operators/reader/py_reader.cc b/paddle/fluid/operators/reader/py_reader.cc new file mode 100644 index 0000000000..dc84faa974 --- /dev/null +++ b/paddle/fluid/operators/reader/py_reader.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/py_reader.h" + +namespace paddle { +namespace operators { +namespace reader { + +PyReader::PyReader(const std::shared_ptr& queue) + : framework::FileReader() { + PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); + queue_ = queue; +} + +void PyReader::ReadNext(std::vector* out) { + bool success; + *out = queue_->Pop(&success); + if (!success) out->clear(); +} + +PyReader::~PyReader() { queue_->Close(); } + +void PyReader::Shutdown() { queue_->Close(); } + +void PyReader::Start() { queue_->ReOpen(); } + +MultiQueuePyReader::MultiQueuePyReader( + const std::vector>& queues) + : queues_(queues) { + PADDLE_ENFORCE(!queues_.empty()); + for (auto& q : queues_) { + PADDLE_ENFORCE_NOT_NULL(q); + } +} + +void MultiQueuePyReader::ReadNext(std::vector* out) { + auto idx = read_out_idx_.fetch_add(1) % queues_.size(); + for (size_t i = 0; i < queues_.size(); ++i) { + *out = queues_[idx]->Pop(); + if (!out->empty()) return; + idx = (idx + 1) % queues_.size(); + } +} + +MultiQueuePyReader::~MultiQueuePyReader() { + for (auto& q : queues_) { + q->Close(); + } +} + +void MultiQueuePyReader::Shutdown() { + for (auto& q : queues_) { + q->Close(); + } + read_out_idx_.store(0, std::memory_order::memory_order_seq_cst); +} + +void MultiQueuePyReader::Start() { + for (auto& q : queues_) { + q->ReOpen(); + } +} + +} // namespace reader +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/py_reader.h b/paddle/fluid/operators/reader/py_reader.h new file mode 100644 index 0000000000..146a2351e5 --- /dev/null +++ b/paddle/fluid/operators/reader/py_reader.h @@ -0,0 +1,62 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" + +namespace paddle { +namespace operators { +namespace reader { + +class PyReader : public framework::FileReader { + public: + explicit PyReader(const std::shared_ptr& queue); + + void ReadNext(std::vector* out) override; + + ~PyReader(); + + void Shutdown() override; + + void Start() override; + + private: + std::shared_ptr queue_; +}; + +class MultiQueuePyReader : public framework::FileReader { + public: + explicit MultiQueuePyReader( + const std::vector>& queues); + + void ReadNext(std::vector* out) override; + + ~MultiQueuePyReader(); + + void Shutdown() override; + + void Start() override; + + private: + std::vector> queues_; + std::atomic read_out_idx_{0}; +}; + +} // namespace reader +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 4ac5b83c56..84da8491d3 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -5,7 +5,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune if(WITH_PYTHON) list(APPEND PYBIND_DEPS py_func_op) endif() -set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc ir.cc inference_api.cc) +set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc reader_py.cc async_executor_py.cc imperative.cc ir.cc inference_api.cc) if(WITH_PYTHON) if(WITH_AMD_GPU) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a4a01ad647..f2000cc45e 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -54,6 +54,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/ir.h" #include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/pybind.h" // NOLINT +#include "paddle/fluid/pybind/reader_py.h" #include "paddle/fluid/pybind/recordio.h" #include "paddle/fluid/pybind/tensor_py.h" @@ -106,6 +107,16 @@ bool IsCompiledWithDIST() { #endif } +template +static inline bool IsSamePlace(const PlaceType1 &p1, const PlaceType2 &p2) { + return paddle::platform::Place(p1) == paddle::platform::Place(p2); +} + +template +static inline int PlaceIndex(const PlaceType &p) { + return static_cast(paddle::platform::Place(p).which()); +} + PYBIND11_MODULE(core, m) { // Not used, just make sure cpu_info.cc is linked. paddle::platform::CpuTotalPhysicalMemory(); @@ -452,6 +463,7 @@ PYBIND11_MODULE(core, m) { All parameter, weight, gradient are variables in Paddle. )DOC") + .def(py::init<>()) .def("is_int", [](const Variable &var) { return var.IsType(); }) .def("set_int", [](Variable &var, int val) -> void { *var.GetMutable() = val; }) @@ -493,9 +505,7 @@ All parameter, weight, gradient are variables in Paddle. }, py::return_value_policy::reference); - py::class_(m, "Reader", "") - .def("start", &framework::ReaderHolder::Start) - .def("reset", &framework::ReaderHolder::ResetAll); + BindReader(&m); using LoDTensorBlockingQueue = ::paddle::operators::reader::LoDTensorBlockingQueue; @@ -657,29 +667,65 @@ All parameter, weight, gradient are variables in Paddle. PADDLE_THROW("Cannot use CUDAPlace in CPU only version"); #endif }) + .def("_type", &PlaceIndex) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", + &IsSamePlace) + .def("gpu_device_id", + [](platform::CUDAPlace &self) { return self.device; }) .def("__str__", string::to_string); py::class_(m, "CPUPlace") .def(py::init<>()) + .def("_type", &PlaceIndex) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", + &IsSamePlace) .def("__str__", string::to_string); py::class_(m, "CUDAPinnedPlace") .def("__init__", - [](platform::CUDAPinnedPlace &) { + [](platform::CUDAPinnedPlace &self) { #ifndef PADDLE_WITH_CUDA PADDLE_THROW("Cannot use CUDAPinnedPlace in CPU only version"); #endif + new (&self) platform::CUDAPinnedPlace(); }) + .def("_type", &PlaceIndex) + .def("_equals", &IsSamePlace) + .def("_equals", + &IsSamePlace) + .def("_equals", + &IsSamePlace) + .def("_equals", + &IsSamePlace) .def("__str__", string::to_string); py::class_(m, "Place") .def(py::init<>()) + .def("_type", &PlaceIndex) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) .def("is_gpu_place", [](platform::Place &self) { return platform::is_gpu_place(self); }) + .def("is_cpu_place", + [](platform::Place &self) { return platform::is_cpu_place(self); }) + .def("is_cuda_pinned_place", + [](platform::Place &self) { + return platform::is_cuda_pinned_place(self); + }) .def("gpu_device_id", [](platform::Place &self) { return boost::get(self).device; }) + .def("set_place", [](platform::Place &self, + const platform::Place &other) { self = other; }) .def("set_place", [](platform::Place &self, const platform::CPUPlace &cpu_place) { self = cpu_place; diff --git a/paddle/fluid/pybind/reader_py.cc b/paddle/fluid/pybind/reader_py.cc new file mode 100644 index 0000000000..a09d18656f --- /dev/null +++ b/paddle/fluid/pybind/reader_py.cc @@ -0,0 +1,132 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pybind/reader_py.h" +#include +#include +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/operators/reader/buffered_reader.h" +#include "paddle/fluid/operators/reader/compose_reader.h" +#include "paddle/fluid/operators/reader/py_reader.h" +#include "paddle/fluid/platform/place.h" +#include "pybind11/stl.h" + +namespace paddle { +namespace pybind { + +class FeedReader { + using ResultDictList = + std::vector>; + + public: + FeedReader(std::unique_ptr reader, + const std::vector &names, size_t num_places, + bool drop_last = true) + : reader_(std::move(reader)), + names_(names), + num_places_(num_places), + drop_last_(drop_last) {} + + ResultDictList ReadNext() { + std::vector tensors; + reader_->ReadNext(&tensors); + if (tensors.empty()) return ResultDictList(); + + PADDLE_ENFORCE(tensors.size() % names_.size() == 0, + "Tensor size: %d, names size: %d", tensors.size(), + names_.size()); + + size_t read_place_num = tensors.size() / names_.size(); + + if (drop_last_ && read_place_num != num_places_) { + return ResultDictList(); + } + + ResultDictList ret(read_place_num); + for (size_t i = 0; i < tensors.size(); ++i) { + ret[i / names_.size()].emplace(names_[i % names_.size()], + std::move(tensors[i])); + } + return ret; + } + + void Start() { reader_->Start(); } + + void Reset() { reader_->ResetAll(); } + + private: + std::unique_ptr reader_; + std::vector names_; + size_t num_places_; + bool drop_last_; +}; + +static std::unique_ptr CreatePyReader( + const std::vector< + std::shared_ptr> &queues, + const std::vector &dst_places) { + std::shared_ptr reader; + if (queues.size() == 1) { + reader.reset(new operators::reader::PyReader(queues[0])); + } else { + reader.reset(new operators::reader::MultiQueuePyReader(queues)); + } + std::vector> buffered_reader; + buffered_reader.reserve(dst_places.size()); + for (auto &p : dst_places) { + buffered_reader.emplace_back( + framework::MakeDecoratedReader( + reader, p, 2)); + } + reader = framework::MakeDecoratedReader( + buffered_reader); + + auto *holder = new framework::ReaderHolder(); + holder->Reset(reader); + return std::unique_ptr(holder); +} + +namespace py = pybind11; + +void BindReader(py::module *module) { + auto &m = *module; + + namespace reader = ::paddle::operators::reader; + + py::class_(m, "Reader", "") + .def("start", &framework::ReaderHolder::Start) + .def("reset", &framework::ReaderHolder::ResetAll); + + py::class_(m, "FeedReader", "") + .def("read_next", &FeedReader::ReadNext, + py::call_guard()) + .def("start", &FeedReader::Start, + py::call_guard()) + .def("reset", &FeedReader::Reset, + py::call_guard()); + + m.def("create_py_reader", + [](const std::vector< + std::shared_ptr> + queues, + const std::vector &names, + const std::vector &dst_places, bool drop_last) { + return new FeedReader(CreatePyReader(queues, dst_places), names, + dst_places.size(), drop_last); + }, + py::return_value_policy::take_ownership); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/reader_py.h b/paddle/fluid/pybind/reader_py.h new file mode 100644 index 0000000000..472ff65368 --- /dev/null +++ b/paddle/fluid/pybind/reader_py.h @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" + +namespace paddle { +namespace pybind { + +void BindReader(pybind11::module *module); + +} // namespace pybind +} // namespace paddle diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index ef02429428..523894af7b 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -17,6 +17,7 @@ import os import six import sys from .. import compat as cpt +from .framework import cuda_places, cpu_places from . import core @@ -78,7 +79,8 @@ class CompiledProgram(object): loss_name=None, build_strategy=None, exec_strategy=None, - share_vars_from=None): + share_vars_from=None, + places=None): """Configs the program to run in data parallel way. Args: @@ -97,6 +99,12 @@ class CompiledProgram(object): will share variables from `share_vars_from`. `share_vars_from` must be run by the executor before this CompiledProgram so that vars are ready. + places(list(CUDAPlace)|list(CPUPlace)|None): If provide, only compile + program in the given places. Otherwise, the places used when compiled + is determined by the Executor, and the places used are controlled + by environment variables: FLAGS_selected_gpus or CUDA_VISIBLE_DEVICES + if using GPU; or CPU_NUM if using CPU. + Returns: self """ @@ -110,6 +118,12 @@ class CompiledProgram(object): self._exec_strategy = ExecutionStrategy() if self._build_strategy is None: self._build_strategy = BuildStrategy() + if places is not None: + if not isinstance(places, (list, tuple)): + places = [places] + self._places = [_place_obj(p) for p in places] + else: + self._places = None return self def with_inference_optimize(self, config): @@ -148,19 +162,16 @@ class CompiledProgram(object): self._local_scopes = [] self._exec_strategy.use_cuda = isinstance(self._place, core.CUDAPlace) - if self._exec_strategy.use_cuda: - gpus_env = os.getenv("FLAGS_selected_gpus") - if gpus_env: - gpus = [int(s) for s in gpus_env.split(",")] - else: - gpus = [ - i for i in six.moves.range(core.get_cuda_device_count()) - ] - self._places = [core.CUDAPlace(i) for i in gpus] + has_set_place = (self._places is not None) + if has_set_place: + desire_place = _place_obj(self._place) + for p in self._places: + assert p._type() == desire_place._type(), \ + "Place type not match. You may set the wrong type of places" else: - cpu_num = int( - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - self._places = [core.CPUPlace() for _ in six.moves.range(cpu_num)] + places = cuda_places( + ) if self._exec_strategy.use_cuda else cpu_places() + self._places = [_place_obj(p) for p in places] assert self._places, "no place for execution" if self._exec_strategy.num_threads == 0: @@ -169,9 +180,7 @@ class CompiledProgram(object): # performance. Worth tunning for other models in the future. self._exec_strategy.num_threads = len(self._places) * 4 else: - cpu_num = int( - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - self._exec_strategy.num_threads = cpu_num * 2 + self._exec_strategy.num_threads = len(self._places) * 2 trainers_endpoints = self._program._trainers_endpoints @@ -217,7 +226,7 @@ class CompiledProgram(object): if self._compiled: if scope and self._scope != scope: raise ValueError("Cannot compile with different scope") - if place and self._place != place: + if place and not self._place._equals(place): raise ValueError("Cannot compile with different place") return self self._compiled = True diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 8815911eae..5454d12e2c 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -554,6 +554,10 @@ class Executor(object): if feed is None: feed = {} + elif isinstance(feed, (list, tuple)): + assert len(feed) == 1, "Not compiled with data parallel" + feed = feed[0] + if not isinstance(feed, dict): raise TypeError( "feed requires dict as its Parameter. But you passed in %s" % diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index ef304b1110..deb837d96c 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -26,6 +26,7 @@ import six import numpy as np import subprocess +import multiprocessing from .. import compat as cpt from .proto import framework_pb2 @@ -63,6 +64,9 @@ __all__ = [ 'default_main_program', 'program_guard', 'name_scope', + 'cuda_places', + 'cpu_places', + 'cuda_pinned_places', ] EMPTY_VAR_NAME = core.kEmptyVarName() @@ -87,6 +91,38 @@ def _current_expected_place(): return _imperative_current_expected_place_ +def _cpu_num(): + return int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + + +def cuda_places(device_ids=None): + assert core.is_compiled_with_cuda(), \ + "Not compiled with CUDA" + if device_ids is None: + gpus_env = os.getenv("FLAGS_selected_gpus") + if gpus_env: + device_ids = [int(s) for s in gpus_env.split(",")] + else: + device_ids = six.moves.range(core.get_cuda_device_count()) + elif not isinstance(device_ids, (list, tuple)): + device_ids = [device_ids] + return [core.CUDAPlace(dev_id) for dev_id in device_ids] + + +def cpu_places(device_count=None): + if device_count is None: + device_count = _cpu_num() + return [core.CPUPlace()] * device_count + + +def cuda_pinned_places(device_count=None): + assert core.is_compiled_with_cuda(), \ + "Not compiled with CUDA" + if device_count is None: + device_count = _cpu_num() + return [core.cuda_pinned_places()] * device_count + + class NameScope(object): def __init__(self, name="", parent=None): self._children = dict() diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index a2abbf36c0..1e3f4f476f 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -26,12 +26,14 @@ from paddle.fluid import layers from paddle.fluid.executor import Executor from paddle.fluid.evaluator import Evaluator from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, program_guard +from . import reader +from .reader import * from . import core __all__ = [ 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'load_persistables', 'save_inference_model', 'load_inference_model' -] +] + reader.__all__ def is_parameter(var): diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py new file mode 100644 index 0000000000..b765430622 --- /dev/null +++ b/python/paddle/fluid/reader.py @@ -0,0 +1,141 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import core +import six +import threading +from .framework import Program, Variable, program_guard +from .data_feeder import DataFeeder + +__all__ = ['PyReader'] + + +def _convert_places(places): + if not isinstance(places, (list, tuple)): + places = [places] + + ret = [] + for p in places: + if not isinstance(p, core.Place): + tmp = core.Place() + tmp.set_place(p) + p = tmp + + ret.append(p) + return ret + + +class PyReader(object): + def __init__(self, feed_list, places, capacity, multi_queue=True): + self._tensor_reader = None + self._thread = None + + # TODO(zjl): to support drop_last = False + self._drop_last = True + + self._feed_list = feed_list + self._var_names = [v.name for v in feed_list] + + self._queues = [] + + self._places = _convert_places(places) + + self._queue_capacity = capacity + + queue_num = len(self._places) if multi_queue else 1 + for _ in six.moves.range(queue_num): + self._queues.append( + core.init_lod_tensor_blocking_queue(core.Variable(), + self._queue_capacity)) + + self._reader = core.create_py_reader(self._queues, self._var_names, + self._places, self._drop_last) + self._exited = True + + def __call__(self): + assert self._tensor_reader is not None, \ + "Data source of PyReader has not set yet" + + class Iterator(object): + def __init__(self, reader): + self._reader = reader + + def __iter__(self): + return self + + def next(self): + ret = self._reader._reader.read_next() + if len(ret): + return ret + else: + self._reader._restart_reader() + self._reader._reader.reset() + raise StopIteration + + return Iterator(self) + + def _restart_reader(self): + if not self._exited: + for q in self._queues: + q.close() + + self._thread.join() + + def __thread_main__(): + queue_num = len(self._queues) + idx = 0 + for tensors in self._tensor_reader(): + array = core.LoDTensorArray() + for item in tensors: + if not isinstance(item, core.LoDTensor): + tmp = core.LoDTensor() + tmp.set(item, core.CPUPlace()) + item = tmp + + array.append(item) + + if not self._queues[idx].push(array): + break + + idx = (idx + 1) % queue_num + + for q in self._queues: + q.close() + + self._exited = True + + self._thread = threading.Thread(target=__thread_main__) + self._thread.daemon = True + self._exited = False + self._thread.start() + + def set_numpy_reader(self, reader): + assert self._tensor_reader is None, \ + "Cannot reset the data source of PyReader" + with program_guard(Program(), Program()): + feeder = DataFeeder( + feed_list=self._feed_list, place=core.CPUPlace()) + paddle_reader = feeder.decorate_reader(reader, multi_devices=False) + + def __tensor_reader_impl__(): + for slots in paddle_reader(): + yield [slots[var.name] for var in self._feed_list] + + self.set_tensor_reader(__tensor_reader_impl__) + + def set_tensor_reader(self, reader): + assert self._tensor_reader is None, \ + "Cannot reset the data source of PyReader" + self._tensor_reader = reader + self._restart_reader() diff --git a/python/paddle/fluid/tests/unittests/test_decoupled_py_reader.py b/python/paddle/fluid/tests/unittests/test_decoupled_py_reader.py new file mode 100644 index 0000000000..807cbaf39d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_decoupled_py_reader.py @@ -0,0 +1,157 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import time +import six +import unittest + +EPOCH_NUM = 60 +BATCH_SIZE = 32 +CLASS_NUM = 10 + + +def random_reader(): + for i in range(BATCH_SIZE * 40): + image = np.random.random([784]) + label = np.random.random_integers(low=0, high=CLASS_NUM - 1) + yield image, label + + +def simple_fc_net(places, use_legacy_py_reader): + startup_prog = fluid.Program() + main_prog = fluid.Program() + startup_prog.random_seed = 1 + main_prog.random_seed = 1 + reader = paddle.batch(random_reader, batch_size=BATCH_SIZE) + + with fluid.unique_name.guard(): + with fluid.program_guard(main_prog, startup_prog): + if not use_legacy_py_reader: + image = fluid.layers.data( + name='image', shape=[784], dtype='float32') + label = fluid.layers.data( + name='label', shape=[1], dtype='int64') + py_reader = fluid.io.PyReader( + feed_list=[image, label], + places=places, + capacity=4, + multi_queue=False) + py_reader.set_numpy_reader(reader) + else: + py_reader = fluid.layers.py_reader( + capacity=4, + shapes=[(-1, 784), (-1, 1)], + dtypes=['float32', 'int64']) + image, label = fluid.layers.read_file(py_reader) + py_reader.decorate_paddle_reader(reader) + + hidden = image + for hidden_size in [10, 20, 30]: + hidden = fluid.layers.fc( + hidden, + size=hidden_size, + act='tanh', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + + predict_label = fluid.layers.fc(hidden, + size=CLASS_NUM, + act='softmax') + loss = fluid.layers.mean( + fluid.layers.cross_entropy( + input=predict_label, label=label)) + + optimizer = fluid.optimizer.Adam() + optimizer.minimize(loss) + return startup_prog, main_prog, py_reader, loss + + +class TestBase(unittest.TestCase): + def run_main(self, use_legacy_py_reader, with_data_parallel, places): + with fluid.scope_guard(fluid.Scope()): + startup_prog, main_prog, py_reader, loss = simple_fc_net( + places, use_legacy_py_reader) + exe = fluid.Executor(place=places[0]) + exe.run(startup_prog) + + prog = fluid.CompiledProgram(main_prog) + if with_data_parallel: + prog = prog.with_data_parallel( + loss_name=loss.name, places=places) + + step = 0 + start_t = time.time() + if use_legacy_py_reader: + for _ in six.moves.range(EPOCH_NUM): + py_reader.start() + while True: + try: + L, = exe.run(program=prog, fetch_list=[loss]) + step += 1 + except fluid.core.EOFException: + py_reader.reset() + break + else: + for _ in six.moves.range(EPOCH_NUM): + for d in py_reader(): + ''' + assert len(d) == len(places) + for i, item in enumerate(d): + image = item['image'] + label = item['label'] + assert image.shape() == [BATCH_SIZE, 784] + assert label.shape() == [BATCH_SIZE, 1] + assert image._place()._equals(places[i]) + assert label._place()._equals(places[i]) + ''' + L, = exe.run(program=prog, feed=d, fetch_list=[loss]) + step += 1 + end_t = time.time() + return {"time": end_t - start_t, "step": step} + + def prepare_places(self, with_data_parallel): + places = [[fluid.CPUPlace()], ] + if with_data_parallel: + places.append([fluid.CPUPlace()] * 2) + + if fluid.core.is_compiled_with_cuda(): + tmp = fluid.cuda_places() + assert len(tmp) > 0, "no gpu detected" + if with_data_parallel: + places.append(tmp) + places.append([tmp[0]]) + return places + + def test_main(self): + for with_data_parallel in [True, False]: + for p in self.prepare_places(with_data_parallel): + t = [] + for use_legacy_py_reader in [False, True]: + ret = self.run_main( + use_legacy_py_reader=use_legacy_py_reader, + with_data_parallel=with_data_parallel, + places=p) + ret['legacy'] = use_legacy_py_reader + ret['data_parallel'] = with_data_parallel + ret['places'] = p + t.append(ret) + + print(t) + + +if __name__ == '__main__': + unittest.main() -- GitLab