Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
7e12c73e
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7e12c73e
编写于
1月 03, 2023
作者:
H
HydrogenSulfate
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish progressive training code
上级
8b8e0431
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
46 addition
and
26 deletion
+46
-26
ppcls/arch/backbone/model_zoo/efficientnet_v2.py
ppcls/arch/backbone/model_zoo/efficientnet_v2.py
+6
-4
ppcls/configs/ImageNet/EfficientNetV2/EfficientNetV2_S.yaml
ppcls/configs/ImageNet/EfficientNetV2/EfficientNetV2_S.yaml
+6
-4
ppcls/data/preprocess/ops/operators.py
ppcls/data/preprocess/ops/operators.py
+2
-0
ppcls/data/preprocess/ops/randaugment.py
ppcls/data/preprocess/ops/randaugment.py
+7
-2
ppcls/engine/train/__init__.py
ppcls/engine/train/__init__.py
+1
-1
ppcls/engine/train/train_progressive.py
ppcls/engine/train/train_progressive.py
+23
-15
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+1
-0
未找到文件。
ppcls/arch/backbone/model_zoo/efficientnet_v2.py
浏览文件 @
7e12c73e
...
...
@@ -268,7 +268,7 @@ v2_xl_block = [ # only for 21k pretraining.
]
efficientnetv2_params
=
{
# params: (block, width, depth, dropout)
"efficientnetv2-s"
:
(
v2_s_block
,
1.0
,
1.0
,
0.2
),
"efficientnetv2-s"
:
(
v2_s_block
,
1.0
,
1.0
,
np
.
linspace
(
0.1
,
0.3
,
4
)
),
"efficientnetv2-m"
:
(
v2_m_block
,
1.0
,
1.0
,
0.3
),
"efficientnetv2-l"
:
(
v2_l_block
,
1.0
,
1.0
,
0.4
),
"efficientnetv2-xl"
:
(
v2_xl_block
,
1.0
,
1.0
,
0.4
),
...
...
@@ -293,7 +293,7 @@ def efficientnetv2_config(model_name: str):
act_fn
=
"silu"
,
survival_prob
=
0.8
,
local_pooling
=
False
,
conv_dropout
=
None
,
conv_dropout
=
0
,
num_classes
=
1000
))
return
cfg
...
...
@@ -756,8 +756,10 @@ class Head(nn.Layer):
self
.
_avg_pooling
=
nn
.
AdaptiveAvgPool2D
(
output_size
=
1
)
if
self
.
dropout_rate
>
0
:
self
.
_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
if
isinstance
(
self
.
dropout_rate
,
(
list
,
tuple
))
or
self
.
dropout_rate
>
0
:
self
.
_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
[
0
]
if
isinstance
(
self
.
dropout_rate
,
(
list
,
tuple
))
else
self
.
dropout_rate
)
else
:
self
.
_dropout
=
None
...
...
ppcls/configs/ImageNet/EfficientNetV2/EfficientNetV2_S.yaml
浏览文件 @
7e12c73e
...
...
@@ -4,16 +4,16 @@ Global:
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
00
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
350
print_batch_step
:
20
use_visualdl
:
False
train_mode
:
progressive
# progressive training
# used for static mode and model export
image_shape
:
[
3
,
384
,
384
]
save_inference_dir
:
./inference
train_mode
:
efficientnetv2
# progressive training
AMP
:
scale_loss
:
65536
...
...
@@ -63,13 +63,15 @@ DataLoader:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
171
progress_size
:
[
171
,
214
,
257
,
300
]
scale
:
[
0.05
,
1.0
]
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
RandAugmentV2
:
num_layers
:
2
magnitude
:
5
magnitude
:
5.0
progress_magnitude
:
[
5.0
,
8.3333333333
,
11.66666666667
,
15.0
]
-
NormalizeImage
:
scale
:
1.0
mean
:
[
128.0
,
128.0
,
128.0
]
...
...
ppcls/data/preprocess/ops/operators.py
浏览文件 @
7e12c73e
...
...
@@ -439,6 +439,7 @@ class RandCropImage(object):
def
__init__
(
self
,
size
,
progress_size
=
None
,
scale
=
None
,
ratio
=
None
,
interpolation
=
None
,
...
...
@@ -448,6 +449,7 @@ class RandCropImage(object):
else
:
self
.
size
=
size
self
.
progress_size
=
progress_size
self
.
scale
=
[
0.08
,
1.0
]
if
scale
is
None
else
scale
self
.
ratio
=
[
3.
/
4.
,
4.
/
3.
]
if
ratio
is
None
else
ratio
...
...
ppcls/data/preprocess/ops/randaugment.py
浏览文件 @
7e12c73e
...
...
@@ -176,9 +176,14 @@ class RandomApply(object):
class
RandAugmentV2
(
RandAugment
):
"""Customed RandAugment for EfficientNetV2"""
def
__init__
(
self
,
num_layers
=
2
,
magnitude
=
5
,
fillcolor
=
(
128
,
128
,
128
)):
def
__init__
(
self
,
num_layers
=
2
,
magnitude
=
5
,
progress_magnitude
=
None
,
fillcolor
=
(
128
,
128
,
128
)):
super
().
__init__
(
num_layers
,
magnitude
,
fillcolor
)
abso_level
=
self
.
magnitude
/
self
.
max_level
# [5.0~10.0/10.0]=[0.5, 1.0]
self
.
progress_magnitude
=
progress_magnitude
abso_level
=
self
.
magnitude
/
self
.
max_level
self
.
level_map
=
{
"shearX"
:
0.3
*
abso_level
,
"shearY"
:
0.3
*
abso_level
,
...
...
ppcls/engine/train/__init__.py
浏览文件 @
7e12c73e
...
...
@@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
ppcls.engine.train.train
import
train_epoch
from
ppcls.engine.train.train_efficientnetv2
import
train_epoch_efficientnetv2
from
ppcls.engine.train.train_fixmatch
import
train_epoch_fixmatch
from
ppcls.engine.train.train_fixmatch_ccssl
import
train_epoch_fixmatch_ccssl
from
ppcls.engine.train.train_progressive
import
train_epoch_progressive
ppcls/engine/train/train_
efficientnetv2
.py
→
ppcls/engine/train/train_
progressive
.py
浏览文件 @
7e12c73e
...
...
@@ -13,29 +13,21 @@
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
time
import
numpy
as
np
from
ppcls.data
import
build_dataloader
from
ppcls.engine.train.utils
import
type_name
from
ppcls.utils
import
logger
from
.train
import
train_epoch
def
train_epoch_
efficientnetv2
(
engine
,
epoch_id
,
print_batch_step
):
def
train_epoch_
progressive
(
engine
,
epoch_id
,
print_batch_step
):
# 1. Build training hyper-parameters for different training stage
num_stage
=
4
ratio_list
=
[(
i
+
1
)
/
num_stage
for
i
in
range
(
num_stage
)]
ram_list
=
np
.
linspace
(
5
,
10
,
num_stage
)
# dropout_rate_list = np.linspace(0.0, 0.2, num_stage)
stones
=
[
int
(
engine
.
config
[
"Global"
][
"epochs"
]
*
ratio_list
[
i
])
for
i
in
range
(
num_stage
)
]
image_size_list
=
[
int
(
128
+
(
300
-
128
)
*
ratio_list
[
i
])
for
i
in
range
(
num_stage
)
]
stage_id
=
0
for
i
in
range
(
num_stage
):
if
epoch_id
>
stones
[
i
]:
...
...
@@ -43,10 +35,24 @@ def train_epoch_efficientnetv2(engine, epoch_id, print_batch_step):
# 2. Adjust training hyper-parameters for different training stage
if
not
hasattr
(
engine
,
'last_stage'
)
or
engine
.
last_stage
<
stage_id
:
cur_dropout_rate
=
0.0
def
_change_dp_func
(
m
):
global
cur_dropout_rate
if
type_name
(
m
)
==
"Head"
and
hasattr
(
m
,
"_dropout"
):
m
.
_dropout
.
p
=
m
.
dropout_rate
[
stage_id
]
cur_dropout_rate
=
m
.
dropout_rate
[
stage_id
]
engine
.
model
.
apply
(
_change_dp_func
)
cur_image_size
=
engine
.
config
[
"DataLoader"
][
"Train"
][
"dataset"
][
"transform_ops"
][
1
][
"RandCropImage"
][
"progress_size"
][
stage_id
]
cur_magnitude
=
engine
.
config
[
"DataLoader"
][
"Train"
][
"dataset"
][
"transform_ops"
][
3
][
"RandAugment"
][
"progress_magnitude"
][
stage_id
]
engine
.
config
[
"DataLoader"
][
"Train"
][
"dataset"
][
"transform_ops"
][
1
][
"RandCropImage"
][
"size"
]
=
image_size_list
[
stage_id
]
"RandCropImage"
][
"size"
]
=
cur_image_size
engine
.
config
[
"DataLoader"
][
"Train"
][
"dataset"
][
"transform_ops"
][
3
][
"RandAugment"
][
"magnitude"
]
=
ram_list
[
stage_id
]
"RandAugment"
][
"magnitude"
]
=
cur_magnitude
engine
.
train_dataloader
=
build_dataloader
(
engine
.
config
[
"DataLoader"
],
"Train"
,
...
...
@@ -55,9 +61,11 @@ def train_epoch_efficientnetv2(engine, epoch_id, print_batch_step):
seed
=
epoch_id
)
engine
.
train_dataloader_iter
=
iter
(
engine
.
train_dataloader
)
engine
.
last_stage
=
stage_id
logger
.
info
(
f
"Training stage: [
{
stage_id
+
1
}
/
{
num_stage
}
](random_aug_magnitude=
{
ram_list
[
stage_id
]
}
, train_image_size=
{
image_size_list
[
stage_id
]
}
)"
)
logger
.
info
(
f
"Training stage: [
{
stage_id
+
1
}
/
{
num_stage
}
]("
f
"random_aug_magnitude=
{
cur_magnitude
}
, "
f
"train_image_size=
{
cur_image_size
}
, "
f
"dropout_rate=
{
cur_dropout_rate
}
"
f
")"
)
# 3. Train one epoch as usual at current stage
train_epoch
(
engine
,
epoch_id
,
print_batch_step
)
ppcls/utils/save_load.py
浏览文件 @
7e12c73e
...
...
@@ -61,6 +61,7 @@ def load_dygraph_pretrain(model, path=None):
m
.
set_dict
(
param_state_dict
)
else
:
model
.
set_dict
(
param_state_dict
)
logger
.
info
(
"Finish load pretrained model from {}"
.
format
(
path
))
return
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录