diff --git a/demo/pantheon/run_teacher1.py b/demo/pantheon/run_teacher1.py index bbe94310b1de83ef1327426108b037a5a4c62258..1e0e089877b642a66cf5bf7dd3e171af28a62f91 100644 --- a/demo/pantheon/run_teacher1.py +++ b/demo/pantheon/run_teacher1.py @@ -72,6 +72,7 @@ def run(args): program=program, reader_config=reader_config, exe=exe, + use_fp16=True, times=args.serving_times) diff --git a/paddleslim/pantheon/README.md b/paddleslim/pantheon/README.md index 70180e4cede384762e277ef85fe9cf7e45b95841..3be5ecfa6740954b2ae06dcdbe81b25137a36681 100644 --- a/paddleslim/pantheon/README.md +++ b/paddleslim/pantheon/README.md @@ -106,6 +106,7 @@ Usually, the public methods of these two classes work in the pairwise way. Their
    reader_config,
    exe,
    buf_size=10, +
    use_fp16=False,
    times=1) get_knowledge_desc()
@@ -213,6 +214,7 @@ The toy "knowledge distillation" system can be launched in three different modes ```shell export PYTHONPATH=../../:$PYTHONPATH export CUDA_VISIBLE_DEVICES=0,1 +export NUM_POSTPROCESS_THREADS=10 # default 8 nohup python -u run_teacher1.py --use_cuda true --out_path teacher1_offline.dat > teacher1_offline.log 2>&1& export CUDA_VISIBLE_DEVICES=2 nohup python -u run_teacher2.py --use_cuda true --out_path teacher2_offline.dat > teacher2_offline.log 2>&1& diff --git a/paddleslim/pantheon/student.py b/paddleslim/pantheon/student.py index 2964c4864933fc384857645f9a83d85220003b70..3f522b607242d880c2d1aca44e0088c48b3bc876 100644 --- a/paddleslim/pantheon/student.py +++ b/paddleslim/pantheon/student.py @@ -28,7 +28,7 @@ from multiprocessing.managers import BaseManager from threading import Thread -from paddleslim.pantheon.utils import EndSignal, SyncSignal, StartSignal, public_authkey +from paddleslim.pantheon.utils import EndSignal, SyncSignal, StartSignal, public_authkey, convert_dtype __all__ = ["Student"] @@ -114,7 +114,60 @@ class Student(object): except: time.sleep(1.0) - knowledge_queue = manager.get_knowledge_queue() + def merge(knowledge_queues): + num = len(knowledge_queues) + if num == 1: + return knowledge_queues[0] + local_queues = [Queue.Queue(100) for _ in range(num)] + + def receive(queue, local_queue): + while True: + data = queue.get() + queue.task_done() + local_queue.put(data) + if isinstance(data, EndSignal): + break + + knowledge_queue = Queue.Queue(100) + + def gather(local_queues, knowledge_queue): + num = len(local_queues) + end_received = False + while True: + for i in range(num): + data = local_queues[i].get() + local_queues[i].task_done() + if isinstance(data, SyncSignal) and i > 0: + continue + elif isinstance(data, EndSignal): + end_received = True + knowledge_queue.put(data) + if end_received: + break + + # threads to receive knowledge from the online teacher + for i in range(num): + p = Thread( + target=receive, + args=(knowledge_queues[i], local_queues[i])) + p.daemon = True + p.start() + # thread to gather data from different local queues + p = Thread(target=gather, args=(local_queues, knowledge_queue)) + p.daemon = True + p.start() + return knowledge_queue + + # get knowledge queues + knowledge_queues, idx = [], 0 + while True: + q = manager.get_knowledge_queue(idx) + if hasattr(q, "get"): + knowledge_queues.append(q) + idx += 1 + else: + break + knowledge_queue = merge(knowledge_queues) self._t2s_queues.append(manager.get_t2s_queue()) self._s2t_queues.append(manager.get_s2t_queue()) self._cmd_queues.append(manager.get_cmd_queue()) @@ -237,6 +290,10 @@ class Student(object): knowledge[k] = result elif self._merge_strategy[k] == "mean": knowledge[k] = result / len(tensors) + # cast back to original data type if necessary + tgt_dtype = self._knowledge_desc[k]["dtype"] + if str(knowledge[k].dtype) != tgt_dtype: + knowledge[k] = knowledge[k].astype(tgt_dtype) return knowledge def send(self, data, teacher_ids=None): @@ -383,11 +440,23 @@ class Student(object): [batches[i][key] for i in range(len(batches))]) return ret_batch - def listen(in_queue, out_queue, batch_size): + def listen(knowledge_queue, out_queue): + """ + listen on the knowledge queue for one teacher, get knowledge data + and put it into a local queue (out_queue). + """ + while True: + data = knowledge_queue.get() + knowledge_queue.task_done() + out_queue.put(data) + if isinstance(data, EndSignal): + break + + def make_new_batch(in_queue, out_queue, batch_size): """ - listen on the knowledge queue for one teacher, get knowledge - data and make a new batch data in the batch size of student, - then put it into the intermediate queue (out_queue). + Get knowledge data from a local queue and make a new batch data in + the batch size of student, then put it into the intermediate + queue (out_queue). """ batches, num_samples = [], 0 while True: @@ -467,17 +536,25 @@ class Student(object): queue.put(StartSignal()) queue.join() - # launch multiple threads to listen on all knowledge queues - med_queues = [Queue.Queue(100) for i in range(self._num_teachers)] + # launch threads to listen on all knowledge queues + local_queues = [Queue.Queue(100) for i in range(self._num_teachers)] for i in range(self._num_teachers): listen_thread = Thread( target=listen, - args=(self._teacher_knowledge_queues[i], med_queues[i], - self._batch_size)) + args=(self._teacher_knowledge_queues[i], local_queues[i])) + listen_thread.dameon = True + listen_thread.start() + + # launch threads to make new batch for student + med_queues = [Queue.Queue(100) for i in range(self._num_teachers)] + for i in range(self._num_teachers): + listen_thread = Thread( + target=make_new_batch, + args=(local_queues[i], med_queues[i], self._batch_size)) listen_thread.dameon = True listen_thread.start() - # launch another thread to merge knowledge + # launch another thread to merge knowledge from different teachers. merge_thread = Thread( target=gather_and_merge, args=(med_queues, self._knowledge_queue)) merge_thread.dameon = True diff --git a/paddleslim/pantheon/teacher.py b/paddleslim/pantheon/teacher.py index 425257e18886940cac98ef5a5a2df356f2e6ae67..9a1d5ae2d8271c79cc49a76c02a7d1226b811c46 100644 --- a/paddleslim/pantheon/teacher.py +++ b/paddleslim/pantheon/teacher.py @@ -35,7 +35,11 @@ from paddleslim.pantheon.utils import convert_dtype, EndSignal, SyncSignal, Star __all__ = ["Teacher"] -knowledge_queue = Queue.Queue(100) +# Num of threads for post-processing, including generating and transferring +# knowledge data +num_postprocess_threads = int(os.getenv("NUM_POSTPROCESS_THREADS", 8)) +knowledge_queues = [Queue.Queue(100) for i in range(num_postprocess_threads)] + t2s_queue = Queue.Queue(100) s2t_queue = Queue.Queue(100) cmd_queue = Queue.Queue(5) @@ -75,6 +79,84 @@ class MixedDataReader(object): self._tail_data = [] +class WorkerParallel(object): + """ + Process data from the input queue by given worker in parallel, and put the + result into output queue in order. + + Args: + num_postprocess_threads (int): Number of threads for data processing. + in_queue (object): The input queue. + out_queue (object|list): The output queue(s). Its length should be equal + to arg 'num_postprocess_threads' when it is a list. + """ + + def __init__(self, num_postprocess_threads, in_queue, out_queue): + self._num_postprocess_threads = num_postprocess_threads + self._in_queue = in_queue + self._local_in_queues = [ + Queue.Queue(5) for i in range(num_postprocess_threads) + ] + if isinstance(out_queue, list): + if len(out_queue) != num_postprocess_threads: + raise ValueError("When out_queue is a list, its length must " + "equal to num_postprocess_threads!") + self._local_out_queues = out_queue + self._out_queue = None + else: + self._local_out_queues = [ + Queue.Queue(5) for i in range(num_postprocess_threads) + ] + self._out_queue = out_queue + + def _distribute(self): + def func(): + idx = 0 + while True: + data = self._in_queue.get() + self._in_queue.task_done() + if not isinstance(data, EndSignal): + self._local_in_queues[ + idx % self._num_postprocess_threads].put(data) + idx += 1 + else: + for q in self._local_in_queues: + q.put(EndSignal()) + + t = Thread(target=func) + t.daemon = True + t.start() + + def _run(self, worker, args): + for i in range(self._num_postprocess_threads): + t = Thread( + target=worker, + args=(self._local_in_queues[i], self._local_out_queues[i]) + + args) + t.daemon = True + t.start() + + def _gather(self): + def func(): + while True: + for idx, q in enumerate(self._local_out_queues): + data = q.get() + q.task_done() + if isinstance(data, EndSignal) and idx > 0: + continue + self._out_queue.put(data) + + t = Thread(target=func) + t.daemon = True + t.start() + + def __call__(self, worker, args): + self._distribute() + self._run(worker, args) + if self._out_queue: + self._gather() + + class Teacher(object): """ The class defined for the teacher model. Generate knowledge data and @@ -102,9 +184,12 @@ class Teacher(object): self._started = False def _start_manager(self): - def get_knowledge_queue(): - global knowledge_queue - return knowledge_queue + def get_knowledge_queue(idx): + global knowledge_queues + if idx < len(knowledge_queues): + return knowledge_queues[idx] + else: + return None def get_s2t_queue(): global s2t_queue @@ -141,12 +226,17 @@ class Teacher(object): self._started = True self._manager = self._start_manager() if self._out_port else None if self._manager: - self._knowledge_queue = self._manager.get_knowledge_queue() + self._knowledge_queues = [ + self._manager.get_knowledge_queue(i) + for i in range(num_postprocess_threads) + ] + print("Num of knowledge queues: {}".format( + num_postprocess_threads)) self._s2t_queue = self._manager.get_s2t_queue() self._t2s_queue = self._manager.get_t2s_queue() self._cmd_queue = self._manager.get_cmd_queue() else: - self._knowledge_queue = None + self._knowledge_queues = None self._s2t_queue = None self._t2s_queue = None self._cmd_queue = None @@ -173,8 +263,9 @@ class Teacher(object): while True: if self._sync_required: - self._knowledge_queue.put(SyncSignal()) - self._knowledge_queue.join() + for q in self._knowledge_queues: + q.put(SyncSignal()) + q.join() self._sync_required = False break @@ -256,6 +347,7 @@ class Teacher(object): reader_config, exe, buf_size=10, + use_fp16=False, times=1): """ Start the knowledge service to generate and transfer knowledge data. @@ -291,6 +383,11 @@ class Teacher(object): exe (fluid.Executor): The executor to run the input program. buf_size (int): The size of buffers for data reader and knowledge writer on each device. + use_fp16 (bool): Whether to transfer/store knowledge data in float16 + if their data type is float32/float64. In the offline + mode, it will reduce the size of dumped knowledge file, + and in the online mode, it will speedup the online + transfer, with the sacrifice in precision . Default False. times (int): The maximum repeated serving times. Default 1. Whenever the public method 'get_knowledge_generator()' in Student object called once, the serving times will be added one, @@ -333,6 +430,8 @@ class Teacher(object): raise ValueError("Input argument should be a fluid Executor!") self._exe = exe + self._use_fp16 = use_fp16 + if not buf_size > 0: raise ValueError("The buffer size should be positive!") self._buf_size = buf_size @@ -402,84 +501,136 @@ class Teacher(object): "generator type, which should be one of 'sample_generator', " "'sample_list_generator', and 'batch_generator'.") - def writer(buf_queue, schema_keys): - samples_sent, batches_sent = 0, 0 + def cast2fp16(know): + for k, v in list(know.items()): + if not isinstance(v, np.ndarray): + break + if v.dtype == np.float32 or v.dtype == np.float64: + v = v.astype("float16") + know[k] = v + return know + + feed_var_names = [var.name for var in self._feed_list] + schema_in_feed, schema_in_fetch = {}, {} + for k, v in list(self._schema.items()): + if k in feed_var_names: + schema_in_feed[k] = v + else: + schema_in_fetch[k] = v + schema_in_fetch_keys, schema_in_fetch_vars = zip( + *list(schema_in_fetch.items())) + + def know_maker(in_queue, out_queue, use_fp16): while True: - outputs = buf_queue.get() - buf_queue.task_done() - if not isinstance(outputs, EndSignal): - batch_samples = dict(zip(schema_keys, outputs)) - if self._knowledge_queue: - self._knowledge_queue.put(batch_samples) - if self._out_file: - self._out_file.write(pickle.dumps(batch_samples)) + data = in_queue.get() + in_queue.task_done() + if isinstance(data, tuple): + dev_batches, outputs = data + know = {} + for k in schema_in_feed.keys(): + batch_know = [ + np.array(batch[k]) for batch in dev_batches + ] + know[k] = np.concatenate(batch_know) + know.update(dict(zip(schema_in_fetch_keys, outputs))) + if use_fp16: + know = cast2fp16(know) + out_queue.put(know) else: - if self._knowledge_queue: - self._knowledge_queue.put(EndSignal()) - # should close file in child thread to wait for all - # writing finished - if self._out_file: + # forward other types of data directly (maybe knowledge desc or EndSignal) + out_queue.put(data) + + know_make_queue = Queue.Queue(self._buf_size) + if self._out_file: + # For offline dump, write the knowledge description to the head of file + self._out_file.write(pickle.dumps(self._knowledge_desc)) + print("output path: %s" % self._out_path) + offline_write_queue = Queue.Queue(self._buf_size) + + def offline_write(queue): + while True: + know = queue.get() + queue.task_done() + if not isinstance(know, EndSignal): + self._out_file.write(pickle.dumps(know)) + else: + # should close file in child thread to wait for all + # writing finished self._out_file.close() - # Asynchronous output - out_buf_queue = Queue.Queue(self._buf_size) - schema_keys, schema_vars = zip(*list(self._schema.items())) - out_thread = Thread(target=writer, args=(out_buf_queue, schema_keys)) - out_thread.daemon = True - out_thread.start() + t = Thread(target=offline_write, args=(offline_write_queue, )) + t.daemon = True + t.start() + make_knowledge = WorkerParallel( + num_postprocess_threads, know_make_queue, offline_write_queue) + + if self._knowledge_queues: + make_knowledge = WorkerParallel(num_postprocess_threads, + know_make_queue, + self._knowledge_queues) + + make_knowledge(worker=know_maker, args=(self._use_fp16, )) compiled_program = fluid.compiler.CompiledProgram( self._program).with_data_parallel() print("Knowledge description {}".format(self._knowledge_desc)) - print( - time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + - " Teacher begins to serve ...") - # For offline dump, write the knowledge description to the head of file - if self._out_file: - self._out_file.write(pickle.dumps(self._knowledge_desc)) - print("output path: %s" % self._out_path) + print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + + " Teacher begins to serve ...") data_reader = MixedDataReader(data_loader, dev_count) # For online mode, send knowledge description every time for repeated in range(self._times): - if self._knowledge_queue: + if self._knowledge_queues: # wait for the accessing of knowledge desc and data while True: if self._sync_required: - self._knowledge_queue.put(SyncSignal()) - self._knowledge_queue.put(self._knowledge_desc) + for q in self._knowledge_queues: + q.put(SyncSignal()) + know_make_queue.put(self._knowledge_desc) self._sync_required = False if self._data_required: self._data_required = False break - self._knowledge_queue.join() + for q in self._knowledge_queues: + q.join() print("No.{} time serving ... ".format(repeated)) num_batches_sent = 0 - for dev_batches in data_reader.multi_dev_generator(): + for index, dev_batches in enumerate( + data_reader.multi_dev_generator()): if self._sync_required: break + tic = time.time() outputs = self._exe.run(compiled_program, feed=dev_batches, - fetch_list=schema_vars) - out_buf_queue.put(outputs) + fetch_list=schema_in_fetch_vars) + toc = time.time() + print("teacher predict time = {}".format(toc - tic)) + know_make_queue.put((dev_batches, outputs)) + #out_buf_queue.put(know) + tic = time.time() + + print("teacher out time = {}".format(tic - toc)) num_batches_sent += dev_count if num_batches_sent % (100 * dev_count) == 0: log = "Processed {} batch samples.".format( num_batches_sent) - if self._knowledge_queue: - log += " Knowledge queue size {}.".format( - self._knowledge_queue.qsize()) + if self._knowledge_queues: + qsize = 0 + for q in self._knowledge_queues: + qsize += q.qsize() + log += " Knowledge queue size {}.".format(qsize) print(log) - outputs = [] + dev_batches, outputs = [], [] for index, batch in enumerate(data_reader.tail_generator()): if self._sync_required: break + dev_batches.append(batch) output = self._exe.run(self._program, feed=batch, - fetch_list=schema_vars) + fetch_list=schema_in_fetch_vars) if outputs: outputs = [ np.concatenate( @@ -488,21 +639,22 @@ class Teacher(object): ] else: outputs = copy.deepcopy(output) - if outputs: - out_buf_queue.put(outputs) + if dev_batches or outputs: + know_make_queue.put((dev_batches, outputs)) + #out_buf_queue.put(know) num_batches_sent += (index + 1) print("Processed {} batch samples in total.".format( num_batches_sent)) - out_buf_queue.put(EndSignal()) - out_buf_queue.join() + know_make_queue.put(EndSignal()) + know_make_queue.join() - if self._knowledge_queue: - self._knowledge_queue.join() - print( - time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + - " Teacher ends serving.") + if self._knowledge_queues: + for q in self._knowledge_queues: + q.join() + print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + + " Teacher ends serving.") def __del__(self): if self._manager: