From d2cc96636be47918176b282621c425f59f918c64 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Wed, 28 Jul 2021 10:22:59 +0800 Subject: [PATCH] [distill] how to get feature map (#799) --- paddleslim/dygraph/dist/__init__.py | 5 + paddleslim/dygraph/dist/distill.py | 186 ++++++++++++++++++++++++++++ tests/dygraph/test_distill.py | 167 +++++++++++++++++++++++++ 3 files changed, 358 insertions(+) create mode 100644 paddleslim/dygraph/dist/distill.py create mode 100644 tests/dygraph/test_distill.py diff --git a/paddleslim/dygraph/dist/__init__.py b/paddleslim/dygraph/dist/__init__.py index 46b14270..e98e53df 100644 --- a/paddleslim/dygraph/dist/__init__.py +++ b/paddleslim/dygraph/dist/__init__.py @@ -12,4 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import distill +from .distill import * + __all__ = [] + +__all__ += distill.__all__ diff --git a/paddleslim/dygraph/dist/distill.py b/paddleslim/dygraph/dist/distill.py new file mode 100644 index 00000000..2571293b --- /dev/null +++ b/paddleslim/dygraph/dist/distill.py @@ -0,0 +1,186 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import numpy as np +import collections +from collections import namedtuple +import paddle.nn as nn +from .losses import * + +__all__ = ['Distill', 'AdaptorBase'] + + +class LayerConfig: + def __init__(self, + s_feature_idx, + t_feature_idx, + feature_type, + loss_function, + weight=1.0, + align=False, + align_shape=None): + self.s_feature_idx = s_feature_idx + self.t_feature_idx = t_feature_idx + self.feature_type = feature_type + if loss_function in ['l1', 'l2', 'smooth_l1']: + self.loss_function = 'DistillationDistanceLoss' + elif loss_function in ['dml']: + self.loss_function = 'DistillationDMLLoss' + elif loss_function in ['rkl']: + self.loss_function = 'DistillationRKDLoss' + else: + raise NotImplementedError("loss function is not support!!!") + self.weight = weight + self.align = align + self.align_shape = align_shape + + +class AdaptorBase: + def __init__(self, model): + self.model = model + self.add_tensor = False + + def _get_activation(self, outs, name): + def get_output_hook(layer, input, output): + outs[name] = output + + return get_output_hook + + def _add_distill_hook(self, outs, mapping_layers_name, layers_type): + """ + Get output by name. + outs(dict): save the middle outputs of model according to the name. + mapping_layers(list): name of middle layers. + layers_type(list): type of the middle layers to calculate distill loss. + """ + ### TODO: support DP model + for idx, (n, m) in enumerate(self.model.named_sublayers()): + if n in mapping_layers_name: + midx = mapping_layers_name.index(n) + m.register_forward_post_hook( + self._get_activation(outs, layers_type[midx])) + + def mapping_layers(self): + raise NotImplementedError("function mapping_layers is not implemented") + + +class Distill(nn.Layer): + ### TODO: support list of student model and teacher model + def __init__(self, distill_configs, student_models, teacher_models, + adaptors_S, adaptors_T): + super(Distill, self).__init__() + self._distill_configs = distill_configs + self._student_models = student_models + self._teacher_models = teacher_models + self._adaptors_S = adaptors_S(self._student_models) + self._adaptors_T = adaptors_T(self._teacher_models) + + self.stu_outs_dict, self.tea_outs_dict = self._prepare_outputs() + + self.configs = [] + for c in self._distill_configs: + self.configs.append(LayerConfig(**c).__dict__) + + self.distill_idx = self._get_distill_idx() + self._loss_config_list = [] + for c in self.configs: + loss_config = {} + loss_config[str(c['loss_function'])] = {} + loss_config[str(c['loss_function'])]['weight'] = c['weight'] + loss_config[str(c['loss_function'])]['key'] = c[ + 'feature_type'] + '_' + str(c['s_feature_idx']) + '_' + str(c[ + 't_feature_idx']) + ### TODO: support list of student models and teacher_models + loss_config[str(c['loss_function'])][ + 'model_name_pairs'] = [['student', 'teacher']] + self._loss_config_list.append(loss_config) + self._prepare_loss() + + def _prepare_hook(self, adaptors, outs_dict): + mapping_layers = adaptors.mapping_layers() + for layer_type, layer in mapping_layers.items(): + if isinstance(layer, str): + adaptors._add_distill_hook(outs_dict, [layer], [layer_type]) + return outs_dict + + def _get_model_intermediate_output(self, adaptors, outs_dict): + mapping_layers = adaptors.mapping_layers() + for layer_type, layer in mapping_layers.items(): + if isinstance(layer, str): + continue + outs_dict[layer_type] = layer + return outs_dict + + def _get_distill_idx(self): + distill_idx = {} + for config in self._distill_configs: + if config['feature_type'] not in distill_idx: + distill_idx[config['feature_type']] = [[ + int(config['s_feature_idx']), int(config['t_feature_idx']) + ]] + else: + distill_idx[config['feature_type']].append([ + int(config['s_feature_idx']), int(config['t_feature_idx']) + ]) + return distill_idx + + def _prepare_loss(self): + self.distill_loss = CombinedLoss(self._loss_config_list) + + def _prepare_outputs(self): + stu_outs_dict = collections.OrderedDict() + tea_outs_dict = collections.OrderedDict() + stu_outs_dict = self._prepare_hook(self._adaptors_S, stu_outs_dict) + tea_outs_dict = self._prepare_hook(self._adaptors_T, tea_outs_dict) + return stu_outs_dict, tea_outs_dict + + def _post_outputs(self): + final_keys = [] + for key, value in self.stu_outs_dict.items(): + if len(key.split('_')) == 1: + final_keys.append(key) + + ### TODO: support list of student models and teacher_models + final_distill_dict = { + "student": collections.OrderedDict(), + "teacher": collections.OrderedDict() + } + + for feature_type, dist_idx in self.distill_idx.items(): + for idx, idx_list in enumerate(dist_idx): + sidx, tidx = idx_list[0], idx_list[1] + final_distill_dict['student'][feature_type + '_' + str( + sidx) + '_' + str(tidx)] = self.stu_outs_dict[ + feature_type + '_' + str(sidx)] + final_distill_dict['teacher'][feature_type + '_' + str( + sidx) + '_' + str(tidx)] = self.tea_outs_dict[ + feature_type + '_' + str(tidx)] + return final_distill_dict + + def forward(self, *inputs, **kwargs): + stu_batch_outs = self._student_models.forward(*inputs, **kwargs) + tea_batch_outs = self._teacher_models.forward(*inputs, **kwargs) + if self._adaptors_S.add_tensor == False: + self._adaptors_S.add_tensor = True + if self._adaptors_T.add_tensor == False: + self._adaptors_T.add_tensor = True + self.stu_outs_dict = self._get_model_intermediate_output( + self._adaptors_S, self.stu_outs_dict) + self.tea_outs_dict = self._get_model_intermediate_output( + self._adaptors_T, self.tea_outs_dict) + distill_inputs = self._post_outputs() + ### batch is None just for now + distill_outputs = self.distill_loss(distill_inputs, None) + distill_loss = distill_outputs['loss'] + return stu_batch_outs, tea_batch_outs, distill_loss diff --git a/tests/dygraph/test_distill.py b/tests/dygraph/test_distill.py new file mode 100644 index 00000000..84e270ba --- /dev/null +++ b/tests/dygraph/test_distill.py @@ -0,0 +1,167 @@ +import sys +sys.path.append("../../") +import logging +import numpy as np +import unittest +import paddle +import paddle.nn as nn +from paddle.vision.models import MobileNetV1 +import paddle.vision.transforms as T +from paddleslim.dygraph.dist import Distill, AdaptorBase +from paddleslim.common.log_helper import get_logger + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class TestImperativeDistill(unittest.TestCase): + def setUp(self): + self.s_model, self.t_model = self.prepare_model() + self.t_model.eval() + self.distill_configs = self.prepare_config() + self.adaptor = self.prepare_adaptor() + + def prepare_model(self): + return MobileNetV1(), MobileNetV1() + + def prepare_config(self): + distill_configs = [{ + 's_feature_idx': 0, + 't_feature_idx': 0, + 'feature_type': 'hidden', + 'loss_function': 'l2' + }, { + 's_feature_idx': 1, + 't_feature_idx': 1, + 'feature_type': 'hidden', + 'loss_function': 'l2' + }, { + 's_feature_idx': 0, + 't_feature_idx': 0, + 'feature_type': 'logits', + 'loss_function': 'l2' + }] + return distill_configs + + def prepare_adaptor(self): + class Adaptor(AdaptorBase): + def mapping_layers(self): + mapping_layers = {} + mapping_layers['hidden_0'] = 'conv1' + mapping_layers['hidden_1'] = 'conv2_2' + mapping_layers['hidden_2'] = 'conv3_2' + mapping_layers['logits_0'] = 'fc' + return mapping_layers + + return Adaptor + + def test_distill(self): + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + + train_dataset = paddle.vision.datasets.Cifar10( + mode='train', backend='cv2', transform=transform) + val_dataset = paddle.vision.datasets.Cifar10( + mode='test', backend='cv2', transform=transform) + + place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( + ) else paddle.CPUPlace() + train_reader = paddle.io.DataLoader( + train_dataset, drop_last=True, places=place, batch_size=64) + test_reader = paddle.io.DataLoader( + val_dataset, places=place, batch_size=64) + + def test(model): + model.eval() + avg_acc = [[], []] + for batch_id, data in enumerate(test_reader): + img = paddle.to_tensor(data[0]) + label = paddle.to_tensor(data[1]) + label = paddle.reshape(label, [-1, 1]) + out = model(img) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + avg_acc[0].append(acc_top1.numpy()) + avg_acc[1].append(acc_top5.numpy()) + if batch_id % 100 == 0: + _logger.info( + "Test | step {}: acc1 = {:}, acc5 = {:}".format( + batch_id, acc_top1.numpy(), acc_top5.numpy())) + + _logger.info("Test |Average: acc_top1 {}, acc_top5 {}".format( + np.mean(avg_acc[0]), np.mean(avg_acc[1]))) + + def train(model): + adam = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + + for batch_id, data in enumerate(train_reader): + img = paddle.to_tensor(data[0]) + label = paddle.to_tensor(data[1]) + student_out, teacher_out, distill_loss = model(img) + loss = paddle.nn.functional.loss.cross_entropy(student_out, + label) + avg_loss = paddle.mean(loss) + all_loss = avg_loss + distill_loss + all_loss.backward() + adam.step() + adam.clear_grad() + if batch_id % 100 == 0: + _logger.info("Train | At epoch {} step {}: loss = {:}". + format(str(0), batch_id, all_loss.numpy())) + test(self.s_model) + self.s_model.train() + + distill_model = Distill(self.distill_configs, self.s_model, + self.t_model, self.adaptor, self.adaptor) + train(distill_model) + + +class TestImperativeDistillCase1(TestImperativeDistill): + def prepare_model(self): + class Model(nn.Layer): + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.Conv2D(3, 3, 3, padding=1) + self.conv2 = nn.Conv2D(3, 3, 3, padding=1) + self.conv3 = nn.Conv2D(3, 3, 3, padding=1) + self.fc = nn.Linear(3072, 10) + + def forward(self, x): + self.conv1_out = self.conv1(x) + conv2_out = self.conv2(self.conv1_out) + self.conv3_out = self.conv3(conv2_out) + out = paddle.reshape(self.conv3_out, shape=[x.shape[0], -1]) + out = self.fc(out) + return out + + return Model(), Model() + + def prepare_adaptor(self): + class Adaptor(AdaptorBase): + def mapping_layers(self): + mapping_layers = {} + mapping_layers['hidden_1'] = 'conv2' + if self.add_tensor: + mapping_layers['hidden_0'] = self.model.conv1_out + mapping_layers['hidden_2'] = self.model.conv3_out + return mapping_layers + + return Adaptor + + def prepare_config(self): + distill_configs = [{ + 's_feature_idx': 0, + 't_feature_idx': 0, + 'feature_type': 'hidden', + 'loss_function': 'l2' + }, { + 's_feature_idx': 1, + 't_feature_idx': 2, + 'feature_type': 'hidden', + 'loss_function': 'l2' + }] + return distill_configs + + +if __name__ == '__main__': + unittest.main() -- GitLab