diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index a6bb4ce75ce252c5167a89ddced2c4cf676efd3b..feabfdcfa21e6dc9d52c6ff174afeb0024972a1d 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -30,16 +30,24 @@ g_scope = core.Scope() class DataFeedDesc(object): def __init__(self, proto_file): self.proto_desc = data_feed_pb2.DataFeedDesc() - f = open(proto_file, 'r') - text_format.Parse(f.read(), self.proto_desc) - f.close() + with open(proto_file, 'r') as f: + text_format.Parse(f.read(), self.proto_desc) + self.__name_to_index = {} + for i, slot in enumerate(self.proto_desc.multi_slot_desc.slots): + self.__name_to_index[slot.name] = i + + def set_data_feed_type(self, data_feed): + self.proto_desc.name = datafeed def set_batch_size(self, batch_size): self.proto_desc.batch = batch_size - def add_slot(self): - slot = self.proto_desc.multi_slot_desc.slots.add() - return slot + def get_slot(self, name): + return self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]] + + def set_use_slots(self, use_slots_name): + for name in use_slots_name: + self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].use = True def desc(self): return text_format.MessageToString(self.proto_desc)