From 638a980c4336e1170661b8326a17f04b6ae34906 Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Thu, 31 Oct 2019 20:16:45 +0800 Subject: [PATCH] add __init__ and run --- paddleslim/dist/mp_distiller.py | 230 ++++++++++++++++++++++++++++++++ 1 file changed, 230 insertions(+) create mode 100755 paddleslim/dist/mp_distiller.py diff --git a/paddleslim/dist/mp_distiller.py b/paddleslim/dist/mp_distiller.py new file mode 100755 index 00000000..ec04bb0a --- /dev/null +++ b/paddleslim/dist/mp_distiller.py @@ -0,0 +1,230 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +from paddle.fluid.framework import Variable +from paddle.fluid.reader import DataLoaderBase +from paddle.fluid.core import EOFException +from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient +from six.moves.queue import Queue +import numpy as np +import os + +__all__ = ['Knowledge'] + + +class Knowledge(object): + """ + The knowledge class describes how to extract and store the dark knowledge + of the teacher model, and how the student model learns these dark knowledge. + """ + + def __init__(self, + path, + items, + reduce_strategy={'type': 'sum', + 'key': 'image'}): + """Init a knowledge instance. + Args: + path(list, str, optional): Specifies the storage path of the knowledge, + supports AFS/HDFS, local file system, and memory. + items(list): Save the tensor of the specified name + reduce_strategy(dict, optional): The policy for performing the reduce + operation. If it is set to None, + the reduce operation is not performed. + reduce_strategy.type(str): Type of reduce operation. + reduce_strategy.key(str): The key of the reduce operation. + It is an element in the item. + """ + assert (isinstance(path, list) or isinstance(path, str) or + (path is None)), "path type should be list or str or None" + assert (isinstance(items, list)), "items should be a list" + assert (isinstance(reduce_strategy, + dict)), "reduce_strategy should be a dict" + self.path = path + if isinstance(self.path, list): + self.write_type = 'HDFS/AFS' + assert (len(self.path) == 4 and isinstance(self.path[0], str) and + isinstance(self.path[1], str) and + isinstance(self.path[2], str) and + isinstance(self.path[3], str)), "path should contains four \ + str, ['local hadoop home', 'fs.default.name', 'hadoop.job.ugi', 'FS path']" + + hadoop_home = self.path[0] + configs = { + "fs.default.name": self.path[1], + "hadoop.job.ugi": self.path[2] + } + self.client = HDFSClient(hadoop_home, configs) + assert (self.client.is_exist(self.path[3]) == True + ), "Plese make sure your hadoop \ + confiuration is correct and FS path exists" + + self.hdfs_local_path = "./teacher_knowledge" + if not os.path.exists(self.hdfs_local_path): + os.mkdir(self.hdfs_local_path) + elif isinstance(self.path, str): + self.write_type = "LocalFS" + if not os.path.exists(path): + raise ValueError("The local path [%s] does not exist." % + (path)) + else: + self.write_type = "MEM" + self.knowledge_queue = Queue(64) + + self.items = items + self.reduce_strategy = reduce_strategy + + def _write(self, data): + if self.write_type == 'HDFS/AFS': + file_name = 'knowledge_' + str(self.file_cnt) + file_path = self.hdfs_local_path + file_name + np.save(file_path, data) + self.file_cnt += 1 + self.client.upload(self.path[3], self.local_file_path) + print('{}.npy pushed to HDFS/AFS: {}'.format(file_name, self.path[ + 3])) + + elif self.write_type == 'LocalFS': + file_name = 'knowledge_' + str(self.file_cnt) + file_path = os.path.join(self.path, file_name) + np.save(file_path, data) + print('{}.npy saved'.format(file_name)) + self.file_cnt += 1 + + else: + self.knowledge_queue.put(data) + + def run(self, teacher_program, exe, place, scope, reader, inputs, outputs, + call_back): + """Start teacher model to do information. + Args: + teacher_program(Program): teacher program. + scope(Scope): The scope used to execute the teacher, + which contains the initialized variables. + reader(reader): The data reader used by the teacher. + inputs(list): The name of variables to feed the teacher program. + outputs(list): Need to write to the variable instance's names of + the Knowledge instance, which needs to correspond + to the Knowledge's items. + call_back(func, optional): The callback function that handles the + outputs of the teacher, which is none by default, + that is, the output of the teacher is concat directly. + Return: + (bool): Whether the teacher task was successfully registered and started + """ + assert (isinstance( + teacher_program, + fluid.Program)), "teacher_program should be a fluid.Program" + assert (isinstance(inputs, list)), "inputs should be a list" + if len(inputs) > 0: + assert (isinstance(inputs[0], str)), "inputs shoud be list" + assert (isinstance(outputs, list)), "outputs should be a list" + if len(outputs) > 0: + assert (isinstance(outputs[0], str)), "outputs should be list" + assert (len(self.items) == len(outputs) + ), "the length of outputs list should be equal with items list" + assert (callable(call_back) or (call_back is None) + ), "call_back should be a callable function or NoneType." + + for var in teacher_program.list_vars(): + var.stop_gradient = True + + compiled_teacher_program = fluid.compiler.CompiledProgram( + teacher_program) + teacher_knowledge = [] + self.file_cnt = 0 + if isinstance(reader, Variable) or ( + isinstance(reader, DataLoaderBase) and (not reader.iterable)): + reader.start() + try: + batch_id = 0 + while True: + logits = exe.run(compiled_teacher_program, + scope=scope, + fetch_list=outputs, + feed=None) + knowledge = dict() + for index, array in enumerate(logits): + knowledge[self.items[index]] = array + teacher_knowledge.append(knowledge) + if batch_id % 1 == 0: + print('infer finish iter {}'.format(batch_id)) + if len(teacher_knowledge) >= 4: + self._write(teacher_knowledge) + teacher_knowledge = [] + except EOFException: + reader.reset() + if len(teacher_knowledge) > 0: + self._write(teacher_knowledge) + + else: + feeder = fluid.DataFeeder( + feed_list=inputs, place=place, program=teacher_program) + for batch_id, data in enumerate(reader()): + feed = feeder.feed(data) + logits = exe.run(compiled_teacher_program, + scope=scope, + fetch_list=outputs, + feed=feed) + knowledge = dict() + for index, array in enumerate(logits): + knowledge[self.items[index]] = array + teacher_knowledge.append(knowledge) + if batch_id % 1 == 0: + print('infer finish iter {}'.format(batch_id)) + if len(teacher_knowledge) >= 4: + self._write(teacher_knowledge) + teacher_knowledge = [] + if len(teacher_knowledge) > 0: + self._write(teacher_knowledge) + return True + + def dist(self, student_program, losses): + """Building the distillation network + Args: + student_program(Program): student program. + losses(list, optional): The losses need to add. If set to None + does not add any loss. + Return: + (Program): Program for distillation. + (startup_program): Program for initializing distillation network. + (reader): Data reader for distillation training. + (Variable): Loss of distillation training + """ + + def loss(self, loss_func, *variables): + """User-defined loss + Args: + loss_func(func): Function used to define loss. + *variables(list): Variable name list. + Return: + (Variable): Distillation loss. + """ + pass + + def fsp_loss(self): + """fsp loss + """ + pass + + def l2_loss(self): + """l2 loss + """ + pass + + def softlabel_loss(self): + """softlabel_loss + """ + pass -- GitLab