diff --git a/paddleslim/dist/mp_distiller.py b/paddleslim/dist/mp_distiller.py index ab42035346446971d2e41227665de6d099694401..f8d3222e3232749ef5268f12d476c59bd93c0119 100755 --- a/paddleslim/dist/mp_distiller.py +++ b/paddleslim/dist/mp_distiller.py @@ -27,7 +27,6 @@ from paddle.fluid.reader import DataLoaderBase from paddle.fluid.core import EOFException from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -151,7 +150,6 @@ class Knowledge(object): compiled_teacher_program = fluid.compiler.CompiledProgram( teacher_program) - teacher_knowledge = [] self.file_cnt = 0 if isinstance(reader, Variable) or (isinstance(reader, DataLoaderBase) and @@ -168,16 +166,11 @@ class Knowledge(object): knowledge = dict() for index, array in enumerate(logits): knowledge[self.items[index]] = array - teacher_knowledge.append(knowledge) if batch_id % 1 == 0: logger.info('infer finish iter {}'.format(batch_id)) - if len(teacher_knowledge) >= 4: - self._write(teacher_knowledge) - teacher_knowledge = [] + self._write(knowledge) except EOFException: reader.reset() - if len(teacher_knowledge) > 0: - self._write(teacher_knowledge) else: feeder = fluid.DataFeeder( @@ -192,14 +185,9 @@ class Knowledge(object): knowledge = dict() for index, array in enumerate(logits): knowledge[self.items[index]] = array - teacher_knowledge.append(knowledge) if batch_id % 1 == 0: logger.info('infer finish iter {}'.format(batch_id)) - if len(teacher_knowledge) >= 4: - self._write(teacher_knowledge) - teacher_knowledge = [] - if len(teacher_knowledge) > 0: - self._write(teacher_knowledge) + self._write(knowledge) return True def dist(self, student_program, losses):