Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
3541a80d
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3541a80d
编写于
9月 08, 2022
作者:
W
Wei Shengyu
提交者:
GitHub
9月 08, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2269 from zengshao0622/merge_CAE
Merge CAE
上级
14d2c23d
8b0e643f
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
1458 addition
and
0 deletion
+1458
-0
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/model_zoo/cae.py
ppcls/arch/backbone/model_zoo/cae.py
+860
-0
ppcls/configs/CAE/cae_base_patch16_224_finetune.yaml
ppcls/configs/CAE/cae_base_patch16_224_finetune.yaml
+167
-0
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+1
-0
ppcls/data/preprocess/batch_ops/batch_operators.py
ppcls/data/preprocess/batch_ops/batch_operators.py
+270
-0
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+1
-0
ppcls/loss/softtargetceloss.py
ppcls/loss/softtargetceloss.py
+16
-0
ppcls/optimizer/optimizer.py
ppcls/optimizer/optimizer.py
+142
-0
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
3541a80d
...
...
@@ -69,6 +69,7 @@ from .model_zoo.repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG
from
.model_zoo.van
import
VAN_tiny
from
.model_zoo.peleenet
import
PeleeNet
from
.model_zoo.convnext
import
ConvNeXt_tiny
from
.model_zoo.cae
import
cae_base_patch16_224
,
cae_large_patch16_224
from
.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
.variant_models.vgg_variant
import
VGG19Sigmoid
...
...
ppcls/arch/backbone/model_zoo/cae.py
0 → 100644
浏览文件 @
3541a80d
此差异已折叠。
点击以展开。
ppcls/configs/CAE/cae_base_patch16_224_finetune.yaml
0 → 100644
浏览文件 @
3541a80d
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
20
eval_during_train
:
True
eval_interval
:
1
epochs
:
100
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# model architecture
Arch
:
name
:
cae_base_patch16_224
class_num
:
102
drop_rate
:
0.0
drop_path_rate
:
0.1
attn_drop_rate
:
0.0
use_mean_pooling
:
True
init_scale
:
0.001
use_rel_pos_bias
:
True
use_abs_pos_emb
:
False
init_values
:
0.1
lin_probe
:
False
sin_pos_emb
:
True
abs_pos_emb
:
False
enable_linear_eval
:
False
model_key
:
model|module|state_dict
rel_pos_bias
:
True
model_ema
:
enable_model_ema
:
False
model_ema_decay
:
0.9999
model_ema_force_cpu
:
False
pretrained
:
True
# loss function config for traing/eval process
Loss
:
Train
:
-
SoftTargetCrossEntropy
:
weight
:
1.0
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
AdamWDL
beta1
:
0.9
beta2
:
0.999
epsilon
:
1e-8
weight_decay
:
0.05
layerwise_decay
:
0.65
lr
:
name
:
Cosine
learning_rate
:
0.001
eta_min
:
1e-6
warmup_epoch
:
10
warmup_start_lr
:
1e-6
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/flowers102/
cls_label_path
:
./dataset/flowers102/train_list.txt
batch_transform_ops
:
-
MixupCutmixHybrid
:
mixup_alpha
:
0.8
cutmix_alpha
:
1.0
switch_prob
:
0.5
num_classes
:
102
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
interpolation
:
bilinear
-
RandFlipImage
:
flip_code
:
1
-
RandAugment
:
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.5
sl
:
0.02
sh
:
0.3
r1
:
0.3
sampler
:
name
:
DistributedBatchSampler
batch_size
:
16
drop_last
:
True
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/flowers102/
cls_label_path
:
./dataset/flowers102/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
16
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/data/preprocess/__init__.py
浏览文件 @
3541a80d
...
...
@@ -42,6 +42,7 @@ from ppcls.data.preprocess.ops.operators import RandomRotation
from
ppcls.data.preprocess.ops.operators
import
Padv2
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
OpSampler
,
FmixOperator
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupCutmixHybrid
import
numpy
as
np
from
PIL
import
Image
...
...
ppcls/data/preprocess/batch_ops/batch_operators.py
浏览文件 @
3541a80d
...
...
@@ -23,6 +23,9 @@ import numpy as np
from
ppcls.utils
import
logger
from
ppcls.data.preprocess.ops.fmix
import
sample_mask
import
paddle
import
paddle.nn.functional
as
F
class
BatchOperator
(
object
):
""" BatchOperator """
...
...
@@ -229,3 +232,270 @@ class OpSampler(object):
list
(
self
.
ops
.
keys
()),
weights
=
list
(
self
.
ops
.
values
()),
k
=
1
)[
0
]
# return batch directly when None Op
return
op
(
batch
)
if
op
else
batch
class
MixupCutmixHybrid
(
object
):
""" Mixup/Cutmix that applies different params to each element or whole batch
Args:
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
prob (float): probability of applying mixup or cutmix per batch or element
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
label_smoothing (float): apply label smoothing to the mixed target tensor
num_classes (int): number of classes for target
"""
def
__init__
(
self
,
mixup_alpha
=
1.
,
cutmix_alpha
=
0.
,
cutmix_minmax
=
None
,
prob
=
1.0
,
switch_prob
=
0.5
,
mode
=
'batch'
,
correct_lam
=
True
,
label_smoothing
=
0.1
,
num_classes
=
4
):
self
.
mixup_alpha
=
mixup_alpha
self
.
cutmix_alpha
=
cutmix_alpha
self
.
cutmix_minmax
=
cutmix_minmax
if
self
.
cutmix_minmax
is
not
None
:
assert
len
(
self
.
cutmix_minmax
)
==
2
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
self
.
cutmix_alpha
=
1.0
self
.
mix_prob
=
prob
self
.
switch_prob
=
switch_prob
self
.
label_smoothing
=
label_smoothing
self
.
num_classes
=
num_classes
self
.
mode
=
mode
self
.
correct_lam
=
correct_lam
# correct lambda based on clipped area for cutmix
self
.
mixup_enabled
=
True
# set to false to disable mixing (intended tp be set by train loop)
def
_one_hot
(
self
,
x
,
num_classes
,
on_value
=
1.
,
off_value
=
0.
):
x
=
paddle
.
cast
(
x
,
dtype
=
'int64'
)
on_value
=
paddle
.
full
([
x
.
shape
[
0
],
num_classes
],
on_value
)
off_value
=
paddle
.
full
([
x
.
shape
[
0
],
num_classes
],
off_value
)
return
paddle
.
where
(
F
.
one_hot
(
x
,
num_classes
)
==
1
,
on_value
,
off_value
)
def
_mixup_target
(
self
,
target
,
num_classes
,
lam
=
1.
,
smoothing
=
0.0
):
off_value
=
smoothing
/
num_classes
on_value
=
1.
-
smoothing
+
off_value
y1
=
self
.
_one_hot
(
target
,
num_classes
,
on_value
=
on_value
,
off_value
=
off_value
,
)
y2
=
self
.
_one_hot
(
target
.
flip
(
0
),
num_classes
,
on_value
=
on_value
,
off_value
=
off_value
)
return
y1
*
lam
+
y2
*
(
1.
-
lam
)
def
_rand_bbox
(
self
,
img_shape
,
lam
,
margin
=
0.
,
count
=
None
):
""" Standard CutMix bounding-box
Generates a random square bbox based on lambda value. This impl includes
support for enforcing a border margin as percent of bbox dimensions.
Args:
img_shape (tuple): Image shape as tuple
lam (float): Cutmix lambda value
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
count (int): Number of bbox to generate
"""
ratio
=
np
.
sqrt
(
1
-
lam
)
img_h
,
img_w
=
img_shape
[
-
2
:]
cut_h
,
cut_w
=
int
(
img_h
*
ratio
),
int
(
img_w
*
ratio
)
margin_y
,
margin_x
=
int
(
margin
*
cut_h
),
int
(
margin
*
cut_w
)
cy
=
np
.
random
.
randint
(
0
+
margin_y
,
img_h
-
margin_y
,
size
=
count
)
cx
=
np
.
random
.
randint
(
0
+
margin_x
,
img_w
-
margin_x
,
size
=
count
)
yl
=
np
.
clip
(
cy
-
cut_h
//
2
,
0
,
img_h
)
yh
=
np
.
clip
(
cy
+
cut_h
//
2
,
0
,
img_h
)
xl
=
np
.
clip
(
cx
-
cut_w
//
2
,
0
,
img_w
)
xh
=
np
.
clip
(
cx
+
cut_w
//
2
,
0
,
img_w
)
return
yl
,
yh
,
xl
,
xh
def
_rand_bbox_minmax
(
self
,
img_shape
,
minmax
,
count
=
None
):
""" Min-Max CutMix bounding-box
Inspired by Darknet cutmix impl, generates a random rectangular bbox
based on min/max percent values applied to each dimension of the input image.
Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
Args:
img_shape (tuple): Image shape as tuple
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
count (int): Number of bbox to generate
"""
assert
len
(
minmax
)
==
2
img_h
,
img_w
=
img_shape
[
-
2
:]
cut_h
=
np
.
random
.
randint
(
int
(
img_h
*
minmax
[
0
]),
int
(
img_h
*
minmax
[
1
]),
size
=
count
)
cut_w
=
np
.
random
.
randint
(
int
(
img_w
*
minmax
[
0
]),
int
(
img_w
*
minmax
[
1
]),
size
=
count
)
yl
=
np
.
random
.
randint
(
0
,
img_h
-
cut_h
,
size
=
count
)
xl
=
np
.
random
.
randint
(
0
,
img_w
-
cut_w
,
size
=
count
)
yu
=
yl
+
cut_h
xu
=
xl
+
cut_w
return
yl
,
yu
,
xl
,
xu
def
_cutmix_bbox_and_lam
(
self
,
img_shape
,
lam
,
ratio_minmax
=
None
,
correct_lam
=
True
,
count
=
None
):
""" Generate bbox and apply lambda correction.
"""
if
ratio_minmax
is
not
None
:
yl
,
yu
,
xl
,
xu
=
self
.
_rand_bbox_minmax
(
img_shape
,
ratio_minmax
,
count
=
count
)
else
:
yl
,
yu
,
xl
,
xu
=
self
.
_rand_bbox
(
img_shape
,
lam
,
count
=
count
)
if
correct_lam
or
ratio_minmax
is
not
None
:
bbox_area
=
(
yu
-
yl
)
*
(
xu
-
xl
)
lam
=
1.
-
bbox_area
/
float
(
img_shape
[
-
2
]
*
img_shape
[
-
1
])
return
(
yl
,
yu
,
xl
,
xu
),
lam
def
_params_per_elem
(
self
,
batch_size
):
lam
=
np
.
ones
(
batch_size
,
dtype
=
np
.
float32
)
use_cutmix
=
np
.
zeros
(
batch_size
,
dtype
=
np
.
bool
)
if
self
.
mixup_enabled
:
if
self
.
mixup_alpha
>
0.
and
self
.
cutmix_alpha
>
0.
:
use_cutmix
=
np
.
random
.
rand
(
batch_size
)
<
self
.
switch_prob
lam_mix
=
np
.
where
(
use_cutmix
,
np
.
random
.
beta
(
self
.
cutmix_alpha
,
self
.
cutmix_alpha
,
size
=
batch_size
),
np
.
random
.
beta
(
self
.
mixup_alpha
,
self
.
mixup_alpha
,
size
=
batch_size
))
elif
self
.
mixup_alpha
>
0.
:
lam_mix
=
np
.
random
.
beta
(
self
.
mixup_alpha
,
self
.
mixup_alpha
,
size
=
batch_size
)
elif
self
.
cutmix_alpha
>
0.
:
use_cutmix
=
np
.
ones
(
batch_size
,
dtype
=
np
.
bool
)
lam_mix
=
np
.
random
.
beta
(
self
.
cutmix_alpha
,
self
.
cutmix_alpha
,
size
=
batch_size
)
else
:
assert
False
,
"One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam
=
np
.
where
(
np
.
random
.
rand
(
batch_size
)
<
self
.
mix_prob
,
lam_mix
.
astype
(
np
.
float32
),
lam
)
return
lam
,
use_cutmix
def
_params_per_batch
(
self
):
lam
=
1.
use_cutmix
=
False
if
self
.
mixup_enabled
and
np
.
random
.
rand
()
<
self
.
mix_prob
:
if
self
.
mixup_alpha
>
0.
and
self
.
cutmix_alpha
>
0.
:
use_cutmix
=
np
.
random
.
rand
()
<
self
.
switch_prob
lam_mix
=
np
.
random
.
beta
(
self
.
cutmix_alpha
,
self
.
cutmix_alpha
)
if
use_cutmix
else
\
np
.
random
.
beta
(
self
.
mixup_alpha
,
self
.
mixup_alpha
)
elif
self
.
mixup_alpha
>
0.
:
lam_mix
=
np
.
random
.
beta
(
self
.
mixup_alpha
,
self
.
mixup_alpha
)
elif
self
.
cutmix_alpha
>
0.
:
use_cutmix
=
True
lam_mix
=
np
.
random
.
beta
(
self
.
cutmix_alpha
,
self
.
cutmix_alpha
)
else
:
assert
False
,
"One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam
=
float
(
lam_mix
)
return
lam
,
use_cutmix
def
_mix_elem
(
self
,
x
):
batch_size
=
len
(
x
)
lam_batch
,
use_cutmix
=
self
.
_params_per_elem
(
batch_size
)
x_orig
=
x
.
clone
(
)
# need to keep an unmodified original for mixing source
for
i
in
range
(
batch_size
):
j
=
batch_size
-
i
-
1
lam
=
lam_batch
[
i
]
if
lam
!=
1.
:
if
use_cutmix
[
i
]:
(
yl
,
yh
,
xl
,
xh
),
lam
=
self
.
_cutmix_bbox_and_lam
(
x
[
i
].
shape
,
lam
,
ratio_minmax
=
self
.
cutmix_minmax
,
correct_lam
=
self
.
correct_lam
)
if
yl
<
yh
and
xl
<
xh
:
x
[
i
][:,
yl
:
yh
,
xl
:
xh
]
=
x_orig
[
j
][:,
yl
:
yh
,
xl
:
xh
]
lam_batch
[
i
]
=
lam
else
:
x
[
i
]
=
x
[
i
]
*
lam
+
x_orig
[
j
]
*
(
1
-
lam
)
return
paddle
.
to_tensor
(
lam_batch
,
dtype
=
x
.
dtype
).
unsqueeze
(
1
)
def
_mix_pair
(
self
,
x
):
batch_size
=
len
(
x
)
lam_batch
,
use_cutmix
=
self
.
_params_per_elem
(
batch_size
//
2
)
x_orig
=
x
.
clone
(
)
# need to keep an unmodified original for mixing source
for
i
in
range
(
batch_size
//
2
):
j
=
batch_size
-
i
-
1
lam
=
lam_batch
[
i
]
if
lam
!=
1.
:
if
use_cutmix
[
i
]:
(
yl
,
yh
,
xl
,
xh
),
lam
=
self
.
_cutmix_bbox_and_lam
(
x
[
i
].
shape
,
lam
,
ratio_minmax
=
self
.
cutmix_minmax
,
correct_lam
=
self
.
correct_lam
)
if
yl
<
yh
and
xl
<
xh
:
x
[
i
][:,
yl
:
yh
,
xl
:
xh
]
=
x_orig
[
j
][:,
yl
:
yh
,
xl
:
xh
]
x
[
j
][:,
yl
:
yh
,
xl
:
xh
]
=
x_orig
[
i
][:,
yl
:
yh
,
xl
:
xh
]
lam_batch
[
i
]
=
lam
else
:
x
[
i
]
=
x
[
i
]
*
lam
+
x_orig
[
j
]
*
(
1
-
lam
)
x
[
j
]
=
x
[
j
]
*
lam
+
x_orig
[
i
]
*
(
1
-
lam
)
lam_batch
=
np
.
concatenate
((
lam_batch
,
lam_batch
[::
-
1
]))
return
paddle
.
to_tensor
(
lam_batch
,
dtype
=
x
.
dtype
).
unsqueeze
(
1
)
def
_mix_batch
(
self
,
x
):
lam
,
use_cutmix
=
self
.
_params_per_batch
()
if
lam
==
1.
:
return
1.
if
use_cutmix
:
(
yl
,
yh
,
xl
,
xh
),
lam
=
self
.
_cutmix_bbox_and_lam
(
x
.
shape
,
lam
,
ratio_minmax
=
self
.
cutmix_minmax
,
correct_lam
=
self
.
correct_lam
)
if
yl
<
yh
and
xl
<
xh
:
x
[:,
:,
yl
:
yh
,
xl
:
xh
]
=
x
.
flip
(
0
)[:,
:,
yl
:
yh
,
xl
:
xh
]
else
:
x_flipped
=
x
.
flip
(
0
)
*
(
1.
-
lam
)
x
[:]
=
x
*
lam
+
x_flipped
return
lam
def
_unpack
(
self
,
batch
):
""" _unpack """
assert
isinstance
(
batch
,
list
),
\
'batch should be a list filled with tuples (img, label)'
bs
=
len
(
batch
)
assert
bs
>
0
,
'size of the batch data should > 0'
#imgs, labels = list(zip(*batch))
imgs
=
[]
labels
=
[]
for
item
in
batch
:
imgs
.
append
(
item
[
0
])
labels
.
append
(
item
[
1
])
return
np
.
array
(
imgs
),
np
.
array
(
labels
),
bs
def
__call__
(
self
,
batch
):
x
,
target
,
bs
=
self
.
_unpack
(
batch
)
x
=
paddle
.
to_tensor
(
x
)
target
=
paddle
.
to_tensor
(
target
)
assert
len
(
x
)
%
2
==
0
,
'Batch size should be even when using this'
if
self
.
mode
==
'elem'
:
lam
=
self
.
_mix_elem
(
x
)
elif
self
.
mode
==
'pair'
:
lam
=
self
.
_mix_pair
(
x
)
else
:
lam
=
self
.
_mix_batch
(
x
)
target
=
self
.
_mixup_target
(
target
,
self
.
num_classes
,
lam
,
self
.
label_smoothing
)
return
list
(
zip
(
x
.
numpy
(),
target
.
numpy
()))
ppcls/loss/__init__.py
浏览文件 @
3541a80d
...
...
@@ -17,6 +17,7 @@ from .supconloss import SupConLoss
from
.pairwisecosface
import
PairwiseCosface
from
.dmlloss
import
DMLLoss
from
.distanceloss
import
DistanceLoss
from
.softtargetceloss
import
SoftTargetCrossEntropy
from
.distillationloss
import
DistillationCELoss
from
.distillationloss
import
DistillationGTCELoss
...
...
ppcls/loss/softtargetceloss.py
0 → 100644
浏览文件 @
3541a80d
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
SoftTargetCrossEntropy
(
nn
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
x
,
target
):
loss
=
paddle
.
sum
(
-
target
*
F
.
log_softmax
(
x
,
axis
=-
1
),
axis
=-
1
)
loss
=
loss
.
mean
()
return
{
"SoftTargetCELoss"
:
loss
}
def
__str__
(
self
,
):
return
type
(
self
).
__name__
ppcls/optimizer/optimizer.py
浏览文件 @
3541a80d
...
...
@@ -272,3 +272,145 @@ class AdamW(object):
def
_apply_decay_param_fun
(
self
,
name
):
return
name
not
in
self
.
no_weight_decay_param_name_list
class
AdamWDL
(
object
):
"""
The AdamWDL optimizer is implemented based on the AdamW Optimization with dynamic lr setting.
Generally it's used for transformer model.
"""
def
__init__
(
self
,
learning_rate
=
0.001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-8
,
weight_decay
=
None
,
multi_precision
=
False
,
grad_clip
=
None
,
layerwise_decay
=
None
,
filter_bias_and_bn
=
True
,
**
args
):
self
.
learning_rate
=
learning_rate
self
.
beta1
=
beta1
self
.
beta2
=
beta2
self
.
epsilon
=
epsilon
self
.
grad_clip
=
grad_clip
self
.
weight_decay
=
weight_decay
self
.
multi_precision
=
multi_precision
self
.
layerwise_decay
=
layerwise_decay
self
.
filter_bias_and_bn
=
filter_bias_and_bn
class
AdamWDLImpl
(
optim
.
AdamW
):
def
__init__
(
self
,
learning_rate
=
0.001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-8
,
parameters
=
None
,
weight_decay
=
0.01
,
apply_decay_param_fun
=
None
,
grad_clip
=
None
,
lazy_mode
=
False
,
multi_precision
=
False
,
layerwise_decay
=
1.0
,
n_layers
=
12
,
name_dict
=
None
,
name
=
None
):
if
not
isinstance
(
layerwise_decay
,
float
)
and
\
not
isinstance
(
layerwise_decay
,
fluid
.
framework
.
Variable
):
raise
TypeError
(
"coeff should be float or Tensor."
)
self
.
layerwise_decay
=
layerwise_decay
self
.
name_dict
=
name_dict
self
.
n_layers
=
n_layers
self
.
set_param_lr_fun
=
self
.
_layerwise_lr_decay
super
().
__init__
(
learning_rate
=
learning_rate
,
parameters
=
parameters
,
beta1
=
beta1
,
beta2
=
beta2
,
epsilon
=
epsilon
,
grad_clip
=
grad_clip
,
name
=
name
,
apply_decay_param_fun
=
apply_decay_param_fun
,
weight_decay
=
weight_decay
,
lazy_mode
=
lazy_mode
,
multi_precision
=
multi_precision
)
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
if
self
.
set_param_lr_fun
is
None
:
return
super
(
AdamLW
,
self
).
_append_optimize_op
(
block
,
param_and_grad
)
self
.
_append_decoupled_weight_decay
(
block
,
param_and_grad
)
prev_lr
=
param_and_grad
[
0
].
optimize_attr
[
"learning_rate"
]
self
.
set_param_lr_fun
(
self
.
layerwise_decay
,
self
.
name_dict
,
self
.
n_layers
,
param_and_grad
[
0
])
# excute Adam op
res
=
super
(
optim
.
AdamW
,
self
).
_append_optimize_op
(
block
,
param_and_grad
)
param_and_grad
[
0
].
optimize_attr
[
"learning_rate"
]
=
prev_lr
return
res
# Layerwise decay
def
_layerwise_lr_decay
(
self
,
decay_rate
,
name_dict
,
n_layers
,
param
):
"""
Args:
decay_rate (float):
The layer-wise decay ratio.
name_dict (dict):
The keys of name_dict is dynamic name of model while the value
of name_dict is static name.
Use model.named_parameters() to get name_dict.
n_layers (int):
Total number of layers in the transformer encoder.
"""
ratio
=
1.0
static_name
=
name_dict
[
param
.
name
]
if
"blocks"
in
static_name
:
idx
=
static_name
.
find
(
"blocks."
)
layer
=
int
(
static_name
[
idx
:].
split
(
"."
)[
1
])
ratio
=
decay_rate
**
(
n_layers
-
layer
)
elif
"embed"
in
static_name
:
ratio
=
decay_rate
**
(
n_layers
+
1
)
param
.
optimize_attr
[
"learning_rate"
]
*=
ratio
def
__call__
(
self
,
model_list
):
model
=
model_list
[
0
]
if
self
.
weight_decay
and
self
.
filter_bias_and_bn
:
skip
=
{}
if
hasattr
(
model
,
'no_weight_decay'
):
skip
=
model
.
no_weight_decay
()
decay_dict
=
{
param
.
name
:
not
(
len
(
param
.
shape
)
==
1
or
name
.
endswith
(
".bias"
)
or
name
in
skip
)
for
name
,
param
in
model
.
named_parameters
()
if
not
'teacher'
in
name
}
parameters
=
[
param
for
param
in
model
.
parameters
()
if
'teacher'
not
in
param
.
name
]
weight_decay
=
0.
else
:
parameters
=
model
.
parameters
()
opt_args
=
dict
(
learning_rate
=
self
.
learning_rate
,
weight_decay
=
self
.
weight_decay
)
opt_args
[
'parameters'
]
=
parameters
if
decay_dict
is
not
None
:
opt_args
[
'apply_decay_param_fun'
]
=
lambda
n
:
decay_dict
[
n
]
opt_args
[
'epsilon'
]
=
self
.
epsilon
opt_args
[
'beta1'
]
=
self
.
beta1
opt_args
[
'beta2'
]
=
self
.
beta2
if
self
.
layerwise_decay
and
self
.
layerwise_decay
<
1.0
:
opt_args
[
'layerwise_decay'
]
=
self
.
layerwise_decay
name_dict
=
dict
()
for
n
,
p
in
model
.
named_parameters
():
name_dict
[
p
.
name
]
=
n
opt_args
[
'name_dict'
]
=
name_dict
opt_args
[
'n_layers'
]
=
model
.
get_num_layers
()
optimizer
=
self
.
AdamWDLImpl
(
**
opt_args
)
return
optimizer
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录