diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index 2d5e29db865118a145924729e81ff423de5f174b..da21e101a27eb0db2c05b658346148bda3139c80 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -27,8 +27,9 @@ from ppcls.arch.backbone.base.theseus_layer import TheseusLayer from ppcls.utils import logger from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.arch.slim import prune_model, quantize_model +from ppcls.arch.distill.afd_attention import LinearTransformStudent, LinearTransformTeacher -__all__ = ["build_model", "RecModel", "DistillationModel"] +__all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel"] def build_model(config): @@ -132,3 +133,24 @@ class DistillationModel(nn.Layer): else: result_dict[model_name] = self.model_list[idx](x, label) return result_dict + + +class AttentionModel(DistillationModel): + def __init__(self, + models=None, + pretrained_list=None, + freeze_params_list=None, + **kargs): + super().__init__(models, pretrained_list, freeze_params_list, **kargs) + + def forward(self, x, label=None): + result_dict = dict() + out = x + for idx, model_name in enumerate(self.model_name_list): + if label is None: + out = self.model_list[idx](out) + result_dict.update(out) + else: + out = self.model_list[idx](out, label) + result_dict.update(out) + return result_dict diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 908d94445e035394f9b7b3f1fc72a8431435b223..9f3f596a85a4ca7336779b88e77053bf19458c7d 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -35,7 +35,7 @@ class TheseusLayer(nn.Layer): self.quanter = None def _return_dict_hook(self, layer, input, output): - res_dict = {"output": output} + res_dict = {"logits": output} # 'list' is needed to avoid error raised by popping self.res_dict for res_key in list(self.res_dict): # clear the res_dict because the forward process may change according to input diff --git a/ppcls/arch/distill/afd_attention.py b/ppcls/arch/distill/afd_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..63b094f313c08f0963a6afbe72361430cbaf789e --- /dev/null +++ b/ppcls/arch/distill/afd_attention.py @@ -0,0 +1,123 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle.nn as nn +import paddle.nn.functional as F +import paddle +import numpy as np + + +class LinearBNReLU(nn.Layer): + def __init__(self, nin, nout): + super().__init__() + self.linear = nn.Linear(nin, nout) + self.bn = nn.BatchNorm1D(nout) + self.relu = nn.ReLU() + + def forward(self, x, relu=True): + if relu: + return self.relu(self.bn(self.linear(x))) + return self.bn(self.linear(x)) + + +def unique_shape(s_shapes): + n_s = [] + unique_shapes = [] + n = -1 + for s_shape in s_shapes: + if s_shape not in unique_shapes: + unique_shapes.append(s_shape) + n += 1 + n_s.append(n) + return n_s, unique_shapes + + +class LinearTransformTeacher(nn.Layer): + def __init__(self, qk_dim, t_shapes, keys): + super().__init__() + self.teacher_keys = keys + self.t_shapes = [[1] + t_i for t_i in t_shapes] + self.query_layer = nn.LayerList( + [LinearBNReLU(t_shape[1], qk_dim) for t_shape in self.t_shapes]) + + def forward(self, t_features_dict): + g_t = [t_features_dict[key] for key in self.teacher_keys] + bs = g_t[0].shape[0] + channel_mean = [f_t.mean(3).mean(2) for f_t in g_t] + spatial_mean = [] + for i in range(len(g_t)): + c, h, w = g_t[i].shape[1:] + spatial_mean.append(g_t[i].pow(2).mean(1).reshape([bs, h * w])) + query = paddle.stack( + [ + query_layer( + f_t, relu=False) + for f_t, query_layer in zip(channel_mean, self.query_layer) + ], + axis=1) + value = [F.normalize(f_s, axis=1) for f_s in spatial_mean] + return {"query": query, "value": value} + + +class LinearTransformStudent(nn.Layer): + def __init__(self, qk_dim, t_shapes, s_shapes, keys): + super().__init__() + self.student_keys = keys + self.t_shapes = [[1] + t_i for t_i in t_shapes] + self.s_shapes = [[1] + s_i for s_i in s_shapes] + self.t = len(self.t_shapes) + self.s = len(self.s_shapes) + self.qk_dim = qk_dim + self.n_t, self.unique_t_shapes = unique_shape(self.t_shapes) + self.relu = nn.ReLU() + self.samplers = nn.LayerList( + [Sample(t_shape) for t_shape in self.unique_t_shapes]) + self.key_layer = nn.LayerList([ + LinearBNReLU(s_shape[1], self.qk_dim) for s_shape in self.s_shapes + ]) + self.bilinear = LinearBNReLU(qk_dim, qk_dim * len(self.t_shapes)) + + def forward(self, s_features_dict): + g_s = [s_features_dict[key] for key in self.student_keys] + bs = g_s[0].shape[0] + channel_mean = [f_s.mean(3).mean(2) for f_s in g_s] + spatial_mean = [sampler(g_s, bs) for sampler in self.samplers] + + key = paddle.stack( + [ + key_layer(f_s) + for key_layer, f_s in zip(self.key_layer, channel_mean) + ], + axis=1).reshape([-1, self.qk_dim]) # Bs x h + bilinear_key = self.bilinear( + key, relu=False).reshape([bs, self.s, self.t, self.qk_dim]) + value = [F.normalize(s_m, axis=2) for s_m in spatial_mean] + return {"bilinear_key": bilinear_key, "value": value} + + +class Sample(nn.Layer): + def __init__(self, t_shape): + super().__init__() + self.t_N, self.t_C, self.t_H, self.t_W = t_shape + self.sample = nn.AdaptiveAvgPool2D((self.t_H, self.t_W)) + + def forward(self, g_s, bs): + g_s = paddle.stack( + [ + self.sample(f_s.pow(2).mean( + 1, keepdim=True)).reshape([bs, self.t_H * self.t_W]) + for f_s in g_s + ], + axis=1) + return g_s diff --git a/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_afd.yaml b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_afd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e5b8b716222316c0fca80a69154b0c937e6c52da --- /dev/null +++ b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_afd.yaml @@ -0,0 +1,202 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output/" + device: "gpu" + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 100 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: "./inference" + +# model architecture +Arch: + name: "DistillationModel" + # if not null, its lengths should be same as models + pretrained_list: + # if not null, its lengths should be same as models + freeze_params_list: + models: + - Teacher: + name: AttentionModel + pretrained_list: + freeze_params_list: + - True + - False + models: + - ResNet34: + name: ResNet34 + pretrained: True + return_patterns: &t_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]", + "blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]", + "blocks[8]", "blocks[9]", "blocks[10]", "blocks[11]", + "blocks[12]", "blocks[13]", "blocks[14]", "blocks[15]"] + - LinearTransformTeacher: + name: LinearTransformTeacher + qk_dim: 128 + keys: *t_keys + t_shapes: &t_shapes [[64, 56, 56], [64, 56, 56], [64, 56, 56], [128, 28, 28], + [128, 28, 28], [128, 28, 28], [128, 28, 28], [256, 14, 14], + [256, 14, 14], [256, 14, 14], [256, 14, 14], [256, 14, 14], + [256, 14, 14], [512, 7, 7], [512, 7, 7], [512, 7, 7]] + + - Student: + name: AttentionModel + pretrained_list: + freeze_params_list: + - False + - False + models: + - ResNet18: + name: ResNet18 + pretrained: False + return_patterns: &s_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]", + "blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]"] + - LinearTransformStudent: + name: LinearTransformStudent + qk_dim: 128 + keys: *s_keys + s_shapes: &s_shapes [[64, 56, 56], [64, 56, 56], [128, 28, 28], [128, 28, 28], + [256, 14, 14], [256, 14, 14], [512, 7, 7], [512, 7, 7]] + t_shapes: *t_shapes + + infer_model_name: "Student" + + +# loss function config for traing/eval process +Loss: + Train: + - DistillationGTCELoss: + weight: 1.0 + model_names: ["Student"] + key: logits + - DistillationKLDivLoss: + weight: 0.9 + model_name_pairs: [["Student", "Teacher"]] + temperature: 4 + key: logits + - AFDLoss: + weight: 50.0 + model_name_pair: ["Student", "Teacher"] + student_keys: ["bilinear_key", "value"] + teacher_keys: ["query", "value"] + s_shapes: *s_shapes + t_shapes: *t_shapes + Eval: + - DistillationGTCELoss: + weight: 1.0 + model_names: ["Student"] + + +Optimizer: + name: Momentum + momentum: 0.9 + weight_decay: 1e-4 + lr: + name: MultiStepDecay + learning_rate: 0.1 + milestones: [30, 60, 90] + step_each_epoch: 1 + gamma: 0.1 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: "./dataset/ILSVRC2012/" + cls_label_path: "./dataset/ILSVRC2012/train_list.txt" + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: "./dataset/ILSVRC2012/" + cls_label_path: "./dataset/ILSVRC2012/val_list.txt" + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: "docs/images/inference_deployment/whl_demo.jpg" + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: DistillationPostProcess + func: Topk + topk: 5 + class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt" + +Metric: + Train: + - DistillationTopkAcc: + model_key: "Student" + topk: [1, 5] + Eval: + - DistillationTopkAcc: + model_key: "Student" + topk: [1, 5] diff --git a/ppcls/data/postprocess/topk.py b/ppcls/data/postprocess/topk.py index 9c1371bfd11f4c93f06c82436e88e0ff20a57b35..0dde72aa11274996087156b2e026447c340f421a 100644 --- a/ppcls/data/postprocess/topk.py +++ b/ppcls/data/postprocess/topk.py @@ -46,6 +46,8 @@ class Topk(object): return class_id_map def __call__(self, x, file_names=None, multilabel=False): + if isinstance(x, dict): + x = x['logits'] assert isinstance(x, paddle.Tensor) if file_names is not None: assert x.shape[0] == len(file_names) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 7f04221c8664e9b8a7a7c3624d550c044409ddce..c991b6e0504c3f103ec487cf90a21c461d05ed6a 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -459,5 +459,7 @@ class ExportModel(TheseusLayer): if self.infer_output_key is not None: x = x[self.infer_output_key] if self.out_act is not None: + if isinstance(x, dict): + x = x["logits"] x = self.out_act(x) return x diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index d7b5c47620bcd03b9ef8ddd44deeea0621ca041d..994eeb5ee3c814ccaf3f23331d19259e66b18f08 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -99,6 +99,8 @@ def classification_eval(engine, epoch_id=0): if isinstance(out, dict): if "Student" in out: out = out["Student"] + if isinstance(out, dict): + out = out["logits"] elif "logits" in out: out = out["logits"] else: diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index d15dab9da36c02c077b2b7f871d98f57582dbd85..0f4893f6ade35a4a2f9d61d0cefef5f7dd107461 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -22,7 +22,9 @@ from .distillationloss import DistillationGTCELoss from .distillationloss import DistillationDMLLoss from .distillationloss import DistillationDistanceLoss from .distillationloss import DistillationRKDLoss +from .distillationloss import DistillationKLDivLoss from .multilabelloss import MultiLabelLoss +from .afdloss import AFDLoss from .deephashloss import DSHSDLoss, LCDSHLoss diff --git a/ppcls/loss/afdloss.py b/ppcls/loss/afdloss.py new file mode 100644 index 0000000000000000000000000000000000000000..3e67e30b98df61576e40449015cc67a13dd6da60 --- /dev/null +++ b/ppcls/loss/afdloss.py @@ -0,0 +1,132 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle.nn as nn +import paddle.nn.functional as F +import paddle +import numpy as np +import matplotlib.pyplot as plt +import cv2 +import warnings +warnings.filterwarnings('ignore') + + +class LinearBNReLU(nn.Layer): + def __init__(self, nin, nout): + super().__init__() + self.linear = nn.Linear(nin, nout) + self.bn = nn.BatchNorm1D(nout) + self.relu = nn.ReLU() + + def forward(self, x, relu=True): + if relu: + return self.relu(self.bn(self.linear(x))) + return self.bn(self.linear(x)) + + +def unique_shape(s_shapes): + n_s = [] + unique_shapes = [] + n = -1 + for s_shape in s_shapes: + if s_shape not in unique_shapes: + unique_shapes.append(s_shape) + n += 1 + n_s.append(n) + return n_s, unique_shapes + + +class AFDLoss(nn.Layer): + """ + AFDLoss + https://www.aaai.org/AAAI21Papers/AAAI-9785.JiM.pdf + https://github.com/clovaai/attention-feature-distillation + """ + + def __init__(self, + model_name_pair=["Student", "Teacher"], + student_keys=["bilinear_key", "value"], + teacher_keys=["query", "value"], + s_shapes=[[64, 16, 160], [128, 8, 160], [256, 4, 160], + [512, 2, 160]], + t_shapes=[[640, 48], [320, 96], [160, 192]], + qk_dim=128, + name="loss_afd"): + super().__init__() + assert isinstance(model_name_pair, list) + self.model_name_pair = model_name_pair + self.student_keys = student_keys + self.teacher_keys = teacher_keys + self.s_shapes = [[1] + s_i for s_i in s_shapes] + self.t_shapes = [[1] + t_i for t_i in t_shapes] + self.qk_dim = qk_dim + self.n_t, self.unique_t_shapes = unique_shape(self.t_shapes) + self.attention = Attention(self.qk_dim, self.t_shapes, self.s_shapes, + self.n_t, self.unique_t_shapes) + self.name = name + + def forward(self, predicts, batch): + s_features_dict = predicts[self.model_name_pair[0]] + t_features_dict = predicts[self.model_name_pair[1]] + + g_s = [s_features_dict[key] for key in self.student_keys] + g_t = [t_features_dict[key] for key in self.teacher_keys] + + loss = self.attention(g_s, g_t) + sum_loss = sum(loss) + + loss_dict = dict() + loss_dict[self.name] = sum_loss + + return loss_dict + + +class Attention(nn.Layer): + def __init__(self, qk_dim, t_shapes, s_shapes, n_t, unique_t_shapes): + super().__init__() + self.qk_dim = qk_dim + self.n_t = n_t + # self.linear_trans_s = LinearTransformStudent(qk_dim, t_shapes, s_shapes, unique_t_shapes) + # self.linear_trans_t = LinearTransformTeacher(qk_dim, t_shapes) + + self.p_t = self.create_parameter( + shape=[len(t_shapes), qk_dim], + default_initializer=nn.initializer.XavierNormal()) + self.p_s = self.create_parameter( + shape=[len(s_shapes), qk_dim], + default_initializer=nn.initializer.XavierNormal()) + + def forward(self, g_s, g_t): + bilinear_key, h_hat_s_all = g_s + query, h_t_all = g_t + + p_logit = paddle.matmul(self.p_t, self.p_s.t()) + + logit = paddle.add( + paddle.einsum('bstq,btq->bts', bilinear_key, query), + p_logit) / np.sqrt(self.qk_dim) + atts = F.softmax(logit, axis=2) # b x t x s + + loss = [] + + for i, (n, h_t) in enumerate(zip(self.n_t, h_t_all)): + h_hat_s = h_hat_s_all[n] + diff = self.cal_diff(h_hat_s, h_t, atts[:, i]) + loss.append(diff) + return loss + + def cal_diff(self, v_s, v_t, att): + diff = (v_s - v_t.unsqueeze(1)).pow(2).mean(2) + diff = paddle.multiply(diff, att).sum(1).mean() + return diff diff --git a/ppcls/loss/distillationloss.py b/ppcls/loss/distillationloss.py index 0340234b9d2374adc97aa4be8b8c2e50ce297e6f..21e5ef371d380e2df50cb70b2e4796725bf277e4 100644 --- a/ppcls/loss/distillationloss.py +++ b/ppcls/loss/distillationloss.py @@ -14,11 +14,13 @@ import paddle import paddle.nn as nn +import paddle.nn.functional as F from .celoss import CELoss from .dmlloss import DMLLoss from .distanceloss import DistanceLoss from .rkdloss import RKdAngle, RkdDistance +from .kldivloss import KLDivLoss class DistillationCELoss(CELoss): @@ -172,3 +174,33 @@ class DistillationRKDLoss(nn.Layer): student_out, teacher_out) return loss_dict + + +class DistillationKLDivLoss(KLDivLoss): + """ + DistillationKLDivLoss + """ + + def __init__(self, + model_name_pairs=[], + temperature=4, + key=None, + name="loss_kl"): + super().__init__(temperature=temperature) + assert isinstance(model_name_pairs, list) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2) + for key in loss: + loss_dict["{}_{}_{}".format(key, pair[0], pair[1])] = loss[key] + return loss_dict diff --git a/ppcls/loss/kldivloss.py b/ppcls/loss/kldivloss.py new file mode 100644 index 0000000000000000000000000000000000000000..da6ab02fbb06874e0b54a6985f3c82f95db68979 --- /dev/null +++ b/ppcls/loss/kldivloss.py @@ -0,0 +1,33 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class KLDivLoss(nn.Layer): + """ + Distilling the Knowledge in a Neural Network + """ + + def __init__(self, temperature=4): + super(KLDivLoss, self).__init__() + self.T = temperature + + def forward(self, y_s, y_t): + p_s = F.log_softmax(y_s / self.T, axis=1) + p_t = F.softmax(y_t / self.T, axis=1) + loss = F.kl_div(p_s, p_t, reduction='sum') * (self.T**2) / y_s.shape[0] + return {"loss_kldiv": loss}