提交 c46cecd7 编写于 作者: W wangguibao

Creating DataFeedDesc from .proto file, then manipulate it (add/del fields etc) from python side

上级 91fc8f35
......@@ -146,7 +146,7 @@ cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
if (NOT WIN32)
py_proto_compile(framework_py_proto SRCS framework.proto)
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
......
......@@ -139,12 +139,14 @@ void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) {
std::vector<float> AsyncExecutor::RunFromFile(
const ProgramDesc& main_program,
const DataFeedDesc& data_feed_desc,
const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_var_names) {
std::vector<std::thread> threads;
DataFeedDesc data_feed_desc;
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, &data_feed_desc);
/*
readerDesc: protobuf description for reader initlization
argument: class_name, batch_size, use_slot, queue_size, buffer_size, padding_index
......
......@@ -55,7 +55,7 @@ class AsyncExecutor {
void SetModelPrefix(const std::string& model_prefix);
void RunStartupProgram(const ProgramDesc& program, Scope* scope);
std::vector<float> RunFromFile(const ProgramDesc& main_program,
const DataFeedDesc& data_feed_desc,
const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_names);
......
......@@ -17,6 +17,17 @@ package paddle.framework;
message DataFeedDesc {
optional string name = 1;
optional int32 batch = 2 [default = 32];
repeated string field_names = 3;
optional MultiSlotDesc multi_slot_desc = 3;
}
message MultiSlotDesc {
repeated Slot slots = 1;
}
message Slot {
required string name = 1;
required string type = 2;
optional bool dense = 3 [default = false];
optional bool use = 4 [default = true];
}
......@@ -41,17 +41,6 @@ namespace paddle {
namespace pybind {
using set_name_func = void (pd::DataFeedDesc::*)(const std::string&);
void BindAsyncExecutor(py::module* m) {
py::class_<pd::DataFeedDesc>(*m, "DataFeedDesc")
.def(pybind11::init<>())
.def("set_name", (set_name_func)&pd::DataFeedDesc::set_name)
.def("set_batch", &pd::DataFeedDesc::set_batch)
.def("set_field_names",
[] (pd::DataFeedDesc& self, const std::vector<std::string> &fields) {
for (auto field : fields) {
self.add_field_names(field);
}
});
py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
.def(py::init<pd::Scope&, const platform::Place&>())
.def("run_from_files", &framework::AsyncExecutor::RunFromFile)
......
......@@ -20,25 +20,29 @@ import six
from .framework import Program, default_main_program, Variable
from . import core
from .executor import global_scope
from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
__all__ = ['MultiSlotDataFeed', 'AsyncExecutor']
__all__ = ['DataFeedDesc', 'AsyncExecutor']
g_scope = core.Scope()
class DataFeedDesc(object):
def __init__(self):
self.desc = core.DataFeedDesc()
def __init__(self, proto_file):
self.proto_desc = data_feed_pb2.DataFeedDesc()
f = open(proto_file, 'r')
text_format.Parse(f.read(), self.proto_desc)
f.close()
def set_batch_size(self, batch_size):
self.desc.set_batch(batch_size)
def set_field_name(self, field_names):
if isinstance(field_names, str):
field_names = [field_names]
self.desc.set_field_names(field_names)
self.proto_desc.batch = batch_size
def add_slot(self):
slot = self.proto_desc.multi_slot_desc.slots.add()
return slot
class MultiSlotDataFeed(DataFeedDesc):
def __init__(self):
super(MultiSlotDataFeed, self).__init__()
self.desc.set_name("MultiSlotDataFeed")
def desc(self):
return text_format.MessageToString(self.proto_desc)
class AsyncExecutor(object):
"""
......@@ -127,6 +131,6 @@ class AsyncExecutor(object):
fetch = [fetch]
fetch_var_names = [var.name for var in fetch]
evaluation = self.executor.run_from_files(program_desc, data_feed.desc, filelist, thread_num, fetch_var_names)
evaluation = self.executor.run_from_files(program_desc, data_feed, filelist, thread_num, fetch_var_names)
return evaluation
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册