From 69d18ea3dc4685c6a54aaaf232012d2786965810 Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Mon, 4 Nov 2019 10:48:40 +0800 Subject: [PATCH] fix train logit --- paddleslim/dist/mp_distiller.py | 44 ++++++++++++++------------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/paddleslim/dist/mp_distiller.py b/paddleslim/dist/mp_distiller.py index f8d3222e..ff15f5f1 100755 --- a/paddleslim/dist/mp_distiller.py +++ b/paddleslim/dist/mp_distiller.py @@ -67,8 +67,7 @@ class Knowledge(object): assert ( len(self.path) == 4 and isinstance(self.path[0], str) and isinstance(self.path[1], str) and - isinstance(self.path[2], str) and - isinstance(self.path[3], str) + isinstance(self.path[2], str) and isinstance(self.path[3], str) ), "path should contains four str, ['local hadoop home', 'fs.default.name', 'hadoop.job.ugi', 'FS path']" hadoop_home = self.path[0] @@ -104,8 +103,8 @@ class Knowledge(object): np.save(file_path, data) self.file_cnt += 1 self.client.upload(self.path[3], file_path) - logger.info('{}.npy pushed to HDFS/AFS: {}'.format( - file_name, self.path[3])) + logger.info('{}.npy pushed to HDFS/AFS: {}'.format(file_name, + self.path[3])) elif self.write_type == 'LocalFS': file_name = 'knowledge_' + str(self.file_cnt) @@ -116,6 +115,7 @@ class Knowledge(object): else: self.knowledge_queue.put(data) + logger.info('{} pushed to Queue'.format(file_name)) def run(self, teacher_program, exe, place, scope, reader, inputs, outputs, call_back): @@ -151,42 +151,36 @@ class Knowledge(object): compiled_teacher_program = fluid.compiler.CompiledProgram( teacher_program) self.file_cnt = 0 - if isinstance(reader, - Variable) or (isinstance(reader, DataLoaderBase) and - (not reader.iterable)): + if isinstance(reader, Variable) or ( + isinstance(reader, DataLoaderBase) and (not reader.iterable)): reader.start() try: - batch_id = 0 while True: - logits = exe.run( - compiled_teacher_program, - scope=scope, - fetch_list=outputs, - feed=None) + logits = exe.run(compiled_teacher_program, + scope=scope, + fetch_list=outputs, + feed=None) knowledge = dict() for index, array in enumerate(logits): knowledge[self.items[index]] = array - if batch_id % 1 == 0: - logger.info('infer finish iter {}'.format(batch_id)) self._write(knowledge) except EOFException: reader.reset() else: - feeder = fluid.DataFeeder( - feed_list=inputs, place=place, program=teacher_program) + if not isinstance(reader, DataLoaderBase): + feeder = fluid.DataFeeder( + feed_list=inputs, place=place, program=teacher_program) for batch_id, data in enumerate(reader()): - feed = feeder.feed(data) - logits = exe.run( - compiled_teacher_program, - scope=scope, - fetch_list=outputs, - feed=feed) + if not isinstance(reader, DataLoaderBase): + data = feeder.feed(data) + logits = exe.run(compiled_teacher_program, + scope=scope, + fetch_list=outputs, + feed=data) knowledge = dict() for index, array in enumerate(logits): knowledge[self.items[index]] = array - if batch_id % 1 == 0: - logger.info('infer finish iter {}'.format(batch_id)) self._write(knowledge) return True -- GitLab