未验证 提交 98a4359f 编写于 作者: Y Yibing Liu 提交者: GitHub

Optimize knowledge transfer in pantheon (#210)

上级 77c64ef4
......@@ -72,6 +72,7 @@ def run(args):
program=program,
reader_config=reader_config,
exe=exe,
use_fp16=True,
times=args.serving_times)
......
......@@ -106,6 +106,7 @@ Usually, the public methods of these two classes work in the pairwise way. Their
<br>&nbsp;&nbsp;&nbsp;&nbsp;reader_config,
<br>&nbsp;&nbsp;&nbsp;&nbsp;exe,
<br>&nbsp;&nbsp;&nbsp;&nbsp;buf_size=10,
<br>&nbsp;&nbsp;&nbsp;&nbsp;use_fp16=False,
<br>&nbsp;&nbsp;&nbsp;&nbsp;times=1)</td>
<td><strong>get_knowledge_desc</strong>()</td>
<td><center></center></td>
......@@ -213,6 +214,7 @@ The toy "knowledge distillation" system can be launched in three different modes
```shell
export PYTHONPATH=../../:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1
export NUM_POSTPROCESS_THREADS=10 # default 8
nohup python -u run_teacher1.py --use_cuda true --out_path teacher1_offline.dat > teacher1_offline.log 2>&1&
export CUDA_VISIBLE_DEVICES=2
nohup python -u run_teacher2.py --use_cuda true --out_path teacher2_offline.dat > teacher2_offline.log 2>&1&
......
......@@ -28,7 +28,7 @@ from multiprocessing.managers import BaseManager
from threading import Thread
from paddleslim.pantheon.utils import EndSignal, SyncSignal, StartSignal, public_authkey
from paddleslim.pantheon.utils import EndSignal, SyncSignal, StartSignal, public_authkey, convert_dtype
__all__ = ["Student"]
......@@ -114,7 +114,60 @@ class Student(object):
except:
time.sleep(1.0)
knowledge_queue = manager.get_knowledge_queue()
def merge(knowledge_queues):
num = len(knowledge_queues)
if num == 1:
return knowledge_queues[0]
local_queues = [Queue.Queue(100) for _ in range(num)]
def receive(queue, local_queue):
while True:
data = queue.get()
queue.task_done()
local_queue.put(data)
if isinstance(data, EndSignal):
break
knowledge_queue = Queue.Queue(100)
def gather(local_queues, knowledge_queue):
num = len(local_queues)
end_received = False
while True:
for i in range(num):
data = local_queues[i].get()
local_queues[i].task_done()
if isinstance(data, SyncSignal) and i > 0:
continue
elif isinstance(data, EndSignal):
end_received = True
knowledge_queue.put(data)
if end_received:
break
# threads to receive knowledge from the online teacher
for i in range(num):
p = Thread(
target=receive,
args=(knowledge_queues[i], local_queues[i]))
p.daemon = True
p.start()
# thread to gather data from different local queues
p = Thread(target=gather, args=(local_queues, knowledge_queue))
p.daemon = True
p.start()
return knowledge_queue
# get knowledge queues
knowledge_queues, idx = [], 0
while True:
q = manager.get_knowledge_queue(idx)
if hasattr(q, "get"):
knowledge_queues.append(q)
idx += 1
else:
break
knowledge_queue = merge(knowledge_queues)
self._t2s_queues.append(manager.get_t2s_queue())
self._s2t_queues.append(manager.get_s2t_queue())
self._cmd_queues.append(manager.get_cmd_queue())
......@@ -237,6 +290,10 @@ class Student(object):
knowledge[k] = result
elif self._merge_strategy[k] == "mean":
knowledge[k] = result / len(tensors)
# cast back to original data type if necessary
tgt_dtype = self._knowledge_desc[k]["dtype"]
if str(knowledge[k].dtype) != tgt_dtype:
knowledge[k] = knowledge[k].astype(tgt_dtype)
return knowledge
def send(self, data, teacher_ids=None):
......@@ -383,11 +440,23 @@ class Student(object):
[batches[i][key] for i in range(len(batches))])
return ret_batch
def listen(in_queue, out_queue, batch_size):
def listen(knowledge_queue, out_queue):
"""
listen on the knowledge queue for one teacher, get knowledge data
and put it into a local queue (out_queue).
"""
while True:
data = knowledge_queue.get()
knowledge_queue.task_done()
out_queue.put(data)
if isinstance(data, EndSignal):
break
def make_new_batch(in_queue, out_queue, batch_size):
"""
listen on the knowledge queue for one teacher, get knowledge
data and make a new batch data in the batch size of student,
then put it into the intermediate queue (out_queue).
Get knowledge data from a local queue and make a new batch data in
the batch size of student, then put it into the intermediate
queue (out_queue).
"""
batches, num_samples = [], 0
while True:
......@@ -467,17 +536,25 @@ class Student(object):
queue.put(StartSignal())
queue.join()
# launch multiple threads to listen on all knowledge queues
med_queues = [Queue.Queue(100) for i in range(self._num_teachers)]
# launch threads to listen on all knowledge queues
local_queues = [Queue.Queue(100) for i in range(self._num_teachers)]
for i in range(self._num_teachers):
listen_thread = Thread(
target=listen,
args=(self._teacher_knowledge_queues[i], med_queues[i],
self._batch_size))
args=(self._teacher_knowledge_queues[i], local_queues[i]))
listen_thread.dameon = True
listen_thread.start()
# launch threads to make new batch for student
med_queues = [Queue.Queue(100) for i in range(self._num_teachers)]
for i in range(self._num_teachers):
listen_thread = Thread(
target=make_new_batch,
args=(local_queues[i], med_queues[i], self._batch_size))
listen_thread.dameon = True
listen_thread.start()
# launch another thread to merge knowledge
# launch another thread to merge knowledge from different teachers.
merge_thread = Thread(
target=gather_and_merge, args=(med_queues, self._knowledge_queue))
merge_thread.dameon = True
......
......@@ -35,7 +35,11 @@ from paddleslim.pantheon.utils import convert_dtype, EndSignal, SyncSignal, Star
__all__ = ["Teacher"]
knowledge_queue = Queue.Queue(100)
# Num of threads for post-processing, including generating and transferring
# knowledge data
num_postprocess_threads = int(os.getenv("NUM_POSTPROCESS_THREADS", 8))
knowledge_queues = [Queue.Queue(100) for i in range(num_postprocess_threads)]
t2s_queue = Queue.Queue(100)
s2t_queue = Queue.Queue(100)
cmd_queue = Queue.Queue(5)
......@@ -75,6 +79,84 @@ class MixedDataReader(object):
self._tail_data = []
class WorkerParallel(object):
"""
Process data from the input queue by given worker in parallel, and put the
result into output queue in order.
Args:
num_postprocess_threads (int): Number of threads for data processing.
in_queue (object): The input queue.
out_queue (object|list): The output queue(s). Its length should be equal
to arg 'num_postprocess_threads' when it is a list.
"""
def __init__(self, num_postprocess_threads, in_queue, out_queue):
self._num_postprocess_threads = num_postprocess_threads
self._in_queue = in_queue
self._local_in_queues = [
Queue.Queue(5) for i in range(num_postprocess_threads)
]
if isinstance(out_queue, list):
if len(out_queue) != num_postprocess_threads:
raise ValueError("When out_queue is a list, its length must "
"equal to num_postprocess_threads!")
self._local_out_queues = out_queue
self._out_queue = None
else:
self._local_out_queues = [
Queue.Queue(5) for i in range(num_postprocess_threads)
]
self._out_queue = out_queue
def _distribute(self):
def func():
idx = 0
while True:
data = self._in_queue.get()
self._in_queue.task_done()
if not isinstance(data, EndSignal):
self._local_in_queues[
idx % self._num_postprocess_threads].put(data)
idx += 1
else:
for q in self._local_in_queues:
q.put(EndSignal())
t = Thread(target=func)
t.daemon = True
t.start()
def _run(self, worker, args):
for i in range(self._num_postprocess_threads):
t = Thread(
target=worker,
args=(self._local_in_queues[i], self._local_out_queues[i]) +
args)
t.daemon = True
t.start()
def _gather(self):
def func():
while True:
for idx, q in enumerate(self._local_out_queues):
data = q.get()
q.task_done()
if isinstance(data, EndSignal) and idx > 0:
continue
self._out_queue.put(data)
t = Thread(target=func)
t.daemon = True
t.start()
def __call__(self, worker, args):
self._distribute()
self._run(worker, args)
if self._out_queue:
self._gather()
class Teacher(object):
"""
The class defined for the teacher model. Generate knowledge data and
......@@ -102,9 +184,12 @@ class Teacher(object):
self._started = False
def _start_manager(self):
def get_knowledge_queue():
global knowledge_queue
return knowledge_queue
def get_knowledge_queue(idx):
global knowledge_queues
if idx < len(knowledge_queues):
return knowledge_queues[idx]
else:
return None
def get_s2t_queue():
global s2t_queue
......@@ -141,12 +226,17 @@ class Teacher(object):
self._started = True
self._manager = self._start_manager() if self._out_port else None
if self._manager:
self._knowledge_queue = self._manager.get_knowledge_queue()
self._knowledge_queues = [
self._manager.get_knowledge_queue(i)
for i in range(num_postprocess_threads)
]
print("Num of knowledge queues: {}".format(
num_postprocess_threads))
self._s2t_queue = self._manager.get_s2t_queue()
self._t2s_queue = self._manager.get_t2s_queue()
self._cmd_queue = self._manager.get_cmd_queue()
else:
self._knowledge_queue = None
self._knowledge_queues = None
self._s2t_queue = None
self._t2s_queue = None
self._cmd_queue = None
......@@ -173,8 +263,9 @@ class Teacher(object):
while True:
if self._sync_required:
self._knowledge_queue.put(SyncSignal())
self._knowledge_queue.join()
for q in self._knowledge_queues:
q.put(SyncSignal())
q.join()
self._sync_required = False
break
......@@ -256,6 +347,7 @@ class Teacher(object):
reader_config,
exe,
buf_size=10,
use_fp16=False,
times=1):
"""
Start the knowledge service to generate and transfer knowledge data.
......@@ -291,6 +383,11 @@ class Teacher(object):
exe (fluid.Executor): The executor to run the input program.
buf_size (int): The size of buffers for data reader and knowledge
writer on each device.
use_fp16 (bool): Whether to transfer/store knowledge data in float16
if their data type is float32/float64. In the offline
mode, it will reduce the size of dumped knowledge file,
and in the online mode, it will speedup the online
transfer, with the sacrifice in precision . Default False.
times (int): The maximum repeated serving times. Default 1. Whenever
the public method 'get_knowledge_generator()' in Student
object called once, the serving times will be added one,
......@@ -333,6 +430,8 @@ class Teacher(object):
raise ValueError("Input argument should be a fluid Executor!")
self._exe = exe
self._use_fp16 = use_fp16
if not buf_size > 0:
raise ValueError("The buffer size should be positive!")
self._buf_size = buf_size
......@@ -402,84 +501,136 @@ class Teacher(object):
"generator type, which should be one of 'sample_generator', "
"'sample_list_generator', and 'batch_generator'.")
def writer(buf_queue, schema_keys):
samples_sent, batches_sent = 0, 0
def cast2fp16(know):
for k, v in list(know.items()):
if not isinstance(v, np.ndarray):
break
if v.dtype == np.float32 or v.dtype == np.float64:
v = v.astype("float16")
know[k] = v
return know
feed_var_names = [var.name for var in self._feed_list]
schema_in_feed, schema_in_fetch = {}, {}
for k, v in list(self._schema.items()):
if k in feed_var_names:
schema_in_feed[k] = v
else:
schema_in_fetch[k] = v
schema_in_fetch_keys, schema_in_fetch_vars = zip(
*list(schema_in_fetch.items()))
def know_maker(in_queue, out_queue, use_fp16):
while True:
outputs = buf_queue.get()
buf_queue.task_done()
if not isinstance(outputs, EndSignal):
batch_samples = dict(zip(schema_keys, outputs))
if self._knowledge_queue:
self._knowledge_queue.put(batch_samples)
if self._out_file:
self._out_file.write(pickle.dumps(batch_samples))
data = in_queue.get()
in_queue.task_done()
if isinstance(data, tuple):
dev_batches, outputs = data
know = {}
for k in schema_in_feed.keys():
batch_know = [
np.array(batch[k]) for batch in dev_batches
]
know[k] = np.concatenate(batch_know)
know.update(dict(zip(schema_in_fetch_keys, outputs)))
if use_fp16:
know = cast2fp16(know)
out_queue.put(know)
else:
if self._knowledge_queue:
self._knowledge_queue.put(EndSignal())
# should close file in child thread to wait for all
# writing finished
if self._out_file:
# forward other types of data directly (maybe knowledge desc or EndSignal)
out_queue.put(data)
know_make_queue = Queue.Queue(self._buf_size)
if self._out_file:
# For offline dump, write the knowledge description to the head of file
self._out_file.write(pickle.dumps(self._knowledge_desc))
print("output path: %s" % self._out_path)
offline_write_queue = Queue.Queue(self._buf_size)
def offline_write(queue):
while True:
know = queue.get()
queue.task_done()
if not isinstance(know, EndSignal):
self._out_file.write(pickle.dumps(know))
else:
# should close file in child thread to wait for all
# writing finished
self._out_file.close()
# Asynchronous output
out_buf_queue = Queue.Queue(self._buf_size)
schema_keys, schema_vars = zip(*list(self._schema.items()))
out_thread = Thread(target=writer, args=(out_buf_queue, schema_keys))
out_thread.daemon = True
out_thread.start()
t = Thread(target=offline_write, args=(offline_write_queue, ))
t.daemon = True
t.start()
make_knowledge = WorkerParallel(
num_postprocess_threads, know_make_queue, offline_write_queue)
if self._knowledge_queues:
make_knowledge = WorkerParallel(num_postprocess_threads,
know_make_queue,
self._knowledge_queues)
make_knowledge(worker=know_maker, args=(self._use_fp16, ))
compiled_program = fluid.compiler.CompiledProgram(
self._program).with_data_parallel()
print("Knowledge description {}".format(self._knowledge_desc))
print(
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher begins to serve ...")
# For offline dump, write the knowledge description to the head of file
if self._out_file:
self._out_file.write(pickle.dumps(self._knowledge_desc))
print("output path: %s" % self._out_path)
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher begins to serve ...")
data_reader = MixedDataReader(data_loader, dev_count)
# For online mode, send knowledge description every time
for repeated in range(self._times):
if self._knowledge_queue:
if self._knowledge_queues:
# wait for the accessing of knowledge desc and data
while True:
if self._sync_required:
self._knowledge_queue.put(SyncSignal())
self._knowledge_queue.put(self._knowledge_desc)
for q in self._knowledge_queues:
q.put(SyncSignal())
know_make_queue.put(self._knowledge_desc)
self._sync_required = False
if self._data_required:
self._data_required = False
break
self._knowledge_queue.join()
for q in self._knowledge_queues:
q.join()
print("No.{} time serving ... ".format(repeated))
num_batches_sent = 0
for dev_batches in data_reader.multi_dev_generator():
for index, dev_batches in enumerate(
data_reader.multi_dev_generator()):
if self._sync_required:
break
tic = time.time()
outputs = self._exe.run(compiled_program,
feed=dev_batches,
fetch_list=schema_vars)
out_buf_queue.put(outputs)
fetch_list=schema_in_fetch_vars)
toc = time.time()
print("teacher predict time = {}".format(toc - tic))
know_make_queue.put((dev_batches, outputs))
#out_buf_queue.put(know)
tic = time.time()
print("teacher out time = {}".format(tic - toc))
num_batches_sent += dev_count
if num_batches_sent % (100 * dev_count) == 0:
log = "Processed {} batch samples.".format(
num_batches_sent)
if self._knowledge_queue:
log += " Knowledge queue size {}.".format(
self._knowledge_queue.qsize())
if self._knowledge_queues:
qsize = 0
for q in self._knowledge_queues:
qsize += q.qsize()
log += " Knowledge queue size {}.".format(qsize)
print(log)
outputs = []
dev_batches, outputs = [], []
for index, batch in enumerate(data_reader.tail_generator()):
if self._sync_required:
break
dev_batches.append(batch)
output = self._exe.run(self._program,
feed=batch,
fetch_list=schema_vars)
fetch_list=schema_in_fetch_vars)
if outputs:
outputs = [
np.concatenate(
......@@ -488,21 +639,22 @@ class Teacher(object):
]
else:
outputs = copy.deepcopy(output)
if outputs:
out_buf_queue.put(outputs)
if dev_batches or outputs:
know_make_queue.put((dev_batches, outputs))
#out_buf_queue.put(know)
num_batches_sent += (index + 1)
print("Processed {} batch samples in total.".format(
num_batches_sent))
out_buf_queue.put(EndSignal())
out_buf_queue.join()
know_make_queue.put(EndSignal())
know_make_queue.join()
if self._knowledge_queue:
self._knowledge_queue.join()
print(
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher ends serving.")
if self._knowledge_queues:
for q in self._knowledge_queues:
q.join()
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher ends serving.")
def __del__(self):
if self._manager:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册