Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
36a9f6f0
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
接近 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
36a9f6f0
编写于
9月 17, 2021
作者:
C
ceci3
提交者:
GitHub
9月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update distill (#892)
* polish distill
上级
5ec362f2
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
1057 addition
and
578 deletion
+1057
-578
demo/dygraph/dist/bert/README.md
demo/dygraph/dist/bert/README.md
+144
-0
demo/dygraph/dist/bert/distill_stage1.yaml
demo/dygraph/dist/bert/distill_stage1.yaml
+20
-0
demo/dygraph/dist/bert/distill_stage2.yaml
demo/dygraph/dist/bert/distill_stage2.yaml
+9
-0
demo/dygraph/dist/bert/run.sh
demo/dygraph/dist/bert/run.sh
+20
-0
demo/dygraph/dist/bert/task_distill.py
demo/dygraph/dist/bert/task_distill.py
+460
-0
paddleslim/dygraph/dist/__init__.py
paddleslim/dygraph/dist/__init__.py
+2
-0
paddleslim/dygraph/dist/distill.py
paddleslim/dygraph/dist/distill.py
+156
-163
paddleslim/dygraph/dist/distill_helpers.py
paddleslim/dygraph/dist/distill_helpers.py
+41
-0
paddleslim/dygraph/dist/losses/__init__.py
paddleslim/dygraph/dist/losses/__init__.py
+13
-28
paddleslim/dygraph/dist/losses/basic_loss.py
paddleslim/dygraph/dist/losses/basic_loss.py
+57
-3
paddleslim/dygraph/dist/losses/distillation_loss.py
paddleslim/dygraph/dist/losses/distillation_loss.py
+31
-187
tests/dygraph/test_distill.py
tests/dygraph/test_distill.py
+33
-50
tests/dygraph/test_distillation_loss.py
tests/dygraph/test_distillation_loss.py
+71
-147
未找到文件。
demo/dygraph/dist/bert/README.md
0 → 100644
浏览文件 @
36a9f6f0
# TinyBERT: Distilling BERT for Natural Language Understanding
以下是本例的简要目录结构及说明:
```
.
├── task_distill.py # 在特定任务上下的蒸馏脚本
└── README.md # 文档,本文件
```
## 简介
本目录下的实验主要参考论文
[
《TinyBERT: Distilling BERT for Natural Language Understanding》
](
https://arxiv.org/abs/1909.10351
)
实现。
TinyBERT中蒸馏的整体过程:首先进行通用蒸馏,然后用数据增强后的数据,在特定任务上进行蒸馏,本文主要进行了第二阶段的蒸馏,模型是利用第一阶段得到的通用小模型
`tinybert-6l-768d-v2`
进行初始化。
<p
align=
"center"
>
<img
src=
"./imgs/tinybert.png"
width=
"950"
/><br
/>
TinyBERT蒸馏流程图
</p>
在模型蒸馏中,较大的模型(在本例中是BERT base)通常被称为教师模型,较小的模型(在本例中是层数为6的BERT,下文都称TinyBERT6)通常被称为学生模型。
知识的蒸馏通常是通过让学生模型学习相关的蒸馏相损失函数实现,在本实验中,蒸馏的学习目标由两个部分组成,分别是中间层的蒸馏损失和预测层的蒸馏损失。其中,中间层的蒸馏包括对Embedding层的蒸馏、对每个Transformer layer输出的蒸馏、以及对每个Transformer中attention矩阵(softmax之前的结果)的蒸馏,三者均采用的是均方误差损失函数。而预测层蒸馏的学习目标则是学生模型输出的logits和教师模型输出的logits的交叉熵损失。
由于教师模型是12层,学生模型的层数少于教师模型的层数,因此需要选择一种layer mapping的方式。论文中采用了一种固定的映射方式,当学生模型的层数为教师模型的1/2时,学生第i层的attention矩阵,需要学习教师的第2i+1层的attention矩阵,Transformer layer输出同理。
实验分为两个大的训练过程:先对BERT-base进行微调,得到教师模型,再进行蒸馏的训练。其中,蒸馏过程也分为两个步骤:先对中间层进行蒸馏多个epochs(论文中针对具体任务可能是10、20或者30个),再对预测层蒸馏3个epochs。
需要注意的是,在使用不同教师模型时,
`tinybert-6l-768d-v2`
、
`tinybert-4l-312d-v2`
这两个v2版本的预训练模型中开放的从学生embedding输出、transformer中间层输出到教师相应输出的转换矩阵是每层独立的,而其他的
`tinybert-6l-768d`
、
`tinybert-4l-312d`
、
`tinybert-6l-768d-zh`
、
`tinybert-4l-312-zh`
则是多层之间的参数共用一个转换矩阵的。
### 安装PaddleNLP和Paddle
本教程基于PaddleNLP中BERT模型进行压缩,依赖PaddleNLP和Paddle。
```
shell
pip
install
paddlenlp
pip
install
paddlepaddle_gpu
```
## 数据、预训练模型介绍及获取
本实验使用GLUE中数据集中的训练集作为训练语料,用数据集中的验证集评估模型的效果。
运行本目录下的实验,数据集会被自动下载到
`paddlenlp.utils.env.DATA_HOME`
路径下,例如在linux系统下,对于GLUE中的QQP数据集,默认存储路径是
`~/.paddlenlp/datasets/Glue/QQP`
。
对于BERT的fine-tuning任务,本实验中使用了预训练模型
`bert-base-uncased`
。同样,这几个模型在训练时会被自动下载到
`paddlenlp.utils.env.MODEL_HOME`
路径下。例如,对于
`bert-base-uncased`
模型,在linux系统下,会被下载到
`~/.paddlenlp/models/bert-base-uncased`
下。
## 蒸馏实验过程
### 对BERT Fine-tuning得到教师模型
首先需要对Pretrain-Model在实际的下游任务上进行Fine-tuning,得到需要压缩的模型。Fine-tuning流程参考
[
Fine-tuning教程
](
https://github.com/PaddlePaddle/PaddleNLP/blob/develop/examples/language_model/bert/README.md
)
训练完成之后,可将训练效果最好的模型保存在本项目下的
`pretrained_models/$TASK_NAME/`
下。模型目录下有
`model_config.json`
,
`model_state.pdparams`
,
`tokenizer_config.json`
及
`vocab.txt`
这几个文件。
### 对TinyBERT在特定任务下蒸馏
先蒸馏中间层:
```
shell
export
CUDA_VISIBLE_DEVICES
=
0
export
TASK_NAME
=
SST-2
export
TEACHER_DIR
=
./pretrained_models/SST-2/best_model_610
python task_distill.py
\
--model_type
tinybert
\
--student_model_name_or_path
tinybert-6l-768d-v2
\
--task_name
$TASK_NAME
\
--intermediate_distill
\
--max_seq_length
64
\
--batch_size
32
\
--T
1
\
--teacher_model_type
bert
\
--teacher_path
$TEACHER_DIR
\
--learning_rate
5e-5
\
--num_train_epochs
20
\
--logging_steps
10
\
--save_steps
10
\
--output_dir
./tmp/
$TASK_NAME
/
\
--distill_config
./distill_stage1.yaml
\
--device
gpu
```
其中参数释义如下:
-
`model_type`
学生模型类型,默认且目前仅支持tinybert。
-
`student_model_name_or_path`
中间层蒸馏后,学生模型存放的目录
-
`distill_config`
蒸馏配置文件
-
`max_seq_length`
表示最大句子长度,超过该长度将被截断。默认:128
-
`T`
softmax的温度,用于对softmax做平滑,在训练中起到放大负标签效果的作用。默认:1
-
`teacher_model_type`
教师模型的类型,默认且目前仅支持bert
-
`teacher_path`
教师Fine-tuned模型的目录
-
`output_dir`
学生模型存放的目录
-
`device`
表示运行该程序的设备,默认是gpu
然后对预测层进行蒸馏:
```
shell
export
TEACHER_DIR
=
../pretrained_models/SST-2/best_model_610
python task_distill.py
\
--model_type
tinybert
\
--student_model_name_or_path
tmp/TASK_NAME best_inter_model
\
--task_name
$TASK_NAME
\
--max_seq_length
64
\
--batch_size
32
\
--T
1
\
--teacher_model_type
bert
\
--teacher_path
$TEACHER_DIR
\
--learning_rate
3e-5
\
--num_train_epochs
3
\
--logging_steps
10
\
--save_steps
10
\
--output_dir
./tmp/
$TASK_NAME
/
\
--distill_config
./distill_stage2.yaml
\
--device
gpu
```
其中参数释义如下:
所有参数说明同上。
### 实验中使用的超参数
| | SST-2 | QQP | MRPC | CoLA | RTE | MNLI | QNLI |
| -------------------------------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- |
| batch_size | 32 | 32 | 32 | 32 | 32 | 32 | 32 |
| max_seq_length | 64 | 128 | 128 | 64 | 128 | 128 | 128 |
| max_epochs_of_intermediate_layer | 20 | 10 | 20 | 50 | 20 | 10 | 10 |
| max_epochs_of_prediction_layer | 3 | 3 | 3 | 3 | 3 | 3 | 3 |
| learning_rate(inter/pred) | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 | 5e-5/3e-5 |
## 蒸馏实验结果
本文档的实验基于TinyBERT的6层、hidden_size为768的通用蒸馏得到的模型,用未使用数据增强的原始数据集训练,并基于验证集进行评价。得到以下实验结果:
| | SST-2 | QQP(acc/f1) | MRPC(acc/f1) | CoLA | RTE | MNLI-m | MNLI-mm | QNLI |
| ----------------- | ----- | ----------- | ------------ | ----- | ----- | ------ | ------- | ----- |
| BERT-base | 93.00 | 90.58/87.35 | 88.23/91.67 | 59.56 | 73.65 | 84.42 | 84.83 | 91.78 |
| TinyBERT(6l-768d) | 93.00 | 91.13/88.20 | 88.48/91.91 | 52.64 | 72.94 | 84.57 | 84.63 | 91.36 |
## 参考文献
Jiao X, Yin Y, Shang L, et al.
[
TinyBERT: Distilling BERT for Natural Language Understanding
](
https://arxiv.org/abs/1909.10351
)[
J
]
. arXiv preprint arXiv:1909.10351v5, 2020.
demo/dygraph/dist/bert/distill_stage1.yaml
0 → 100644
浏览文件 @
36a9f6f0
-
DistillConfig
:
loss_function
:
MSELoss
model_name_pairs
:
-
-
student_0
-
teacher_0
weight
:
1.0
-
layers
:
-
layers_name
:
[
'
tinybert.embeddings'
,
'
bert.embeddings'
]
-
layers_name
:
[
'
tinybert.encoder.layers.0'
,
'
bert.encoder.layers.1'
]
-
layers_name
:
[
'
tinybert.encoder.layers.1'
,
'
bert.encoder.layers.3'
]
-
layers_name
:
[
'
tinybert.encoder.layers.2'
,
'
bert.encoder.layers.5'
]
-
layers_name
:
[
'
tinybert.encoder.layers.3'
,
'
bert.encoder.layers.7'
]
-
layers_name
:
[
'
tinybert.encoder.layers.4'
,
'
bert.encoder.layers.9'
]
-
layers_name
:
[
'
tinybert.encoder.layers.5'
,
'
bert.encoder.layers.11'
]
-
layers_name
:
[
'
tinybert.encoder.layers.0.self_attn'
,
'
bert.encoder.layers.1.self_attn'
]
-
layers_name
:
[
'
tinybert.encoder.layers.1.self_attn'
,
'
bert.encoder.layers.3.self_attn'
]
-
layers_name
:
[
'
tinybert.encoder.layers.2.self_attn'
,
'
bert.encoder.layers.5.self_attn'
]
-
layers_name
:
[
'
tinybert.encoder.layers.3.self_attn'
,
'
bert.encoder.layers.7.self_attn'
]
-
layers_name
:
[
'
tinybert.encoder.layers.4.self_attn'
,
'
bert.encoder.layers.9.self_attn'
]
-
layers_name
:
[
'
tinybert.encoder.layers.5.self_attn'
,
'
bert.encoder.layers.11.self_attn'
]
demo/dygraph/dist/bert/distill_stage2.yaml
0 → 100644
浏览文件 @
36a9f6f0
-
DistillConfig
:
loss_function
:
CELoss
model_name_pairs
:
-
-
student_0
-
teacher_0
weight
:
1.0
-
layers
:
-
layers_name
:
[
'
classifier'
,
'
classifier'
]
temperature
:
1.0
demo/dygraph/dist/bert/run.sh
0 → 100644
浏览文件 @
36a9f6f0
export
CUDA_VISIBLE_DEVICES
=
0
export
TASK_NAME
=
SST-2
export
TEACHER_DIR
=
/root/work/Distill_PaddleSlim/PaddleNLP/examples/model_compression/tinybert/best_model_610
python3.7 task_distill.py
\
--model_type
tinybert
\
--student_model_name_or_path
tinybert-6l-768d-v2
\
--task_name
$TASK_NAME
\
--intermediate_distill
\
--max_seq_length
64
\
--batch_size
32
\
--T
1
\
--teacher_model_type
bert
\
--teacher_path
$TEACHER_DIR
\
--learning_rate
5e-5
\
--num_train_epochs
20
\
--logging_steps
10
\
--save_steps
10
\
--output_dir
./tmp/
$TASK_NAME
/
\
--device
gpu
demo/dygraph/dist/bert/task_distill.py
0 → 100644
浏览文件 @
36a9f6f0
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
logging
import
os
import
sys
import
random
import
time
import
math
from
functools
import
partial
import
numpy
as
np
import
paddle
from
paddle.io
import
DataLoader
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.metric
import
Accuracy
from
paddlenlp.datasets
import
load_dataset
from
paddlenlp.data
import
Stack
,
Tuple
,
Pad
,
Dict
from
paddlenlp.data.sampler
import
SamplerHelper
from
paddlenlp.metrics
import
AccuracyAndF1
,
Mcc
,
PearsonAndSpearman
import
paddlenlp.transformers
as
T
from
paddleslim
import
Distill
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
METRIC_CLASSES
=
{
"cola"
:
Mcc
,
"sst-2"
:
Accuracy
,
"mrpc"
:
AccuracyAndF1
,
"sts-b"
:
PearsonAndSpearman
,
"qqp"
:
AccuracyAndF1
,
"mnli"
:
Accuracy
,
"qnli"
:
Accuracy
,
"rte"
:
Accuracy
,
}
MODEL_CLASSES
=
{
"bert"
:
(
T
.
BertForSequenceClassification
,
T
.
BertTokenizer
),
"tinybert"
:
(
T
.
TinyBertForSequenceClassification
,
T
.
TinyBertTokenizer
),
}
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
# Required parameters
parser
.
add_argument
(
"--task_name"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The name of the task to train selected in the list: "
+
", "
.
join
(
METRIC_CLASSES
.
keys
()),
)
parser
.
add_argument
(
"--model_type"
,
default
=
"tinybert"
,
type
=
str
,
required
=
True
,
help
=
"Model type selected in the list: "
+
", "
.
join
(
MODEL_CLASSES
.
keys
()),
)
parser
.
add_argument
(
"--teacher_model_type"
,
default
=
"bert"
,
type
=
str
,
required
=
True
,
help
=
"Model type selected in the list: "
+
", "
.
join
(
MODEL_CLASSES
.
keys
()),
)
parser
.
add_argument
(
"--student_model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to pre-trained model or shortcut name selected in the list: "
+
", "
.
join
(
sum
([
list
(
classes
[
-
1
].
pretrained_init_configuration
.
keys
())
for
classes
in
MODEL_CLASSES
.
values
()
],
[])),
)
parser
.
add_argument
(
"--distill_config"
,
default
=
None
,
type
=
str
,
help
=
"distill config file path"
)
parser
.
add_argument
(
"--teacher_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to pre-trained model."
)
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The output directory where the model predictions and checkpoints will be written."
,
)
parser
.
add_argument
(
"--glue_dir"
,
default
=
"/root/.paddlenlp/datasets/Glue/"
,
type
=
str
,
required
=
False
,
help
=
"The Glue directory."
,
)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
type
=
int
,
help
=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
,
)
parser
.
add_argument
(
"--learning_rate"
,
default
=
1e-4
,
type
=
float
,
help
=
"The initial learning rate for Adam."
)
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
3
,
type
=
int
,
help
=
"Total number of training epochs to perform."
,
)
parser
.
add_argument
(
"--logging_steps"
,
type
=
int
,
default
=
100
,
help
=
"Log every X updates steps."
)
parser
.
add_argument
(
"--save_steps"
,
type
=
int
,
default
=
100
,
help
=
"Save checkpoint every X updates steps."
)
parser
.
add_argument
(
"--batch_size"
,
default
=
32
,
type
=
int
,
help
=
"Batch size per GPU/CPU for training."
,
)
parser
.
add_argument
(
"--T"
,
default
=
1
,
type
=
int
,
help
=
"Temperature for softmax"
,
)
parser
.
add_argument
(
"--use_aug"
,
action
=
"store_true"
,
help
=
"Whether to use augmentation data to train."
,
)
parser
.
add_argument
(
"--intermediate_distill"
,
action
=
"store_true"
,
help
=
"Whether distilling intermediate layers. If False, it means prediction layer distillation."
,
)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight decay if we apply some."
)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0
,
type
=
int
,
help
=
"Linear warmup over warmup_steps. If > 0: Override warmup_proportion"
)
parser
.
add_argument
(
"--warmup_proportion"
,
default
=
0.1
,
type
=
float
,
help
=
"Linear warmup proportion over total steps."
)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-6
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
)
parser
.
add_argument
(
"--max_steps"
,
default
=-
1
,
type
=
int
,
help
=
"If > 0: set total number of training steps to perform. Override num_train_epochs."
,
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
,
help
=
"random seed for initialization"
)
parser
.
add_argument
(
"--device"
,
default
=
"gpu"
,
type
=
str
,
help
=
"The device to select to train the model, is must be cpu/gpu/xpu."
)
args
=
parser
.
parse_args
()
return
args
def
set_seed
(
args
):
# Use the same data seed(for data shuffle) for all procs to guarantee data
# consistency after sharding.
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
# Maybe different op seeds(for dropout) for different procs is better. By:
# `paddle.seed(args.seed + paddle.distributed.get_rank())`
paddle
.
seed
(
args
.
seed
)
@
paddle
.
no_grad
()
def
evaluate
(
model
,
metric
,
data_loader
):
model
.
eval
()
metric
.
reset
()
for
batch
in
data_loader
:
input_ids
,
segment_ids
,
labels
=
batch
logits
=
model
(
input_ids
,
segment_ids
)
correct
=
metric
.
compute
(
logits
,
labels
)
metric
.
update
(
correct
)
res
=
metric
.
accumulate
()
if
isinstance
(
metric
,
AccuracyAndF1
):
print
(
"acc: %s, precision: %s, recall: %s, f1: %s, acc and f1: %s, "
%
(
res
[
0
],
res
[
1
],
res
[
2
],
res
[
3
],
res
[
4
],
),
end
=
''
)
elif
isinstance
(
metric
,
Mcc
):
print
(
"mcc: %s, "
%
(
res
[
0
]),
end
=
''
)
elif
isinstance
(
metric
,
PearsonAndSpearman
):
print
(
"pearson: %s, spearman: %s, pearson and spearman: %s, "
%
(
res
[
0
],
res
[
1
],
res
[
2
]),
end
=
''
)
else
:
print
(
"acc: %s, "
%
(
res
),
end
=
''
)
model
.
train
()
return
res
[
0
]
if
isinstance
(
metric
,
(
AccuracyAndF1
,
Mcc
,
PearsonAndSpearman
))
else
res
def
convert_example
(
example
,
tokenizer
,
label_list
,
max_seq_length
=
512
,
is_test
=
False
):
"""convert a glue example into necessary features"""
if
not
is_test
:
# `label_list == None` is for regression task
label_dtype
=
"int64"
if
label_list
else
"float32"
# Get the label
label
=
example
[
'labels'
]
label
=
np
.
array
([
label
],
dtype
=
label_dtype
)
# Convert raw text to feature
if
(
int
(
is_test
)
+
len
(
example
))
==
2
:
example
=
tokenizer
(
example
[
'sentence'
],
max_seq_len
=
max_seq_length
)
else
:
example
=
tokenizer
(
example
[
'sentence1'
],
text_pair
=
example
[
'sentence2'
],
max_seq_len
=
max_seq_length
)
if
not
is_test
:
return
example
[
'input_ids'
],
example
[
'token_type_ids'
],
label
else
:
return
example
[
'input_ids'
],
example
[
'token_type_ids'
]
def
do_train
(
args
):
paddle
.
set_device
(
args
.
device
)
if
paddle
.
distributed
.
get_world_size
()
>
1
:
paddle
.
distributed
.
init_parallel_env
()
set_seed
(
args
)
args
.
task_name
=
args
.
task_name
.
lower
()
metric_class
=
METRIC_CLASSES
[
args
.
task_name
]
args
.
model_type
=
args
.
model_type
.
lower
()
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
if
args
.
use_aug
:
aug_data_file
=
os
.
path
.
join
(
os
.
path
.
join
(
args
.
glue_dir
,
args
.
task_name
),
"train_aug.tsv"
),
train_ds
=
load_dataset
(
'glue'
,
args
.
task_name
,
data_files
=
aug_data_file
)
else
:
train_ds
=
load_dataset
(
'glue'
,
args
.
task_name
,
splits
=
'train'
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
student_model_name_or_path
)
trans_func
=
partial
(
convert_example
,
tokenizer
=
tokenizer
,
label_list
=
train_ds
.
label_list
,
max_seq_length
=
args
.
max_seq_length
)
train_ds
=
train_ds
.
map
(
trans_func
,
lazy
=
True
)
train_batch_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_ds
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
batchify_fn
=
lambda
samples
,
fn
=
Tuple
(
Pad
(
axis
=
0
,
pad_val
=
tokenizer
.
pad_token_id
),
# input
Pad
(
axis
=
0
,
pad_val
=
tokenizer
.
pad_token_type_id
),
# segment
Stack
(
dtype
=
"int64"
if
train_ds
.
label_list
else
"float32"
)
# label
):
fn
(
samples
)
train_data_loader
=
DataLoader
(
dataset
=
train_ds
,
batch_sampler
=
train_batch_sampler
,
collate_fn
=
batchify_fn
,
num_workers
=
0
,
return_list
=
True
)
if
args
.
task_name
==
"mnli"
:
dev_ds_matched
,
dev_ds_mismatched
=
load_dataset
(
'glue'
,
args
.
task_name
,
splits
=
[
"dev_matched"
,
"dev_mismatched"
])
dev_ds_matched
=
dev_ds_matched
.
map
(
trans_func
,
lazy
=
True
)
dev_ds_mismatched
=
dev_ds_mismatched
.
map
(
trans_func
,
lazy
=
True
)
dev_batch_sampler_matched
=
paddle
.
io
.
BatchSampler
(
dev_ds_matched
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
)
dev_data_loader_matched
=
DataLoader
(
dataset
=
dev_ds_matched
,
batch_sampler
=
dev_batch_sampler_matched
,
collate_fn
=
batchify_fn
,
num_workers
=
0
,
return_list
=
True
)
dev_batch_sampler_mismatched
=
paddle
.
io
.
BatchSampler
(
dev_ds_mismatched
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
)
dev_data_loader_mismatched
=
DataLoader
(
dataset
=
dev_ds_mismatched
,
batch_sampler
=
dev_batch_sampler_mismatched
,
collate_fn
=
batchify_fn
,
num_workers
=
0
,
return_list
=
True
)
else
:
dev_ds
=
load_dataset
(
'glue'
,
args
.
task_name
,
splits
=
'dev'
)
dev_ds
=
dev_ds
.
map
(
trans_func
,
lazy
=
True
)
dev_batch_sampler
=
paddle
.
io
.
BatchSampler
(
dev_ds
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
)
dev_data_loader
=
DataLoader
(
dataset
=
dev_ds
,
batch_sampler
=
dev_batch_sampler
,
collate_fn
=
batchify_fn
,
num_workers
=
0
,
return_list
=
True
)
num_classes
=
1
if
train_ds
.
label_list
==
None
else
len
(
train_ds
.
label_list
)
student
=
model_class
.
from_pretrained
(
args
.
student_model_name_or_path
,
num_classes
=
num_classes
)
teacher_model_class
,
_
=
MODEL_CLASSES
[
args
.
teacher_model_type
]
teacher
=
teacher_model_class
.
from_pretrained
(
args
.
teacher_path
,
num_classes
=
num_classes
)
teacher
.
eval
()
if
paddle
.
distributed
.
get_world_size
()
>
1
:
student
=
paddle
.
DataParallel
(
student
,
find_unused_parameters
=
True
)
teacher
=
paddle
.
DataParallel
(
teacher
,
find_unused_parameters
=
True
)
if
args
.
max_steps
>
0
:
num_training_steps
=
args
.
max_steps
num_train_epochs
=
math
.
ceil
(
num_training_steps
/
len
(
train_data_loader
))
else
:
num_training_steps
=
len
(
train_data_loader
)
*
args
.
num_train_epochs
num_train_epochs
=
args
.
num_train_epochs
warmup
=
args
.
warmup_steps
if
args
.
warmup_steps
>
0
else
args
.
warmup_proportion
lr_scheduler
=
T
.
LinearDecayWithWarmup
(
args
.
learning_rate
,
num_training_steps
,
warmup
)
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params
=
[
p
.
name
for
n
,
p
in
student
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
[
"bias"
,
"norm"
])
]
optimizer
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
lr_scheduler
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
args
.
adam_epsilon
,
parameters
=
student
.
parameters
(),
weight_decay
=
args
.
weight_decay
,
apply_decay_param_fun
=
lambda
x
:
x
in
decay_params
)
metric
=
metric_class
()
pad_token_id
=
0
global_step
=
0
tic_train
=
time
.
time
()
best_res
=
0.0
assert
os
.
path
.
exists
(
args
.
distill_config
),
"distill file {} not exist."
.
format
(
args
.
distill_config
)
distill_model
=
Distill
(
args
.
distill_config
,
student_models
=
[
student
],
teacher_models
=
[
teacher
])
for
epoch
in
range
(
num_train_epochs
):
for
step
,
batch
in
enumerate
(
train_data_loader
):
global_step
+=
1
input_ids
,
segment_ids
,
labels
=
batch
loss
,
_
,
_
=
distill_model
(
input_ids
,
segment_ids
)
loss
.
backward
()
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
clear_grad
()
if
global_step
%
args
.
logging_steps
==
0
:
print
(
"global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s"
%
(
global_step
,
num_training_steps
,
epoch
,
step
,
paddle
.
distributed
.
get_rank
(),
loss
,
optimizer
.
get_lr
(),
args
.
logging_steps
/
(
time
.
time
()
-
tic_train
)))
tic_train
=
time
.
time
()
if
global_step
%
args
.
save_steps
==
0
or
global_step
==
num_training_steps
:
tic_eval
=
time
.
time
()
if
args
.
task_name
==
"mnli"
:
res
=
evaluate
(
student
,
metric
,
dev_data_loader_matched
)
evaluate
(
student
,
metric
,
dev_data_loader_mismatched
)
print
(
"eval done total : %s s"
%
(
time
.
time
()
-
tic_eval
))
else
:
res
=
evaluate
(
student
,
metric
,
dev_data_loader
)
print
(
"eval done total : %s s"
%
(
time
.
time
()
-
tic_eval
))
if
(
best_res
<
res
and
global_step
<
num_training_steps
or
global_step
==
num_training_steps
)
and
paddle
.
distributed
.
get_rank
()
==
0
:
if
global_step
<
num_training_steps
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"distill_model_%d.pdparams"
%
(
global_step
))
else
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"distill_model_final.pdparams"
)
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
# Need better way to get inner model of DataParallel
model_to_save
=
student
.
_layers
if
isinstance
(
student
,
paddle
.
DataParallel
)
else
student
model_to_save
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
best_res
=
res
if
global_step
>=
num_training_steps
:
return
def
print_arguments
(
args
):
"""print arguments"""
print
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
items
()):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
if
__name__
==
"__main__"
:
args
=
parse_args
()
print_arguments
(
args
)
do_train
(
args
)
paddleslim/dygraph/dist/__init__.py
浏览文件 @
36a9f6f0
...
@@ -14,7 +14,9 @@
...
@@ -14,7 +14,9 @@
from
.
import
distill
from
.
import
distill
from
.distill
import
*
from
.distill
import
*
from
.distill_helpers
import
*
__all__
=
[]
__all__
=
[]
__all__
+=
distill
.
__all__
__all__
+=
distill
.
__all__
__all__
+=
distill_helpers
.
__all__
paddleslim/dygraph/dist/distill.py
浏览文件 @
36a9f6f0
...
@@ -17,207 +17,200 @@ import collections
...
@@ -17,207 +17,200 @@ import collections
from
collections
import
namedtuple
from
collections
import
namedtuple
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
.
import
losses
from
.
import
losses
from
.losses.basic_loss
import
BASIC_LOSS
from
.distill_helpers
import
yaml2config
__all__
=
[
'Distill'
,
'AdaptorBase'
]
__all__
=
[
'Distill'
]
class
LayerConfig
:
class
LayerConfig
:
""" The key of config can be set"""
def
__init__
(
self
,
def
__init__
(
self
,
s_feature_idx
,
model_name_pairs
,
t_feature_idx
,
layers_name
,
feature_type
,
loss_function
,
loss_function
,
weight
=
1.0
,
weight
=
1.0
,
align
=
False
,
temperature
=
1.0
,
align_shape
=
None
):
align_params
=
None
,
self
.
s_feature_idx
=
s_feature_idx
**
loss_params
):
self
.
t_feature_idx
=
t_feature_idx
self
.
model_name_pairs
=
model_name_pairs
self
.
feature_type
=
feature_type
self
.
layers_name
=
layers_name
if
loss_function
in
[
'l1'
,
'l2'
,
'smooth_l1'
]:
if
loss_function
not
in
BASIC_LOSS
.
module_dict
:
self
.
loss_function
=
'DistillationDistanceLoss'
raise
NotImplementedError
(
"loss function {} is not support. "
elif
loss_function
in
[
'dml'
]:
"Support loss including {}"
.
format
(
self
.
loss_function
=
'DistillationDMLLoss'
loss_function
,
elif
loss_function
in
[
'rkl'
]:
BASIC_LOSS
.
module_dict
.
keys
()))
self
.
loss_function
=
'DistillationRKDLoss'
self
.
loss_function
=
loss_function
elif
hasattr
(
losses
,
loss_function
):
self
.
loss_function
=
loss_function
else
:
raise
NotImplementedError
(
"loss function is not support!!!"
)
self
.
weight
=
weight
self
.
weight
=
weight
self
.
align
=
align
self
.
temperature
=
temperature
self
.
align_shape
=
align_shape
self
.
align_params
=
align_params
for
k
,
v
in
loss_params
.
items
():
setattr
(
self
,
k
,
v
)
class
AdaptorBase
:
def
__init__
(
self
,
model
):
self
.
model
=
model
def
_add_hooks
(
model
,
outs
,
hook_layers_name
):
self
.
add_tensor
=
False
"""
Get output by layer name.
def
_get_activation
(
self
,
outs
,
name
):
models(nn.Layer): model need to be add hook.
outs(dict): save the middle outputs of model according to the name.
hook_layers_name(list): name of middle layers.
"""
def
_get_activation
(
outs
,
name
):
### TODO: need to support get input tensor
#outs[name] = {}
def
get_output_hook
(
layer
,
input
,
output
):
def
get_output_hook
(
layer
,
input
,
output
):
#outs[name]["output"] = output
#outs[name]["input"] = input
outs
[
name
]
=
output
outs
[
name
]
=
output
return
get_output_hook
return
get_output_hook
def
_add_distill_hook
(
self
,
outs
,
mapping_layers_name
,
layers_type
):
### TODO: support DP model
"""
for
idx
,
(
n
,
m
)
in
enumerate
(
model
.
named_sublayers
()):
Get output by layer name.
if
n
in
hook_layers_name
:
outs(dict): save the middle outputs of model according to the name.
m
.
register_forward_post_hook
(
_get_activation
(
outs
,
n
))
mapping_layers(list): name of middle layers.
layers_type(list): type of the middle layers to calculate distill loss.
"""
### TODO: support DP model
for
idx
,
(
n
,
m
)
in
enumerate
(
self
.
model
.
named_sublayers
()):
if
n
in
mapping_layers_name
:
midx
=
mapping_layers_name
.
index
(
n
)
m
.
register_forward_post_hook
(
self
.
_get_activation
(
outs
,
layers_type
[
midx
]))
def
mapping_layers
(
self
):
raise
NotImplementedError
(
"function mapping_layers is not implemented"
)
class
Distill
(
nn
.
Layer
):
class
Distill
(
nn
.
Layer
):
### TODO: support list of student model and teacher model
"""
def
__init__
(
self
,
distill_configs
,
student_models
,
teacher_models
,
Distill API.
adaptors_S
,
adaptors_T
):
distill_configs(list(dict) | path): the list of distill config.
super
(
Distill
,
self
).
__init__
()
student_models(list(nn.Layer)): the list of student model, the state of student model must be training mode.
assert
student_models
.
training
,
"The student model should be eval mode."
teacher_models(list(nn.Layer)): the list of teacher model, the state of student model must be evaluate mode.
return_model_outputs(bool): whether to return model output. Default: True.
"""
self
.
_distill_configs
=
distill_configs
def
__init__
(
self
,
distill_configs
,
student_models
,
teacher_models
,
return_model_outputs
=
True
):
super
(
Distill
,
self
).
__init__
()
if
isinstance
(
student_models
,
nn
.
Layer
):
student_models
=
[
student_models
]
if
isinstance
(
teacher_models
,
nn
.
Layer
):
teacher_models
=
[
teacher_models
]
for
student_model
in
student_models
:
assert
student_model
.
training
,
"The student model should not be eval mode."
for
teacher_model
in
teacher_models
:
assert
teacher_model
.
training
is
False
,
"The teacher model should be eval mode."
if
isinstance
(
distill_configs
,
list
):
self
.
_distill_configs
=
distill_configs
elif
os
.
path
.
exists
(
distill_configs
):
if
distill_configs
.
endswith
(
".yaml"
):
self
.
_distill_configs
=
yaml2config
(
distill_configs
)
else
:
raise
NotImplementedError
(
"distill config file type error!"
)
else
:
raise
NotImplementedError
(
"distill config error!"
)
self
.
_student_models
=
student_models
self
.
_student_models
=
student_models
self
.
_teacher_models
=
teacher_models
self
.
_teacher_models
=
teacher_models
self
.
_adaptors_S
=
adaptors_S
(
self
.
_student_models
)
self
.
_return_model_outputs
=
return_model_outputs
self
.
_adaptors_T
=
adaptors_T
(
self
.
_teacher_models
)
self
.
stu_outs_dict
,
self
.
tea_outs_dict
=
self
.
_prepare_outputs
()
self
.
_loss_config_list
=
[]
self
.
configs
=
[]
for
c
in
self
.
_distill_configs
:
for
c
in
self
.
_distill_configs
:
self
.
configs
.
append
(
LayerConfig
(
**
c
).
__dict__
)
self
.
_transpose_config
(
c
)
self
.
distill_idx
=
self
.
_get_distill_idx
()
self
.
_hook_layers
=
self
.
_extract_hook_position
()
self
.
_loss_config_list
=
[]
for
c
in
self
.
configs
:
loss_config
=
{}
loss_config
[
str
(
c
[
'loss_function'
])]
=
{}
loss_config
[
str
(
c
[
'loss_function'
])][
'weight'
]
=
c
[
'weight'
]
loss_config
[
str
(
c
[
'loss_function'
])][
'key'
]
=
c
[
'feature_type'
]
+
'_'
+
str
(
c
[
's_feature_idx'
])
+
'_'
+
str
(
c
[
't_feature_idx'
])
### TODO: support list of student models and teacher_models
loss_config
[
str
(
c
[
'loss_function'
])][
'model_name_pairs'
]
=
[[
'student'
,
'teacher'
]]
self
.
_loss_config_list
.
append
(
loss_config
)
# use self._loss_config_list to create all loss object
# use self._loss_config_list to create all loss object
self
.
distill_loss
=
losses
.
CombinedLoss
(
self
.
_loss_config_list
)
self
.
distill_loss
=
losses
.
CombinedLoss
(
self
.
_loss_config_list
)
self
.
_output_tensor_dict
=
self
.
_prepare_outputs
()
def
parameters
(
self
):
params
=
[]
for
s_model
in
self
.
_student_models
:
params
.
extend
(
s_model
.
parameters
())
return
params
def
_extract_hook_position
(
self
):
""" extrat hook position according to config"""
model_hook_layers
=
{}
for
config
in
self
.
_loss_config_list
:
model_name_pairs
=
config
[
'model_name_pairs'
]
layers_name
=
config
[
'layers_name'
]
for
model_name_pair
in
model_name_pairs
:
for
idx
,
model_name
in
enumerate
(
model_name_pair
):
if
model_name
not
in
model_hook_layers
:
model_hook_layers
[
model_name
]
=
[
layers_name
[
idx
]]
else
:
model_hook_layers
[
model_name
].
append
(
layers_name
[
idx
])
for
model_name
,
hook_layers
in
model_hook_layers
.
items
():
model_hook_layers
[
model_name
]
=
list
(
set
(
hook_layers
))
return
model_hook_layers
def
_transpose_config
(
self
,
config
):
""" Transpose config to loss needed """
global_config
=
{}
if
'model_name_pairs'
not
in
config
:
global_config
[
'model_name_pairs'
]
=
[[
'student_0'
,
'teacher_0'
]]
else
:
if
isinstance
(
config
[
'model_name_pairs'
][
0
],
str
):
config
[
'model_name_pairs'
]
=
[
config
[
'model_name_pairs'
]]
global_config
[
'model_name_pairs'
]
=
config
[
'model_name_pairs'
]
config
.
pop
(
'model_name_pairs'
)
for
key
in
config
.
keys
():
if
key
!=
'layers'
:
global_config
[
key
]
=
config
[
key
]
for
per_layer_config
in
config
[
'layers'
]:
per_layer_config
.
update
(
global_config
)
self
.
_loss_config_list
.
append
(
LayerConfig
(
**
per_layer_config
).
__dict__
)
def
_prepare_outputs
(
self
):
def
_prepare_outputs
(
self
):
"""
"""
Add hook to get the output tensor of target layer.
Add hook to get the output tensor of target layer.
Returns:
stu_outs_dict(dict): the name and tensor for the student model,
such as {'hidden_0': tensor_0, ..}
tea_outs_dict(dict): the name and tensor for the teather model,
such as {'hidden_0': tensor_0, ..}
"""
"""
stu_outs_dict
=
collections
.
OrderedDict
()
outputs_tensor
=
{}
tea_outs_dict
=
collections
.
OrderedDict
()
for
idx
,
m
in
enumerate
(
self
.
_student_models
):
stu_outs_dict
=
self
.
_prepare_hook
(
self
.
_adaptors_S
,
stu_outs_dict
)
hook_layers
=
self
.
_hook_layers
[
'student_{}'
.
format
(
idx
)]
tea_outs_dict
=
self
.
_prepare_hook
(
self
.
_adaptors_T
,
tea_outs_dict
)
stu_outs
=
collections
.
OrderedDict
()
return
stu_outs_dict
,
tea_outs_dict
outputs_tensor
[
'student_{}'
.
format
(
idx
)]
=
self
.
_prepare_hook
(
m
,
hook_layers
,
stu_outs
)
def
_prepare_hook
(
self
,
adaptors
,
outs_dict
):
for
idx
,
m
in
enumerate
(
self
.
_teacher_models
):
hook_layers
=
self
.
_hook_layers
[
'teacher_{}'
.
format
(
idx
)]
tea_outs
=
collections
.
OrderedDict
()
outputs_tensor
[
'teacher_{}'
.
format
(
idx
)]
=
self
.
_prepare_hook
(
m
,
hook_layers
,
tea_outs
)
return
outputs_tensor
def
_prepare_hook
(
self
,
model
,
hook_layers
,
outs_dict
):
"""
"""
Add hook.
Add hook.
"""
"""
mapping_layers
=
adaptors
.
mapping_layers
()
for
layer
in
hook_layers
:
for
layer_type
,
layer
in
mapping_layers
.
items
():
if
isinstance
(
layer
,
str
):
if
isinstance
(
layer
,
str
):
adaptors
.
_add_distill_hook
(
outs_dict
,
[
layer
],
[
layer_type
]
)
_add_hooks
(
model
,
outs_dict
,
layer
)
return
outs_dict
return
outs_dict
def
_get_distill_idx
(
self
):
"""
For each feature_type, get the feature index in the student and teacher models.
Returns:
distill_idx(dict): the feature index for each feature_type,
such as {'hidden': [[0, 0], [1, 1]], 'out': [[0, 0]]}
"""
distill_idx
=
{}
for
config
in
self
.
_distill_configs
:
if
config
[
'feature_type'
]
not
in
distill_idx
:
distill_idx
[
config
[
'feature_type'
]]
=
[[
int
(
config
[
's_feature_idx'
]),
int
(
config
[
't_feature_idx'
])
]]
else
:
distill_idx
[
config
[
'feature_type'
]].
append
([
int
(
config
[
's_feature_idx'
]),
int
(
config
[
't_feature_idx'
])
])
return
distill_idx
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
stu_batch_outs
=
self
.
_student_models
.
forward
(
*
inputs
,
**
kwargs
)
students_batch_outs
=
[]
tea_batch_outs
=
self
.
_teacher_models
.
forward
(
*
inputs
,
**
kwargs
)
teachers_batch_outs
=
[]
if
not
self
.
_teacher_models
.
training
:
for
idx
,
student_model
in
enumerate
(
self
.
_student_models
):
tea_batch_outs
=
[
i
.
detach
()
for
i
in
tea_batch_outs
]
stu_batch_outs
=
student_model
.
forward
(
*
inputs
,
**
kwargs
)
students_batch_outs
.
append
(
stu_batch_outs
)
# get all target tensor
for
idx
,
teacher_model
in
enumerate
(
self
.
_teacher_models
):
if
self
.
_adaptors_S
.
add_tensor
==
False
:
tea_batch_outs
=
teacher_model
.
forward
(
*
inputs
,
**
kwargs
)
self
.
_adaptors_S
.
add_tensor
=
True
if
not
teacher_model
.
training
:
if
self
.
_adaptors_T
.
add_tensor
==
False
:
tea_batch_outs
=
[
i
.
detach
()
for
i
in
tea_batch_outs
]
self
.
_adaptors_T
.
add_tensor
=
True
teachers_batch_outs
.
extend
(
tea_batch_outs
)
self
.
stu_outs_dict
=
self
.
_get_model_intermediate_output
(
self
.
_adaptors_S
,
self
.
stu_outs_dict
)
if
len
(
self
.
_student_models
)
==
1
:
self
.
tea_outs_dict
=
self
.
_get_model_intermediate_output
(
students_batch_outs
=
students_batch_outs
[
0
]
self
.
_adaptors_T
,
self
.
tea_outs_dict
)
if
len
(
self
.
_teacher_models
)
==
1
:
teachers_batch_outs
=
teachers_batch_outs
[
0
]
distill_inputs
=
self
.
_process_outputs
()
### batch is None just for now
### batch is None just for now
distill_outputs
=
self
.
distill_loss
(
distill_inputs
,
None
)
distill_outputs
=
self
.
distill_loss
(
self
.
_output_tensor_dict
,
None
)
distill_loss
=
distill_outputs
[
'loss'
]
distill_loss
=
distill_outputs
[
'loss'
]
return
stu_batch_outs
,
tea_batch_outs
,
distill_loss
if
self
.
_return_model_outputs
:
return
distill_loss
,
students_batch_outs
,
teachers_batch_outs
def
_get_model_intermediate_output
(
self
,
adaptors
,
outs_dict
):
else
:
"""
return
distill_loss
Use the adaptor get the target tensor.
Returns:
outs_dict(dict): the name and tensor for the target model,
such as {'hidden_0': tensor_0, ..}
"""
mapping_layers
=
adaptors
.
mapping_layers
()
for
layer_type
,
layer
in
mapping_layers
.
items
():
if
isinstance
(
layer
,
str
):
continue
outs_dict
[
layer_type
]
=
layer
return
outs_dict
def
_process_outputs
(
self
):
"""
Process the target tensor to adapt for loss.
"""
### TODO: support list of student models and teacher_models
final_distill_dict
=
{
"student"
:
collections
.
OrderedDict
(),
"teacher"
:
collections
.
OrderedDict
()
}
for
feature_type
,
dist_idx
in
self
.
distill_idx
.
items
():
for
idx
,
idx_list
in
enumerate
(
dist_idx
):
sidx
,
tidx
=
idx_list
[
0
],
idx_list
[
1
]
stu_out
=
self
.
stu_outs_dict
[
feature_type
+
'_'
+
str
(
sidx
)]
tea_out
=
self
.
tea_outs_dict
[
feature_type
+
'_'
+
str
(
tidx
)]
if
not
self
.
_student_models
.
training
:
stu_out
=
stu_out
.
detach
()
if
not
self
.
_teacher_models
.
training
:
tea_out
=
tea_out
.
detach
()
name_str
=
feature_type
+
'_'
+
str
(
sidx
)
+
'_'
+
str
(
tidx
)
final_distill_dict
[
'student'
][
name_str
]
=
stu_out
final_distill_dict
[
'teacher'
][
name_str
]
=
tea_out
return
final_distill_dict
paddleslim/dygraph/dist/distill_helpers.py
0 → 100644
浏览文件 @
36a9f6f0
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
yaml
__all__
=
[
'config2yaml'
]
def
yaml2config
(
yaml_path
):
"""
convert yaml to dict config.
"""
final_configs
=
[]
f
=
open
(
yaml_path
,
'r'
)
origin_configs
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
f
.
close
()
for
configs
in
origin_configs
:
configs
=
configs
[
'DistillConfig'
]
final_configs
.
extend
(
configs
)
return
final_configs
def
config2yaml
(
configs
,
yaml_path
):
"""
convert dict config to yaml.
"""
final_yaml
=
dict
()
final_yaml
[
'DistillConfig'
]
=
configs
f
=
open
(
yaml_path
,
"w"
)
yaml
.
dump
([
final_yaml
],
f
)
f
.
close
()
paddleslim/dygraph/dist/losses/__init__.py
浏览文件 @
36a9f6f0
...
@@ -19,18 +19,7 @@ import paddle.nn as nn
...
@@ -19,18 +19,7 @@ import paddle.nn as nn
from
.
import
basic_loss
from
.
import
basic_loss
from
.
import
distillation_loss
from
.
import
distillation_loss
from
.basic_loss
import
L1Loss
from
.distillation_loss
import
DistillationLoss
from
.basic_loss
import
L2Loss
from
.basic_loss
import
SmoothL1Loss
from
.basic_loss
import
CELoss
from
.basic_loss
import
DMLLoss
from
.basic_loss
import
DistanceLoss
from
.basic_loss
import
RKdAngle
,
RkdDistance
from
.distillation_loss
import
DistillationDistanceLoss
from
.distillation_loss
import
DistillationDMLLoss
from
.distillation_loss
import
DistillationRKDLoss
from
.distillation_loss
import
SegPairWiseLoss
,
SegChannelwiseLoss
class
CombinedLoss
(
nn
.
Layer
):
class
CombinedLoss
(
nn
.
Layer
):
...
@@ -40,13 +29,12 @@ class CombinedLoss(nn.Layer):
...
@@ -40,13 +29,12 @@ class CombinedLoss(nn.Layer):
loss_config_list: a config list used to build loss function. A demo is as follows,
loss_config_list: a config list used to build loss function. A demo is as follows,
which is used to calculate dml loss between Student output and
which is used to calculate dml loss between Student output and
Teacher output. Parameter weight is needed for the loss weight.
Teacher output. Parameter weight is needed for the loss weight.
- DistillationDMLLoss:
{ loss_function: DMLLoss
weight: 1.0
weight: 1.0
act: "softmax"
act: "softmax"
model_name_pairs:
model_name_pairs:["student_0", "teacher_0"]}
- ["Student", "Teacher"]
Another example is {loss_function: "MSELoss", 'weight': 1.0,
Another example is {'DistillationDistanceLoss': {'weight': 1.0,
'layers_name': ['conv0', 'conv0'], 'model_name_pairs': [['student', 'teacher']]}
'key': 'hidden_0_0', 'model_name_pairs': [['student', 'teacher']]}
"""
"""
def
__init__
(
self
,
loss_config_list
=
None
):
def
__init__
(
self
,
loss_config_list
=
None
):
...
@@ -56,18 +44,14 @@ class CombinedLoss(nn.Layer):
...
@@ -56,18 +44,14 @@ class CombinedLoss(nn.Layer):
self
.
loss_weight
=
[]
self
.
loss_weight
=
[]
assert
isinstance
(
loss_config_list
,
list
),
(
assert
isinstance
(
loss_config_list
,
list
),
(
'operator config should be a list'
)
'operator config should be a list'
)
supported_loss_list
=
basic_loss
.
__all__
+
distillation_loss
.
__all__
for
config
in
loss_config_list
:
for
config
in
loss_config_list
:
assert
isinstance
(
config
,
assert
isinstance
(
dict
)
and
len
(
config
)
==
1
,
"yaml format error"
config
,
dict
),
"config must be a dict, but now is {}"
.
format
(
name
=
list
(
config
)[
0
]
type
(
config
))
assert
name
in
supported_loss_list
,
\
assert
"weight"
in
config
,
"weight must be in param, but param just contains {}"
.
format
(
"loss name must be in {} but got: {}"
.
format
(
name
,
supported_loss_list
)
config
.
keys
())
param
=
config
[
name
]
self
.
loss_weight
.
append
(
config
.
pop
(
"weight"
))
assert
"weight"
in
param
,
"weight must be in param, but param just contains {}"
.
format
(
self
.
loss_func
.
append
(
DistillationLoss
(
**
config
))
param
.
keys
())
self
.
loss_weight
.
append
(
param
.
pop
(
"weight"
))
self
.
loss_func
.
append
(
eval
(
name
)(
**
param
))
def
forward
(
self
,
input
,
batch
,
**
kargs
):
def
forward
(
self
,
input
,
batch
,
**
kargs
):
loss_dict
=
{}
loss_dict
=
{}
...
@@ -82,6 +66,7 @@ class CombinedLoss(nn.Layer):
...
@@ -82,6 +66,7 @@ class CombinedLoss(nn.Layer):
for
key
in
loss
for
key
in
loss
}
}
loss_dict
.
update
(
loss
)
loss_dict
.
update
(
loss
)
if
loss_dict
==
{}:
if
loss_dict
==
{}:
loss_dict
[
"loss"
]
=
paddle
.
to_tensor
(
0.
)
loss_dict
[
"loss"
]
=
paddle
.
to_tensor
(
0.
)
else
:
else
:
...
...
paddleslim/dygraph/dist/losses/basic_loss.py
浏览文件 @
36a9f6f0
...
@@ -20,11 +20,13 @@ from paddle.nn import L1Loss
...
@@ -20,11 +20,13 @@ from paddle.nn import L1Loss
from
paddle.nn
import
MSELoss
as
L2Loss
from
paddle.nn
import
MSELoss
as
L2Loss
from
paddle.nn
import
SmoothL1Loss
from
paddle.nn
import
SmoothL1Loss
__all__
=
[
from
....core
import
Registry
"CELoss"
,
"DMLLoss"
,
"DistanceLoss"
,
"RKdAngle"
,
"RkdDistance"
,
"KLLoss"
]
__all__
=
[
"BASIC_LOSS"
]
BASIC_LOSS
=
Registry
(
"basicloss"
)
@
BASIC_LOSS
.
register
class
CELoss
(
nn
.
Layer
):
class
CELoss
(
nn
.
Layer
):
"""
"""
CELoss: cross entropy loss
CELoss: cross entropy loss
...
@@ -78,6 +80,7 @@ class CELoss(nn.Layer):
...
@@ -78,6 +80,7 @@ class CELoss(nn.Layer):
return
loss
return
loss
@
BASIC_LOSS
.
register
class
DMLLoss
(
nn
.
Layer
):
class
DMLLoss
(
nn
.
Layer
):
"""
"""
DMLLoss
DMLLoss
...
@@ -110,6 +113,7 @@ class DMLLoss(nn.Layer):
...
@@ -110,6 +113,7 @@ class DMLLoss(nn.Layer):
return
loss
return
loss
@
BASIC_LOSS
.
register
class
KLLoss
(
nn
.
Layer
):
class
KLLoss
(
nn
.
Layer
):
"""
"""
KLLoss.
KLLoss.
...
@@ -153,6 +157,7 @@ class KLLoss(nn.Layer):
...
@@ -153,6 +157,7 @@ class KLLoss(nn.Layer):
return
loss
return
loss
@
BASIC_LOSS
.
register
class
DistanceLoss
(
nn
.
Layer
):
class
DistanceLoss
(
nn
.
Layer
):
"""
"""
DistanceLoss
DistanceLoss
...
@@ -191,6 +196,7 @@ def pdist(e, squared=False, eps=1e-12):
...
@@ -191,6 +196,7 @@ def pdist(e, squared=False, eps=1e-12):
return
res
return
res
@
BASIC_LOSS
.
register
class
RKdAngle
(
nn
.
Layer
):
class
RKdAngle
(
nn
.
Layer
):
"""
"""
RKdAngle loss, see https://arxiv.org/abs/1904.05068
RKdAngle loss, see https://arxiv.org/abs/1904.05068
...
@@ -218,6 +224,7 @@ class RKdAngle(nn.Layer):
...
@@ -218,6 +224,7 @@ class RKdAngle(nn.Layer):
return
loss
return
loss
@
BASIC_LOSS
.
register
class
RkdDistance
(
nn
.
Layer
):
class
RkdDistance
(
nn
.
Layer
):
"""
"""
RkdDistance loss, see https://arxiv.org/abs/1904.05068
RkdDistance loss, see https://arxiv.org/abs/1904.05068
...
@@ -244,3 +251,50 @@ class RkdDistance(nn.Layer):
...
@@ -244,3 +251,50 @@ class RkdDistance(nn.Layer):
loss
=
F
.
smooth_l1_loss
(
d
,
t_d
,
reduction
=
"mean"
)
loss
=
F
.
smooth_l1_loss
(
d
,
t_d
,
reduction
=
"mean"
)
return
loss
return
loss
@
BASIC_LOSS
.
register
class
MSELoss
(
DistanceLoss
):
"""
MSELoss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/MSELoss_cn.html#mseloss
"""
def
__init__
(
self
,
**
kargs
):
super
().
__init__
(
mode
=
'l2'
,
**
kargs
)
@
BASIC_LOSS
.
register
class
L1Loss
(
DistanceLoss
):
"""
L1loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/L1Loss_cn.html#l1loss
"""
def
__init__
(
self
,
**
kargs
):
super
().
__init__
(
mode
=
'l1'
,
**
kargs
)
@
BASIC_LOSS
.
register
class
SmoothL1Loss
(
DistanceLoss
):
"""
SmoothL1Loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/SmoothL1Loss_cn.html#smoothl1loss
"""
def
__init__
(
self
,
**
kargs
):
super
().
__init__
(
mode
=
'smooth_l1'
,
**
kargs
)
@
BASIC_LOSS
.
register
class
RKDLoss
(
nn
.
Layer
):
"""
RKDLoss
"""
def
__init__
(
self
,
eps
=
1e-12
):
super
().
__init__
()
self
.
rkd_angle_loss_func
=
RKdAngle
()
self
.
rkd_dist_func
=
RkdDistance
(
eps
=
eps
)
def
forward
(
self
,
student
,
teacher
):
angle_loss
=
self
.
rkd_angle_loss_func
(
student
,
teacher
)
dist_loss
=
self
.
rkd_dist_func
(
student
,
teacher
)
return
angle_loss
+
dist_loss
paddleslim/dygraph/dist/losses/distillation_loss.py
浏览文件 @
36a9f6f0
...
@@ -15,210 +15,54 @@
...
@@ -15,210 +15,54 @@
import
paddle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
.basic_loss
import
DMLLoss
from
.basic_loss
import
BASIC_LOSS
from
.basic_loss
import
DistanceLoss
from
.basic_loss
import
RkdDistance
from
.basic_loss
import
RKdAngle
from
.basic_loss
import
KLLoss
__all__
=
[
__all__
=
[
"DistillationLoss"
]
"DistillationDMLLoss"
,
"DistillationDistanceLoss"
,
"DistillationRKDLoss"
,
"SegPairWiseLoss"
,
"SegChannelwiseLoss"
,
]
class
Distillation
DMLLoss
(
DMLLoss
):
class
Distillation
Loss
(
nn
.
Layer
):
"""
"""
Distillation
DML
Loss
DistillationLoss
Args:
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
model_name_pairs(list | tuple): model name pairs to extract submodel output.
act(string | None): activation function used to build dml loss.
layers_name(list(string)): keys of the tensor used to calculate loss if the submodel.
axis(int): axis used to build activation function.
loss_function(string): the name of loss function.
key(string | None): key of the tensor used to calculate loss if the submodel
temperature(float): the temperature to compute distill loss.
output type is dict.
name(string): loss name.
"""
def
__init__
(
self
,
model_name_pairs
=
[],
act
=
None
,
key
=
None
,
name
=
"loss_dml"
):
super
().
__init__
(
act
=
act
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out2
=
predicts
[
pair
[
1
]]
if
self
.
key
is
not
None
:
out1
=
out1
[
self
.
key
]
out2
=
out2
[
self
.
key
]
loss_dict
[
"{}_{}_{}_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)]
=
super
().
forward
(
out1
,
out2
)
return
loss_dict
class
DistillationDistanceLoss
(
DistanceLoss
):
"""
DistillationDistanceLoss
Args:
mode: loss mode
model_name_pairs(list | tuple): model name pairs to extract submodel output.
such as [['student', 'teacher']]
key(string | None): key of the tensor used to calculate loss if the submodel.
such as 'hidden_0_0'
name(string): loss name.
kargs(dict): used to build corresponding loss function.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
mode
=
"l2"
,
model_name_pairs
=
[],
model_name_pairs
=
[],
key
=
None
,
layers_name
=
None
,
name
=
"loss_distance"
,
loss_function
=
None
,
**
kargs
):
temperature
=
1.0
,
super
().
__init__
(
mode
=
mode
,
**
kargs
)
**
params
):
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
+
"_"
+
mode
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out2
=
predicts
[
pair
[
1
]]
if
self
.
key
is
not
None
:
out1
=
out1
[
self
.
key
]
out2
=
out2
[
self
.
key
]
loss
=
super
().
forward
(
out1
,
out2
)
loss_dict
[
"{}_{}_{}_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)]
=
loss
return
loss_dict
class
DistillationRKDLoss
(
nn
.
Layer
):
"""
DistillationRKDLoss
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string | None): key of the tensor used to calculate loss if the submodel.
eps(float): epsilon for the pdist function for RkdDistance loss.
name(string): loss name.
"""
def
__init__
(
self
,
model_name_pairs
=
[],
key
=
None
,
eps
=
1e-12
,
name
=
"loss_rkd"
):
super
().
__init__
()
super
().
__init__
()
self
.
model_name_pairs
=
model_name_pairs
self
.
model_name_pairs
=
model_name_pairs
self
.
key
=
key
self
.
layers_name
=
layers_name
self
.
loss_function
=
loss_function
self
.
temperature
=
temperature
self
.
align_params
=
params
.
pop
(
'align_params'
)
if
'align_params'
in
params
else
None
if
self
.
align_params
is
not
None
:
for
attr
,
value
in
self
.
align_params
.
items
():
setattr
(
self
,
attr
,
value
)
self
.
rkd_angle_loss_func
=
RKdAngle
()
self
.
loss_func
=
BASIC_LOSS
.
get
(
loss_function
)(
**
params
)
self
.
rkd_dist_func
=
RkdDistance
(
eps
=
eps
)
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out1
=
predicts
[
pair
[
0
]]
out2
=
predicts
[
pair
[
1
]]
out2
=
predicts
[
pair
[
1
]]
if
self
.
key
is
not
None
:
if
self
.
layers_name
!=
None
:
out1
=
out1
[
self
.
key
]
assert
len
(
self
.
layers_name
out2
=
out2
[
self
.
key
]
)
==
2
,
"length of layers_name must be equal to 2."
loss_dict
[
"{}_{}_{}_angle_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
out1
=
out1
[
self
.
layers_name
[
0
]]
1
],
idx
)]
=
self
.
rkd_angle_loss_func
(
out1
,
out2
)
out2
=
out2
[
self
.
layers_name
[
1
]]
if
self
.
temperature
!=
1.0
:
loss_dict
[
"{}_{}_{}_dist_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
out1
=
out1
/
self
.
temperature
1
],
idx
)]
=
self
.
rkd_dist_func
(
out1
,
out2
)
out2
=
out2
/
self
.
temperature
return
loss_dict
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
self
.
loss_function
,
pair
[
0
],
pair
[
1
],
self
.
layers_name
[
0
]
if
self
.
layers_name
!=
None
else
"0"
,
\
self
.
layers_name
[
1
]
if
self
.
layers_name
!=
None
else
"0"
)]
=
self
.
loss_func
(
out1
,
out2
)
class
SegPairWiseLoss
(
DistanceLoss
):
"""
Segmentation pairwise loss, see https://arxiv.org/pdf/1903.04197.pdf
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string): key of the tensor used to calculate loss if the submodel
output type is dict.
mode(string, optional): loss mode. It supports l1, l2 and smooth_l1. Default: l2.
reduction(string, optional): the reduction params for F.kl_div. Default: mean.
name(string, optional): loss name. Default: seg_pair_wise_loss.
"""
def
__init__
(
self
,
model_name_pairs
=
[],
key
=
None
,
mode
=
"l2"
,
reduction
=
"mean"
,
name
=
"seg_pair_wise_loss"
):
super
().
__init__
(
mode
=
mode
,
reduction
=
reduction
)
assert
isinstance
(
model_name_pairs
,
list
)
assert
key
is
not
None
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
self
.
pool1
=
nn
.
AdaptiveAvgPool2D
(
output_size
=
[
2
,
2
])
self
.
pool2
=
nn
.
AdaptiveAvgPool2D
(
output_size
=
[
2
,
2
])
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]][
self
.
key
]
out2
=
predicts
[
pair
[
1
]][
self
.
key
]
pool1
=
self
.
pool1
(
out1
)
pool2
=
self
.
pool2
(
out2
)
loss_name
=
"{}_{}_{}_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)
loss_dict
[
loss_name
]
=
super
().
forward
(
pool1
,
pool2
)
return
loss_dict
class
SegChannelwiseLoss
(
KLLoss
):
"""
Segmentation channel wise loss, see `Channel-wise Distillation for Semantic Segmentation`.
Args:
model_name_pairs(list | tuple): model name pairs to extract submodel output.
key(string): key of the tensor used to calculate loss if the submodel
output type is dict.
act(string, optional): activation function used for the input and label tensor.
Default: softmax.
axis(int, optional): the axis for the act. Default: -1.
reduction(str, optional): the reduction params for F.kl_div. Default: mean.
name(string, optional): loss name. Default: seg_ch_wise_loss.
"""
def
__init__
(
self
,
model_name_pairs
=
[],
key
=
None
,
act
=
'softmax'
,
axis
=-
1
,
reduction
=
"mean"
,
name
=
"seg_ch_wise_loss"
):
super
().
__init__
(
act
,
axis
,
reduction
)
assert
isinstance
(
model_name_pairs
,
list
)
assert
key
is
not
None
self
.
model_name_pairs
=
model_name_pairs
self
.
key
=
key
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]][
self
.
key
]
out2
=
predicts
[
pair
[
1
]][
self
.
key
]
loss_name
=
"{}_{}_{}_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)
loss_dict
[
loss_name
]
=
super
().
forward
(
out1
,
out2
)
return
loss_dict
return
loss_dict
tests/dygraph/test_distill.py
浏览文件 @
36a9f6f0
...
@@ -7,7 +7,7 @@ import paddle
...
@@ -7,7 +7,7 @@ import paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
paddle.vision.models
import
MobileNetV1
from
paddle.vision.models
import
MobileNetV1
import
paddle.vision.transforms
as
T
import
paddle.vision.transforms
as
T
from
paddleslim.dygraph.dist
import
Distill
,
AdaptorBase
from
paddleslim.dygraph.dist
import
Distill
,
config2yaml
from
paddleslim.common.log_helper
import
get_logger
from
paddleslim.common.log_helper
import
get_logger
_logger
=
get_logger
(
_logger
=
get_logger
(
...
@@ -19,42 +19,30 @@ class TestImperativeDistill(unittest.TestCase):
...
@@ -19,42 +19,30 @@ class TestImperativeDistill(unittest.TestCase):
self
.
s_model
,
self
.
t_model
=
self
.
prepare_model
()
self
.
s_model
,
self
.
t_model
=
self
.
prepare_model
()
self
.
t_model
.
eval
()
self
.
t_model
.
eval
()
self
.
distill_configs
=
self
.
prepare_config
()
self
.
distill_configs
=
self
.
prepare_config
()
self
.
adaptor
=
self
.
prepare_adaptor
()
def
prepare_model
(
self
):
def
prepare_model
(
self
):
return
MobileNetV1
(),
MobileNetV1
()
return
MobileNetV1
(),
MobileNetV1
()
def
prepare_config
(
self
):
def
prepare_config
(
self
):
distill_configs
=
[{
distill_configs
=
[{
's_feature_idx'
:
0
,
'loss_function'
:
'MSELoss'
,
't_feature_idx'
:
0
,
'layers'
:
[
'feature_type'
:
'hidden'
,
{
'loss_function'
:
'l2'
"layers_name"
:
[
"conv1"
,
"conv1"
]
},
{
"layers_name"
:
[
"conv2_2"
,
"conv2_2"
]
},
]
},
{
},
{
's_feature_idx'
:
1
,
'loss_function'
:
'CELoss'
,
't_feature_idx'
:
1
,
'temperature'
:
1.0
,
'feature_type'
:
'hidden'
,
'layers'
:
[{
'loss_function'
:
'l2'
"layers_name"
:
[
"fc"
,
"fc"
]
},
{
},
]
's_feature_idx'
:
0
,
't_feature_idx'
:
0
,
'feature_type'
:
'logits'
,
'loss_function'
:
'l2'
}]
}]
return
distill_configs
return
distill_configs
def
prepare_adaptor
(
self
):
class
Adaptor
(
AdaptorBase
):
def
mapping_layers
(
self
):
mapping_layers
=
{}
mapping_layers
[
'hidden_0'
]
=
'conv1'
mapping_layers
[
'hidden_1'
]
=
'conv2_2'
mapping_layers
[
'hidden_2'
]
=
'conv3_2'
mapping_layers
[
'logits_0'
]
=
'fc'
return
mapping_layers
return
Adaptor
def
test_distill
(
self
):
def
test_distill
(
self
):
transform
=
T
.
Compose
([
T
.
Transpose
(),
T
.
Normalize
([
127.5
],
[
127.5
])])
transform
=
T
.
Compose
([
T
.
Transpose
(),
T
.
Normalize
([
127.5
],
[
127.5
])])
...
@@ -97,7 +85,7 @@ class TestImperativeDistill(unittest.TestCase):
...
@@ -97,7 +85,7 @@ class TestImperativeDistill(unittest.TestCase):
for
batch_id
,
data
in
enumerate
(
train_reader
):
for
batch_id
,
data
in
enumerate
(
train_reader
):
img
=
paddle
.
to_tensor
(
data
[
0
])
img
=
paddle
.
to_tensor
(
data
[
0
])
label
=
paddle
.
to_tensor
(
data
[
1
])
label
=
paddle
.
to_tensor
(
data
[
1
])
student_out
,
teacher_out
,
distill_loss
=
model
(
img
)
distill_loss
,
student_out
,
teacher_out
=
model
(
img
)
loss
=
paddle
.
nn
.
functional
.
loss
.
cross_entropy
(
student_out
,
loss
=
paddle
.
nn
.
functional
.
loss
.
cross_entropy
(
student_out
,
label
)
label
)
avg_loss
=
paddle
.
mean
(
loss
)
avg_loss
=
paddle
.
mean
(
loss
)
...
@@ -112,7 +100,7 @@ class TestImperativeDistill(unittest.TestCase):
...
@@ -112,7 +100,7 @@ class TestImperativeDistill(unittest.TestCase):
self
.
s_model
.
train
()
self
.
s_model
.
train
()
distill_model
=
Distill
(
self
.
distill_configs
,
self
.
s_model
,
distill_model
=
Distill
(
self
.
distill_configs
,
self
.
s_model
,
self
.
t_model
,
self
.
adaptor
,
self
.
adaptor
)
self
.
t_model
)
train
(
distill_model
)
train
(
distill_model
)
...
@@ -136,31 +124,26 @@ class TestImperativeDistillCase1(TestImperativeDistill):
...
@@ -136,31 +124,26 @@ class TestImperativeDistillCase1(TestImperativeDistill):
return
Model
(),
Model
()
return
Model
(),
Model
()
def
prepare_adaptor
(
self
):
class
Adaptor
(
AdaptorBase
):
def
mapping_layers
(
self
):
mapping_layers
=
{}
mapping_layers
[
'hidden_1'
]
=
'conv2'
if
self
.
add_tensor
:
mapping_layers
[
'hidden_0'
]
=
self
.
model
.
conv1_out
mapping_layers
[
'hidden_2'
]
=
self
.
model
.
conv3_out
return
mapping_layers
return
Adaptor
def
prepare_config
(
self
):
def
prepare_config
(
self
):
distill_configs
=
[{
distill_configs
=
[{
's_feature_idx'
:
0
,
'loss_function'
:
'MSELoss'
,
't_feature_idx'
:
0
,
'layers'
:
[
'feature_type'
:
'hidden'
,
{
'loss_function'
:
'l2'
"layers_name"
:
[
"conv1"
,
"conv1"
]
},
{
"layers_name"
:
[
"conv2"
,
"conv3"
]
},
]
},
{
},
{
's_feature_idx'
:
1
,
'loss_function'
:
'CELoss'
,
't_feature_idx'
:
2
,
'temperature'
:
1.0
,
'feature_type'
:
'hidden'
,
'layers'
:
[{
'loss_function'
:
'l2'
"layers_name"
:
[
"fc"
,
"fc"
]
},
]
}]
}]
return
distill_configs
config2yaml
(
distill_configs
,
'test.yaml'
)
return
'./test.yaml'
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tests/dygraph/test_distillation_loss.py
浏览文件 @
36a9f6f0
...
@@ -24,18 +24,14 @@ import paddle.nn.functional as F
...
@@ -24,18 +24,14 @@ import paddle.nn.functional as F
from
paddleslim.dygraph.dist.losses
import
CombinedLoss
from
paddleslim.dygraph.dist.losses
import
CombinedLoss
# basic loss
# basic loss
from
paddleslim.dygraph.dist.losses
import
DistanceLoss
from
paddleslim.dygraph.dist.losses
.basic_loss
import
DistanceLoss
from
paddleslim.dygraph.dist.losses
import
CELoss
from
paddleslim.dygraph.dist.losses
.basic_loss
import
CELoss
from
paddleslim.dygraph.dist.losses
import
DMLLoss
from
paddleslim.dygraph.dist.losses
.basic_loss
import
DMLLoss
from
paddleslim.dygraph.dist.losses
import
RkdDistance
from
paddleslim.dygraph.dist.losses
.basic_loss
import
RkdDistance
from
paddleslim.dygraph.dist.losses
import
RKdAngle
from
paddleslim.dygraph.dist.losses
.basic_loss
import
RKdAngle
# distillation loss
# distillation loss
from
paddleslim.dygraph.dist.losses
import
DistillationDistanceLoss
from
paddleslim.dygraph.dist.losses
import
DistillationLoss
from
paddleslim.dygraph.dist.losses
import
DistillationRKDLoss
from
paddleslim.dygraph.dist.losses
import
DistillationDMLLoss
from
paddleslim.dygraph.dist.losses
import
SegPairWiseLoss
from
paddleslim.dygraph.dist.losses
import
SegChannelwiseLoss
import
numpy
as
np
import
numpy
as
np
...
@@ -70,14 +66,13 @@ class TestDistanceLoss(unittest.TestCase):
...
@@ -70,14 +66,13 @@ class TestDistanceLoss(unittest.TestCase):
out
=
np
.
sum
(
diff
)
out
=
np
.
sum
(
diff
)
return
out
return
out
def
dist_np_distance_loss
(
def
dist_np_distance_loss
(
self
,
self
,
predicts
,
predicts
,
loss_function
=
None
,
mode
=
"l2"
,
mode
=
"l2"
,
reduction
=
"none"
,
reduction
=
"none"
,
model_name_pairs
=
([
""
,
""
]),
model_name_pairs
=
([
""
,
""
]),
key
=
None
,
key
=
None
):
name
=
"loss_distance"
,
):
loss_dict
=
dict
()
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
model_name_pairs
):
for
idx
,
pair
in
enumerate
(
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out1
=
predicts
[
pair
[
0
]]
...
@@ -85,10 +80,12 @@ class TestDistanceLoss(unittest.TestCase):
...
@@ -85,10 +80,12 @@ class TestDistanceLoss(unittest.TestCase):
if
key
is
not
None
:
if
key
is
not
None
:
out1
=
out1
[
key
]
out1
=
out1
[
key
]
out2
=
out2
[
key
]
out2
=
out2
[
key
]
else
:
key
=
0
loss
=
self
.
np_distance_loss
(
loss
=
self
.
np_distance_loss
(
out1
,
out2
,
mode
=
mode
,
reduction
=
reduction
)
out1
,
out2
,
mode
=
mode
,
reduction
=
reduction
)
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
name
,
mode
,
pair
[
0
],
pair
[
1
],
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
idx
)]
=
loss
str
(
loss_function
),
pair
[
0
],
pair
[
1
],
key
,
key
)]
=
loss
return
loss_dict
return
loss_dict
...
@@ -120,7 +117,7 @@ class TestDistanceLoss(unittest.TestCase):
...
@@ -120,7 +117,7 @@ class TestDistanceLoss(unittest.TestCase):
"student"
:
paddle
.
rand
(
shape
),
"student"
:
paddle
.
rand
(
shape
),
"teacher"
:
paddle
.
rand
(
shape
),
"teacher"
:
paddle
.
rand
(
shape
),
}
}
self
.
calc_distillation_distance_loss
(
predicts
,
pairs
,
key
=
None
)
self
.
calc_distillation_distance_loss
(
predicts
,
pairs
)
predicts
=
{
predicts
=
{
"student"
:
{
"student"
:
{
...
@@ -143,13 +140,15 @@ class TestDistanceLoss(unittest.TestCase):
...
@@ -143,13 +140,15 @@ class TestDistanceLoss(unittest.TestCase):
paddle
.
set_device
(
device
)
paddle
.
set_device
(
device
)
for
reduction
in
reductions
:
for
reduction
in
reductions
:
for
mode
in
modes
:
for
mode
in
modes
:
loss_func
=
Distillation
Distance
Loss
(
loss_func
=
DistillationLoss
(
mode
=
mode
,
mode
=
mode
,
loss_function
=
'DistanceLoss'
,
model_name_pairs
=
pairs
,
model_name_pairs
=
pairs
,
key
=
key
,
layers_name
=
[
key
,
key
]
if
key
!=
None
else
None
,
reduction
=
reduction
)
reduction
=
reduction
)
np_result_dict
=
self
.
dist_np_distance_loss
(
np_result_dict
=
self
.
dist_np_distance_loss
(
predicts
,
predicts
,
loss_function
=
'DistanceLoss'
,
mode
=
mode
,
mode
=
mode
,
reduction
=
reduction
,
reduction
=
reduction
,
model_name_pairs
=
pairs
,
model_name_pairs
=
pairs
,
...
@@ -358,12 +357,11 @@ class TestDMLLoss(unittest.TestCase):
...
@@ -358,12 +357,11 @@ class TestDMLLoss(unittest.TestCase):
np_loss
=
self
.
np_dml_loss
(
x
,
target
)
np_loss
=
self
.
np_dml_loss
(
x
,
target
)
self
.
assertTrue
(
np
.
allclose
(
np_loss
,
pd_loss
))
self
.
assertTrue
(
np
.
allclose
(
np_loss
,
pd_loss
))
def
dist_np_dml_loss
(
def
dist_np_dml_loss
(
self
,
self
,
predicts
,
predicts
,
loss_function
=
None
,
model_name_pairs
=
([
""
,
""
]),
model_name_pairs
=
([
""
,
""
]),
key
=
None
,
key
=
None
):
name
=
"loss_dml"
,
):
loss_dict
=
dict
()
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
model_name_pairs
):
for
idx
,
pair
in
enumerate
(
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out1
=
predicts
[
pair
[
0
]]
...
@@ -371,8 +369,11 @@ class TestDMLLoss(unittest.TestCase):
...
@@ -371,8 +369,11 @@ class TestDMLLoss(unittest.TestCase):
if
key
is
not
None
:
if
key
is
not
None
:
out1
=
out1
[
key
]
out1
=
out1
[
key
]
out2
=
out2
[
key
]
out2
=
out2
[
key
]
loss_dict
[
"{}_{}_{}_{}"
.
format
(
name
,
pair
[
0
],
pair
[
1
],
else
:
idx
)]
=
self
.
np_dml_loss
(
out1
,
out2
)
key
=
0
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
str
(
loss_function
),
pair
[
0
],
pair
[
1
],
key
,
key
)]
=
self
.
np_dml_loss
(
out1
,
out2
)
return
loss_dict
return
loss_dict
def
calc_distillation_dml_loss
(
self
,
predicts
,
pairs
,
key
=
None
):
def
calc_distillation_dml_loss
(
self
,
predicts
,
pairs
,
key
=
None
):
...
@@ -382,11 +383,19 @@ class TestDMLLoss(unittest.TestCase):
...
@@ -382,11 +383,19 @@ class TestDMLLoss(unittest.TestCase):
for
device
in
devices
:
for
device
in
devices
:
paddle
.
set_device
(
device
)
paddle
.
set_device
(
device
)
loss_func
=
DistillationDMLLoss
(
loss_func
=
DistillationLoss
(
act
=
"softmax"
,
model_name_pairs
=
pairs
,
key
=
key
)
act
=
"softmax"
,
model_name_pairs
=
pairs
,
loss_function
=
'DMLLoss'
,
layers_name
=
[
key
,
key
]
if
key
!=
None
else
None
)
np_result_dict
=
self
.
dist_np_dml_loss
(
np_result_dict
=
self
.
dist_np_dml_loss
(
predicts
,
model_name_pairs
=
pairs
,
key
=
key
)
predicts
,
model_name_pairs
=
pairs
,
loss_function
=
'DMLLoss'
,
key
=
key
)
pd_result_dict
=
loss_func
(
predicts
,
None
)
pd_result_dict
=
loss_func
(
predicts
,
None
)
print
(
pd_result_dict
.
keys
())
print
(
np_result_dict
.
keys
())
for
k
in
np_result_dict
:
for
k
in
np_result_dict
:
pd_result
=
pd_result_dict
[
k
].
numpy
()
pd_result
=
pd_result_dict
[
k
].
numpy
()
np_result
=
np_result_dict
[
k
]
np_result
=
np_result_dict
[
k
]
...
@@ -526,7 +535,7 @@ class TestRKDLoss(unittest.TestCase):
...
@@ -526,7 +535,7 @@ class TestRKDLoss(unittest.TestCase):
predicts
,
predicts
,
model_name_pairs
=
([
""
,
""
]),
model_name_pairs
=
([
""
,
""
]),
key
=
None
,
key
=
None
,
name
=
"
loss_rkd
"
,
):
name
=
"
RKDLoss
"
,
):
loss_dict
=
dict
()
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
model_name_pairs
):
for
idx
,
pair
in
enumerate
(
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out1
=
predicts
[
pair
[
0
]]
...
@@ -534,11 +543,12 @@ class TestRKDLoss(unittest.TestCase):
...
@@ -534,11 +543,12 @@ class TestRKDLoss(unittest.TestCase):
if
key
is
not
None
:
if
key
is
not
None
:
out1
=
out1
[
key
]
out1
=
out1
[
key
]
out2
=
out2
[
key
]
out2
=
out2
[
key
]
loss_dict
[
"{}_{}_{}_angle_{}"
.
format
(
name
,
pair
[
0
],
pair
[
else
:
1
],
idx
)]
=
self
.
np_rkd_angle
(
out1
,
out2
)
key
=
0
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
name
,
pair
[
0
],
pair
[
1
],
key
,
key
)]
=
self
.
np_rkd_angle
(
out1
,
out2
)
+
self
.
np_rkd_distance
(
out1
,
out2
)
loss_dict
[
"{}_{}_{}_dist_{}"
.
format
(
name
,
pair
[
0
],
pair
[
1
],
idx
)]
=
self
.
np_rkd_distance
(
out1
,
out2
)
return
loss_dict
return
loss_dict
def
calc_distillation_rkd_loss
(
self
,
predicts
,
pairs
,
key
=
None
):
def
calc_distillation_rkd_loss
(
self
,
predicts
,
pairs
,
key
=
None
):
...
@@ -548,7 +558,10 @@ class TestRKDLoss(unittest.TestCase):
...
@@ -548,7 +558,10 @@ class TestRKDLoss(unittest.TestCase):
for
device
in
devices
:
for
device
in
devices
:
paddle
.
set_device
(
device
)
paddle
.
set_device
(
device
)
loss_func
=
DistillationRKDLoss
(
model_name_pairs
=
pairs
,
key
=
key
)
loss_func
=
DistillationLoss
(
model_name_pairs
=
pairs
,
loss_function
=
'RKDLoss'
,
layers_name
=
[
key
,
key
]
if
key
!=
None
else
None
)
np_result_dict
=
self
.
dist_np_rkd_loss
(
np_result_dict
=
self
.
dist_np_rkd_loss
(
predicts
,
model_name_pairs
=
pairs
,
key
=
key
)
predicts
,
model_name_pairs
=
pairs
,
key
=
key
)
pd_result_dict
=
loss_func
(
predicts
,
None
)
pd_result_dict
=
loss_func
(
predicts
,
None
)
...
@@ -623,13 +636,12 @@ class TestCombinedLoss(unittest.TestCase):
...
@@ -623,13 +636,12 @@ class TestCombinedLoss(unittest.TestCase):
log_soft_target
,
soft_x
))
/
2.0
log_soft_target
,
soft_x
))
/
2.0
return
loss
return
loss
def
dist_np_dml_loss
(
def
dist_np_dml_loss
(
self
,
self
,
predicts
,
predicts
,
model_name_pairs
=
([
""
,
""
]),
model_name_pairs
=
([
""
,
""
]),
loss_function
=
None
,
key
=
None
,
key
=
None
,
act
=
"softmax"
,
act
=
"softmax"
):
name
=
"loss_dml"
,
):
loss_dict
=
dict
()
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
model_name_pairs
):
for
idx
,
pair
in
enumerate
(
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out1
=
predicts
[
pair
[
0
]]
...
@@ -637,20 +649,24 @@ class TestCombinedLoss(unittest.TestCase):
...
@@ -637,20 +649,24 @@ class TestCombinedLoss(unittest.TestCase):
if
key
is
not
None
:
if
key
is
not
None
:
out1
=
out1
[
key
]
out1
=
out1
[
key
]
out2
=
out2
[
key
]
out2
=
out2
[
key
]
loss_dict
[
"{}_{}_{}_{}"
.
format
(
name
,
pair
[
0
],
pair
[
1
],
loss_dict
[
"{}_{}_{}_{}_0"
.
format
(
idx
)]
=
self
.
np_dml_loss
(
out1
,
out2
)
str
(
loss_function
),
pair
[
0
],
pair
[
1
],
idx
)]
=
self
.
np_dml_loss
(
out1
,
out2
)
return
loss_dict
return
loss_dict
def
np_combined_loss
(
self
,
predicts
,
loss_cfg_list
):
def
np_combined_loss
(
self
,
predicts
,
loss_cfg_list
):
# NOTE, dml is set as the list for combined loss
# NOTE, dml is set as the list for combined loss
loss_dict
=
dict
()
loss_dict
=
dict
()
for
idx
,
loss_func
in
enumerate
(
loss_cfg_list
):
for
idx
,
loss_func
in
enumerate
(
loss_cfg_list
):
cfg
=
copy
.
deepcopy
(
loss_func
[
"DistillationDMLLoss"
]
)
cfg
=
copy
.
deepcopy
(
loss_func
)
weight
=
cfg
.
pop
(
"weight"
)
weight
=
cfg
.
pop
(
"weight"
)
loss
=
self
.
dist_np_dml_loss
(
predicts
,
**
cfg
)
loss
=
self
.
dist_np_dml_loss
(
predicts
,
**
cfg
)
if
isinstance
(
loss
,
np
.
ndarray
):
if
isinstance
(
loss
,
np
.
ndarray
):
loss
=
{
"loss_{}_{}"
.
format
(
str
(
loss
),
idx
):
loss
}
loss
=
{
"{}_{}_{}"
.
format
(
loss_func
[
'loss_function'
],
str
(
loss
),
idx
):
loss
}
else
:
else
:
loss
=
{
loss
=
{
"{}_{}"
.
format
(
key
,
idx
):
loss
[
key
]
*
weight
"{}_{}"
.
format
(
key
,
idx
):
loss
[
key
]
*
weight
...
@@ -677,12 +693,10 @@ class TestCombinedLoss(unittest.TestCase):
...
@@ -677,12 +693,10 @@ class TestCombinedLoss(unittest.TestCase):
devices
.
append
(
"gpu"
)
devices
.
append
(
"gpu"
)
loss_cfg_list
=
[{
loss_cfg_list
=
[{
"DistillationDMLLoss"
:
{
"loss_function"
:
"DMLLoss"
,
"weight"
:
1.0
,
"weight"
:
1.0
,
"act"
:
"softmax"
,
"act"
:
"softmax"
,
"model_name_pairs"
:
pairs
,
"model_name_pairs"
:
pairs
"key"
:
None
}
},
]
},
]
for
device
in
devices
:
for
device
in
devices
:
...
@@ -696,95 +710,5 @@ class TestCombinedLoss(unittest.TestCase):
...
@@ -696,95 +710,5 @@ class TestCombinedLoss(unittest.TestCase):
self
.
assertTrue
(
np
.
allclose
(
np_result
,
pd_result
))
self
.
assertTrue
(
np
.
allclose
(
np_result
,
pd_result
))
class
TestSegPairWiseLoss
(
unittest
.
TestCase
):
def
calculate_gt_loss
(
self
,
x
,
y
):
pool_x
=
F
.
adaptive_avg_pool2d
(
x
,
[
2
,
2
])
pool_y
=
F
.
adaptive_avg_pool2d
(
y
,
[
2
,
2
])
loss
=
F
.
mse_loss
(
pool_x
,
pool_y
)
return
loss
def
test_seg_pair_wise_loss
(
self
):
shape
=
[
1
,
3
,
10
,
10
]
x
=
paddle
.
rand
(
shape
)
y
=
paddle
.
rand
(
shape
)
model_name_pairs
=
[[
'student'
,
'teacher'
]]
key
=
'hidden_0_0'
inputs
=
{
model_name_pairs
[
0
][
0
]:
{
key
:
x
},
model_name_pairs
[
0
][
1
]:
{
key
:
y
}
}
devices
=
[
"cpu"
]
if
paddle
.
is_compiled_with_cuda
():
devices
.
append
(
"gpu"
)
for
device
in
devices
:
paddle
.
set_device
(
device
)
loss_func
=
SegPairWiseLoss
(
model_name_pairs
,
key
)
pd_loss_dict
=
loss_func
(
inputs
,
None
)
pd_loss
=
pd_loss_dict
[
'seg_pair_wise_loss_student_teacher_0'
]
gt_loss
=
self
.
calculate_gt_loss
(
x
,
y
)
self
.
assertTrue
(
np
.
allclose
(
pd_loss
.
numpy
(),
gt_loss
.
numpy
()))
class
TestSegChannelWiseLoss
(
unittest
.
TestCase
):
def
init
(
self
):
self
.
act_name
=
None
self
.
act_func
=
None
def
calculate_gt_loss
(
self
,
x
,
y
,
act
=
None
):
if
act
is
not
None
:
x
=
act
(
x
)
y
=
act
(
y
)
x
=
paddle
.
log
(
x
)
loss
=
F
.
kl_div
(
x
,
y
)
return
loss
def
test_seg_pair_wise_loss
(
self
):
self
.
init
()
shape
=
[
1
,
3
,
10
,
10
]
x
=
paddle
.
rand
(
shape
)
y
=
paddle
.
rand
(
shape
)
model_name_pairs
=
[[
'student'
,
'teacher'
]]
key
=
'hidden_0_0'
inputs
=
{
model_name_pairs
[
0
][
0
]:
{
key
:
x
},
model_name_pairs
[
0
][
1
]:
{
key
:
y
}
}
devices
=
[
"cpu"
]
if
paddle
.
is_compiled_with_cuda
():
devices
.
append
(
"gpu"
)
for
device
in
devices
:
paddle
.
set_device
(
device
)
loss_func
=
SegChannelwiseLoss
(
model_name_pairs
,
key
,
self
.
act_name
)
pd_loss_dict
=
loss_func
(
inputs
,
None
)
pd_loss
=
pd_loss_dict
[
'seg_ch_wise_loss_student_teacher_0'
]
gt_loss
=
self
.
calculate_gt_loss
(
x
,
y
,
self
.
act_func
)
self
.
assertTrue
(
np
.
allclose
(
pd_loss
.
numpy
(),
gt_loss
.
numpy
()))
class
TestSegChannelWiseLoss1
(
TestSegChannelWiseLoss
):
def
init
(
self
):
self
.
act_name
=
"softmax"
self
.
act_func
=
F
.
softmax
class
TestSegChannelWiseLoss1
(
TestSegChannelWiseLoss
):
def
init
(
self
):
self
.
act_name
=
"sigmoid"
self
.
act_func
=
F
.
sigmoid
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录