未验证 提交 45faf9bd 编写于 作者: Y Yibing Liu 提交者: GitHub

Optimize knowledge receiving in pantheon student (#182)

上级 41e40edb
...@@ -16,8 +16,10 @@ import six ...@@ -16,8 +16,10 @@ import six
import time import time
if six.PY2: if six.PY2:
import cPickle as pickle import cPickle as pickle
import Queue
else: else:
import pickle import pickle
import queue as Queue
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
...@@ -54,7 +56,7 @@ class Student(object): ...@@ -54,7 +56,7 @@ class Student(object):
self._common_schema = merge_strategy.keys() if merge_strategy else [] self._common_schema = merge_strategy.keys() if merge_strategy else []
self._knowledge_desc = OrderedDict() self._knowledge_desc = OrderedDict()
self._knowledge_queue = Manager().Queue(100) self._knowledge_queue = Queue.Queue(100)
self._teacher_knowledge_queues = [] self._teacher_knowledge_queues = []
self._t2s_queues = [] self._t2s_queues = []
self._s2t_queues = [] self._s2t_queues = []
...@@ -180,9 +182,9 @@ class Student(object): ...@@ -180,9 +182,9 @@ class Student(object):
out_queue.put(EndSignal()) out_queue.put(EndSignal())
out_queue.join() out_queue.join()
knowledge_queue = Manager().Queue(100) knowledge_queue = Queue.Queue(100)
cmd_queue = Manager().Queue(5) cmd_queue = Queue.Queue(5)
p = Process( p = Thread(
target=read_offline, target=read_offline,
args=(in_path, cmd_queue, knowledge_queue)) args=(in_path, cmd_queue, knowledge_queue))
p.daemon = True p.daemon = True
...@@ -372,6 +374,8 @@ class Student(object): ...@@ -372,6 +374,8 @@ class Student(object):
return first, second return first, second
def concat_batches(batches): def concat_batches(batches):
if len(batches) == 1:
return batches[0]
keys = batches[0].keys() keys = batches[0].keys()
ret_batch = {} ret_batch = {}
for key in keys: for key in keys:
...@@ -379,52 +383,65 @@ class Student(object): ...@@ -379,52 +383,65 @@ class Student(object):
[batches[i][key] for i in range(len(batches))]) [batches[i][key] for i in range(len(batches))])
return ret_batch return ret_batch
def listen(queues, out_queue): def listen(in_queue, out_queue, batch_size):
def data_receiver(queue, batch_size): """
def wrapper(): listen on the knowledge queue for one teacher, get knowledge
# The batch size of the teacher and student model may be data and make a new batch data in the batch size of student,
# not the same, make a new batch in the batch size of the then put it into the intermediate queue (out_queue).
# student model. """
batches, num_samples = [], 0 batches, num_samples = [], 0
while True: while True:
batch_samples = queue.get() batch_samples = in_queue.get()
queue.task_done() in_queue.task_done()
if not isinstance(batch_samples, EndSignal): if not isinstance(batch_samples, EndSignal):
cur_num_samples = list(batch_samples.values())[ cur_num_samples = list(batch_samples.values())[0].shape[0]
0].shape[0] if num_samples + cur_num_samples < batch_size:
if num_samples + cur_num_samples < batch_size: batches.append(batch_samples)
batches.append(batch_samples) num_samples += cur_num_samples
num_samples += cur_num_samples elif num_samples + cur_num_samples == batch_size:
elif num_samples + cur_num_samples == batch_size: batches.append(batch_samples)
batches.append(batch_samples) out_queue.put(concat_batches(batches))
yield concat_batches(batches) batches, num_samples = [], 0
batches, num_samples = [], 0 else:
else: num_splited = batch_size - num_samples
num_splited = batch_size - num_samples first, second = split_batch(batch_samples, num_splited)
first, second = split_batch(batch_samples, batches.append(first)
num_splited) out_queue.put(concat_batches(batches))
batches.append(first) num_left = cur_num_samples - num_splited
yield concat_batches(batches) while num_left > batch_size:
num_left = cur_num_samples - num_splited first, second = split_batch(second, batch_size)
while num_left > batch_size: out_queue.put(first)
first, second = split_batch(second, num_left -= batch_size
batch_size)
yield first if num_left == batch_size:
num_left -= batch_size out_queue.put(second)
batches, num_samples = [second], num_left batches, num_samples = [], 0
else: else:
if len(batches) > 0: batches, num_samples = [second], num_left
yield concat_batches(batches) else:
yield EndSignal() if len(batches) > 0:
break out_queue.put(concat_batches(batches))
out_queue.put(EndSignal())
break
return wrapper def gather_and_merge(in_queues, out_queue):
"""
Gather knowledge from all intermediate queues, merge them
and put the final knowledge into the knowledge queue to
student (out_queue).
"""
data_receivers = [ def data_receiver(queue):
data_receiver(queue, self._batch_size)() for queue in queues while True:
] batch = queue.get()
queue.task_done()
yield batch
if isinstance(batch, EndSignal):
break
data_receivers = [data_receiver(queue) for queue in in_queues]
end_received = [0] * len(queues) end_received = [0] * len(in_queues)
while True: while True:
knowledge = OrderedDict( knowledge = OrderedDict(
[(k, []) for k, v in list(self._knowledge_desc.items())]) [(k, []) for k, v in list(self._knowledge_desc.items())])
...@@ -437,7 +454,7 @@ class Student(object): ...@@ -437,7 +454,7 @@ class Student(object):
knowledge[k].append(v) knowledge[k].append(v)
else: else:
end_received[idx] = 1 end_received[idx] = 1
if sum(end_received) == len(queues): if sum(end_received) == len(in_queues):
break break
knowledge = self._merge_knowledge(knowledge) knowledge = self._merge_knowledge(knowledge)
out_queue.put(knowledge) out_queue.put(knowledge)
...@@ -450,15 +467,24 @@ class Student(object): ...@@ -450,15 +467,24 @@ class Student(object):
queue.put(StartSignal()) queue.put(StartSignal())
queue.join() queue.join()
self._listen_thread = Thread( # launch multiple threads to listen on all knowledge queues
target=listen, med_queues = [Queue.Queue(100) for i in range(self._num_teachers)]
args=(self._teacher_knowledge_queues, self._knowledge_queue)) for i in range(self._num_teachers):
self._listen_thread.dameon = True listen_thread = Thread(
self._listen_thread.start() target=listen,
args=(self._teacher_knowledge_queues[i], med_queues[i],
self._batch_size))
listen_thread.dameon = True
listen_thread.start()
# launch another thread to merge knowledge
merge_thread = Thread(
target=gather_and_merge, args=(med_queues, self._knowledge_queue))
merge_thread.dameon = True
merge_thread.start()
# yield knowledge data
def wrapper(): def wrapper():
samples = []
while True: while True:
knowledge = self._knowledge_queue.get() knowledge = self._knowledge_queue.get()
self._knowledge_queue.task_done() self._knowledge_queue.task_done()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册