提交 c46cecd7 编写于 作者: W wangguibao

Creating DataFeedDesc from .proto file, then manipulate it (add/del fields etc) from python side

上级 91fc8f35
...@@ -146,7 +146,7 @@ cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator ...@@ -146,7 +146,7 @@ cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
if (NOT WIN32) if (NOT WIN32)
py_proto_compile(framework_py_proto SRCS framework.proto) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
......
...@@ -139,12 +139,14 @@ void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) { ...@@ -139,12 +139,14 @@ void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) {
std::vector<float> AsyncExecutor::RunFromFile( std::vector<float> AsyncExecutor::RunFromFile(
const ProgramDesc& main_program, const ProgramDesc& main_program,
const DataFeedDesc& data_feed_desc, const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist, const std::vector<std::string>& filelist,
const int thread_num, const int thread_num,
const std::vector<std::string>& fetch_var_names) { const std::vector<std::string>& fetch_var_names) {
std::vector<std::thread> threads; std::vector<std::thread> threads;
DataFeedDesc data_feed_desc;
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, &data_feed_desc);
/* /*
readerDesc: protobuf description for reader initlization readerDesc: protobuf description for reader initlization
argument: class_name, batch_size, use_slot, queue_size, buffer_size, padding_index argument: class_name, batch_size, use_slot, queue_size, buffer_size, padding_index
......
...@@ -55,7 +55,7 @@ class AsyncExecutor { ...@@ -55,7 +55,7 @@ class AsyncExecutor {
void SetModelPrefix(const std::string& model_prefix); void SetModelPrefix(const std::string& model_prefix);
void RunStartupProgram(const ProgramDesc& program, Scope* scope); void RunStartupProgram(const ProgramDesc& program, Scope* scope);
std::vector<float> RunFromFile(const ProgramDesc& main_program, std::vector<float> RunFromFile(const ProgramDesc& main_program,
const DataFeedDesc& data_feed_desc, const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist, const std::vector<std::string>& filelist,
const int thread_num, const int thread_num,
const std::vector<std::string>& fetch_names); const std::vector<std::string>& fetch_names);
......
...@@ -17,6 +17,17 @@ package paddle.framework; ...@@ -17,6 +17,17 @@ package paddle.framework;
message DataFeedDesc { message DataFeedDesc {
optional string name = 1; optional string name = 1;
optional int32 batch = 2 [default = 32]; optional int32 batch = 2 [default = 32];
repeated string field_names = 3; optional MultiSlotDesc multi_slot_desc = 3;
}
message MultiSlotDesc {
repeated Slot slots = 1;
}
message Slot {
required string name = 1;
required string type = 2;
optional bool dense = 3 [default = false];
optional bool use = 4 [default = true];
} }
...@@ -41,17 +41,6 @@ namespace paddle { ...@@ -41,17 +41,6 @@ namespace paddle {
namespace pybind { namespace pybind {
using set_name_func = void (pd::DataFeedDesc::*)(const std::string&); using set_name_func = void (pd::DataFeedDesc::*)(const std::string&);
void BindAsyncExecutor(py::module* m) { void BindAsyncExecutor(py::module* m) {
py::class_<pd::DataFeedDesc>(*m, "DataFeedDesc")
.def(pybind11::init<>())
.def("set_name", (set_name_func)&pd::DataFeedDesc::set_name)
.def("set_batch", &pd::DataFeedDesc::set_batch)
.def("set_field_names",
[] (pd::DataFeedDesc& self, const std::vector<std::string> &fields) {
for (auto field : fields) {
self.add_field_names(field);
}
});
py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor") py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
.def(py::init<pd::Scope&, const platform::Place&>()) .def(py::init<pd::Scope&, const platform::Place&>())
.def("run_from_files", &framework::AsyncExecutor::RunFromFile) .def("run_from_files", &framework::AsyncExecutor::RunFromFile)
......
...@@ -20,25 +20,29 @@ import six ...@@ -20,25 +20,29 @@ import six
from .framework import Program, default_main_program, Variable from .framework import Program, default_main_program, Variable
from . import core from . import core
from .executor import global_scope from .executor import global_scope
from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
__all__ = ['MultiSlotDataFeed', 'AsyncExecutor'] __all__ = ['DataFeedDesc', 'AsyncExecutor']
g_scope = core.Scope() g_scope = core.Scope()
class DataFeedDesc(object): class DataFeedDesc(object):
def __init__(self): def __init__(self, proto_file):
self.desc = core.DataFeedDesc() self.proto_desc = data_feed_pb2.DataFeedDesc()
f = open(proto_file, 'r')
text_format.Parse(f.read(), self.proto_desc)
f.close()
def set_batch_size(self, batch_size): def set_batch_size(self, batch_size):
self.desc.set_batch(batch_size) self.proto_desc.batch = batch_size
def set_field_name(self, field_names):
if isinstance(field_names, str): def add_slot(self):
field_names = [field_names] slot = self.proto_desc.multi_slot_desc.slots.add()
self.desc.set_field_names(field_names) return slot
class MultiSlotDataFeed(DataFeedDesc): def desc(self):
def __init__(self): return text_format.MessageToString(self.proto_desc)
super(MultiSlotDataFeed, self).__init__()
self.desc.set_name("MultiSlotDataFeed")
class AsyncExecutor(object): class AsyncExecutor(object):
""" """
...@@ -127,6 +131,6 @@ class AsyncExecutor(object): ...@@ -127,6 +131,6 @@ class AsyncExecutor(object):
fetch = [fetch] fetch = [fetch]
fetch_var_names = [var.name for var in fetch] fetch_var_names = [var.name for var in fetch]
evaluation = self.executor.run_from_files(program_desc, data_feed.desc, filelist, thread_num, fetch_var_names) evaluation = self.executor.run_from_files(program_desc, data_feed, filelist, thread_num, fetch_var_names)
return evaluation return evaluation
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册