提交 2e72ebfa 编写于 作者: B barrierye

add MultiSlotDesc

上级 c083c4bc
...@@ -23,35 +23,37 @@ from .executor import global_scope, Executor ...@@ -23,35 +23,37 @@ from .executor import global_scope, Executor
from paddle.fluid.proto import data_feed_pb2 from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format from google.protobuf import text_format
__all__ = ['DataFeedDesc', 'AsyncExecutor'] __all__ = ['MultiSlotDesc', 'AsyncExecutor']
g_scope = core.Scope() g_scope = core.Scope()
class DataFeedDesc(object): class DataFeedDesc(object):
def __init__(self, proto_file): def __init__(self, proto_file):
self.proto_desc = data_feed_pb2.DataFeedDesc() self._proto_desc = data_feed_pb2.DataFeedDesc()
with open(proto_file, 'r') as f: with open(proto_file, 'r') as f:
text_format.Parse(f.read(), self.proto_desc) text_format.Parse(f.read(), self._proto_desc)
self.__name_to_index = {}
for i, slot in enumerate(self.proto_desc.multi_slot_desc.slots):
self.__name_to_index[slot.name] = i
def set_data_feed_type(self, data_feed):
self.proto_desc.name = datafeed
def set_batch_size(self, batch_size): def set_batch_size(self, batch_size):
self.proto_desc.batch = batch_size self._proto_desc.batch = batch_size
def desc(self):
return text_format.MessageToString(self._proto_desc)
class MultiSlotDesc(DataFeedDesc):
def __init__(self, proto_file):
super(MultiSlotDesc, self).__init__(proto_file)
if self._proto_desc.name != "MultiSlotDataFeed":
raise ValueError("The DataFeed name in proto is %s, not MultiSlotDataFeed" % self._proto_desc.name)
self.__name_to_index = {slot.name: i for i, slot in enumerate(self._proto_desc.multi_slot_desc.slots)}
def set_dense_slots(self, dense_slots_name): def set_dense_slots(self, dense_slots_name):
for name in dense_slots_name: for name in dense_slots_name:
self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].dense = True self._proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].dense = True
def set_use_slots(self, use_slots_name): def set_use_slots(self, use_slots_name):
for name in use_slots_name: for name in use_slots_name:
self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].use = True self._proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].use = True
def desc(self):
return text_format.MessageToString(self.proto_desc)
class AsyncExecutor(object): class AsyncExecutor(object):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册