Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
3a8b5680
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3a8b5680
编写于
12月 25, 2022
作者:
H
HydrogenSulfate
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(model): add EfficientNetV2 code and fix AttrDict BUG
上级
7a0c7965
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
1364 addition
and
9 deletion
+1364
-9
deploy/utils/config.py
deploy/utils/config.py
+1
-1
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/model_zoo/efficientnet_v2.py
ppcls/arch/backbone/model_zoo/efficientnet_v2.py
+991
-0
ppcls/configs/ImageNet/EfficientNetV2/EfficientNetV2_S.yaml
ppcls/configs/ImageNet/EfficientNetV2/EfficientNetV2_S.yaml
+142
-0
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+9
-0
ppcls/data/preprocess/ops/operators.py
ppcls/data/preprocess/ops/operators.py
+19
-0
ppcls/data/preprocess/ops/randaugment.py
ppcls/data/preprocess/ops/randaugment.py
+135
-6
ppcls/engine/train/__init__.py
ppcls/engine/train/__init__.py
+2
-1
ppcls/engine/train/train_efficientnetv2.py
ppcls/engine/train/train_efficientnetv2.py
+63
-0
ppcls/utils/config.py
ppcls/utils/config.py
+1
-1
未找到文件。
deploy/utils/config.py
浏览文件 @
3a8b5680
...
...
@@ -33,7 +33,7 @@ class AttrDict(dict):
self
[
key
]
=
value
def
__deepcopy__
(
self
,
content
):
return
copy
.
deepcopy
(
dict
(
self
))
return
AttrDict
(
copy
.
deepcopy
(
dict
(
self
)
))
def
create_attr_dict
(
yaml_config
):
...
...
ppcls/arch/backbone/__init__.py
浏览文件 @
3a8b5680
...
...
@@ -38,6 +38,7 @@ from .model_zoo.dpn import DPN68, DPN92, DPN98, DPN107, DPN131
from
.model_zoo.dsnet
import
DSNet_tiny
,
DSNet_small
,
DSNet_base
from
.model_zoo.densenet
import
DenseNet121
,
DenseNet161
,
DenseNet169
,
DenseNet201
,
DenseNet264
from
.model_zoo.efficientnet
import
EfficientNetB0
,
EfficientNetB1
,
EfficientNetB2
,
EfficientNetB3
,
EfficientNetB4
,
EfficientNetB5
,
EfficientNetB6
,
EfficientNetB7
,
EfficientNetB0_small
from
.model_zoo.efficientnet_v2
import
EfficientNetV2_S
from
.model_zoo.resnest
import
ResNeSt50_fast_1s1x64d
,
ResNeSt50
,
ResNeSt101
,
ResNeSt200
,
ResNeSt269
from
.model_zoo.googlenet
import
GoogLeNet
from
.model_zoo.mobilenet_v2
import
MobileNetV2_x0_25
,
MobileNetV2_x0_5
,
MobileNetV2_x0_75
,
MobileNetV2
,
MobileNetV2_x1_5
,
MobileNetV2_x2_0
...
...
ppcls/arch/backbone/model_zoo/efficientnet_v2.py
0 → 100644
浏览文件 @
3a8b5680
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Code was based on https://github.com/lukemelas/EfficientNet-PyTorch
# reference: https://arxiv.org/abs/1905.11946
import
math
import
re
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
paddle.nn.initializer
import
Constant
,
Normal
,
Uniform
from
paddle.regularizer
import
L2Decay
from
ppcls.utils.config
import
AttrDict
from
....utils.save_load
import
(
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
)
MODEL_URLS
=
{
"EfficientNetV2_S"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_S_pretrained.pdparams"
,
"EfficientNetV2_M"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_M_pretrained.pdparams"
,
"EfficientNetV2_L"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_L_pretrained.pdparams"
,
"EfficientNetV2_XL"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_XL_pretrained.pdparams"
,
}
__all__
=
list
(
MODEL_URLS
.
keys
())
inp_shape
=
{
"efficientnetv2-s"
:
[
384
,
192
,
192
,
96
,
48
,
24
,
24
,
12
],
"efficientnetv2-m"
:
[
384
,
192
,
192
,
96
,
48
,
24
,
24
,
12
],
"efficientnetv2-l"
:
[
384
,
192
,
192
,
96
,
48
,
24
,
24
,
12
],
"efficientnetv2-xl"
:
[
384
,
192
,
192
,
96
,
48
,
24
,
24
,
12
],
}
def
cal_padding
(
img_size
,
stride
,
kernel_size
):
"""Calculate padding size."""
if
img_size
%
stride
==
0
:
out_size
=
max
(
kernel_size
-
stride
,
0
)
else
:
out_size
=
max
(
kernel_size
-
(
img_size
%
stride
),
0
)
return
out_size
//
2
,
out_size
-
out_size
//
2
class
Conv2ds
(
nn
.
Layer
):
"""Customed Conv2D with tensorflow's padding style
Args:
input_channels (int): input channels
output_channels (int): output channels
kernel_size (int): filter size
stride (int, optional): stride. Defaults to 1.
padding (int, optional): padding. Defaults to 0.
groups (int, optional): groups. Defaults to None.
act (str, optional): act. Defaults to None.
use_bias (bool, optional): use_bias. Defaults to None.
padding_type (str, optional): padding_type. Defaults to None.
model_name (str, optional): model name. Defaults to None.
cur_stage (int, optional): current stage. Defaults to None.
Returns:
nn.Layer: Customed Conv2D instance
"""
def
__init__
(
self
,
input_channels
:
int
,
output_channels
:
int
,
kernel_size
:
int
,
stride
=
1
,
padding
=
0
,
groups
=
None
,
act
=
None
,
use_bias
=
None
,
padding_type
=
None
,
model_name
=
None
,
cur_stage
=
None
):
super
(
Conv2ds
,
self
).
__init__
()
assert
act
in
[
None
,
"swish"
,
"sigmoid"
]
self
.
_act
=
act
def
get_padding
(
kernel_size
,
stride
=
1
,
dilation
=
1
):
padding
=
((
stride
-
1
)
+
dilation
*
(
kernel_size
-
1
))
//
2
return
padding
inps
=
inp_shape
[
model_name
][
cur_stage
]
self
.
need_crop
=
False
if
padding_type
==
"SAME"
:
top_padding
,
bottom_padding
=
cal_padding
(
inps
,
stride
,
kernel_size
)
left_padding
,
right_padding
=
cal_padding
(
inps
,
stride
,
kernel_size
)
height_padding
=
bottom_padding
width_padding
=
right_padding
if
top_padding
!=
bottom_padding
or
left_padding
!=
right_padding
:
height_padding
=
top_padding
+
stride
width_padding
=
left_padding
+
stride
self
.
need_crop
=
True
padding
=
[
height_padding
,
width_padding
]
elif
padding_type
==
"VALID"
:
height_padding
=
0
width_padding
=
0
padding
=
[
height_padding
,
width_padding
]
elif
padding_type
==
"DYNAMIC"
:
padding
=
get_padding
(
kernel_size
,
stride
)
else
:
padding
=
padding_type
groups
=
1
if
groups
is
None
else
groups
self
.
_conv
=
nn
.
Conv2D
(
input_channels
,
output_channels
,
kernel_size
,
groups
=
groups
,
stride
=
stride
,
padding
=
padding
,
weight_attr
=
None
,
bias_attr
=
use_bias
if
not
use_bias
else
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
def
forward
(
self
,
inputs
):
x
=
self
.
_conv
(
inputs
)
if
self
.
_act
==
"swish"
:
x
=
F
.
swish
(
x
)
elif
self
.
_act
==
"sigmoid"
:
x
=
F
.
sigmoid
(
x
)
if
self
.
need_crop
:
x
=
x
[:,
:,
1
:,
1
:]
return
x
class
BlockDecoder
(
object
):
"""Block Decoder for readability."""
def
_decode_block_string
(
self
,
block_string
):
"""Gets a block through a string notation of arguments."""
assert
isinstance
(
block_string
,
str
)
ops
=
block_string
.
split
(
'_'
)
options
=
AttrDict
()
for
op
in
ops
:
splits
=
re
.
split
(
r
'(\d.*)'
,
op
)
if
len
(
splits
)
>=
2
:
key
,
value
=
splits
[:
2
]
options
[
key
]
=
value
t
=
AttrDict
(
kernel_size
=
int
(
options
[
'k'
]),
num_repeat
=
int
(
options
[
'r'
]),
in_channels
=
int
(
options
[
'i'
]),
out_channels
=
int
(
options
[
'o'
]),
expand_ratio
=
int
(
options
[
'e'
]),
se_ratio
=
float
(
options
[
'se'
])
if
'se'
in
options
else
None
,
strides
=
int
(
options
[
's'
]),
conv_type
=
int
(
options
[
'c'
])
if
'c'
in
options
else
0
,
)
return
t
def
_encode_block_string
(
self
,
block
):
"""Encodes a block to a string."""
args
=
[
'r%d'
%
block
.
num_repeat
,
'k%d'
%
block
.
kernel_size
,
's%d'
%
block
.
strides
,
'e%s'
%
block
.
expand_ratio
,
'i%d'
%
block
.
in_channels
,
'o%d'
%
block
.
out_channels
,
'c%d'
%
block
.
conv_type
,
'f%d'
%
block
.
fused_conv
,
]
if
block
.
se_ratio
>
0
and
block
.
se_ratio
<=
1
:
args
.
append
(
'se%s'
%
block
.
se_ratio
)
return
'_'
.
join
(
args
)
def
decode
(
self
,
string_list
):
"""Decodes a list of string notations to specify blocks inside the network.
Args:
string_list: a list of strings, each string is a notation of block.
Returns:
A list of namedtuples to represent blocks arguments.
"""
assert
isinstance
(
string_list
,
list
)
blocks_args
=
[]
for
block_string
in
string_list
:
blocks_args
.
append
(
self
.
_decode_block_string
(
block_string
))
return
blocks_args
def
encode
(
self
,
blocks_args
):
"""Encodes a list of Blocks to a list of strings.
Args:
blocks_args: A list of namedtuples to represent blocks arguments.
Returns:
a list of strings, each string is a notation of block.
"""
block_strings
=
[]
for
block
in
blocks_args
:
block_strings
.
append
(
self
.
_encode_block_string
(
block
))
return
block_strings
#################### EfficientNet V2 configs ####################
v2_base_block
=
[
# The baseline config for v2 models.
"r1_k3_s1_e1_i32_o16_c1"
,
"r2_k3_s2_e4_i16_o32_c1"
,
"r2_k3_s2_e4_i32_o48_c1"
,
"r3_k3_s2_e4_i48_o96_se0.25"
,
"r5_k3_s1_e6_i96_o112_se0.25"
,
"r8_k3_s2_e6_i112_o192_se0.25"
,
]
v2_s_block
=
[
# about base * (width1.4, depth1.8)
"r2_k3_s1_e1_i24_o24_c1"
,
"r4_k3_s2_e4_i24_o48_c1"
,
"r4_k3_s2_e4_i48_o64_c1"
,
"r6_k3_s2_e4_i64_o128_se0.25"
,
"r9_k3_s1_e6_i128_o160_se0.25"
,
"r15_k3_s2_e6_i160_o256_se0.25"
,
]
v2_m_block
=
[
# about base * (width1.6, depth2.2)
"r3_k3_s1_e1_i24_o24_c1"
,
"r5_k3_s2_e4_i24_o48_c1"
,
"r5_k3_s2_e4_i48_o80_c1"
,
"r7_k3_s2_e4_i80_o160_se0.25"
,
"r14_k3_s1_e6_i160_o176_se0.25"
,
"r18_k3_s2_e6_i176_o304_se0.25"
,
"r5_k3_s1_e6_i304_o512_se0.25"
,
]
v2_l_block
=
[
# about base * (width2.0, depth3.1)
"r4_k3_s1_e1_i32_o32_c1"
,
"r7_k3_s2_e4_i32_o64_c1"
,
"r7_k3_s2_e4_i64_o96_c1"
,
"r10_k3_s2_e4_i96_o192_se0.25"
,
"r19_k3_s1_e6_i192_o224_se0.25"
,
"r25_k3_s2_e6_i224_o384_se0.25"
,
"r7_k3_s1_e6_i384_o640_se0.25"
,
]
v2_xl_block
=
[
# only for 21k pretraining.
"r4_k3_s1_e1_i32_o32_c1"
,
"r8_k3_s2_e4_i32_o64_c1"
,
"r8_k3_s2_e4_i64_o96_c1"
,
"r16_k3_s2_e4_i96_o192_se0.25"
,
"r24_k3_s1_e6_i192_o256_se0.25"
,
"r32_k3_s2_e6_i256_o512_se0.25"
,
"r8_k3_s1_e6_i512_o640_se0.25"
,
]
efficientnetv2_params
=
{
# params: (block, width, depth, dropout)
"efficientnetv2-s"
:
(
v2_s_block
,
1.0
,
1.0
,
0.2
),
"efficientnetv2-m"
:
(
v2_m_block
,
1.0
,
1.0
,
0.3
),
"efficientnetv2-l"
:
(
v2_l_block
,
1.0
,
1.0
,
0.4
),
"efficientnetv2-xl"
:
(
v2_xl_block
,
1.0
,
1.0
,
0.4
),
}
def
efficientnetv2_config
(
model_name
:
str
):
"""EfficientNetV2 model config."""
block
,
width
,
depth
,
dropout
=
efficientnetv2_params
[
model_name
]
cfg
=
AttrDict
(
model
=
AttrDict
(
model_name
=
model_name
,
blocks_args
=
BlockDecoder
().
decode
(
block
),
width_coefficient
=
width
,
depth_coefficient
=
depth
,
dropout_rate
=
dropout
,
feature_size
=
1280
,
bn_momentum
=
0.9
,
bn_epsilon
=
1e-3
,
depth_divisor
=
8
,
min_depth
=
8
,
act_fn
=
"silu"
,
survival_prob
=
0.8
,
local_pooling
=
False
,
conv_dropout
=
None
,
num_classes
=
1000
))
return
cfg
def
get_model_config
(
model_name
:
str
):
"""Main entry for model name to config."""
if
model_name
.
startswith
(
"efficientnetv2-"
):
return
efficientnetv2_config
(
model_name
)
raise
ValueError
(
f
"Unknown model_name
{
model_name
}
"
)
################################################################################
def
round_filters
(
filters
,
width_coefficient
,
depth_divisor
,
min_depth
,
skip
=
False
):
"""Round number of filters based on depth multiplier."""
multiplier
=
width_coefficient
divisor
=
depth_divisor
min_depth
=
min_depth
if
skip
or
not
multiplier
:
return
filters
filters
*=
multiplier
min_depth
=
min_depth
or
divisor
new_filters
=
max
(
min_depth
,
int
(
filters
+
divisor
/
2
)
//
divisor
*
divisor
)
return
int
(
new_filters
)
def
round_repeats
(
repeats
,
multiplier
,
skip
=
False
):
"""Round number of filters based on depth multiplier."""
if
skip
or
not
multiplier
:
return
repeats
return
int
(
math
.
ceil
(
multiplier
*
repeats
))
def
activation_fn
(
act_fn
:
str
):
"""Customized non-linear activation type."""
if
not
act_fn
:
return
nn
.
Silu
()
elif
act_fn
in
(
"silu"
,
"swish"
):
return
nn
.
Swish
()
elif
act_fn
==
"relu"
:
return
nn
.
ReLU
()
elif
act_fn
==
"relu6"
:
return
nn
.
ReLU6
()
elif
act_fn
==
"elu"
:
return
nn
.
ELU
()
elif
act_fn
==
"leaky_relu"
:
return
nn
.
LeakyReLU
()
elif
act_fn
==
"selu"
:
return
nn
.
SELU
()
elif
act_fn
==
"mish"
:
return
nn
.
Mish
()
else
:
raise
ValueError
(
"Unsupported act_fn {}"
.
format
(
act_fn
))
def
drop_path
(
x
,
training
=
False
,
survival_prob
=
1.0
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if
not
training
:
return
x
shape
=
(
paddle
.
shape
(
x
)[
0
],
)
+
(
1
,
)
*
(
x
.
ndim
-
1
)
keep_prob
=
paddle
.
to_tensor
(
survival_prob
)
random_tensor
=
keep_prob
+
paddle
.
rand
(
shape
).
astype
(
x
.
dtype
)
random_tensor
=
paddle
.
floor
(
random_tensor
)
# binarize
output
=
x
.
divide
(
keep_prob
)
*
random_tensor
return
output
class
SE
(
nn
.
Layer
):
"""Squeeze-and-excitation layer.
Args:
local_pooling (bool): local_pooling
act_fn (str): act_fn
in_channels (int): in_channels
se_channels (int): se_channels
out_channels (int): out_channels
cur_stage (int): cur_stage
padding_type (str): padding_type
model_name (str): model_name
"""
def
__init__
(
self
,
local_pooling
:
bool
,
act_fn
:
str
,
in_channels
:
int
,
se_channels
:
int
,
out_channels
:
int
,
cur_stage
:
int
,
padding_type
:
str
,
model_name
:
str
):
super
(
SE
,
self
).
__init__
()
self
.
_local_pooling
=
local_pooling
self
.
_act
=
activation_fn
(
act_fn
)
# Squeeze and Excitation layer.
self
.
_se_reduce
=
Conv2ds
(
in_channels
,
se_channels
,
1
,
stride
=
1
,
padding_type
=
padding_type
,
model_name
=
model_name
,
cur_stage
=
cur_stage
)
self
.
_se_expand
=
Conv2ds
(
se_channels
,
out_channels
,
1
,
stride
=
1
,
padding_type
=
padding_type
,
model_name
=
model_name
,
cur_stage
=
cur_stage
)
def
forward
(
self
,
x
):
if
self
.
_local_pooling
:
se_tensor
=
F
.
adaptive_avg_pool2d
(
x
,
output_size
=
1
)
else
:
se_tensor
=
paddle
.
mean
(
x
,
axis
=
[
2
,
3
],
keepdim
=
True
)
se_tensor
=
self
.
_se_expand
(
self
.
_act
(
self
.
_se_reduce
(
se_tensor
)))
return
F
.
sigmoid
(
se_tensor
)
*
x
class
MBConvBlock
(
nn
.
Layer
):
"""A class of MBConv: Mobile Inverted Residual Bottleneck.
Args:
se_ratio (int): se_ratio
in_channels (int): in_channels
expand_ratio (int): expand_ratio
kernel_size (int): kernel_size
strides (int): strides
out_channels (int): out_channels
bn_momentum (float): bn_momentum
bn_epsilon (float): bn_epsilon
local_pooling (bool): local_pooling
conv_dropout (float): conv_dropout
cur_stage (int): cur_stage
padding_type (str): padding_type
model_name (str): model_name
"""
def
__init__
(
self
,
se_ratio
:
int
,
in_channels
:
int
,
expand_ratio
:
int
,
kernel_size
:
int
,
strides
:
int
,
out_channels
:
int
,
bn_momentum
:
float
,
bn_epsilon
:
float
,
local_pooling
:
bool
,
conv_dropout
:
float
,
cur_stage
:
int
,
padding_type
:
str
,
model_name
:
str
):
super
(
MBConvBlock
,
self
).
__init__
()
self
.
se_ratio
=
se_ratio
self
.
in_channels
=
in_channels
self
.
expand_ratio
=
expand_ratio
self
.
kernel_size
=
kernel_size
self
.
strides
=
strides
self
.
out_channels
=
out_channels
self
.
bn_momentum
=
bn_momentum
self
.
bn_epsilon
=
bn_epsilon
self
.
_local_pooling
=
local_pooling
self
.
act_fn
=
None
self
.
conv_dropout
=
conv_dropout
self
.
_act
=
activation_fn
(
None
)
self
.
_has_se
=
(
self
.
se_ratio
is
not
None
and
0
<
self
.
se_ratio
<=
1
)
"""Builds block according to the arguments."""
expand_channels
=
self
.
in_channels
*
self
.
expand_ratio
kernel_size
=
self
.
kernel_size
# Expansion phase. Called if not using fused convolutions and expansion
# phase is necessary.
if
self
.
expand_ratio
!=
1
:
self
.
_expand_conv
=
Conv2ds
(
self
.
in_channels
,
expand_channels
,
1
,
stride
=
1
,
use_bias
=
False
,
padding_type
=
padding_type
,
model_name
=
model_name
,
cur_stage
=
cur_stage
)
self
.
_norm0
=
nn
.
BatchNorm2D
(
expand_channels
,
self
.
bn_momentum
,
self
.
bn_epsilon
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
# Depth-wise convolution phase. Called if not using fused convolutions.
self
.
_depthwise_conv
=
Conv2ds
(
expand_channels
,
expand_channels
,
kernel_size
,
padding
=
kernel_size
//
2
,
stride
=
self
.
strides
,
groups
=
expand_channels
,
use_bias
=
False
,
padding_type
=
padding_type
,
model_name
=
model_name
,
cur_stage
=
cur_stage
)
self
.
_norm1
=
nn
.
BatchNorm2D
(
expand_channels
,
self
.
bn_momentum
,
self
.
bn_epsilon
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
if
self
.
_has_se
:
num_reduced_filters
=
max
(
1
,
int
(
self
.
in_channels
*
self
.
se_ratio
))
self
.
_se
=
SE
(
self
.
_local_pooling
,
None
,
expand_channels
,
num_reduced_filters
,
expand_channels
,
cur_stage
,
padding_type
,
model_name
)
else
:
self
.
_se
=
None
# Output phase.
self
.
_project_conv
=
Conv2ds
(
expand_channels
,
self
.
out_channels
,
1
,
stride
=
1
,
use_bias
=
False
,
padding_type
=
padding_type
,
model_name
=
model_name
,
cur_stage
=
cur_stage
)
self
.
_norm2
=
nn
.
BatchNorm2D
(
self
.
out_channels
,
self
.
bn_momentum
,
self
.
bn_epsilon
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
self
.
drop_out
=
nn
.
Dropout
(
self
.
conv_dropout
)
def
residual
(
self
,
inputs
,
x
,
survival_prob
):
if
(
self
.
strides
==
1
and
self
.
in_channels
==
self
.
out_channels
):
# Apply only if skip connection presents.
if
survival_prob
:
x
=
drop_path
(
x
,
self
.
training
,
survival_prob
)
x
=
paddle
.
add
(
x
,
inputs
)
return
x
def
forward
(
self
,
inputs
,
survival_prob
=
None
):
"""Implementation of call().
Args:
inputs: the inputs tensor.
survival_prob: float, between 0 to 1, drop connect rate.
Returns:
A output tensor.
"""
x
=
inputs
if
self
.
expand_ratio
!=
1
:
x
=
self
.
_act
(
self
.
_norm0
(
self
.
_expand_conv
(
x
)))
x
=
self
.
_act
(
self
.
_norm1
(
self
.
_depthwise_conv
(
x
)))
if
self
.
conv_dropout
and
self
.
expand_ratio
>
1
:
x
=
self
.
drop_out
(
x
)
if
self
.
_se
:
x
=
self
.
_se
(
x
)
x
=
self
.
_norm2
(
self
.
_project_conv
(
x
))
x
=
self
.
residual
(
inputs
,
x
,
survival_prob
)
return
x
class
FusedMBConvBlock
(
MBConvBlock
):
"""Fusing the proj conv1x1 and depthwise_conv into a conv2d."""
def
__init__
(
self
,
se_ratio
,
in_channels
,
expand_ratio
,
kernel_size
,
strides
,
out_channels
,
bn_momentum
,
bn_epsilon
,
local_pooling
,
conv_dropout
,
cur_stage
,
padding_type
,
model_name
):
"""Builds block according to the arguments."""
super
(
MBConvBlock
,
self
).
__init__
()
self
.
se_ratio
=
se_ratio
self
.
in_channels
=
in_channels
self
.
expand_ratio
=
expand_ratio
self
.
kernel_size
=
kernel_size
self
.
strides
=
strides
self
.
out_channels
=
out_channels
self
.
bn_momentum
=
bn_momentum
self
.
bn_epsilon
=
bn_epsilon
self
.
_local_pooling
=
local_pooling
self
.
act_fn
=
None
self
.
conv_dropout
=
conv_dropout
self
.
_act
=
activation_fn
(
None
)
self
.
_has_se
=
(
self
.
se_ratio
is
not
None
and
0
<
self
.
se_ratio
<=
1
)
expand_channels
=
self
.
in_channels
*
self
.
expand_ratio
kernel_size
=
self
.
kernel_size
if
self
.
expand_ratio
!=
1
:
# Expansion phase:
self
.
_expand_conv
=
Conv2ds
(
self
.
in_channels
,
expand_channels
,
kernel_size
,
padding
=
kernel_size
//
2
,
stride
=
self
.
strides
,
use_bias
=
False
,
padding_type
=
padding_type
,
model_name
=
model_name
,
cur_stage
=
cur_stage
)
self
.
_norm0
=
nn
.
BatchNorm2D
(
expand_channels
,
self
.
bn_momentum
,
self
.
bn_epsilon
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
if
self
.
_has_se
:
num_reduced_filters
=
max
(
1
,
int
(
self
.
in_channels
*
self
.
se_ratio
))
self
.
_se
=
SE
(
self
.
_local_pooling
,
None
,
expand_channels
,
num_reduced_filters
,
expand_channels
,
cur_stage
,
padding_type
,
model_name
)
else
:
self
.
_se
=
None
# Output phase:
self
.
_project_conv
=
Conv2ds
(
expand_channels
,
self
.
out_channels
,
1
if
(
self
.
expand_ratio
!=
1
)
else
kernel_size
,
padding
=
(
1
if
(
self
.
expand_ratio
!=
1
)
else
kernel_size
)
//
2
,
stride
=
1
if
(
self
.
expand_ratio
!=
1
)
else
self
.
strides
,
use_bias
=
False
,
padding_type
=
padding_type
,
model_name
=
model_name
,
cur_stage
=
cur_stage
)
self
.
_norm1
=
nn
.
BatchNorm2D
(
self
.
out_channels
,
self
.
bn_momentum
,
self
.
bn_epsilon
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
self
.
drop_out
=
nn
.
Dropout
(
conv_dropout
)
def
forward
(
self
,
inputs
,
survival_prob
=
None
):
"""Implementation of call().
Args:
inputs: the inputs tensor.
training: boolean, whether the model is constructed for training.
survival_prob: float, between 0 to 1, drop connect rate.
Returns:
A output tensor.
"""
x
=
inputs
if
self
.
expand_ratio
!=
1
:
x
=
self
.
_act
(
self
.
_norm0
(
self
.
_expand_conv
(
x
)))
if
self
.
conv_dropout
and
self
.
expand_ratio
>
1
:
x
=
self
.
drop_out
(
x
)
if
self
.
_se
:
x
=
self
.
_se
(
x
)
x
=
self
.
_norm1
(
self
.
_project_conv
(
x
))
if
self
.
expand_ratio
==
1
:
x
=
self
.
_act
(
x
)
# add act if no expansion.
x
=
self
.
residual
(
inputs
,
x
,
survival_prob
)
return
x
class
Stem
(
nn
.
Layer
):
"""Stem layer at the begining of the network."""
def
__init__
(
self
,
width_coefficient
,
depth_divisor
,
min_depth
,
skip
,
bn_momentum
,
bn_epsilon
,
act_fn
,
stem_channels
,
cur_stage
,
padding_type
,
model_name
):
super
(
Stem
,
self
).
__init__
()
self
.
_conv_stem
=
Conv2ds
(
3
,
round_filters
(
stem_channels
,
width_coefficient
,
depth_divisor
,
min_depth
,
skip
),
3
,
padding
=
1
,
stride
=
2
,
use_bias
=
False
,
padding_type
=
padding_type
,
model_name
=
model_name
,
cur_stage
=
cur_stage
)
self
.
_norm
=
nn
.
BatchNorm2D
(
round_filters
(
stem_channels
,
width_coefficient
,
depth_divisor
,
min_depth
,
skip
),
bn_momentum
,
bn_epsilon
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
self
.
_act
=
activation_fn
(
act_fn
)
def
forward
(
self
,
inputs
):
return
self
.
_act
(
self
.
_norm
(
self
.
_conv_stem
(
inputs
)))
class
Head
(
nn
.
Layer
):
"""Head layer for network outputs."""
def
__init__
(
self
,
in_channels
,
feature_size
,
bn_momentum
,
bn_epsilon
,
act_fn
,
dropout_rate
,
local_pooling
,
width_coefficient
,
depth_divisor
,
min_depth
,
skip
=
False
):
super
(
Head
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
feature_size
=
feature_size
self
.
bn_momentum
=
bn_momentum
self
.
bn_epsilon
=
bn_epsilon
self
.
dropout_rate
=
dropout_rate
self
.
_local_pooling
=
local_pooling
self
.
_conv_head
=
nn
.
Conv2D
(
in_channels
,
round_filters
(
self
.
feature_size
or
1280
,
width_coefficient
,
depth_divisor
,
min_depth
,
skip
),
kernel_size
=
1
,
stride
=
1
,
bias_attr
=
False
)
self
.
_norm
=
nn
.
BatchNorm2D
(
round_filters
(
self
.
feature_size
or
1280
,
width_coefficient
,
depth_divisor
,
min_depth
,
skip
),
self
.
bn_momentum
,
self
.
bn_epsilon
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
self
.
_act
=
activation_fn
(
act_fn
)
self
.
_avg_pooling
=
nn
.
AdaptiveAvgPool2D
(
output_size
=
1
)
if
self
.
dropout_rate
>
0
:
self
.
_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
else
:
self
.
_dropout
=
None
def
forward
(
self
,
x
):
"""Call the layer."""
outputs
=
self
.
_act
(
self
.
_norm
(
self
.
_conv_head
(
x
)))
if
self
.
_local_pooling
:
outputs
=
F
.
adaptive_avg_pool2d
(
outputs
,
output_size
=
1
)
if
self
.
_dropout
:
outputs
=
self
.
_dropout
(
outputs
)
if
self
.
_fc
:
outputs
=
paddle
.
squeeze
(
outputs
,
axis
=
[
2
,
3
])
outputs
=
self
.
_fc
(
outputs
)
else
:
outputs
=
self
.
_avg_pooling
(
outputs
)
if
self
.
_dropout
:
outputs
=
self
.
_dropout
(
outputs
)
return
paddle
.
flatten
(
outputs
,
start_axis
=
1
)
class
EfficientNetV2
(
nn
.
Layer
):
"""A class implements tf.keras.Model.
Reference: https://arxiv.org/abs/1807.11626
"""
def
__init__
(
self
,
model_name
,
blocks_args
=
None
,
mconfig
=
None
,
include_top
=
True
,
class_num
=
1000
,
padding_type
=
"SAME"
):
"""Initializes an `Model` instance.
Args:
model_name: A string of model name.
model_config: A dict of model configurations or a string of hparams.
Raises:
ValueError: when blocks_args is not specified as a list.
"""
super
(
EfficientNetV2
,
self
).
__init__
()
self
.
blocks_args
=
blocks_args
self
.
mconfig
=
mconfig
"""Builds a model."""
self
.
_blocks
=
nn
.
LayerList
()
cur_stage
=
0
# Stem part.
self
.
_stem
=
Stem
(
self
.
mconfig
.
width_coefficient
,
self
.
mconfig
.
depth_divisor
,
self
.
mconfig
.
min_depth
,
False
,
self
.
mconfig
.
bn_momentum
,
self
.
mconfig
.
bn_epsilon
,
self
.
mconfig
.
act_fn
,
stem_channels
=
self
.
blocks_args
[
0
].
in_channels
,
cur_stage
=
cur_stage
,
padding_type
=
padding_type
,
model_name
=
model_name
)
cur_stage
+=
1
# Builds blocks.
for
block_args
in
self
.
blocks_args
:
assert
block_args
.
num_repeat
>
0
# Update block input and output filters based on depth multiplier.
in_channels
=
round_filters
(
block_args
.
in_channels
,
self
.
mconfig
.
width_coefficient
,
self
.
mconfig
.
depth_divisor
,
self
.
mconfig
.
min_depth
,
False
)
out_channels
=
round_filters
(
block_args
.
out_channels
,
self
.
mconfig
.
width_coefficient
,
self
.
mconfig
.
depth_divisor
,
self
.
mconfig
.
min_depth
,
False
)
repeats
=
round_repeats
(
block_args
.
num_repeat
,
self
.
mconfig
.
depth_coefficient
)
block_args
.
update
(
dict
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
num_repeat
=
repeats
))
# The first block needs to take care of stride and filter size increase.
conv_block
=
{
0
:
MBConvBlock
,
1
:
FusedMBConvBlock
}[
block_args
.
conv_type
]
self
.
_blocks
.
append
(
conv_block
(
block_args
.
se_ratio
,
block_args
.
in_channels
,
block_args
.
expand_ratio
,
block_args
.
kernel_size
,
block_args
.
strides
,
block_args
.
out_channels
,
self
.
mconfig
.
bn_momentum
,
self
.
mconfig
.
bn_epsilon
,
self
.
mconfig
.
local_pooling
,
self
.
mconfig
.
conv_dropout
,
cur_stage
,
padding_type
,
model_name
))
if
block_args
.
num_repeat
>
1
:
# rest of blocks with the same block_arg
block_args
.
in_channels
=
block_args
.
out_channels
block_args
.
strides
=
1
for
_
in
range
(
block_args
.
num_repeat
-
1
):
self
.
_blocks
.
append
(
conv_block
(
block_args
.
se_ratio
,
block_args
.
in_channels
,
block_args
.
expand_ratio
,
block_args
.
kernel_size
,
block_args
.
strides
,
block_args
.
out_channels
,
self
.
mconfig
.
bn_momentum
,
self
.
mconfig
.
bn_epsilon
,
self
.
mconfig
.
local_pooling
,
self
.
mconfig
.
conv_dropout
,
cur_stage
,
padding_type
,
model_name
))
cur_stage
+=
1
# Head part.
self
.
_head
=
Head
(
self
.
blocks_args
[
-
1
].
out_channels
,
self
.
mconfig
.
feature_size
,
self
.
mconfig
.
bn_momentum
,
self
.
mconfig
.
bn_epsilon
,
self
.
mconfig
.
act_fn
,
self
.
mconfig
.
dropout_rate
,
self
.
mconfig
.
local_pooling
,
self
.
mconfig
.
width_coefficient
,
self
.
mconfig
.
depth_divisor
,
self
.
mconfig
.
min_depth
,
False
)
# top part for classification
if
include_top
and
class_num
:
self
.
_fc
=
nn
.
Linear
(
self
.
mconfig
.
feature_size
,
class_num
,
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
else
:
self
.
_fc
=
None
# initialize weight
def
_init_weights
(
m
):
if
isinstance
(
m
,
nn
.
Conv2D
):
out_filters
,
in_channels
,
kernel_height
,
kernel_width
=
m
.
weight
.
shape
if
in_channels
==
1
and
out_filters
>
in_channels
:
out_filters
=
in_channels
fan_out
=
int
(
kernel_height
*
kernel_width
*
out_filters
)
Normal
(
mean
=
0.0
,
std
=
np
.
sqrt
(
2.0
/
fan_out
))(
m
.
weight
)
elif
isinstance
(
m
,
nn
.
Linear
):
init_range
=
1.0
/
np
.
sqrt
(
m
.
weight
.
shape
[
1
])
Uniform
(
-
init_range
,
init_range
)(
m
.
weight
)
Constant
(
0.0
)(
m
.
bias
)
self
.
apply
(
_init_weights
)
def
forward
(
self
,
inputs
):
# Calls Stem layers
outputs
=
self
.
_stem
(
inputs
)
# print(f"stem: {outputs.mean().item():.10f}")
# Calls blocks.
for
idx
,
block
in
enumerate
(
self
.
_blocks
):
survival_prob
=
self
.
mconfig
.
survival_prob
if
survival_prob
:
drop_rate
=
1.0
-
survival_prob
survival_prob
=
1.0
-
drop_rate
*
float
(
idx
)
/
len
(
self
.
_blocks
)
outputs
=
block
(
outputs
,
survival_prob
=
survival_prob
)
# Head to obtain the final feature.
outputs
=
self
.
_head
(
outputs
)
# Calls final dense layers and returns logits.
if
self
.
_fc
:
outputs
=
self
.
_fc
(
outputs
)
return
outputs
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
=
False
):
if
pretrained
is
False
:
pass
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def
EfficientNetV2_S
(
include_top
=
True
,
pretrained
=
False
,
**
kwargs
):
"""Get a V2 model instance.
Returns:
nn.Layer: A single model instantce
"""
model_name
=
"efficientnetv2-s"
model_config
=
efficientnetv2_config
(
model_name
)
model
=
EfficientNetV2
(
model_name
,
model_config
.
model
.
blocks_args
,
model_config
.
model
,
include_top
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"EfficientNetV2_S"
])
return
model
def
EfficientNetV2_M
(
include_top
=
True
,
pretrained
=
False
,
**
kwargs
):
"""Get a V2 model instance.
Returns:
nn.Layer: A single model instantce
"""
model_name
=
"efficientnetv2-m"
model_config
=
efficientnetv2_config
(
model_name
)
model
=
EfficientNetV2
(
model_name
,
model_config
.
model
.
blocks_args
,
model_config
.
model
,
include_top
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"EfficientNetV2_M"
])
return
model
def
EfficientNetV2_L
(
include_top
=
True
,
pretrained
=
False
,
**
kwargs
):
"""Get a V2 model instance.
Returns:
nn.Layer: A single model instantce
"""
model_name
=
"efficientnetv2-l"
model_config
=
efficientnetv2_config
(
model_name
)
model
=
EfficientNetV2
(
model_name
,
model_config
.
model
.
blocks_args
,
model_config
.
model
,
include_top
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"EfficientNetV2_L"
])
return
model
def
EfficientNetV2_XL
(
include_top
=
True
,
pretrained
=
False
,
**
kwargs
):
"""Get a V2 model instance.
Returns:
nn.Layer: A single model instantce
"""
model_name
=
"efficientnetv2-xl"
model_config
=
efficientnetv2_config
(
model_name
)
model
=
EfficientNetV2
(
model_name
,
model_config
.
model
.
blocks_args
,
model_config
.
model
,
include_top
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"EfficientNetV2_XL"
])
return
model
ppcls/configs/ImageNet/EfficientNetV2/EfficientNetV2_S.yaml
0 → 100644
浏览文件 @
3a8b5680
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
100
eval_during_train
:
True
eval_interval
:
1
epochs
:
350
print_batch_step
:
20
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
384
,
384
]
save_inference_dir
:
./inference
train_mode
:
efficientnetv2
# progressive training
AMP
:
scale_loss
:
65536
use_dynamic_loss_scaling
:
True
# O1: mixed fp16
level
:
O1
EMA
:
decay
:
0.9999
# model architecture
Arch
:
name
:
EfficientNetV2_S
class_num
:
1000
use_sync_bn
:
True
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.65
# 8gpux128bs
warmup_epoch
:
5
regularizer
:
name
:
L2
coeff
:
0.00001
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
scale
:
[
0.05
,
1.0
]
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
RandAugmentV2
:
num_layers
:
2
magnitude
:
5
-
NormalizeImage
:
scale
:
1.0
mean
:
[
128.0
,
128.0
,
128.0
]
std
:
[
128.0
,
128.0
,
128.0
]
order
:
"
"
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
True
shuffle
:
True
loader
:
num_workers
:
8
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
CropImageAtRatio
:
size
:
384
pad
:
32
interpolation
:
bilinear
-
NormalizeImage
:
scale
:
1.0
mean
:
[
128.0
,
128.0
,
128.0
]
std
:
[
128.0
,
128.0
,
128.0
]
order
:
"
"
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
8
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
CropImageAtRatio
:
size
:
384
pad
:
32
interpolation
:
bilinear
-
NormalizeImage
:
scale
:
1.0
mean
:
[
128.0
,
128.0
,
128.0
]
std
:
[
128.0
,
128.0
,
128.0
]
order
:
"
"
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/data/preprocess/__init__.py
浏览文件 @
3a8b5680
...
...
@@ -15,6 +15,7 @@
from
ppcls.data.preprocess.ops.autoaugment
import
ImageNetPolicy
as
RawImageNetPolicy
from
ppcls.data.preprocess.ops.randaugment
import
RandAugment
as
RawRandAugment
from
ppcls.data.preprocess.ops.randaugment
import
RandomApply
from
ppcls.data.preprocess.ops.randaugment
import
RandAugmentV2
as
RawRandAugmentV2
from
ppcls.data.preprocess.ops.timm_autoaugment
import
RawTimmAutoAugment
from
ppcls.data.preprocess.ops.cutout
import
Cutout
...
...
@@ -25,6 +26,7 @@ from ppcls.data.preprocess.ops.grid import GridMask
from
ppcls.data.preprocess.ops.operators
import
DecodeImage
from
ppcls.data.preprocess.ops.operators
import
ResizeImage
from
ppcls.data.preprocess.ops.operators
import
CropImage
from
ppcls.data.preprocess.ops.operators
import
CropImageAtRatio
from
ppcls.data.preprocess.ops.operators
import
CenterCrop
,
Resize
from
ppcls.data.preprocess.ops.operators
import
RandCropImage
from
ppcls.data.preprocess.ops.operators
import
RandCropImageV2
...
...
@@ -101,6 +103,13 @@ class RandAugment(RawRandAugment):
return
img
class
RandAugmentV2
(
RawRandAugmentV2
):
""" RandAugmentV2 wrapper to auto fit different img types """
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
class
TimmAutoAugment
(
RawTimmAutoAugment
):
""" TimmAutoAugment wrapper to auto fit different img tyeps. """
...
...
ppcls/data/preprocess/ops/operators.py
浏览文件 @
3a8b5680
...
...
@@ -319,6 +319,25 @@ class CropImage(object):
return
img
[
h_start
:
h_end
,
w_start
:
w_end
,
:]
class
CropImageAtRatio
(
object
):
""" crop image with specified size and padding"""
def
__init__
(
self
,
size
:
int
,
pad
:
int
,
interpolation
=
"bilinear"
):
self
.
size
=
size
self
.
ratio
=
size
/
(
size
+
pad
)
self
.
interpolation
=
interpolation
def
__call__
(
self
,
img
):
height
,
width
=
img
.
shape
[:
2
]
crop_size
=
int
(
self
.
ratio
*
min
(
height
,
width
))
y
=
(
height
-
crop_size
)
//
2
x
=
(
width
-
crop_size
)
//
2
crop_img
=
img
[
y
:
y
+
crop_size
,
x
:
x
+
crop_size
,
:]
return
F
.
resize
(
crop_img
,
[
self
.
size
,
self
.
size
],
self
.
interpolation
)
class
Padv2
(
object
):
def
__init__
(
self
,
size
=
None
,
...
...
ppcls/data/preprocess/ops/randaugment.py
浏览文件 @
3a8b5680
...
...
@@ -15,12 +15,60 @@
# This code is based on https://github.com/heartInsert/randaugment
# reference: https://arxiv.org/abs/1909.13719
from
PIL
import
Image
,
ImageEnhance
,
ImageOps
import
numpy
as
np
import
random
from
.operators
import
RawColorJitter
from
paddle.vision.transforms
import
transforms
as
T
import
numpy
as
np
from
PIL
import
Image
,
ImageEnhance
,
ImageOps
def
solarize_add
(
img
,
add
,
thresh
=
128
,
**
__
):
lut
=
[]
for
i
in
range
(
256
):
if
i
<
thresh
:
lut
.
append
(
min
(
255
,
i
+
add
))
else
:
lut
.
append
(
i
)
if
img
.
mode
in
(
"L"
,
"RGB"
):
if
img
.
mode
==
"RGB"
and
len
(
lut
)
==
256
:
lut
=
lut
+
lut
+
lut
return
img
.
point
(
lut
)
else
:
return
img
def
cutout
(
image
,
pad_size
,
replace
=
0
):
image_np
=
np
.
array
(
image
)
image_height
,
image_width
,
_
=
image_np
.
shape
# Sample the center location in the image where the zero mask will be applied.
cutout_center_height
=
np
.
random
.
randint
(
0
,
image_height
+
1
)
cutout_center_width
=
np
.
random
.
randint
(
0
,
image_width
+
1
)
lower_pad
=
np
.
maximum
(
0
,
cutout_center_height
-
pad_size
)
upper_pad
=
np
.
maximum
(
0
,
image_height
-
cutout_center_height
-
pad_size
)
left_pad
=
np
.
maximum
(
0
,
cutout_center_width
-
pad_size
)
right_pad
=
np
.
maximum
(
0
,
image_width
-
cutout_center_width
-
pad_size
)
cutout_shape
=
[
image_height
-
(
lower_pad
+
upper_pad
),
image_width
-
(
left_pad
+
right_pad
)
]
padding_dims
=
[[
lower_pad
,
upper_pad
],
[
left_pad
,
right_pad
]]
mask
=
np
.
pad
(
np
.
zeros
(
cutout_shape
,
dtype
=
image_np
.
dtype
),
padding_dims
,
constant_values
=
1
)
mask
=
np
.
expand_dims
(
mask
,
-
1
)
mask
=
np
.
tile
(
mask
,
[
1
,
1
,
3
])
image_np
=
np
.
where
(
np
.
equal
(
mask
,
0
),
np
.
full_like
(
image_np
,
fill_value
=
replace
,
dtype
=
image_np
.
dtype
),
image_np
)
return
Image
.
fromarray
(
image_np
)
class
RandAugment
(
object
):
def
__init__
(
self
,
num_layers
=
2
,
magnitude
=
5
,
fillcolor
=
(
128
,
128
,
128
)):
...
...
@@ -95,10 +143,10 @@ class RandAugment(object):
"brightness"
:
lambda
img
,
magnitude
:
ImageEnhance
.
Brightness
(
img
).
enhance
(
1
+
magnitude
*
rnd_ch_op
([
-
1
,
1
])),
"autocontrast"
:
lambda
img
,
magnitude
:
"autocontrast"
:
lambda
img
,
_
:
ImageOps
.
autocontrast
(
img
),
"equalize"
:
lambda
img
,
magnitude
:
ImageOps
.
equalize
(
img
),
"invert"
:
lambda
img
,
magnitude
:
ImageOps
.
invert
(
img
)
"equalize"
:
lambda
img
,
_
:
ImageOps
.
equalize
(
img
),
"invert"
:
lambda
img
,
_
:
ImageOps
.
invert
(
img
)
}
def
__call__
(
self
,
img
):
...
...
@@ -121,4 +169,85 @@ class RandomApply(object):
def
__call__
(
self
,
img
):
timg
=
self
.
trans
(
img
)
return
timg
\ No newline at end of file
return
timg
## RandAugment_EfficientNetV2 code below ##
class
RandAugmentV2
(
RandAugment
):
"""Customed RandAugment for EfficientNetV2"""
def
__init__
(
self
,
num_layers
=
2
,
magnitude
=
5
,
fillcolor
=
(
128
,
128
,
128
)):
super
().
__init__
(
num_layers
,
magnitude
,
fillcolor
)
abso_level
=
self
.
magnitude
/
self
.
max_level
# [5.0~10.0/10.0]=[0.5, 1.0]
self
.
level_map
=
{
"shearX"
:
0.3
*
abso_level
,
"shearY"
:
0.3
*
abso_level
,
"translateX"
:
100.0
*
abso_level
,
"translateY"
:
100.0
*
abso_level
,
"rotate"
:
30
*
abso_level
,
"color"
:
1.8
*
abso_level
+
0.1
,
"posterize"
:
int
(
4.0
*
abso_level
),
"solarize"
:
int
(
256.0
*
abso_level
),
"solarize_add"
:
int
(
110.0
*
abso_level
),
"contrast"
:
1.8
*
abso_level
+
0.1
,
"sharpness"
:
1.8
*
abso_level
+
0.1
,
"brightness"
:
1.8
*
abso_level
+
0.1
,
"autocontrast"
:
0
,
"equalize"
:
0
,
"invert"
:
0
,
"cutout"
:
int
(
40
*
abso_level
)
}
def
rotate_with_fill
(
img
,
magnitude
):
rot
=
img
.
convert
(
"RGBA"
).
rotate
(
magnitude
)
return
Image
.
composite
(
rot
,
Image
.
new
(
"RGBA"
,
rot
.
size
,
(
128
,
)
*
4
),
rot
).
convert
(
img
.
mode
)
rnd_ch_op
=
random
.
choice
self
.
func
=
{
"shearX"
:
lambda
img
,
magnitude
:
img
.
transform
(
img
.
size
,
Image
.
AFFINE
,
(
1
,
magnitude
*
rnd_ch_op
([
-
1
,
1
]),
0
,
0
,
1
,
0
),
Image
.
NEAREST
,
fillcolor
=
fillcolor
),
"shearY"
:
lambda
img
,
magnitude
:
img
.
transform
(
img
.
size
,
Image
.
AFFINE
,
(
1
,
0
,
0
,
magnitude
*
rnd_ch_op
([
-
1
,
1
]),
1
,
0
),
Image
.
NEAREST
,
fillcolor
=
fillcolor
),
"translateX"
:
lambda
img
,
magnitude
:
img
.
transform
(
img
.
size
,
Image
.
AFFINE
,
(
1
,
0
,
magnitude
*
rnd_ch_op
([
-
1
,
1
]),
0
,
1
,
0
),
Image
.
NEAREST
,
fillcolor
=
fillcolor
),
"translateY"
:
lambda
img
,
magnitude
:
img
.
transform
(
img
.
size
,
Image
.
AFFINE
,
(
1
,
0
,
0
,
0
,
1
,
magnitude
*
rnd_ch_op
([
-
1
,
1
])),
Image
.
NEAREST
,
fillcolor
=
fillcolor
),
"rotate"
:
lambda
img
,
magnitude
:
rotate_with_fill
(
img
,
magnitude
*
rnd_ch_op
([
-
1
,
1
])),
"color"
:
lambda
img
,
magnitude
:
ImageEnhance
.
Color
(
img
).
enhance
(
magnitude
),
"posterize"
:
lambda
img
,
magnitude
:
ImageOps
.
posterize
(
img
,
magnitude
),
"solarize"
:
lambda
img
,
magnitude
:
ImageOps
.
solarize
(
img
,
magnitude
),
"solarize_add"
:
lambda
img
,
magnitude
:
solarize_add
(
img
,
magnitude
),
"contrast"
:
lambda
img
,
magnitude
:
ImageEnhance
.
Contrast
(
img
).
enhance
(
magnitude
),
"sharpness"
:
lambda
img
,
magnitude
:
ImageEnhance
.
Sharpness
(
img
).
enhance
(
magnitude
),
"brightness"
:
lambda
img
,
magnitude
:
ImageEnhance
.
Brightness
(
img
).
enhance
(
magnitude
),
"autocontrast"
:
lambda
img
,
_
:
ImageOps
.
autocontrast
(
img
),
"equalize"
:
lambda
img
,
_
:
ImageOps
.
equalize
(
img
),
"invert"
:
lambda
img
,
_
:
ImageOps
.
invert
(
img
),
"cutout"
:
lambda
img
,
magnitude
:
cutout
(
img
,
magnitude
,
replace
=
fillcolor
[
0
])
}
ppcls/engine/train/__init__.py
浏览文件 @
3a8b5680
...
...
@@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
ppcls.engine.train.train
import
train_epoch
from
ppcls.engine.train.train_efficientnetv2
import
train_epoch_efficientnetv2
from
ppcls.engine.train.train_fixmatch
import
train_epoch_fixmatch
from
ppcls.engine.train.train_fixmatch_ccssl
import
train_epoch_fixmatch_ccssl
\ No newline at end of file
from
ppcls.engine.train.train_fixmatch_ccssl
import
train_epoch_fixmatch_ccssl
ppcls/engine/train/train_efficientnetv2.py
0 → 100644
浏览文件 @
3a8b5680
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
time
import
numpy
as
np
from
ppcls.data
import
build_dataloader
from
ppcls.utils
import
logger
from
.train
import
train_epoch
def
train_epoch_efficientnetv2
(
engine
,
epoch_id
,
print_batch_step
):
# 1. Build training hyper-parameters for different training stage
num_stage
=
4
ratio_list
=
[(
i
+
1
)
/
num_stage
for
i
in
range
(
num_stage
)]
ram_list
=
np
.
linspace
(
5
,
10
,
num_stage
)
# dropout_rate_list = np.linspace(0.0, 0.2, num_stage)
stones
=
[
int
(
engine
.
config
[
"Global"
][
"epochs"
]
*
ratio_list
[
i
])
for
i
in
range
(
num_stage
)
]
image_size_list
=
[
int
(
128
+
(
300
-
128
)
*
ratio_list
[
i
])
for
i
in
range
(
num_stage
)
]
stage_id
=
0
for
i
in
range
(
num_stage
):
if
epoch_id
>
stones
[
i
]:
stage_id
=
i
+
1
# 2. Adjust training hyper-parameters for different training stage
if
not
hasattr
(
engine
,
'last_stage'
)
or
engine
.
last_stage
<
stage_id
:
engine
.
config
[
"DataLoader"
][
"Train"
][
"dataset"
][
"transform_ops"
][
1
][
"RandCropImage"
][
"size"
]
=
image_size_list
[
stage_id
]
engine
.
config
[
"DataLoader"
][
"Train"
][
"dataset"
][
"transform_ops"
][
3
][
"RandAugment"
][
"magnitude"
]
=
ram_list
[
stage_id
]
engine
.
train_dataloader
=
build_dataloader
(
engine
.
config
[
"DataLoader"
],
"Train"
,
engine
.
device
,
engine
.
use_dali
,
seed
=
epoch_id
)
engine
.
train_dataloader_iter
=
iter
(
engine
.
train_dataloader
)
engine
.
last_stage
=
stage_id
logger
.
info
(
f
"Training stage: [
{
stage_id
+
1
}
/
{
num_stage
}
](random_aug_magnitude=
{
ram_list
[
stage_id
]
}
, train_image_size=
{
image_size_list
[
stage_id
]
}
)"
)
# 3. Train one epoch as usual at current stage
train_epoch
(
engine
,
epoch_id
,
print_batch_step
)
ppcls/utils/config.py
浏览文件 @
3a8b5680
...
...
@@ -33,7 +33,7 @@ class AttrDict(dict):
self
[
key
]
=
value
def
__deepcopy__
(
self
,
content
):
return
copy
.
deepcopy
(
dict
(
self
))
return
AttrDict
(
copy
.
deepcopy
(
dict
(
self
)
))
def
create_attr_dict
(
yaml_config
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录