diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 9247c85a5229d891feef3c0bf3d199597e5f9ffa..69b7056169aa0110980744ce282ed8ef2a6da679 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -146,7 +146,7 @@ cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) if (NOT WIN32) -py_proto_compile(framework_py_proto SRCS framework.proto) +py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_dependencies(framework_py_proto framework_py_proto_init) diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index f49290f1aa7e24c21713e0fcf171b7ad235f9a5a..639b546ff394c3394f6c2d27396a208fde573eea 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -139,12 +139,14 @@ void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) { std::vector AsyncExecutor::RunFromFile( const ProgramDesc& main_program, - const DataFeedDesc& data_feed_desc, + const std::string& data_feed_desc_str, const std::vector& filelist, const int thread_num, const std::vector& fetch_var_names) { std::vector threads; + DataFeedDesc data_feed_desc; + google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, &data_feed_desc); /* readerDesc: protobuf description for reader initlization argument: class_name, batch_size, use_slot, queue_size, buffer_size, padding_index diff --git a/paddle/fluid/framework/async_executor.h b/paddle/fluid/framework/async_executor.h index 89aa8efd9c15d7766fef3b5bf2e2effdd3d4462e..6bbfad98d3bbf83efd7263046409bec3d40e4f9d 100644 --- a/paddle/fluid/framework/async_executor.h +++ b/paddle/fluid/framework/async_executor.h @@ -55,7 +55,7 @@ class AsyncExecutor { void SetModelPrefix(const std::string& model_prefix); void RunStartupProgram(const ProgramDesc& program, Scope* scope); std::vector RunFromFile(const ProgramDesc& main_program, - const DataFeedDesc& data_feed_desc, + const std::string& data_feed_desc_str, const std::vector& filelist, const int thread_num, const std::vector& fetch_names); diff --git a/paddle/fluid/framework/data_feed.proto b/paddle/fluid/framework/data_feed.proto index 88e576b1907b8b4eb9c062fea2d01ed448ea0aa0..d6ae0002dfce925ec1bf054a935291cfa614ad8b 100644 --- a/paddle/fluid/framework/data_feed.proto +++ b/paddle/fluid/framework/data_feed.proto @@ -17,6 +17,17 @@ package paddle.framework; message DataFeedDesc { optional string name = 1; optional int32 batch = 2 [default = 32]; - repeated string field_names = 3; + optional MultiSlotDesc multi_slot_desc = 3; +} + +message MultiSlotDesc { + repeated Slot slots = 1; +} + +message Slot { + required string name = 1; + required string type = 2; + optional bool dense = 3 [default = false]; + optional bool use = 4 [default = true]; } diff --git a/paddle/fluid/pybind/async_executor_py.cc b/paddle/fluid/pybind/async_executor_py.cc index 2f7bae06289210769fccef39672e5cf08d5d8482..729b5aa1d0331f87d0ccad0534752b1f6c0c291d 100644 --- a/paddle/fluid/pybind/async_executor_py.cc +++ b/paddle/fluid/pybind/async_executor_py.cc @@ -41,17 +41,6 @@ namespace paddle { namespace pybind { using set_name_func = void (pd::DataFeedDesc::*)(const std::string&); void BindAsyncExecutor(py::module* m) { - py::class_(*m, "DataFeedDesc") - .def(pybind11::init<>()) - .def("set_name", (set_name_func)&pd::DataFeedDesc::set_name) - .def("set_batch", &pd::DataFeedDesc::set_batch) - .def("set_field_names", - [] (pd::DataFeedDesc& self, const std::vector &fields) { - for (auto field : fields) { - self.add_field_names(field); - } - }); - py::class_(*m, "AsyncExecutor") .def(py::init()) .def("run_from_files", &framework::AsyncExecutor::RunFromFile) diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index 9af7f97f18b873a6053de00080cee39be784d963..a6bb4ce75ce252c5167a89ddced2c4cf676efd3b 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -20,25 +20,29 @@ import six from .framework import Program, default_main_program, Variable from . import core from .executor import global_scope +from paddle.fluid.proto import data_feed_pb2 +from google.protobuf import text_format -__all__ = ['MultiSlotDataFeed', 'AsyncExecutor'] +__all__ = ['DataFeedDesc', 'AsyncExecutor'] g_scope = core.Scope() class DataFeedDesc(object): - def __init__(self): - self.desc = core.DataFeedDesc() + def __init__(self, proto_file): + self.proto_desc = data_feed_pb2.DataFeedDesc() + f = open(proto_file, 'r') + text_format.Parse(f.read(), self.proto_desc) + f.close() + def set_batch_size(self, batch_size): - self.desc.set_batch(batch_size) - def set_field_name(self, field_names): - if isinstance(field_names, str): - field_names = [field_names] - self.desc.set_field_names(field_names) + self.proto_desc.batch = batch_size + + def add_slot(self): + slot = self.proto_desc.multi_slot_desc.slots.add() + return slot -class MultiSlotDataFeed(DataFeedDesc): - def __init__(self): - super(MultiSlotDataFeed, self).__init__() - self.desc.set_name("MultiSlotDataFeed") + def desc(self): + return text_format.MessageToString(self.proto_desc) class AsyncExecutor(object): """ @@ -127,6 +131,6 @@ class AsyncExecutor(object): fetch = [fetch] fetch_var_names = [var.name for var in fetch] - evaluation = self.executor.run_from_files(program_desc, data_feed.desc, filelist, thread_num, fetch_var_names) + evaluation = self.executor.run_from_files(program_desc, data_feed, filelist, thread_num, fetch_var_names) return evaluation