未验证 提交 45faf9bd 编写于 作者: Y Yibing Liu 提交者: GitHub

Optimize knowledge receiving in pantheon student (#182)

上级 41e40edb
......@@ -16,8 +16,10 @@ import six
import time
if six.PY2:
import cPickle as pickle
import Queue
else:
import pickle
import queue as Queue
import numpy as np
from collections import OrderedDict
......@@ -54,7 +56,7 @@ class Student(object):
self._common_schema = merge_strategy.keys() if merge_strategy else []
self._knowledge_desc = OrderedDict()
self._knowledge_queue = Manager().Queue(100)
self._knowledge_queue = Queue.Queue(100)
self._teacher_knowledge_queues = []
self._t2s_queues = []
self._s2t_queues = []
......@@ -180,9 +182,9 @@ class Student(object):
out_queue.put(EndSignal())
out_queue.join()
knowledge_queue = Manager().Queue(100)
cmd_queue = Manager().Queue(5)
p = Process(
knowledge_queue = Queue.Queue(100)
cmd_queue = Queue.Queue(5)
p = Thread(
target=read_offline,
args=(in_path, cmd_queue, knowledge_queue))
p.daemon = True
......@@ -372,6 +374,8 @@ class Student(object):
return first, second
def concat_batches(batches):
if len(batches) == 1:
return batches[0]
keys = batches[0].keys()
ret_batch = {}
for key in keys:
......@@ -379,52 +383,65 @@ class Student(object):
[batches[i][key] for i in range(len(batches))])
return ret_batch
def listen(queues, out_queue):
def data_receiver(queue, batch_size):
def wrapper():
# The batch size of the teacher and student model may be
# not the same, make a new batch in the batch size of the
# student model.
batches, num_samples = [], 0
while True:
batch_samples = queue.get()
queue.task_done()
if not isinstance(batch_samples, EndSignal):
cur_num_samples = list(batch_samples.values())[
0].shape[0]
if num_samples + cur_num_samples < batch_size:
batches.append(batch_samples)
num_samples += cur_num_samples
elif num_samples + cur_num_samples == batch_size:
batches.append(batch_samples)
yield concat_batches(batches)
batches, num_samples = [], 0
else:
num_splited = batch_size - num_samples
first, second = split_batch(batch_samples,
num_splited)
batches.append(first)
yield concat_batches(batches)
num_left = cur_num_samples - num_splited
while num_left > batch_size:
first, second = split_batch(second,
batch_size)
yield first
num_left -= batch_size
batches, num_samples = [second], num_left
def listen(in_queue, out_queue, batch_size):
"""
listen on the knowledge queue for one teacher, get knowledge
data and make a new batch data in the batch size of student,
then put it into the intermediate queue (out_queue).
"""
batches, num_samples = [], 0
while True:
batch_samples = in_queue.get()
in_queue.task_done()
if not isinstance(batch_samples, EndSignal):
cur_num_samples = list(batch_samples.values())[0].shape[0]
if num_samples + cur_num_samples < batch_size:
batches.append(batch_samples)
num_samples += cur_num_samples
elif num_samples + cur_num_samples == batch_size:
batches.append(batch_samples)
out_queue.put(concat_batches(batches))
batches, num_samples = [], 0
else:
num_splited = batch_size - num_samples
first, second = split_batch(batch_samples, num_splited)
batches.append(first)
out_queue.put(concat_batches(batches))
num_left = cur_num_samples - num_splited
while num_left > batch_size:
first, second = split_batch(second, batch_size)
out_queue.put(first)
num_left -= batch_size
if num_left == batch_size:
out_queue.put(second)
batches, num_samples = [], 0
else:
if len(batches) > 0:
yield concat_batches(batches)
yield EndSignal()
break
batches, num_samples = [second], num_left
else:
if len(batches) > 0:
out_queue.put(concat_batches(batches))
out_queue.put(EndSignal())
break
return wrapper
def gather_and_merge(in_queues, out_queue):
"""
Gather knowledge from all intermediate queues, merge them
and put the final knowledge into the knowledge queue to
student (out_queue).
"""
data_receivers = [
data_receiver(queue, self._batch_size)() for queue in queues
]
def data_receiver(queue):
while True:
batch = queue.get()
queue.task_done()
yield batch
if isinstance(batch, EndSignal):
break
data_receivers = [data_receiver(queue) for queue in in_queues]
end_received = [0] * len(queues)
end_received = [0] * len(in_queues)
while True:
knowledge = OrderedDict(
[(k, []) for k, v in list(self._knowledge_desc.items())])
......@@ -437,7 +454,7 @@ class Student(object):
knowledge[k].append(v)
else:
end_received[idx] = 1
if sum(end_received) == len(queues):
if sum(end_received) == len(in_queues):
break
knowledge = self._merge_knowledge(knowledge)
out_queue.put(knowledge)
......@@ -450,15 +467,24 @@ class Student(object):
queue.put(StartSignal())
queue.join()
self._listen_thread = Thread(
target=listen,
args=(self._teacher_knowledge_queues, self._knowledge_queue))
self._listen_thread.dameon = True
self._listen_thread.start()
# launch multiple threads to listen on all knowledge queues
med_queues = [Queue.Queue(100) for i in range(self._num_teachers)]
for i in range(self._num_teachers):
listen_thread = Thread(
target=listen,
args=(self._teacher_knowledge_queues[i], med_queues[i],
self._batch_size))
listen_thread.dameon = True
listen_thread.start()
# launch another thread to merge knowledge
merge_thread = Thread(
target=gather_and_merge, args=(med_queues, self._knowledge_queue))
merge_thread.dameon = True
merge_thread.start()
# yield knowledge data
def wrapper():
samples = []
while True:
knowledge = self._knowledge_queue.get()
self._knowledge_queue.task_done()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册