提交 fea43077 编写于 作者: Y Yu Yang

Refine

上级 a305cb21
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace detail {
......
cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader)
op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry)
op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry)
op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry)
op_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc DEPS reader_op_registry)
set(READER_LIBRARY
create_recordio_file_reader_op
create_random_data_generator_op
create_shuffle_reader_op
create_batch_reader_op
PARENT_SCOPE)
set(LOCAL_READER_LIBS)
function(reader_library TARGET_NAME)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
set(options "")
set(common_deps reader_op_registry)
cmake_parse_arguments(reader_library "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
op_library(${TARGET_NAME} SRCS ${reader_library_SRCS} DEPS ${common_deps} ${reader_library_DEPS})
set(LOCAL_READER_LIBS
${TARGET_NAME}
${LOCAL_READER_LIBS}
PARENT_SCOPE)
endfunction()
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
# Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/cond_op.h"
#include "paddle/fluid/operators/net_op.h"
......@@ -219,8 +220,18 @@ All parameter, weight, gradient are variables in Paddle.
[](Variable &self) -> operators::NetOp * {
return self.GetMutable<operators::NetOp>();
},
py::return_value_policy::reference)
.def("get_reader",
[](Variable &self) -> framework::ReaderHolder * {
PADDLE_ENFORCE(self.IsType<framework::ReaderHolder>());
return self.GetMutable<framework::ReaderHolder>();
},
py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("has_next", &framework::ReaderHolder::HasNext)
.def("reset", &framework::ReaderHolder::ReInit);
py::class_<Scope>(m, "Scope", "")
.def("var",
[](Scope &self, const std::string &name) -> Variable * {
......
......@@ -16,6 +16,7 @@
#include <fstream>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/writer.h"
namespace paddle {
namespace pybind {
......
......@@ -47,7 +47,7 @@ def is_parameter(var):
def is_persistable(var):
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
return False
return var.persistable
......
......@@ -17,6 +17,7 @@ from ..framework import convert_np_dtype_to_dtype_, default_main_program, defaul
from ..unique_name import generate as unique_name
from control_flow import BlockGuard
from ..layer_helper import LayerHelper
from ..executor import global_scope
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
......@@ -230,12 +231,29 @@ def Recv(endpoints, get_vars):
"epmap": epmap})
def monkey_patch_reader_methods(reader):
def __get_reader__():
scope = global_scope()
var = scope.find_var(reader.name)
return var.get_reader()
def eof():
return not __get_reader__().has_next()
def reset():
return __get_reader__().reset()
reader.eof = eof
reader.reset = reset
return reader
def _copy_reader_var_(block, var):
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes())
new_var.persistable = True
return new_var
return monkey_patch_reader_methods(new_var)
def open_recordio_file(filename, shapes, lod_levels, dtypes):
......
......@@ -54,7 +54,16 @@ class TestRecordIO(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
avg_loss_np = []
for i in xrange(100): # train 100 mini-batch
tmp, = exe.run(fetch_list=[avg_loss])
avg_loss_np.append(tmp)
for i in xrange(2): # 2 pass
batch_id = 0
while not data_file.eof():
try:
batch_id += 1
tmp, = exe.run(fetch_list=[avg_loss])
avg_loss_np.append(tmp)
except:
print batch_id
break
data_file.reset()
self.assertLess(avg_loss_np[-1], avg_loss_np[0])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册