未验证 提交 7595ba6d 编写于 作者: wc晨曦's avatar wc晨曦 提交者: GitHub

add AFD (#1683)

* add AFD
上级 b27acf6a
...@@ -27,8 +27,9 @@ from ppcls.arch.backbone.base.theseus_layer import TheseusLayer ...@@ -27,8 +27,9 @@ from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.arch.slim import prune_model, quantize_model from ppcls.arch.slim import prune_model, quantize_model
from ppcls.arch.distill.afd_attention import LinearTransformStudent, LinearTransformTeacher
__all__ = ["build_model", "RecModel", "DistillationModel"] __all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel"]
def build_model(config): def build_model(config):
...@@ -132,3 +133,24 @@ class DistillationModel(nn.Layer): ...@@ -132,3 +133,24 @@ class DistillationModel(nn.Layer):
else: else:
result_dict[model_name] = self.model_list[idx](x, label) result_dict[model_name] = self.model_list[idx](x, label)
return result_dict return result_dict
class AttentionModel(DistillationModel):
def __init__(self,
models=None,
pretrained_list=None,
freeze_params_list=None,
**kargs):
super().__init__(models, pretrained_list, freeze_params_list, **kargs)
def forward(self, x, label=None):
result_dict = dict()
out = x
for idx, model_name in enumerate(self.model_name_list):
if label is None:
out = self.model_list[idx](out)
result_dict.update(out)
else:
out = self.model_list[idx](out, label)
result_dict.update(out)
return result_dict
...@@ -35,7 +35,7 @@ class TheseusLayer(nn.Layer): ...@@ -35,7 +35,7 @@ class TheseusLayer(nn.Layer):
self.quanter = None self.quanter = None
def _return_dict_hook(self, layer, input, output): def _return_dict_hook(self, layer, input, output):
res_dict = {"output": output} res_dict = {"logits": output}
# 'list' is needed to avoid error raised by popping self.res_dict # 'list' is needed to avoid error raised by popping self.res_dict
for res_key in list(self.res_dict): for res_key in list(self.res_dict):
# clear the res_dict because the forward process may change according to input # clear the res_dict because the forward process may change according to input
......
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.nn as nn
import paddle.nn.functional as F
import paddle
import numpy as np
class LinearBNReLU(nn.Layer):
def __init__(self, nin, nout):
super().__init__()
self.linear = nn.Linear(nin, nout)
self.bn = nn.BatchNorm1D(nout)
self.relu = nn.ReLU()
def forward(self, x, relu=True):
if relu:
return self.relu(self.bn(self.linear(x)))
return self.bn(self.linear(x))
def unique_shape(s_shapes):
n_s = []
unique_shapes = []
n = -1
for s_shape in s_shapes:
if s_shape not in unique_shapes:
unique_shapes.append(s_shape)
n += 1
n_s.append(n)
return n_s, unique_shapes
class LinearTransformTeacher(nn.Layer):
def __init__(self, qk_dim, t_shapes, keys):
super().__init__()
self.teacher_keys = keys
self.t_shapes = [[1] + t_i for t_i in t_shapes]
self.query_layer = nn.LayerList(
[LinearBNReLU(t_shape[1], qk_dim) for t_shape in self.t_shapes])
def forward(self, t_features_dict):
g_t = [t_features_dict[key] for key in self.teacher_keys]
bs = g_t[0].shape[0]
channel_mean = [f_t.mean(3).mean(2) for f_t in g_t]
spatial_mean = []
for i in range(len(g_t)):
c, h, w = g_t[i].shape[1:]
spatial_mean.append(g_t[i].pow(2).mean(1).reshape([bs, h * w]))
query = paddle.stack(
[
query_layer(
f_t, relu=False)
for f_t, query_layer in zip(channel_mean, self.query_layer)
],
axis=1)
value = [F.normalize(f_s, axis=1) for f_s in spatial_mean]
return {"query": query, "value": value}
class LinearTransformStudent(nn.Layer):
def __init__(self, qk_dim, t_shapes, s_shapes, keys):
super().__init__()
self.student_keys = keys
self.t_shapes = [[1] + t_i for t_i in t_shapes]
self.s_shapes = [[1] + s_i for s_i in s_shapes]
self.t = len(self.t_shapes)
self.s = len(self.s_shapes)
self.qk_dim = qk_dim
self.n_t, self.unique_t_shapes = unique_shape(self.t_shapes)
self.relu = nn.ReLU()
self.samplers = nn.LayerList(
[Sample(t_shape) for t_shape in self.unique_t_shapes])
self.key_layer = nn.LayerList([
LinearBNReLU(s_shape[1], self.qk_dim) for s_shape in self.s_shapes
])
self.bilinear = LinearBNReLU(qk_dim, qk_dim * len(self.t_shapes))
def forward(self, s_features_dict):
g_s = [s_features_dict[key] for key in self.student_keys]
bs = g_s[0].shape[0]
channel_mean = [f_s.mean(3).mean(2) for f_s in g_s]
spatial_mean = [sampler(g_s, bs) for sampler in self.samplers]
key = paddle.stack(
[
key_layer(f_s)
for key_layer, f_s in zip(self.key_layer, channel_mean)
],
axis=1).reshape([-1, self.qk_dim]) # Bs x h
bilinear_key = self.bilinear(
key, relu=False).reshape([bs, self.s, self.t, self.qk_dim])
value = [F.normalize(s_m, axis=2) for s_m in spatial_mean]
return {"bilinear_key": bilinear_key, "value": value}
class Sample(nn.Layer):
def __init__(self, t_shape):
super().__init__()
self.t_N, self.t_C, self.t_H, self.t_W = t_shape
self.sample = nn.AdaptiveAvgPool2D((self.t_H, self.t_W))
def forward(self, g_s, bs):
g_s = paddle.stack(
[
self.sample(f_s.pow(2).mean(
1, keepdim=True)).reshape([bs, self.t_H * self.t_W])
for f_s in g_s
],
axis=1)
return g_s
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
# model architecture
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
models:
- Teacher:
name: AttentionModel
pretrained_list:
freeze_params_list:
- True
- False
models:
- ResNet34:
name: ResNet34
pretrained: True
return_patterns: &t_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]",
"blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]",
"blocks[8]", "blocks[9]", "blocks[10]", "blocks[11]",
"blocks[12]", "blocks[13]", "blocks[14]", "blocks[15]"]
- LinearTransformTeacher:
name: LinearTransformTeacher
qk_dim: 128
keys: *t_keys
t_shapes: &t_shapes [[64, 56, 56], [64, 56, 56], [64, 56, 56], [128, 28, 28],
[128, 28, 28], [128, 28, 28], [128, 28, 28], [256, 14, 14],
[256, 14, 14], [256, 14, 14], [256, 14, 14], [256, 14, 14],
[256, 14, 14], [512, 7, 7], [512, 7, 7], [512, 7, 7]]
- Student:
name: AttentionModel
pretrained_list:
freeze_params_list:
- False
- False
models:
- ResNet18:
name: ResNet18
pretrained: False
return_patterns: &s_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]",
"blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]"]
- LinearTransformStudent:
name: LinearTransformStudent
qk_dim: 128
keys: *s_keys
s_shapes: &s_shapes [[64, 56, 56], [64, 56, 56], [128, 28, 28], [128, 28, 28],
[256, 14, 14], [256, 14, 14], [512, 7, 7], [512, 7, 7]]
t_shapes: *t_shapes
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
key: logits
- DistillationKLDivLoss:
weight: 0.9
model_name_pairs: [["Student", "Teacher"]]
temperature: 4
key: logits
- AFDLoss:
weight: 50.0
model_name_pair: ["Student", "Teacher"]
student_keys: ["bilinear_key", "value"]
teacher_keys: ["query", "value"]
s_shapes: *s_shapes
t_shapes: *t_shapes
Eval:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
Optimizer:
name: Momentum
momentum: 0.9
weight_decay: 1e-4
lr:
name: MultiStepDecay
learning_rate: 0.1
milestones: [30, 60, 90]
step_each_epoch: 1
gamma: 0.1
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: "./dataset/ILSVRC2012/"
cls_label_path: "./dataset/ILSVRC2012/train_list.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
interpolation: bicubic
backend: pil
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: "./dataset/ILSVRC2012/"
cls_label_path: "./dataset/ILSVRC2012/val_list.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: "docs/images/inference_deployment/whl_demo.jpg"
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: DistillationPostProcess
func: Topk
topk: 5
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
Metric:
Train:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Eval:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
...@@ -46,6 +46,8 @@ class Topk(object): ...@@ -46,6 +46,8 @@ class Topk(object):
return class_id_map return class_id_map
def __call__(self, x, file_names=None, multilabel=False): def __call__(self, x, file_names=None, multilabel=False):
if isinstance(x, dict):
x = x['logits']
assert isinstance(x, paddle.Tensor) assert isinstance(x, paddle.Tensor)
if file_names is not None: if file_names is not None:
assert x.shape[0] == len(file_names) assert x.shape[0] == len(file_names)
......
...@@ -459,5 +459,7 @@ class ExportModel(TheseusLayer): ...@@ -459,5 +459,7 @@ class ExportModel(TheseusLayer):
if self.infer_output_key is not None: if self.infer_output_key is not None:
x = x[self.infer_output_key] x = x[self.infer_output_key]
if self.out_act is not None: if self.out_act is not None:
if isinstance(x, dict):
x = x["logits"]
x = self.out_act(x) x = self.out_act(x)
return x return x
...@@ -99,6 +99,8 @@ def classification_eval(engine, epoch_id=0): ...@@ -99,6 +99,8 @@ def classification_eval(engine, epoch_id=0):
if isinstance(out, dict): if isinstance(out, dict):
if "Student" in out: if "Student" in out:
out = out["Student"] out = out["Student"]
if isinstance(out, dict):
out = out["logits"]
elif "logits" in out: elif "logits" in out:
out = out["logits"] out = out["logits"]
else: else:
......
...@@ -22,7 +22,9 @@ from .distillationloss import DistillationGTCELoss ...@@ -22,7 +22,9 @@ from .distillationloss import DistillationGTCELoss
from .distillationloss import DistillationDMLLoss from .distillationloss import DistillationDMLLoss
from .distillationloss import DistillationDistanceLoss from .distillationloss import DistillationDistanceLoss
from .distillationloss import DistillationRKDLoss from .distillationloss import DistillationRKDLoss
from .distillationloss import DistillationKLDivLoss
from .multilabelloss import MultiLabelLoss from .multilabelloss import MultiLabelLoss
from .afdloss import AFDLoss
from .deephashloss import DSHSDLoss, LCDSHLoss from .deephashloss import DSHSDLoss, LCDSHLoss
......
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.nn as nn
import paddle.nn.functional as F
import paddle
import numpy as np
import matplotlib.pyplot as plt
import cv2
import warnings
warnings.filterwarnings('ignore')
class LinearBNReLU(nn.Layer):
def __init__(self, nin, nout):
super().__init__()
self.linear = nn.Linear(nin, nout)
self.bn = nn.BatchNorm1D(nout)
self.relu = nn.ReLU()
def forward(self, x, relu=True):
if relu:
return self.relu(self.bn(self.linear(x)))
return self.bn(self.linear(x))
def unique_shape(s_shapes):
n_s = []
unique_shapes = []
n = -1
for s_shape in s_shapes:
if s_shape not in unique_shapes:
unique_shapes.append(s_shape)
n += 1
n_s.append(n)
return n_s, unique_shapes
class AFDLoss(nn.Layer):
"""
AFDLoss
https://www.aaai.org/AAAI21Papers/AAAI-9785.JiM.pdf
https://github.com/clovaai/attention-feature-distillation
"""
def __init__(self,
model_name_pair=["Student", "Teacher"],
student_keys=["bilinear_key", "value"],
teacher_keys=["query", "value"],
s_shapes=[[64, 16, 160], [128, 8, 160], [256, 4, 160],
[512, 2, 160]],
t_shapes=[[640, 48], [320, 96], [160, 192]],
qk_dim=128,
name="loss_afd"):
super().__init__()
assert isinstance(model_name_pair, list)
self.model_name_pair = model_name_pair
self.student_keys = student_keys
self.teacher_keys = teacher_keys
self.s_shapes = [[1] + s_i for s_i in s_shapes]
self.t_shapes = [[1] + t_i for t_i in t_shapes]
self.qk_dim = qk_dim
self.n_t, self.unique_t_shapes = unique_shape(self.t_shapes)
self.attention = Attention(self.qk_dim, self.t_shapes, self.s_shapes,
self.n_t, self.unique_t_shapes)
self.name = name
def forward(self, predicts, batch):
s_features_dict = predicts[self.model_name_pair[0]]
t_features_dict = predicts[self.model_name_pair[1]]
g_s = [s_features_dict[key] for key in self.student_keys]
g_t = [t_features_dict[key] for key in self.teacher_keys]
loss = self.attention(g_s, g_t)
sum_loss = sum(loss)
loss_dict = dict()
loss_dict[self.name] = sum_loss
return loss_dict
class Attention(nn.Layer):
def __init__(self, qk_dim, t_shapes, s_shapes, n_t, unique_t_shapes):
super().__init__()
self.qk_dim = qk_dim
self.n_t = n_t
# self.linear_trans_s = LinearTransformStudent(qk_dim, t_shapes, s_shapes, unique_t_shapes)
# self.linear_trans_t = LinearTransformTeacher(qk_dim, t_shapes)
self.p_t = self.create_parameter(
shape=[len(t_shapes), qk_dim],
default_initializer=nn.initializer.XavierNormal())
self.p_s = self.create_parameter(
shape=[len(s_shapes), qk_dim],
default_initializer=nn.initializer.XavierNormal())
def forward(self, g_s, g_t):
bilinear_key, h_hat_s_all = g_s
query, h_t_all = g_t
p_logit = paddle.matmul(self.p_t, self.p_s.t())
logit = paddle.add(
paddle.einsum('bstq,btq->bts', bilinear_key, query),
p_logit) / np.sqrt(self.qk_dim)
atts = F.softmax(logit, axis=2) # b x t x s
loss = []
for i, (n, h_t) in enumerate(zip(self.n_t, h_t_all)):
h_hat_s = h_hat_s_all[n]
diff = self.cal_diff(h_hat_s, h_t, atts[:, i])
loss.append(diff)
return loss
def cal_diff(self, v_s, v_t, att):
diff = (v_s - v_t.unsqueeze(1)).pow(2).mean(2)
diff = paddle.multiply(diff, att).sum(1).mean()
return diff
...@@ -14,11 +14,13 @@ ...@@ -14,11 +14,13 @@
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F
from .celoss import CELoss from .celoss import CELoss
from .dmlloss import DMLLoss from .dmlloss import DMLLoss
from .distanceloss import DistanceLoss from .distanceloss import DistanceLoss
from .rkdloss import RKdAngle, RkdDistance from .rkdloss import RKdAngle, RkdDistance
from .kldivloss import KLDivLoss
class DistillationCELoss(CELoss): class DistillationCELoss(CELoss):
...@@ -172,3 +174,33 @@ class DistillationRKDLoss(nn.Layer): ...@@ -172,3 +174,33 @@ class DistillationRKDLoss(nn.Layer):
student_out, teacher_out) student_out, teacher_out)
return loss_dict return loss_dict
class DistillationKLDivLoss(KLDivLoss):
"""
DistillationKLDivLoss
"""
def __init__(self,
model_name_pairs=[],
temperature=4,
key=None,
name="loss_kl"):
super().__init__(temperature=temperature)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
for key in loss:
loss_dict["{}_{}_{}".format(key, pair[0], pair[1])] = loss[key]
return loss_dict
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class KLDivLoss(nn.Layer):
"""
Distilling the Knowledge in a Neural Network
"""
def __init__(self, temperature=4):
super(KLDivLoss, self).__init__()
self.T = temperature
def forward(self, y_s, y_t):
p_s = F.log_softmax(y_s / self.T, axis=1)
p_t = F.softmax(y_t / self.T, axis=1)
loss = F.kl_div(p_s, p_t, reduction='sum') * (self.T**2) / y_s.shape[0]
return {"loss_kldiv": loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册