未验证 提交 283ae9b3 编写于 作者: wc晨曦's avatar wc晨曦 提交者: GitHub

add dkd (#1888)

* add dkd

* update dkd

* update dkd

* update dkd

* update dkd

* update dkd

* update dkd and add tipc
上级 f5da9044
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
# model architecture
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
models:
- Teacher:
name: ResNet34
pretrained: True
- Student:
name: ResNet18
pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationDKDLoss:
weight: 1.0
model_name_pairs: [["Student", "Teacher"]]
temperature: 1
alpha: 1.0
beta: 1.0
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
weight_decay: 1e-4
lr:
name: MultiStepDecay
learning_rate: 0.2
milestones: [30, 60, 90]
step_each_epoch: 1
gamma: 0.1
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: "./dataset/ILSVRC2012/"
cls_label_path: "./dataset/ILSVRC2012/train_list.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: "./dataset/ILSVRC2012/"
cls_label_path: "./dataset/ILSVRC2012/val_list.txt"
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: "docs/images/inference_deployment/whl_demo.jpg"
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: DistillationPostProcess
func: Topk
topk: 5
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
Metric:
Train:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Eval:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
...@@ -23,6 +23,7 @@ from .distillationloss import DistillationDMLLoss ...@@ -23,6 +23,7 @@ from .distillationloss import DistillationDMLLoss
from .distillationloss import DistillationDistanceLoss from .distillationloss import DistillationDistanceLoss
from .distillationloss import DistillationRKDLoss from .distillationloss import DistillationRKDLoss
from .distillationloss import DistillationKLDivLoss from .distillationloss import DistillationKLDivLoss
from .distillationloss import DistillationDKDLoss
from .multilabelloss import MultiLabelLoss from .multilabelloss import MultiLabelLoss
from .afdloss import AFDLoss from .afdloss import AFDLoss
......
...@@ -21,6 +21,7 @@ from .dmlloss import DMLLoss ...@@ -21,6 +21,7 @@ from .dmlloss import DMLLoss
from .distanceloss import DistanceLoss from .distanceloss import DistanceLoss
from .rkdloss import RKdAngle, RkdDistance from .rkdloss import RKdAngle, RkdDistance
from .kldivloss import KLDivLoss from .kldivloss import KLDivLoss
from .dkdloss import DKDLoss
class DistillationCELoss(CELoss): class DistillationCELoss(CELoss):
...@@ -204,3 +205,33 @@ class DistillationKLDivLoss(KLDivLoss): ...@@ -204,3 +205,33 @@ class DistillationKLDivLoss(KLDivLoss):
for key in loss: for key in loss:
loss_dict["{}_{}_{}".format(key, pair[0], pair[1])] = loss[key] loss_dict["{}_{}_{}".format(key, pair[0], pair[1])] = loss[key]
return loss_dict return loss_dict
class DistillationDKDLoss(DKDLoss):
"""
DistillationDKDLoss
"""
def __init__(self,
model_name_pairs=[],
key=None,
temperature=1.0,
alpha=1.0,
beta=1.0,
name="loss_dkd"):
super().__init__(temperature=temperature, alpha=alpha, beta=beta)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2, batch)
loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
return loss_dict
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class DKDLoss(nn.Layer):
"""
DKDLoss
Reference: https://arxiv.org/abs/2203.08679
Code was heavily based on https://github.com/megvii-research/mdistiller
"""
def __init__(self, temperature=1.0, alpha=1.0, beta=1.0):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.beta = beta
def forward(self, logits_student, logits_teacher, target):
gt_mask = _get_gt_mask(logits_student, target)
other_mask = 1 - gt_mask
pred_student = F.softmax(logits_student / self.temperature, axis=1)
pred_teacher = F.softmax(logits_teacher / self.temperature, axis=1)
pred_student = cat_mask(pred_student, gt_mask, other_mask)
pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)
log_pred_student = paddle.log(pred_student)
tckd_loss = (F.kl_div(
log_pred_student, pred_teacher,
reduction='sum') * (self.temperature**2) / target.shape[0])
pred_teacher_part2 = F.softmax(
logits_teacher / self.temperature - 1000.0 * gt_mask, axis=1)
log_pred_student_part2 = F.log_softmax(
logits_student / self.temperature - 1000.0 * gt_mask, axis=1)
nckd_loss = (F.kl_div(
log_pred_student_part2, pred_teacher_part2,
reduction='sum') * (self.temperature**2) / target.shape[0])
return self.alpha * tckd_loss + self.beta * nckd_loss
def _get_gt_mask(logits, target):
target = target.reshape([-1]).unsqueeze(1)
updates = paddle.ones_like(target)
mask = scatter(
paddle.zeros_like(logits), target, updates.astype('float32'))
return mask
def cat_mask(t, mask1, mask2):
t1 = (t * mask1).sum(axis=1, keepdim=True)
t2 = (t * mask2).sum(axis=1, keepdim=True)
rt = paddle.concat([t1, t2], axis=1)
return rt
def scatter(x, index, updates):
i, j = index.shape
grid_x, grid_y = paddle.meshgrid(paddle.arange(i), paddle.arange(j))
index = paddle.stack([grid_x.flatten(), index.flatten()], axis=1)
updates_index = paddle.stack([grid_x.flatten(), grid_y.flatten()], axis=1)
updates = paddle.gather_nd(updates, index=updates_index)
return paddle.scatter_nd_add(x, index, updates)
===========================train_params===========================
model_name:DistillationModel
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=100
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:amp_train
amp_train:tools/train.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o AMP.scale_loss=128 -o AMP.use_dynamic_loss_scaling=True -o AMP.level=O2
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
===========================train_params===========================
model_name:DistillationModel
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=100
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册