From df0dff849a0f035e5e0eefafe483ee098b9dc59a Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 14 Apr 2020 16:50:57 +0800 Subject: [PATCH] Stand out toy example & fix bugs in child threads (#219) * Stand out toy example & fix bugs in child threads * Refine comments --- demo/pantheon/README.md | 2 - demo/pantheon/toy/README.md | 54 ++++++++++++++++++++++++ demo/pantheon/{ => toy}/run_student.py | 0 demo/pantheon/{ => toy}/run_teacher1.py | 0 demo/pantheon/{ => toy}/run_teacher2.py | 0 demo/pantheon/{ => toy}/utils.py | 0 paddleslim/pantheon/README.md | 55 ++----------------------- paddleslim/pantheon/student.py | 46 ++++++++++++--------- paddleslim/pantheon/teacher.py | 33 ++++++++------- 9 files changed, 101 insertions(+), 89 deletions(-) delete mode 100644 demo/pantheon/README.md create mode 100644 demo/pantheon/toy/README.md rename demo/pantheon/{ => toy}/run_student.py (100%) rename demo/pantheon/{ => toy}/run_teacher1.py (100%) rename demo/pantheon/{ => toy}/run_teacher2.py (100%) rename demo/pantheon/{ => toy}/utils.py (100%) diff --git a/demo/pantheon/README.md b/demo/pantheon/README.md deleted file mode 100644 index 3cc55c33..00000000 --- a/demo/pantheon/README.md +++ /dev/null @@ -1,2 +0,0 @@ - -The toy examples for Pantheon, see details in [PaddleSlim/Pantheon](../../paddleslim/pantheon). diff --git a/demo/pantheon/toy/README.md b/demo/pantheon/toy/README.md new file mode 100644 index 00000000..3cb561b4 --- /dev/null +++ b/demo/pantheon/toy/README.md @@ -0,0 +1,54 @@ +## Toy example for Pantheon + +See more details about Pantheon in [PaddleSlim/Pantheon](../../../paddleslim/pantheon). + +Here implements two teacher models (not trainable, just for demo): teacher1 takes an integer **x** as input and predicts value **2x-1**, see in [run_teacher1.py](run_teacher1.py); teacher2 also takes **x** as input and predicts **2x+1**, see in [run_teacher2.py](run_teacher2.py). They two share a data reader to read a sequence of increasing natural numbers from zero to some positive inter **max_n** as input and generate different knowledge. And the schema keys for knowledge in teacher1 is [**"x", "2x-1", "result"**], and [**"2x+1", "result"**] for knowledge in teacher2, in which **"result"** is the common schema and the copy of two predictions respectively. On instantiating the **Student** object, the merging strategy for the common schema **"result"** should be specified, and the schema keys for the merged knowledge will be [**"x", "2x-1", "2x+1", "result"**], with the merged **"result"** equal to **"2x"** when the merging strategy is **"mean"** and **"4x"** when merging strategy is **"sum"**. The student model gets merged knowledge from teachers and prints them out, see in [run_student.py](run_student.py). + +The toy "knowledge distillation" system can be launched in three different modes, i.e., offline, online and their hybrid. All three modes should have the same outputs, and the correctness of results can be verified by checking the order and values of outputs. + +### Offline + + The two teachers work in offline mode, and start them with given local file paths. + + ```shell +export PYTHONPATH=../../../:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0,1 +export NUM_POSTPROCESS_THREADS=10 # default 8 +nohup python -u run_teacher1.py --use_cuda true --out_path teacher1_offline.dat > teacher1_offline.log 2>&1& +export CUDA_VISIBLE_DEVICES=2 +nohup python -u run_teacher2.py --use_cuda true --out_path teacher2_offline.dat > teacher2_offline.log 2>&1& + ``` + After the two executions both finished, start the student model with the two generated knowledge files. + + ```shell +export PYTHONPATH=../../../:$PYTHONPATH + python -u run_student.py \ + --in_path0 teacher1_offline.dat \ + --in_path1 teacher2_offline.dat + ``` + + +### Online + +The two teachers work in online mode, and start them with given TCP/IP ports. Please make sure that the ICP/IP ports are available. + +```shell +export PYTHONPATH=../../../:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +nohup python -u run_teacher1.py --use_cuda true --out_port 8080 > teacher1_online.log 2>&1& +export CUDA_VISIBLE_DEVICES=1,2 +nohup python -u run_teacher2.py --use_cuda true --out_port 8081 > teacher2_online.log 2>&1& +``` +Start the student model with the IP addresses that can reach the ports of the two teacher models, e.g., in the same node + +```shell +export PYTHONPATH=../../../:$PYTHONPATH +python -u run_student.py \ + --in_address0 127.0.0.1:8080 \ + --in_address1 127.0.0.1:8081 \ +``` +**Note:** in online mode, the starting order of teachers and the sudent doesn't matter, and they will wait for each other to establish connection. + +### Hybrid of offline and online + +One teacher works in offline mode and another one works in online mode. This time, start the offline teacher first. After the offline knowledge file gets well prepared, start the online teacher and the student at the same time. diff --git a/demo/pantheon/run_student.py b/demo/pantheon/toy/run_student.py similarity index 100% rename from demo/pantheon/run_student.py rename to demo/pantheon/toy/run_student.py diff --git a/demo/pantheon/run_teacher1.py b/demo/pantheon/toy/run_teacher1.py similarity index 100% rename from demo/pantheon/run_teacher1.py rename to demo/pantheon/toy/run_teacher1.py diff --git a/demo/pantheon/run_teacher2.py b/demo/pantheon/toy/run_teacher2.py similarity index 100% rename from demo/pantheon/run_teacher2.py rename to demo/pantheon/toy/run_teacher2.py diff --git a/demo/pantheon/utils.py b/demo/pantheon/toy/utils.py similarity index 100% rename from demo/pantheon/utils.py rename to demo/pantheon/toy/utils.py diff --git a/paddleslim/pantheon/README.md b/paddleslim/pantheon/README.md index 3be5ecfa..7cd10928 100644 --- a/paddleslim/pantheon/README.md +++ b/paddleslim/pantheon/README.md @@ -199,57 +199,8 @@ data_generator = student.get_knowledge_generator( batch_size=32, drop_last=False) ``` -### Example +## Examples -Here provide a toy example to show how the knowledge data is transferred from teachers to the student model and merged. +### Toy Example -In the directory [demo/pantheon/](../../demo/pantheon/), there implement two teacher models (not trainable, just for demo): teacher1 takes an integer **x** as input and predicts value **2x-1**, see in [run_teacher1.py](../../demo/pantheon/run_teacher1.py); teacher2 also takes **x** as input and predicts **2x+1**, see in [run_teacher2.py](../../demo/pantheon/run_teacher2.py). They two share a data reader to read a sequence of increasing natural numbers from zero to some positive inter **max_n** as input and generate different knowledge. And the schema keys for knowledge in teacher1 is [**"x", "2x-1", "result"**], and [**"2x+1", "result"**] for knowledge in teacher2, in which **"result"** is the common schema and the copy of two predictions respectively. On instantiating the **Student** object, the merging strategy for the common schema **"result"** should be specified, and the schema keys for the merged knowledge will be [**"x", "2x-1", "2x+1", "result"**], with the merged **"result"** equal to **"2x"** when the merging strategy is **"mean"** and **"4x"** when merging strategy is **"sum"**. The student model gets merged knowledge from teachers and prints them out, see in [run_student.py](../../demo/pantheon/run_student.py). - -The toy "knowledge distillation" system can be launched in three different modes, i.e., offline, online and their hybrid. All three modes should have the same outputs, and the correctness of results can be verified by checking the order and values of outputs. - -1) **Offline** - - The two teachers work in offline mode, and start them with given local file paths. - - ```shell -export PYTHONPATH=../../:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=0,1 -export NUM_POSTPROCESS_THREADS=10 # default 8 -nohup python -u run_teacher1.py --use_cuda true --out_path teacher1_offline.dat > teacher1_offline.log 2>&1& -export CUDA_VISIBLE_DEVICES=2 -nohup python -u run_teacher2.py --use_cuda true --out_path teacher2_offline.dat > teacher2_offline.log 2>&1& - ``` - After the two executions both finished, start the student model with the two generated knowledge files. - - ```shell -export PYTHONPATH=../../:$PYTHONPATH - python -u run_student.py \ - --in_path0 teacher1_offline.dat \ - --in_path1 teacher2_offline.dat - ``` - - -2) **Online** - -The two teachers work in online mode, and start them with given TCP/IP ports. Please make sure that the ICP/IP ports are available. - -```shell -export PYTHONPATH=../../:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=0 -nohup python -u run_teacher1.py --use_cuda true --out_port 8080 > teacher1_online.log 2>&1& -export CUDA_VISIBLE_DEVICES=1,2 -nohup python -u run_teacher2.py --use_cuda true --out_port 8081 > teacher2_online.log 2>&1& -``` -Start the student model with the IP addresses that can reach the ports of the two teacher models, e.g., in the same node - -```shell -export PYTHONPATH=../../:$PYTHONPATH -python -u run_student.py \ - --in_address0 127.0.0.1:8080 \ - --in_address1 127.0.0.1:8081 \ -``` -**Note:** in online mode, the starting order of teachers and the sudent doesn't matter, and they will wait for each other to establish connection. - -3) **Hybrid of offline and online** - -One teacher works in offline mode and another one works in online mode. This time, start the offline teacher first. After the offline knowledge file gets well prepared, start the online teacher and the student at the same time. +A toy example is provied to show how the knowledge data is transferred from teachers to the student model and merged, including offline, online modes and their hybrid. See [demo/pantheon/toy](../../demo/pantheon/toy). diff --git a/paddleslim/pantheon/student.py b/paddleslim/pantheon/student.py index 3f522b60..72bdd5e1 100644 --- a/paddleslim/pantheon/student.py +++ b/paddleslim/pantheon/student.py @@ -104,9 +104,9 @@ class Student(object): manager = BaseManager( address=(ip, int(port)), authkey=public_authkey.encode()) - # Wait for teacher model started to establish connection print("Connecting to {}, with public key {} ...".format( in_address, public_authkey)) + # Wait for teacher model started to establish connection while True: try: manager.connect() @@ -122,27 +122,37 @@ class Student(object): def receive(queue, local_queue): while True: - data = queue.get() - queue.task_done() - local_queue.put(data) - if isinstance(data, EndSignal): + try: + data = queue.get() + queue.task_done() + local_queue.put(data) + except EOFError: break knowledge_queue = Queue.Queue(100) def gather(local_queues, knowledge_queue): num = len(local_queues) - end_received = False + end_received = [0] * num while True: - for i in range(num): - data = local_queues[i].get() - local_queues[i].task_done() - if isinstance(data, SyncSignal) and i > 0: - continue - elif isinstance(data, EndSignal): - end_received = True - knowledge_queue.put(data) - if end_received: + try: + for i in range(num): + data = local_queues[i].get() + local_queues[i].task_done() + + if isinstance(data, SyncSignal): + if i == 0: + knowledge_queue.put(data) + elif isinstance(data, EndSignal): + end_received[i] = 1 + if i == 0: + knowledge_queue.put(data) + if sum(end_received) == num: + end_received = [0] * num + break + else: + knowledge_queue.put(data) + except EOFError: break # threads to receive knowledge from the online teacher @@ -419,7 +429,6 @@ class Student(object): "Return None.") return None self._is_knowledge_gen_locked = True - self.get_knowledge_desc() def split_batch(batch, num): @@ -536,8 +545,8 @@ class Student(object): queue.put(StartSignal()) queue.join() - # launch threads to listen on all knowledge queues local_queues = [Queue.Queue(100) for i in range(self._num_teachers)] + # launch threads to listen on all knowledge queues for i in range(self._num_teachers): listen_thread = Thread( target=listen, @@ -545,8 +554,8 @@ class Student(object): listen_thread.dameon = True listen_thread.start() - # launch threads to make new batch for student med_queues = [Queue.Queue(100) for i in range(self._num_teachers)] + # launch threads to make new batch for student for i in range(self._num_teachers): listen_thread = Thread( target=make_new_batch, @@ -560,7 +569,6 @@ class Student(object): merge_thread.dameon = True merge_thread.start() - # yield knowledge data def wrapper(): while True: knowledge = self._knowledge_queue.get() diff --git a/paddleslim/pantheon/teacher.py b/paddleslim/pantheon/teacher.py index 9a1d5ae2..281fe23d 100644 --- a/paddleslim/pantheon/teacher.py +++ b/paddleslim/pantheon/teacher.py @@ -122,6 +122,7 @@ class WorkerParallel(object): else: for q in self._local_in_queues: q.put(EndSignal()) + break t = Thread(target=func) t.daemon = True @@ -138,13 +139,18 @@ class WorkerParallel(object): def _gather(self): def func(): + end_received = False while True: for idx, q in enumerate(self._local_out_queues): data = q.get() q.task_done() - if isinstance(data, EndSignal) and idx > 0: - continue + if isinstance(data, EndSignal): + end_received = True + if idx > 0: + continue self._out_queue.put(data) + if end_received: + break t = Thread(target=func) t.daemon = True @@ -539,6 +545,8 @@ class Teacher(object): else: # forward other types of data directly (maybe knowledge desc or EndSignal) out_queue.put(data) + if isinstance(data, EndSignal): + break know_make_queue = Queue.Queue(self._buf_size) if self._out_file: @@ -569,8 +577,6 @@ class Teacher(object): know_make_queue, self._knowledge_queues) - make_knowledge(worker=know_maker, args=(self._use_fp16, )) - compiled_program = fluid.compiler.CompiledProgram( self._program).with_data_parallel() @@ -579,14 +585,15 @@ class Teacher(object): " Teacher begins to serve ...") data_reader = MixedDataReader(data_loader, dev_count) - # For online mode, send knowledge description every time for repeated in range(self._times): + make_knowledge(worker=know_maker, args=(self._use_fp16, )) if self._knowledge_queues: # wait for the accessing of knowledge desc and data while True: if self._sync_required: for q in self._knowledge_queues: q.put(SyncSignal()) + # For online mode, send knowledge description every sync know_make_queue.put(self._knowledge_desc) self._sync_required = False if self._data_required: @@ -601,17 +608,11 @@ class Teacher(object): data_reader.multi_dev_generator()): if self._sync_required: break - tic = time.time() outputs = self._exe.run(compiled_program, feed=dev_batches, fetch_list=schema_in_fetch_vars) - toc = time.time() - print("teacher predict time = {}".format(toc - tic)) know_make_queue.put((dev_batches, outputs)) - #out_buf_queue.put(know) - tic = time.time() - print("teacher out time = {}".format(tic - toc)) num_batches_sent += dev_count if num_batches_sent % (100 * dev_count) == 0: log = "Processed {} batch samples.".format( @@ -641,18 +642,18 @@ class Teacher(object): outputs = copy.deepcopy(output) if dev_batches or outputs: know_make_queue.put((dev_batches, outputs)) - #out_buf_queue.put(know) num_batches_sent += (index + 1) print("Processed {} batch samples in total.".format( num_batches_sent)) - know_make_queue.put(EndSignal()) know_make_queue.join() - if self._knowledge_queues: - for q in self._knowledge_queues: - q.join() + if self._knowledge_queues: + for q in self._knowledge_queues: + q.join() + if self._out_file: + offline_write_queue.join() print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + " Teacher ends serving.") -- GitLab