未验证 提交 eac4f3b2 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix offline file close in pantheon (#114)

上级 a4f4298d
......@@ -222,7 +222,7 @@ class Student(object):
self._started = True
def _merge_knowledge(self, knowledge):
for k, tensors in knowledge.items():
for k, tensors in list(knowledge.items()):
if len(tensors) == 0:
del knowledge[k]
elif len(tensors) == 1:
......@@ -308,7 +308,7 @@ class Student(object):
print("Knowledge merging strategy: {}".format(
self._merge_strategy))
print("Knowledge description after merging:")
for schema, desc in knowledge_desc.items():
for schema, desc in list(knowledge_desc.items()):
print("{}: {}".format(schema, desc))
self._knowledge_desc = knowledge_desc
......@@ -426,13 +426,13 @@ class Student(object):
end_received = [0] * len(queues)
while True:
knowledge = OrderedDict(
[(k, []) for k, v in self._knowledge_desc.items()])
[(k, []) for k, v in list(self._knowledge_desc.items())])
for idx, receiver in enumerate(data_receivers):
if not end_received[idx]:
batch_samples = receiver.next(
) if six.PY2 else receiver.__next__()
if not isinstance(batch_samples, EndSignal):
for k, v in batch_samples.items():
for k, v in list(batch_samples.items()):
knowledge[k].append(v)
else:
end_received[idx] = 1
......
......@@ -231,7 +231,7 @@ class Teacher(object):
"The knowledge data should be a dict or OrderedDict!")
knowledge_desc = {}
for name, value in knowledge.items():
for name, value in list(knowledge.items()):
knowledge_desc[name] = {
"shape": [-1] + list(value.shape[1:]),
"dtype": str(value.dtype),
......@@ -294,7 +294,8 @@ class Teacher(object):
times (int): The maximum repeated serving times. Default 1. Whenever
the public method 'get_knowledge_generator()' in Student
object called once, the serving times will be added one,
until reaching the maximum and ending the service.
until reaching the maximum and ending the service. Only
valid in online mode, and will be ignored in offline mode.
"""
if not self._started:
raise ValueError("The method start() should be called first!")
......@@ -339,9 +340,12 @@ class Teacher(object):
if not times > 0:
raise ValueError("Repeated serving times should be positive!")
self._times = times
if self._times > 1 and self._out_file:
self._times = 1
print("WARNING: args 'times' will be ignored in offline mode")
desc = {}
for name, var in schema.items():
for name, var in list(schema.items()):
if not isinstance(var, fluid.framework.Variable):
raise ValueError(
"The member of schema must be fluid Variable.")
......@@ -412,10 +416,14 @@ class Teacher(object):
else:
if self._knowledge_queue:
self._knowledge_queue.put(EndSignal())
# should close file in child thread to wait for all
# writing finished
if self._out_file:
self._out_file.close()
# Asynchronous output
out_buf_queue = Queue.Queue(self._buf_size)
schema_keys, schema_vars = zip(*self._schema.items())
schema_keys, schema_vars = zip(*list(self._schema.items()))
out_thread = Thread(target=writer, args=(out_buf_queue, schema_keys))
out_thread.daemon = True
out_thread.start()
......@@ -424,8 +432,9 @@ class Teacher(object):
self._program).with_data_parallel()
print("Knowledge description {}".format(self._knowledge_desc))
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher begins to serve ...")
print(
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher begins to serve ...")
# For offline dump, write the knowledge description to the head of file
if self._out_file:
self._out_file.write(pickle.dumps(self._knowledge_desc))
......@@ -491,11 +500,10 @@ class Teacher(object):
if self._knowledge_queue:
self._knowledge_queue.join()
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher ends serving.")
print(
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) +
" Teacher ends serving.")
def __del__(self):
if self._manager:
self._manager.shutdown()
if self._out_file:
self._out_file.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册