Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7aa7a09b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7aa7a09b
编写于
5月 18, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/optimizer): add optimizer adadelta
GitOrigin-RevId: 244bc0d74a0cc7d8d7274e2ff22cb24f0e95f2ca
上级
205291a3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
111 addition
and
34 deletion
+111
-34
python_module/megengine/optimizer/__init__.py
python_module/megengine/optimizer/__init__.py
+1
-0
python_module/megengine/optimizer/adadelta.py
python_module/megengine/optimizer/adadelta.py
+78
-0
python_module/test/unit/optimizer/test_optimizer.py
python_module/test/unit/optimizer/test_optimizer.py
+32
-34
未找到文件。
python_module/megengine/optimizer/__init__.py
浏览文件 @
7aa7a09b
...
...
@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.adadelta
import
Adadelta
from
.adagrad
import
Adagrad
from
.adam
import
Adam
from
.lr_scheduler
import
LRScheduler
...
...
python_module/megengine/optimizer/adadelta.py
0 → 100644
浏览文件 @
7aa7a09b
from
typing
import
Iterable
,
Union
import
numpy
as
np
from
..core
import
Buffer
,
Parameter
from
..functional
import
sqrt
from
.internal
import
add_update_fastpath
as
add_update
from
.optimizer
import
Optimizer
class
Adadelta
(
Optimizer
):
r
"""Implements Adadelta algorithm.
It has been proposed in `"ADADELTA: An Adaptive Learning Rate Method" <https://arxiv.org/abs/1212.5701>`_.
:param params: iterable of parameters to optimize or dicts defining
parameter groups.
:param lr: coefficient that scale delta before it is applied
to the parameters (default: 1.0).
:param rho: coefficient used for computing a running average
of squared gradients (default: 0.9).
:param eps: term added to the denominator to improve
numerical stability (default: 1e-6).
:param weight_decay: weight decay (L2 penalty) (default: 0).
"""
def
__init__
(
self
,
params
:
Union
[
Iterable
[
Parameter
],
dict
],
lr
:
float
=
1.0
,
rho
:
float
=
0.9
,
eps
:
float
=
1e-6
,
weight_decay
:
float
=
0.0
,
):
assert
lr
>=
0.0
,
"Invalid learning rate: {}"
.
format
(
lr
)
assert
rho
>=
0.0
and
rho
<=
1.0
,
"Invalid rho value: {}"
.
format
(
rho
)
assert
eps
>=
0.0
,
"Invalid epsilon value: {}"
.
format
(
eps
)
assert
weight_decay
>=
0.0
,
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
defaults
=
dict
(
lr
=
lr
,
rho
=
rho
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
().
__init__
(
params
,
defaults
)
def
_create_state
(
self
,
param_group
):
for
param
in
param_group
[
"params"
]:
self
.
_add_state
(
param
,
"square_avg"
)
self
.
_add_state
(
param
,
"acc_delta"
)
self
.
_add_state
(
param
,
"step"
,
initializer
=
0.0
)
def
_updates
(
self
,
param_group
):
lr
=
param_group
[
"lr"
]
weight_decay
=
param_group
[
"weight_decay"
]
rho
=
param_group
[
"rho"
]
eps
=
param_group
[
"eps"
]
for
param
in
param_group
[
"params"
]:
if
not
isinstance
(
param
.
grad
,
Buffer
):
raise
TypeError
(
"grad must be a Buffer, maybe you forget to call backward()?"
)
if
not
param
.
requires_grad
:
continue
step
=
self
.
_state
[
param
][
"step"
]
step
=
add_update
(
step
,
1
)
grad
=
param
.
grad
if
weight_decay
!=
0.0
:
grad
=
add_update
(
grad
,
param
,
beta
=
weight_decay
)
square_avg
=
self
.
_state
[
param
][
"square_avg"
]
acc_delta
=
self
.
_state
[
param
][
"acc_delta"
]
square_avg
=
add_update
(
square_avg
,
grad
**
2
,
alpha
=
rho
,
beta
=
1
-
rho
)
std
=
sqrt
(
square_avg
+
eps
)
delta
=
sqrt
(
acc_delta
+
eps
)
/
std
*
grad
add_update
(
param
,
delta
,
beta
=-
lr
)
acc_delta
=
add_update
(
acc_delta
,
delta
**
2
,
alpha
=
rho
,
beta
=
1
-
rho
)
python_module/test/unit/optimizer/test_optimizer.py
浏览文件 @
7aa7a09b
...
...
@@ -189,72 +189,70 @@ def test_adam():
_test_optimizer
(
"Adam"
,
case
,
CheckValue
,
update_lr
=
True
)
def
test_ada
m
():
def
test_ada
grad
():
class
CheckValue
:
def
__init__
(
self
,
net
,
**
kwarg
):
self
.
m_slots
=
TensorDict
()
self
.
v_slots
=
TensorDict
()
self
.
s_slots
=
TensorDict
()
for
param
in
net
.
parameters
():
self
.
m_slots
[
param
]
=
np
.
zeros
(
param
.
shape
).
astype
(
np
.
float32
)
self
.
v_slots
[
param
]
=
np
.
zeros
(
param
.
shape
).
astype
(
np
.
float32
)
self
.
s_slots
[
param
]
=
np
.
zeros
(
param
.
shape
).
astype
(
np
.
float32
)
for
k
,
v
in
kwarg
.
items
():
setattr
(
self
,
k
,
v
)
def
__call__
(
self
,
ori_params
,
new_params
,
step
):
for
param
in
new_params
:
grad
=
param
.
grad
.
numpy
()
m
=
self
.
m_slots
[
param
]
v
=
self
.
v_slots
[
param
]
m
*=
self
.
betas
[
0
]
m
+=
(
1
-
self
.
betas
[
0
])
*
grad
v
*=
self
.
betas
[
1
]
v
+=
(
1
-
self
.
betas
[
1
])
*
grad
*
grad
delta
=
(
m
/
(
1
-
self
.
betas
[
0
]
**
step
))
/
(
np
.
sqrt
(
v
/
(
1
-
self
.
betas
[
1
]
**
step
))
+
self
.
eps
)
assertTensorClose
(
param
.
numpy
(),
ori_params
[
param
]
-
self
.
lr
*
delta
)
self
.
s_slots
[
param
]
+=
grad
**
2
delta
=
grad
/
(
self
.
s_slots
[
param
]
+
self
.
eps
)
**
0.5
delta
*=
-
(
self
.
lr
/
(
1
+
(
step
-
1
)
*
self
.
lr_decay
))
assertTensorClose
(
param
.
numpy
(),
ori_params
[
param
]
+
delta
)
cases
=
[
{
"betas"
:
(
0.8
,
0.9
),
"eps"
:
1e-04
,
"lr"
:
0.01
},
{
"lr"
:
0.01
,
"eps"
:
1e-06
,
"lr_decay"
:
0.01
},
{
"lr"
:
0.01
,
"eps"
:
1e-06
,
"lr_decay"
:
0.0
},
# without lr_decay
{
"betas"
:
(
0.8
,
0.9
),
"eps"
:
1e-04
,
"lr"
:
0.01
,
"eps"
:
1e-06
,
"lr_decay"
:
0.01
,
"weight_decay"
:
0.1
,
},
# with weight_decay
]
for
case
in
cases
:
_test_optimizer
(
"Ada
m
"
,
case
,
CheckValue
)
_test_optimizer
(
"Ada
m
"
,
case
,
CheckValue
,
update_lr
=
True
)
_test_optimizer
(
"Ada
grad
"
,
case
,
CheckValue
)
_test_optimizer
(
"Ada
grad
"
,
case
,
CheckValue
,
update_lr
=
True
)
def
test_ada
grad
():
def
test_ada
delta
():
class
CheckValue
:
def
__init__
(
self
,
net
,
**
kwarg
):
self
.
s_slots
=
TensorDict
()
self
.
a_slots
=
TensorDict
()
for
param
in
net
.
parameters
():
self
.
s_slots
[
param
]
=
np
.
zeros
(
param
.
shape
).
astype
(
np
.
float32
)
self
.
a_slots
[
param
]
=
np
.
zeros
(
param
.
shape
).
astype
(
np
.
float32
)
for
k
,
v
in
kwarg
.
items
():
setattr
(
self
,
k
,
v
)
def
__call__
(
self
,
ori_params
,
new_params
,
step
):
for
param
in
new_params
:
grad
=
param
.
grad
.
numpy
()
self
.
s_slots
[
param
]
+=
grad
**
2
delta
=
grad
/
(
self
.
s_slots
[
param
]
+
self
.
eps
)
**
0.5
delta
*=
-
(
self
.
lr
/
(
1
+
(
step
-
1
)
*
self
.
lr_decay
))
self
.
s_slots
[
param
]
=
self
.
s_slots
[
param
]
*
self
.
rho
+
grad
**
2
*
(
1
-
self
.
rho
)
delta
=
(
grad
*
((
self
.
a_slots
[
param
]
+
self
.
eps
)
**
0.5
)
/
(
self
.
s_slots
[
param
]
+
self
.
eps
)
**
0.5
)
self
.
a_slots
[
param
]
=
self
.
a_slots
[
param
]
*
self
.
rho
+
delta
**
2
*
(
1
-
self
.
rho
)
delta
*=
-
self
.
lr
assertTensorClose
(
param
.
numpy
(),
ori_params
[
param
]
+
delta
)
cases
=
[
{
"lr"
:
0.01
,
"eps"
:
1e-06
,
"lr_decay"
:
0.01
},
{
"lr"
:
0.01
,
"eps"
:
1e-06
,
"lr_decay"
:
0.0
},
# without lr_decay
{
"lr"
:
0.01
,
"eps"
:
1e-06
,
"lr_decay"
:
0.01
,
"weight_decay"
:
0.1
,
},
# with weight_decay
{
"lr"
:
1.0
,
"eps"
:
1e-06
,
"rho"
:
0.9
},
{
"lr"
:
1.0
,
"eps"
:
1e-06
,
"rho"
:
0.9
,
"weight_decay"
:
0.9
},
# with weight_decay
]
for
case
in
cases
:
_test_optimizer
(
"Ada
grad
"
,
case
,
CheckValue
)
_test_optimizer
(
"Ada
grad
"
,
case
,
CheckValue
,
update_lr
=
True
)
_test_optimizer
(
"Ada
delta
"
,
case
,
CheckValue
)
_test_optimizer
(
"Ada
delta
"
,
case
,
CheckValue
,
update_lr
=
True
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录