Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
65c776de
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
65c776de
编写于
2月 13, 2023
作者:
G
Guanghua Yu
提交者:
GitHub
2月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add Structural Re-parameterization implementation (#1608)
上级
44e3306b
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
2075 addition
and
0 deletion
+2075
-0
example/reparameterization/README.md
example/reparameterization/README.md
+38
-0
example/reparameterization/optimizer.py
example/reparameterization/optimizer.py
+55
-0
example/reparameterization/train.py
example/reparameterization/train.py
+323
-0
paddleslim/dygraph/__init__.py
paddleslim/dygraph/__init__.py
+2
-0
paddleslim/dygraph/rep/__init__.py
paddleslim/dygraph/rep/__init__.py
+26
-0
paddleslim/dygraph/rep/config.py
paddleslim/dygraph/rep/config.py
+168
-0
paddleslim/dygraph/rep/rep.py
paddleslim/dygraph/rep/rep.py
+184
-0
paddleslim/dygraph/rep/reper/__init__.py
paddleslim/dygraph/rep/reper/__init__.py
+30
-0
paddleslim/dygraph/rep/reper/acblock.py
paddleslim/dygraph/rep/reper/acblock.py
+145
-0
paddleslim/dygraph/rep/reper/base.py
paddleslim/dygraph/rep/reper/base.py
+76
-0
paddleslim/dygraph/rep/reper/diversebranchblock.py
paddleslim/dygraph/rep/reper/diversebranchblock.py
+324
-0
paddleslim/dygraph/rep/reper/repvgg.py
paddleslim/dygraph/rep/reper/repvgg.py
+205
-0
paddleslim/dygraph/rep/reper/slimrep.py
paddleslim/dygraph/rep/reper/slimrep.py
+252
-0
tests/dygraph/test_reparameterization.py
tests/dygraph/test_reparameterization.py
+247
-0
未找到文件。
example/reparameterization/README.md
0 → 100755
浏览文件 @
65c776de
# 重参数化
本示例介绍如何对动态图模型进行重参数化训练,示例以常用的MobileNetV1模型为例,介绍如何对其进行DBB重参数化实验,DBB参考自
[
论文
](
https://arxiv.org/abs/2103.13425
)
。
## 分类模型的重参数化训练流程
### 准备数据
在当前目录下创建
``data``
文件夹,将
``ImageNet``
数据集解压在
``data``
文件夹下,解压后
``data/ILSVRC2012``
文件夹下应包含以下文件:
-
``'train'``
文件夹,训练图片
-
``'train_list.txt'``
文件
-
``'val'``
文件夹,验证图片
-
``'val_list.txt'``
文件
### 准备需要重参数化的模型
-
对于paddle vision支持的
[
模型
](
https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/vision/models
)
:
`[lenet, mobilenetv1, mobilenetv2, resnet, vgg]`
可以直接使用vision内置的模型定义和ImageNet预训练权重
### 训练命令
-
MobileNetV1
启动命令如下:
```
bash
# 单卡训练
python train.py
--model
=
mobilenet_v1
# 多卡训练,以0到3号卡为例
python
-m
paddle.distributed.launch
--gpus
=
"0,1,2,3"
train.py
```
### 重参数化结果
| 模型 | FP32模型准确率(Top1) | 重参数化方法 | 重参数化模型准确率(Top1) |
| ----------- | --------------------------- | ------------ | --------------------------- |
| MobileNetV1 | 70.99 | DBB | 72.01 |
example/reparameterization/optimizer.py
0 → 100644
浏览文件 @
65c776de
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
def
piecewise_decay
(
net
,
device_num
,
args
):
step
=
int
(
math
.
ceil
(
float
(
args
.
total_images
)
/
(
args
.
batch_size
*
device_num
)))
bd
=
[
step
*
e
for
e
in
args
.
step_epochs
]
lr
=
[
args
.
lr
*
(
0.1
**
i
)
for
i
in
range
(
len
(
bd
)
+
1
)]
learning_rate
=
paddle
.
optimizer
.
lr
.
PiecewiseDecay
(
boundaries
=
bd
,
values
=
lr
,
verbose
=
False
)
optimizer
=
paddle
.
optimizer
.
Momentum
(
parameters
=
net
.
parameters
(),
learning_rate
=
learning_rate
,
momentum
=
args
.
momentum_rate
,
weight_decay
=
paddle
.
regularizer
.
L2Decay
(
args
.
l2_decay
))
return
optimizer
,
learning_rate
def
cosine_decay
(
net
,
device_num
,
args
):
step
=
int
(
math
.
ceil
(
float
(
args
.
total_images
)
/
(
args
.
batch_size
*
device_num
)))
learning_rate
=
paddle
.
optimizer
.
lr
.
CosineAnnealingDecay
(
learning_rate
=
args
.
lr
,
T_max
=
step
*
args
.
num_epochs
,
verbose
=
False
)
optimizer
=
paddle
.
optimizer
.
Momentum
(
parameters
=
net
.
parameters
(),
learning_rate
=
learning_rate
,
momentum
=
args
.
momentum_rate
,
weight_decay
=
paddle
.
regularizer
.
L2Decay
(
args
.
l2_decay
))
return
optimizer
,
learning_rate
def
create_optimizer
(
net
,
device_num
,
args
):
if
args
.
lr_strategy
==
"piecewise_decay"
:
return
piecewise_decay
(
net
,
device_num
,
args
)
elif
args
.
lr_strategy
==
"cosine_decay"
:
return
cosine_decay
(
net
,
device_num
,
args
)
example/reparameterization/train.py
0 → 100644
浏览文件 @
65c776de
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
import
logging
import
paddle
import
argparse
import
functools
import
math
import
time
import
random
import
numpy
as
np
from
paddle.distributed
import
ParallelEnv
from
paddle.static
import
load_program_state
from
paddle.vision.models
import
mobilenet_v1
import
paddle.vision.transforms
as
T
from
paddleslim.common
import
get_logger
from
paddleslim.dygraph.rep
import
Reparameter
,
DBBRepConfig
,
ACBRepConfig
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
"__file__"
)))
from
optimizer
import
create_optimizer
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
,
os
.
path
.
pardir
))
from
utility
import
add_arguments
,
print_arguments
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
# yapf: disable
add_arg
(
'batch_size'
,
int
,
64
,
"Single Card Minibatch size."
)
add_arg
(
'use_gpu'
,
bool
,
True
,
"Whether to use GPU or not."
)
add_arg
(
'lr'
,
float
,
0.1
,
"The learning rate used to fine-tune pruned model."
)
add_arg
(
'lr_strategy'
,
str
,
"piecewise_decay"
,
"The learning rate decay strategy."
)
add_arg
(
'l2_decay'
,
float
,
0.00003
,
"The l2_decay parameter."
)
add_arg
(
'ls_epsilon'
,
float
,
0.0
,
"Label smooth epsilon."
)
add_arg
(
'use_pact'
,
bool
,
False
,
"Whether to use PACT method."
)
add_arg
(
'ce_test'
,
bool
,
False
,
"Whether to CE test."
)
add_arg
(
'momentum_rate'
,
float
,
0.9
,
"The value of momentum_rate."
)
add_arg
(
'num_epochs'
,
int
,
120
,
"The number of total epochs."
)
add_arg
(
'total_images'
,
int
,
1281167
,
"The number of total training images."
)
add_arg
(
'data'
,
str
,
"imagenet"
,
"Which data to use. 'cifar10' or 'imagenet'"
)
add_arg
(
'log_period'
,
int
,
10
,
"Log period in batches."
)
add_arg
(
'model_save_dir'
,
str
,
"./output_models"
,
"model save directory."
)
parser
.
add_argument
(
'--step_epochs'
,
nargs
=
'+'
,
type
=
int
,
default
=
[
30
,
60
,
90
],
help
=
"piecewise decay step"
)
# yapf: enable
def
load_dygraph_pretrain
(
model
,
path
=
None
,
load_static_weights
=
False
):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
path
))
if
load_static_weights
:
pre_state_dict
=
load_program_state
(
path
)
param_state_dict
=
{}
model_dict
=
model
.
state_dict
()
for
key
in
model_dict
.
keys
():
weight_name
=
model_dict
[
key
].
name
if
weight_name
in
pre_state_dict
.
keys
():
print
(
'Load weight: {}, shape: {}'
.
format
(
weight_name
,
pre_state_dict
[
weight_name
].
shape
))
param_state_dict
[
key
]
=
pre_state_dict
[
weight_name
]
else
:
param_state_dict
[
key
]
=
model_dict
[
key
]
model
.
set_dict
(
param_state_dict
)
return
param_state_dict
=
paddle
.
load
(
path
+
".pdparams"
)
model
.
set_dict
(
param_state_dict
)
return
def
train
(
args
):
num_workers
=
4
shuffle
=
True
if
args
.
ce_test
:
# set seed
seed
=
111
paddle
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
num_workers
=
0
shuffle
=
False
if
args
.
data
==
"cifar10"
:
transform
=
T
.
Compose
([
T
.
Transpose
(),
T
.
Normalize
([
127.5
],
[
127.5
])])
train_dataset
=
paddle
.
vision
.
datasets
.
Cifar10
(
mode
=
"train"
,
backend
=
"cv2"
,
transform
=
transform
)
val_dataset
=
paddle
.
vision
.
datasets
.
Cifar10
(
mode
=
"test"
,
backend
=
"cv2"
,
transform
=
transform
)
class_dim
=
10
image_shape
=
[
3
,
32
,
32
]
pretrain
=
False
args
.
total_images
=
50000
elif
args
.
data
==
"imagenet"
:
import
imagenet_reader
as
reader
train_dataset
=
reader
.
ImageNetDataset
(
mode
=
'train'
)
val_dataset
=
reader
.
ImageNetDataset
(
mode
=
'val'
)
class_dim
=
1000
image_shape
=
"3,224,224"
else
:
raise
ValueError
(
"{} is not supported."
.
format
(
args
.
data
))
trainer_num
=
paddle
.
distributed
.
get_world_size
()
use_data_parallel
=
trainer_num
!=
1
place
=
paddle
.
set_device
(
'gpu'
if
args
.
use_gpu
else
'cpu'
)
# model definition
if
use_data_parallel
:
paddle
.
distributed
.
init_parallel_env
()
pretrain
=
True
if
args
.
data
==
"imagenet"
else
False
net
=
mobilenet_v1
(
pretrained
=
pretrain
,
num_classes
=
class_dim
)
rep_config
=
DBBRepConfig
()
reper
=
Reparameter
(
rep_config
)
reper
.
prepare
(
net
)
paddle
.
summary
(
net
,
(
1
,
3
,
224
,
224
))
opt
,
lr
=
create_optimizer
(
net
,
trainer_num
,
args
)
if
use_data_parallel
:
net
=
paddle
.
DataParallel
(
net
)
train_batch_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
shuffle
,
drop_last
=
True
)
train_loader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
batch_sampler
=
train_batch_sampler
,
places
=
place
,
return_list
=
True
,
num_workers
=
num_workers
)
valid_loader
=
paddle
.
io
.
DataLoader
(
val_dataset
,
places
=
place
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
return_list
=
True
,
num_workers
=
num_workers
)
@
paddle
.
no_grad
()
def
test
(
epoch
,
net
):
net
.
eval
()
batch_id
=
0
acc_top1_ns
=
[]
acc_top5_ns
=
[]
eval_reader_cost
=
0.0
eval_run_cost
=
0.0
total_samples
=
0
reader_start
=
time
.
time
()
for
data
in
valid_loader
():
eval_reader_cost
+=
time
.
time
()
-
reader_start
image
=
data
[
0
]
label
=
data
[
1
]
if
args
.
data
==
"cifar10"
:
label
=
paddle
.
reshape
(
label
,
[
-
1
,
1
])
eval_start
=
time
.
time
()
out
=
net
(
image
)
acc_top1
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
1
)
acc_top5
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
5
)
eval_run_cost
+=
time
.
time
()
-
eval_start
batch_size
=
image
.
shape
[
0
]
total_samples
+=
batch_size
if
batch_id
%
args
.
log_period
==
0
:
log_period
=
1
if
batch_id
==
0
else
args
.
log_period
_logger
.
info
(
"Eval epoch[{}] batch[{}] - top1: {:.6f}; top5: {:.6f}; avg_reader_cost: {:.6f} s, avg_batch_cost: {:.6f} s, avg_samples: {}, avg_ips: {:.3f} images/s"
.
format
(
epoch
,
batch_id
,
np
.
mean
(
acc_top1
.
numpy
()),
np
.
mean
(
acc_top5
.
numpy
()),
eval_reader_cost
/
log_period
,
(
eval_reader_cost
+
eval_run_cost
)
/
log_period
,
total_samples
/
log_period
,
total_samples
/
(
eval_reader_cost
+
eval_run_cost
)))
eval_reader_cost
=
0.0
eval_run_cost
=
0.0
total_samples
=
0
acc_top1_ns
.
append
(
np
.
mean
(
acc_top1
.
numpy
()))
acc_top5_ns
.
append
(
np
.
mean
(
acc_top5
.
numpy
()))
batch_id
+=
1
reader_start
=
time
.
time
()
_logger
.
info
(
"Final eval epoch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}"
.
format
(
epoch
,
np
.
mean
(
np
.
array
(
acc_top1_ns
)),
np
.
mean
(
np
.
array
(
acc_top5_ns
))))
return
np
.
mean
(
np
.
array
(
acc_top1_ns
))
def
cross_entropy
(
input
,
target
,
ls_epsilon
):
if
ls_epsilon
>
0
:
if
target
.
shape
[
-
1
]
!=
class_dim
:
target
=
paddle
.
nn
.
functional
.
one_hot
(
target
,
class_dim
)
target
=
paddle
.
nn
.
functional
.
label_smooth
(
target
,
epsilon
=
ls_epsilon
)
target
=
paddle
.
reshape
(
target
,
shape
=
[
-
1
,
class_dim
])
input
=
-
paddle
.
nn
.
functional
.
log_softmax
(
input
,
axis
=-
1
)
cost
=
paddle
.
sum
(
target
*
input
,
axis
=-
1
)
else
:
cost
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
input
,
label
=
target
)
avg_cost
=
paddle
.
mean
(
cost
)
return
avg_cost
def
train
(
epoch
,
net
):
net
.
train
()
batch_id
=
0
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
reader_start
=
time
.
time
()
for
data
in
train_loader
():
train_reader_cost
+=
time
.
time
()
-
reader_start
image
=
data
[
0
]
label
=
data
[
1
]
if
args
.
data
==
"cifar10"
:
label
=
paddle
.
reshape
(
label
,
[
-
1
,
1
])
train_start
=
time
.
time
()
out
=
net
(
image
)
avg_cost
=
cross_entropy
(
out
,
label
,
args
.
ls_epsilon
)
acc_top1
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
1
)
acc_top5
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
5
)
avg_cost
.
backward
()
opt
.
step
()
opt
.
clear_grad
()
lr
.
step
()
loss_n
=
np
.
mean
(
avg_cost
.
numpy
())
acc_top1_n
=
np
.
mean
(
acc_top1
.
numpy
())
acc_top5_n
=
np
.
mean
(
acc_top5
.
numpy
())
train_run_cost
+=
time
.
time
()
-
train_start
batch_size
=
image
.
shape
[
0
]
total_samples
+=
batch_size
if
batch_id
%
args
.
log_period
==
0
:
log_period
=
1
if
batch_id
==
0
else
args
.
log_period
_logger
.
info
(
"epoch[{}]-batch[{}] lr: {:.6f} - loss: {:.6f}; top1: {:.6f}; top5: {:.6f}; avg_reader_cost: {:.6f} s, avg_batch_cost: {:.6f} s, avg_samples: {}, avg_ips: {:.3f} images/s"
.
format
(
epoch
,
batch_id
,
lr
.
get_lr
(),
loss_n
,
acc_top1_n
,
acc_top5_n
,
train_reader_cost
/
log_period
,
(
train_reader_cost
+
train_run_cost
)
/
log_period
,
total_samples
/
log_period
,
total_samples
/
(
train_reader_cost
+
train_run_cost
)))
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
batch_id
+=
1
reader_start
=
time
.
time
()
# train loop
best_acc1
=
0.0
best_epoch
=
0
for
i
in
range
(
args
.
num_epochs
):
train
(
i
,
net
)
acc1
=
test
(
i
,
net
)
if
paddle
.
distributed
.
get_rank
()
==
0
:
model_prefix
=
os
.
path
.
join
(
args
.
model_save_dir
,
"epoch_"
+
str
(
i
))
paddle
.
save
(
net
.
state_dict
(),
model_prefix
+
".pdparams"
)
paddle
.
save
(
opt
.
state_dict
(),
model_prefix
+
".pdopt"
)
if
acc1
>
best_acc1
:
best_acc1
=
acc1
best_epoch
=
i
if
paddle
.
distributed
.
get_rank
()
==
0
:
model_prefix
=
os
.
path
.
join
(
args
.
model_save_dir
,
"best_model"
)
paddle
.
save
(
net
.
state_dict
(),
model_prefix
+
".pdparams"
)
paddle
.
save
(
opt
.
state_dict
(),
model_prefix
+
".pdopt"
)
# Save model
reper
.
convert
(
net
)
if
paddle
.
distributed
.
get_rank
()
==
0
:
# load best model
load_dygraph_pretrain
(
net
,
os
.
path
.
join
(
args
.
model_save_dir
,
"best_model"
))
path
=
os
.
path
.
join
(
args
.
model_save_dir
,
"inference_model"
,
'rep_model'
)
paddle
.
jit
.
save
(
net
,
path
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
224
,
224
],
dtype
=
'float32'
)
])
def
main
():
args
=
parser
.
parse_args
()
print_arguments
(
args
)
train
(
args
)
if
__name__
==
'__main__'
:
main
()
paddleslim/dygraph/__init__.py
浏览文件 @
65c776de
...
...
@@ -5,3 +5,5 @@ from .prune import *
__all__
+=
prune
.
__all__
from
.dist
import
*
__all__
+=
dist
.
__all__
from
.rep
import
*
__all__
+=
rep
.
__all__
\ No newline at end of file
paddleslim/dygraph/rep/__init__.py
0 → 100644
浏览文件 @
65c776de
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
rep
from
.
import
config
from
.
import
reper
from
.rep
import
Reparameter
from
.config
import
*
from
.reper
import
*
__all__
=
[]
__all__
+=
rep
.
__all__
__all__
+=
config
.
__all__
__all__
+=
reper
.
__all__
paddleslim/dygraph/rep/config.py
0 → 100644
浏览文件 @
65c776de
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
,
Union
import
paddle.nn
as
nn
from
.reper
import
DiverseBranchBlock
,
ACBlock
,
RepVGGBlock
,
SlimRepBlock
SUPPORT_REP_TYPE_LAYERS
=
[
nn
.
Conv2D
,
nn
.
Linear
]
__all__
=
[
"BaseRepConfig"
,
"DBBRepConfig"
,
"ACBRepConfig"
,
"RepVGGConfig"
,
"SlimRepConfig"
]
class
BaseRepConfig
:
"""
Basic reparameterization configuration class.
Args:
type_config(dict): Set the reper by the type of layer. The key of `type_config`
should be subclass of `paddle.nn.Layer`. Its priority is lower than `layer_config`.
Default is: `{nn.Conv2D: ACBlock}`.
layer_config(dict): Set the reper by layer. It has the highest priority among
all the setting methods. Such as: `{model.conv1: ACBlock}`. Default is None.
"""
def
__init__
(
self
,
type_config
:
Dict
=
{
nn
.
Conv2D
:
ACBlock
},
layer_config
:
Dict
=
None
,
):
self
.
_type_config
=
self
.
_set_type_config
(
type_config
)
self
.
_layer_config
=
self
.
_set_layer_config
(
layer_config
)
def
add_config
(
self
,
type_config
:
Dict
=
None
,
layer_config
:
Dict
=
None
,
):
self
.
_type_config
.
update
(
self
.
_set_type_config
(
type_config
))
self
.
_layer_config
.
update
(
self
.
_set_layer_config
(
layer_config
))
@
property
def
all_config
(
self
):
return
{
'type_config'
:
self
.
_type_config
,
'layer_config'
:
self
.
_layer_config
,
}
def
_set_type_config
(
self
,
type_config
):
_type_config
=
{}
if
type_config
:
for
_layer
in
type_config
:
assert
isinstance
(
_layer
,
type
)
and
issubclass
(
_layer
,
nn
.
Layer
),
"Expect to get subclasses under nn.Layer, but got {}."
.
format
(
_layer
)
assert
_layer
in
SUPPORT_REP_TYPE_LAYERS
,
"Expect to get one of `{}`, but got {}."
.
format
(
SUPPORT_REP_TYPE_LAYERS
,
_layer
)
_type_config
[
_layer
]
=
type_config
[
_layer
]
return
_type_config
def
_set_layer_config
(
self
,
layer_config
):
_layer_config
=
{}
if
layer_config
:
for
_layer
in
layer_config
:
is_support
=
False
for
support_type
in
SUPPORT_REP_TYPE_LAYERS
:
if
isinstance
(
_layer
,
support_type
):
is_support
=
True
assert
is_support
,
"Expect layer to get one of `{}`."
.
format
(
SUPPORT_REP_LAYERS
)
_layer_config
[
_layer
.
full_name
()]
=
layer_config
[
_layer
]
return
_layer_config
def
__str__
(
self
):
result
=
""
if
len
(
self
.
_type_config
)
>
0
:
result
+=
f
"Type config:
\n
{
self
.
_type_config
}
\n
"
if
len
(
self
.
_layer_config
)
>
0
:
result
+=
f
"Layer config:
\n
{
self
.
_layer_config
}
\n
"
return
result
class
DBBRepConfig
(
BaseRepConfig
):
"""
DBB reparameterization configuration class.
Args:
type_config(dict): Set the reper by the type of layer. The key of `type_config`
should be subclass of `paddle.nn.Layer`. Its priority is lower than `layer_config`.
Default is: `{nn.Conv2D: ACBlock}`.
layer_config(dict): Set the reper by layer. It has the highest priority among
all the setting methods. Such as: `{model.conv1: ACBlock}`. Default is None.
"""
def
__init__
(
self
,
type_config
:
Dict
=
{
nn
.
Conv2D
:
DiverseBranchBlock
},
layer_config
:
Dict
=
None
,
):
self
.
_type_config
=
self
.
_set_type_config
(
type_config
)
self
.
_layer_config
=
self
.
_set_layer_config
(
layer_config
)
class
ACBRepConfig
(
BaseRepConfig
):
"""
ACBlock reparameterization configuration class.
Args:
type_config(dict): Set the reper by the type of layer. The key of `type_config`
should be subclass of `paddle.nn.Layer`. Its priority is lower than `layer_config`.
Default is: `{nn.Conv2D: ACBlock}`.
layer_config(dict): Set the reper by layer. It has the highest priority among
all the setting methods. Such as: `{model.conv1: ACBlock}`. Default is None.
"""
def
__init__
(
self
,
type_config
:
Dict
=
{
nn
.
Conv2D
:
ACBlock
},
layer_config
:
Dict
=
None
,
):
self
.
_type_config
=
self
.
_set_type_config
(
type_config
)
self
.
_layer_config
=
self
.
_set_layer_config
(
layer_config
)
class
RepVGGConfig
(
BaseRepConfig
):
"""
RepVGG reparameterization configuration class.
Args:
type_config(dict): Set the reper by the type of layer. The key of `type_config`
should be subclass of `paddle.nn.Layer`. Its priority is lower than `layer_config`.
Default is: `{nn.Conv2D: ACBlock}`.
layer_config(dict): Set the reper by layer. It has the highest priority among
all the setting methods. Such as: `{model.conv1: ACBlock}`. Default is None.
"""
def
__init__
(
self
,
type_config
:
Dict
=
{
nn
.
Conv2D
:
RepVGGBlock
},
layer_config
:
Dict
=
None
,
):
self
.
_type_config
=
self
.
_set_type_config
(
type_config
)
self
.
_layer_config
=
self
.
_set_layer_config
(
layer_config
)
class
SlimRepConfig
(
BaseRepConfig
):
"""
SlimRepBlock reparameterization configuration class.
Args:
type_config(dict): Set the reper by the type of layer. The key of `type_config`
should be subclass of `paddle.nn.Layer`. Its priority is lower than `layer_config`.
Default is: `{nn.Conv2D: ACBlock}`.
layer_config(dict): Set the reper by layer. It has the highest priority among
all the setting methods. Such as: `{model.conv1: ACBlock}`. Default is None.
"""
def
__init__
(
self
,
type_config
:
Dict
=
{
nn
.
Conv2D
:
SlimRepBlock
},
layer_config
:
Dict
=
None
,
):
self
.
_type_config
=
self
.
_set_type_config
(
type_config
)
self
.
_layer_config
=
self
.
_set_layer_config
(
layer_config
)
paddleslim/dygraph/rep/rep.py
0 → 100644
浏览文件 @
65c776de
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
paddle.nn
as
nn
from
...common
import
get_logger
from
.config
import
BaseRepConfig
,
SUPPORT_REP_TYPE_LAYERS
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
__all__
=
[
"Reparameter"
]
class
Reparameter
:
"""
Re-parameterization interface of dygraph model.
Args:
model(nn.Layer): Model of networks.
config(instance): Reparameterization config, default is `BaseRepConfig`.
"""
def
__init__
(
self
,
config
=
BaseRepConfig
):
assert
config
!=
None
,
"config cannot be None."
self
.
_config
=
config
.
all_config
self
.
_layer2reper_config
=
{}
def
prepare
(
self
,
model
):
"""
Re-parameterization prepare model callback interface.
Args:
model(nn.Layer): The model to be reparameterized.
"""
self
.
_layer2reper_config
=
self
.
_parser_rep_config
(
model
)
# Conv2D
if
"Conv2D"
in
self
.
_layer2reper_config
:
conv2d2reper_config
=
self
.
_layer2reper_config
[
"Conv2D"
]
conv_bn_pairs
=
self
.
_get_conv_bn_pair
(
model
)
if
not
conv_bn_pairs
:
_logger
.
info
(
"No conv-bn layer found, so skip the reparameterization."
)
return
model
for
layer_name
in
conv2d2reper_config
:
if
layer_name
in
list
(
conv_bn_pairs
.
keys
()):
per_conv_bn_pair
=
[
layer_name
,
conv_bn_pairs
[
layer_name
]]
self
.
_replace_conv_bn_with_reper
(
model
,
conv2d2reper_config
[
layer_name
],
per_conv_bn_pair
)
return
model
def
convert
(
self
,
model
):
"""
Re-parameterization export interface, it will run fusion operation.
Args:
model(nn.Layer): The model that has been reparameterized.
"""
for
layer
in
model
.
sublayers
():
if
hasattr
(
layer
,
'convert_to_deploy'
):
layer
.
convert_to_deploy
()
def
_parser_rep_config
(
self
,
model
):
_layer2reper_config
=
{}
for
name
,
layer
in
model
.
named_sublayers
():
support_type_layers
=
list
(
self
.
_config
[
'type_config'
].
keys
())
refine_layer_full_names
=
list
(
self
.
_config
[
'layer_config'
].
keys
())
cur_layer_reper
=
None
# Firstly, parser type layer in model.
for
layer_type
in
support_type_layers
:
if
isinstance
(
layer
,
layer_type
):
cur_layer_reper
=
self
.
_config
[
'type_config'
][
layer_type
]
# Secondly, parser layer full name in model.
if
name
in
refine_layer_full_names
:
cur_layer_reper
=
self
.
_config
[
'layer_config'
][
name
]
# Conv2d
if
cur_layer_reper
and
isinstance
(
layer
,
nn
.
Conv2D
):
if
"Conv2D"
in
_layer2reper_config
:
_layer2reper_config
[
"Conv2D"
].
update
({
name
:
cur_layer_reper
})
else
:
_layer2reper_config
[
"Conv2D"
]
=
{
name
:
cur_layer_reper
}
# Linear
elif
cur_layer_reper
and
isinstance
(
layer
,
nn
.
Linear
):
if
"Linear"
in
_layer2reper_config
:
_layer2reper_config
[
"Linear"
].
update
({
name
:
cur_layer_reper
})
else
:
_layer2reper_config
[
"Linear"
]
=
{
name
:
cur_layer_reper
}
elif
cur_layer_reper
:
_logger
.
info
(
"{} not support reparameterization, please choose one of {}"
.
format
(
name
,
SUPPORT_REP_TYPE_LAYERS
))
return
_layer2reper_config
def
_get_conv_bn_pair
(
self
,
model
):
"""
Get the combination of Conv2D and BatchNorm2D.
Args:
model(nn.Layer): The model that has been reparameterized.
"""
conv_bn_pairs
=
{}
tmp_pair
=
[
None
,
None
]
for
name
,
layer
in
model
.
named_sublayers
():
if
isinstance
(
layer
,
nn
.
Conv2D
):
tmp_pair
[
0
]
=
name
if
isinstance
(
layer
,
nn
.
BatchNorm2D
)
or
isinstance
(
layer
,
nn
.
BatchNorm
):
tmp_pair
[
1
]
=
name
if
tmp_pair
[
0
]
and
tmp_pair
[
1
]
and
len
(
tmp_pair
)
==
2
:
conv_bn_pairs
[
tmp_pair
[
0
]]
=
tmp_pair
[
1
]
tmp_pair
=
[
None
,
None
]
return
conv_bn_pairs
def
_replace_conv_bn_with_reper
(
self
,
model
,
reper
,
conv_bn_pair
):
"""
Replace Conv2D and BatchNorm2D with reper.
Args:
model(nn.Layer): The model that has been reparameterized.
reper(nn.Layer): The reper used by the current layer.
conv_bn_pairs(list[str, str]): List of combination of Conv2D and BatchNorm2D.
"""
for
layer_name
in
conv_bn_pair
:
parent_layer
,
sub_name
=
self
.
_find_parent_layer_and_sub_name
(
model
,
layer_name
)
module
=
getattr
(
parent_layer
,
sub_name
)
if
isinstance
(
module
,
nn
.
Conv2D
):
new_layer
=
reper
(
in_channels
=
module
.
_in_channels
,
out_channels
=
module
.
_out_channels
,
kernel_size
=
module
.
_kernel_size
[
0
],
stride
=
module
.
_stride
[
0
],
groups
=
module
.
_groups
,
padding
=
module
.
_padding
)
setattr
(
parent_layer
,
sub_name
,
new_layer
)
if
isinstance
(
module
,
nn
.
BatchNorm2D
)
or
isinstance
(
module
,
nn
.
BatchNorm
):
new_layer
=
nn
.
Identity
()
setattr
(
parent_layer
,
sub_name
,
new_layer
)
def
_find_parent_layer_and_sub_name
(
self
,
model
,
name
):
"""
Given the model and the name of a layer, find the parent layer and
the sub_name of the layer.
For example, if name is 'block_1/convbn_1/conv_1', the parent layer is
'block_1/convbn_1' and the sub_name is `conv_1`.
Args:
model(paddle.nn.Layer): the model to be reparameterized.
name(string): the name of a layer.
Returns:
parent_layer, subname
"""
assert
isinstance
(
model
,
nn
.
Layer
),
\
"The model must be the instance of paddle.nn.Layer."
assert
len
(
name
)
>
0
,
"The input (name) should not be empty."
last_idx
=
0
idx
=
0
parent_layer
=
model
while
idx
<
len
(
name
):
if
name
[
idx
]
==
'.'
:
sub_name
=
name
[
last_idx
:
idx
]
if
hasattr
(
parent_layer
,
sub_name
):
parent_layer
=
getattr
(
parent_layer
,
sub_name
)
last_idx
=
idx
+
1
idx
+=
1
sub_name
=
name
[
last_idx
:
idx
]
return
parent_layer
,
sub_name
paddleslim/dygraph/rep/reper/__init__.py
0 → 100644
浏览文件 @
65c776de
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
diversebranchblock
from
.
import
acblock
from
.
import
repvgg
from
.
import
slimrep
from
.
import
base
from
.diversebranchblock
import
DiverseBranchBlock
from
.acblock
import
ACBlock
from
.repvgg
import
RepVGGBlock
from
.slimrep
import
SlimRepBlock
__all__
=
[]
__all__
+=
diversebranchblock
.
__all__
__all__
+=
acblock
.
__all__
__all__
+=
repvgg
.
__all__
__all__
+=
slimrep
.
__all__
paddleslim/dygraph/rep/reper/acblock.py
0 → 100644
浏览文件 @
65c776de
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
from
.base
import
BaseConv2DReper
,
ConvBNLayer
__all__
=
[
"ACBlock"
]
class
ACBlock
(
BaseConv2DReper
):
"""
An instance of the ACBlock module, which replaces the conv-bn layer in the network.
Refer from Paper: https://arxiv.org/abs/1908.03930.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
padding
=
None
):
super
(
ACBlock
,
self
).
__init__
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
groups
=
groups
,
padding
=
padding
)
if
self
.
padding
-
self
.
kernel_size
//
2
>=
0
:
self
.
crop
=
0
# Compared to the KxK layer, the padding of the 1xK layer and Kx1 layer should be adjust to align the sliding windows (Fig 2 in the paper)
hor_padding
=
[
self
.
padding
-
self
.
kernel_size
//
2
,
self
.
padding
]
ver_padding
=
[
self
.
padding
,
self
.
padding
-
self
.
kernel_size
//
2
]
else
:
# A negative "padding" (padding - kernel_size//2 < 0, which is not a common use case) is cropping.
# Since nn.Conv2D does not support negative padding, we implement it manually
self
.
crop
=
self
.
kernel_size
//
2
-
self
.
padding
hor_padding
=
[
0
,
self
.
padding
]
ver_padding
=
[
self
.
padding
,
0
]
# kxk square branch
self
.
square_branch
=
ConvBNLayer
(
self
.
in_channels
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
stride
,
groups
=
self
.
groups
,
padding
=
self
.
padding
)
# kx1 vertical branch
self
.
ver_branch
=
ConvBNLayer
(
self
.
in_channels
,
self
.
out_channels
,
(
self
.
kernel_size
,
1
),
self
.
stride
,
groups
=
self
.
groups
,
padding
=
ver_padding
)
# 1xk horizontal branch
self
.
hor_branch
=
ConvBNLayer
(
self
.
in_channels
,
self
.
out_channels
,
(
1
,
self
.
kernel_size
),
self
.
stride
,
groups
=
self
.
groups
,
padding
=
hor_padding
)
def
_add_to_square_kernel
(
self
,
square_kernel
,
asym_kernel
):
asym_h
=
asym_kernel
.
shape
[
2
]
asym_w
=
asym_kernel
.
shape
[
3
]
square_h
=
square_kernel
.
shape
[
2
]
square_w
=
square_kernel
.
shape
[
3
]
square_kernel
[:,
:,
square_h
//
2
-
asym_h
//
2
:
square_h
//
2
-
asym_h
//
2
+
asym_h
,
square_w
//
2
-
asym_w
//
2
:
square_w
//
2
-
asym_w
//
2
+
asym_w
]
+=
asym_kernel
def
_fuse_bn
(
self
,
kernel
,
bn
):
running_mean
=
bn
.
_mean
running_var
=
bn
.
_variance
gamma
=
bn
.
weight
beta
=
bn
.
bias
eps
=
bn
.
_epsilon
std
=
(
running_var
+
eps
).
sqrt
()
t
=
(
gamma
/
std
).
reshape
((
-
1
,
1
,
1
,
1
))
return
kernel
*
t
,
beta
-
running_mean
*
gamma
/
std
def
_get_equivalent_kernel_bias
(
self
):
hor_k
,
hor_b
=
self
.
_fuse_bn
(
self
.
hor_branch
.
conv
.
weight
,
self
.
hor_branch
.
bn
)
ver_k
,
ver_b
=
self
.
_fuse_bn
(
self
.
ver_branch
.
conv
.
weight
,
self
.
ver_branch
.
bn
)
square_k
,
square_b
=
self
.
_fuse_bn
(
self
.
square_branch
.
conv
.
weight
,
self
.
square_branch
.
bn
)
self
.
_add_to_square_kernel
(
square_k
,
hor_k
)
self
.
_add_to_square_kernel
(
square_k
,
ver_k
)
return
square_k
,
hor_b
+
ver_b
+
square_b
def
convert_to_deploy
(
self
):
if
hasattr
(
self
,
'fused_branch'
):
return
kernel
,
bias
=
self
.
_get_equivalent_kernel_bias
()
self
.
fused_branch
=
nn
.
Conv2D
(
in_channels
=
self
.
in_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
self
.
kernel_size
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
groups
=
self
.
groups
,
bias_attr
=
True
)
self
.
fused_branch
.
weight
.
set_value
(
kernel
)
self
.
fused_branch
.
bias
.
set_value
(
bias
)
self
.
__delattr__
(
'ver_branch'
)
self
.
__delattr__
(
'hor_branch'
)
self
.
__delattr__
(
'square_branch'
)
def
forward
(
self
,
input
):
if
hasattr
(
self
,
'fused_branch'
):
return
self
.
fused_branch
(
input
)
out
=
self
.
square_branch
(
input
)
if
self
.
crop
>
0
:
ver_input
=
input
[:,
:,
:,
self
.
crop
:
-
self
.
crop
]
hor_input
=
input
[:,
:,
self
.
crop
:
-
self
.
crop
,
:]
else
:
ver_input
=
input
hor_input
=
input
out
+=
self
.
ver_branch
(
ver_input
)
out
+=
self
.
hor_branch
(
hor_input
)
return
out
paddleslim/dygraph/rep/reper/base.py
0 → 100644
浏览文件 @
65c776de
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.nn
as
nn
from
paddle
import
ParamAttr
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
KaimingNormal
class
BaseConv2DReper
(
nn
.
Layer
):
"""
An Base instance of the Reparameterization module based on Conv2D.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
padding
=
None
):
super
(
BaseConv2DReper
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
kernel_size
self
.
stride
=
stride
self
.
groups
=
groups
self
.
padding
=
padding
def
convert_to_deploy
(
self
):
pass
def
forward
(
self
,
input
):
pass
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
filter_size
,
stride
,
groups
=
1
,
padding
=
None
):
super
().
__init__
()
if
not
padding
:
padding
=
filter_size
//
2
self
.
conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()),
bias_attr
=
False
)
self
.
bn
=
nn
.
BatchNorm2D
(
out_channels
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
return
x
paddleslim/dygraph/rep/reper/diversebranchblock.py
0 → 100644
浏览文件 @
65c776de
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is referenced from: https://github.com/DingXiaoH/DiverseBranchBlock/blob/main/diversebranchblock.py
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
.base
import
BaseConv2DReper
,
ConvBNLayer
__all__
=
[
"DiverseBranchBlock"
]
class
IdentityBasedConv1x1
(
nn
.
Conv2D
):
def
__init__
(
self
,
channels
,
groups
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
nn
.
initializer
.
Constant
(
0.0
))):
super
(
IdentityBasedConv1x1
,
self
).
__init__
(
in_channels
=
channels
,
out_channels
=
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
groups
=
groups
,
weight_attr
=
weight_attr
,
bias_attr
=
False
)
assert
channels
%
groups
==
0
input_dim
=
channels
//
groups
id_value
=
np
.
zeros
((
channels
,
input_dim
,
1
,
1
))
for
i
in
range
(
channels
):
id_value
[
i
,
i
%
input_dim
,
0
,
0
]
=
1
self
.
id_tensor
=
paddle
.
to_tensor
(
id_value
)
self
.
groups
=
groups
def
forward
(
self
,
input
):
kernel
=
self
.
weight
+
self
.
id_tensor
result
=
F
.
conv2d
(
input
,
kernel
,
None
,
stride
=
1
,
padding
=
0
,
groups
=
self
.
groups
)
return
result
def
get_actual_kernel
(
self
):
return
self
.
weight
+
self
.
id_tensor
class
BNAndPadLayer
(
nn
.
Layer
):
def
__init__
(
self
,
pad_pixels
,
num_features
,
eps
=
1e-5
,
momentum
=
0.1
):
super
(
BNAndPadLayer
,
self
).
__init__
()
self
.
bn
=
nn
.
BatchNorm2D
(
num_features
,
momentum
,
eps
)
self
.
pad_pixels
=
pad_pixels
def
forward
(
self
,
input
):
output
=
self
.
bn
(
input
)
if
self
.
pad_pixels
>
0
:
pad_values
=
self
.
bn
.
bias
-
self
.
bn
.
_mean
*
self
.
bn
.
weight
/
paddle
.
sqrt
(
self
.
bn
.
_variance
+
self
.
bn
.
_epsilon
)
output
=
F
.
pad
(
output
,
[
self
.
pad_pixels
]
*
4
)
pad_values
=
pad_values
.
reshape
((
1
,
-
1
,
1
,
1
))
output
[:,
:,
0
:
self
.
pad_pixels
,
:]
=
pad_values
output
[:,
:,
-
self
.
pad_pixels
:,
:]
=
pad_values
output
[:,
:,
:,
0
:
self
.
pad_pixels
]
=
pad_values
output
[:,
:,
:,
-
self
.
pad_pixels
:]
=
pad_values
return
output
@
property
def
weight
(
self
):
return
self
.
bn
.
weight
@
property
def
bias
(
self
):
return
self
.
bn
.
bias
@
property
def
_mean
(
self
):
return
self
.
bn
.
_mean
@
property
def
_variance
(
self
):
return
self
.
bn
.
_variance
@
property
def
_epsilon
(
self
):
return
self
.
bn
.
_epsilon
class
DiverseBranchBlock
(
BaseConv2DReper
):
"""
An instance of the DBB module, which replaces the conv-bn layer in the network.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
padding
=
None
,
internal_channels_1x1_3x3
=
None
):
super
(
DiverseBranchBlock
,
self
).
__init__
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
groups
=
groups
,
padding
=
padding
)
# kxk branch
self
.
dbb_origin
=
ConvBNLayer
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
groups
=
groups
)
# 1x1-avg branch
self
.
dbb_avg
=
nn
.
Sequential
()
if
groups
<
out_channels
:
self
.
dbb_avg
.
add_sublayer
(
'conv'
,
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
groups
=
groups
,
bias_attr
=
False
))
self
.
dbb_avg
.
add_sublayer
(
'bn'
,
BNAndPadLayer
(
pad_pixels
=
self
.
padding
,
num_features
=
out_channels
))
self
.
dbb_avg
.
add_sublayer
(
'avg'
,
nn
.
AvgPool2D
(
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
0
))
else
:
self
.
dbb_avg
.
add_sublayer
(
'avg'
,
nn
.
AvgPool2D
(
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
self
.
padding
))
self
.
dbb_avg
.
add_sublayer
(
'avgbn'
,
nn
.
BatchNorm2D
(
out_channels
))
# 1x1 branch
if
groups
<
out_channels
:
self
.
dbb_1x1
=
ConvBNLayer
(
in_channels
,
out_channels
,
1
,
stride
,
groups
=
groups
)
# 1x1-kxk branch
if
internal_channels_1x1_3x3
is
None
:
# For mobilenet, it is better to have 2X internal channels
internal_channels_1x1_3x3
=
in_channels
if
groups
<
out_channels
else
2
*
in_channels
self
.
dbb_1x1_kxk
=
nn
.
Sequential
()
if
internal_channels_1x1_3x3
==
in_channels
:
self
.
dbb_1x1_kxk
.
add_sublayer
(
'idconv1'
,
IdentityBasedConv1x1
(
channels
=
in_channels
,
groups
=
groups
))
else
:
self
.
dbb_1x1_kxk
.
add_sublayer
(
'conv1'
,
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
internal_channels_1x1_3x3
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
groups
=
groups
,
bias_attr
=
False
))
self
.
dbb_1x1_kxk
.
add_sublayer
(
'bn1'
,
BNAndPadLayer
(
pad_pixels
=
self
.
padding
,
num_features
=
internal_channels_1x1_3x3
))
self
.
dbb_1x1_kxk
.
add_sublayer
(
'conv2'
,
nn
.
Conv2D
(
in_channels
=
internal_channels_1x1_3x3
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
0
,
groups
=
groups
,
bias_attr
=
False
))
self
.
dbb_1x1_kxk
.
add_sublayer
(
'bn2'
,
nn
.
BatchNorm2D
(
out_channels
))
def
_fuse_bn
(
self
,
kernel
,
bn
):
running_mean
=
bn
.
_mean
running_var
=
bn
.
_variance
gamma
=
bn
.
weight
beta
=
bn
.
bias
eps
=
bn
.
_epsilon
std
=
(
running_var
+
eps
).
sqrt
()
t
=
(
gamma
/
std
).
reshape
((
-
1
,
1
,
1
,
1
))
return
kernel
*
t
,
beta
-
running_mean
*
gamma
/
std
def
_fuse_1x1_kxk
(
self
,
k1
,
b1
,
k2
,
b2
,
groups
):
if
groups
==
1
:
k
=
F
.
conv2d
(
k2
,
k1
.
transpose
((
1
,
0
,
2
,
3
)))
b_hat
=
(
k2
*
b1
.
reshape
((
1
,
-
1
,
1
,
1
))).
sum
((
1
,
2
,
3
))
else
:
k_slices
=
[]
b_slices
=
[]
k1_T
=
k1
.
transpose
((
1
,
0
,
2
,
3
))
k1_group_width
=
k1
.
shape
[
0
]
//
groups
k2_group_width
=
k2
.
shape
[
0
]
//
groups
for
g
in
range
(
groups
):
k1_T_slice
=
k1_T
[:,
g
*
k1_group_width
:(
g
+
1
)
*
k1_group_width
,
:,
:]
k2_slice
=
k2
[
g
*
k2_group_width
:(
g
+
1
)
*
k2_group_width
,
:,
:,
:]
k_slices
.
append
(
F
.
conv2d
(
k2_slice
,
k1_T_slice
))
b_slices
.
append
(
(
k2_slice
*
b1
[
g
*
k1_group_width
:(
g
+
1
)
*
k1_group_width
].
reshape
(
(
1
,
-
1
,
1
,
1
))).
sum
((
1
,
2
,
3
)))
k
=
paddle
.
concat
(
k_slices
)
b_hat
=
paddle
.
concat
(
b_slices
)
return
k
,
b_hat
+
b2
def
_fuse_avg
(
self
,
channels
,
kernel_size
,
groups
):
input_dim
=
channels
//
groups
k
=
paddle
.
zeros
((
channels
,
input_dim
,
kernel_size
,
kernel_size
))
k
[
np
.
arange
(
channels
),
np
.
tile
(
np
.
arange
(
input_dim
),
groups
),
:,
:]
=
1.0
/
kernel_size
**
2
return
k
# This has not been tested with non-square kernels (kernel.size(2) != kernel.size(3)) nor even-size kernels
def
_fuse_multiscale
(
self
,
kernel
,
target_kernel_size
):
H_pixels_to_pad
=
(
target_kernel_size
-
kernel
.
shape
[
2
])
//
2
W_pixels_to_pad
=
(
target_kernel_size
-
kernel
.
shape
[
3
])
//
2
return
F
.
pad
(
kernel
,
[
H_pixels_to_pad
,
H_pixels_to_pad
,
W_pixels_to_pad
,
W_pixels_to_pad
])
def
_get_equivalent_kernel_bias
(
self
):
k_origin
,
b_origin
=
self
.
_fuse_bn
(
self
.
dbb_origin
.
conv
.
weight
,
self
.
dbb_origin
.
bn
)
if
hasattr
(
self
,
'dbb_1x1'
):
k_1x1
,
b_1x1
=
self
.
_fuse_bn
(
self
.
dbb_1x1
.
conv
.
weight
,
self
.
dbb_1x1
.
bn
)
k_1x1
=
self
.
_fuse_multiscale
(
k_1x1
,
self
.
kernel_size
)
else
:
k_1x1
,
b_1x1
=
0
,
0
if
hasattr
(
self
.
dbb_1x1_kxk
,
'idconv1'
):
k_1x1_kxk_first
=
self
.
dbb_1x1_kxk
.
idconv1
.
get_actual_kernel
()
else
:
k_1x1_kxk_first
=
self
.
dbb_1x1_kxk
.
conv1
.
weight
k_1x1_kxk_first
,
b_1x1_kxk_first
=
self
.
_fuse_bn
(
k_1x1_kxk_first
,
self
.
dbb_1x1_kxk
.
bn1
)
k_1x1_kxk_second
,
b_1x1_kxk_second
=
self
.
_fuse_bn
(
self
.
dbb_1x1_kxk
.
conv2
.
weight
,
self
.
dbb_1x1_kxk
.
bn2
)
k_1x1_kxk_merged
,
b_1x1_kxk_merged
=
self
.
_fuse_1x1_kxk
(
k_1x1_kxk_first
,
b_1x1_kxk_first
,
k_1x1_kxk_second
,
b_1x1_kxk_second
,
groups
=
self
.
groups
)
k_avg
=
self
.
_fuse_avg
(
self
.
out_channels
,
self
.
kernel_size
,
self
.
groups
)
k_1x1_avg_second
,
b_1x1_avg_second
=
self
.
_fuse_bn
(
k_avg
,
self
.
dbb_avg
.
avgbn
)
if
hasattr
(
self
.
dbb_avg
,
'conv'
):
k_1x1_avg_first
,
b_1x1_avg_first
=
self
.
_fuse_bn
(
self
.
dbb_avg
.
conv
.
weight
,
self
.
dbb_avg
.
bn
)
k_1x1_avg_merged
,
b_1x1_avg_merged
=
self
.
_fuse_1x1_kxk
(
k_1x1_avg_first
,
b_1x1_avg_first
,
k_1x1_avg_second
,
b_1x1_avg_second
,
groups
=
self
.
groups
)
else
:
k_1x1_avg_merged
,
b_1x1_avg_merged
=
k_1x1_avg_second
,
b_1x1_avg_second
return
sum
([
k_origin
,
k_1x1
,
k_1x1_kxk_merged
,
k_1x1_avg_merged
]),
sum
(
[
b_origin
,
b_1x1
,
b_1x1_kxk_merged
,
b_1x1_avg_merged
])
def
convert_to_deploy
(
self
):
if
hasattr
(
self
,
'dbb_reparam'
):
return
kernel
,
bias
=
self
.
_get_equivalent_kernel_bias
()
self
.
dbb_reparam
=
nn
.
Conv2D
(
in_channels
=
self
.
in_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
self
.
kernel_size
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
groups
=
self
.
groups
,
bias_attr
=
True
)
self
.
dbb_reparam
.
weight
.
set_value
(
kernel
)
self
.
dbb_reparam
.
bias
.
set_value
(
bias
)
self
.
__delattr__
(
'dbb_origin'
)
self
.
__delattr__
(
'dbb_avg'
)
if
hasattr
(
self
,
'dbb_1x1'
):
self
.
__delattr__
(
'dbb_1x1'
)
self
.
__delattr__
(
'dbb_1x1_kxk'
)
def
forward
(
self
,
inputs
):
if
hasattr
(
self
,
'dbb_reparam'
):
return
self
.
dbb_reparam
(
inputs
)
out
=
self
.
dbb_origin
(
inputs
)
if
hasattr
(
self
,
'dbb_1x1'
):
out
+=
self
.
dbb_1x1
(
inputs
)
out
+=
self
.
dbb_avg
(
inputs
)
out
+=
self
.
dbb_1x1_kxk
(
inputs
)
return
out
paddleslim/dygraph/rep/reper/repvgg.py
0 → 100644
浏览文件 @
65c776de
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
from
paddle
import
ParamAttr
from
paddle.regularizer
import
L2Decay
import
paddle.nn
as
nn
from
.base
import
BaseConv2DReper
,
ConvBNLayer
__all__
=
[
"RepVGGBlock"
]
class
RepVGGBlock
(
BaseConv2DReper
):
"""
An instance of the RepVGGBlock module, which replaces the conv-bn layer in the network.
Refer from Paper: https://arxiv.org/abs/2101.03697.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
padding
=
None
):
super
(
RepVGGBlock
,
self
).
__init__
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
groups
=
groups
,
padding
=
padding
)
# Re-parameterizable skip connection
self
.
rbr_skip
=
nn
.
BatchNorm2D
(
num_features
=
in_channels
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
))
)
if
in_channels
==
out_channels
and
self
.
stride
==
1
else
None
# Re-parameterizable conv branches
self
.
rbr_conv
=
ConvBNLayer
(
self
.
in_channels
,
self
.
out_channels
,
self
.
kernel_size
,
stride
=
self
.
stride
,
groups
=
self
.
groups
)
# Re-parameterizable scale branch
self
.
rbr_scale
=
None
if
kernel_size
>
1
:
self
.
rbr_scale
=
ConvBNLayer
(
self
.
in_channels
,
self
.
out_channels
,
1
,
stride
=
self
.
stride
,
groups
=
self
.
groups
)
def
forward
(
self
,
x
):
# Inference mode forward pass.
if
hasattr
(
self
,
"reparam_conv"
):
return
self
.
reparam_conv
(
x
)
# Multi-branched train-time forward pass.
# Skip branch output
identity_out
=
0
if
self
.
rbr_skip
is
not
None
:
identity_out
=
self
.
rbr_skip
(
x
)
# Scale branch output
scale_out
=
0
if
self
.
rbr_scale
is
not
None
:
scale_out
=
self
.
rbr_scale
(
x
)
# Other branches
out
=
scale_out
+
identity_out
out
+=
self
.
rbr_conv
(
x
)
return
out
def
convert_to_deploy
(
self
):
"""
Re-parameterize multi-branched architecture used at training
time to obtain a plain CNN-like structure for inference.
"""
if
hasattr
(
self
,
'reparam_conv'
):
return
kernel
,
bias
=
self
.
_get_kernel_bias
()
self
.
reparam_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
in_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
self
.
kernel_size
,
stride
=
self
.
stride
,
padding
=
(
self
.
kernel_size
-
1
)
//
2
,
groups
=
self
.
groups
)
self
.
reparam_conv
.
weight
.
set_value
(
kernel
)
self
.
reparam_conv
.
bias
.
set_value
(
bias
)
# Delete un-used branches
self
.
__delattr__
(
'rbr_conv'
)
if
hasattr
(
self
,
'rbr_scale'
):
self
.
__delattr__
(
'rbr_scale'
)
if
hasattr
(
self
,
'rbr_skip'
):
self
.
__delattr__
(
'rbr_skip'
)
def
_get_kernel_bias
(
self
):
"""
Method to obtain re-parameterized kernel and bias.
"""
# get weights and bias of scale branch
kernel_scale
=
0
bias_scale
=
0
if
self
.
rbr_scale
is
not
None
:
kernel_scale
,
bias_scale
=
self
.
_fuse_bn_tensor
(
self
.
rbr_scale
)
# Pad scale branch kernel to match conv branch kernel size. 1x1->3x3
padding_size
=
self
.
kernel_size
//
2
kernel_scale
=
paddle
.
nn
.
functional
.
pad
(
kernel_scale
,
[
padding_size
,
padding_size
,
padding_size
,
padding_size
])
# get weights and bias of skip branch
kernel_identity
=
0
bias_identity
=
0
if
self
.
rbr_skip
is
not
None
:
kernel_identity
,
bias_identity
=
self
.
_fuse_bn_tensor
(
self
.
rbr_skip
)
# get weights and bias of conv branches
kernel_conv
,
bias_conv
=
self
.
_fuse_bn_tensor
(
self
.
rbr_conv
)
kernel_final
=
kernel_conv
+
kernel_scale
+
kernel_identity
bias_final
=
bias_conv
+
bias_scale
+
bias_identity
return
kernel_final
,
bias_final
def
_fuse_bn_tensor
(
self
,
branch
):
if
branch
is
None
:
return
0
,
0
if
isinstance
(
branch
,
nn
.
LayerList
):
fused_kernels
=
[]
fused_bias
=
[]
for
block
in
branch
:
kernel
=
block
.
conv
.
weight
running_mean
=
block
.
bn
.
_mean
running_var
=
block
.
bn
.
_variance
gamma
=
block
.
bn
.
weight
beta
=
block
.
bn
.
bias
eps
=
block
.
bn
.
_epsilon
std
=
(
running_var
+
eps
).
sqrt
()
t
=
(
gamma
/
std
).
reshape
((
-
1
,
1
,
1
,
1
))
fused_kernels
.
append
(
kernel
*
t
)
fused_bias
.
append
(
beta
-
running_mean
*
gamma
/
std
)
return
sum
(
fused_kernels
),
sum
(
fused_bias
)
elif
isinstance
(
branch
,
ConvBNLayer
):
kernel
=
branch
.
conv
.
weight
running_mean
=
branch
.
bn
.
_mean
running_var
=
branch
.
bn
.
_variance
gamma
=
branch
.
bn
.
weight
beta
=
branch
.
bn
.
bias
eps
=
branch
.
bn
.
_epsilon
else
:
assert
isinstance
(
branch
,
nn
.
BatchNorm2D
)
input_dim
=
self
.
in_channels
if
self
.
kernel_size
==
1
else
1
kernel_value
=
paddle
.
zeros
(
shape
=
[
self
.
in_channels
,
input_dim
,
self
.
kernel_size
,
self
.
kernel_size
],
dtype
=
'float32'
)
if
self
.
kernel_size
>
1
:
for
i
in
range
(
self
.
in_channels
):
kernel_value
[
i
,
i
%
input_dim
,
(
self
.
kernel_size
-
1
)
//
2
,
(
self
.
kernel_size
-
1
)
//
2
]
=
1
elif
self
.
kernel_size
==
1
:
for
i
in
range
(
self
.
in_channels
):
kernel_value
[
i
,
i
%
input_dim
,
0
,
0
]
=
1
else
:
raise
ValueError
(
"Invalid kernel size recieved!"
)
kernel
=
paddle
.
to_tensor
(
kernel_value
,
place
=
branch
.
weight
.
place
)
running_mean
=
branch
.
_mean
running_var
=
branch
.
_variance
gamma
=
branch
.
weight
beta
=
branch
.
bias
eps
=
branch
.
_epsilon
std
=
(
running_var
+
eps
).
sqrt
()
t
=
(
gamma
/
std
).
reshape
((
-
1
,
1
,
1
,
1
))
return
kernel
*
t
,
beta
-
running_mean
*
gamma
/
std
paddleslim/dygraph/rep/reper/slimrep.py
0 → 100644
浏览文件 @
65c776de
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
from
paddle
import
ParamAttr
from
paddle.regularizer
import
L2Decay
import
paddle.nn
as
nn
from
.base
import
BaseConv2DReper
,
ConvBNLayer
__all__
=
[
"SlimRepBlock"
]
class
SlimRepBlock
(
BaseConv2DReper
):
"""
An instance of the SlimRepBlock module, which replaces the conv-bn layer in the network.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
padding
=
None
):
super
(
SlimRepBlock
,
self
).
__init__
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
groups
=
groups
,
padding
=
padding
)
self
.
num_conv_branches
=
1
if
not
self
.
padding
:
self
.
padding
=
self
.
kernel_size
//
2
if
self
.
padding
-
self
.
kernel_size
//
2
>=
0
:
self
.
crop
=
0
# Compared to the KxK layer, the padding of the 1xK layer and Kx1 layer should be adjust to align the sliding windows (Fig 2 in the paper)
hor_padding
=
[
self
.
padding
-
self
.
kernel_size
//
2
,
self
.
padding
]
ver_padding
=
[
self
.
padding
,
self
.
padding
-
self
.
kernel_size
//
2
]
else
:
# A negative "padding" (padding - kernel_size//2 < 0, which is not a common use case) is cropping.
# Since nn.Conv2D does not support negative padding, we implement it manually
self
.
crop
=
self
.
kernel_size
//
2
-
self
.
padding
hor_padding
=
[
0
,
self
.
padding
]
ver_padding
=
[
self
.
padding
,
0
]
# Re-parameterizable skip connection
self
.
rbr_skip
=
nn
.
BatchNorm2D
(
num_features
=
in_channels
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
))
)
if
in_channels
==
out_channels
and
self
.
stride
==
1
else
None
# Re-parameterizable conv branches
self
.
rbr_conv
=
nn
.
LayerList
()
for
_
in
range
(
self
.
num_conv_branches
):
for
kernel_size
in
range
(
self
.
kernel_size
,
0
,
-
2
):
self
.
rbr_conv
.
append
(
ConvBNLayer
(
self
.
in_channels
,
self
.
out_channels
,
kernel_size
,
stride
=
self
.
stride
,
groups
=
self
.
groups
))
# kx1 vertical branch
self
.
ver_branch
=
ConvBNLayer
(
self
.
in_channels
,
self
.
out_channels
,
(
self
.
kernel_size
,
1
),
self
.
stride
,
groups
=
self
.
groups
,
padding
=
ver_padding
)
# 1xk horizontal branch
self
.
hor_branch
=
ConvBNLayer
(
self
.
in_channels
,
self
.
out_channels
,
(
1
,
self
.
kernel_size
),
self
.
stride
,
groups
=
self
.
groups
,
padding
=
hor_padding
)
def
forward
(
self
,
x
):
# Inference mode forward pass.
if
hasattr
(
self
,
"reparam_conv"
):
return
self
.
reparam_conv
(
x
)
# Multi-branched train-time forward pass.
out
=
0
for
rbr_conv
in
self
.
rbr_conv
:
out
+=
rbr_conv
(
x
)
# Skip branch output
if
self
.
rbr_skip
is
not
None
:
out
+=
self
.
rbr_skip
(
x
)
if
self
.
crop
>
0
:
ver_input
=
x
[:,
:,
:,
self
.
crop
:
-
self
.
crop
]
hor_input
=
x
[:,
:,
self
.
crop
:
-
self
.
crop
,
:]
else
:
ver_input
=
x
hor_input
=
x
out
+=
self
.
ver_branch
(
ver_input
)
out
+=
self
.
hor_branch
(
hor_input
)
return
out
def
convert_to_deploy
(
self
):
"""
Re-parameterize multi-branched architecture used at training
time to obtain a plain CNN-like structure for inference.
"""
if
hasattr
(
self
,
'reparam_conv'
):
return
kernel
,
bias
=
self
.
_get_kernel_bias
()
self
.
reparam_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
in_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
self
.
kernel_size
,
stride
=
self
.
stride
,
padding
=
(
self
.
kernel_size
-
1
)
//
2
,
groups
=
self
.
groups
)
self
.
reparam_conv
.
weight
.
set_value
(
kernel
)
self
.
reparam_conv
.
bias
.
set_value
(
bias
)
# Delete un-used branches
self
.
__delattr__
(
'rbr_conv'
)
if
hasattr
(
self
,
'rbr_skip'
):
self
.
__delattr__
(
'rbr_skip'
)
self
.
__delattr__
(
'ver_branch'
)
self
.
__delattr__
(
'hor_branch'
)
def
_get_kernel_bias
(
self
):
"""
Method to obtain re-parameterized kernel and bias.
"""
# get weights and bias of conv branches
kernel_conv
=
0
bias_conv
=
0
for
ix
in
range
(
self
.
num_conv_branches
):
_kernel
,
_bias
=
self
.
_fuse_bn_tensor
(
self
.
rbr_conv
[
ix
])
_kernel
=
self
.
_pad_tensor
(
_kernel
,
to_size
=
self
.
kernel_size
)
kernel_conv
+=
_kernel
bias_conv
+=
_bias
# get weights and bias of skip branch
kernel_identity
=
0
bias_identity
=
0
if
self
.
rbr_skip
is
not
None
:
kernel_identity
,
bias_identity
=
self
.
_fuse_bn_tensor
(
self
.
rbr_skip
)
kernel_final
=
kernel_conv
+
kernel_identity
bias_final
=
bias_conv
+
bias_identity
# get kx1 1xk branch
hor_k
,
hor_b
=
self
.
_fuse_bn_tensor
(
self
.
hor_branch
)
ver_k
,
ver_b
=
self
.
_fuse_bn_tensor
(
self
.
ver_branch
)
self
.
_add_to_square_kernel
(
kernel_final
,
hor_k
)
self
.
_add_to_square_kernel
(
kernel_final
,
ver_k
)
bias_final
+=
hor_b
+
ver_b
return
kernel_final
,
bias_final
def
_add_to_square_kernel
(
self
,
square_kernel
,
asym_kernel
):
asym_h
=
asym_kernel
.
shape
[
2
]
asym_w
=
asym_kernel
.
shape
[
3
]
square_h
=
square_kernel
.
shape
[
2
]
square_w
=
square_kernel
.
shape
[
3
]
square_kernel
[:,
:,
square_h
//
2
-
asym_h
//
2
:
square_h
//
2
-
asym_h
//
2
+
asym_h
,
square_w
//
2
-
asym_w
//
2
:
square_w
//
2
-
asym_w
//
2
+
asym_w
]
+=
asym_kernel
def
_pad_tensor
(
self
,
tensor
,
to_size
):
from_size
=
tensor
.
shape
[
-
1
]
if
from_size
==
to_size
:
return
tensor
pad
=
(
to_size
-
from_size
)
//
2
return
paddle
.
nn
.
functional
.
pad
(
tensor
,
[
pad
,
pad
,
pad
,
pad
])
def
_fuse_bn_tensor
(
self
,
branch
):
if
branch
is
None
:
return
0
,
0
if
isinstance
(
branch
,
nn
.
LayerList
):
fused_kernels
=
[]
fused_bias
=
[]
for
block
in
branch
:
kernel
=
block
.
conv
.
weight
running_mean
=
block
.
bn
.
_mean
running_var
=
block
.
bn
.
_variance
gamma
=
block
.
bn
.
weight
beta
=
block
.
bn
.
bias
eps
=
block
.
bn
.
_epsilon
std
=
(
running_var
+
eps
).
sqrt
()
t
=
(
gamma
/
std
).
reshape
((
-
1
,
1
,
1
,
1
))
fused_kernels
.
append
(
kernel
*
t
)
fused_bias
.
append
(
beta
-
running_mean
*
gamma
/
std
)
return
sum
(
fused_kernels
),
sum
(
fused_bias
)
elif
isinstance
(
branch
,
ConvBNLayer
):
kernel
=
branch
.
conv
.
weight
running_mean
=
branch
.
bn
.
_mean
running_var
=
branch
.
bn
.
_variance
gamma
=
branch
.
bn
.
weight
beta
=
branch
.
bn
.
bias
eps
=
branch
.
bn
.
_epsilon
else
:
assert
isinstance
(
branch
,
nn
.
BatchNorm2D
)
input_dim
=
self
.
in_channels
if
self
.
kernel_size
==
1
else
1
kernel_value
=
paddle
.
zeros
(
shape
=
[
self
.
in_channels
,
input_dim
,
self
.
kernel_size
,
self
.
kernel_size
],
dtype
=
'float32'
)
if
self
.
kernel_size
>
1
:
for
i
in
range
(
self
.
in_channels
):
kernel_value
[
i
,
i
%
input_dim
,
(
self
.
kernel_size
-
1
)
//
2
,
(
self
.
kernel_size
-
1
)
//
2
]
=
1
elif
self
.
kernel_size
==
1
:
for
i
in
range
(
self
.
in_channels
):
kernel_value
[
i
,
i
%
input_dim
,
0
,
0
]
=
1
else
:
raise
ValueError
(
"Invalid kernel size recieved!"
)
kernel
=
paddle
.
to_tensor
(
kernel_value
,
place
=
branch
.
weight
.
place
)
running_mean
=
branch
.
_mean
running_var
=
branch
.
_variance
gamma
=
branch
.
weight
beta
=
branch
.
bias
eps
=
branch
.
_epsilon
std
=
(
running_var
+
eps
).
sqrt
()
t
=
(
gamma
/
std
).
reshape
((
-
1
,
1
,
1
,
1
))
return
kernel
*
t
,
beta
-
running_mean
*
gamma
/
std
tests/dygraph/test_reparameterization.py
0 → 100644
浏览文件 @
65c776de
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
sys
sys
.
path
.
append
(
"../../"
)
import
unittest
import
logging
import
paddle
from
paddleslim.common
import
get_logger
from
paddleslim.dygraph.rep
import
Reparameter
,
DBBRepConfig
,
ACBRepConfig
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
class
ImperativeLenet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
num_classes
=
10
):
super
(
ImperativeLenet
,
self
).
__init__
()
self
.
features
=
paddle
.
nn
.
Sequential
(
paddle
.
nn
.
Conv2D
(
in_channels
=
1
,
out_channels
=
6
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
),
paddle
.
nn
.
BatchNorm2D
(
6
),
paddle
.
nn
.
ReLU
(),
paddle
.
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
),
paddle
.
nn
.
Conv2D
(
in_channels
=
6
,
out_channels
=
16
,
kernel_size
=
5
,
stride
=
1
,
padding
=
2
,
bias_attr
=
False
),
paddle
.
nn
.
BatchNorm2D
(
16
),
paddle
.
nn
.
PReLU
(),
paddle
.
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
))
self
.
fc
=
paddle
.
nn
.
Sequential
(
paddle
.
nn
.
Linear
(
in_features
=
784
,
out_features
=
120
),
paddle
.
nn
.
LeakyReLU
(),
paddle
.
nn
.
Linear
(
in_features
=
120
,
out_features
=
84
),
paddle
.
nn
.
Sigmoid
(),
paddle
.
nn
.
Linear
(
in_features
=
84
,
out_features
=
num_classes
),
paddle
.
nn
.
Softmax
())
def
forward
(
self
,
inputs
):
x
=
self
.
features
(
inputs
)
x
=
paddle
.
flatten
(
x
,
1
)
x
=
self
.
fc
(
x
)
return
x
class
TestRep
(
unittest
.
TestCase
):
"""
Test dygraph reparameterization.
"""
def
model_test
(
self
,
model
,
test_reader
):
model
.
eval
()
avg_acc
=
[[],
[]]
for
batch_id
,
data
in
enumerate
(
test_reader
):
img
=
paddle
.
to_tensor
(
data
[
0
])
img
=
paddle
.
reshape
(
img
,
[
-
1
,
1
,
28
,
28
])
label
=
paddle
.
to_tensor
(
data
[
1
])
label
=
paddle
.
reshape
(
label
,
[
-
1
,
1
])
out
=
model
(
img
)
acc_top1
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
1
)
acc_top5
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
5
)
avg_acc
[
0
].
append
(
acc_top1
.
numpy
())
avg_acc
[
1
].
append
(
acc_top5
.
numpy
())
if
batch_id
%
100
==
0
:
_logger
.
info
(
"Test | step {}: acc1 = {:}, acc5 = {:}"
.
format
(
batch_id
,
acc_top1
.
numpy
(),
acc_top5
.
numpy
()))
_logger
.
info
(
"Test |Average: acc_top1 {}, acc_top5 {}"
.
format
(
np
.
mean
(
avg_acc
[
0
]),
np
.
mean
(
avg_acc
[
1
])))
return
np
.
mean
(
avg_acc
[
0
]),
np
.
mean
(
avg_acc
[
1
])
def
model_train
(
self
,
model
,
train_reader
):
adam
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.0001
,
parameters
=
model
.
parameters
())
epoch_num
=
1
for
epoch
in
range
(
epoch_num
):
model
.
train
()
for
batch_id
,
data
in
enumerate
(
train_reader
):
img
=
paddle
.
to_tensor
(
data
[
0
])
label
=
paddle
.
to_tensor
(
data
[
1
])
img
=
paddle
.
reshape
(
img
,
[
-
1
,
1
,
28
,
28
])
label
=
paddle
.
reshape
(
label
,
[
-
1
,
1
])
out
=
model
(
img
)
acc
=
paddle
.
metric
.
accuracy
(
out
,
label
)
loss
=
paddle
.
nn
.
functional
.
loss
.
cross_entropy
(
out
,
label
)
avg_loss
=
paddle
.
mean
(
loss
)
avg_loss
.
backward
()
adam
.
minimize
(
avg_loss
)
model
.
clear_gradients
()
if
batch_id
%
100
==
0
:
_logger
.
info
(
"Train | At epoch {} step {}: loss = {:}, acc= {:}"
.
format
(
epoch
,
batch_id
,
avg_loss
.
numpy
(),
acc
.
numpy
()))
def
test_dbb
(
self
):
seed
=
1
np
.
random
.
seed
(
seed
)
paddle
.
static
.
default_main_program
().
random_seed
=
seed
paddle
.
static
.
default_startup_program
().
random_seed
=
seed
_logger
.
info
(
"create the fp32 model"
)
fp32_lenet
=
ImperativeLenet
()
_logger
.
info
(
"prepare data"
)
batch_size
=
64
transform
=
paddle
.
vision
.
transforms
.
Compose
([
paddle
.
vision
.
transforms
.
Transpose
(),
paddle
.
vision
.
transforms
.
Normalize
([
127.5
],
[
127.5
])
])
train_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
mode
=
'train'
,
backend
=
'cv2'
,
transform
=
transform
)
val_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
mode
=
'test'
,
backend
=
'cv2'
,
transform
=
transform
)
place
=
paddle
.
CUDAPlace
(
0
)
\
if
paddle
.
is_compiled_with_cuda
()
else
paddle
.
CPUPlace
()
train_reader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
drop_last
=
True
,
places
=
place
,
batch_size
=
batch_size
,
return_list
=
True
)
test_reader
=
paddle
.
io
.
DataLoader
(
val_dataset
,
places
=
place
,
batch_size
=
batch_size
,
return_list
=
True
)
_logger
.
info
(
"train the fp32 model"
)
self
.
model_train
(
fp32_lenet
,
train_reader
)
_logger
.
info
(
"test fp32 model"
)
fp32_top1
,
fp32_top5
=
self
.
model_test
(
fp32_lenet
,
test_reader
)
rep_config
=
DBBRepConfig
()
reper
=
Reparameter
(
rep_config
)
reper
.
prepare
(
fp32_lenet
)
_logger
.
info
(
"train the DBB reparameterization model"
)
self
.
model_train
(
fp32_lenet
,
train_reader
)
rep_top1
,
rep_top5
=
self
.
model_test
(
fp32_lenet
,
test_reader
)
_logger
.
info
(
"save and test the DBB reparameterization model"
)
reper
.
convert
(
fp32_lenet
)
save_path
=
"./tmp/model"
input_spec
=
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
1
,
28
,
28
],
dtype
=
'float32'
)
paddle
.
jit
.
save
(
fp32_lenet
,
save_path
,
input_spec
=
[
input_spec
])
_logger
.
info
(
"FP32 acc: top1: {}, top5: {}"
.
format
(
fp32_top1
,
fp32_top5
))
_logger
.
info
(
"Int acc: top1: {}, top5: {}"
.
format
(
rep_top1
,
rep_top5
))
diff
=
0.005
self
.
assertTrue
(
fp32_top1
-
rep_top1
<
diff
,
msg
=
"The acc of rep model is too lower than fp32 model"
)
def
test_acb
(
self
):
seed
=
1
np
.
random
.
seed
(
seed
)
paddle
.
static
.
default_main_program
().
random_seed
=
seed
paddle
.
static
.
default_startup_program
().
random_seed
=
seed
_logger
.
info
(
"create the fp32 model"
)
fp32_lenet
=
ImperativeLenet
()
_logger
.
info
(
"prepare data"
)
batch_size
=
64
transform
=
paddle
.
vision
.
transforms
.
Compose
([
paddle
.
vision
.
transforms
.
Transpose
(),
paddle
.
vision
.
transforms
.
Normalize
([
127.5
],
[
127.5
])
])
train_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
mode
=
'train'
,
backend
=
'cv2'
,
transform
=
transform
)
val_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
mode
=
'test'
,
backend
=
'cv2'
,
transform
=
transform
)
place
=
paddle
.
CUDAPlace
(
0
)
\
if
paddle
.
is_compiled_with_cuda
()
else
paddle
.
CPUPlace
()
train_reader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
drop_last
=
True
,
places
=
place
,
batch_size
=
batch_size
,
return_list
=
True
)
test_reader
=
paddle
.
io
.
DataLoader
(
val_dataset
,
places
=
place
,
batch_size
=
batch_size
,
return_list
=
True
)
_logger
.
info
(
"train the fp32 model"
)
self
.
model_train
(
fp32_lenet
,
train_reader
)
_logger
.
info
(
"test fp32 model"
)
fp32_top1
,
fp32_top5
=
self
.
model_test
(
fp32_lenet
,
test_reader
)
rep_config
=
ACBRepConfig
()
reper
=
Reparameter
(
rep_config
)
reper
.
prepare
(
fp32_lenet
)
_logger
.
info
(
"train the ACB reparameterization model"
)
self
.
model_train
(
fp32_lenet
,
train_reader
)
rep_top1
,
rep_top5
=
self
.
model_test
(
fp32_lenet
,
test_reader
)
_logger
.
info
(
"save and test the ACB reparameterization model"
)
reper
.
convert
(
fp32_lenet
)
save_path
=
"./tmp/model"
input_spec
=
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
1
,
28
,
28
],
dtype
=
'float32'
)
paddle
.
jit
.
save
(
fp32_lenet
,
save_path
,
input_spec
=
[
input_spec
])
_logger
.
info
(
"FP32 acc: top1: {}, top5: {}"
.
format
(
fp32_top1
,
fp32_top5
))
_logger
.
info
(
"Int acc: top1: {}, top5: {}"
.
format
(
rep_top1
,
rep_top5
))
diff
=
0.005
self
.
assertTrue
(
fp32_top1
-
rep_top1
<
diff
,
msg
=
"The acc of rep model is too lower than fp32 model"
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录