diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index cfe6730e0ca96020932880fd11b292469349fdf4..032da0cad85ce43ab2630123f9f2cfd8dee4224e 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -10,6 +10,9 @@ paddle.fluid.default_startup_program (ArgSpec(args=[], varargs=None, keywords=No paddle.fluid.default_main_program (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', '5430f54ab4895f9f47db6bebbaf71659')) paddle.fluid.program_guard (ArgSpec(args=['main_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b54f403e57825a1592aece03afe3afb6')) paddle.fluid.name_scope (ArgSpec(args=['prefix'], varargs=None, keywords=None, defaults=(None,)), ('document', '0ef753f5cec69fef9ae6ad8b867b33a2')) +paddle.fluid.cuda_places (ArgSpec(args=['device_ids'], varargs=None, keywords=None, defaults=(None,)), ('document', '7d9a51fc9cf3c5245b5227080a8064c3')) +paddle.fluid.cpu_places (ArgSpec(args=['device_count'], varargs=None, keywords=None, defaults=(None,)), ('document', '4c0cd83f0b401fc2ff84c70974e5d210')) +paddle.fluid.cuda_pinned_places (ArgSpec(args=['device_count'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd0c3ebd813c39958c92b78e3eef7e912')) paddle.fluid.Executor.__init__ (ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.Executor.close (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'f5369953dd0c443961cf79f7a00e1a03')) paddle.fluid.Executor.run (ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)), ('document', 'f482e93b38b4018796969a2e1dde479d')) @@ -44,7 +47,7 @@ paddle.fluid.AsyncExecutor.run (ArgSpec(args=['self', 'program', 'data_feed', 'f paddle.fluid.AsyncExecutor.save_model (ArgSpec(args=['self', 'save_path'], varargs=None, keywords=None, defaults=None), ('document', 'c8ac0dfcb3b187aba25d03af7fea56b2')) paddle.fluid.AsyncExecutor.stop (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '5f23d043607bb5d55e466ec3f578e093')) paddle.fluid.CompiledProgram.__init__ (ArgSpec(args=['self', 'program_or_graph'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.CompiledProgram.with_data_parallel (ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'e1af7fd53cf868554f312779fc803864')) +paddle.fluid.CompiledProgram.with_data_parallel (ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from', 'places'], varargs=None, keywords=None, defaults=(None, None, None, None, None)), ('document', 'a8c7793803cf976680d9478e378fa356')) paddle.fluid.CompiledProgram.with_inference_optimize (ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=None), ('document', '9e5b009d850191a010e859189c127fd8')) paddle.fluid.ExecutionStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.ExecutionStrategy) -> None paddle.fluid.BuildStrategy.GradientScaleStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.BuildStrategy.GradientScaleStrategy, arg0: int) -> None @@ -58,6 +61,12 @@ paddle.fluid.io.load_params (ArgSpec(args=['executor', 'dirname', 'main_program' paddle.fluid.io.load_persistables (ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)), ('document', '28df5bfe26ca7a077f91156abb0fe6d2')) paddle.fluid.io.save_inference_model (ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True)), ('document', '70f4f53f13572436ac72d1c8b5efeb9d')) paddle.fluid.io.load_inference_model (ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '7a5255386075dac3c75b7058254fcdcb')) +paddle.fluid.io.PyReader.__init__ (ArgSpec(args=['self', 'feed_list', 'capacity', 'use_double_buffer', 'iterable'], varargs=None, keywords=None, defaults=(True, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.io.PyReader.decorate_batch_generator (ArgSpec(args=['self', 'reader', 'places'], varargs=None, keywords=None, defaults=(None,)), ('document', 'a3fefec8bacd6ce83f49906a9d05e779')) +paddle.fluid.io.PyReader.decorate_sample_generator (ArgSpec(args=['self', 'sample_generator', 'batch_size', 'drop_last', 'places'], varargs=None, keywords=None, defaults=(True, None)), ('document', '7abd9cf7d695bab5bb6cf7ded5903cb2')) +paddle.fluid.io.PyReader.decorate_sample_list_generator (ArgSpec(args=['self', 'reader', 'places'], varargs=None, keywords=None, defaults=(None,)), ('document', 'faef298f73e91aedcfaf5d184f3109b7')) +paddle.fluid.io.PyReader.reset (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'ff1cc1e2beb8824d453656c72c28ddfb')) +paddle.fluid.io.PyReader.start (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'b7ea0a548991924e4cfe61a577b8e56d')) paddle.fluid.initializer.ConstantInitializer.__init__ (ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.initializer.UniformInitializer.__init__ (ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.initializer.NormalInitializer.__init__ (ArgSpec(args=['self', 'loc', 'scale', 'seed'], varargs=None, keywords=None, defaults=(0.0, 1.0, 0)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) @@ -230,7 +239,7 @@ paddle.fluid.layers.shuffle (ArgSpec(args=['reader', 'buffer_size'], varargs=Non paddle.fluid.layers.batch (ArgSpec(args=['reader', 'batch_size'], varargs=None, keywords=None, defaults=None), ('document', 'f563d376d35e1a4c4db100fd11b381a0')) paddle.fluid.layers.double_buffer (ArgSpec(args=['reader', 'place', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', '07e5b796674796eb1ef3fee9c10d24e3')) paddle.fluid.layers.random_data_generator (ArgSpec(args=['low', 'high', 'shapes', 'lod_levels', 'for_parallel'], varargs=None, keywords=None, defaults=(True,)), ('document', '9b7f0f86ec24bbc97643cadcb6499cff')) -paddle.fluid.layers.py_reader (ArgSpec(args=['capacity', 'shapes', 'dtypes', 'lod_levels', 'name', 'use_double_buffer'], varargs=None, keywords=None, defaults=(None, None, True)), ('document', '13dabc57863f62ab3141586784ee356b')) +paddle.fluid.layers.py_reader (ArgSpec(args=['capacity', 'shapes', 'dtypes', 'lod_levels', 'name', 'use_double_buffer'], varargs=None, keywords=None, defaults=(None, None, True)), ('document', '4357643685cfd65454ba5a15f0151709')) paddle.fluid.layers.create_py_reader_by_data (ArgSpec(args=['capacity', 'feed_list', 'name', 'use_double_buffer'], varargs=None, keywords=None, defaults=(None, True)), ('document', '350f74d93fab9adb2ac4950f1c26416b')) paddle.fluid.layers.Preprocessor.__init__ (ArgSpec(args=['self', 'reader', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.layers.Preprocessor.block (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) @@ -511,6 +520,7 @@ paddle.fluid.unique_name.guard (ArgSpec(args=['new_generator'], varargs=None, ke paddle.fluid.recordio_writer.convert_reader_to_recordio_file (ArgSpec(args=['filename', 'reader_creator', 'feeder', 'compressor', 'max_num_records', 'feed_order'], varargs=None, keywords=None, defaults=(Compressor.Snappy, 1000, None)), ('document', '65c7523e86f0c50bb729b01667f36310')) paddle.fluid.recordio_writer.convert_reader_to_recordio_files (ArgSpec(args=['filename', 'batch_per_file', 'reader_creator', 'feeder', 'compressor', 'max_num_records', 'feed_order'], varargs=None, keywords=None, defaults=(Compressor.Snappy, 1000, None)), ('document', 'bc643f0f5f1b9db57ff0d8a57d379bd7')) paddle.fluid.Scope Scope() -> paddle.fluid.core._Scope +paddle.reader.cache (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '1676886070eb607cb608f7ba47be0d3c')) paddle.reader.map_readers (ArgSpec(args=['func'], varargs='readers', keywords=None, defaults=None), ('document', '77cbadb09df588e21e5cc0819b69c87d')) paddle.reader.buffered (ArgSpec(args=['reader', 'size'], varargs=None, keywords=None, defaults=None), ('document', '0d6186f109feceb99f60ec50a0a624cb')) paddle.reader.compose (ArgSpec(args=[], varargs='readers', keywords='kwargs', defaults=None), ('document', '884291104e1c3f37f33aae44b7deeb0d')) diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 7c284312df912ad758f6fffc44f111dfe765feb8..5ee1206175600cd668ccbbf5b98053708a4406d3 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -17,7 +17,9 @@ function(reader_library TARGET_NAME) PARENT_SCOPE) endfunction() +cc_library(py_reader SRCS py_reader.cc DEPS reader) cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool) + reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) @@ -26,7 +28,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc DEPS buffered_reader) reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc) -reader_library(create_py_reader_op SRCS create_py_reader_op.cc) +reader_library(create_py_reader_op SRCS create_py_reader_op.cc DEPS py_reader) if (NOT WIN32 AND NOT ON_INFER) cc_library(ctr_reader SRCS ctr_reader.cc DEPS gzstream reader zlib) @@ -38,7 +40,7 @@ cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc) # Export local libraries to parent # set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) -op_library(read_op) +op_library(read_op DEPS py_reader buffered_reader) foreach(src ${LOCAL_READER_LIBS}) set(OP_LIBRARY ${src} ${OP_LIBRARY} CACHE INTERNAL "op libs") diff --git a/paddle/fluid/operators/reader/blocking_queue.h b/paddle/fluid/operators/reader/blocking_queue.h index 51b980acb5a08d431d96a3a92479dec09119c27e..78d238aa6115265023d5d87c01048a87180448d0 100644 --- a/paddle/fluid/operators/reader/blocking_queue.h +++ b/paddle/fluid/operators/reader/blocking_queue.h @@ -16,6 +16,7 @@ #include // NOLINT #include +#include #include "paddle/fluid/platform/enforce.h" @@ -34,7 +35,7 @@ class BlockingQueue { explicit BlockingQueue(size_t capacity, bool speed_test_mode = false) : capacity_(capacity), speed_test_mode_(speed_test_mode), closed_(false) { PADDLE_ENFORCE_GT( - capacity_, 0, + capacity_, static_cast(0), "The capacity of a reader::BlockingQueue must be greater than 0."); } diff --git a/paddle/fluid/operators/reader/buffered_reader.cc b/paddle/fluid/operators/reader/buffered_reader.cc index 134807092d59329ce93381da67a98b8230db5767..c24e9aedc4ebd8f4fa9e483b1c1cc71fe0bf0aa7 100644 --- a/paddle/fluid/operators/reader/buffered_reader.cc +++ b/paddle/fluid/operators/reader/buffered_reader.cc @@ -30,8 +30,10 @@ BufferedReader::~BufferedReader() { #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { platform::SetDeviceId(boost::get(place_).device); - PADDLE_ENFORCE(cudaStreamDestroy(stream)); - for (auto &event : events) PADDLE_ENFORCE(cudaEventDestroy(event)); + PADDLE_ENFORCE(cudaStreamDestroy(stream_)); + for (auto &event : events_) { + PADDLE_ENFORCE(cudaEventDestroy(event)); + } } #endif } @@ -46,15 +48,15 @@ BufferedReader::BufferedReader( #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { platform::SetDeviceId(boost::get(place_).device); - compute_stream = + compute_stream_ = ((platform::CUDADeviceContext *)(platform::DeviceContextPool::Instance() .Get(place_))) ->stream(); - events.resize(buffer_size); - for (auto &event : events) { + events_.resize(buffer_size); + for (auto &event : events_) { PADDLE_ENFORCE(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); } - PADDLE_ENFORCE(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + PADDLE_ENFORCE(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); } #endif cpu_buffer_.resize(buffer_size); @@ -73,7 +75,7 @@ void BufferedReader::ReadAsync(size_t i) { #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { platform::SetDeviceId(boost::get(place_).device); - PADDLE_ENFORCE(cudaEventRecord(events[i], compute_stream)); + PADDLE_ENFORCE(cudaEventRecord(events_[i], compute_stream_)); } #endif position_.emplace(thread_pool_.enqueue([this, i]() -> size_t { @@ -91,7 +93,7 @@ void BufferedReader::ReadAsync(size_t i) { // commands from different streams cannot run concurrently. if (platform::is_gpu_place(place_)) { platform::SetDeviceId(boost::get(place_).device); - PADDLE_ENFORCE(cudaStreamWaitEvent(stream, events[i], 0)); + PADDLE_ENFORCE(cudaStreamWaitEvent(stream_, events_[i], 0)); TensorVec &gpu = gpu_buffer_[i]; gpu.resize(cpu.size()); platform::RecordEvent record_event("BufferedReader:MemoryCopy"); @@ -106,12 +108,14 @@ void BufferedReader::ReadAsync(size_t i) { if (platform::is_cuda_pinned_place(cpu_place)) { memory::Copy(boost::get(place_), gpu_ptr, boost::get(cpu_place), - cpu_ptr, size, stream); + cpu_ptr, size, stream_); } else if ((platform::is_gpu_place(cpu_place))) { memory::Copy(boost::get(place_), gpu_ptr, boost::get(cpu_place), cpu_ptr, - size, stream); + size, stream_); } else { + // if cpu place is not pinned, async copy is slower than sync copy, + // so we use sync copy instead. // TODO(zcd): The default stream should not be used here. memory::Copy(boost::get(place_), gpu_ptr, boost::get(cpu_place), cpu_ptr, size, @@ -119,7 +123,7 @@ void BufferedReader::ReadAsync(size_t i) { } gpu[i].set_lod(cpu[i].lod()); } - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); } #endif return i; diff --git a/paddle/fluid/operators/reader/buffered_reader.h b/paddle/fluid/operators/reader/buffered_reader.h index 87680da01a1f51cfdfe4d100508440eda9d1877f..5f8b2d47c22d0a15d53c8d30d39608fd64d4bddd 100644 --- a/paddle/fluid/operators/reader/buffered_reader.h +++ b/paddle/fluid/operators/reader/buffered_reader.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include "ThreadPool.h" @@ -63,9 +64,9 @@ class BufferedReader : public framework::DecoratedReader { std::vector gpu_buffer_; size_t prev_pos_{-1UL}; #ifdef PADDLE_WITH_CUDA - cudaStream_t stream; - cudaStream_t compute_stream; - std::vector events; + cudaStream_t stream_; + cudaStream_t compute_stream_; + std::vector events_; #endif }; diff --git a/paddle/fluid/operators/reader/create_py_reader_op.cc b/paddle/fluid/operators/reader/create_py_reader_op.cc index 901a92ab5b5c74b071be8b57a7653d90e2a4fb29..4a6581bbbd00019db33896371adac6d4e420e48c 100644 --- a/paddle/fluid/operators/reader/create_py_reader_op.cc +++ b/paddle/fluid/operators/reader/create_py_reader_op.cc @@ -12,37 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" +#include "paddle/fluid/operators/reader/py_reader.h" #include "paddle/fluid/operators/reader/reader_op_registry.h" namespace paddle { namespace operators { namespace reader { -class PyReader : public framework::FileReader { - public: - explicit PyReader(const std::shared_ptr& queue) - : framework::FileReader() { - PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); - queue_ = queue; - } - - void ReadNext(std::vector* out) override { - bool success; - *out = queue_->Pop(&success); - if (!success) out->clear(); - } - - ~PyReader() { queue_->Close(); } - - void Shutdown() override { queue_->Close(); } - - void Start() override { queue_->ReOpen(); } - - private: - std::shared_ptr queue_; -}; - class CreatePyReaderOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; diff --git a/paddle/fluid/operators/reader/py_reader.cc b/paddle/fluid/operators/reader/py_reader.cc new file mode 100644 index 0000000000000000000000000000000000000000..155ae859defcf20a5e226a4abfb99dc308dfb23c --- /dev/null +++ b/paddle/fluid/operators/reader/py_reader.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/py_reader.h" +#include + +namespace paddle { +namespace operators { +namespace reader { + +PyReader::PyReader(const std::shared_ptr& queue) + : framework::FileReader() { + PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); + queue_ = queue; +} + +void PyReader::ReadNext(std::vector* out) { + bool success; + *out = queue_->Pop(&success); + if (!success) out->clear(); +} + +PyReader::~PyReader() { queue_->Close(); } + +void PyReader::Shutdown() { queue_->Close(); } + +void PyReader::Start() { queue_->ReOpen(); } + +} // namespace reader +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/py_reader.h b/paddle/fluid/operators/reader/py_reader.h new file mode 100644 index 0000000000000000000000000000000000000000..43079075142e8db22c0e3b7c86de4249d447f961 --- /dev/null +++ b/paddle/fluid/operators/reader/py_reader.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" + +namespace paddle { +namespace operators { +namespace reader { + +class PyReader : public framework::FileReader { + public: + explicit PyReader(const std::shared_ptr& queue); + + void ReadNext(std::vector* out) override; + + ~PyReader(); + + void Shutdown() override; + + void Start() override; + + private: + std::shared_ptr queue_; +}; + +} // namespace reader +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index f1385f57184eceec49b791cf6c89641b098f036a..0991eff0fdaaca80ada2d8dd3c68eba72fd3f6e6 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -5,7 +5,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune if(WITH_PYTHON) list(APPEND PYBIND_DEPS py_func_op) endif() -set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc ir.cc inference_api.cc) +set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc reader_py.cc async_executor_py.cc imperative.cc ir.cc inference_api.cc) if(WITH_PYTHON) if(WITH_AMD_GPU) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a57083a1444a164cdeecf7e3e6eff6dc0e1e7be7..cef95de2ef675e417b5a2c49d01e3c85e23f9718 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -55,6 +55,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/ir.h" #include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/pybind.h" // NOLINT +#include "paddle/fluid/pybind/reader_py.h" #include "paddle/fluid/pybind/recordio.h" #include "paddle/fluid/pybind/tensor_py.h" @@ -128,6 +129,11 @@ static inline bool IsSamePlace(const PlaceType1 &p1, const PlaceType2 &p2) { return paddle::platform::Place(p1) == paddle::platform::Place(p2); } +template +static inline int PlaceIndex(const PlaceType &p) { + return static_cast(paddle::platform::Place(p).which()); +} + PYBIND11_MODULE(core, m) { // Not used, just make sure cpu_info.cc is linked. paddle::platform::CpuTotalPhysicalMemory(); @@ -531,6 +537,7 @@ PYBIND11_MODULE(core, m) { All parameter, weight, gradient are variables in Paddle. )DOC") + .def(py::init<>()) .def("is_int", [](const Variable &var) { return var.IsType(); }) .def("set_int", [](Variable &var, int val) -> void { *var.GetMutable() = val; }) @@ -572,14 +579,13 @@ All parameter, weight, gradient are variables in Paddle. }, py::return_value_policy::reference); - py::class_(m, "Reader", "") - .def("start", &framework::ReaderHolder::Start) - .def("reset", &framework::ReaderHolder::ResetAll); + BindReader(&m); using LoDTensorBlockingQueue = ::paddle::operators::reader::LoDTensorBlockingQueue; using LoDTensorBlockingQueueHolder = ::paddle::operators::reader::LoDTensorBlockingQueueHolder; + py::class_>( m, "LoDTensorBlockingQueue", "") .def("push", @@ -776,6 +782,7 @@ All parameter, weight, gradient are variables in Paddle. PADDLE_THROW("Cannot use CUDAPlace in CPU only version"); #endif }) + .def("_type", &PlaceIndex) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) @@ -785,6 +792,7 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "CPUPlace") .def(py::init<>()) + .def("_type", &PlaceIndex) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) @@ -800,6 +808,7 @@ All parameter, weight, gradient are variables in Paddle. #endif new (&self) platform::CUDAPinnedPlace(); }) + .def("_type", &PlaceIndex) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) @@ -811,16 +820,25 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Place") .def(py::init<>()) + .def("_type", &PlaceIndex) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("is_gpu_place", [](platform::Place &self) { return platform::is_gpu_place(self); }) + .def("is_cpu_place", + [](platform::Place &self) { return platform::is_cpu_place(self); }) + .def("is_cuda_pinned_place", + [](platform::Place &self) { + return platform::is_cuda_pinned_place(self); + }) .def("gpu_device_id", [](platform::Place &self) { return boost::get(self).device; }) + .def("set_place", [](platform::Place &self, + const platform::Place &other) { self = other; }) .def("set_place", [](platform::Place &self, const platform::CPUPlace &cpu_place) { self = cpu_place; diff --git a/paddle/fluid/pybind/reader_py.cc b/paddle/fluid/pybind/reader_py.cc new file mode 100644 index 0000000000000000000000000000000000000000..af7d30552ed47c0fbe26090b328cc7128b90f84d --- /dev/null +++ b/paddle/fluid/pybind/reader_py.cc @@ -0,0 +1,161 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pybind/reader_py.h" +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/operators/reader/buffered_reader.h" +#include "paddle/fluid/operators/reader/py_reader.h" +#include "paddle/fluid/platform/place.h" +#include "pybind11/stl.h" + +namespace paddle { +namespace pybind { + +class MultiDeviceFeedReader { + public: + using ResultDictList = + std::vector>; + + MultiDeviceFeedReader( + const std::shared_ptr &queue, + const std::vector &names, + const std::vector &dst_places, bool use_double_buffer) + : queue_(queue), + names_(names), + pool_(new ::ThreadPool(dst_places.size())) { + std::shared_ptr reader( + new operators::reader::PyReader(queue)); + + readers_.reserve(dst_places.size()); + for (auto &p : dst_places) { + auto *holder = new framework::ReaderHolder(); + if (use_double_buffer) { + holder->Reset( + framework::MakeDecoratedReader( + reader, p, 2)); + } else { + if (platform::is_gpu_place(p)) { + PADDLE_THROW( + "Place cannot be CUDAPlace when use_double_buffer is False"); + } + holder->Reset(reader); + } + readers_.emplace_back(holder); + } + + futures_.resize(dst_places.size()); + ret_.resize(dst_places.size()); + ReadAsync(); + } + + ResultDictList ReadNext() { + bool success = WaitFutures(); + + if (!success) { + return {}; + } + + ResultDictList result(ret_.size()); + for (size_t i = 0; i < ret_.size(); ++i) { + for (size_t j = 0; j < names_.size(); ++j) { + result[i].emplace(names_[j], std::move(ret_[i][j])); + } + } + ReadAsync(); + return result; + } + + void Reset() { + Shutdown(); + Start(); + ReadAsync(); + } + + ~MultiDeviceFeedReader() { + queue_->Close(); + pool_.reset(); + } + + private: + bool WaitFutures() { + bool success = true; + for (auto &f : futures_) { + success &= f.get(); + } + return success; + } + + void Shutdown() { + for (auto &r : readers_) r->Shutdown(); + } + + void Start() { + for (auto &r : readers_) r->Start(); + } + + void ReadAsync() { + for (size_t i = 0; i < readers_.size(); ++i) { + futures_[i] = pool_->enqueue([this, i] { + readers_[i]->ReadNext(&ret_[i]); + return !ret_[i].empty(); + }); + } + } + + std::shared_ptr queue_; + std::vector names_; + std::unique_ptr<::ThreadPool> pool_; + + std::vector> readers_; + + std::vector> futures_; + std::vector> ret_; +}; + +namespace py = pybind11; + +void BindReader(py::module *module) { + auto &m = *module; + + namespace reader = ::paddle::operators::reader; + + py::class_(m, "Reader", "") + .def("start", &framework::ReaderHolder::Start) + .def("reset", &framework::ReaderHolder::ResetAll); + + py::class_(m, "MultiDeviceFeedReader", "") + .def("read_next", &MultiDeviceFeedReader::ReadNext, + py::call_guard()) + .def("reset", &MultiDeviceFeedReader::Reset, + py::call_guard()); + + m.def("create_py_reader", + [](const std::shared_ptr + &queue, + const std::vector &names, + const std::vector &dst_places, + bool use_double_buffer) { + return new MultiDeviceFeedReader(queue, names, dst_places, + use_double_buffer); + }, + py::return_value_policy::take_ownership); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/reader_py.h b/paddle/fluid/pybind/reader_py.h new file mode 100644 index 0000000000000000000000000000000000000000..472ff65368f3fb206ae599ae5d9d11e9ae8195ae --- /dev/null +++ b/paddle/fluid/pybind/reader_py.h @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" + +namespace paddle { +namespace pybind { + +void BindReader(pybind11::module *module); + +} // namespace pybind +} // namespace paddle diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index 5732377bd60f849494ae7e463f40d4843ffa2c23..ac2a40a7c25f7c3ff0cc103647355da55d27fec3 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -17,9 +17,10 @@ import os import six import sys from .. import compat as cpt +from . import framework +from .framework import cuda_places, cpu_places from . import core -from . import framework __all__ = ['CompiledProgram', 'ExecutionStrategy', 'BuildStrategy'] @@ -44,21 +45,6 @@ def _is_pserver_mode(main_program): return False -def get_available_places(use_cuda): - if use_cuda: - gpus_env = os.getenv("FLAGS_selected_gpus") - if gpus_env: - gpus = [int(s) for s in gpus_env.split(",")] - else: - gpus = [i for i in six.moves.range(core.get_cuda_device_count())] - places = [core.CUDAPlace(i) for i in gpus] - else: - cpu_num = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - places = [core.CPUPlace() for _ in six.moves.range(cpu_num)] - assert places, "no place for execution" - return places - - class CompiledProgram(object): """ Compiles to Graph for execution. @@ -117,7 +103,8 @@ class CompiledProgram(object): loss_name=None, build_strategy=None, exec_strategy=None, - share_vars_from=None): + share_vars_from=None, + places=None): """Configs the program to run in data parallel way. Args: @@ -132,10 +119,18 @@ class CompiledProgram(object): threads are used, how many iterations to clean up the temp variables. For more information, please refer to fluid.ExecutionStrategy. Default None. - share_vars_from(CompiledProgram): If provide, this CompiledProgram + share_vars_from(CompiledProgram): If provided, this CompiledProgram will share variables from `share_vars_from`. `share_vars_from` must be run by the executor before this CompiledProgram so that vars are ready. + places(list(CUDAPlace)|list(CPUPlace)|None): If provided, only compile + program in the given places. Otherwise, the places used when compiled + is determined by the Executor, and the places used are controlled + by environment variables: FLAGS_selected_gpus or CUDA_VISIBLE_DEVICES + if using GPU; or CPU_NUM if using CPU. For example, if you want to + run on GPU 0 and 1, set places=[fluid.CUDAPlace(0), fluid.CUDAPlace(1)]. + If you want to run on 2 CPU cores, set places=[fluid.CPUPlace()]*2. + Returns: self """ @@ -150,6 +145,12 @@ class CompiledProgram(object): self._exec_strategy = ExecutionStrategy() if self._build_strategy is None: self._build_strategy = BuildStrategy() + if places is not None: + if not isinstance(places, (list, tuple)): + places = [places] + self._places = places + else: + self._places = None self._build_strategy.is_distribution = _is_pserver_mode(self._program) return self @@ -192,7 +193,15 @@ class CompiledProgram(object): self._local_scopes = [] self._exec_strategy.use_cuda = use_cuda - self._places = get_available_places(self._exec_strategy.use_cuda) + has_set_place = (self._places is not None) + if has_set_place: + for p in self._places: + assert p._type() == self._place._type(), \ + "Place type not match. You may set the wrong type of places" + else: + self._places = cuda_places( + ) if self._exec_strategy.use_cuda else cpu_places() + assert self._places, "no place for execution" if self._exec_strategy.num_threads == 0: if self._exec_strategy.use_cuda: @@ -200,9 +209,7 @@ class CompiledProgram(object): # performance. Worth tunning for other models in the future. self._exec_strategy.num_threads = len(self._places) * 4 else: - cpu_num = int( - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - self._exec_strategy.num_threads = cpu_num * 2 + self._exec_strategy.num_threads = len(self._places) * 2 # FIXME(dzhwinter): enable_inplace should be after memory_optimize # if turn on python memory optimize, turn off the inplace_pass. diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 3dac41ce43d61c02f3e11087aef98e2fc454556b..00c4e5691a23a9864ed3e8964f4cafaf9588c665 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -26,6 +26,24 @@ from .framework import Variable, default_main_program __all__ = ['DataFeeder'] +def convert_dtype(dtype): + if dtype == core.VarDesc.VarType.FP32: + return 'float32' + elif dtype == core.VarDesc.VarType.INT64: + return 'int64' + elif dtype == core.VarDesc.VarType.FP64: + return 'float64' + elif dtype == core.VarDesc.VarType.FP16: + return 'float16' + elif dtype == core.VarDesc.VarType.INT32: + return 'int32' + elif dtype == core.VarDesc.VarType.UINT8: + return 'uint8' + else: + raise ValueError("dtype must be any of [int32, float32, int64, " + "float64, uint8]") + + class DataToLoDTensorConverter(object): def __init__(self, place, lod_level, shape, dtype): self.place = place @@ -38,27 +56,12 @@ class DataToLoDTensorConverter(object): if negtive_count > 1: self.shape = None break - if dtype == core.VarDesc.VarType.FP32: - self.dtype = 'float32' - elif dtype == core.VarDesc.VarType.INT64: - self.dtype = 'int64' - elif dtype == core.VarDesc.VarType.FP64: - self.dtype = 'float64' - elif dtype == core.VarDesc.VarType.FP16: - self.dtype = 'float16' - elif dtype == core.VarDesc.VarType.INT32: - self.dtype = 'int32' - elif dtype == core.VarDesc.VarType.UINT8: - self.dtype = 'uint8' - else: - raise ValueError("dtype must be any of [int32, float32, int64, " - "float64, uint8]") + self.dtype = convert_dtype(dtype) + self._reset() + def _reset(self): self.data = [] - self.lod = [] - - for i in six.moves.range(lod_level): - self.lod.append([]) + self.lod = [[] for _ in six.moves.range(self.lod_level)] def feed(self, data): self._feed_impl_(data, self.lod, self.lod_level) @@ -88,15 +91,52 @@ class DataToLoDTensorConverter(object): raise ValueError( "Reshape error. What is defined in data layer is {}, but receive {}" .format(self.shape, arr.shape)) - #else: - # self._check_shape(arr.shape) t = core.LoDTensor() t.set(arr, self.place) if self.lod_level > 0: t.set_recursive_sequence_lengths(self.lod) + self._reset() return t +class BatchedTensorProvider(object): + def __init__(self, feed_list, place, batch_size, generator, drop_last): + self.place = place + self.batch_size = batch_size + self.generator = generator + self.converters = [] + self.drop_last = drop_last + + for var in feed_list: + assert var.lod_level == 0, "lod_level must be 0" + self.converters.append( + DataToLoDTensorConverter( + place=self.place, + lod_level=0, + shape=var.shape, + dtype=var.dtype)) + + def _done(self): + return [c.done() for c in self.converters] + + def __call__(self): + idx = 0 + for each_sample in self.generator(): + for each_slot, each_converter in six.moves.zip(each_sample, + self.converters): + each_converter.data.append(each_slot) + + idx += 1 + if idx == self.batch_size: + idx = 0 + yield self._done() + + if not self.drop_last and idx > 0: + yield self._done() + else: + [c._reset() for c in self.converters] + + class DataFeeder(object): """ DataFeeder converts the data that returned by a reader into a data diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 03aa9917f3201e690a7072442cf11ac2284b03c5..018e38cbb3f2676ac05f1a27e9e92b6e0f16efb0 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -564,6 +564,10 @@ class Executor(object): if feed is None: feed = {} + elif isinstance(feed, (list, tuple)): + assert len(feed) == 1, "Not compiled with data parallel" + feed = feed[0] + if not isinstance(feed, dict): raise TypeError( "feed requires dict as its Parameter. But you passed in %s" % diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index f3d876f141763beec940899e8ab5ed464328b06e..ee366026f1c49044b7c7040f140abaf66b9c006a 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -26,6 +26,7 @@ import six import numpy as np import subprocess +import multiprocessing from .. import compat as cpt from .proto import framework_pb2 @@ -63,6 +64,9 @@ __all__ = [ 'default_main_program', 'program_guard', 'name_scope', + 'cuda_places', + 'cpu_places', + 'cuda_pinned_places', ] EMPTY_VAR_NAME = core.kEmptyVarName() @@ -87,6 +91,87 @@ def _current_expected_place(): return _imperative_current_expected_place_ +def _cpu_num(): + return int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + + +def cuda_places(device_ids=None): + ''' + Create a list of :code:`fluid.CUDAPlace` objects. + + If :code:`device_ids` is None, environment variable of + :code:`FLAGS_selected_gpus` would be checked first. If + :code:`FLAGS_selected_gpus=0,1,2`, the returned list would + be [fluid.CUDAPlace(0), fluid.CUDAPlace(1), fluid.CUDAPlace(2)]. + If :code:`FLAGS_selected_gpus` is not set, all visible + gpu places would be returned. + + If :code:`device_ids` is not None, it should be the device + ids of gpus. For example, if :code:`device_ids=[0,1,2]`, + the returned list would be + [fluid.CUDAPlace(0), fluid.CUDAPlace(1), fluid.CUDAPlace(2)]. + + Args: + device_ids (None|list(int)|tuple(int)): gpu device id list. + + Returns: + out (list(fluid.CUDAPlace)): gpu place list. + ''' + assert core.is_compiled_with_cuda(), \ + "Not compiled with CUDA" + if device_ids is None: + gpus_env = os.getenv("FLAGS_selected_gpus") + if gpus_env: + device_ids = [int(s) for s in gpus_env.split(",")] + else: + device_ids = six.moves.range(core.get_cuda_device_count()) + elif not isinstance(device_ids, (list, tuple)): + device_ids = [device_ids] + return [core.CUDAPlace(dev_id) for dev_id in device_ids] + + +def cpu_places(device_count=None): + ''' + Create a list of :code:`fluid.CPUPlace` objects. + + If :code:`device_count` is None, the device count would + be determined by environment variable :code:`CPU_NUM`. + If :code:`CPU_NUM` is not set, the device count would + be determined by :code:`multiprocessing.cpu_count()`. + + Args: + device_count (None|int): device number. + + Returns: + out (list(fluid.CPUPlace)): cpu place list. + ''' + if device_count is None: + device_count = _cpu_num() + return [core.CPUPlace()] * device_count + + +def cuda_pinned_places(device_count=None): + ''' + Create a list of :code:`fluid.CUDAPinnedPlace` objects. + + If :code:`device_count` is None, the device count would + be determined by environment variable :code:`CPU_NUM`. + If :code:`CPU_NUM` is not set, the device count would + be determined by :code:`multiprocessing.cpu_count()`. + + Args: + device_count (None|int): device number. + + Returns: + out (list(fluid.CUDAPinnedPlace)): cuda pinned place list. + ''' + assert core.is_compiled_with_cuda(), \ + "Not compiled with CUDA" + if device_count is None: + device_count = _cpu_num() + return [core.cuda_pinned_places()] * device_count + + class NameScope(object): def __init__(self, name="", parent=None): self._children = dict() diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 326a84d82b5718dad898620a6d9e0490f7519448..4d5523627218601d00021c72a8777b4b6413880e 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -26,12 +26,14 @@ from paddle.fluid import layers from paddle.fluid.executor import Executor from paddle.fluid.evaluator import Evaluator from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, program_guard +from . import reader +from .reader import * from . import core __all__ = [ 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'load_persistables', 'save_inference_model', 'load_inference_model' -] +] + reader.__all__ def is_parameter(var): diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index a9b391fd53a98dc05ee2d909a38dcf82cd5880ea..94fd9f3ea5a41a542da0115a66a52a5cd7f26748 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -563,22 +563,26 @@ def _py_reader(capacity, def start_provide_thread(func): def __provider_thread__(): - for tensors in func(): - array = core.LoDTensorArray() - for item in tensors: - if not isinstance(item, core.LoDTensor): - tmp = core.LoDTensor() - tmp.set(item, core.CPUPlace()) - item = tmp - - array.append(item) - - if reader.exited: - break - feed_queue.push(array) - if reader.exited: - break - feed_queue.close() + try: + for tensors in func(): + array = core.LoDTensorArray() + for item in tensors: + if not isinstance(item, core.LoDTensor): + tmp = core.LoDTensor() + tmp.set(item, core.CPUPlace()) + item = tmp + + array.append(item) + + if reader.exited: + break + feed_queue.push(array) + if reader.exited: + break + feed_queue.close() + except Exception as ex: + feed_queue.close() + raise ex reader.thread = threading.Thread(target=__provider_thread__) reader.thread.daemon = True @@ -628,6 +632,9 @@ def _py_reader(capacity, reader.reset = __reset__ reader.decorate_tensor_provider = __set_tensor_provider__ reader.decorate_paddle_reader = __set_paddle_reader__ + + reader.decorate_batch_generator = __set_tensor_provider__ + reader.decorate_sample_list_generator = __set_paddle_reader__ reader.start = __start__ return reader @@ -692,6 +699,11 @@ def py_reader(capacity, >>> exe.run(fetch_list=[loss.name]) >>> except fluid.core.EOFException: >>> reader.reset() + >>> + >>> ... + >>> + >>> fluid.io.save_inference_model(dirname='./model', feeded_var_names=[img, label], + >>> target_vars=[loss], executor=fluid.Executor(fluid.CUDAPlace(0))) 2. When training and testing are both performed, two different :code:`py_reader` should be created with different names, e.g.: diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 517418da1cf2f745ee5578e3c2b118394db7fae7..6702fc808b121d80fe555412e2cc7f673d6d8389 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -99,7 +99,8 @@ class ParallelExecutor(object): build_strategy.num_trainers = num_trainers build_strategy.trainer_id = trainer_id - self._places = compiler.get_available_places(use_cuda) + self._places = framework.cuda_places( + ) if use_cuda else framework.cpu_places() self._scope = scope if scope is not None else executor.global_scope() main_program = main_program if main_program is not None \ diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..74ee2828deb6ecd51ff36b878e97254a62ad1cb6 --- /dev/null +++ b/python/paddle/fluid/reader.py @@ -0,0 +1,373 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import core +import six +import threading +from .framework import Program, Variable, program_guard, default_main_program, default_startup_program +from .executor import global_scope +from .data_feeder import DataFeeder, BatchedTensorProvider +from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer +from .unique_name import UniqueNameGenerator + +__all__ = ['PyReader'] + + +def _convert_places(places): + if not isinstance(places, (list, tuple)): + places = [places] + + ret = [] + for p in places: + if not isinstance(p, core.Place): + tmp = core.Place() + tmp.set_place(p) + p = tmp + + ret.append(p) + return ret + + +class PyReader(object): + """ + Create a reader object for data feeding in Python. + Data would be prefetched using Python thread and be pushed + into a queue asynchronously. Data in the queue would be extracted + automatically when `Executor.run(...)` is called. + + Args: + feed_list (list(Variable)|tuple(Variable)): feed variable list. + The variables should be created by :code:`fluid.layers.data()`. + capacity (int): capacity of the queue maintained in PyReader object. + use_double_buffer (bool): whether to use double_buffer_reader to + speed up data feeding. + iterable (bool): whether the created reader object is iterable. + + Returns: + reader (Reader): the created reader object. + + Examples: + 1. If iterable = False, the created PyReader object is almost the + same as :code:`fluid.layers.py_reader()`. Operators would be + inserted into the program. User should call :code:`start()` + before each epoch and catch :code:`fluid.core.EOFException` + thrown by :code:`Executor.run()` when epoch ends. Once the + exception is caught, user should call :code:`reset()` to reset + the reader manually. + + .. code-block:: python + + image = fluid.layers.data( + name='image', shape=[784], dtype='float32') + label = fluid.layers.data( + name='label', shape=[1], dtype='int64') + + reader = fluid.io.PyReader(feed_list=[image, label], + capacity=4, iterable=False) + reader.decorate_sample_list_generator(user_defined_reader) + ... # definition of network is omitted + executor.run(fluid.default_main_program()) + for _ in range(EPOCH_NUM): + reader.start() + while True: + try: + executor.run(feed=None, ...) + except fluid.core.EOFException: + reader.reset() + break + + 2. If iterable=True, the created PyReader object is decoupled with + the program. No operator would be inserted into the program. + In this case, the created reader is a Python generator, which + is iterable. User should feed the data yielded from PyReader + object into :code:`Executor.run(feed=...)`. + + .. code-block:: python + + image = fluid.layers.data( + name='image', shape=[784], dtype='float32') + label = fluid.layers.data( + name='label', shape=[1], dtype='int64') + + reader = fluid.io.PyReader(feed_list=[image, label], + capacity=4, iterable=True) + reader.decorate_sample_list_generator(user_defined_reader, + places=fluid.cuda_places()) + ... # definition of network is omitted + executor.run(fluid.default_main_program()) + for _ in range(EPOCH_NUM): + for data in reader(): + executor.run(feed=data, ...) + """ + + unique_name_generator = UniqueNameGenerator() + + def __init__(self, + feed_list, + capacity, + use_double_buffer=True, + iterable=False): + self._tensor_reader = None + self._thread = None + self._iterable = iterable + self._use_double_buffer = use_double_buffer + self._capacity = capacity + self._feed_list = feed_list + if not self._iterable: + self._init_non_iterable() + + def _init_iterable(self, places): + self._var_names = [v.name for v in self._feed_list] + self._places = _convert_places(places) + self._queue = core.init_lod_tensor_blocking_queue(core.Variable(), + self._capacity) + self._reader = core.create_py_reader( + self.queue, self._var_names, self._places, self._use_double_buffer) + + def _init_non_iterable(self): + lod_levels = [] + dtypes = [] + shape_concat = [] + ranks = [] + shapes = [] + + for feed_data in self._feed_list: + dtypes.append(feed_data.dtype) + shape_concat.extend(feed_data.shape) + ranks.append(len(feed_data.shape)) + shapes.append(feed_data.shape) + lod_levels.append(feed_data.lod_level) + + queue_name = PyReader.unique_name_generator('lod_tensor_blocking_queue') + reader_name = PyReader.unique_name_generator('create_py_reader') + double_buffer_name = PyReader.unique_name_generator('double_buffer') + + var = global_scope().var(queue_name) + self._queue = core.init_lod_tensor_blocking_queue(var, self._capacity) + + startup_blk = default_startup_program().current_block() + startup_var = startup_blk.create_var(name=reader_name) + + startup_blk.append_op( + type='create_py_reader', + inputs={'blocking_queue': [queue_name]}, + outputs={'Out': [startup_var]}, + attrs={ + 'shape_concat': shape_concat, + 'lod_levels': lod_levels, + 'ranks': ranks + }) + + startup_var.desc.set_dtypes(dtypes) + startup_var.persistable = True + + main_prog_var = _copy_reader_var_( + default_main_program().current_block(), startup_var) + + main_prog_var.stop_gradient = True + main_prog_var.persistable = True + + reader = monkey_patch_reader_methods(main_prog_var) + if self._use_double_buffer: + double_buffer_reader = double_buffer( + reader, name=double_buffer_name) + # we return a double buffer reader. However, the reset method comes from + # py_reader. + double_buffer_reader.reset = reader.reset + reader = double_buffer_reader + + self._reader = reader + + default_main_program().current_block().append_op( + type='read', + inputs={'Reader': [self._reader]}, + outputs={'Out': self._feed_list}) + + @property + def queue(self): + return self._queue + + @property + def iterable(self): + return self._iterable + + def __call__(self): + assert self.iterable, "PyReader is not iterable" + assert self._tensor_reader is not None, \ + "Data source of PyReader has not set yet" + + class Iterator(object): + def __init__(self, reader): + self._reader = reader._reader + self._reset = reader._reset + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def next(self): + ret = self._reader.read_next() + if ret: + return ret + else: + self._reset() + raise StopIteration + + self._start() + return Iterator(self) + + def _reset(self): + self._reader.reset() + self._thread.join() + + def start(self): + ''' + Start the data feeding thread. + Can only call when the reader object is not iterable. + ''' + assert not self._iterable, "start() cannot be called when PyReader is iterable" + self._start() + + def reset(self): + ''' + Reset the reader object when :code:`fluid.core.EOFException` raises. + Can only call when the reader object is not iterable. + ''' + assert not self._iterable, "reset() cannot be called when PyReader is iterable" + self._reset() + + def _start(self): + def __thread_main__(): + try: + for tensors in self._tensor_reader(): + array = core.LoDTensorArray() + for item in tensors: + if not isinstance(item, core.LoDTensor): + tmp = core.LoDTensor() + tmp.set(item, core.CPUPlace()) + item = tmp + + array.append(item) + + if not self._queue.push(array): + break + + self._queue.close() + except Exception as ex: + self._queue.close() + raise ex + + self._thread = threading.Thread(target=__thread_main__) + self._thread.daemon = True + self._thread.start() + + def decorate_sample_generator(self, + sample_generator, + batch_size, + drop_last=True, + places=None): + ''' + Set the data source of the PyReader object. + + The provided :code:`sample_generator` should be a Python generator, + which yields numpy.ndarray typed data of each sample. + + :code:`places` must be set when the PyReader object is iterable. + + If all inputs have no lods, this method is faster than + :code:`decorate_sample_list_generator(paddle.batch(sample_generator, ...))` . + + Args: + sample_generator (generator): Python generator that yields + numpy.ndarray-typed sample data. + batch_size (int): batch size. Must be larger than 0. + drop_last (bool): Whether to drop the last batch when sample number + is less than batch_size. + places (None|list(CUDAPlace)|list(CPUPlace)): place list. Must + be provided when PyReader is iterable. + ''' + assert batch_size > 0, "batch_size must be larger than 0" + has_lod = False + for f in self._feed_list: + if f.lod_level != 0: + has_lod = True + break + + if has_lod: + self.decorate_sample_list_generator( + paddle.batch( + sample_generator, + batch_size=batch_size, + drop_last=drop_last), + places=places) + else: + reader = BatchedTensorProvider( + feed_list=self._feed_list, + place=core.CPUPlace(), + batch_size=batch_size, + generator=sample_generator, + drop_last=drop_last) + self.decorate_batch_generator(reader, places=places) + + def decorate_sample_list_generator(self, reader, places=None): + ''' + Set the data source of the PyReader object. + + The provided :code:`reader` should be a Python generator, + which yields list(numpy.ndarray) typed batched data. + + :code:`places` must be set when the PyReader object is iterable. + + Args: + reader (generator): Python generator that yields + list(numpy.ndarray)-typed batched data. + places (None|list(CUDAPlace)|list(CPUPlace)): place list. Must + be provided when PyReader is iterable. + ''' + assert self._tensor_reader is None, \ + "Cannot reset the data source of PyReader" + with program_guard(Program(), Program()): + feeder = DataFeeder( + feed_list=self._feed_list, place=core.CPUPlace()) + paddle_reader = feeder.decorate_reader(reader, multi_devices=False) + + def __tensor_reader_impl__(): + for slots in paddle_reader(): + yield [slots[var.name] for var in self._feed_list] + + self.decorate_batch_generator(__tensor_reader_impl__, places) + + def decorate_batch_generator(self, reader, places=None): + ''' + Set the data source of the PyReader object. + + The provided :code:`reader` should be a Python generator, + which yields numpy.ndarray-typed or LoDTensor-typed batched data. + + :code:`places` must be set when the PyReader object is iterable. + + Args: + reader (generator): Python generator that yields LoDTensor-typed + batched data. + places (None|list(CUDAPlace)|list(CPUPlace)): place list. Must + be provided when PyReader is iterable. + ''' + assert self._tensor_reader is None, \ + "Cannot reset the data source of PyReader" + self._tensor_reader = reader + if self._iterable: + assert places is not None, "Places cannot be None when py_reader is iterable" + self._init_iterable(places) diff --git a/python/paddle/fluid/tests/unittests/test_decoupled_py_reader.py b/python/paddle/fluid/tests/unittests/test_decoupled_py_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..377014510b55633f697ef7bf2f5f597281e5f5a5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_decoupled_py_reader.py @@ -0,0 +1,175 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import time +import six +import unittest + +EPOCH_NUM = 60 +BATCH_SIZE = 32 +CLASS_NUM = 10 + + +def random_reader(): + np.random.seed(1) + for i in range(BATCH_SIZE * 40): + image = np.random.random([784]) + label = np.random.random_integers(low=0, high=CLASS_NUM - 1) + yield image, label + + +def simple_fc_net(places, use_legacy_py_reader, use_double_buffer): + startup_prog = fluid.Program() + main_prog = fluid.Program() + startup_prog.random_seed = 1 + main_prog.random_seed = 1 + + with fluid.unique_name.guard(): + with fluid.program_guard(main_prog, startup_prog): + image = fluid.layers.data( + name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + py_reader = fluid.io.PyReader( + feed_list=[image, label], + capacity=4, + iterable=not use_legacy_py_reader, + use_double_buffer=use_double_buffer) + hidden = image + for hidden_size in [10, 20, 30]: + hidden = fluid.layers.fc( + hidden, + size=hidden_size, + act='tanh', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + + predict_label = fluid.layers.fc(hidden, + size=CLASS_NUM, + act='softmax') + loss = fluid.layers.mean( + fluid.layers.cross_entropy( + input=predict_label, label=label)) + + optimizer = fluid.optimizer.Adam() + optimizer.minimize(loss) + return startup_prog, main_prog, py_reader, loss + + +class TestBase(unittest.TestCase): + def run_main(self, use_legacy_py_reader, with_data_parallel, places, + use_double_buffer): + scope = fluid.Scope() + with fluid.scope_guard(scope): + startup_prog, main_prog, py_reader, loss = simple_fc_net( + places, use_legacy_py_reader, use_double_buffer) + + reader = paddle.batch(random_reader, batch_size=BATCH_SIZE) + + ps = places if use_double_buffer else fluid.cpu_places(len(places)) + + py_reader.decorate_sample_list_generator( + reader, places=ps if py_reader.iterable else None) + + exe = fluid.Executor(place=places[0]) + exe.run(startup_prog) + + prog = fluid.CompiledProgram(main_prog) + if with_data_parallel: + prog = prog.with_data_parallel( + loss_name=loss.name, places=places) + + step = 0 + step_list = [] + loss_list = [] + start_t = time.time() + if not py_reader.iterable: + for _ in six.moves.range(EPOCH_NUM): + step = 0 + py_reader.start() + while True: + try: + L, = exe.run(program=prog, + fetch_list=[loss], + use_program_cache=True) + loss_list.append(np.mean(L)) + step += 1 + except fluid.core.EOFException: + py_reader.reset() + break + step_list.append(step) + else: + for _ in six.moves.range(EPOCH_NUM): + step = 0 + for d in py_reader(): + assert len(d) == len(places) + for i, item in enumerate(d): + image = item['image'] + label = item['label'] + assert image.shape() == [BATCH_SIZE, 784] + assert label.shape() == [BATCH_SIZE, 1] + assert image._place()._equals(ps[i]) + assert label._place()._equals(ps[i]) + L, = exe.run(program=prog, + feed=d, + fetch_list=[loss], + use_program_cache=True) + loss_list.append(np.mean(L)) + step += 1 + step_list.append(step) + end_t = time.time() + ret = { + "time": end_t - start_t, + "step": step_list, + "loss": np.array(loss_list) + } + return ret + + def prepare_places(self, with_data_parallel, with_cpu=True, with_gpu=True): + places = [] + if with_cpu: + places.append([fluid.CPUPlace()]) + if with_data_parallel: + places.append([fluid.CPUPlace()] * 2) + + if with_gpu and fluid.core.is_compiled_with_cuda(): + tmp = fluid.cuda_places() + assert len(tmp) > 0, "no gpu detected" + if with_data_parallel: + places.append(tmp) + places.append([tmp[0]]) + return places + + def test_main(self): + for with_data_parallel in [True, False]: + for p in self.prepare_places(with_data_parallel): + for use_double_buffer in [False, True]: + results = [] + for use_legacy_py_reader in [False, True]: + ret = self.run_main( + use_legacy_py_reader=use_legacy_py_reader, + with_data_parallel=with_data_parallel, + places=p, + use_double_buffer=use_double_buffer) + results.append(ret) + if not use_double_buffer: + diff = np.max( + np.abs(results[0]['loss'] - results[1]['loss'])) + self.assertLess(diff, 1e-3) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_py_reader_sample_generator.py b/python/paddle/fluid/tests/unittests/test_py_reader_sample_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..4efca5e2aafd9c370ccc37791a9900b18f2705f6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_py_reader_sample_generator.py @@ -0,0 +1,137 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import math +import unittest +import numpy as np +import os + +os.environ['CPU_NUM'] = '1' + + +def random_reader(sample_num): + def __impl__(): + for _ in range(sample_num): + yield np.random.random( + size=[784]).astype('float32'), np.random.random_integers( + low=0, high=9, size=[1]).astype('int64') + + return paddle.reader.cache(__impl__) + + +class TestCaseBase(unittest.TestCase): + def setUp(self): + self.batch_size = 32 + self.epoch_num = 2 + self.sample_num = 165 + + def generate_all_data(self, reader): + ret = [] + for d in reader(): + slots = [[], []] + for item in d: + slots[0].append(item[0]) + slots[1].append(item[1]) + slots = [np.array(slot) for slot in slots] + ret.append(slots) + return ret + + def run_main(self, reader, use_sample_generator, iterable, drop_last): + image = fluid.layers.data(name='image', dtype='float32', shape=[784]) + label = fluid.layers.data(name='label', dtype='int64', shape=[1]) + py_reader = fluid.io.PyReader( + feed_list=[image, label], + capacity=16, + iterable=iterable, + use_double_buffer=False) + + batch_reader = paddle.batch(reader, self.batch_size, drop_last) + all_datas = self.generate_all_data(batch_reader) + + if not use_sample_generator: + py_reader.decorate_sample_list_generator( + batch_reader, places=fluid.cpu_places()) + else: + py_reader.decorate_sample_generator( + reader, self.batch_size, drop_last, places=fluid.cpu_places()) + + if drop_last: + batch_num = int(self.sample_num / self.batch_size) + else: + batch_num = math.ceil(float(self.sample_num) / self.batch_size) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + for _ in range(self.epoch_num): + if py_reader.iterable: + step = 0 + for data in py_reader(): + img, lbl = exe.run(feed=data, fetch_list=[image, label]) + self.assertArrayEqual(img, all_datas[step][0]) + self.assertArrayEqual(lbl, all_datas[step][1]) + step += 1 + self.assertEqual(step, len(all_datas)) + else: + step = 0 + try: + py_reader.start() + while True: + img, lbl = exe.run(fetch_list=[image, label]) + self.assertArrayEqual(img, all_datas[step][0]) + self.assertArrayEqual(lbl, all_datas[step][1]) + step += 1 + except fluid.core.EOFException: + py_reader.reset() + self.assertEqual(step, len(all_datas)) + break + + def assertArrayEqual(self, arr1, arr2): + self.assertEqual(arr1.shape, arr2.shape) + self.assertTrue((arr1 == arr2).all()) + + def test_main(self): + reader = random_reader(self.sample_num) + for use_sample_generator in [False, True]: + for iterable in [False, True]: + for drop_last in [False, True]: + with fluid.program_guard(fluid.Program(), fluid.Program()): + self.run_main(reader, use_sample_generator, iterable, + drop_last) + + +class TestCase1(TestCaseBase): + def setUp(self): + self.batch_size = 32 + self.epoch_num = 10 + self.sample_num = 160 + + +class TestCase2(TestCaseBase): + def setUp(self): + self.batch_size = 32 + self.epoch_num = 2 + self.sample_num = 200 + + +class TestCase3(TestCaseBase): + def setUp(self): + self.batch_size = 32 + self.epoch_num = 2 + self.sample_num = 159 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index 685d08b9e0b2127fbe8f8b55f8c329ce0002bbe7..f8c5ae0eaf45fd3ab43652c16b4954d622787702 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -13,7 +13,7 @@ # limitations under the License. __all__ = [ - 'map_readers', 'buffered', 'compose', 'chain', 'shuffle', + 'cache', 'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader', 'multiprocess_reader', 'Fake' ] @@ -33,6 +33,30 @@ import zlib import paddle.compat as cpt +def cache(reader): + """ + Cache the reader data into memory. + + Be careful that this method may take long time to process, + and consume lots of memory. :code:`reader()` would only + call once. + + Args: + reader (generator): a reader object which yields + data each time. + + Returns: + generator: a decorated reader object which yields data from cached memory. + """ + all_data = tuple(reader()) + + def __impl__(): + for item in all_data: + yield item + + return __impl__ + + def map_readers(func, *readers): """ Creates a data reader that outputs return value of function using