未验证 提交 36a9f6f0 编写于 作者: C ceci3 提交者: GitHub

update distill (#892)

* polish distill
上级 5ec362f2
# TinyBERT: Distilling BERT for Natural Language Understanding
以下是本例的简要目录结构及说明:
```
.
├── task_distill.py # 在特定任务上下的蒸馏脚本
└── README.md # 文档,本文件
```
## 简介
本目录下的实验主要参考论文[《TinyBERT: Distilling BERT for Natural Language Understanding》](https://arxiv.org/abs/1909.10351)实现。
TinyBERT中蒸馏的整体过程:首先进行通用蒸馏,然后用数据增强后的数据,在特定任务上进行蒸馏,本文主要进行了第二阶段的蒸馏,模型是利用第一阶段得到的通用小模型`tinybert-6l-768d-v2`进行初始化。
<p align="center">
<img src="./imgs/tinybert.png" width="950"/><br />
TinyBERT蒸馏流程图
</p>
在模型蒸馏中,较大的模型(在本例中是BERT base)通常被称为教师模型,较小的模型(在本例中是层数为6的BERT,下文都称TinyBERT6)通常被称为学生模型。
知识的蒸馏通常是通过让学生模型学习相关的蒸馏相损失函数实现,在本实验中,蒸馏的学习目标由两个部分组成,分别是中间层的蒸馏损失和预测层的蒸馏损失。其中,中间层的蒸馏包括对Embedding层的蒸馏、对每个Transformer layer输出的蒸馏、以及对每个Transformer中attention矩阵(softmax之前的结果)的蒸馏,三者均采用的是均方误差损失函数。而预测层蒸馏的学习目标则是学生模型输出的logits和教师模型输出的logits的交叉熵损失。
由于教师模型是12层,学生模型的层数少于教师模型的层数,因此需要选择一种layer mapping的方式。论文中采用了一种固定的映射方式,当学生模型的层数为教师模型的1/2时,学生第i层的attention矩阵,需要学习教师的第2i+1层的attention矩阵,Transformer layer输出同理。
实验分为两个大的训练过程:先对BERT-base进行微调,得到教师模型,再进行蒸馏的训练。其中,蒸馏过程也分为两个步骤:先对中间层进行蒸馏多个epochs(论文中针对具体任务可能是10、20或者30个),再对预测层蒸馏3个epochs。
需要注意的是,在使用不同教师模型时,`tinybert-6l-768d-v2``tinybert-4l-312d-v2`这两个v2版本的预训练模型中开放的从学生embedding输出、transformer中间层输出到教师相应输出的转换矩阵是每层独立的,而其他的`tinybert-6l-768d``tinybert-4l-312d``tinybert-6l-768d-zh``tinybert-4l-312-zh`则是多层之间的参数共用一个转换矩阵的。
### 安装PaddleNLP和Paddle
本教程基于PaddleNLP中BERT模型进行压缩,依赖PaddleNLP和Paddle。
```shell
pip install paddlenlp
pip install paddlepaddle_gpu
```
## 数据、预训练模型介绍及获取
本实验使用GLUE中数据集中的训练集作为训练语料,用数据集中的验证集评估模型的效果。
运行本目录下的实验,数据集会被自动下载到`paddlenlp.utils.env.DATA_HOME` 路径下,例如在linux系统下,对于GLUE中的QQP数据集,默认存储路径是`~/.paddlenlp/datasets/Glue/QQP`
对于BERT的fine-tuning任务,本实验中使用了预训练模型`bert-base-uncased`。同样,这几个模型在训练时会被自动下载到`paddlenlp.utils.env.MODEL_HOME`路径下。例如,对于`bert-base-uncased`模型,在linux系统下,会被下载到`~/.paddlenlp/models/bert-base-uncased`下。
## 蒸馏实验过程
### 对BERT Fine-tuning得到教师模型
首先需要对Pretrain-Model在实际的下游任务上进行Fine-tuning,得到需要压缩的模型。Fine-tuning流程参考[Fine-tuning教程](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/examples/language_model/bert/README.md)
训练完成之后,可将训练效果最好的模型保存在本项目下的`pretrained_models/$TASK_NAME/`下。模型目录下有`model_config.json`, `model_state.pdparams`, `tokenizer_config.json``vocab.txt`这几个文件。
### 对TinyBERT在特定任务下蒸馏
先蒸馏中间层:
```shell
export CUDA_VISIBLE_DEVICES=0
export TASK_NAME=SST-2
export TEACHER_DIR=./pretrained_models/SST-2/best_model_610
python task_distill.py \
--model_type tinybert \
--student_model_name_or_path tinybert-6l-768d-v2 \
--task_name $TASK_NAME \
--intermediate_distill \
--max_seq_length 64 \
--batch_size 32 \
--T 1 \
--teacher_model_type bert \
--teacher_path $TEACHER_DIR \
--learning_rate 5e-5 \
--num_train_epochs 20 \
--logging_steps 10 \
--save_steps 10 \
--output_dir ./tmp/$TASK_NAME/ \
--distill_config ./distill_stage1.yaml \
--device gpu
```
其中参数释义如下:
- `model_type` 学生模型类型,默认且目前仅支持tinybert。
- `student_model_name_or_path` 中间层蒸馏后,学生模型存放的目录
- `distill_config` 蒸馏配置文件
- `max_seq_length` 表示最大句子长度,超过该长度将被截断。默认:128
- `T` softmax的温度,用于对softmax做平滑,在训练中起到放大负标签效果的作用。默认:1
- `teacher_model_type` 教师模型的类型,默认且目前仅支持bert
- `teacher_path` 教师Fine-tuned模型的目录
- `output_dir` 学生模型存放的目录
- `device` 表示运行该程序的设备,默认是gpu
然后对预测层进行蒸馏:
```shell
export TEACHER_DIR=../pretrained_models/SST-2/best_model_610
python task_distill.py \
--model_type tinybert \
--student_model_name_or_path tmp/TASK_NAME best_inter_model \
--task_name $TASK_NAME \
--max_seq_length 64 \
--batch_size 32 \
--T 1 \
--teacher_model_type bert \
--teacher_path $TEACHER_DIR \
--learning_rate 3e-5 \
--num_train_epochs 3 \
--logging_steps 10 \
--save_steps 10 \
--output_dir ./tmp/$TASK_NAME/ \
--distill_config ./distill_stage2.yaml \
--device gpu
```
其中参数释义如下:
所有参数说明同上。
### 实验中使用的超参数
| | SST-2 | QQP | MRPC | CoLA | RTE | MNLI | QNLI |
| -------------------------------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- |
| batch_size | 32 | 32 | 32 | 32 | 32 | 32 | 32 |
| max_seq_length | 64 | 128 | 128 | 64 | 128 | 128 | 128 |
| max_epochs_of_intermediate_layer | 20 | 10 | 20 | 50 | 20 | 10 | 10 |
| max_epochs_of_prediction_layer | 3 | 3 | 3 | 3 | 3 | 3 | 3 |
| learning_rate(inter/pred) | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 |
## 蒸馏实验结果
本文档的实验基于TinyBERT的6层、hidden_size为768的通用蒸馏得到的模型,用未使用数据增强的原始数据集训练,并基于验证集进行评价。得到以下实验结果:
| | SST-2 | QQP(acc/f1) | MRPC(acc/f1) | CoLA | RTE | MNLI-m | MNLI-mm | QNLI |
| ----------------- | ----- | ----------- | ------------ | ----- | ----- | ------ | ------- | ----- |
| BERT-base | 93.00 | 90.58/87.35 | 88.23/91.67 | 59.56 | 73.65 | 84.42 | 84.83 | 91.78 |
| TinyBERT(6l-768d) | 93.00 | 91.13/88.20 | 88.48/91.91 | 52.64 | 72.94 | 84.57 | 84.63 | 91.36 |
## 参考文献
Jiao X, Yin Y, Shang L, et al. [TinyBERT: Distilling BERT for Natural Language Understanding](https://arxiv.org/abs/1909.10351)[J]. arXiv preprint arXiv:1909.10351v5, 2020.
- DistillConfig:
loss_function: MSELoss
model_name_pairs:
- - student_0
- teacher_0
weight: 1.0
- layers:
- layers_name: ['tinybert.embeddings', 'bert.embeddings']
- layers_name: ['tinybert.encoder.layers.0', 'bert.encoder.layers.1']
- layers_name: ['tinybert.encoder.layers.1', 'bert.encoder.layers.3']
- layers_name: ['tinybert.encoder.layers.2', 'bert.encoder.layers.5']
- layers_name: ['tinybert.encoder.layers.3', 'bert.encoder.layers.7']
- layers_name: ['tinybert.encoder.layers.4', 'bert.encoder.layers.9']
- layers_name: ['tinybert.encoder.layers.5', 'bert.encoder.layers.11']
- layers_name: ['tinybert.encoder.layers.0.self_attn', 'bert.encoder.layers.1.self_attn']
- layers_name: ['tinybert.encoder.layers.1.self_attn', 'bert.encoder.layers.3.self_attn']
- layers_name: ['tinybert.encoder.layers.2.self_attn', 'bert.encoder.layers.5.self_attn']
- layers_name: ['tinybert.encoder.layers.3.self_attn', 'bert.encoder.layers.7.self_attn']
- layers_name: ['tinybert.encoder.layers.4.self_attn', 'bert.encoder.layers.9.self_attn']
- layers_name: ['tinybert.encoder.layers.5.self_attn', 'bert.encoder.layers.11.self_attn']
- DistillConfig:
loss_function: CELoss
model_name_pairs:
- - student_0
- teacher_0
weight: 1.0
- layers:
- layers_name: ['classifier', 'classifier']
temperature: 1.0
export CUDA_VISIBLE_DEVICES=0
export TASK_NAME=SST-2
export TEACHER_DIR=/root/work/Distill_PaddleSlim/PaddleNLP/examples/model_compression/tinybert/best_model_610
python3.7 task_distill.py \
--model_type tinybert \
--student_model_name_or_path tinybert-6l-768d-v2 \
--task_name $TASK_NAME \
--intermediate_distill \
--max_seq_length 64 \
--batch_size 32 \
--T 1 \
--teacher_model_type bert \
--teacher_path $TEACHER_DIR \
--learning_rate 5e-5 \
--num_train_epochs 20 \
--logging_steps 10 \
--save_steps 10 \
--output_dir ./tmp/$TASK_NAME/ \
--device gpu
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import sys
import random
import time
import math
from functools import partial
import numpy as np
import paddle
from paddle.io import DataLoader
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.metric import Accuracy
from paddlenlp.datasets import load_dataset
from paddlenlp.data import Stack, Tuple, Pad, Dict
from paddlenlp.data.sampler import SamplerHelper
from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
import paddlenlp.transformers as T
from paddleslim import Distill
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
METRIC_CLASSES = {
"cola": Mcc,
"sst-2": Accuracy,
"mrpc": AccuracyAndF1,
"sts-b": PearsonAndSpearman,
"qqp": AccuracyAndF1,
"mnli": Accuracy,
"qnli": Accuracy,
"rte": Accuracy,
}
MODEL_CLASSES = {
"bert": (T.BertForSequenceClassification, T.BertTokenizer),
"tinybert": (T.TinyBertForSequenceClassification, T.TinyBertTokenizer),
}
def parse_args():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--task_name",
default=None,
type=str,
required=True,
help="The name of the task to train selected in the list: " +
", ".join(METRIC_CLASSES.keys()), )
parser.add_argument(
"--model_type",
default="tinybert",
type=str,
required=True,
help="Model type selected in the list: " +
", ".join(MODEL_CLASSES.keys()), )
parser.add_argument(
"--teacher_model_type",
default="bert",
type=str,
required=True,
help="Model type selected in the list: " +
", ".join(MODEL_CLASSES.keys()), )
parser.add_argument(
"--student_model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: "
+ ", ".join(
sum([
list(classes[-1].pretrained_init_configuration.keys())
for classes in MODEL_CLASSES.values()
], [])), )
parser.add_argument(
"--distill_config",
default=None,
type=str,
help="distill config file path")
parser.add_argument(
"--teacher_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model.")
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--glue_dir",
default="/root/.paddlenlp/datasets/Glue/",
type=str,
required=False,
help="The Glue directory.", )
parser.add_argument(
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.", )
parser.add_argument(
"--learning_rate",
default=1e-4,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument(
"--num_train_epochs",
default=3,
type=int,
help="Total number of training epochs to perform.", )
parser.add_argument(
"--logging_steps",
type=int,
default=100,
help="Log every X updates steps.")
parser.add_argument(
"--save_steps",
type=int,
default=100,
help="Save checkpoint every X updates steps.")
parser.add_argument(
"--batch_size",
default=32,
type=int,
help="Batch size per GPU/CPU for training.", )
parser.add_argument(
"--T",
default=1,
type=int,
help="Temperature for softmax", )
parser.add_argument(
"--use_aug",
action="store_true",
help="Whether to use augmentation data to train.", )
parser.add_argument(
"--intermediate_distill",
action="store_true",
help="Whether distilling intermediate layers. If False, it means prediction layer distillation.",
)
parser.add_argument(
"--weight_decay",
default=0.0,
type=float,
help="Weight decay if we apply some.")
parser.add_argument(
"--warmup_steps",
default=0,
type=int,
help="Linear warmup over warmup_steps. If > 0: Override warmup_proportion"
)
parser.add_argument(
"--warmup_proportion",
default=0.1,
type=float,
help="Linear warmup proportion over total steps.")
parser.add_argument(
"--adam_epsilon",
default=1e-6,
type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument(
"--seed", default=42, type=int, help="random seed for initialization")
parser.add_argument(
"--device",
default="gpu",
type=str,
help="The device to select to train the model, is must be cpu/gpu/xpu.")
args = parser.parse_args()
return args
def set_seed(args):
# Use the same data seed(for data shuffle) for all procs to guarantee data
# consistency after sharding.
random.seed(args.seed)
np.random.seed(args.seed)
# Maybe different op seeds(for dropout) for different procs is better. By:
# `paddle.seed(args.seed + paddle.distributed.get_rank())`
paddle.seed(args.seed)
@paddle.no_grad()
def evaluate(model, metric, data_loader):
model.eval()
metric.reset()
for batch in data_loader:
input_ids, segment_ids, labels = batch
logits = model(input_ids, segment_ids)
correct = metric.compute(logits, labels)
metric.update(correct)
res = metric.accumulate()
if isinstance(metric, AccuracyAndF1):
print(
"acc: %s, precision: %s, recall: %s, f1: %s, acc and f1: %s, " % (
res[0],
res[1],
res[2],
res[3],
res[4], ),
end='')
elif isinstance(metric, Mcc):
print("mcc: %s, " % (res[0]), end='')
elif isinstance(metric, PearsonAndSpearman):
print(
"pearson: %s, spearman: %s, pearson and spearman: %s, " %
(res[0], res[1], res[2]),
end='')
else:
print("acc: %s, " % (res), end='')
model.train()
return res[0] if isinstance(metric, (AccuracyAndF1, Mcc,
PearsonAndSpearman)) else res
def convert_example(example,
tokenizer,
label_list,
max_seq_length=512,
is_test=False):
"""convert a glue example into necessary features"""
if not is_test:
# `label_list == None` is for regression task
label_dtype = "int64" if label_list else "float32"
# Get the label
label = example['labels']
label = np.array([label], dtype=label_dtype)
# Convert raw text to feature
if (int(is_test) + len(example)) == 2:
example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
else:
example = tokenizer(
example['sentence1'],
text_pair=example['sentence2'],
max_seq_len=max_seq_length)
if not is_test:
return example['input_ids'], example['token_type_ids'], label
else:
return example['input_ids'], example['token_type_ids']
def do_train(args):
paddle.set_device(args.device)
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
set_seed(args)
args.task_name = args.task_name.lower()
metric_class = METRIC_CLASSES[args.task_name]
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
if args.use_aug:
aug_data_file = os.path.join(
os.path.join(args.glue_dir, args.task_name), "train_aug.tsv"),
train_ds = load_dataset(
'glue', args.task_name, data_files=aug_data_file)
else:
train_ds = load_dataset('glue', args.task_name, splits='train')
tokenizer = tokenizer_class.from_pretrained(args.student_model_name_or_path)
trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=train_ds.label_list,
max_seq_length=args.max_seq_length)
train_ds = train_ds.map(trans_func, lazy=True)
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_ds, batch_size=args.batch_size, shuffle=True)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
Stack(dtype="int64" if train_ds.label_list else "float32") # label
): fn(samples)
train_data_loader = DataLoader(
dataset=train_ds,
batch_sampler=train_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
if args.task_name == "mnli":
dev_ds_matched, dev_ds_mismatched = load_dataset(
'glue', args.task_name, splits=["dev_matched", "dev_mismatched"])
dev_ds_matched = dev_ds_matched.map(trans_func, lazy=True)
dev_ds_mismatched = dev_ds_mismatched.map(trans_func, lazy=True)
dev_batch_sampler_matched = paddle.io.BatchSampler(
dev_ds_matched, batch_size=args.batch_size, shuffle=False)
dev_data_loader_matched = DataLoader(
dataset=dev_ds_matched,
batch_sampler=dev_batch_sampler_matched,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
dev_batch_sampler_mismatched = paddle.io.BatchSampler(
dev_ds_mismatched, batch_size=args.batch_size, shuffle=False)
dev_data_loader_mismatched = DataLoader(
dataset=dev_ds_mismatched,
batch_sampler=dev_batch_sampler_mismatched,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
else:
dev_ds = load_dataset('glue', args.task_name, splits='dev')
dev_ds = dev_ds.map(trans_func, lazy=True)
dev_batch_sampler = paddle.io.BatchSampler(
dev_ds, batch_size=args.batch_size, shuffle=False)
dev_data_loader = DataLoader(
dataset=dev_ds,
batch_sampler=dev_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=True)
num_classes = 1 if train_ds.label_list == None else len(train_ds.label_list)
student = model_class.from_pretrained(
args.student_model_name_or_path, num_classes=num_classes)
teacher_model_class, _ = MODEL_CLASSES[args.teacher_model_type]
teacher = teacher_model_class.from_pretrained(
args.teacher_path, num_classes=num_classes)
teacher.eval()
if paddle.distributed.get_world_size() > 1:
student = paddle.DataParallel(student, find_unused_parameters=True)
teacher = paddle.DataParallel(teacher, find_unused_parameters=True)
if args.max_steps > 0:
num_training_steps = args.max_steps
num_train_epochs = math.ceil(num_training_steps /
len(train_data_loader))
else:
num_training_steps = len(train_data_loader) * args.num_train_epochs
num_train_epochs = args.num_train_epochs
warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion
lr_scheduler = T.LinearDecayWithWarmup(args.learning_rate,
num_training_steps, warmup)
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in student.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
beta1=0.9,
beta2=0.999,
epsilon=args.adam_epsilon,
parameters=student.parameters(),
weight_decay=args.weight_decay,
apply_decay_param_fun=lambda x: x in decay_params)
metric = metric_class()
pad_token_id = 0
global_step = 0
tic_train = time.time()
best_res = 0.0
assert os.path.exists(
args.distill_config), "distill file {} not exist.".format(
args.distill_config)
distill_model = Distill(
args.distill_config, student_models=[student],
teacher_models=[teacher])
for epoch in range(num_train_epochs):
for step, batch in enumerate(train_data_loader):
global_step += 1
input_ids, segment_ids, labels = batch
loss, _, _ = distill_model(input_ids, segment_ids)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()
if global_step % args.logging_steps == 0:
print(
"global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s"
% (global_step, num_training_steps, epoch, step,
paddle.distributed.get_rank(), loss, optimizer.get_lr(),
args.logging_steps / (time.time() - tic_train)))
tic_train = time.time()
if global_step % args.save_steps == 0 or global_step == num_training_steps:
tic_eval = time.time()
if args.task_name == "mnli":
res = evaluate(student, metric, dev_data_loader_matched)
evaluate(student, metric, dev_data_loader_mismatched)
print("eval done total : %s s" % (time.time() - tic_eval))
else:
res = evaluate(student, metric, dev_data_loader)
print("eval done total : %s s" % (time.time() - tic_eval))
if (best_res < res and global_step < num_training_steps or
global_step == num_training_steps
) and paddle.distributed.get_rank() == 0:
if global_step < num_training_steps:
output_dir = os.path.join(args.output_dir,
"distill_model_%d.pdparams" %
(global_step))
else:
output_dir = os.path.join(
args.output_dir, "distill_model_final.pdparams")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Need better way to get inner model of DataParallel
model_to_save = student._layers if isinstance(
student, paddle.DataParallel) else student
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
best_res = res
if global_step >= num_training_steps:
return
def print_arguments(args):
"""print arguments"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == "__main__":
args = parse_args()
print_arguments(args)
do_train(args)
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
from . import distill from . import distill
from .distill import * from .distill import *
from .distill_helpers import *
__all__ = [] __all__ = []
__all__ += distill.__all__ __all__ += distill.__all__
__all__ += distill_helpers.__all__
...@@ -17,207 +17,200 @@ import collections ...@@ -17,207 +17,200 @@ import collections
from collections import namedtuple from collections import namedtuple
import paddle.nn as nn import paddle.nn as nn
from . import losses from . import losses
from .losses.basic_loss import BASIC_LOSS
from .distill_helpers import yaml2config
__all__ = ['Distill', 'AdaptorBase'] __all__ = ['Distill']
class LayerConfig: class LayerConfig:
""" The key of config can be set"""
def __init__(self, def __init__(self,
s_feature_idx, model_name_pairs,
t_feature_idx, layers_name,
feature_type,
loss_function, loss_function,
weight=1.0, weight=1.0,
align=False, temperature=1.0,
align_shape=None): align_params=None,
self.s_feature_idx = s_feature_idx **loss_params):
self.t_feature_idx = t_feature_idx self.model_name_pairs = model_name_pairs
self.feature_type = feature_type self.layers_name = layers_name
if loss_function in ['l1', 'l2', 'smooth_l1']: if loss_function not in BASIC_LOSS.module_dict:
self.loss_function = 'DistillationDistanceLoss' raise NotImplementedError("loss function {} is not support. "
elif loss_function in ['dml']: "Support loss including {}".format(
self.loss_function = 'DistillationDMLLoss' loss_function,
elif loss_function in ['rkl']: BASIC_LOSS.module_dict.keys()))
self.loss_function = 'DistillationRKDLoss' self.loss_function = loss_function
elif hasattr(losses, loss_function):
self.loss_function = loss_function
else:
raise NotImplementedError("loss function is not support!!!")
self.weight = weight self.weight = weight
self.align = align self.temperature = temperature
self.align_shape = align_shape self.align_params = align_params
for k, v in loss_params.items():
setattr(self, k, v)
class AdaptorBase:
def __init__(self, model):
self.model = model def _add_hooks(model, outs, hook_layers_name):
self.add_tensor = False """
Get output by layer name.
def _get_activation(self, outs, name): models(nn.Layer): model need to be add hook.
outs(dict): save the middle outputs of model according to the name.
hook_layers_name(list): name of middle layers.
"""
def _get_activation(outs, name):
### TODO: need to support get input tensor
#outs[name] = {}
def get_output_hook(layer, input, output): def get_output_hook(layer, input, output):
#outs[name]["output"] = output
#outs[name]["input"] = input
outs[name] = output outs[name] = output
return get_output_hook return get_output_hook
def _add_distill_hook(self, outs, mapping_layers_name, layers_type): ### TODO: support DP model
""" for idx, (n, m) in enumerate(model.named_sublayers()):
Get output by layer name. if n in hook_layers_name:
outs(dict): save the middle outputs of model according to the name. m.register_forward_post_hook(_get_activation(outs, n))
mapping_layers(list): name of middle layers.
layers_type(list): type of the middle layers to calculate distill loss.
"""
### TODO: support DP model
for idx, (n, m) in enumerate(self.model.named_sublayers()):
if n in mapping_layers_name:
midx = mapping_layers_name.index(n)
m.register_forward_post_hook(
self._get_activation(outs, layers_type[midx]))
def mapping_layers(self):
raise NotImplementedError("function mapping_layers is not implemented")
class Distill(nn.Layer): class Distill(nn.Layer):
### TODO: support list of student model and teacher model """
def __init__(self, distill_configs, student_models, teacher_models, Distill API.
adaptors_S, adaptors_T): distill_configs(list(dict) | path): the list of distill config.
super(Distill, self).__init__() student_models(list(nn.Layer)): the list of student model, the state of student model must be training mode.
assert student_models.training, "The student model should be eval mode." teacher_models(list(nn.Layer)): the list of teacher model, the state of student model must be evaluate mode.
return_model_outputs(bool): whether to return model output. Default: True.
"""
self._distill_configs = distill_configs def __init__(self,
distill_configs,
student_models,
teacher_models,
return_model_outputs=True):
super(Distill, self).__init__()
if isinstance(student_models, nn.Layer):
student_models = [student_models]
if isinstance(teacher_models, nn.Layer):
teacher_models = [teacher_models]
for student_model in student_models:
assert student_model.training, "The student model should not be eval mode."
for teacher_model in teacher_models:
assert teacher_model.training is False, "The teacher model should be eval mode."
if isinstance(distill_configs, list):
self._distill_configs = distill_configs
elif os.path.exists(distill_configs):
if distill_configs.endswith(".yaml"):
self._distill_configs = yaml2config(distill_configs)
else:
raise NotImplementedError("distill config file type error!")
else:
raise NotImplementedError("distill config error!")
self._student_models = student_models self._student_models = student_models
self._teacher_models = teacher_models self._teacher_models = teacher_models
self._adaptors_S = adaptors_S(self._student_models) self._return_model_outputs = return_model_outputs
self._adaptors_T = adaptors_T(self._teacher_models)
self.stu_outs_dict, self.tea_outs_dict = self._prepare_outputs() self._loss_config_list = []
self.configs = []
for c in self._distill_configs: for c in self._distill_configs:
self.configs.append(LayerConfig(**c).__dict__) self._transpose_config(c)
self.distill_idx = self._get_distill_idx() self._hook_layers = self._extract_hook_position()
self._loss_config_list = []
for c in self.configs:
loss_config = {}
loss_config[str(c['loss_function'])] = {}
loss_config[str(c['loss_function'])]['weight'] = c['weight']
loss_config[str(c['loss_function'])]['key'] = c[
'feature_type'] + '_' + str(c['s_feature_idx']) + '_' + str(c[
't_feature_idx'])
### TODO: support list of student models and teacher_models
loss_config[str(c['loss_function'])][
'model_name_pairs'] = [['student', 'teacher']]
self._loss_config_list.append(loss_config)
# use self._loss_config_list to create all loss object # use self._loss_config_list to create all loss object
self.distill_loss = losses.CombinedLoss(self._loss_config_list) self.distill_loss = losses.CombinedLoss(self._loss_config_list)
self._output_tensor_dict = self._prepare_outputs()
def parameters(self):
params = []
for s_model in self._student_models:
params.extend(s_model.parameters())
return params
def _extract_hook_position(self):
""" extrat hook position according to config"""
model_hook_layers = {}
for config in self._loss_config_list:
model_name_pairs = config['model_name_pairs']
layers_name = config['layers_name']
for model_name_pair in model_name_pairs:
for idx, model_name in enumerate(model_name_pair):
if model_name not in model_hook_layers:
model_hook_layers[model_name] = [layers_name[idx]]
else:
model_hook_layers[model_name].append(layers_name[idx])
for model_name, hook_layers in model_hook_layers.items():
model_hook_layers[model_name] = list(set(hook_layers))
return model_hook_layers
def _transpose_config(self, config):
""" Transpose config to loss needed """
global_config = {}
if 'model_name_pairs' not in config:
global_config['model_name_pairs'] = [['student_0', 'teacher_0']]
else:
if isinstance(config['model_name_pairs'][0], str):
config['model_name_pairs'] = [config['model_name_pairs']]
global_config['model_name_pairs'] = config['model_name_pairs']
config.pop('model_name_pairs')
for key in config.keys():
if key != 'layers':
global_config[key] = config[key]
for per_layer_config in config['layers']:
per_layer_config.update(global_config)
self._loss_config_list.append(
LayerConfig(**per_layer_config).__dict__)
def _prepare_outputs(self): def _prepare_outputs(self):
""" """
Add hook to get the output tensor of target layer. Add hook to get the output tensor of target layer.
Returns:
stu_outs_dict(dict): the name and tensor for the student model,
such as {'hidden_0': tensor_0, ..}
tea_outs_dict(dict): the name and tensor for the teather model,
such as {'hidden_0': tensor_0, ..}
""" """
stu_outs_dict = collections.OrderedDict() outputs_tensor = {}
tea_outs_dict = collections.OrderedDict() for idx, m in enumerate(self._student_models):
stu_outs_dict = self._prepare_hook(self._adaptors_S, stu_outs_dict) hook_layers = self._hook_layers['student_{}'.format(idx)]
tea_outs_dict = self._prepare_hook(self._adaptors_T, tea_outs_dict) stu_outs = collections.OrderedDict()
return stu_outs_dict, tea_outs_dict outputs_tensor['student_{}'.format(idx)] = self._prepare_hook(
m, hook_layers, stu_outs)
def _prepare_hook(self, adaptors, outs_dict): for idx, m in enumerate(self._teacher_models):
hook_layers = self._hook_layers['teacher_{}'.format(idx)]
tea_outs = collections.OrderedDict()
outputs_tensor['teacher_{}'.format(idx)] = self._prepare_hook(
m, hook_layers, tea_outs)
return outputs_tensor
def _prepare_hook(self, model, hook_layers, outs_dict):
""" """
Add hook. Add hook.
""" """
mapping_layers = adaptors.mapping_layers() for layer in hook_layers:
for layer_type, layer in mapping_layers.items():
if isinstance(layer, str): if isinstance(layer, str):
adaptors._add_distill_hook(outs_dict, [layer], [layer_type]) _add_hooks(model, outs_dict, layer)
return outs_dict return outs_dict
def _get_distill_idx(self):
"""
For each feature_type, get the feature index in the student and teacher models.
Returns:
distill_idx(dict): the feature index for each feature_type,
such as {'hidden': [[0, 0], [1, 1]], 'out': [[0, 0]]}
"""
distill_idx = {}
for config in self._distill_configs:
if config['feature_type'] not in distill_idx:
distill_idx[config['feature_type']] = [[
int(config['s_feature_idx']), int(config['t_feature_idx'])
]]
else:
distill_idx[config['feature_type']].append([
int(config['s_feature_idx']), int(config['t_feature_idx'])
])
return distill_idx
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
stu_batch_outs = self._student_models.forward(*inputs, **kwargs) students_batch_outs = []
tea_batch_outs = self._teacher_models.forward(*inputs, **kwargs) teachers_batch_outs = []
if not self._teacher_models.training: for idx, student_model in enumerate(self._student_models):
tea_batch_outs = [i.detach() for i in tea_batch_outs] stu_batch_outs = student_model.forward(*inputs, **kwargs)
students_batch_outs.append(stu_batch_outs)
# get all target tensor for idx, teacher_model in enumerate(self._teacher_models):
if self._adaptors_S.add_tensor == False: tea_batch_outs = teacher_model.forward(*inputs, **kwargs)
self._adaptors_S.add_tensor = True if not teacher_model.training:
if self._adaptors_T.add_tensor == False: tea_batch_outs = [i.detach() for i in tea_batch_outs]
self._adaptors_T.add_tensor = True teachers_batch_outs.extend(tea_batch_outs)
self.stu_outs_dict = self._get_model_intermediate_output(
self._adaptors_S, self.stu_outs_dict) if len(self._student_models) == 1:
self.tea_outs_dict = self._get_model_intermediate_output( students_batch_outs = students_batch_outs[0]
self._adaptors_T, self.tea_outs_dict) if len(self._teacher_models) == 1:
teachers_batch_outs = teachers_batch_outs[0]
distill_inputs = self._process_outputs()
### batch is None just for now ### batch is None just for now
distill_outputs = self.distill_loss(distill_inputs, None) distill_outputs = self.distill_loss(self._output_tensor_dict, None)
distill_loss = distill_outputs['loss'] distill_loss = distill_outputs['loss']
return stu_batch_outs, tea_batch_outs, distill_loss if self._return_model_outputs:
return distill_loss, students_batch_outs, teachers_batch_outs
def _get_model_intermediate_output(self, adaptors, outs_dict): else:
""" return distill_loss
Use the adaptor get the target tensor.
Returns:
outs_dict(dict): the name and tensor for the target model,
such as {'hidden_0': tensor_0, ..}
"""
mapping_layers = adaptors.mapping_layers()
for layer_type, layer in mapping_layers.items():
if isinstance(layer, str):
continue
outs_dict[layer_type] = layer
return outs_dict
def _process_outputs(self):
"""
Process the target tensor to adapt for loss.
"""
### TODO: support list of student models and teacher_models
final_distill_dict = {
"student": collections.OrderedDict(),
"teacher": collections.OrderedDict()
}
for feature_type, dist_idx in self.distill_idx.items():
for idx, idx_list in enumerate(dist_idx):
sidx, tidx = idx_list[0], idx_list[1]
stu_out = self.stu_outs_dict[feature_type + '_' + str(sidx)]
tea_out = self.tea_outs_dict[feature_type + '_' + str(tidx)]
if not self._student_models.training:
stu_out = stu_out.detach()
if not self._teacher_models.training:
tea_out = tea_out.detach()
name_str = feature_type + '_' + str(sidx) + '_' + str(tidx)
final_distill_dict['student'][name_str] = stu_out
final_distill_dict['teacher'][name_str] = tea_out
return final_distill_dict
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
__all__ = ['config2yaml']
def yaml2config(yaml_path):
"""
convert yaml to dict config.
"""
final_configs = []
f = open(yaml_path, 'r')
origin_configs = yaml.load(f, Loader=yaml.FullLoader)
f.close()
for configs in origin_configs:
configs = configs['DistillConfig']
final_configs.extend(configs)
return final_configs
def config2yaml(configs, yaml_path):
"""
convert dict config to yaml.
"""
final_yaml = dict()
final_yaml['DistillConfig'] = configs
f = open(yaml_path, "w")
yaml.dump([final_yaml], f)
f.close()
...@@ -19,18 +19,7 @@ import paddle.nn as nn ...@@ -19,18 +19,7 @@ import paddle.nn as nn
from . import basic_loss from . import basic_loss
from . import distillation_loss from . import distillation_loss
from .basic_loss import L1Loss from .distillation_loss import DistillationLoss
from .basic_loss import L2Loss
from .basic_loss import SmoothL1Loss
from .basic_loss import CELoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
from .basic_loss import RKdAngle, RkdDistance
from .distillation_loss import DistillationDistanceLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationRKDLoss
from .distillation_loss import SegPairWiseLoss, SegChannelwiseLoss
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):
...@@ -40,13 +29,12 @@ class CombinedLoss(nn.Layer): ...@@ -40,13 +29,12 @@ class CombinedLoss(nn.Layer):
loss_config_list: a config list used to build loss function. A demo is as follows, loss_config_list: a config list used to build loss function. A demo is as follows,
which is used to calculate dml loss between Student output and which is used to calculate dml loss between Student output and
Teacher output. Parameter weight is needed for the loss weight. Teacher output. Parameter weight is needed for the loss weight.
- DistillationDMLLoss: { loss_function: DMLLoss
weight: 1.0 weight: 1.0
act: "softmax" act: "softmax"
model_name_pairs: model_name_pairs:["student_0", "teacher_0"]}
- ["Student", "Teacher"] Another example is {loss_function: "MSELoss", 'weight': 1.0,
Another example is {'DistillationDistanceLoss': {'weight': 1.0, 'layers_name': ['conv0', 'conv0'], 'model_name_pairs': [['student', 'teacher']]}
'key': 'hidden_0_0', 'model_name_pairs': [['student', 'teacher']]}
""" """
def __init__(self, loss_config_list=None): def __init__(self, loss_config_list=None):
...@@ -56,18 +44,14 @@ class CombinedLoss(nn.Layer): ...@@ -56,18 +44,14 @@ class CombinedLoss(nn.Layer):
self.loss_weight = [] self.loss_weight = []
assert isinstance(loss_config_list, list), ( assert isinstance(loss_config_list, list), (
'operator config should be a list') 'operator config should be a list')
supported_loss_list = basic_loss.__all__ + distillation_loss.__all__
for config in loss_config_list: for config in loss_config_list:
assert isinstance(config, assert isinstance(
dict) and len(config) == 1, "yaml format error" config, dict), "config must be a dict, but now is {}".format(
name = list(config)[0] type(config))
assert name in supported_loss_list, \ assert "weight" in config, "weight must be in param, but param just contains {}".format(
"loss name must be in {} but got: {}".format(name, supported_loss_list) config.keys())
param = config[name] self.loss_weight.append(config.pop("weight"))
assert "weight" in param, "weight must be in param, but param just contains {}".format( self.loss_func.append(DistillationLoss(**config))
param.keys())
self.loss_weight.append(param.pop("weight"))
self.loss_func.append(eval(name)(**param))
def forward(self, input, batch, **kargs): def forward(self, input, batch, **kargs):
loss_dict = {} loss_dict = {}
...@@ -82,6 +66,7 @@ class CombinedLoss(nn.Layer): ...@@ -82,6 +66,7 @@ class CombinedLoss(nn.Layer):
for key in loss for key in loss
} }
loss_dict.update(loss) loss_dict.update(loss)
if loss_dict == {}: if loss_dict == {}:
loss_dict["loss"] = paddle.to_tensor(0.) loss_dict["loss"] = paddle.to_tensor(0.)
else: else:
......
...@@ -20,11 +20,13 @@ from paddle.nn import L1Loss ...@@ -20,11 +20,13 @@ from paddle.nn import L1Loss
from paddle.nn import MSELoss as L2Loss from paddle.nn import MSELoss as L2Loss
from paddle.nn import SmoothL1Loss from paddle.nn import SmoothL1Loss
__all__ = [ from ....core import Registry
"CELoss", "DMLLoss", "DistanceLoss", "RKdAngle", "RkdDistance", "KLLoss"
]
__all__ = ["BASIC_LOSS"]
BASIC_LOSS = Registry("basicloss")
@BASIC_LOSS.register
class CELoss(nn.Layer): class CELoss(nn.Layer):
""" """
CELoss: cross entropy loss CELoss: cross entropy loss
...@@ -78,6 +80,7 @@ class CELoss(nn.Layer): ...@@ -78,6 +80,7 @@ class CELoss(nn.Layer):
return loss return loss
@BASIC_LOSS.register
class DMLLoss(nn.Layer): class DMLLoss(nn.Layer):
""" """
DMLLoss DMLLoss
...@@ -110,6 +113,7 @@ class DMLLoss(nn.Layer): ...@@ -110,6 +113,7 @@ class DMLLoss(nn.Layer):
return loss return loss
@BASIC_LOSS.register
class KLLoss(nn.Layer): class KLLoss(nn.Layer):
""" """
KLLoss. KLLoss.
...@@ -153,6 +157,7 @@ class KLLoss(nn.Layer): ...@@ -153,6 +157,7 @@ class KLLoss(nn.Layer):
return loss return loss
@BASIC_LOSS.register
class DistanceLoss(nn.Layer): class DistanceLoss(nn.Layer):
""" """
DistanceLoss DistanceLoss
...@@ -191,6 +196,7 @@ def pdist(e, squared=False, eps=1e-12): ...@@ -191,6 +196,7 @@ def pdist(e, squared=False, eps=1e-12):
return res return res
@BASIC_LOSS.register
class RKdAngle(nn.Layer): class RKdAngle(nn.Layer):
""" """
RKdAngle loss, see https://arxiv.org/abs/1904.05068 RKdAngle loss, see https://arxiv.org/abs/1904.05068
...@@ -218,6 +224,7 @@ class RKdAngle(nn.Layer): ...@@ -218,6 +224,7 @@ class RKdAngle(nn.Layer):
return loss return loss
@BASIC_LOSS.register
class RkdDistance(nn.Layer): class RkdDistance(nn.Layer):
""" """
RkdDistance loss, see https://arxiv.org/abs/1904.05068 RkdDistance loss, see https://arxiv.org/abs/1904.05068
...@@ -244,3 +251,50 @@ class RkdDistance(nn.Layer): ...@@ -244,3 +251,50 @@ class RkdDistance(nn.Layer):
loss = F.smooth_l1_loss(d, t_d, reduction="mean") loss = F.smooth_l1_loss(d, t_d, reduction="mean")
return loss return loss
@BASIC_LOSS.register
class MSELoss(DistanceLoss):
"""
MSELoss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/MSELoss_cn.html#mseloss
"""
def __init__(self, **kargs):
super().__init__(mode='l2', **kargs)
@BASIC_LOSS.register
class L1Loss(DistanceLoss):
"""
L1loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/L1Loss_cn.html#l1loss
"""
def __init__(self, **kargs):
super().__init__(mode='l1', **kargs)
@BASIC_LOSS.register
class SmoothL1Loss(DistanceLoss):
"""
SmoothL1Loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/SmoothL1Loss_cn.html#smoothl1loss
"""
def __init__(self, **kargs):
super().__init__(mode='smooth_l1', **kargs)
@BASIC_LOSS.register
class RKDLoss(nn.Layer):
"""
RKDLoss
"""
def __init__(self, eps=1e-12):
super().__init__()
self.rkd_angle_loss_func = RKdAngle()
self.rkd_dist_func = RkdDistance(eps=eps)
def forward(self, student, teacher):
angle_loss = self.rkd_angle_loss_func(student, teacher)
dist_loss = self.rkd_dist_func(student, teacher)
return angle_loss + dist_loss
...@@ -15,210 +15,54 @@ ...@@ -15,210 +15,54 @@
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from .basic_loss import DMLLoss from .basic_loss import BASIC_LOSS
from .basic_loss import DistanceLoss
from .basic_loss import RkdDistance
from .basic_loss import RKdAngle
from .basic_loss import KLLoss
__all__ = [ __all__ = ["DistillationLoss"]
"DistillationDMLLoss",
"DistillationDistanceLoss",
"DistillationRKDLoss",
"SegPairWiseLoss",
"SegChannelwiseLoss",
]
class DistillationDMLLoss(DMLLoss): class DistillationLoss(nn.Layer):
""" """
DistillationDMLLoss DistillationLoss
Args: Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output. model_name_pairs(list | tuple): model name pairs to extract submodel output.
act(string | None): activation function used to build dml loss. layers_name(list(string)): keys of the tensor used to calculate loss if the submodel.
axis(int): axis used to build activation function. loss_function(string): the name of loss function.
key(string | None): key of the tensor used to calculate loss if the submodel temperature(float): the temperature to compute distill loss.
output type is dict.
name(string): loss name.
"""
def __init__(self, model_name_pairs=[], act=None, key=None,
name="loss_dml"):
super().__init__(act=act)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = super().forward(out1, out2)
return loss_dict
class DistillationDistanceLoss(DistanceLoss):
"""
DistillationDistanceLoss
Args:
mode: loss mode
model_name_pairs(list | tuple): model name pairs to extract submodel output.
such as [['student', 'teacher']]
key(string | None): key of the tensor used to calculate loss if the submodel.
such as 'hidden_0_0'
name(string): loss name.
kargs(dict): used to build corresponding loss function.
""" """
def __init__(self, def __init__(self,
mode="l2",
model_name_pairs=[], model_name_pairs=[],
key=None, layers_name=None,
name="loss_distance", loss_function=None,
**kargs): temperature=1.0,
super().__init__(mode=mode, **kargs) **params):
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name + "_" + mode
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict
class DistillationRKDLoss(nn.Layer):
"""
DistillationRKDLoss
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string | None): key of the tensor used to calculate loss if the submodel.
eps(float): epsilon for the pdist function for RkdDistance loss.
name(string): loss name.
"""
def __init__(self,
model_name_pairs=[],
key=None,
eps=1e-12,
name="loss_rkd"):
super().__init__() super().__init__()
self.model_name_pairs = model_name_pairs self.model_name_pairs = model_name_pairs
self.key = key self.layers_name = layers_name
self.loss_function = loss_function
self.temperature = temperature
self.align_params = params.pop(
'align_params') if 'align_params' in params else None
if self.align_params is not None:
for attr, value in self.align_params.items():
setattr(self, attr, value)
self.rkd_angle_loss_func = RKdAngle() self.loss_func = BASIC_LOSS.get(loss_function)(**params)
self.rkd_dist_func = RkdDistance(eps=eps)
self.name = name
def forward(self, predicts, batch): def forward(self, predicts, batch):
loss_dict = dict() loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs): for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]] out1 = predicts[pair[0]]
out2 = predicts[pair[1]] out2 = predicts[pair[1]]
if self.key is not None: if self.layers_name != None:
out1 = out1[self.key] assert len(self.layers_name
out2 = out2[self.key] ) == 2, "length of layers_name must be equal to 2."
loss_dict["{}_{}_{}_angle_{}".format(self.name, pair[0], pair[ out1 = out1[self.layers_name[0]]
1], idx)] = self.rkd_angle_loss_func(out1, out2) out2 = out2[self.layers_name[1]]
if self.temperature != 1.0:
loss_dict["{}_{}_{}_dist_{}".format(self.name, pair[0], pair[ out1 = out1 / self.temperature
1], idx)] = self.rkd_dist_func(out1, out2) out2 = out2 / self.temperature
return loss_dict loss_dict["{}_{}_{}_{}_{}".format(self.loss_function, pair[0], pair[
1], self.layers_name[0] if self.layers_name != None else "0", \
self.layers_name[1] if self.layers_name != None else "0")] = self.loss_func(out1, out2)
class SegPairWiseLoss(DistanceLoss):
"""
Segmentation pairwise loss, see https://arxiv.org/pdf/1903.04197.pdf
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string): key of the tensor used to calculate loss if the submodel
output type is dict.
mode(string, optional): loss mode. It supports l1, l2 and smooth_l1. Default: l2.
reduction(string, optional): the reduction params for F.kl_div. Default: mean.
name(string, optional): loss name. Default: seg_pair_wise_loss.
"""
def __init__(self,
model_name_pairs=[],
key=None,
mode="l2",
reduction="mean",
name="seg_pair_wise_loss"):
super().__init__(mode=mode, reduction=reduction)
assert isinstance(model_name_pairs, list)
assert key is not None
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
self.pool1 = nn.AdaptiveAvgPool2D(output_size=[2, 2])
self.pool2 = nn.AdaptiveAvgPool2D(output_size=[2, 2])
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]][self.key]
out2 = predicts[pair[1]][self.key]
pool1 = self.pool1(out1)
pool2 = self.pool2(out2)
loss_name = "{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)
loss_dict[loss_name] = super().forward(pool1, pool2)
return loss_dict
class SegChannelwiseLoss(KLLoss):
"""
Segmentation channel wise loss, see `Channel-wise Distillation for Semantic Segmentation`.
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string): key of the tensor used to calculate loss if the submodel
output type is dict.
act(string, optional): activation function used for the input and label tensor.
Default: softmax.
axis(int, optional): the axis for the act. Default: -1.
reduction(str, optional): the reduction params for F.kl_div. Default: mean.
name(string, optional): loss name. Default: seg_ch_wise_loss.
"""
def __init__(self,
model_name_pairs=[],
key=None,
act='softmax',
axis=-1,
reduction="mean",
name="seg_ch_wise_loss"):
super().__init__(act, axis, reduction)
assert isinstance(model_name_pairs, list)
assert key is not None
self.model_name_pairs = model_name_pairs
self.key = key
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]][self.key]
out2 = predicts[pair[1]][self.key]
loss_name = "{}_{}_{}_{}".format(self.name, pair[0], pair[1], idx)
loss_dict[loss_name] = super().forward(out1, out2)
return loss_dict return loss_dict
...@@ -7,7 +7,7 @@ import paddle ...@@ -7,7 +7,7 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
from paddle.vision.models import MobileNetV1 from paddle.vision.models import MobileNetV1
import paddle.vision.transforms as T import paddle.vision.transforms as T
from paddleslim.dygraph.dist import Distill, AdaptorBase from paddleslim.dygraph.dist import Distill, config2yaml
from paddleslim.common.log_helper import get_logger from paddleslim.common.log_helper import get_logger
_logger = get_logger( _logger = get_logger(
...@@ -19,42 +19,30 @@ class TestImperativeDistill(unittest.TestCase): ...@@ -19,42 +19,30 @@ class TestImperativeDistill(unittest.TestCase):
self.s_model, self.t_model = self.prepare_model() self.s_model, self.t_model = self.prepare_model()
self.t_model.eval() self.t_model.eval()
self.distill_configs = self.prepare_config() self.distill_configs = self.prepare_config()
self.adaptor = self.prepare_adaptor()
def prepare_model(self): def prepare_model(self):
return MobileNetV1(), MobileNetV1() return MobileNetV1(), MobileNetV1()
def prepare_config(self): def prepare_config(self):
distill_configs = [{ distill_configs = [{
's_feature_idx': 0, 'loss_function': 'MSELoss',
't_feature_idx': 0, 'layers': [
'feature_type': 'hidden', {
'loss_function': 'l2' "layers_name": ["conv1", "conv1"]
},
{
"layers_name": ["conv2_2", "conv2_2"]
},
]
}, { }, {
's_feature_idx': 1, 'loss_function': 'CELoss',
't_feature_idx': 1, 'temperature': 1.0,
'feature_type': 'hidden', 'layers': [{
'loss_function': 'l2' "layers_name": ["fc", "fc"]
}, { }, ]
's_feature_idx': 0,
't_feature_idx': 0,
'feature_type': 'logits',
'loss_function': 'l2'
}] }]
return distill_configs return distill_configs
def prepare_adaptor(self):
class Adaptor(AdaptorBase):
def mapping_layers(self):
mapping_layers = {}
mapping_layers['hidden_0'] = 'conv1'
mapping_layers['hidden_1'] = 'conv2_2'
mapping_layers['hidden_2'] = 'conv3_2'
mapping_layers['logits_0'] = 'fc'
return mapping_layers
return Adaptor
def test_distill(self): def test_distill(self):
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
...@@ -97,7 +85,7 @@ class TestImperativeDistill(unittest.TestCase): ...@@ -97,7 +85,7 @@ class TestImperativeDistill(unittest.TestCase):
for batch_id, data in enumerate(train_reader): for batch_id, data in enumerate(train_reader):
img = paddle.to_tensor(data[0]) img = paddle.to_tensor(data[0])
label = paddle.to_tensor(data[1]) label = paddle.to_tensor(data[1])
student_out, teacher_out, distill_loss = model(img) distill_loss, student_out, teacher_out = model(img)
loss = paddle.nn.functional.loss.cross_entropy(student_out, loss = paddle.nn.functional.loss.cross_entropy(student_out,
label) label)
avg_loss = paddle.mean(loss) avg_loss = paddle.mean(loss)
...@@ -112,7 +100,7 @@ class TestImperativeDistill(unittest.TestCase): ...@@ -112,7 +100,7 @@ class TestImperativeDistill(unittest.TestCase):
self.s_model.train() self.s_model.train()
distill_model = Distill(self.distill_configs, self.s_model, distill_model = Distill(self.distill_configs, self.s_model,
self.t_model, self.adaptor, self.adaptor) self.t_model)
train(distill_model) train(distill_model)
...@@ -136,31 +124,26 @@ class TestImperativeDistillCase1(TestImperativeDistill): ...@@ -136,31 +124,26 @@ class TestImperativeDistillCase1(TestImperativeDistill):
return Model(), Model() return Model(), Model()
def prepare_adaptor(self):
class Adaptor(AdaptorBase):
def mapping_layers(self):
mapping_layers = {}
mapping_layers['hidden_1'] = 'conv2'
if self.add_tensor:
mapping_layers['hidden_0'] = self.model.conv1_out
mapping_layers['hidden_2'] = self.model.conv3_out
return mapping_layers
return Adaptor
def prepare_config(self): def prepare_config(self):
distill_configs = [{ distill_configs = [{
's_feature_idx': 0, 'loss_function': 'MSELoss',
't_feature_idx': 0, 'layers': [
'feature_type': 'hidden', {
'loss_function': 'l2' "layers_name": ["conv1", "conv1"]
},
{
"layers_name": ["conv2", "conv3"]
},
]
}, { }, {
's_feature_idx': 1, 'loss_function': 'CELoss',
't_feature_idx': 2, 'temperature': 1.0,
'feature_type': 'hidden', 'layers': [{
'loss_function': 'l2' "layers_name": ["fc", "fc"]
}, ]
}] }]
return distill_configs config2yaml(distill_configs, 'test.yaml')
return './test.yaml'
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -24,18 +24,14 @@ import paddle.nn.functional as F ...@@ -24,18 +24,14 @@ import paddle.nn.functional as F
from paddleslim.dygraph.dist.losses import CombinedLoss from paddleslim.dygraph.dist.losses import CombinedLoss
# basic loss # basic loss
from paddleslim.dygraph.dist.losses import DistanceLoss from paddleslim.dygraph.dist.losses.basic_loss import DistanceLoss
from paddleslim.dygraph.dist.losses import CELoss from paddleslim.dygraph.dist.losses.basic_loss import CELoss
from paddleslim.dygraph.dist.losses import DMLLoss from paddleslim.dygraph.dist.losses.basic_loss import DMLLoss
from paddleslim.dygraph.dist.losses import RkdDistance from paddleslim.dygraph.dist.losses.basic_loss import RkdDistance
from paddleslim.dygraph.dist.losses import RKdAngle from paddleslim.dygraph.dist.losses.basic_loss import RKdAngle
# distillation loss # distillation loss
from paddleslim.dygraph.dist.losses import DistillationDistanceLoss from paddleslim.dygraph.dist.losses import DistillationLoss
from paddleslim.dygraph.dist.losses import DistillationRKDLoss
from paddleslim.dygraph.dist.losses import DistillationDMLLoss
from paddleslim.dygraph.dist.losses import SegPairWiseLoss
from paddleslim.dygraph.dist.losses import SegChannelwiseLoss
import numpy as np import numpy as np
...@@ -70,14 +66,13 @@ class TestDistanceLoss(unittest.TestCase): ...@@ -70,14 +66,13 @@ class TestDistanceLoss(unittest.TestCase):
out = np.sum(diff) out = np.sum(diff)
return out return out
def dist_np_distance_loss( def dist_np_distance_loss(self,
self, predicts,
predicts, loss_function=None,
mode="l2", mode="l2",
reduction="none", reduction="none",
model_name_pairs=(["", ""]), model_name_pairs=(["", ""]),
key=None, key=None):
name="loss_distance", ):
loss_dict = dict() loss_dict = dict()
for idx, pair in enumerate(model_name_pairs): for idx, pair in enumerate(model_name_pairs):
out1 = predicts[pair[0]] out1 = predicts[pair[0]]
...@@ -85,10 +80,12 @@ class TestDistanceLoss(unittest.TestCase): ...@@ -85,10 +80,12 @@ class TestDistanceLoss(unittest.TestCase):
if key is not None: if key is not None:
out1 = out1[key] out1 = out1[key]
out2 = out2[key] out2 = out2[key]
else:
key = 0
loss = self.np_distance_loss( loss = self.np_distance_loss(
out1, out2, mode=mode, reduction=reduction) out1, out2, mode=mode, reduction=reduction)
loss_dict["{}_{}_{}_{}_{}".format(name, mode, pair[0], pair[1], loss_dict["{}_{}_{}_{}_{}".format(
idx)] = loss str(loss_function), pair[0], pair[1], key, key)] = loss
return loss_dict return loss_dict
...@@ -120,7 +117,7 @@ class TestDistanceLoss(unittest.TestCase): ...@@ -120,7 +117,7 @@ class TestDistanceLoss(unittest.TestCase):
"student": paddle.rand(shape), "student": paddle.rand(shape),
"teacher": paddle.rand(shape), "teacher": paddle.rand(shape),
} }
self.calc_distillation_distance_loss(predicts, pairs, key=None) self.calc_distillation_distance_loss(predicts, pairs)
predicts = { predicts = {
"student": { "student": {
...@@ -143,13 +140,15 @@ class TestDistanceLoss(unittest.TestCase): ...@@ -143,13 +140,15 @@ class TestDistanceLoss(unittest.TestCase):
paddle.set_device(device) paddle.set_device(device)
for reduction in reductions: for reduction in reductions:
for mode in modes: for mode in modes:
loss_func = DistillationDistanceLoss( loss_func = DistillationLoss(
mode=mode, mode=mode,
loss_function='DistanceLoss',
model_name_pairs=pairs, model_name_pairs=pairs,
key=key, layers_name=[key, key] if key != None else None,
reduction=reduction) reduction=reduction)
np_result_dict = self.dist_np_distance_loss( np_result_dict = self.dist_np_distance_loss(
predicts, predicts,
loss_function='DistanceLoss',
mode=mode, mode=mode,
reduction=reduction, reduction=reduction,
model_name_pairs=pairs, model_name_pairs=pairs,
...@@ -358,12 +357,11 @@ class TestDMLLoss(unittest.TestCase): ...@@ -358,12 +357,11 @@ class TestDMLLoss(unittest.TestCase):
np_loss = self.np_dml_loss(x, target) np_loss = self.np_dml_loss(x, target)
self.assertTrue(np.allclose(np_loss, pd_loss)) self.assertTrue(np.allclose(np_loss, pd_loss))
def dist_np_dml_loss( def dist_np_dml_loss(self,
self, predicts,
predicts, loss_function=None,
model_name_pairs=(["", ""]), model_name_pairs=(["", ""]),
key=None, key=None):
name="loss_dml", ):
loss_dict = dict() loss_dict = dict()
for idx, pair in enumerate(model_name_pairs): for idx, pair in enumerate(model_name_pairs):
out1 = predicts[pair[0]] out1 = predicts[pair[0]]
...@@ -371,8 +369,11 @@ class TestDMLLoss(unittest.TestCase): ...@@ -371,8 +369,11 @@ class TestDMLLoss(unittest.TestCase):
if key is not None: if key is not None:
out1 = out1[key] out1 = out1[key]
out2 = out2[key] out2 = out2[key]
loss_dict["{}_{}_{}_{}".format(name, pair[0], pair[1], else:
idx)] = self.np_dml_loss(out1, out2) key = 0
loss_dict["{}_{}_{}_{}_{}".format(
str(loss_function), pair[0], pair[1], key,
key)] = self.np_dml_loss(out1, out2)
return loss_dict return loss_dict
def calc_distillation_dml_loss(self, predicts, pairs, key=None): def calc_distillation_dml_loss(self, predicts, pairs, key=None):
...@@ -382,11 +383,19 @@ class TestDMLLoss(unittest.TestCase): ...@@ -382,11 +383,19 @@ class TestDMLLoss(unittest.TestCase):
for device in devices: for device in devices:
paddle.set_device(device) paddle.set_device(device)
loss_func = DistillationDMLLoss( loss_func = DistillationLoss(
act="softmax", model_name_pairs=pairs, key=key) act="softmax",
model_name_pairs=pairs,
loss_function='DMLLoss',
layers_name=[key, key] if key != None else None)
np_result_dict = self.dist_np_dml_loss( np_result_dict = self.dist_np_dml_loss(
predicts, model_name_pairs=pairs, key=key) predicts,
model_name_pairs=pairs,
loss_function='DMLLoss',
key=key)
pd_result_dict = loss_func(predicts, None) pd_result_dict = loss_func(predicts, None)
print(pd_result_dict.keys())
print(np_result_dict.keys())
for k in np_result_dict: for k in np_result_dict:
pd_result = pd_result_dict[k].numpy() pd_result = pd_result_dict[k].numpy()
np_result = np_result_dict[k] np_result = np_result_dict[k]
...@@ -526,7 +535,7 @@ class TestRKDLoss(unittest.TestCase): ...@@ -526,7 +535,7 @@ class TestRKDLoss(unittest.TestCase):
predicts, predicts,
model_name_pairs=(["", ""]), model_name_pairs=(["", ""]),
key=None, key=None,
name="loss_rkd", ): name="RKDLoss", ):
loss_dict = dict() loss_dict = dict()
for idx, pair in enumerate(model_name_pairs): for idx, pair in enumerate(model_name_pairs):
out1 = predicts[pair[0]] out1 = predicts[pair[0]]
...@@ -534,11 +543,12 @@ class TestRKDLoss(unittest.TestCase): ...@@ -534,11 +543,12 @@ class TestRKDLoss(unittest.TestCase):
if key is not None: if key is not None:
out1 = out1[key] out1 = out1[key]
out2 = out2[key] out2 = out2[key]
loss_dict["{}_{}_{}_angle_{}".format(name, pair[0], pair[ else:
1], idx)] = self.np_rkd_angle(out1, out2) key = 0
loss_dict["{}_{}_{}_{}_{}".format(name, pair[0], pair[
1], key, key)] = self.np_rkd_angle(
out1, out2) + self.np_rkd_distance(out1, out2)
loss_dict["{}_{}_{}_dist_{}".format(name, pair[0], pair[
1], idx)] = self.np_rkd_distance(out1, out2)
return loss_dict return loss_dict
def calc_distillation_rkd_loss(self, predicts, pairs, key=None): def calc_distillation_rkd_loss(self, predicts, pairs, key=None):
...@@ -548,7 +558,10 @@ class TestRKDLoss(unittest.TestCase): ...@@ -548,7 +558,10 @@ class TestRKDLoss(unittest.TestCase):
for device in devices: for device in devices:
paddle.set_device(device) paddle.set_device(device)
loss_func = DistillationRKDLoss(model_name_pairs=pairs, key=key) loss_func = DistillationLoss(
model_name_pairs=pairs,
loss_function='RKDLoss',
layers_name=[key, key] if key != None else None)
np_result_dict = self.dist_np_rkd_loss( np_result_dict = self.dist_np_rkd_loss(
predicts, model_name_pairs=pairs, key=key) predicts, model_name_pairs=pairs, key=key)
pd_result_dict = loss_func(predicts, None) pd_result_dict = loss_func(predicts, None)
...@@ -623,13 +636,12 @@ class TestCombinedLoss(unittest.TestCase): ...@@ -623,13 +636,12 @@ class TestCombinedLoss(unittest.TestCase):
log_soft_target, soft_x)) / 2.0 log_soft_target, soft_x)) / 2.0
return loss return loss
def dist_np_dml_loss( def dist_np_dml_loss(self,
self, predicts,
predicts, model_name_pairs=(["", ""]),
model_name_pairs=(["", ""]), loss_function=None,
key=None, key=None,
act="softmax", act="softmax"):
name="loss_dml", ):
loss_dict = dict() loss_dict = dict()
for idx, pair in enumerate(model_name_pairs): for idx, pair in enumerate(model_name_pairs):
out1 = predicts[pair[0]] out1 = predicts[pair[0]]
...@@ -637,20 +649,24 @@ class TestCombinedLoss(unittest.TestCase): ...@@ -637,20 +649,24 @@ class TestCombinedLoss(unittest.TestCase):
if key is not None: if key is not None:
out1 = out1[key] out1 = out1[key]
out2 = out2[key] out2 = out2[key]
loss_dict["{}_{}_{}_{}".format(name, pair[0], pair[1], loss_dict["{}_{}_{}_{}_0".format(
idx)] = self.np_dml_loss(out1, out2) str(loss_function), pair[0], pair[1], idx)] = self.np_dml_loss(
out1, out2)
return loss_dict return loss_dict
def np_combined_loss(self, predicts, loss_cfg_list): def np_combined_loss(self, predicts, loss_cfg_list):
# NOTE, dml is set as the list for combined loss # NOTE, dml is set as the list for combined loss
loss_dict = dict() loss_dict = dict()
for idx, loss_func in enumerate(loss_cfg_list): for idx, loss_func in enumerate(loss_cfg_list):
cfg = copy.deepcopy(loss_func["DistillationDMLLoss"]) cfg = copy.deepcopy(loss_func)
weight = cfg.pop("weight") weight = cfg.pop("weight")
loss = self.dist_np_dml_loss(predicts, **cfg) loss = self.dist_np_dml_loss(predicts, **cfg)
if isinstance(loss, np.ndarray): if isinstance(loss, np.ndarray):
loss = {"loss_{}_{}".format(str(loss), idx): loss} loss = {
"{}_{}_{}".format(loss_func['loss_function'],
str(loss), idx): loss
}
else: else:
loss = { loss = {
"{}_{}".format(key, idx): loss[key] * weight "{}_{}".format(key, idx): loss[key] * weight
...@@ -677,12 +693,10 @@ class TestCombinedLoss(unittest.TestCase): ...@@ -677,12 +693,10 @@ class TestCombinedLoss(unittest.TestCase):
devices.append("gpu") devices.append("gpu")
loss_cfg_list = [{ loss_cfg_list = [{
"DistillationDMLLoss": { "loss_function": "DMLLoss",
"weight": 1.0, "weight": 1.0,
"act": "softmax", "act": "softmax",
"model_name_pairs": pairs, "model_name_pairs": pairs
"key": None
}
}, ] }, ]
for device in devices: for device in devices:
...@@ -696,95 +710,5 @@ class TestCombinedLoss(unittest.TestCase): ...@@ -696,95 +710,5 @@ class TestCombinedLoss(unittest.TestCase):
self.assertTrue(np.allclose(np_result, pd_result)) self.assertTrue(np.allclose(np_result, pd_result))
class TestSegPairWiseLoss(unittest.TestCase):
def calculate_gt_loss(self, x, y):
pool_x = F.adaptive_avg_pool2d(x, [2, 2])
pool_y = F.adaptive_avg_pool2d(y, [2, 2])
loss = F.mse_loss(pool_x, pool_y)
return loss
def test_seg_pair_wise_loss(self):
shape = [1, 3, 10, 10]
x = paddle.rand(shape)
y = paddle.rand(shape)
model_name_pairs = [['student', 'teacher']]
key = 'hidden_0_0'
inputs = {
model_name_pairs[0][0]: {
key: x
},
model_name_pairs[0][1]: {
key: y
}
}
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
loss_func = SegPairWiseLoss(model_name_pairs, key)
pd_loss_dict = loss_func(inputs, None)
pd_loss = pd_loss_dict['seg_pair_wise_loss_student_teacher_0']
gt_loss = self.calculate_gt_loss(x, y)
self.assertTrue(np.allclose(pd_loss.numpy(), gt_loss.numpy()))
class TestSegChannelWiseLoss(unittest.TestCase):
def init(self):
self.act_name = None
self.act_func = None
def calculate_gt_loss(self, x, y, act=None):
if act is not None:
x = act(x)
y = act(y)
x = paddle.log(x)
loss = F.kl_div(x, y)
return loss
def test_seg_pair_wise_loss(self):
self.init()
shape = [1, 3, 10, 10]
x = paddle.rand(shape)
y = paddle.rand(shape)
model_name_pairs = [['student', 'teacher']]
key = 'hidden_0_0'
inputs = {
model_name_pairs[0][0]: {
key: x
},
model_name_pairs[0][1]: {
key: y
}
}
devices = ["cpu"]
if paddle.is_compiled_with_cuda():
devices.append("gpu")
for device in devices:
paddle.set_device(device)
loss_func = SegChannelwiseLoss(model_name_pairs, key, self.act_name)
pd_loss_dict = loss_func(inputs, None)
pd_loss = pd_loss_dict['seg_ch_wise_loss_student_teacher_0']
gt_loss = self.calculate_gt_loss(x, y, self.act_func)
self.assertTrue(np.allclose(pd_loss.numpy(), gt_loss.numpy()))
class TestSegChannelWiseLoss1(TestSegChannelWiseLoss):
def init(self):
self.act_name = "softmax"
self.act_func = F.softmax
class TestSegChannelWiseLoss1(TestSegChannelWiseLoss):
def init(self):
self.act_name = "sigmoid"
self.act_func = F.sigmoid
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册