diff --git a/paddleslim/pantheon/student.py b/paddleslim/pantheon/student.py index f501bb1f3b47177d79ac9d0c75e2a8311cd169ba..2964c4864933fc384857645f9a83d85220003b70 100644 --- a/paddleslim/pantheon/student.py +++ b/paddleslim/pantheon/student.py @@ -16,8 +16,10 @@ import six import time if six.PY2: import cPickle as pickle + import Queue else: import pickle + import queue as Queue import numpy as np from collections import OrderedDict @@ -54,7 +56,7 @@ class Student(object): self._common_schema = merge_strategy.keys() if merge_strategy else [] self._knowledge_desc = OrderedDict() - self._knowledge_queue = Manager().Queue(100) + self._knowledge_queue = Queue.Queue(100) self._teacher_knowledge_queues = [] self._t2s_queues = [] self._s2t_queues = [] @@ -180,9 +182,9 @@ class Student(object): out_queue.put(EndSignal()) out_queue.join() - knowledge_queue = Manager().Queue(100) - cmd_queue = Manager().Queue(5) - p = Process( + knowledge_queue = Queue.Queue(100) + cmd_queue = Queue.Queue(5) + p = Thread( target=read_offline, args=(in_path, cmd_queue, knowledge_queue)) p.daemon = True @@ -372,6 +374,8 @@ class Student(object): return first, second def concat_batches(batches): + if len(batches) == 1: + return batches[0] keys = batches[0].keys() ret_batch = {} for key in keys: @@ -379,52 +383,65 @@ class Student(object): [batches[i][key] for i in range(len(batches))]) return ret_batch - def listen(queues, out_queue): - def data_receiver(queue, batch_size): - def wrapper(): - # The batch size of the teacher and student model may be - # not the same, make a new batch in the batch size of the - # student model. - batches, num_samples = [], 0 - while True: - batch_samples = queue.get() - queue.task_done() - if not isinstance(batch_samples, EndSignal): - cur_num_samples = list(batch_samples.values())[ - 0].shape[0] - if num_samples + cur_num_samples < batch_size: - batches.append(batch_samples) - num_samples += cur_num_samples - elif num_samples + cur_num_samples == batch_size: - batches.append(batch_samples) - yield concat_batches(batches) - batches, num_samples = [], 0 - else: - num_splited = batch_size - num_samples - first, second = split_batch(batch_samples, - num_splited) - batches.append(first) - yield concat_batches(batches) - num_left = cur_num_samples - num_splited - while num_left > batch_size: - first, second = split_batch(second, - batch_size) - yield first - num_left -= batch_size - batches, num_samples = [second], num_left + def listen(in_queue, out_queue, batch_size): + """ + listen on the knowledge queue for one teacher, get knowledge + data and make a new batch data in the batch size of student, + then put it into the intermediate queue (out_queue). + """ + batches, num_samples = [], 0 + while True: + batch_samples = in_queue.get() + in_queue.task_done() + if not isinstance(batch_samples, EndSignal): + cur_num_samples = list(batch_samples.values())[0].shape[0] + if num_samples + cur_num_samples < batch_size: + batches.append(batch_samples) + num_samples += cur_num_samples + elif num_samples + cur_num_samples == batch_size: + batches.append(batch_samples) + out_queue.put(concat_batches(batches)) + batches, num_samples = [], 0 + else: + num_splited = batch_size - num_samples + first, second = split_batch(batch_samples, num_splited) + batches.append(first) + out_queue.put(concat_batches(batches)) + num_left = cur_num_samples - num_splited + while num_left > batch_size: + first, second = split_batch(second, batch_size) + out_queue.put(first) + num_left -= batch_size + + if num_left == batch_size: + out_queue.put(second) + batches, num_samples = [], 0 else: - if len(batches) > 0: - yield concat_batches(batches) - yield EndSignal() - break + batches, num_samples = [second], num_left + else: + if len(batches) > 0: + out_queue.put(concat_batches(batches)) + out_queue.put(EndSignal()) + break - return wrapper + def gather_and_merge(in_queues, out_queue): + """ + Gather knowledge from all intermediate queues, merge them + and put the final knowledge into the knowledge queue to + student (out_queue). + """ - data_receivers = [ - data_receiver(queue, self._batch_size)() for queue in queues - ] + def data_receiver(queue): + while True: + batch = queue.get() + queue.task_done() + yield batch + if isinstance(batch, EndSignal): + break + + data_receivers = [data_receiver(queue) for queue in in_queues] - end_received = [0] * len(queues) + end_received = [0] * len(in_queues) while True: knowledge = OrderedDict( [(k, []) for k, v in list(self._knowledge_desc.items())]) @@ -437,7 +454,7 @@ class Student(object): knowledge[k].append(v) else: end_received[idx] = 1 - if sum(end_received) == len(queues): + if sum(end_received) == len(in_queues): break knowledge = self._merge_knowledge(knowledge) out_queue.put(knowledge) @@ -450,15 +467,24 @@ class Student(object): queue.put(StartSignal()) queue.join() - self._listen_thread = Thread( - target=listen, - args=(self._teacher_knowledge_queues, self._knowledge_queue)) - self._listen_thread.dameon = True - self._listen_thread.start() - + # launch multiple threads to listen on all knowledge queues + med_queues = [Queue.Queue(100) for i in range(self._num_teachers)] + for i in range(self._num_teachers): + listen_thread = Thread( + target=listen, + args=(self._teacher_knowledge_queues[i], med_queues[i], + self._batch_size)) + listen_thread.dameon = True + listen_thread.start() + + # launch another thread to merge knowledge + merge_thread = Thread( + target=gather_and_merge, args=(med_queues, self._knowledge_queue)) + merge_thread.dameon = True + merge_thread.start() + + # yield knowledge data def wrapper(): - samples = [] - while True: knowledge = self._knowledge_queue.get() self._knowledge_queue.task_done()