未验证 提交 91a3b648 编写于 作者: Y Yibing Liu 提交者: GitHub

Update pantheon in release 1.0.0 (#124)

上级 42cded62
...@@ -80,8 +80,8 @@ def run(args): ...@@ -80,8 +80,8 @@ def run(args):
student.start() student.start()
if args.test_send_recv: if args.test_send_recv:
for t in xrange(2): for t in range(2):
for i in xrange(3): for i in range(3):
print(student.recv(t)) print(student.recv(t))
student.send("message from student!") student.send("message from student!")
......
# 多进程蒸馏 # 大规模可扩展知识蒸馏框架 Pantheon
## Teacher ## Teacher
...@@ -100,7 +100,8 @@ pantheon.Teacher.start\_knowledge\_service(feed\_list, schema, program, reader\_ ...@@ -100,7 +100,8 @@ pantheon.Teacher.start\_knowledge\_service(feed\_list, schema, program, reader\_
- **times (int):** The maximum repeated serving times, default 1. Whenever - **times (int):** The maximum repeated serving times, default 1. Whenever
the public method **get\_knowledge\_generator()** in **Student** the public method **get\_knowledge\_generator()** in **Student**
object called once, the serving times will be added one, object called once, the serving times will be added one,
until reaching the maximum and ending the service. until reaching the maximum and ending the service. Only
valid in online mode, and will be ignored in offline mode.
**Return:** None **Return:** None
......
...@@ -19,4 +19,5 @@ from paddleslim import nas ...@@ -19,4 +19,5 @@ from paddleslim import nas
from paddleslim import analysis from paddleslim import analysis
from paddleslim import dist from paddleslim import dist
from paddleslim import quant from paddleslim import quant
__all__ = ['models', 'prune', 'nas', 'analysis', 'dist', 'quant'] from paddleslim import pantheon
__all__ = ['models', 'prune', 'nas', 'analysis', 'dist', 'quant', 'pantheon']
...@@ -13,7 +13,7 @@ The illustration below shows an application of Pantheon, where the sudent model ...@@ -13,7 +13,7 @@ The illustration below shows an application of Pantheon, where the sudent model
## Prerequisites ## Prerequisites
- Python 2.7.x or 3.x - Python 2.7.x or 3.x
- PaddlePaddle >= 1.6.0 - PaddlePaddle >= 1.7.0
## APIs ## APIs
......
...@@ -158,7 +158,7 @@ class Student(object): ...@@ -158,7 +158,7 @@ class Student(object):
if end_recved: if end_recved:
break break
with open(in_path, 'r') as fin: with open(in_path, 'rb') as fin:
# get knowledge desc # get knowledge desc
desc = pickle.load(fin) desc = pickle.load(fin)
out_queue.put(desc) out_queue.put(desc)
...@@ -222,7 +222,7 @@ class Student(object): ...@@ -222,7 +222,7 @@ class Student(object):
self._started = True self._started = True
def _merge_knowledge(self, knowledge): def _merge_knowledge(self, knowledge):
for k, tensors in knowledge.items(): for k, tensors in list(knowledge.items()):
if len(tensors) == 0: if len(tensors) == 0:
del knowledge[k] del knowledge[k]
elif len(tensors) == 1: elif len(tensors) == 1:
...@@ -308,7 +308,7 @@ class Student(object): ...@@ -308,7 +308,7 @@ class Student(object):
print("Knowledge merging strategy: {}".format( print("Knowledge merging strategy: {}".format(
self._merge_strategy)) self._merge_strategy))
print("Knowledge description after merging:") print("Knowledge description after merging:")
for schema, desc in knowledge_desc.items(): for schema, desc in list(knowledge_desc.items()):
print("{}: {}".format(schema, desc)) print("{}: {}".format(schema, desc))
self._knowledge_desc = knowledge_desc self._knowledge_desc = knowledge_desc
...@@ -426,13 +426,13 @@ class Student(object): ...@@ -426,13 +426,13 @@ class Student(object):
end_received = [0] * len(queues) end_received = [0] * len(queues)
while True: while True:
knowledge = OrderedDict( knowledge = OrderedDict(
[(k, []) for k, v in self._knowledge_desc.items()]) [(k, []) for k, v in list(self._knowledge_desc.items())])
for idx, receiver in enumerate(data_receivers): for idx, receiver in enumerate(data_receivers):
if not end_received[idx]: if not end_received[idx]:
batch_samples = receiver.next( batch_samples = receiver.next(
) if six.PY2 else receiver.__next__() ) if six.PY2 else receiver.__next__()
if not isinstance(batch_samples, EndSignal): if not isinstance(batch_samples, EndSignal):
for k, v in batch_samples.items(): for k, v in list(batch_samples.items()):
knowledge[k].append(v) knowledge[k].append(v)
else: else:
end_received[idx] = 1 end_received[idx] = 1
......
...@@ -151,7 +151,7 @@ class Teacher(object): ...@@ -151,7 +151,7 @@ class Teacher(object):
self._t2s_queue = None self._t2s_queue = None
self._cmd_queue = None self._cmd_queue = None
self._out_file = open(self._out_path, "w") if self._out_path else None self._out_file = open(self._out_path, "wb") if self._out_path else None
if self._out_file: if self._out_file:
return return
...@@ -231,7 +231,7 @@ class Teacher(object): ...@@ -231,7 +231,7 @@ class Teacher(object):
"The knowledge data should be a dict or OrderedDict!") "The knowledge data should be a dict or OrderedDict!")
knowledge_desc = {} knowledge_desc = {}
for name, value in knowledge.items(): for name, value in list(knowledge.items()):
knowledge_desc[name] = { knowledge_desc[name] = {
"shape": [-1] + list(value.shape[1:]), "shape": [-1] + list(value.shape[1:]),
"dtype": str(value.dtype), "dtype": str(value.dtype),
...@@ -294,7 +294,8 @@ class Teacher(object): ...@@ -294,7 +294,8 @@ class Teacher(object):
times (int): The maximum repeated serving times. Default 1. Whenever times (int): The maximum repeated serving times. Default 1. Whenever
the public method 'get_knowledge_generator()' in Student the public method 'get_knowledge_generator()' in Student
object called once, the serving times will be added one, object called once, the serving times will be added one,
until reaching the maximum and ending the service. until reaching the maximum and ending the service. Only
valid in online mode, and will be ignored in offline mode.
""" """
if not self._started: if not self._started:
raise ValueError("The method start() should be called first!") raise ValueError("The method start() should be called first!")
...@@ -339,9 +340,12 @@ class Teacher(object): ...@@ -339,9 +340,12 @@ class Teacher(object):
if not times > 0: if not times > 0:
raise ValueError("Repeated serving times should be positive!") raise ValueError("Repeated serving times should be positive!")
self._times = times self._times = times
if self._times > 1 and self._out_file:
self._times = 1
print("WARNING: args 'times' will be ignored in offline mode")
desc = {} desc = {}
for name, var in schema.items(): for name, var in list(schema.items()):
if not isinstance(var, fluid.framework.Variable): if not isinstance(var, fluid.framework.Variable):
raise ValueError( raise ValueError(
"The member of schema must be fluid Variable.") "The member of schema must be fluid Variable.")
...@@ -412,10 +416,14 @@ class Teacher(object): ...@@ -412,10 +416,14 @@ class Teacher(object):
else: else:
if self._knowledge_queue: if self._knowledge_queue:
self._knowledge_queue.put(EndSignal()) self._knowledge_queue.put(EndSignal())
# should close file in child thread to wait for all
# writing finished
if self._out_file:
self._out_file.close()
# Asynchronous output # Asynchronous output
out_buf_queue = Queue.Queue(self._buf_size) out_buf_queue = Queue.Queue(self._buf_size)
schema_keys, schema_vars = zip(*self._schema.items()) schema_keys, schema_vars = zip(*list(self._schema.items()))
out_thread = Thread(target=writer, args=(out_buf_queue, schema_keys)) out_thread = Thread(target=writer, args=(out_buf_queue, schema_keys))
out_thread.daemon = True out_thread.daemon = True
out_thread.start() out_thread.start()
...@@ -424,8 +432,9 @@ class Teacher(object): ...@@ -424,8 +432,9 @@ class Teacher(object):
self._program).with_data_parallel() self._program).with_data_parallel()
print("Knowledge description {}".format(self._knowledge_desc)) print("Knowledge description {}".format(self._knowledge_desc))
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + print(
" Teacher begins to serve ...") time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher begins to serve ...")
# For offline dump, write the knowledge description to the head of file # For offline dump, write the knowledge description to the head of file
if self._out_file: if self._out_file:
self._out_file.write(pickle.dumps(self._knowledge_desc)) self._out_file.write(pickle.dumps(self._knowledge_desc))
...@@ -491,11 +500,10 @@ class Teacher(object): ...@@ -491,11 +500,10 @@ class Teacher(object):
if self._knowledge_queue: if self._knowledge_queue:
self._knowledge_queue.join() self._knowledge_queue.join()
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + print(
" Teacher ends serving.") time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher ends serving.")
def __del__(self): def __del__(self):
if self._manager: if self._manager:
self._manager.shutdown() self._manager.shutdown()
if self._out_file:
self._out_file.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册