未验证 提交 36a9f6f0 编写于 作者: C ceci3 提交者: GitHub

update distill (#892)

* polish distill
上级 5ec362f2
# TinyBERT: Distilling BERT for Natural Language Understanding
以下是本例的简要目录结构及说明:
```
.
├── task_distill.py # 在特定任务上下的蒸馏脚本
└── README.md # 文档,本文件
```
## 简介
本目录下的实验主要参考论文[《TinyBERT: Distilling BERT for Natural Language Understanding》](https://arxiv.org/abs/1909.10351)实现。
TinyBERT中蒸馏的整体过程:首先进行通用蒸馏,然后用数据增强后的数据,在特定任务上进行蒸馏,本文主要进行了第二阶段的蒸馏,模型是利用第一阶段得到的通用小模型`tinybert-6l-768d-v2`进行初始化。
<p align="center">
<img src="./imgs/tinybert.png" width="950"/><br />
TinyBERT蒸馏流程图
</p>
在模型蒸馏中,较大的模型(在本例中是BERT base)通常被称为教师模型,较小的模型(在本例中是层数为6的BERT,下文都称TinyBERT6)通常被称为学生模型。
知识的蒸馏通常是通过让学生模型学习相关的蒸馏相损失函数实现,在本实验中,蒸馏的学习目标由两个部分组成,分别是中间层的蒸馏损失和预测层的蒸馏损失。其中,中间层的蒸馏包括对Embedding层的蒸馏、对每个Transformer layer输出的蒸馏、以及对每个Transformer中attention矩阵(softmax之前的结果)的蒸馏,三者均采用的是均方误差损失函数。而预测层蒸馏的学习目标则是学生模型输出的logits和教师模型输出的logits的交叉熵损失。
由于教师模型是12层,学生模型的层数少于教师模型的层数,因此需要选择一种layer mapping的方式。论文中采用了一种固定的映射方式,当学生模型的层数为教师模型的1/2时,学生第i层的attention矩阵,需要学习教师的第2i+1层的attention矩阵,Transformer layer输出同理。
实验分为两个大的训练过程:先对BERT-base进行微调,得到教师模型,再进行蒸馏的训练。其中,蒸馏过程也分为两个步骤:先对中间层进行蒸馏多个epochs(论文中针对具体任务可能是10、20或者30个),再对预测层蒸馏3个epochs。
需要注意的是,在使用不同教师模型时,`tinybert-6l-768d-v2``tinybert-4l-312d-v2`这两个v2版本的预训练模型中开放的从学生embedding输出、transformer中间层输出到教师相应输出的转换矩阵是每层独立的,而其他的`tinybert-6l-768d``tinybert-4l-312d``tinybert-6l-768d-zh``tinybert-4l-312-zh`则是多层之间的参数共用一个转换矩阵的。
### 安装PaddleNLP和Paddle
本教程基于PaddleNLP中BERT模型进行压缩,依赖PaddleNLP和Paddle。
```shell
pip install paddlenlp
pip install paddlepaddle_gpu
```
## 数据、预训练模型介绍及获取
本实验使用GLUE中数据集中的训练集作为训练语料,用数据集中的验证集评估模型的效果。
运行本目录下的实验,数据集会被自动下载到`paddlenlp.utils.env.DATA_HOME` 路径下,例如在linux系统下,对于GLUE中的QQP数据集,默认存储路径是`~/.paddlenlp/datasets/Glue/QQP`
对于BERT的fine-tuning任务,本实验中使用了预训练模型`bert-base-uncased`。同样,这几个模型在训练时会被自动下载到`paddlenlp.utils.env.MODEL_HOME`路径下。例如,对于`bert-base-uncased`模型,在linux系统下,会被下载到`~/.paddlenlp/models/bert-base-uncased`下。
## 蒸馏实验过程
### 对BERT Fine-tuning得到教师模型
首先需要对Pretrain-Model在实际的下游任务上进行Fine-tuning,得到需要压缩的模型。Fine-tuning流程参考[Fine-tuning教程](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/examples/language_model/bert/README.md)
训练完成之后,可将训练效果最好的模型保存在本项目下的`pretrained_models/$TASK_NAME/`下。模型目录下有`model_config.json`, `model_state.pdparams`, `tokenizer_config.json``vocab.txt`这几个文件。
### 对TinyBERT在特定任务下蒸馏
先蒸馏中间层:
```shell
export CUDA_VISIBLE_DEVICES=0
export TASK_NAME=SST-2
export TEACHER_DIR=./pretrained_models/SST-2/best_model_610
python task_distill.py \
--model_type tinybert \
--student_model_name_or_path tinybert-6l-768d-v2 \
--task_name $TASK_NAME \
--intermediate_distill \
--max_seq_length 64 \
--batch_size 32 \
--T 1 \
--teacher_model_type bert \
--teacher_path $TEACHER_DIR \
--learning_rate 5e-5 \
--num_train_epochs 20 \
--logging_steps 10 \
--save_steps 10 \
--output_dir ./tmp/$TASK_NAME/ \
--distill_config ./distill_stage1.yaml \
--device gpu
```
其中参数释义如下:
- `model_type` 学生模型类型,默认且目前仅支持tinybert。
- `student_model_name_or_path` 中间层蒸馏后,学生模型存放的目录
- `distill_config` 蒸馏配置文件
- `max_seq_length` 表示最大句子长度,超过该长度将被截断。默认:128
- `T` softmax的温度,用于对softmax做平滑,在训练中起到放大负标签效果的作用。默认:1
- `teacher_model_type` 教师模型的类型,默认且目前仅支持bert
- `teacher_path` 教师Fine-tuned模型的目录
- `output_dir` 学生模型存放的目录
- `device` 表示运行该程序的设备,默认是gpu
然后对预测层进行蒸馏:
```shell
export TEACHER_DIR=../pretrained_models/SST-2/best_model_610
python task_distill.py \
--model_type tinybert \
--student_model_name_or_path tmp/TASK_NAME best_inter_model \
--task_name $TASK_NAME \
--max_seq_length 64 \
--batch_size 32 \
--T 1 \
--teacher_model_type bert \
--teacher_path $TEACHER_DIR \
--learning_rate 3e-5 \
--num_train_epochs 3 \
--logging_steps 10 \
--save_steps 10 \
--output_dir ./tmp/$TASK_NAME/ \
--distill_config ./distill_stage2.yaml \
--device gpu
```
其中参数释义如下:
所有参数说明同上。
### 实验中使用的超参数
| | SST-2 | QQP | MRPC | CoLA | RTE | MNLI | QNLI |
| -------------------------------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- |
| batch_size | 32 | 32 | 32 | 32 | 32 | 32 | 32 |
| max_seq_length | 64 | 128 | 128 | 64 | 128 | 128 | 128 |
| max_epochs_of_intermediate_layer | 20 | 10 | 20 | 50 | 20 | 10 | 10 |
| max_epochs_of_prediction_layer | 3 | 3 | 3 | 3 | 3 | 3 | 3 |
| learning_rate(inter/pred) | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 |
## 蒸馏实验结果
本文档的实验基于TinyBERT的6层、hidden_size为768的通用蒸馏得到的模型,用未使用数据增强的原始数据集训练,并基于验证集进行评价。得到以下实验结果:
| | SST-2 | QQP(acc/f1) | MRPC(acc/f1) | CoLA | RTE | MNLI-m | MNLI-mm | QNLI |
| ----------------- | ----- | ----------- | ------------ | ----- | ----- | ------ | ------- | ----- |
| BERT-base | 93.00 | 90.58/87.35 | 88.23/91.67 | 59.56 | 73.65 | 84.42 | 84.83 | 91.78 |
| TinyBERT(6l-768d) | 93.00 | 91.13/88.20 | 88.48/91.91 | 52.64 | 72.94 | 84.57 | 84.63 | 91.36 |
## 参考文献
Jiao X, Yin Y, Shang L, et al. [TinyBERT: Distilling BERT for Natural Language Understanding](https://arxiv.org/abs/1909.10351)[J]. arXiv preprint arXiv:1909.10351v5, 2020.
- DistillConfig:
loss_function: MSELoss
model_name_pairs:
- - student_0
- teacher_0
weight: 1.0
- layers:
- layers_name: ['tinybert.embeddings', 'bert.embeddings']
- layers_name: ['tinybert.encoder.layers.0', 'bert.encoder.layers.1']
- layers_name: ['tinybert.encoder.layers.1', 'bert.encoder.layers.3']
- layers_name: ['tinybert.encoder.layers.2', 'bert.encoder.layers.5']
- layers_name: ['tinybert.encoder.layers.3', 'bert.encoder.layers.7']
- layers_name: ['tinybert.encoder.layers.4', 'bert.encoder.layers.9']
- layers_name: ['tinybert.encoder.layers.5', 'bert.encoder.layers.11']
- layers_name: ['tinybert.encoder.layers.0.self_attn', 'bert.encoder.layers.1.self_attn']
- layers_name: ['tinybert.encoder.layers.1.self_attn', 'bert.encoder.layers.3.self_attn']
- layers_name: ['tinybert.encoder.layers.2.self_attn', 'bert.encoder.layers.5.self_attn']
- layers_name: ['tinybert.encoder.layers.3.self_attn', 'bert.encoder.layers.7.self_attn']
- layers_name: ['tinybert.encoder.layers.4.self_attn', 'bert.encoder.layers.9.self_attn']
- layers_name: ['tinybert.encoder.layers.5.self_attn', 'bert.encoder.layers.11.self_attn']
- DistillConfig:
loss_function: CELoss
model_name_pairs:
- - student_0
- teacher_0
weight: 1.0
- layers:
- layers_name: ['classifier', 'classifier']
temperature: 1.0
export CUDA_VISIBLE_DEVICES=0
export TASK_NAME=SST-2
export TEACHER_DIR=/root/work/Distill_PaddleSlim/PaddleNLP/examples/model_compression/tinybert/best_model_610
python3.7 task_distill.py \
--model_type tinybert \
--student_model_name_or_path tinybert-6l-768d-v2 \
--task_name $TASK_NAME \
--intermediate_distill \
--max_seq_length 64 \
--batch_size 32 \
--T 1 \
--teacher_model_type bert \
--teacher_path $TEACHER_DIR \
--learning_rate 5e-5 \
--num_train_epochs 20 \
--logging_steps 10 \
--save_steps 10 \
--output_dir ./tmp/$TASK_NAME/ \
--device gpu
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import sys
import random
import time
import math
from functools import partial
import numpy as np
import paddle
from paddle.io import DataLoader
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.metric import Accuracy
from paddlenlp.datasets import load_dataset
from paddlenlp.data import Stack, Tuple, Pad, Dict
from paddlenlp.data.sampler import SamplerHelper
from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
import paddlenlp.transformers as T
from paddleslim import Distill
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
METRIC_CLASSES = {
"cola": Mcc,
"sst-2": Accuracy,
"mrpc": AccuracyAndF1,
"sts-b": PearsonAndSpearman,
"qqp": AccuracyAndF1,
"mnli": Accuracy,
"qnli": Accuracy,
"rte": Accuracy,
}
MODEL_CLASSES = {
"bert": (T.BertForSequenceClassification, T.BertTokenizer),
"tinybert": (T.TinyBertForSequenceClassification, T.TinyBertTokenizer),
}
def parse_args():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--task_name",
default=None,
type=str,
required=True,
help="The name of the task to train selected in the list: " +
", ".join(METRIC_CLASSES.keys()), )
parser.add_argument(
"--model_type",
default="tinybert",
type=str,
required=True,
help="Model type selected in the list: " +
", ".join(MODEL_CLASSES.keys()), )
parser.add_argument(
"--teacher_model_type",
default="bert",
type=str,
required=True,
help="Model type selected in the list: " +
", ".join(MODEL_CLASSES.keys()), )
parser.add_argument(
"--student_model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: "
+ ", ".join(
sum([
list(classes[-1].pretrained_init_configuration.keys())
for classes in MODEL_CLASSES.values()
], [])), )
parser.add_argument(
"--distill_config",
default=None,
type=str,
help="distill config file path")
parser.add_argument(
"--teacher_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model.")
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--glue_dir",
default="/root/.paddlenlp/datasets/Glue/",
type=str,
required=False,
help="The Glue directory.", )
parser.add_argument(
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.", )
parser.add_argument(
"--learning_rate",
default=1e-4,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument(
"--num_train_epochs",
default=3,
type=int,
help="Total number of training epochs to perform.", )
parser.add_argument(
"--logging_steps",
type=int,
default=100,
help="Log every X updates steps.")
parser.add_argument(
"--save_steps",
type=int,
default=100,
help="Save checkpoint every X updates steps.")
parser.add_argument(
"--batch_size",
default=32,
type=int,
help="Batch size per GPU/CPU for training.", )
parser.add_argument(
"--T",
default=1,
type=int,
help="Temperature for softmax", )
parser.add_argument(
"--use_aug",
action="store_true",
help="Whether to use augmentation data to train.", )
parser.add_argument(
"--intermediate_distill",
action="store_true",
help="Whether distilling intermediate layers. If False, it means prediction layer distillation.",
)
parser.add_argument(
"--weight_decay",
default=0.0,
type=float,
help="Weight decay if we apply some.")
parser.add_argument(
"--warmup_steps",
default=0,
type=int,
help="Linear warmup over warmup_steps. If > 0: Override warmup_proportion"
)
parser.add_argument(
"--warmup_proportion",
default=0.1,
type=float,
help="Linear warmup proportion over total steps.")
parser.add_argument(
"--adam_epsilon",
default=1e-6,
type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument(
"--seed", default=42, type=int, help="random seed for initialization")
parser.add_argument(
"--device",
default="gpu",
type=str,
help="The device to select to train the model, is must be cpu/gpu/xpu.")
args = parser.parse_args()
return args
def set_seed(args):
# Use the same data seed(for data shuffle) for all procs to guarantee data
# consistency after sharding.
random.seed(args.seed)
np.random.seed(args.seed)
# Maybe different op seeds(for dropout) for different procs is better. By:
# `paddle.seed(args.seed + paddle.distributed.get_rank())`
paddle.seed(args.seed)
@paddle.no_grad()
def evaluate(model, metric, data_loader):
model.eval()
metric.reset()
for batch in data_loader:
input_ids, segment_ids, labels = batch
logits = model(input_ids, segment_ids)
correct = metric.compute(logits, labels)
metric.update(correct)
res = metric.accumulate()
if isinstance(metric, AccuracyAndF1):
print(
"acc: %s, precision: %s, recall: %s, f1: %s, acc and f1: %s, " % (
res[0],
res[1],
res[2],
res[3],
res[4], ),
end='')
elif isinstance(metric, Mcc):
print("mcc: %s, " % (res[0]), end='')
elif isinstance(metric, PearsonAndSpearman):
print(
"pearson: %s, spearman: %s, pearson and spearman: %s, " %
(res[0], res[1], res[2]),
end='')
else:
print("acc: %s, " % (res), end='')
model.train()
return res[0] if isinstance(metric, (AccuracyAndF1, Mcc,
PearsonAndSpearman)) else res
def convert_example(example,
tokenizer,
label_list,
max_seq_length=512,
is_test=False):
"""convert a glue example into necessary features"""
if not is_test:
# `label_list == None` is for regression task
label_dtype = "int64" if label_list else "float32"
# Get the label
label = example['labels']
label = np.array([label], dtype=label_dtype)
# Convert raw text to feature
if (int(is_test) + len(example)) == 2:
example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
else:
example = tokenizer(
example['sentence1'],
text_pair=example['sentence2'],
max_seq_len=max_seq_length)
if not is_test:
return example['input_ids'], example['token_type_ids'], label
else:
return example['input_ids'], example['token_type_ids']
def do_train(args):
paddle.set_device(args.device)
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
set_seed(args)
args.task_name = args.task_name.lower()
metric_class = METRIC_CLASSES[args.task_name]
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
if args.use_aug:
aug_data_file = os.path.join(
os.path.join(args.glue_dir, args.task_name), "train_aug.tsv"),
train_ds = load_dataset(
'glue', args.task_name, data_files=aug_data_file)
else:
train_ds = load_dataset('glue', args.task_name, splits='train')
tokenizer = tokenizer_class.from_pretrained(args.student_model_name_or_path)
trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=train_ds.label_list,
max_seq_length=args.max_seq_length)
train_ds = train_ds.map(trans_func, lazy=True)
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_ds, batch_size=args.batch_size, shuffle=True)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
Stack(dtype="int64" if train_ds.label_list else "float32") # label
): fn(samples)
train_data_loader = DataLoader(
dataset=train_ds,
batch_sampler=train_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
if args.task_name == "mnli":
dev_ds_matched, dev_ds_mismatched = load_dataset(
'glue', args.task_name, splits=["dev_matched", "dev_mismatched"])
dev_ds_matched = dev_ds_matched.map(trans_func, lazy=True)
dev_ds_mismatched = dev_ds_mismatched.map(trans_func, lazy=True)
dev_batch_sampler_matched = paddle.io.BatchSampler(
dev_ds_matched, batch_size=args.batch_size, shuffle=False)
dev_data_loader_matched = DataLoader(
dataset=dev_ds_matched,
batch_sampler=dev_batch_sampler_matched,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
dev_batch_sampler_mismatched = paddle.io.BatchSampler(
dev_ds_mismatched, batch_size=args.batch_size, shuffle=False)
dev_data_loader_mismatched = DataLoader(
dataset=dev_ds_mismatched,
batch_sampler=dev_batch_sampler_mismatched,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
else:
dev_ds = load_dataset('glue', args.task_name, splits='dev')
dev_ds = dev_ds.map(trans_func, lazy=True)
dev_batch_sampler = paddle.io.BatchSampler(
dev_ds, batch_size=args.batch_size, shuffle=False)
dev_data_loader = DataLoader(
dataset=dev_ds,
batch_sampler=dev_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
num_classes = 1 if train_ds.label_list == None else len(train_ds.label_list)
student = model_class.from_pretrained(
args.student_model_name_or_path, num_classes=num_classes)
teacher_model_class, _ = MODEL_CLASSES[args.teacher_model_type]
teacher = teacher_model_class.from_pretrained(
args.teacher_path, num_classes=num_classes)
teacher.eval()
if paddle.distributed.get_world_size() > 1:
student = paddle.DataParallel(student, find_unused_parameters=True)
teacher = paddle.DataParallel(teacher, find_unused_parameters=True)
if args.max_steps > 0:
num_training_steps = args.max_steps
num_train_epochs = math.ceil(num_training_steps /
len(train_data_loader))
else:
num_training_steps = len(train_data_loader) * args.num_train_epochs
num_train_epochs = args.num_train_epochs
warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion
lr_scheduler = T.LinearDecayWithWarmup(args.learning_rate,
num_training_steps, warmup)
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in student.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
beta1=0.9,
beta2=0.999,
epsilon=args.adam_epsilon,
parameters=student.parameters(),
weight_decay=args.weight_decay,
apply_decay_param_fun=lambda x: x in decay_params)
metric = metric_class()
pad_token_id = 0
global_step = 0
tic_train = time.time()
best_res = 0.0
assert os.path.exists(
args.distill_config), "distill file {} not exist.".format(
args.distill_config)
distill_model = Distill(
args.distill_config, student_models=[student],
teacher_models=[teacher])
for epoch in range(num_train_epochs):
for step, batch in enumerate(train_data_loader):
global_step += 1
input_ids, segment_ids, labels = batch
loss, _, _ = distill_model(input_ids, segment_ids)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()
if global_step % args.logging_steps == 0:
print(
"global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s"
% (global_step, num_training_steps, epoch, step,
paddle.distributed.get_rank(), loss, optimizer.get_lr(),
args.logging_steps / (time.time() - tic_train)))
tic_train = time.time()
if global_step % args.save_steps == 0 or global_step == num_training_steps:
tic_eval = time.time()
if args.task_name == "mnli":
res = evaluate(student, metric, dev_data_loader_matched)
evaluate(student, metric, dev_data_loader_mismatched)
print("eval done total : %s s" % (time.time() - tic_eval))
else:
res = evaluate(student, metric, dev_data_loader)
print("eval done total : %s s" % (time.time() - tic_eval))
if (best_res < res and global_step < num_training_steps or
global_step == num_training_steps
) and paddle.distributed.get_rank() == 0:
if global_step < num_training_steps:
output_dir = os.path.join(args.output_dir,
"distill_model_%d.pdparams" %
(global_step))
else:
output_dir = os.path.join(
args.output_dir, "distill_model_final.pdparams")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Need better way to get inner model of DataParallel
model_to_save = student._layers if isinstance(
student, paddle.DataParallel) else student
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
best_res = res
if global_step >= num_training_steps:
return
def print_arguments(args):
"""print arguments"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == "__main__":
args = parse_args()
print_arguments(args)
do_train(args)
......@@ -14,7 +14,9 @@
from . import distill
from .distill import *
from .distill_helpers import *
__all__ = []
__all__ += distill.__all__
__all__ += distill_helpers.__all__
......@@ -17,207 +17,200 @@ import collections
from collections import namedtuple
import paddle.nn as nn
from . import losses
from .losses.basic_loss import BASIC_LOSS
from .distill_helpers import yaml2config
__all__ = ['Distill', 'AdaptorBase']
__all__ = ['Distill']
class LayerConfig:
""" The key of config can be set"""
def __init__(self,
s_feature_idx,
t_feature_idx,
feature_type,
model_name_pairs,
layers_name,
loss_function,
weight=1.0,
align=False,
align_shape=None):
self.s_feature_idx = s_feature_idx
self.t_feature_idx = t_feature_idx
self.feature_type = feature_type
if loss_function in ['l1', 'l2', 'smooth_l1']:
self.loss_function = 'DistillationDistanceLoss'
elif loss_function in ['dml']:
self.loss_function = 'DistillationDMLLoss'
elif loss_function in ['rkl']:
self.loss_function = 'DistillationRKDLoss'
elif hasattr(losses, loss_function):
self.loss_function = loss_function
else:
raise NotImplementedError("loss function is not support!!!")
temperature=1.0,
align_params=None,
**loss_params):
self.model_name_pairs = model_name_pairs
self.layers_name = layers_name
if loss_function not in BASIC_LOSS.module_dict:
raise NotImplementedError("loss function {} is not support. "
"Support loss including {}".format(
loss_function,
BASIC_LOSS.module_dict.keys()))
self.loss_function = loss_function
self.weight = weight
self.align = align
self.align_shape = align_shape
class AdaptorBase:
def __init__(self, model):
self.model = model
self.add_tensor = False
def _get_activation(self, outs, name):
self.temperature = temperature
self.align_params = align_params
for k, v in loss_params.items():
setattr(self, k, v)
def _add_hooks(model, outs, hook_layers_name):
"""
Get output by layer name.
models(nn.Layer): model need to be add hook.
outs(dict): save the middle outputs of model according to the name.
hook_layers_name(list): name of middle layers.
"""
def _get_activation(outs, name):
### TODO: need to support get input tensor
#outs[name] = {}
def get_output_hook(layer, input, output):
#outs[name]["output"] = output
#outs[name]["input"] = input
outs[name] = output
return get_output_hook
def _add_distill_hook(self, outs, mapping_layers_name, layers_type):
"""
Get output by layer name.
outs(dict): save the middle outputs of model according to the name.
mapping_layers(list): name of middle layers.
layers_type(list): type of the middle layers to calculate distill loss.
"""
### TODO: support DP model
for idx, (n, m) in enumerate(self.model.named_sublayers()):
if n in mapping_layers_name:
midx = mapping_layers_name.index(n)
m.register_forward_post_hook(
self._get_activation(outs, layers_type[midx]))
def mapping_layers(self):
raise NotImplementedError("function mapping_layers is not implemented")
### TODO: support DP model
for idx, (n, m) in enumerate(model.named_sublayers()):
if n in hook_layers_name:
m.register_forward_post_hook(_get_activation(outs, n))
class Distill(nn.Layer):
### TODO: support list of student model and teacher model
def __init__(self, distill_configs, student_models, teacher_models,
adaptors_S, adaptors_T):
super(Distill, self).__init__()
assert student_models.training, "The student model should be eval mode."
"""
Distill API.
distill_configs(list(dict) | path): the list of distill config.
student_models(list(nn.Layer)): the list of student model, the state of student model must be training mode.
teacher_models(list(nn.Layer)): the list of teacher model, the state of student model must be evaluate mode.
return_model_outputs(bool): whether to return model output. Default: True.
"""
self._distill_configs = distill_configs
def __init__(self,
distill_configs,
student_models,
teacher_models,
return_model_outputs=True):
super(Distill, self).__init__()
if isinstance(student_models, nn.Layer):
student_models = [student_models]
if isinstance(teacher_models, nn.Layer):
teacher_models = [teacher_models]
for student_model in student_models:
assert student_model.training, "The student model should not be eval mode."
for teacher_model in teacher_models:
assert teacher_model.training is False, "The teacher model should be eval mode."
if isinstance(distill_configs, list):
self._distill_configs = distill_configs
elif os.path.exists(distill_configs):
if distill_configs.endswith(".yaml"):
self._distill_configs = yaml2config(distill_configs)
else:
raise NotImplementedError("distill config file type error!")
else:
raise NotImplementedError("distill config error!")
self._student_models = student_models
self._teacher_models = teacher_models
self._adaptors_S = adaptors_S(self._student_models)
self._adaptors_T = adaptors_T(self._teacher_models)
self._return_model_outputs = return_model_outputs
self.stu_outs_dict, self.tea_outs_dict = self._prepare_outputs()
self.configs = []
self._loss_config_list = []
for c in self._distill_configs:
self.configs.append(LayerConfig(**c).__dict__)
self._transpose_config(c)
self.distill_idx = self._get_distill_idx()
self._loss_config_list = []
for c in self.configs:
loss_config = {}
loss_config[str(c['loss_function'])] = {}
loss_config[str(c['loss_function'])]['weight'] = c['weight']
loss_config[str(c['loss_function'])]['key'] = c[
'feature_type'] + '_' + str(c['s_feature_idx']) + '_' + str(c[
't_feature_idx'])
### TODO: support list of student models and teacher_models
loss_config[str(c['loss_function'])][
'model_name_pairs'] = [['student', 'teacher']]
self._loss_config_list.append(loss_config)
self._hook_layers = self._extract_hook_position()
# use self._loss_config_list to create all loss object
self.distill_loss = losses.CombinedLoss(self._loss_config_list)
self._output_tensor_dict = self._prepare_outputs()
def parameters(self):
params = []
for s_model in self._student_models:
params.extend(s_model.parameters())
return params
def _extract_hook_position(self):
""" extrat hook position according to config"""
model_hook_layers = {}
for config in self._loss_config_list:
model_name_pairs = config['model_name_pairs']
layers_name = config['layers_name']
for model_name_pair in model_name_pairs:
for idx, model_name in enumerate(model_name_pair):
if model_name not in model_hook_layers:
model_hook_layers[model_name] = [layers_name[idx]]
else:
model_hook_layers[model_name].append(layers_name[idx])
for model_name, hook_layers in model_hook_layers.items():
model_hook_layers[model_name] = list(set(hook_layers))
return model_hook_layers
def _transpose_config(self, config):
""" Transpose config to loss needed """
global_config = {}
if 'model_name_pairs' not in config:
global_config['model_name_pairs'] = [['student_0', 'teacher_0']]
else:
if isinstance(config['model_name_pairs'][0], str):
config['model_name_pairs'] = [config['model_name_pairs']]
global_config['model_name_pairs'] = config['model_name_pairs']
config.pop('model_name_pairs')
for key in config.keys():
if key != 'layers':
global_config[key] = config[key]
for per_layer_config in config['layers']:
per_layer_config.update(global_config)
self._loss_config_list.append(
LayerConfig(**per_layer_config).__dict__)
def _prepare_outputs(self):
"""
Add hook to get the output tensor of target layer.
Returns:
stu_outs_dict(dict): the name and tensor for the student model,
such as {'hidden_0': tensor_0, ..}
tea_outs_dict(dict): the name and tensor for the teather model,
such as {'hidden_0': tensor_0, ..}
"""
stu_outs_dict = collections.OrderedDict()
tea_outs_dict = collections.OrderedDict()
stu_outs_dict = self._prepare_hook(self._adaptors_S, stu_outs_dict)
tea_outs_dict = self._prepare_hook(self._adaptors_T, tea_outs_dict)
return stu_outs_dict, tea_outs_dict
def _prepare_hook(self, adaptors, outs_dict):
outputs_tensor = {}
for idx, m in enumerate(self._student_models):
hook_layers = self._hook_layers['student_{}'.format(idx)]
stu_outs = collections.OrderedDict()
outputs_tensor['student_{}'.format(idx)] = self._prepare_hook(
m, hook_layers, stu_outs)
for idx, m in enumerate(self._teacher_models):
hook_layers = self._hook_layers['teacher_{}'.format(idx)]
tea_outs = collections.OrderedDict()
outputs_tensor['teacher_{}'.format(idx)] = self._prepare_hook(
m, hook_layers, tea_outs)
return outputs_tensor
def _prepare_hook(self, model, hook_layers, outs_dict):
"""
Add hook.
"""
mapping_layers = adaptors.mapping_layers()
for layer_type, layer in mapping_layers.items():
for layer in hook_layers:
if isinstance(layer, str):
adaptors._add_distill_hook(outs_dict, [layer], [layer_type])
_add_hooks(model, outs_dict, layer)
return outs_dict
def _get_distill_idx(self):
"""
For each feature_type, get the feature index in the student and teacher models.
Returns:
distill_idx(dict): the feature index for each feature_type,
such as {'hidden': [[0, 0], [1, 1]], 'out': [[0, 0]]}
"""
distill_idx = {}
for config in self._distill_configs:
if config['feature_type'] not in distill_idx:
distill_idx[config['feature_type']] = [[
int(config['s_feature_idx']), int(config['t_feature_idx'])
]]
else:
distill_idx[config['feature_type']].append([
int(config['s_feature_idx']), int(config['t_feature_idx'])
])
return distill_idx
def forward(self, *inputs, **kwargs):
stu_batch_outs = self._student_models.forward(*inputs, **kwargs)
tea_batch_outs = self._teacher_models.forward(*inputs, **kwargs)
if not self._teacher_models.training:
tea_batch_outs = [i.detach() for i in tea_batch_outs]
# get all target tensor
if self._adaptors_S.add_tensor == False:
self._adaptors_S.add_tensor = True
if self._adaptors_T.add_tensor == False:
self._adaptors_T.add_tensor = True
self.stu_outs_dict = self._get_model_intermediate_output(
self._adaptors_S, self.stu_outs_dict)
self.tea_outs_dict = self._get_model_intermediate_output(
self._adaptors_T, self.tea_outs_dict)
distill_inputs = self._process_outputs()
students_batch_outs = []
teachers_batch_outs = []
for idx, student_model in enumerate(self._student_models):
stu_batch_outs = student_model.forward(*inputs, **kwargs)
students_batch_outs.append(stu_batch_outs)
for idx, teacher_model in enumerate(self._teacher_models):
tea_batch_outs = teacher_model.forward(*inputs, **kwargs)
if not teacher_model.training:
tea_batch_outs = [i.detach() for i in tea_batch_outs]
teachers_batch_outs.extend(tea_batch_outs)
if len(self._student_models) == 1:
students_batch_outs = students_batch_outs[0]
if len(self._teacher_models) == 1:
teachers_batch_outs = teachers_batch_outs[0]
### batch is None just for now
distill_outputs = self.distill_loss(distill_inputs, None)
distill_outputs = self.distill_loss(self._output_tensor_dict, None)
distill_loss = distill_outputs['loss']
return stu_batch_outs, tea_batch_outs, distill_loss
def _get_model_intermediate_output(self, adaptors, outs_dict):
"""
Use the adaptor get the target tensor.
Returns:
outs_dict(dict): the name and tensor for the target model,
such as {'hidden_0': tensor_0, ..}
"""
mapping_layers = adaptors.mapping_layers()
for layer_type, layer in mapping_layers.items():
if isinstance(layer, str):
continue
outs_dict[layer_type] = layer
return outs_dict
def _process_outputs(self):
"""
Process the target tensor to adapt for loss.
"""
### TODO: support list of student models and teacher_models
final_distill_dict = {
"student": collections.OrderedDict(),
"teacher": collections.OrderedDict()
}
for feature_type, dist_idx in self.distill_idx.items():
for idx, idx_list in enumerate(dist_idx):
sidx, tidx = idx_list[0], idx_list[1]
stu_out = self.stu_outs_dict[feature_type + '_' + str(sidx)]
tea_out = self.tea_outs_dict[feature_type + '_' + str(tidx)]
if not self._student_models.training:
stu_out = stu_out.detach()
if not self._teacher_models.training:
tea_out = tea_out.detach()
name_str = feature_type + '_' + str(sidx) + '_' + str(tidx)
final_distill_dict['student'][name_str] = stu_out
final_distill_dict['teacher'][name_str] = tea_out
return final_distill_dict
if self._return_model_outputs:
return distill_loss, students_batch_outs, teachers_batch_outs
else:
return distill_loss
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
__all__ = ['config2yaml']
def yaml2config(yaml_path):
"""
convert yaml to dict config.
"""
final_configs = []
f = open(yaml_path, 'r')
origin_configs = yaml.load(f, Loader=yaml.FullLoader)
f.close()
for configs in origin_configs:
configs = configs['DistillConfig']
final_configs.extend(configs)
return final_configs
def config2yaml(configs, yaml_path):
"""
convert dict config to yaml.
"""
final_yaml = dict()
final_yaml['DistillConfig'] = configs
f = open(yaml_path, "w")
yaml.dump([final_yaml], f)
f.close()
......@@ -19,18 +19,7 @@ import paddle.nn as nn
from . import basic_loss
from . import distillation_loss
from .basic_loss import L1Loss
from .basic_loss import L2Loss
from .basic_loss import SmoothL1Loss
from .basic_loss import CELoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
from .basic_loss import RKdAngle, RkdDistance
from .distillation_loss import DistillationDistanceLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationRKDLoss
from .distillation_loss import SegPairWiseLoss, SegChannelwiseLoss
from .distillation_loss import DistillationLoss
class CombinedLoss(nn.Layer):
......@@ -40,13 +29,12 @@ class CombinedLoss(nn.Layer):
loss_config_list: a config list used to build loss function. A demo is as follows,
which is used to calculate dml loss between Student output and
Teacher output. Parameter weight is needed for the loss weight.
- DistillationDMLLoss:
{ loss_function: DMLLoss
weight: 1.0
act: "softmax"
model_name_pairs:
- ["Student", "Teacher"]
Another example is {'DistillationDistanceLoss': {'weight': 1.0,
'key': 'hidden_0_0', 'model_name_pairs': [['student', 'teacher']]}
model_name_pairs:["student_0", "teacher_0"]}
Another example is {loss_function: "MSELoss", 'weight': 1.0,
'layers_name': ['conv0', 'conv0'], 'model_name_pairs': [['student', 'teacher']]}
"""
def __init__(self, loss_config_list=None):
......@@ -56,18 +44,14 @@ class CombinedLoss(nn.Layer):
self.loss_weight = []
assert isinstance(loss_config_list, list), (
'operator config should be a list')
supported_loss_list = basic_loss.__all__ + distillation_loss.__all__
for config in loss_config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
name = list(config)[0]
assert name in supported_loss_list, \
"loss name must be in {} but got: {}".format(name, supported_loss_list)
param = config[name]
assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys())
self.loss_weight.append(param.pop("weight"))
self.loss_func.append(eval(name)(**param))
assert isinstance(
config, dict), "config must be a dict, but now is {}".format(
type(config))
assert "weight" in config, "weight must be in param, but param just contains {}".format(
config.keys())
self.loss_weight.append(config.pop("weight"))
self.loss_func.append(DistillationLoss(**config))
def forward(self, input, batch, **kargs):
loss_dict = {}
......@@ -82,6 +66,7 @@ class CombinedLoss(nn.Layer):
for key in loss
}
loss_dict.update(loss)
if loss_dict == {}:
loss_dict["loss"] = paddle.to_tensor(0.)
else:
......
......@@ -20,11 +20,13 @@ from paddle.nn import L1Loss
from paddle.nn import MSELoss as L2Loss
from paddle.nn import SmoothL1Loss
__all__ = [
"CELoss", "DMLLoss", "DistanceLoss", "RKdAngle", "RkdDistance", "KLLoss"
]
from ....core import Registry
__all__ = ["BASIC_LOSS"]
BASIC_LOSS = Registry("basicloss")
@BASIC_LOSS.register
class CELoss(nn.Layer):
"""
CELoss: cross entropy loss
......@@ -78,6 +80,7 @@ class CELoss(nn.Layer):
return loss
@BASIC_LOSS.register
class DMLLoss(nn.Layer):
"""
DMLLoss
......@@ -110,6 +113,7 @@ class DMLLoss(nn.Layer):
return loss
@BASIC_LOSS.register
class KLLoss(nn.Layer):
"""
KLLoss.
......@@ -153,6 +157,7 @@ class KLLoss(nn.Layer):
return loss
@BASIC_LOSS.register
class DistanceLoss(nn.Layer):
"""
DistanceLoss
......@@ -191,6 +196,7 @@ def pdist(e, squared=False, eps=1e-12):
return res
@BASIC_LOSS.register
class RKdAngle(nn.Layer):
"""
RKdAngle loss, see https://arxiv.org/abs/1904.05068
......@@ -218,6 +224,7 @@ class RKdAngle(nn.Layer):
return loss
@BASIC_LOSS.register
class RkdDistance(nn.Layer):
"""
RkdDistance loss, see https://arxiv.org/abs/1904.05068
......@@ -244,3 +251,50 @@ class RkdDistance(nn.Layer):
loss = F.smooth_l1_loss(d, t_d, reduction="mean")
return loss
@BASIC_LOSS.register
class MSELoss(DistanceLoss):
"""
MSELoss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/MSELoss_cn.html#mseloss
"""
def __init__(self, **kargs):
super().__init__(mode='l2', **kargs)
@BASIC_LOSS.register
class L1Loss(DistanceLoss):
"""
L1loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/L1Loss_cn.html#l1loss
"""
def __init__(self, **kargs):
super().__init__(mode='l1', **kargs)
@BASIC_LOSS.register
class SmoothL1Loss(DistanceLoss):
"""
SmoothL1Loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/SmoothL1Loss_cn.html#smoothl1loss
"""
def __init__(self, **kargs):
super().__init__(mode='smooth_l1', **kargs)
@BASIC_LOSS.register
class RKDLoss(nn.Layer):
"""
RKDLoss
"""
def __init__(self, eps=1e-12):
super().__init__()
self.rkd_angle_loss_func = RKdAngle()
self.rkd_dist_func = RkdDistance(eps=eps)
def forward(self, student, teacher):
angle_loss = self.rkd_angle_loss_func(student, teacher)
dist_loss = self.rkd_dist_func(student, teacher)
return angle_loss + dist_loss
......@@ -15,210 +15,54 @@
import paddle
import paddle.nn as nn
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
from .basic_loss import RkdDistance
from .basic_loss import RKdAngle
from .basic_loss import KLLoss
from .basic_loss import BASIC_LOSS
__all__ = [
"DistillationDMLLoss",
"DistillationDistanceLoss",
"DistillationRKDLoss",
"SegPairWiseLoss",
"SegChannelwiseLoss",
]
__all__ = ["DistillationLoss"]
class DistillationDMLLoss(DMLLoss):
class DistillationLoss(nn.Layer):
"""
DistillationDMLLoss
DistillationLoss
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
act(string | None): activation function used to build dml loss.
axis(int): axis used to build activation function.
key(string | None): key of the tensor used to calculate loss if the submodel
output type is dict.
name(string): loss name.
"""
def __init__(self, model_name_pairs=[], act=None, key=None,
name="loss_dml"):
super().__init__(act=act)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = super().forward(out1, out2)
return loss_dict
class DistillationDistanceLoss(DistanceLoss):
"""
DistillationDistanceLoss
Args:
mode: loss mode
model_name_pairs(list | tuple): model name pairs to extract submodel output.
such as [['student', 'teacher']]
key(string | None): key of the tensor used to calculate loss if the submodel.
such as 'hidden_0_0'
name(string): loss name.
kargs(dict): used to build corresponding loss function.
layers_name(list(string)): keys of the tensor used to calculate loss if the submodel.
loss_function(string): the name of loss function.
temperature(float): the temperature to compute distill loss.
"""
def __init__(self,
mode="l2",
model_name_pairs=[],
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name + "_" + mode
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict
class DistillationRKDLoss(nn.Layer):
"""
DistillationRKDLoss
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string | None): key of the tensor used to calculate loss if the submodel.
eps(float): epsilon for the pdist function for RkdDistance loss.
name(string): loss name.
"""
def __init__(self,
model_name_pairs=[],
key=None,
eps=1e-12,
name="loss_rkd"):
layers_name=None,
loss_function=None,
temperature=1.0,
**params):
super().__init__()
self.model_name_pairs = model_name_pairs
self.key = key
self.layers_name = layers_name
self.loss_function = loss_function
self.temperature = temperature
self.align_params = params.pop(
'align_params') if 'align_params' in params else None
if self.align_params is not None:
for attr, value in self.align_params.items():
setattr(self, attr, value)
self.rkd_angle_loss_func = RKdAngle()
self.rkd_dist_func = RkdDistance(eps=eps)
self.name = name
self.loss_func = BASIC_LOSS.get(loss_function)(**params)
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss_dict["{}_{}_{}_angle_{}".format(self.name, pair[0], pair[
1], idx)] = self.rkd_angle_loss_func(out1, out2)
loss_dict["{}_{}_{}_dist_{}".format(self.name, pair[0], pair[
1], idx)] = self.rkd_dist_func(out1, out2)
return loss_dict
class SegPairWiseLoss(DistanceLoss):
"""
Segmentation pairwise loss, see https://arxiv.org/pdf/1903.04197.pdf
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string): key of the tensor used to calculate loss if the submodel
output type is dict.
mode(string, optional): loss mode. It supports l1, l2 and smooth_l1. Default: l2.
reduction(string, optional): the reduction params for F.kl_div. Default: mean.
name(string, optional): loss name. Default: seg_pair_wise_loss.
"""
def __init__(self,
model_name_pairs=[],
key=None,
mode="l2",
reduction="mean",
name="seg_pair_wise_loss"):
super().__init__(mode=mode, reduction=reduction)
assert isinstance(model_name_pairs, list)
assert key is not None
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
self.pool1 = nn.AdaptiveAvgPool2D(output_size=[2, 2])
self.pool2 = nn.AdaptiveAvgPool2D(output_size=[2, 2])
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]][self.key]
out2 = predicts[pair[1]][self.key]
pool1 = self.pool1(out1)
pool2 = self.pool2(out2)
loss_name = "{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)
loss_dict[loss_name] = super().forward(pool1, pool2)
return loss_dict
class SegChannelwiseLoss(KLLoss):
"""
Segmentation channel wise loss, see `Channel-wise Distillation for Semantic Segmentation`.
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string): key of the tensor used to calculate loss if the submodel
output type is dict.
act(string, optional): activation function used for the input and label tensor.
Default: softmax.
axis(int, optional): the axis for the act. Default: -1.
reduction(str, optional): the reduction params for F.kl_div. Default: mean.
name(string, optional): loss name. Default: seg_ch_wise_loss.
"""
def __init__(self,
model_name_pairs=[],
key=None,
act='softmax',
axis=-1,
reduction="mean",
name="seg_ch_wise_loss"):
super().__init__(act, axis, reduction)
assert isinstance(model_name_pairs, list)
assert key is not None
self.model_name_pairs = model_name_pairs
self.key = key
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]][self.key]
out2 = predicts[pair[1]][self.key]
loss_name = "{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)
loss_dict[loss_name] = super().forward(out1, out2)
if self.layers_name != None:
assert len(self.layers_name
) == 2, "length of layers_name must be equal to 2."
out1 = out1[self.layers_name[0]]
out2 = out2[self.layers_name[1]]
if self.temperature != 1.0:
out1 = out1 / self.temperature
out2 = out2 / self.temperature
loss_dict["{}_{}_{}_{}_{}".format(self.loss_function, pair[0], pair[
1], self.layers_name[0] if self.layers_name != None else "0", \
self.layers_name[1] if self.layers_name != None else "0")] = self.loss_func(out1, out2)
return loss_dict
......@@ -7,7 +7,7 @@ import paddle
import paddle.nn as nn
from paddle.vision.models import MobileNetV1
import paddle.vision.transforms as T
from paddleslim.dygraph.dist import Distill, AdaptorBase
from paddleslim.dygraph.dist import Distill, config2yaml
from paddleslim.common.log_helper import get_logger
_logger = get_logger(
......@@ -19,42 +19,30 @@ class TestImperativeDistill(unittest.TestCase):
self.s_model, self.t_model = self.prepare_model()
self.t_model.eval()
self.distill_configs = self.prepare_config()
self.adaptor = self.prepare_adaptor()
def prepare_model(self):
return MobileNetV1(), MobileNetV1()
def prepare_config(self):
distill_configs = [{
's_feature_idx': 0,
't_feature_idx': 0,
'feature_type': 'hidden',
'loss_function': 'l2'
'loss_function': 'MSELoss',
'layers': [
{
"layers_name": ["conv1", "conv1"]
},
{
"layers_name": ["conv2_2", "conv2_2"]
},
]
}, {
's_feature_idx': 1,
't_feature_idx': 1,
'feature_type': 'hidden',
'loss_function': 'l2'
}, {
's_feature_idx': 0,
't_feature_idx': 0,
'feature_type': 'logits',
'loss_function': 'l2'
'loss_function': 'CELoss',
'temperature': 1.0,
'layers': [{
"layers_name": ["fc", "fc"]
}, ]
}]
return distill_configs
def prepare_adaptor(self):
class Adaptor(AdaptorBase):
def mapping_layers(self):
mapping_layers = {}
mapping_layers['hidden_0'] = 'conv1'
mapping_layers['hidden_1'] = 'conv2_2'
mapping_layers['hidden_2'] = 'conv3_2'
mapping_layers['logits_0'] = 'fc'
return mapping_layers
return Adaptor
def test_distill(self):
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
......@@ -97,7 +85,7 @@ class TestImperativeDistill(unittest.TestCase):
for batch_id, data in enumerate(train_reader):
img = paddle.to_tensor(data[0])
label = paddle.to_tensor(data[1])
student_out, teacher_out, distill_loss = model(img)
distill_loss, student_out, teacher_out = model(img)
loss = paddle.nn.functional.loss.cross_entropy(student_out,
label)
avg_loss = paddle.mean(loss)
......@@ -112,7 +100,7 @@ class TestImperativeDistill(unittest.TestCase):
self.s_model.train()
distill_model = Distill(self.distill_configs, self.s_model,
self.t_model, self.adaptor, self.adaptor)
self.t_model)
train(distill_model)
......@@ -136,31 +124,26 @@ class TestImperativeDistillCase1(TestImperativeDistill):
return Model(), Model()
def prepare_adaptor(self):
class Adaptor(AdaptorBase):
def mapping_layers(self):
mapping_layers = {}
mapping_layers['hidden_1'] = 'conv2'
if self.add_tensor:
mapping_layers['hidden_0'] = self.model.conv1_out
mapping_layers['hidden_2'] = self.model.conv3_out
return mapping_layers
return Adaptor
def prepare_config(self):
distill_configs = [{
's_feature_idx': 0,
't_feature_idx': 0,
'feature_type': 'hidden',
'loss_function': 'l2'
'loss_function': 'MSELoss',
'layers': [
{
"layers_name": ["conv1", "conv1"]
},
{
"layers_name": ["conv2", "conv3"]
},
]
}, {
's_feature_idx': 1,
't_feature_idx': 2,
'feature_type': 'hidden',
'loss_function': 'l2'
'loss_function': 'CELoss',
'temperature': 1.0,
'layers': [{
"layers_name": ["fc", "fc"]
}, ]
}]
return distill_configs
config2yaml(distill_configs, 'test.yaml')
return './test.yaml'
if __name__ == '__main__':
......
......@@ -24,18 +24,14 @@ import paddle.nn.functional as F
from paddleslim.dygraph.dist.losses import CombinedLoss
# basic loss
from paddleslim.dygraph.dist.losses import DistanceLoss
from paddleslim.dygraph.dist.losses import CELoss
from paddleslim.dygraph.dist.losses import DMLLoss
from paddleslim.dygraph.dist.losses import RkdDistance
from paddleslim.dygraph.dist.losses import RKdAngle
from paddleslim.dygraph.dist.losses.basic_loss import DistanceLoss
from paddleslim.dygraph.dist.losses.basic_loss import CELoss
from paddleslim.dygraph.dist.losses.basic_loss import DMLLoss
from paddleslim.dygraph.dist.losses.basic_loss import RkdDistance
from paddleslim.dygraph.dist.losses.basic_loss import RKdAngle
# distillation loss
from paddleslim.dygraph.dist.losses import DistillationDistanceLoss
from paddleslim.dygraph.dist.losses import DistillationRKDLoss
from paddleslim.dygraph.dist.losses import DistillationDMLLoss
from paddleslim.dygraph.dist.losses import SegPairWiseLoss
from paddleslim.dygraph.dist.losses import SegChannelwiseLoss
from paddleslim.dygraph.dist.losses import DistillationLoss
import numpy as np
......@@ -70,14 +66,13 @@ class TestDistanceLoss(unittest.TestCase):
out = np.sum(diff)
return out
def dist_np_distance_loss(
self,
predicts,
mode="l2",
reduction="none",
model_name_pairs=(["", ""]),
key=None,
name="loss_distance", ):
def dist_np_distance_loss(self,
predicts,
loss_function=None,
mode="l2",
reduction="none",
model_name_pairs=(["", ""]),
key=None):
loss_dict = dict()
for idx, pair in enumerate(model_name_pairs):
out1 = predicts[pair[0]]
......@@ -85,10 +80,12 @@ class TestDistanceLoss(unittest.TestCase):
if key is not None:
out1 = out1[key]
out2 = out2[key]
else:
key = 0
loss = self.np_distance_loss(
out1, out2, mode=mode, reduction=reduction)
loss_dict["{}_{}_{}_{}_{}".format(name, mode, pair[0], pair[1],
idx)] = loss
loss_dict["{}_{}_{}_{}_{}".format(
str(loss_function), pair[0], pair[1], key, key)] = loss
return loss_dict
......@@ -120,7 +117,7 @@ class TestDistanceLoss(unittest.TestCase):
"student": paddle.rand(shape),
"teacher": paddle.rand(shape),
}
self.calc_distillation_distance_loss(predicts, pairs, key=None)
self.calc_distillation_distance_loss(predicts, pairs)
predicts = {
"student": {
......@@ -143,13 +140,15 @@ class TestDistanceLoss(unittest.TestCase):
paddle.set_device(device)
for reduction in reductions:
for mode in modes:
loss_func = DistillationDistanceLoss(
loss_func = DistillationLoss(
mode=mode,
loss_function='DistanceLoss',
model_name_pairs=pairs,
key=key,
layers_name=[key, key] if key != None else None,
reduction=reduction)
np_result_dict = self.dist_np_distance_loss(
predicts,
loss_function='DistanceLoss',
mode=mode,
reduction=reduction,
model_name_pairs=pairs,
......@@ -358,12 +357,11 @@ class TestDMLLoss(unittest.TestCase):
np_loss = self.np_dml_loss(x, target)
self.assertTrue(np.allclose(np_loss, pd_loss))
def dist_np_dml_loss(
self,
predicts,
model_name_pairs=(["", ""]),
key=None,
name="loss_dml", ):
def dist_np_dml_loss(self,
predicts,
loss_function=None,
model_name_pairs=(["", ""]),
key=None):
loss_dict = dict()
for idx, pair in enumerate(model_name_pairs):
out1 = predicts[pair[0]]
......@@ -371,8 +369,11 @@ class TestDMLLoss(unittest.TestCase):
if key is not None:
out1 = out1[key]
out2 = out2[key]
loss_dict["{}_{}_{}_{}".format(name, pair[0], pair[1],
idx)] = self.np_dml_loss(out1, out2)
else:
key = 0
loss_dict["{}_{}_{}_{}_{}".format(
str(loss_function), pair[0], pair[1], key,
key)] = self.np_dml_loss(out1, out2)
return loss_dict
def calc_distillation_dml_loss(self, predicts, pairs, key=None):
......@@ -382,11 +383,19 @@ class TestDMLLoss(unittest.TestCase):
for device in devices:
paddle.set_device(device)
loss_func = DistillationDMLLoss(
act="softmax", model_name_pairs=pairs, key=key)
loss_func = DistillationLoss(
act="softmax",
model_name_pairs=pairs,
loss_function='DMLLoss',
layers_name=[key, key] if key != None else None)
np_result_dict = self.dist_np_dml_loss(
predicts, model_name_pairs=pairs, key=key)
predicts,
model_name_pairs=pairs,
loss_function='DMLLoss',
key=key)
pd_result_dict = loss_func(predicts, None)
print(pd_result_dict.keys())
print(np_result_dict.keys())
for k in np_result_dict:
pd_result = pd_result_dict[k].numpy()
np_result = np_result_dict[k]
......@@ -526,7 +535,7 @@ class TestRKDLoss(unittest.TestCase):
predicts,
model_name_pairs=(["", ""]),
key=None,
name="loss_rkd", ):
name="RKDLoss", ):
loss_dict = dict()
for idx, pair in enumerate(model_name_pairs):
out1 = predicts[pair[0]]
......@@ -534,11 +543,12 @@ class TestRKDLoss(unittest.TestCase):
if key is not None:
out1 = out1[key]
out2 = out2[key]
loss_dict["{}_{}_{}_angle_{}".format(name, pair[0], pair[
1], idx)] = self.np_rkd_angle(out1, out2)
else:
key = 0
loss_dict["{}_{}_{}_{}_{}".format(name, pair[0], pair[
1], key, key)] = self.np_rkd_angle(
out1, out2) + self.np_rkd_distance(out1, out2)
loss_dict["{}_{}_{}_dist_{}".format(name, pair[0], pair[
1], idx)] = self.np_rkd_distance(out1, out2)
return loss_dict
def calc_distillation_rkd_loss(self, predicts, pairs, key=None):
......@@ -548,7 +558,10 @@ class TestRKDLoss(unittest.TestCase):
for device in devices:
paddle.set_device(device)
loss_func = DistillationRKDLoss(model_name_pairs=pairs, key=key)
loss_func = DistillationLoss(
model_name_pairs=pairs,
loss_function='RKDLoss',
layers_name=[key, key] if key != None else None)
np_result_dict = self.dist_np_rkd_loss(
predicts, model_name_pairs=pairs, key=key)
pd_result_dict = loss_func(predicts, None)
......@@ -623,13 +636,12 @@ class TestCombinedLoss(unittest.TestCase):
log_soft_target, soft_x)) / 2.0
return loss
def dist_np_dml_loss(
self,
predicts,
model_name_pairs=(["", ""]),
key=None,
act="softmax",
name="loss_dml", ):
def dist_np_dml_loss(self,
predicts,
model_name_pairs=(["", ""]),
loss_function=None,
key=None,
act="softmax"):
loss_dict = dict()
for idx, pair in enumerate(model_name_pairs):
out1 = predicts[pair[0]]
......@@ -637,20 +649,24 @@ class TestCombinedLoss(unittest.TestCase):
if key is not None:
out1 = out1[key]
out2 = out2[key]
loss_dict["{}_{}_{}_{}".format(name, pair[0], pair[1],
idx)] = self.np_dml_loss(out1, out2)
loss_dict["{}_{}_{}_{}_0".format(
str(loss_function), pair[0], pair[1], idx)] = self.np_dml_loss(
out1, out2)
return loss_dict
def np_combined_loss(self, predicts, loss_cfg_list):
# NOTE, dml is set as the list for combined loss
loss_dict = dict()
for idx, loss_func in enumerate(loss_cfg_list):
cfg = copy.deepcopy(loss_func["DistillationDMLLoss"])
cfg = copy.deepcopy(loss_func)
weight = cfg.pop("weight")
loss = self.dist_np_dml_loss(predicts, **cfg)
if isinstance(loss, np.ndarray):
loss = {"loss_{}_{}".format(str(loss), idx): loss}
loss = {
"{}_{}_{}".format(loss_func['loss_function'],
str(loss), idx): loss
}
else:
loss = {
"{}_{}".format(key, idx): loss[key] * weight
......@@ -677,12 +693,10 @@ class TestCombinedLoss(unittest.TestCase):
devices.append("gpu")
loss_cfg_list = [{
"DistillationDMLLoss": {
"weight": 1.0,
"act": "softmax",
"model_name_pairs": pairs,
"key": None
}
"loss_function": "DMLLoss",
"weight": 1.0,
"act": "softmax",
"model_name_pairs": pairs
}, ]
for device in devices:
......@@ -696,95 +710,5 @@ class TestCombinedLoss(unittest.TestCase):
self.assertTrue(np.allclose(np_result, pd_result))
class TestSegPairWiseLoss(unittest.TestCase):
def calculate_gt_loss(self, x, y):
pool_x = F.adaptive_avg_pool2d(x, [2, 2])
pool_y = F.adaptive_avg_pool2d(y, [2, 2])
loss = F.mse_loss(pool_x, pool_y)
return loss
def test_seg_pair_wise_loss(self):
shape = [1, 3, 10, 10]
x = paddle.rand(shape)
y = paddle.rand(shape)
model_name_pairs = [['student', 'teacher']]
key = 'hidden_0_0'
inputs = {
model_name_pairs[0][0]: {
key: x
},
model_name_pairs[0][1]: {
key: y
}
}
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
loss_func = SegPairWiseLoss(model_name_pairs, key)
pd_loss_dict = loss_func(inputs, None)
pd_loss = pd_loss_dict['seg_pair_wise_loss_student_teacher_0']
gt_loss = self.calculate_gt_loss(x, y)
self.assertTrue(np.allclose(pd_loss.numpy(), gt_loss.numpy()))
class TestSegChannelWiseLoss(unittest.TestCase):
def init(self):
self.act_name = None
self.act_func = None
def calculate_gt_loss(self, x, y, act=None):
if act is not None:
x = act(x)
y = act(y)
x = paddle.log(x)
loss = F.kl_div(x, y)
return loss
def test_seg_pair_wise_loss(self):
self.init()
shape = [1, 3, 10, 10]
x = paddle.rand(shape)
y = paddle.rand(shape)
model_name_pairs = [['student', 'teacher']]
key = 'hidden_0_0'
inputs = {
model_name_pairs[0][0]: {
key: x
},
model_name_pairs[0][1]: {
key: y
}
}
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
loss_func = SegChannelwiseLoss(model_name_pairs, key, self.act_name)
pd_loss_dict = loss_func(inputs, None)
pd_loss = pd_loss_dict['seg_ch_wise_loss_student_teacher_0']
gt_loss = self.calculate_gt_loss(x, y, self.act_func)
self.assertTrue(np.allclose(pd_loss.numpy(), gt_loss.numpy()))
class TestSegChannelWiseLoss1(TestSegChannelWiseLoss):
def init(self):
self.act_name = "softmax"
self.act_func = F.softmax
class TestSegChannelWiseLoss1(TestSegChannelWiseLoss):
def init(self):
self.act_name = "sigmoid"
self.act_func = F.sigmoid
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册