diff --git a/demo/pantheon/run_teacher1.py b/demo/pantheon/run_teacher1.py
index bbe94310b1de83ef1327426108b037a5a4c62258..1e0e089877b642a66cf5bf7dd3e171af28a62f91 100644
--- a/demo/pantheon/run_teacher1.py
+++ b/demo/pantheon/run_teacher1.py
@@ -72,6 +72,7 @@ def run(args):
program=program,
reader_config=reader_config,
exe=exe,
+ use_fp16=True,
times=args.serving_times)
diff --git a/paddleslim/pantheon/README.md b/paddleslim/pantheon/README.md
index 70180e4cede384762e277ef85fe9cf7e45b95841..3be5ecfa6740954b2ae06dcdbe81b25137a36681 100644
--- a/paddleslim/pantheon/README.md
+++ b/paddleslim/pantheon/README.md
@@ -106,6 +106,7 @@ Usually, the public methods of these two classes work in the pairwise way. Their
reader_config,
exe,
buf_size=10,
+
use_fp16=False,
times=1)
get_knowledge_desc() |
✅ |
@@ -213,6 +214,7 @@ The toy "knowledge distillation" system can be launched in three different modes
```shell
export PYTHONPATH=../../:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1
+export NUM_POSTPROCESS_THREADS=10 # default 8
nohup python -u run_teacher1.py --use_cuda true --out_path teacher1_offline.dat > teacher1_offline.log 2>&1&
export CUDA_VISIBLE_DEVICES=2
nohup python -u run_teacher2.py --use_cuda true --out_path teacher2_offline.dat > teacher2_offline.log 2>&1&
diff --git a/paddleslim/pantheon/student.py b/paddleslim/pantheon/student.py
index 2964c4864933fc384857645f9a83d85220003b70..3f522b607242d880c2d1aca44e0088c48b3bc876 100644
--- a/paddleslim/pantheon/student.py
+++ b/paddleslim/pantheon/student.py
@@ -28,7 +28,7 @@ from multiprocessing.managers import BaseManager
from threading import Thread
-from paddleslim.pantheon.utils import EndSignal, SyncSignal, StartSignal, public_authkey
+from paddleslim.pantheon.utils import EndSignal, SyncSignal, StartSignal, public_authkey, convert_dtype
__all__ = ["Student"]
@@ -114,7 +114,60 @@ class Student(object):
except:
time.sleep(1.0)
- knowledge_queue = manager.get_knowledge_queue()
+ def merge(knowledge_queues):
+ num = len(knowledge_queues)
+ if num == 1:
+ return knowledge_queues[0]
+ local_queues = [Queue.Queue(100) for _ in range(num)]
+
+ def receive(queue, local_queue):
+ while True:
+ data = queue.get()
+ queue.task_done()
+ local_queue.put(data)
+ if isinstance(data, EndSignal):
+ break
+
+ knowledge_queue = Queue.Queue(100)
+
+ def gather(local_queues, knowledge_queue):
+ num = len(local_queues)
+ end_received = False
+ while True:
+ for i in range(num):
+ data = local_queues[i].get()
+ local_queues[i].task_done()
+ if isinstance(data, SyncSignal) and i > 0:
+ continue
+ elif isinstance(data, EndSignal):
+ end_received = True
+ knowledge_queue.put(data)
+ if end_received:
+ break
+
+ # threads to receive knowledge from the online teacher
+ for i in range(num):
+ p = Thread(
+ target=receive,
+ args=(knowledge_queues[i], local_queues[i]))
+ p.daemon = True
+ p.start()
+ # thread to gather data from different local queues
+ p = Thread(target=gather, args=(local_queues, knowledge_queue))
+ p.daemon = True
+ p.start()
+ return knowledge_queue
+
+ # get knowledge queues
+ knowledge_queues, idx = [], 0
+ while True:
+ q = manager.get_knowledge_queue(idx)
+ if hasattr(q, "get"):
+ knowledge_queues.append(q)
+ idx += 1
+ else:
+ break
+ knowledge_queue = merge(knowledge_queues)
self._t2s_queues.append(manager.get_t2s_queue())
self._s2t_queues.append(manager.get_s2t_queue())
self._cmd_queues.append(manager.get_cmd_queue())
@@ -237,6 +290,10 @@ class Student(object):
knowledge[k] = result
elif self._merge_strategy[k] == "mean":
knowledge[k] = result / len(tensors)
+ # cast back to original data type if necessary
+ tgt_dtype = self._knowledge_desc[k]["dtype"]
+ if str(knowledge[k].dtype) != tgt_dtype:
+ knowledge[k] = knowledge[k].astype(tgt_dtype)
return knowledge
def send(self, data, teacher_ids=None):
@@ -383,11 +440,23 @@ class Student(object):
[batches[i][key] for i in range(len(batches))])
return ret_batch
- def listen(in_queue, out_queue, batch_size):
+ def listen(knowledge_queue, out_queue):
+ """
+ listen on the knowledge queue for one teacher, get knowledge data
+ and put it into a local queue (out_queue).
+ """
+ while True:
+ data = knowledge_queue.get()
+ knowledge_queue.task_done()
+ out_queue.put(data)
+ if isinstance(data, EndSignal):
+ break
+
+ def make_new_batch(in_queue, out_queue, batch_size):
"""
- listen on the knowledge queue for one teacher, get knowledge
- data and make a new batch data in the batch size of student,
- then put it into the intermediate queue (out_queue).
+ Get knowledge data from a local queue and make a new batch data in
+ the batch size of student, then put it into the intermediate
+ queue (out_queue).
"""
batches, num_samples = [], 0
while True:
@@ -467,17 +536,25 @@ class Student(object):
queue.put(StartSignal())
queue.join()
- # launch multiple threads to listen on all knowledge queues
- med_queues = [Queue.Queue(100) for i in range(self._num_teachers)]
+ # launch threads to listen on all knowledge queues
+ local_queues = [Queue.Queue(100) for i in range(self._num_teachers)]
for i in range(self._num_teachers):
listen_thread = Thread(
target=listen,
- args=(self._teacher_knowledge_queues[i], med_queues[i],
- self._batch_size))
+ args=(self._teacher_knowledge_queues[i], local_queues[i]))
+ listen_thread.dameon = True
+ listen_thread.start()
+
+ # launch threads to make new batch for student
+ med_queues = [Queue.Queue(100) for i in range(self._num_teachers)]
+ for i in range(self._num_teachers):
+ listen_thread = Thread(
+ target=make_new_batch,
+ args=(local_queues[i], med_queues[i], self._batch_size))
listen_thread.dameon = True
listen_thread.start()
- # launch another thread to merge knowledge
+ # launch another thread to merge knowledge from different teachers.
merge_thread = Thread(
target=gather_and_merge, args=(med_queues, self._knowledge_queue))
merge_thread.dameon = True
diff --git a/paddleslim/pantheon/teacher.py b/paddleslim/pantheon/teacher.py
index 425257e18886940cac98ef5a5a2df356f2e6ae67..9a1d5ae2d8271c79cc49a76c02a7d1226b811c46 100644
--- a/paddleslim/pantheon/teacher.py
+++ b/paddleslim/pantheon/teacher.py
@@ -35,7 +35,11 @@ from paddleslim.pantheon.utils import convert_dtype, EndSignal, SyncSignal, Star
__all__ = ["Teacher"]
-knowledge_queue = Queue.Queue(100)
+# Num of threads for post-processing, including generating and transferring
+# knowledge data
+num_postprocess_threads = int(os.getenv("NUM_POSTPROCESS_THREADS", 8))
+knowledge_queues = [Queue.Queue(100) for i in range(num_postprocess_threads)]
+
t2s_queue = Queue.Queue(100)
s2t_queue = Queue.Queue(100)
cmd_queue = Queue.Queue(5)
@@ -75,6 +79,84 @@ class MixedDataReader(object):
self._tail_data = []
+class WorkerParallel(object):
+ """
+ Process data from the input queue by given worker in parallel, and put the
+ result into output queue in order.
+
+ Args:
+ num_postprocess_threads (int): Number of threads for data processing.
+ in_queue (object): The input queue.
+ out_queue (object|list): The output queue(s). Its length should be equal
+ to arg 'num_postprocess_threads' when it is a list.
+ """
+
+ def __init__(self, num_postprocess_threads, in_queue, out_queue):
+ self._num_postprocess_threads = num_postprocess_threads
+ self._in_queue = in_queue
+ self._local_in_queues = [
+ Queue.Queue(5) for i in range(num_postprocess_threads)
+ ]
+ if isinstance(out_queue, list):
+ if len(out_queue) != num_postprocess_threads:
+ raise ValueError("When out_queue is a list, its length must "
+ "equal to num_postprocess_threads!")
+ self._local_out_queues = out_queue
+ self._out_queue = None
+ else:
+ self._local_out_queues = [
+ Queue.Queue(5) for i in range(num_postprocess_threads)
+ ]
+ self._out_queue = out_queue
+
+ def _distribute(self):
+ def func():
+ idx = 0
+ while True:
+ data = self._in_queue.get()
+ self._in_queue.task_done()
+ if not isinstance(data, EndSignal):
+ self._local_in_queues[
+ idx % self._num_postprocess_threads].put(data)
+ idx += 1
+ else:
+ for q in self._local_in_queues:
+ q.put(EndSignal())
+
+ t = Thread(target=func)
+ t.daemon = True
+ t.start()
+
+ def _run(self, worker, args):
+ for i in range(self._num_postprocess_threads):
+ t = Thread(
+ target=worker,
+ args=(self._local_in_queues[i], self._local_out_queues[i]) +
+ args)
+ t.daemon = True
+ t.start()
+
+ def _gather(self):
+ def func():
+ while True:
+ for idx, q in enumerate(self._local_out_queues):
+ data = q.get()
+ q.task_done()
+ if isinstance(data, EndSignal) and idx > 0:
+ continue
+ self._out_queue.put(data)
+
+ t = Thread(target=func)
+ t.daemon = True
+ t.start()
+
+ def __call__(self, worker, args):
+ self._distribute()
+ self._run(worker, args)
+ if self._out_queue:
+ self._gather()
+
+
class Teacher(object):
"""
The class defined for the teacher model. Generate knowledge data and
@@ -102,9 +184,12 @@ class Teacher(object):
self._started = False
def _start_manager(self):
- def get_knowledge_queue():
- global knowledge_queue
- return knowledge_queue
+ def get_knowledge_queue(idx):
+ global knowledge_queues
+ if idx < len(knowledge_queues):
+ return knowledge_queues[idx]
+ else:
+ return None
def get_s2t_queue():
global s2t_queue
@@ -141,12 +226,17 @@ class Teacher(object):
self._started = True
self._manager = self._start_manager() if self._out_port else None
if self._manager:
- self._knowledge_queue = self._manager.get_knowledge_queue()
+ self._knowledge_queues = [
+ self._manager.get_knowledge_queue(i)
+ for i in range(num_postprocess_threads)
+ ]
+ print("Num of knowledge queues: {}".format(
+ num_postprocess_threads))
self._s2t_queue = self._manager.get_s2t_queue()
self._t2s_queue = self._manager.get_t2s_queue()
self._cmd_queue = self._manager.get_cmd_queue()
else:
- self._knowledge_queue = None
+ self._knowledge_queues = None
self._s2t_queue = None
self._t2s_queue = None
self._cmd_queue = None
@@ -173,8 +263,9 @@ class Teacher(object):
while True:
if self._sync_required:
- self._knowledge_queue.put(SyncSignal())
- self._knowledge_queue.join()
+ for q in self._knowledge_queues:
+ q.put(SyncSignal())
+ q.join()
self._sync_required = False
break
@@ -256,6 +347,7 @@ class Teacher(object):
reader_config,
exe,
buf_size=10,
+ use_fp16=False,
times=1):
"""
Start the knowledge service to generate and transfer knowledge data.
@@ -291,6 +383,11 @@ class Teacher(object):
exe (fluid.Executor): The executor to run the input program.
buf_size (int): The size of buffers for data reader and knowledge
writer on each device.
+ use_fp16 (bool): Whether to transfer/store knowledge data in float16
+ if their data type is float32/float64. In the offline
+ mode, it will reduce the size of dumped knowledge file,
+ and in the online mode, it will speedup the online
+ transfer, with the sacrifice in precision . Default False.
times (int): The maximum repeated serving times. Default 1. Whenever
the public method 'get_knowledge_generator()' in Student
object called once, the serving times will be added one,
@@ -333,6 +430,8 @@ class Teacher(object):
raise ValueError("Input argument should be a fluid Executor!")
self._exe = exe
+ self._use_fp16 = use_fp16
+
if not buf_size > 0:
raise ValueError("The buffer size should be positive!")
self._buf_size = buf_size
@@ -402,84 +501,136 @@ class Teacher(object):
"generator type, which should be one of 'sample_generator', "
"'sample_list_generator', and 'batch_generator'.")
- def writer(buf_queue, schema_keys):
- samples_sent, batches_sent = 0, 0
+ def cast2fp16(know):
+ for k, v in list(know.items()):
+ if not isinstance(v, np.ndarray):
+ break
+ if v.dtype == np.float32 or v.dtype == np.float64:
+ v = v.astype("float16")
+ know[k] = v
+ return know
+
+ feed_var_names = [var.name for var in self._feed_list]
+ schema_in_feed, schema_in_fetch = {}, {}
+ for k, v in list(self._schema.items()):
+ if k in feed_var_names:
+ schema_in_feed[k] = v
+ else:
+ schema_in_fetch[k] = v
+ schema_in_fetch_keys, schema_in_fetch_vars = zip(
+ *list(schema_in_fetch.items()))
+
+ def know_maker(in_queue, out_queue, use_fp16):
while True:
- outputs = buf_queue.get()
- buf_queue.task_done()
- if not isinstance(outputs, EndSignal):
- batch_samples = dict(zip(schema_keys, outputs))
- if self._knowledge_queue:
- self._knowledge_queue.put(batch_samples)
- if self._out_file:
- self._out_file.write(pickle.dumps(batch_samples))
+ data = in_queue.get()
+ in_queue.task_done()
+ if isinstance(data, tuple):
+ dev_batches, outputs = data
+ know = {}
+ for k in schema_in_feed.keys():
+ batch_know = [
+ np.array(batch[k]) for batch in dev_batches
+ ]
+ know[k] = np.concatenate(batch_know)
+ know.update(dict(zip(schema_in_fetch_keys, outputs)))
+ if use_fp16:
+ know = cast2fp16(know)
+ out_queue.put(know)
else:
- if self._knowledge_queue:
- self._knowledge_queue.put(EndSignal())
- # should close file in child thread to wait for all
- # writing finished
- if self._out_file:
+ # forward other types of data directly (maybe knowledge desc or EndSignal)
+ out_queue.put(data)
+
+ know_make_queue = Queue.Queue(self._buf_size)
+ if self._out_file:
+ # For offline dump, write the knowledge description to the head of file
+ self._out_file.write(pickle.dumps(self._knowledge_desc))
+ print("output path: %s" % self._out_path)
+ offline_write_queue = Queue.Queue(self._buf_size)
+
+ def offline_write(queue):
+ while True:
+ know = queue.get()
+ queue.task_done()
+ if not isinstance(know, EndSignal):
+ self._out_file.write(pickle.dumps(know))
+ else:
+ # should close file in child thread to wait for all
+ # writing finished
self._out_file.close()
- # Asynchronous output
- out_buf_queue = Queue.Queue(self._buf_size)
- schema_keys, schema_vars = zip(*list(self._schema.items()))
- out_thread = Thread(target=writer, args=(out_buf_queue, schema_keys))
- out_thread.daemon = True
- out_thread.start()
+ t = Thread(target=offline_write, args=(offline_write_queue, ))
+ t.daemon = True
+ t.start()
+ make_knowledge = WorkerParallel(
+ num_postprocess_threads, know_make_queue, offline_write_queue)
+
+ if self._knowledge_queues:
+ make_knowledge = WorkerParallel(num_postprocess_threads,
+ know_make_queue,
+ self._knowledge_queues)
+
+ make_knowledge(worker=know_maker, args=(self._use_fp16, ))
compiled_program = fluid.compiler.CompiledProgram(
self._program).with_data_parallel()
print("Knowledge description {}".format(self._knowledge_desc))
- print(
- time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
- " Teacher begins to serve ...")
- # For offline dump, write the knowledge description to the head of file
- if self._out_file:
- self._out_file.write(pickle.dumps(self._knowledge_desc))
- print("output path: %s" % self._out_path)
+ print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
+ " Teacher begins to serve ...")
data_reader = MixedDataReader(data_loader, dev_count)
# For online mode, send knowledge description every time
for repeated in range(self._times):
- if self._knowledge_queue:
+ if self._knowledge_queues:
# wait for the accessing of knowledge desc and data
while True:
if self._sync_required:
- self._knowledge_queue.put(SyncSignal())
- self._knowledge_queue.put(self._knowledge_desc)
+ for q in self._knowledge_queues:
+ q.put(SyncSignal())
+ know_make_queue.put(self._knowledge_desc)
self._sync_required = False
if self._data_required:
self._data_required = False
break
- self._knowledge_queue.join()
+ for q in self._knowledge_queues:
+ q.join()
print("No.{} time serving ... ".format(repeated))
num_batches_sent = 0
- for dev_batches in data_reader.multi_dev_generator():
+ for index, dev_batches in enumerate(
+ data_reader.multi_dev_generator()):
if self._sync_required:
break
+ tic = time.time()
outputs = self._exe.run(compiled_program,
feed=dev_batches,
- fetch_list=schema_vars)
- out_buf_queue.put(outputs)
+ fetch_list=schema_in_fetch_vars)
+ toc = time.time()
+ print("teacher predict time = {}".format(toc - tic))
+ know_make_queue.put((dev_batches, outputs))
+ #out_buf_queue.put(know)
+ tic = time.time()
+
+ print("teacher out time = {}".format(tic - toc))
num_batches_sent += dev_count
if num_batches_sent % (100 * dev_count) == 0:
log = "Processed {} batch samples.".format(
num_batches_sent)
- if self._knowledge_queue:
- log += " Knowledge queue size {}.".format(
- self._knowledge_queue.qsize())
+ if self._knowledge_queues:
+ qsize = 0
+ for q in self._knowledge_queues:
+ qsize += q.qsize()
+ log += " Knowledge queue size {}.".format(qsize)
print(log)
- outputs = []
+ dev_batches, outputs = [], []
for index, batch in enumerate(data_reader.tail_generator()):
if self._sync_required:
break
+ dev_batches.append(batch)
output = self._exe.run(self._program,
feed=batch,
- fetch_list=schema_vars)
+ fetch_list=schema_in_fetch_vars)
if outputs:
outputs = [
np.concatenate(
@@ -488,21 +639,22 @@ class Teacher(object):
]
else:
outputs = copy.deepcopy(output)
- if outputs:
- out_buf_queue.put(outputs)
+ if dev_batches or outputs:
+ know_make_queue.put((dev_batches, outputs))
+ #out_buf_queue.put(know)
num_batches_sent += (index + 1)
print("Processed {} batch samples in total.".format(
num_batches_sent))
- out_buf_queue.put(EndSignal())
- out_buf_queue.join()
+ know_make_queue.put(EndSignal())
+ know_make_queue.join()
- if self._knowledge_queue:
- self._knowledge_queue.join()
- print(
- time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
- " Teacher ends serving.")
+ if self._knowledge_queues:
+ for q in self._knowledge_queues:
+ q.join()
+ print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
+ " Teacher ends serving.")
def __del__(self):
if self._manager: