未验证 提交 eac4f3b2 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix offline file close in pantheon (#114)

上级 a4f4298d
...@@ -222,7 +222,7 @@ class Student(object): ...@@ -222,7 +222,7 @@ class Student(object):
self._started = True self._started = True
def _merge_knowledge(self, knowledge): def _merge_knowledge(self, knowledge):
for k, tensors in knowledge.items(): for k, tensors in list(knowledge.items()):
if len(tensors) == 0: if len(tensors) == 0:
del knowledge[k] del knowledge[k]
elif len(tensors) == 1: elif len(tensors) == 1:
...@@ -308,7 +308,7 @@ class Student(object): ...@@ -308,7 +308,7 @@ class Student(object):
print("Knowledge merging strategy: {}".format( print("Knowledge merging strategy: {}".format(
self._merge_strategy)) self._merge_strategy))
print("Knowledge description after merging:") print("Knowledge description after merging:")
for schema, desc in knowledge_desc.items(): for schema, desc in list(knowledge_desc.items()):
print("{}: {}".format(schema, desc)) print("{}: {}".format(schema, desc))
self._knowledge_desc = knowledge_desc self._knowledge_desc = knowledge_desc
...@@ -426,13 +426,13 @@ class Student(object): ...@@ -426,13 +426,13 @@ class Student(object):
end_received = [0] * len(queues) end_received = [0] * len(queues)
while True: while True:
knowledge = OrderedDict( knowledge = OrderedDict(
[(k, []) for k, v in self._knowledge_desc.items()]) [(k, []) for k, v in list(self._knowledge_desc.items())])
for idx, receiver in enumerate(data_receivers): for idx, receiver in enumerate(data_receivers):
if not end_received[idx]: if not end_received[idx]:
batch_samples = receiver.next( batch_samples = receiver.next(
) if six.PY2 else receiver.__next__() ) if six.PY2 else receiver.__next__()
if not isinstance(batch_samples, EndSignal): if not isinstance(batch_samples, EndSignal):
for k, v in batch_samples.items(): for k, v in list(batch_samples.items()):
knowledge[k].append(v) knowledge[k].append(v)
else: else:
end_received[idx] = 1 end_received[idx] = 1
......
...@@ -231,7 +231,7 @@ class Teacher(object): ...@@ -231,7 +231,7 @@ class Teacher(object):
"The knowledge data should be a dict or OrderedDict!") "The knowledge data should be a dict or OrderedDict!")
knowledge_desc = {} knowledge_desc = {}
for name, value in knowledge.items(): for name, value in list(knowledge.items()):
knowledge_desc[name] = { knowledge_desc[name] = {
"shape": [-1] + list(value.shape[1:]), "shape": [-1] + list(value.shape[1:]),
"dtype": str(value.dtype), "dtype": str(value.dtype),
...@@ -294,7 +294,8 @@ class Teacher(object): ...@@ -294,7 +294,8 @@ class Teacher(object):
times (int): The maximum repeated serving times. Default 1. Whenever times (int): The maximum repeated serving times. Default 1. Whenever
the public method 'get_knowledge_generator()' in Student the public method 'get_knowledge_generator()' in Student
object called once, the serving times will be added one, object called once, the serving times will be added one,
until reaching the maximum and ending the service. until reaching the maximum and ending the service. Only
valid in online mode, and will be ignored in offline mode.
""" """
if not self._started: if not self._started:
raise ValueError("The method start() should be called first!") raise ValueError("The method start() should be called first!")
...@@ -339,9 +340,12 @@ class Teacher(object): ...@@ -339,9 +340,12 @@ class Teacher(object):
if not times > 0: if not times > 0:
raise ValueError("Repeated serving times should be positive!") raise ValueError("Repeated serving times should be positive!")
self._times = times self._times = times
if self._times > 1 and self._out_file:
self._times = 1
print("WARNING: args 'times' will be ignored in offline mode")
desc = {} desc = {}
for name, var in schema.items(): for name, var in list(schema.items()):
if not isinstance(var, fluid.framework.Variable): if not isinstance(var, fluid.framework.Variable):
raise ValueError( raise ValueError(
"The member of schema must be fluid Variable.") "The member of schema must be fluid Variable.")
...@@ -412,10 +416,14 @@ class Teacher(object): ...@@ -412,10 +416,14 @@ class Teacher(object):
else: else:
if self._knowledge_queue: if self._knowledge_queue:
self._knowledge_queue.put(EndSignal()) self._knowledge_queue.put(EndSignal())
# should close file in child thread to wait for all
# writing finished
if self._out_file:
self._out_file.close()
# Asynchronous output # Asynchronous output
out_buf_queue = Queue.Queue(self._buf_size) out_buf_queue = Queue.Queue(self._buf_size)
schema_keys, schema_vars = zip(*self._schema.items()) schema_keys, schema_vars = zip(*list(self._schema.items()))
out_thread = Thread(target=writer, args=(out_buf_queue, schema_keys)) out_thread = Thread(target=writer, args=(out_buf_queue, schema_keys))
out_thread.daemon = True out_thread.daemon = True
out_thread.start() out_thread.start()
...@@ -424,7 +432,8 @@ class Teacher(object): ...@@ -424,7 +432,8 @@ class Teacher(object):
self._program).with_data_parallel() self._program).with_data_parallel()
print("Knowledge description {}".format(self._knowledge_desc)) print("Knowledge description {}".format(self._knowledge_desc))
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + print(
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher begins to serve ...") " Teacher begins to serve ...")
# For offline dump, write the knowledge description to the head of file # For offline dump, write the knowledge description to the head of file
if self._out_file: if self._out_file:
...@@ -491,11 +500,10 @@ class Teacher(object): ...@@ -491,11 +500,10 @@ class Teacher(object):
if self._knowledge_queue: if self._knowledge_queue:
self._knowledge_queue.join() self._knowledge_queue.join()
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + print(
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher ends serving.") " Teacher ends serving.")
def __del__(self): def __del__(self):
if self._manager: if self._manager:
self._manager.shutdown() self._manager.shutdown()
if self._out_file:
self._out_file.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册