提交 593094fd 编写于 作者: B baiyfbupt

fix np.save strategy

上级 47441ff4
...@@ -27,7 +27,6 @@ from paddle.fluid.reader import DataLoaderBase ...@@ -27,7 +27,6 @@ from paddle.fluid.reader import DataLoaderBase
from paddle.fluid.core import EOFException from paddle.fluid.core import EOFException
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
...@@ -151,7 +150,6 @@ class Knowledge(object): ...@@ -151,7 +150,6 @@ class Knowledge(object):
compiled_teacher_program = fluid.compiler.CompiledProgram( compiled_teacher_program = fluid.compiler.CompiledProgram(
teacher_program) teacher_program)
teacher_knowledge = []
self.file_cnt = 0 self.file_cnt = 0
if isinstance(reader, if isinstance(reader,
Variable) or (isinstance(reader, DataLoaderBase) and Variable) or (isinstance(reader, DataLoaderBase) and
...@@ -168,16 +166,11 @@ class Knowledge(object): ...@@ -168,16 +166,11 @@ class Knowledge(object):
knowledge = dict() knowledge = dict()
for index, array in enumerate(logits): for index, array in enumerate(logits):
knowledge[self.items[index]] = array knowledge[self.items[index]] = array
teacher_knowledge.append(knowledge)
if batch_id % 1 == 0: if batch_id % 1 == 0:
logger.info('infer finish iter {}'.format(batch_id)) logger.info('infer finish iter {}'.format(batch_id))
if len(teacher_knowledge) >= 4: self._write(knowledge)
self._write(teacher_knowledge)
teacher_knowledge = []
except EOFException: except EOFException:
reader.reset() reader.reset()
if len(teacher_knowledge) > 0:
self._write(teacher_knowledge)
else: else:
feeder = fluid.DataFeeder( feeder = fluid.DataFeeder(
...@@ -192,14 +185,9 @@ class Knowledge(object): ...@@ -192,14 +185,9 @@ class Knowledge(object):
knowledge = dict() knowledge = dict()
for index, array in enumerate(logits): for index, array in enumerate(logits):
knowledge[self.items[index]] = array knowledge[self.items[index]] = array
teacher_knowledge.append(knowledge)
if batch_id % 1 == 0: if batch_id % 1 == 0:
logger.info('infer finish iter {}'.format(batch_id)) logger.info('infer finish iter {}'.format(batch_id))
if len(teacher_knowledge) >= 4: self._write(knowledge)
self._write(teacher_knowledge)
teacher_knowledge = []
if len(teacher_knowledge) > 0:
self._write(teacher_knowledge)
return True return True
def dist(self, student_program, losses): def dist(self, student_program, losses):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册