Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
ba9b708a
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ba9b708a
编写于
9月 17, 2021
作者:
C
cuicheng01
提交者:
GitHub
9月 17, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1174 from TingquanGao/dev/add_adamw
feat: add AdamW
上级
36aeefcf
079434dc
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
71 addition
and
12 deletion
+71
-12
ppcls/engine/engine.py
ppcls/engine/engine.py
+1
-1
ppcls/optimizer/__init__.py
ppcls/optimizer/__init__.py
+8
-6
ppcls/optimizer/optimizer.py
ppcls/optimizer/optimizer.py
+62
-5
未找到文件。
ppcls/engine/engine.py
浏览文件 @
ba9b708a
...
...
@@ -200,7 +200,7 @@ class Engine(object):
if
self
.
mode
==
'train'
:
self
.
optimizer
,
self
.
lr_sch
=
build_optimizer
(
self
.
config
[
"Optimizer"
],
self
.
config
[
"Global"
][
"epochs"
],
len
(
self
.
train_dataloader
),
self
.
model
.
parameters
()
)
len
(
self
.
train_dataloader
),
[
self
.
model
]
)
# for distributed
self
.
config
[
"Global"
][
...
...
ppcls/optimizer/__init__.py
浏览文件 @
ba9b708a
...
...
@@ -41,19 +41,22 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
return
lr
def
build_optimizer
(
config
,
epochs
,
step_each_epoch
,
parameters
=
None
):
def
build_optimizer
(
config
,
epochs
,
step_each_epoch
,
model_list
):
config
=
copy
.
deepcopy
(
config
)
# step1 build lr
lr
=
build_lr_scheduler
(
config
.
pop
(
'lr'
),
epochs
,
step_each_epoch
)
logger
.
debug
(
"build lr ({}) success.."
.
format
(
lr
))
# step2 build regularization
if
'regularizer'
in
config
and
config
[
'regularizer'
]
is
not
None
:
if
'weight_decay'
in
config
:
logger
.
warning
(
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config.
\"
weight_decay
\"
has been ignored."
)
reg_config
=
config
.
pop
(
'regularizer'
)
reg_name
=
reg_config
.
pop
(
'name'
)
+
'Decay'
reg
=
getattr
(
paddle
.
regularizer
,
reg_name
)(
**
reg_config
)
else
:
reg
=
None
logger
.
debug
(
"build regularizer ({}) success.."
.
format
(
reg
))
config
[
"weight_decay"
]
=
reg
logger
.
debug
(
"build regularizer ({}) success.."
.
format
(
reg
))
# step3 build optimizer
optim_name
=
config
.
pop
(
'name'
)
if
'clip_norm'
in
config
:
...
...
@@ -62,8 +65,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters=None):
else
:
grad_clip
=
None
optim
=
getattr
(
optimizer
,
optim_name
)(
learning_rate
=
lr
,
weight_decay
=
reg
,
grad_clip
=
grad_clip
,
**
config
)(
parameters
=
parameters
)
**
config
)(
model_list
=
model_list
)
logger
.
debug
(
"build optimizer ({}) success.."
.
format
(
optim
))
return
optim
,
lr
ppcls/optimizer/optimizer.py
浏览文件 @
ba9b708a
...
...
@@ -35,14 +35,15 @@ class Momentum(object):
weight_decay
=
None
,
grad_clip
=
None
,
multi_precision
=
False
):
super
(
Momentum
,
self
).
__init__
()
super
().
__init__
()
self
.
learning_rate
=
learning_rate
self
.
momentum
=
momentum
self
.
weight_decay
=
weight_decay
self
.
grad_clip
=
grad_clip
self
.
multi_precision
=
multi_precision
def
__call__
(
self
,
parameters
):
def
__call__
(
self
,
model_list
):
parameters
=
sum
([
m
.
parameters
()
for
m
in
model_list
],
[])
opt
=
optim
.
Momentum
(
learning_rate
=
self
.
learning_rate
,
momentum
=
self
.
momentum
,
...
...
@@ -77,7 +78,8 @@ class Adam(object):
self
.
lazy_mode
=
lazy_mode
self
.
multi_precision
=
multi_precision
def
__call__
(
self
,
parameters
):
def
__call__
(
self
,
model_list
):
parameters
=
sum
([
m
.
parameters
()
for
m
in
model_list
],
[])
opt
=
optim
.
Adam
(
learning_rate
=
self
.
learning_rate
,
beta1
=
self
.
beta1
,
...
...
@@ -112,7 +114,7 @@ class RMSProp(object):
weight_decay
=
None
,
grad_clip
=
None
,
multi_precision
=
False
):
super
(
RMSProp
,
self
).
__init__
()
super
().
__init__
()
self
.
learning_rate
=
learning_rate
self
.
momentum
=
momentum
self
.
rho
=
rho
...
...
@@ -120,7 +122,8 @@ class RMSProp(object):
self
.
weight_decay
=
weight_decay
self
.
grad_clip
=
grad_clip
def
__call__
(
self
,
parameters
):
def
__call__
(
self
,
model_list
):
parameters
=
sum
([
m
.
parameters
()
for
m
in
model_list
],
[])
opt
=
optim
.
RMSProp
(
learning_rate
=
self
.
learning_rate
,
momentum
=
self
.
momentum
,
...
...
@@ -130,3 +133,57 @@ class RMSProp(object):
grad_clip
=
self
.
grad_clip
,
parameters
=
parameters
)
return
opt
class
AdamW
(
object
):
def
__init__
(
self
,
learning_rate
=
0.001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-8
,
weight_decay
=
None
,
multi_precision
=
False
,
grad_clip
=
None
,
no_weight_decay_name
=
None
,
one_dim_param_no_weight_decay
=
False
,
**
args
):
super
().
__init__
()
self
.
learning_rate
=
learning_rate
self
.
beta1
=
beta1
self
.
beta2
=
beta2
self
.
epsilon
=
epsilon
self
.
grad_clip
=
grad_clip
self
.
weight_decay
=
weight_decay
self
.
multi_precision
=
multi_precision
self
.
no_weight_decay_name_list
=
no_weight_decay_name
.
split
(
)
if
no_weight_decay_name
else
[]
self
.
one_dim_param_no_weight_decay
=
one_dim_param_no_weight_decay
def
__call__
(
self
,
model_list
):
parameters
=
sum
([
m
.
parameters
()
for
m
in
model_list
],
[])
self
.
no_weight_decay_param_name_list
=
[
p
.
name
for
model
in
model_list
for
n
,
p
in
model
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
self
.
no_weight_decay_name_list
)
]
if
self
.
one_dim_param_no_weight_decay
:
self
.
no_weight_decay_param_name_list
+=
[
p
.
name
for
model
in
model_list
for
n
,
p
in
model
.
named_parameters
()
if
len
(
p
.
shape
)
==
1
]
opt
=
optim
.
AdamW
(
learning_rate
=
self
.
learning_rate
,
beta1
=
self
.
beta1
,
beta2
=
self
.
beta2
,
epsilon
=
self
.
epsilon
,
parameters
=
parameters
,
weight_decay
=
self
.
weight_decay
,
multi_precision
=
self
.
multi_precision
,
grad_clip
=
self
.
grad_clip
,
apply_decay_param_fun
=
self
.
_apply_decay_param_fun
)
return
opt
def
_apply_decay_param_fun
(
self
,
name
):
return
name
not
in
self
.
no_weight_decay_param_name_list
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录