From fea43077f6a0c2aca7915ed86a5cf56549d9369b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 12 Mar 2018 10:49:24 +0800 Subject: [PATCH] Refine --- paddle/fluid/operators/detail/safe_ref.h | 2 ++ paddle/fluid/operators/reader/CMakeLists.txt | 34 +++++++++++++------ paddle/fluid/pybind/pybind.cc | 11 ++++++ paddle/fluid/pybind/recordio.cc | 1 + python/paddle/fluid/io.py | 2 +- python/paddle/fluid/layers/io.py | 20 ++++++++++- .../tests/unittests/test_recordio_reader.py | 15 ++++++-- 7 files changed, 70 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/detail/safe_ref.h b/paddle/fluid/operators/detail/safe_ref.h index 9cb5851deba..48bdce74087 100644 --- a/paddle/fluid/operators/detail/safe_ref.h +++ b/paddle/fluid/operators/detail/safe_ref.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/platform/enforce.h" + namespace paddle { namespace operators { namespace detail { diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 88a0beb46b5..9dded87a5d9 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -1,11 +1,25 @@ cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader) -op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry) -op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry) -op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry) -op_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc DEPS reader_op_registry) -set(READER_LIBRARY - create_recordio_file_reader_op - create_random_data_generator_op - create_shuffle_reader_op - create_batch_reader_op - PARENT_SCOPE) + +set(LOCAL_READER_LIBS) + +function(reader_library TARGET_NAME) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + set(options "") + set(common_deps reader_op_registry) + cmake_parse_arguments(reader_library "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + op_library(${TARGET_NAME} SRCS ${reader_library_SRCS} DEPS ${common_deps} ${reader_library_DEPS}) + set(LOCAL_READER_LIBS + ${TARGET_NAME} + ${LOCAL_READER_LIBS} + PARENT_SCOPE) +endfunction() + +reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) +reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) +reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) +reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) + +# Export local libraries to parent +set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 15b99c6bd0b..d2e883caccd 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/prune.h" +#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/cond_op.h" #include "paddle/fluid/operators/net_op.h" @@ -219,8 +220,18 @@ All parameter, weight, gradient are variables in Paddle. [](Variable &self) -> operators::NetOp * { return self.GetMutable(); }, + py::return_value_policy::reference) + .def("get_reader", + [](Variable &self) -> framework::ReaderHolder * { + PADDLE_ENFORCE(self.IsType()); + return self.GetMutable(); + }, py::return_value_policy::reference); + py::class_(m, "Reader", "") + .def("has_next", &framework::ReaderHolder::HasNext) + .def("reset", &framework::ReaderHolder::ReInit); + py::class_(m, "Scope", "") .def("var", [](Scope &self, const std::string &name) -> Variable * { diff --git a/paddle/fluid/pybind/recordio.cc b/paddle/fluid/pybind/recordio.cc index 06e149787eb..16f8bfb1a2e 100644 --- a/paddle/fluid/pybind/recordio.cc +++ b/paddle/fluid/pybind/recordio.cc @@ -16,6 +16,7 @@ #include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/recordio/writer.h" + namespace paddle { namespace pybind { diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 1817caa9427..5b888143ad3 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -47,7 +47,7 @@ def is_parameter(var): def is_persistable(var): if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ - var.desc.type() == core.VarDesc.VarType.FETCH_LIST: + var.desc.type() == core.VarDesc.VarType.FETCH_LIST: return False return var.persistable diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 641cee3bd92..f1b2af70205 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -17,6 +17,7 @@ from ..framework import convert_np_dtype_to_dtype_, default_main_program, defaul from ..unique_name import generate as unique_name from control_flow import BlockGuard from ..layer_helper import LayerHelper +from ..executor import global_scope __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', @@ -230,12 +231,29 @@ def Recv(endpoints, get_vars): "epmap": epmap}) +def monkey_patch_reader_methods(reader): + def __get_reader__(): + scope = global_scope() + var = scope.find_var(reader.name) + return var.get_reader() + + def eof(): + return not __get_reader__().has_next() + + def reset(): + return __get_reader__().reset() + + reader.eof = eof + reader.reset = reset + return reader + + def _copy_reader_var_(block, var): new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER) new_var.desc.set_shapes(var.desc.shapes()) new_var.desc.set_dtypes(var.desc.dtypes()) new_var.persistable = True - return new_var + return monkey_patch_reader_methods(new_var) def open_recordio_file(filename, shapes, lod_levels, dtypes): diff --git a/python/paddle/fluid/tests/unittests/test_recordio_reader.py b/python/paddle/fluid/tests/unittests/test_recordio_reader.py index 6ec833f6c1a..7844d46320e 100644 --- a/python/paddle/fluid/tests/unittests/test_recordio_reader.py +++ b/python/paddle/fluid/tests/unittests/test_recordio_reader.py @@ -54,7 +54,16 @@ class TestRecordIO(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace()) exe.run(fluid.default_startup_program()) avg_loss_np = [] - for i in xrange(100): # train 100 mini-batch - tmp, = exe.run(fetch_list=[avg_loss]) - avg_loss_np.append(tmp) + + for i in xrange(2): # 2 pass + batch_id = 0 + while not data_file.eof(): + try: + batch_id += 1 + tmp, = exe.run(fetch_list=[avg_loss]) + avg_loss_np.append(tmp) + except: + print batch_id + break + data_file.reset() self.assertLess(avg_loss_np[-1], avg_loss_np[0]) -- GitLab