diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index db55e4d070ea6c36251a7f77609db3391092c6b2..bfd29f1ab7744fad38d95be89a0384aed61c9ab0 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -23,35 +23,37 @@ from .executor import global_scope, Executor from paddle.fluid.proto import data_feed_pb2 from google.protobuf import text_format -__all__ = ['DataFeedDesc', 'AsyncExecutor'] +__all__ = ['MultiSlotDesc', 'AsyncExecutor'] g_scope = core.Scope() class DataFeedDesc(object): def __init__(self, proto_file): - self.proto_desc = data_feed_pb2.DataFeedDesc() + self._proto_desc = data_feed_pb2.DataFeedDesc() with open(proto_file, 'r') as f: - text_format.Parse(f.read(), self.proto_desc) - self.__name_to_index = {} - for i, slot in enumerate(self.proto_desc.multi_slot_desc.slots): - self.__name_to_index[slot.name] = i - - def set_data_feed_type(self, data_feed): - self.proto_desc.name = datafeed + text_format.Parse(f.read(), self._proto_desc) def set_batch_size(self, batch_size): - self.proto_desc.batch = batch_size + self._proto_desc.batch = batch_size + + def desc(self): + return text_format.MessageToString(self._proto_desc) +class MultiSlotDesc(DataFeedDesc): + def __init__(self, proto_file): + super(MultiSlotDesc, self).__init__(proto_file) + if self._proto_desc.name != "MultiSlotDataFeed": + raise ValueError("The DataFeed name in proto is %s, not MultiSlotDataFeed" % self._proto_desc.name) + self.__name_to_index = {slot.name: i for i, slot in enumerate(self._proto_desc.multi_slot_desc.slots)} + def set_dense_slots(self, dense_slots_name): for name in dense_slots_name: - self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].dense = True + self._proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].dense = True def set_use_slots(self, use_slots_name): for name in use_slots_name: - self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].use = True + self._proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].use = True - def desc(self): - return text_format.MessageToString(self.proto_desc) class AsyncExecutor(object): """