提交 69d18ea3 编写于 作者: B baiyfbupt

fix train logit

上级 593094fd
...@@ -67,8 +67,7 @@ class Knowledge(object): ...@@ -67,8 +67,7 @@ class Knowledge(object):
assert ( assert (
len(self.path) == 4 and isinstance(self.path[0], str) and len(self.path) == 4 and isinstance(self.path[0], str) and
isinstance(self.path[1], str) and isinstance(self.path[1], str) and
isinstance(self.path[2], str) and isinstance(self.path[2], str) and isinstance(self.path[3], str)
isinstance(self.path[3], str)
), "path should contains four str, ['local hadoop home', 'fs.default.name', 'hadoop.job.ugi', 'FS path']" ), "path should contains four str, ['local hadoop home', 'fs.default.name', 'hadoop.job.ugi', 'FS path']"
hadoop_home = self.path[0] hadoop_home = self.path[0]
...@@ -104,8 +103,8 @@ class Knowledge(object): ...@@ -104,8 +103,8 @@ class Knowledge(object):
np.save(file_path, data) np.save(file_path, data)
self.file_cnt += 1 self.file_cnt += 1
self.client.upload(self.path[3], file_path) self.client.upload(self.path[3], file_path)
logger.info('{}.npy pushed to HDFS/AFS: {}'.format( logger.info('{}.npy pushed to HDFS/AFS: {}'.format(file_name,
file_name, self.path[3])) self.path[3]))
elif self.write_type == 'LocalFS': elif self.write_type == 'LocalFS':
file_name = 'knowledge_' + str(self.file_cnt) file_name = 'knowledge_' + str(self.file_cnt)
...@@ -116,6 +115,7 @@ class Knowledge(object): ...@@ -116,6 +115,7 @@ class Knowledge(object):
else: else:
self.knowledge_queue.put(data) self.knowledge_queue.put(data)
logger.info('{} pushed to Queue'.format(file_name))
def run(self, teacher_program, exe, place, scope, reader, inputs, outputs, def run(self, teacher_program, exe, place, scope, reader, inputs, outputs,
call_back): call_back):
...@@ -151,42 +151,36 @@ class Knowledge(object): ...@@ -151,42 +151,36 @@ class Knowledge(object):
compiled_teacher_program = fluid.compiler.CompiledProgram( compiled_teacher_program = fluid.compiler.CompiledProgram(
teacher_program) teacher_program)
self.file_cnt = 0 self.file_cnt = 0
if isinstance(reader, if isinstance(reader, Variable) or (
Variable) or (isinstance(reader, DataLoaderBase) and isinstance(reader, DataLoaderBase) and (not reader.iterable)):
(not reader.iterable)):
reader.start() reader.start()
try: try:
batch_id = 0
while True: while True:
logits = exe.run( logits = exe.run(compiled_teacher_program,
compiled_teacher_program, scope=scope,
scope=scope, fetch_list=outputs,
fetch_list=outputs, feed=None)
feed=None)
knowledge = dict() knowledge = dict()
for index, array in enumerate(logits): for index, array in enumerate(logits):
knowledge[self.items[index]] = array knowledge[self.items[index]] = array
if batch_id % 1 == 0:
logger.info('infer finish iter {}'.format(batch_id))
self._write(knowledge) self._write(knowledge)
except EOFException: except EOFException:
reader.reset() reader.reset()
else: else:
feeder = fluid.DataFeeder( if not isinstance(reader, DataLoaderBase):
feed_list=inputs, place=place, program=teacher_program) feeder = fluid.DataFeeder(
feed_list=inputs, place=place, program=teacher_program)
for batch_id, data in enumerate(reader()): for batch_id, data in enumerate(reader()):
feed = feeder.feed(data) if not isinstance(reader, DataLoaderBase):
logits = exe.run( data = feeder.feed(data)
compiled_teacher_program, logits = exe.run(compiled_teacher_program,
scope=scope, scope=scope,
fetch_list=outputs, fetch_list=outputs,
feed=feed) feed=data)
knowledge = dict() knowledge = dict()
for index, array in enumerate(logits): for index, array in enumerate(logits):
knowledge[self.items[index]] = array knowledge[self.items[index]] = array
if batch_id % 1 == 0:
logger.info('infer finish iter {}'.format(batch_id))
self._write(knowledge) self._write(knowledge)
return True return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册