提交 69d18ea3 编写于 作者: B baiyfbupt

fix train logit

上级 593094fd
......@@ -67,8 +67,7 @@ class Knowledge(object):
assert (
len(self.path) == 4 and isinstance(self.path[0], str) and
isinstance(self.path[1], str) and
isinstance(self.path[2], str) and
isinstance(self.path[3], str)
isinstance(self.path[2], str) and isinstance(self.path[3], str)
), "path should contains four str, ['local hadoop home', 'fs.default.name', 'hadoop.job.ugi', 'FS path']"
hadoop_home = self.path[0]
......@@ -104,8 +103,8 @@ class Knowledge(object):
np.save(file_path, data)
self.file_cnt += 1
self.client.upload(self.path[3], file_path)
logger.info('{}.npy pushed to HDFS/AFS: {}'.format(
file_name, self.path[3]))
logger.info('{}.npy pushed to HDFS/AFS: {}'.format(file_name,
self.path[3]))
elif self.write_type == 'LocalFS':
file_name = 'knowledge_' + str(self.file_cnt)
......@@ -116,6 +115,7 @@ class Knowledge(object):
else:
self.knowledge_queue.put(data)
logger.info('{} pushed to Queue'.format(file_name))
def run(self, teacher_program, exe, place, scope, reader, inputs, outputs,
call_back):
......@@ -151,42 +151,36 @@ class Knowledge(object):
compiled_teacher_program = fluid.compiler.CompiledProgram(
teacher_program)
self.file_cnt = 0
if isinstance(reader,
Variable) or (isinstance(reader, DataLoaderBase) and
(not reader.iterable)):
if isinstance(reader, Variable) or (
isinstance(reader, DataLoaderBase) and (not reader.iterable)):
reader.start()
try:
batch_id = 0
while True:
logits = exe.run(
compiled_teacher_program,
scope=scope,
fetch_list=outputs,
feed=None)
logits = exe.run(compiled_teacher_program,
scope=scope,
fetch_list=outputs,
feed=None)
knowledge = dict()
for index, array in enumerate(logits):
knowledge[self.items[index]] = array
if batch_id % 1 == 0:
logger.info('infer finish iter {}'.format(batch_id))
self._write(knowledge)
except EOFException:
reader.reset()
else:
feeder = fluid.DataFeeder(
feed_list=inputs, place=place, program=teacher_program)
if not isinstance(reader, DataLoaderBase):
feeder = fluid.DataFeeder(
feed_list=inputs, place=place, program=teacher_program)
for batch_id, data in enumerate(reader()):
feed = feeder.feed(data)
logits = exe.run(
compiled_teacher_program,
scope=scope,
fetch_list=outputs,
feed=feed)
if not isinstance(reader, DataLoaderBase):
data = feeder.feed(data)
logits = exe.run(compiled_teacher_program,
scope=scope,
fetch_list=outputs,
feed=data)
knowledge = dict()
for index, array in enumerate(logits):
knowledge[self.items[index]] = array
if batch_id % 1 == 0:
logger.info('infer finish iter {}'.format(batch_id))
self._write(knowledge)
return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册