From 9598ebf83cfaef80a3178cf5d3fd02747d583dd7 Mon Sep 17 00:00:00 2001 From: barrierye Date: Wed, 21 Nov 2018 17:17:14 +0800 Subject: [PATCH] update async_executor.py for support set_use_slots --- python/paddle/fluid/async_executor.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index a6bb4ce75ce..feabfdcfa21 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -30,16 +30,24 @@ g_scope = core.Scope() class DataFeedDesc(object): def __init__(self, proto_file): self.proto_desc = data_feed_pb2.DataFeedDesc() - f = open(proto_file, 'r') - text_format.Parse(f.read(), self.proto_desc) - f.close() + with open(proto_file, 'r') as f: + text_format.Parse(f.read(), self.proto_desc) + self.__name_to_index = {} + for i, slot in enumerate(self.proto_desc.multi_slot_desc.slots): + self.__name_to_index[slot.name] = i + + def set_data_feed_type(self, data_feed): + self.proto_desc.name = datafeed def set_batch_size(self, batch_size): self.proto_desc.batch = batch_size - def add_slot(self): - slot = self.proto_desc.multi_slot_desc.slots.add() - return slot + def get_slot(self, name): + return self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]] + + def set_use_slots(self, use_slots_name): + for name in use_slots_name: + self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].use = True def desc(self): return text_format.MessageToString(self.proto_desc) -- GitLab