Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
a7aa1452
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
a7aa1452
编写于
4月 14, 2021
作者:
L
littletomatodonkey
提交者:
GitHub
4月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix repvgg eval (#677)
* fix repvgg eval * fix dp training * fix single card train
上级
2e62e2e2
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
165 addition
and
61 deletion
+165
-61
ppcls/modeling/architectures/repvgg.py
ppcls/modeling/architectures/repvgg.py
+156
-52
tools/eval.py
tools/eval.py
+5
-6
tools/train.py
tools/train.py
+4
-3
未找到文件。
ppcls/modeling/architectures/repvgg.py
浏览文件 @
a7aa1452
...
...
@@ -4,19 +4,39 @@ import numpy as np
__all__
=
[
'RepVGG'
,
'RepVGG_A0'
,
'RepVGG_A1'
,
'RepVGG_A2'
,
'RepVGG_B0'
,
'RepVGG_B1'
,
'RepVGG_B2'
,
'RepVGG_B3'
,
'RepVGG_B1g2'
,
'RepVGG_B1g4'
,
'RepVGG_B2g2'
,
'RepVGG_B2g4'
,
'RepVGG_B3g2'
,
'RepVGG_B3g4'
,
'RepVGG_A0'
,
'RepVGG_A1'
,
'RepVGG_A2'
,
'RepVGG_B0'
,
'RepVGG_B1'
,
'RepVGG_B2'
,
'RepVGG_B3'
,
'RepVGG_B1g2'
,
'RepVGG_B1g4'
,
'RepVGG_B2g2'
,
'RepVGG_B2g4'
,
'RepVGG_B3g2'
,
'RepVGG_B3g4'
,
]
class
ConvBN
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
groups
=
1
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
groups
=
1
):
super
(
ConvBN
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
bias_attr
=
False
)
self
.
conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
bias_attr
=
False
)
self
.
bn
=
nn
.
BatchNorm2D
(
num_features
=
out_channels
)
def
forward
(
self
,
x
):
...
...
@@ -26,9 +46,15 @@ class ConvBN(nn.Layer):
class
RepVGGBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
padding_mode
=
'zeros'
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
padding_mode
=
'zeros'
):
super
(
RepVGGBlock
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
...
...
@@ -47,11 +73,22 @@ class RepVGGBlock(nn.Layer):
self
.
nonlinearity
=
nn
.
ReLU
()
self
.
rbr_identity
=
nn
.
BatchNorm2D
(
num_features
=
in_channels
)
if
out_channels
==
in_channels
and
stride
==
1
else
None
self
.
rbr_dense
=
ConvBN
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
)
self
.
rbr_1x1
=
ConvBN
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding_11
,
groups
=
groups
)
num_features
=
in_channels
)
if
out_channels
==
in_channels
and
stride
==
1
else
None
self
.
rbr_dense
=
ConvBN
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
)
self
.
rbr_1x1
=
ConvBN
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding_11
,
groups
=
groups
)
def
forward
(
self
,
inputs
):
if
not
self
.
training
:
...
...
@@ -61,12 +98,20 @@ class RepVGGBlock(nn.Layer):
id_out
=
0
else
:
id_out
=
self
.
rbr_identity
(
inputs
)
return
self
.
nonlinearity
(
self
.
rbr_dense
(
inputs
)
+
self
.
rbr_1x1
(
inputs
)
+
id_out
)
return
self
.
nonlinearity
(
self
.
rbr_dense
(
inputs
)
+
self
.
rbr_1x1
(
inputs
)
+
id_out
)
def
eval
(
self
):
if
not
hasattr
(
self
,
'rbr_reparam'
):
self
.
rbr_reparam
=
nn
.
Conv2D
(
in_channels
=
self
.
in_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
self
.
kernel_size
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
groups
=
self
.
groups
,
padding_mode
=
self
.
padding_mode
)
self
.
rbr_reparam
=
nn
.
Conv2D
(
in_channels
=
self
.
in_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
self
.
kernel_size
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
groups
=
self
.
groups
,
padding_mode
=
self
.
padding_mode
)
self
.
training
=
False
kernel
,
bias
=
self
.
get_equivalent_kernel_bias
()
self
.
rbr_reparam
.
weight
.
set_value
(
kernel
)
...
...
@@ -78,7 +123,8 @@ class RepVGGBlock(nn.Layer):
kernel3x3
,
bias3x3
=
self
.
_fuse_bn_tensor
(
self
.
rbr_dense
)
kernel1x1
,
bias1x1
=
self
.
_fuse_bn_tensor
(
self
.
rbr_1x1
)
kernelid
,
biasid
=
self
.
_fuse_bn_tensor
(
self
.
rbr_identity
)
return
kernel3x3
+
self
.
_pad_1x1_to_3x3_tensor
(
kernel1x1
)
+
kernelid
,
bias3x3
+
bias1x1
+
biasid
return
kernel3x3
+
self
.
_pad_1x1_to_3x3_tensor
(
kernel1x1
)
+
kernelid
,
bias3x3
+
bias1x1
+
biasid
def
_pad_1x1_to_3x3_tensor
(
self
,
kernel1x1
):
if
kernel1x1
is
None
:
...
...
@@ -117,8 +163,11 @@ class RepVGGBlock(nn.Layer):
class
RepVGG
(
nn
.
Layer
):
def
__init__
(
self
,
num_blocks
,
width_multiplier
=
None
,
override_groups_map
=
None
,
class_dim
=
1000
):
def
__init__
(
self
,
num_blocks
,
width_multiplier
=
None
,
override_groups_map
=
None
,
class_dim
=
1000
):
super
(
RepVGG
,
self
).
__init__
()
assert
len
(
width_multiplier
)
==
4
...
...
@@ -129,7 +178,11 @@ class RepVGG(nn.Layer):
self
.
in_planes
=
min
(
64
,
int
(
64
*
width_multiplier
[
0
]))
self
.
stage0
=
RepVGGBlock
(
in_channels
=
3
,
out_channels
=
self
.
in_planes
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
in_channels
=
3
,
out_channels
=
self
.
in_planes
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
cur_layer_idx
=
1
self
.
stage1
=
self
.
_make_stage
(
int
(
64
*
width_multiplier
[
0
]),
num_blocks
[
0
],
stride
=
2
)
...
...
@@ -143,16 +196,28 @@ class RepVGG(nn.Layer):
self
.
linear
=
nn
.
Linear
(
int
(
512
*
width_multiplier
[
3
]),
class_dim
)
def
_make_stage
(
self
,
planes
,
num_blocks
,
stride
):
strides
=
[
stride
]
+
[
1
]
*
(
num_blocks
-
1
)
strides
=
[
stride
]
+
[
1
]
*
(
num_blocks
-
1
)
blocks
=
[]
for
stride
in
strides
:
cur_groups
=
self
.
override_groups_map
.
get
(
self
.
cur_layer_idx
,
1
)
blocks
.
append
(
RepVGGBlock
(
in_channels
=
self
.
in_planes
,
out_channels
=
planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
groups
=
cur_groups
))
blocks
.
append
(
RepVGGBlock
(
in_channels
=
self
.
in_planes
,
out_channels
=
planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
groups
=
cur_groups
))
self
.
in_planes
=
planes
self
.
cur_layer_idx
+=
1
return
nn
.
Sequential
(
*
blocks
)
def
eval
(
self
):
self
.
training
=
False
for
layer
in
self
.
sublayers
():
layer
.
training
=
False
layer
.
eval
()
def
forward
(
self
,
x
):
out
=
self
.
stage0
(
x
)
out
=
self
.
stage1
(
out
)
...
...
@@ -171,65 +236,104 @@ g4_map = {l: 4 for l in optional_groupwise_layers}
def
RepVGG_A0
(
**
kwargs
):
return
RepVGG
(
num_blocks
=
[
2
,
4
,
14
,
1
],
width_multiplier
=
[
0.75
,
0.75
,
0.75
,
2.5
],
override_groups_map
=
None
,
**
kwargs
)
return
RepVGG
(
num_blocks
=
[
2
,
4
,
14
,
1
],
width_multiplier
=
[
0.75
,
0.75
,
0.75
,
2.5
],
override_groups_map
=
None
,
**
kwargs
)
def
RepVGG_A1
(
**
kwargs
):
return
RepVGG
(
num_blocks
=
[
2
,
4
,
14
,
1
],
width_multiplier
=
[
1
,
1
,
1
,
2.5
],
override_groups_map
=
None
,
**
kwargs
)
return
RepVGG
(
num_blocks
=
[
2
,
4
,
14
,
1
],
width_multiplier
=
[
1
,
1
,
1
,
2.5
],
override_groups_map
=
None
,
**
kwargs
)
def
RepVGG_A2
(
**
kwargs
):
return
RepVGG
(
num_blocks
=
[
2
,
4
,
14
,
1
],
width_multiplier
=
[
1.5
,
1.5
,
1.5
,
2.75
],
override_groups_map
=
None
,
**
kwargs
)
return
RepVGG
(
num_blocks
=
[
2
,
4
,
14
,
1
],
width_multiplier
=
[
1.5
,
1.5
,
1.5
,
2.75
],
override_groups_map
=
None
,
**
kwargs
)
def
RepVGG_B0
(
**
kwargs
):
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
1
,
1
,
1
,
2.5
],
override_groups_map
=
None
,
**
kwargs
)
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
1
,
1
,
1
,
2.5
],
override_groups_map
=
None
,
**
kwargs
)
def
RepVGG_B1
(
**
kwargs
):
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
2
,
2
,
2
,
4
],
override_groups_map
=
None
,
**
kwargs
)
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
2
,
2
,
2
,
4
],
override_groups_map
=
None
,
**
kwargs
)
def
RepVGG_B1g2
(
**
kwargs
):
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
2
,
2
,
2
,
4
],
override_groups_map
=
g2_map
,
**
kwargs
)
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
2
,
2
,
2
,
4
],
override_groups_map
=
g2_map
,
**
kwargs
)
def
RepVGG_B1g4
(
**
kwargs
):
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
2
,
2
,
2
,
4
],
override_groups_map
=
g4_map
,
**
kwargs
)
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
2
,
2
,
2
,
4
],
override_groups_map
=
g4_map
,
**
kwargs
)
def
RepVGG_B2
(
**
kwargs
):
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
2.5
,
2.5
,
2.5
,
5
],
override_groups_map
=
None
,
**
kwargs
)
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
2.5
,
2.5
,
2.5
,
5
],
override_groups_map
=
None
,
**
kwargs
)
def
RepVGG_B2g2
(
**
kwargs
):
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
2.5
,
2.5
,
2.5
,
5
],
override_groups_map
=
g2_map
,
**
kwargs
)
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
2.5
,
2.5
,
2.5
,
5
],
override_groups_map
=
g2_map
,
**
kwargs
)
def
RepVGG_B2g4
(
**
kwargs
):
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
2.5
,
2.5
,
2.5
,
5
],
override_groups_map
=
g4_map
,
**
kwargs
)
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
2.5
,
2.5
,
2.5
,
5
],
override_groups_map
=
g4_map
,
**
kwargs
)
def
RepVGG_B3
(
**
kwargs
):
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
3
,
3
,
3
,
5
],
override_groups_map
=
None
,
**
kwargs
)
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
3
,
3
,
3
,
5
],
override_groups_map
=
None
,
**
kwargs
)
def
RepVGG_B3g2
(
**
kwargs
):
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
3
,
3
,
3
,
5
],
override_groups_map
=
g2_map
,
**
kwargs
)
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
3
,
3
,
3
,
5
],
override_groups_map
=
g2_map
,
**
kwargs
)
def
RepVGG_B3g4
(
**
kwargs
):
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
3
,
3
,
3
,
5
],
override_groups_map
=
g4_map
,
**
kwargs
)
return
RepVGG
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_multiplier
=
[
3
,
3
,
3
,
5
],
override_groups_map
=
g4_map
,
**
kwargs
)
tools/eval.py
浏览文件 @
a7aa1452
...
...
@@ -69,24 +69,23 @@ def main(args, return_dict={}):
paddle
.
distributed
.
init_parallel_env
()
net
=
program
.
create_model
(
config
.
ARCHITECTURE
,
config
.
classes_num
)
if
config
[
"use_data_parallel"
]:
net
=
paddle
.
DataParallel
(
net
)
init_model
(
config
,
net
,
optimizer
=
None
)
valid_dataloader
=
Reader
(
config
,
'valid'
,
places
=
place
)()
net
.
eval
()
with
paddle
.
no_grad
():
if
not
multilabel
:
top1_acc
=
program
.
run
(
valid_dataloader
,
config
,
net
,
None
,
None
,
0
,
'valid'
)
top1_acc
=
program
.
run
(
valid_dataloader
,
config
,
net
,
None
,
None
,
0
,
'valid'
)
return_dict
[
"top1_acc"
]
=
top1_acc
return
top1_acc
else
:
all_outs
=
[]
targets
=
[]
for
idx
,
batch
in
enumerate
(
valid_dataloader
()):
feeds
=
program
.
create_feeds
(
batch
,
False
,
config
.
classes_num
,
multilabel
)
for
_
,
batch
in
enumerate
(
valid_dataloader
()):
feeds
=
program
.
create_feeds
(
batch
,
False
,
config
.
classes_num
,
multilabel
)
out
=
net
(
feeds
[
"image"
])
out
=
F
.
sigmoid
(
out
)
...
...
tools/train.py
浏览文件 @
a7aa1452
...
...
@@ -69,9 +69,10 @@ def main(args):
optimizer
,
lr_scheduler
=
program
.
create_optimizer
(
config
,
parameter_list
=
net
.
parameters
())
dp_net
=
net
if
config
[
"use_data_parallel"
]:
find_unused_parameters
=
config
.
get
(
"find_unused_parameters"
,
False
)
net
=
paddle
.
DataParallel
(
dp_
net
=
paddle
.
DataParallel
(
net
,
find_unused_parameters
=
find_unused_parameters
)
# load model from checkpoint or pretrained model
...
...
@@ -96,8 +97,8 @@ def main(args):
for
epoch_id
in
range
(
last_epoch_id
+
1
,
config
.
epochs
):
net
.
train
()
# 1. train with train dataset
program
.
run
(
train_dataloader
,
config
,
net
,
optimizer
,
lr_schedul
er
,
epoch_id
,
'train'
,
vdl_writer
)
program
.
run
(
train_dataloader
,
config
,
dp_net
,
optimiz
er
,
lr_scheduler
,
epoch_id
,
'train'
,
vdl_writer
)
# 2. validate with validate dataset
if
config
.
validate
and
epoch_id
%
config
.
valid_interval
==
0
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录