提交 fea43077 编写于 作者: Y Yu Yang

Refine

上级 a305cb21
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
......
cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader) cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader)
op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry)
op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry) set(LOCAL_READER_LIBS)
op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry)
op_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc DEPS reader_op_registry) function(reader_library TARGET_NAME)
set(READER_LIBRARY set(oneValueArgs "")
create_recordio_file_reader_op set(multiValueArgs SRCS DEPS)
create_random_data_generator_op set(options "")
create_shuffle_reader_op set(common_deps reader_op_registry)
create_batch_reader_op cmake_parse_arguments(reader_library "${options}" "${oneValueArgs}"
PARENT_SCOPE) "${multiValueArgs}" ${ARGN})
op_library(${TARGET_NAME} SRCS ${reader_library_SRCS} DEPS ${common_deps} ${reader_library_DEPS})
set(LOCAL_READER_LIBS
${TARGET_NAME}
${LOCAL_READER_LIBS}
PARENT_SCOPE)
endfunction()
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
# Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
...@@ -26,6 +26,7 @@ limitations under the License. */ ...@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/prune.h" #include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/cond_op.h" #include "paddle/fluid/operators/cond_op.h"
#include "paddle/fluid/operators/net_op.h" #include "paddle/fluid/operators/net_op.h"
...@@ -219,8 +220,18 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -219,8 +220,18 @@ All parameter, weight, gradient are variables in Paddle.
[](Variable &self) -> operators::NetOp * { [](Variable &self) -> operators::NetOp * {
return self.GetMutable<operators::NetOp>(); return self.GetMutable<operators::NetOp>();
}, },
py::return_value_policy::reference)
.def("get_reader",
[](Variable &self) -> framework::ReaderHolder * {
PADDLE_ENFORCE(self.IsType<framework::ReaderHolder>());
return self.GetMutable<framework::ReaderHolder>();
},
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("has_next", &framework::ReaderHolder::HasNext)
.def("reset", &framework::ReaderHolder::ReInit);
py::class_<Scope>(m, "Scope", "") py::class_<Scope>(m, "Scope", "")
.def("var", .def("var",
[](Scope &self, const std::string &name) -> Variable * { [](Scope &self, const std::string &name) -> Variable * {
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <fstream> #include <fstream>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/writer.h" #include "paddle/fluid/recordio/writer.h"
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
......
...@@ -47,7 +47,7 @@ def is_parameter(var): ...@@ -47,7 +47,7 @@ def is_parameter(var):
def is_persistable(var): def is_persistable(var):
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST: var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
return False return False
return var.persistable return var.persistable
......
...@@ -17,6 +17,7 @@ from ..framework import convert_np_dtype_to_dtype_, default_main_program, defaul ...@@ -17,6 +17,7 @@ from ..framework import convert_np_dtype_to_dtype_, default_main_program, defaul
from ..unique_name import generate as unique_name from ..unique_name import generate as unique_name
from control_flow import BlockGuard from control_flow import BlockGuard
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..executor import global_scope
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
...@@ -230,12 +231,29 @@ def Recv(endpoints, get_vars): ...@@ -230,12 +231,29 @@ def Recv(endpoints, get_vars):
"epmap": epmap}) "epmap": epmap})
def monkey_patch_reader_methods(reader):
def __get_reader__():
scope = global_scope()
var = scope.find_var(reader.name)
return var.get_reader()
def eof():
return not __get_reader__().has_next()
def reset():
return __get_reader__().reset()
reader.eof = eof
reader.reset = reset
return reader
def _copy_reader_var_(block, var): def _copy_reader_var_(block, var):
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER) new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
new_var.desc.set_shapes(var.desc.shapes()) new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes()) new_var.desc.set_dtypes(var.desc.dtypes())
new_var.persistable = True new_var.persistable = True
return new_var return monkey_patch_reader_methods(new_var)
def open_recordio_file(filename, shapes, lod_levels, dtypes): def open_recordio_file(filename, shapes, lod_levels, dtypes):
......
...@@ -54,7 +54,16 @@ class TestRecordIO(unittest.TestCase): ...@@ -54,7 +54,16 @@ class TestRecordIO(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
avg_loss_np = [] avg_loss_np = []
for i in xrange(100): # train 100 mini-batch
tmp, = exe.run(fetch_list=[avg_loss]) for i in xrange(2): # 2 pass
avg_loss_np.append(tmp) batch_id = 0
while not data_file.eof():
try:
batch_id += 1
tmp, = exe.run(fetch_list=[avg_loss])
avg_loss_np.append(tmp)
except:
print batch_id
break
data_file.reset()
self.assertLess(avg_loss_np[-1], avg_loss_np[0]) self.assertLess(avg_loss_np[-1], avg_loss_np[0])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册