Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7e22e9f0
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7e22e9f0
编写于
4月 14, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(optimzer): add AdamW
GitOrigin-RevId: e608b5d5b95a843694fb6a251c0662c080a7b240
上级
2d18074a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
243 addition
and
42 deletion
+243
-42
imperative/python/megengine/optimizer/__init__.py
imperative/python/megengine/optimizer/__init__.py
+1
-0
imperative/python/megengine/optimizer/adamw.py
imperative/python/megengine/optimizer/adamw.py
+128
-0
imperative/python/test/integration/test_optimizer.py
imperative/python/test/integration/test_optimizer.py
+114
-42
未找到文件。
imperative/python/megengine/optimizer/__init__.py
浏览文件 @
7e22e9f0
...
...
@@ -9,6 +9,7 @@
from
.adadelta
import
Adadelta
from
.adagrad
import
Adagrad
from
.adam
import
Adam
from
.adamw
import
AdamW
from
.lr_scheduler
import
LRScheduler
from
.multi_step_lr
import
MultiStepLR
from
.optimizer
import
Optimizer
...
...
imperative/python/megengine/optimizer/adamw.py
0 → 100644
浏览文件 @
7e22e9f0
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
os
from
typing
import
Iterable
,
Tuple
,
Union
from
..functional.inplace
import
_inplace_add_
from
..tensor
import
Parameter
,
tensor
from
.optimizer
import
Optimizer
class
AdamW
(
Optimizer
):
r
"""
Implements AdamW algorithm proposed in `"Decoupled Weight Decay Regularization" <https://arxiv.org/abs/1711.05101>`_.
:param params: iterable of parameters to optimize or dicts defining
parameter groups.
:param lr: learning rate.
:param betas: coefficients used for computing running averages of gradient
and its square. Default: (0.9, 0.999)
:param eps: term added to the denominator to improve numerical stability
Default: 1e-8
:param weight_decay: weight decay (L2 penalty). Default: 1e-2
"""
def
__init__
(
self
,
params
:
Union
[
Iterable
[
Parameter
],
dict
],
lr
:
float
,
betas
:
Tuple
[
float
,
float
]
=
(
0.9
,
0.999
),
eps
:
float
=
1e-8
,
weight_decay
:
float
=
1e-2
,
):
if
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
weight_decay
<
0.0
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
if
not
0.0
<=
betas
[
0
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter at index 0: {}"
.
format
(
betas
[
0
]))
if
not
0.0
<=
betas
[
1
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter at index 1: {}"
.
format
(
betas
[
1
]))
defaults
=
dict
(
lr
=
lr
,
weight_decay
=
weight_decay
,
betas
=
betas
,
eps
=
eps
)
super
().
__init__
(
params
,
defaults
)
def
_create_state
(
self
,
param_group
):
for
param
in
param_group
[
"params"
]:
self
.
_add_state
(
param
,
"exp_avg"
)
self
.
_add_state
(
param
,
"exp_avg_sq"
)
self
.
_add_state
(
param
,
"step"
,
initializer
=
0.0
)
def
_updates
(
self
,
param_group
):
lr
=
param_group
[
"lr"
]
weight_decay
=
param_group
[
"weight_decay"
]
eps
=
param_group
[
"eps"
]
beta0
,
beta1
=
param_group
[
"betas"
]
def
make_scalar
(
val
):
return
tensor
(
val
)
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr
,
_neg_lr
=
map
(
make_scalar
,
(
lr
,
-
lr
))
_weight_decay
=
make_scalar
(
weight_decay
)
_eps
=
make_scalar
(
eps
)
_beta0
,
_beta1
=
map
(
make_scalar
,
(
beta0
,
beta1
))
c1
,
c05
=
map
(
make_scalar
,
(
1.0
,
0.5
))
inplace_mode
=
int
(
os
.
getenv
(
"MEGENGINE_INPLACE_UPDATE"
,
"0"
))
if
inplace_mode
:
# reduce device sync
c1_sub_beta0
,
c1_sub_beta1
=
map
(
make_scalar
,
(
1
-
beta0
,
1
-
beta1
))
for
param
in
param_group
[
"params"
]:
if
param
.
grad
is
None
:
continue
grad
=
param
.
grad
states
=
self
.
_state
[
param
]
step
,
exp_avg
,
exp_avg_sq
=
(
states
[
"step"
],
states
[
"exp_avg"
],
states
[
"exp_avg_sq"
],
)
if
inplace_mode
:
_inplace_add_
(
step
,
c1
,
alpha
=
c1
,
beta
=
c1
)
_inplace_add_
(
exp_avg
,
grad
,
alpha
=
_beta0
,
beta
=
c1_sub_beta0
)
_inplace_add_
(
exp_avg_sq
,
grad
*
grad
,
alpha
=
_beta1
,
beta
=
c1_sub_beta1
,
)
delta
=
(
exp_avg
/
(
c1
-
_beta0
**
step
))
/
(
(
exp_avg_sq
/
(
c1
-
_beta1
**
step
))
**
c05
+
_eps
)
if
weight_decay
!=
0.0
:
delta
+=
param
*
_weight_decay
_inplace_add_
(
param
,
delta
,
alpha
=
c1
,
beta
=
_neg_lr
)
continue
# step = step + c1
step
+=
c1
# exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0)
exp_avg
*=
_beta0
exp_avg
+=
grad
*
(
c1
-
_beta0
)
# exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad)
exp_avg_sq
*=
_beta1
exp_avg_sq
+=
(
c1
-
_beta1
)
*
(
grad
*
grad
)
delta
=
(
exp_avg
/
(
c1
-
_beta0
**
step
))
/
(
(
exp_avg_sq
/
(
c1
-
_beta1
**
step
))
**
c05
+
_eps
)
if
weight_decay
!=
0.0
:
delta
+=
param
*
_weight_decay
param
-=
_lr
*
delta
imperative/python/test/integration/test_optimizer.py
浏览文件 @
7e22e9f0
...
...
@@ -6,7 +6,10 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
os
import
numpy
as
np
import
pytest
import
megengine.autodiff
as
ad
import
megengine.functional
as
F
...
...
@@ -110,7 +113,17 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
}
def
test_sgd
():
@
pytest
.
mark
.
parametrize
(
"case"
,
[
{
"momentum"
:
0.9
,
"lr"
:
0.01
},
# SGD with momentum
{
"lr"
:
0.01
},
# simple SGD
{
"weight_decay"
:
0.1
,
"lr"
:
0.01
},
# with weight_decay
],
)
@
pytest
.
mark
.
parametrize
(
"update_lr"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"inplace_mode"
,
[
False
,
True
])
def
test_sgd
(
monkeypatch
,
case
,
update_lr
,
inplace_mode
):
class
CheckValue
:
def
__init__
(
self
,
net
,
**
kwarg
):
self
.
slots
=
{}
...
...
@@ -131,17 +144,26 @@ def test_sgd():
param
.
numpy
(),
ori_params
[
param
]
+
delta
,
decimal
=
6
)
cases
=
[
{
"momentum"
:
0.9
,
"lr"
:
0.01
},
# SGD with momentum
{
"lr"
:
0.01
},
# simple SGD
{
"weight_decay"
:
0.1
,
"lr"
:
0.01
},
# with weight_decay
]
for
case
in
cases
:
_test_optimizer
(
"SGD"
,
case
,
CheckValue
)
_test_optimizer
(
"SGD"
,
case
,
CheckValue
,
update_lr
=
True
)
with
monkeypatch
.
context
()
as
mk
:
mk
.
setenv
(
"MEGENGINE_INPLACE_UPDATE"
,
str
(
int
(
inplace_mode
)))
_test_optimizer
(
"SGD"
,
case
,
CheckValue
,
update_lr
=
update_lr
)
def
test_adam
():
@
pytest
.
mark
.
parametrize
(
"case"
,
[
{
"betas"
:
(
0.8
,
0.9
),
"eps"
:
1e-04
,
"lr"
:
0.01
},
{
"betas"
:
(
0.8
,
0.9
),
"eps"
:
1e-04
,
"lr"
:
0.01
,
"weight_decay"
:
0.1
,
},
# with weight_decay
],
)
@
pytest
.
mark
.
parametrize
(
"update_lr"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"inplace_mode"
,
[
False
,
True
])
def
test_adam
(
monkeypatch
,
case
,
update_lr
,
inplace_mode
):
class
CheckValue
:
def
__init__
(
self
,
net
,
**
kwarg
):
self
.
m_slots
=
{}
...
...
@@ -168,21 +190,27 @@ def test_adam():
param
.
numpy
(),
ori_params
[
param
]
-
self
.
lr
*
delta
,
decimal
=
6
)
cases
=
[
{
"betas"
:
(
0.8
,
0.9
),
"eps"
:
1e-04
,
"lr"
:
0.01
},
with
monkeypatch
.
context
()
as
mk
:
mk
.
setenv
(
"MEGENGINE_INPLACE_UPDATE"
,
str
(
int
(
inplace_mode
)))
_test_optimizer
(
"Adam"
,
case
,
CheckValue
,
update_lr
=
update_lr
)
@
pytest
.
mark
.
parametrize
(
"case"
,
[
{
"lr"
:
0.01
,
"eps"
:
1e-06
,
"lr_decay"
:
0.01
},
{
"lr"
:
0.01
,
"eps"
:
1e-06
,
"lr_decay"
:
0.0
},
# without lr_decay
{
"betas"
:
(
0.8
,
0.9
),
"eps"
:
1e-04
,
"lr"
:
0.01
,
"eps"
:
1e-06
,
"lr_decay"
:
0.01
,
"weight_decay"
:
0.1
,
},
# with weight_decay
]
for
case
in
cases
:
_test_optimizer
(
"Adam"
,
case
,
CheckValue
)
_test_optimizer
(
"Adam"
,
case
,
CheckValue
,
update_lr
=
True
)
def
test_adagrad
():
],
)
@
pytest
.
mark
.
parametrize
(
"update_lr"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"inplace_mode"
,
[
False
,
True
])
def
test_adagrad
(
monkeypatch
,
case
,
update_lr
,
inplace_mode
):
class
CheckValue
:
def
__init__
(
self
,
net
,
**
kwarg
):
self
.
s_slots
=
{}
...
...
@@ -201,22 +229,21 @@ def test_adagrad():
param
.
numpy
(),
ori_params
[
param
]
+
delta
,
decimal
=
6
)
cases
=
[
{
"lr"
:
0.01
,
"eps"
:
1e-06
,
"lr_decay"
:
0.01
},
{
"lr"
:
0.01
,
"eps"
:
1e-06
,
"lr_decay"
:
0.0
},
# without lr_decay
{
"lr"
:
0.01
,
"eps"
:
1e-06
,
"lr_decay"
:
0.01
,
"weight_decay"
:
0.1
,
},
# with weight_decay
]
for
case
in
cases
:
_test_optimizer
(
"Adagrad"
,
case
,
CheckValue
)
_test_optimizer
(
"Adagrad"
,
case
,
CheckValue
,
update_lr
=
True
)
with
monkeypatch
.
context
()
as
mk
:
mk
.
setenv
(
"MEGENGINE_INPLACE_UPDATE"
,
str
(
int
(
inplace_mode
)))
_test_optimizer
(
"Adagrad"
,
case
,
CheckValue
,
update_lr
=
update_lr
)
def
test_adadelta
():
@
pytest
.
mark
.
parametrize
(
"case"
,
[
{
"lr"
:
1.0
,
"eps"
:
1e-06
,
"rho"
:
0.9
},
{
"lr"
:
1.0
,
"eps"
:
1e-06
,
"rho"
:
0.9
,
"weight_decay"
:
0.9
},
# with weight_decay
],
)
@
pytest
.
mark
.
parametrize
(
"update_lr"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"inplace_mode"
,
[
False
,
True
])
def
test_adadelta
(
monkeypatch
,
case
,
update_lr
,
inplace_mode
):
class
CheckValue
:
def
__init__
(
self
,
net
,
**
kwarg
):
self
.
s_slots
=
{}
...
...
@@ -246,10 +273,55 @@ def test_adadelta():
param
.
numpy
(),
ori_params
[
param
]
+
delta
,
decimal
=
6
)
cases
=
[
{
"lr"
:
1.0
,
"eps"
:
1e-06
,
"rho"
:
0.9
},
{
"lr"
:
1.0
,
"eps"
:
1e-06
,
"rho"
:
0.9
,
"weight_decay"
:
0.9
},
# with weight_decay
]
for
case
in
cases
:
_test_optimizer
(
"Adadelta"
,
case
,
CheckValue
)
_test_optimizer
(
"Adadelta"
,
case
,
CheckValue
,
update_lr
=
True
)
with
monkeypatch
.
context
()
as
mk
:
mk
.
setenv
(
"MEGENGINE_INPLACE_UPDATE"
,
str
(
int
(
inplace_mode
)))
_test_optimizer
(
"Adadelta"
,
case
,
CheckValue
,
update_lr
=
update_lr
)
@
pytest
.
mark
.
parametrize
(
"case"
,
[
{
"betas"
:
(
0.8
,
0.9
),
"eps"
:
1e-08
,
"lr"
:
0.01
},
{
"betas"
:
(
0.8
,
0.9
),
"eps"
:
1e-08
,
"lr"
:
0.01
,
"weight_decay"
:
0.1
,
},
# with weight_decay
],
)
@
pytest
.
mark
.
parametrize
(
"update_lr"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"inplace_mode"
,
[
False
,
True
])
def
test_adamw
(
monkeypatch
,
case
,
update_lr
,
inplace_mode
):
class
CheckValue
:
def
__init__
(
self
,
net
,
**
kwarg
):
self
.
m_slots
=
{}
self
.
v_slots
=
{}
for
param
in
net
.
parameters
():
self
.
m_slots
[
param
]
=
np
.
zeros
(
param
.
shape
).
astype
(
np
.
float32
)
self
.
v_slots
[
param
]
=
np
.
zeros
(
param
.
shape
).
astype
(
np
.
float32
)
self
.
weight_decay
=
0.01
for
k
,
v
in
kwarg
.
items
():
setattr
(
self
,
k
,
v
)
def
__call__
(
self
,
ori_params
,
new_params
,
step
):
step
=
np
.
array
(
step
).
astype
(
np
.
float32
)
for
param
in
new_params
:
grad
=
param
.
grad
.
numpy
()
m
=
self
.
m_slots
[
param
]
v
=
self
.
v_slots
[
param
]
m
*=
self
.
betas
[
0
]
m
+=
(
1
-
self
.
betas
[
0
])
*
grad
v
*=
self
.
betas
[
1
]
v
+=
(
1
-
self
.
betas
[
1
])
*
grad
*
grad
delta
=
(
m
/
(
1
-
self
.
betas
[
0
]
**
step
))
/
(
np
.
sqrt
(
v
/
(
1
-
self
.
betas
[
1
]
**
step
))
+
self
.
eps
)
delta
+=
ori_params
[
param
]
*
self
.
weight_decay
np
.
testing
.
assert_almost_equal
(
param
.
numpy
(),
ori_params
[
param
]
-
self
.
lr
*
delta
,
decimal
=
6
)
with
monkeypatch
.
context
()
as
mk
:
mk
.
setenv
(
"MEGENGINE_INPLACE_UPDATE"
,
str
(
int
(
inplace_mode
)))
_test_optimizer
(
"AdamW"
,
case
,
CheckValue
,
update_lr
=
update_lr
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录