diff --git a/paddleslim/pantheon/student.py b/paddleslim/pantheon/student.py index 54b75ab20afa3008b849966993b6f5b07f2af343..e4f6b3e16886eecdc1ece841da283b948b6b4e8b 100644 --- a/paddleslim/pantheon/student.py +++ b/paddleslim/pantheon/student.py @@ -222,7 +222,7 @@ class Student(object): self._started = True def _merge_knowledge(self, knowledge): - for k, tensors in knowledge.items(): + for k, tensors in list(knowledge.items()): if len(tensors) == 0: del knowledge[k] elif len(tensors) == 1: @@ -308,7 +308,7 @@ class Student(object): print("Knowledge merging strategy: {}".format( self._merge_strategy)) print("Knowledge description after merging:") - for schema, desc in knowledge_desc.items(): + for schema, desc in list(knowledge_desc.items()): print("{}: {}".format(schema, desc)) self._knowledge_desc = knowledge_desc @@ -426,13 +426,13 @@ class Student(object): end_received = [0] * len(queues) while True: knowledge = OrderedDict( - [(k, []) for k, v in self._knowledge_desc.items()]) + [(k, []) for k, v in list(self._knowledge_desc.items())]) for idx, receiver in enumerate(data_receivers): if not end_received[idx]: batch_samples = receiver.next( ) if six.PY2 else receiver.__next__() if not isinstance(batch_samples, EndSignal): - for k, v in batch_samples.items(): + for k, v in list(batch_samples.items()): knowledge[k].append(v) else: end_received[idx] = 1 diff --git a/paddleslim/pantheon/teacher.py b/paddleslim/pantheon/teacher.py index 9a17f6c788790a82e1a865817144769dcb1cfb4f..425257e18886940cac98ef5a5a2df356f2e6ae67 100644 --- a/paddleslim/pantheon/teacher.py +++ b/paddleslim/pantheon/teacher.py @@ -231,7 +231,7 @@ class Teacher(object): "The knowledge data should be a dict or OrderedDict!") knowledge_desc = {} - for name, value in knowledge.items(): + for name, value in list(knowledge.items()): knowledge_desc[name] = { "shape": [-1] + list(value.shape[1:]), "dtype": str(value.dtype), @@ -294,7 +294,8 @@ class Teacher(object): times (int): The maximum repeated serving times. Default 1. Whenever the public method 'get_knowledge_generator()' in Student object called once, the serving times will be added one, - until reaching the maximum and ending the service. + until reaching the maximum and ending the service. Only + valid in online mode, and will be ignored in offline mode. """ if not self._started: raise ValueError("The method start() should be called first!") @@ -339,9 +340,12 @@ class Teacher(object): if not times > 0: raise ValueError("Repeated serving times should be positive!") self._times = times + if self._times > 1 and self._out_file: + self._times = 1 + print("WARNING: args 'times' will be ignored in offline mode") desc = {} - for name, var in schema.items(): + for name, var in list(schema.items()): if not isinstance(var, fluid.framework.Variable): raise ValueError( "The member of schema must be fluid Variable.") @@ -412,10 +416,14 @@ class Teacher(object): else: if self._knowledge_queue: self._knowledge_queue.put(EndSignal()) + # should close file in child thread to wait for all + # writing finished + if self._out_file: + self._out_file.close() # Asynchronous output out_buf_queue = Queue.Queue(self._buf_size) - schema_keys, schema_vars = zip(*self._schema.items()) + schema_keys, schema_vars = zip(*list(self._schema.items())) out_thread = Thread(target=writer, args=(out_buf_queue, schema_keys)) out_thread.daemon = True out_thread.start() @@ -424,8 +432,9 @@ class Teacher(object): self._program).with_data_parallel() print("Knowledge description {}".format(self._knowledge_desc)) - print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + - " Teacher begins to serve ...") + print( + time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + + " Teacher begins to serve ...") # For offline dump, write the knowledge description to the head of file if self._out_file: self._out_file.write(pickle.dumps(self._knowledge_desc)) @@ -491,11 +500,10 @@ class Teacher(object): if self._knowledge_queue: self._knowledge_queue.join() - print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + - " Teacher ends serving.") + print( + time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + + " Teacher ends serving.") def __del__(self): if self._manager: self._manager.shutdown() - if self._out_file: - self._out_file.close()