提交 9598ebf8 编写于 作者: B barrierye

update async_executor.py for support set_use_slots

上级 623f1d46
...@@ -30,16 +30,24 @@ g_scope = core.Scope() ...@@ -30,16 +30,24 @@ g_scope = core.Scope()
class DataFeedDesc(object): class DataFeedDesc(object):
def __init__(self, proto_file): def __init__(self, proto_file):
self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc = data_feed_pb2.DataFeedDesc()
f = open(proto_file, 'r') with open(proto_file, 'r') as f:
text_format.Parse(f.read(), self.proto_desc) text_format.Parse(f.read(), self.proto_desc)
f.close() self.__name_to_index = {}
for i, slot in enumerate(self.proto_desc.multi_slot_desc.slots):
self.__name_to_index[slot.name] = i
def set_data_feed_type(self, data_feed):
self.proto_desc.name = datafeed
def set_batch_size(self, batch_size): def set_batch_size(self, batch_size):
self.proto_desc.batch = batch_size self.proto_desc.batch = batch_size
def add_slot(self): def get_slot(self, name):
slot = self.proto_desc.multi_slot_desc.slots.add() return self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]]
return slot
def set_use_slots(self, use_slots_name):
for name in use_slots_name:
self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].use = True
def desc(self): def desc(self):
return text_format.MessageToString(self.proto_desc) return text_format.MessageToString(self.proto_desc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册