Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
1fed5929
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
1fed5929
编写于
10月 28, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(mge/optimizer): close conver_inputs for optimizer step
GitOrigin-RevId: c710530d934e1be29e611322b75837c9b72a610c
上级
1f75c7ad
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
94 addition
and
29 deletion
+94
-29
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+22
-0
imperative/python/megengine/optimizer/adadelta.py
imperative/python/megengine/optimizer/adadelta.py
+19
-9
imperative/python/megengine/optimizer/adagrad.py
imperative/python/megengine/optimizer/adagrad.py
+17
-7
imperative/python/megengine/optimizer/adam.py
imperative/python/megengine/optimizer/adam.py
+18
-8
imperative/python/megengine/optimizer/optimizer.py
imperative/python/megengine/optimizer/optimizer.py
+6
-0
imperative/python/megengine/optimizer/sgd.py
imperative/python/megengine/optimizer/sgd.py
+12
-5
未找到文件。
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
1fed5929
...
...
@@ -16,6 +16,25 @@ from ..ops.special import Const
from
..tensor.core
import
OpBase
,
TensorBase
,
TensorWrapperBase
,
apply
from
.dtype
import
is_equal
,
is_quantize
_enable_convert_inputs
=
True
def
get_convert_inputs
():
""" get the curerent state of `_enable_convert_inputs` """
return
_enable_convert_inputs
def
set_convert_inputs
(
flag
):
""" This function is a temporary workaround for reducing the overhead of operator
invocations. The function `convert_inputs` is disabled if the global state
`_enable_convert_inputs` is set to `False`, otherwise enabled. This function is for
internal use only, and should be removed when the tensor-like system is refactored.
"""
global
_enable_convert_inputs
backup
=
_enable_convert_inputs
_enable_convert_inputs
=
flag
return
backup
def
dtype_promotion
(
inputs
):
"""
...
...
@@ -129,6 +148,9 @@ def convert_single_value(v, inputs, *, dtype=None, device=None):
def
convert_inputs
(
*
args
:
TensorBase
):
if
not
_enable_convert_inputs
:
return
args
dtype
=
dtype_promotion
(
args
)
device
=
get_device
(
args
)
...
...
imperative/python/megengine/optimizer/adadelta.py
浏览文件 @
1fed5929
...
...
@@ -10,8 +10,8 @@ from typing import Iterable, Union
import
numpy
as
np
from
..
functional
import
sqrt
from
..tensor
import
Parameter
from
..
core.tensor.tensor
import
Tensor
from
..tensor
import
Parameter
,
tensor
from
.optimizer
import
Optimizer
...
...
@@ -62,6 +62,16 @@ class Adadelta(Optimizer):
rho
=
param_group
[
"rho"
]
eps
=
param_group
[
"eps"
]
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr
=
tensor
([
lr
])
_weight_decay
=
tensor
([
weight_decay
])
_rho
=
tensor
([
rho
])
_eps
=
tensor
([
eps
])
c05
=
tensor
([
0.5
])
c1
=
tensor
([
1.0
])
c2
=
tensor
([
2.0
])
for
param
in
param_group
[
"params"
]:
if
param
.
grad
is
None
:
...
...
@@ -69,17 +79,17 @@ class Adadelta(Optimizer):
states
=
self
.
_state
[
param
]
step
=
states
[
"step"
]
step
+=
1.0
step
+=
c1
grad
=
param
.
grad
if
weight_decay
!=
0.0
:
grad
+=
param
*
weight_decay
grad
+=
param
*
_
weight_decay
square_avg
=
states
[
"square_avg"
]
acc_delta
=
states
[
"acc_delta"
]
square_avg
=
rho
*
square_avg
+
(
1
-
rho
)
*
grad
**
2
std
=
sqrt
(
square_avg
+
eps
)
delta
=
sqrt
(
acc_delta
+
eps
)
/
std
*
grad
param
-=
lr
*
delta
acc_delta
=
rho
*
acc_delta
+
(
1
-
rho
)
*
delta
**
2
square_avg
=
_rho
*
square_avg
+
(
c1
-
_rho
)
*
grad
**
c
2
std
=
(
square_avg
+
_eps
)
**
c05
delta
=
(
acc_delta
+
_eps
)
**
c05
/
std
*
grad
param
-=
_
lr
*
delta
acc_delta
=
_rho
*
acc_delta
+
(
c1
-
_rho
)
*
delta
**
c
2
states
[
"square_avg"
].
_reset
(
square_avg
)
states
[
"acc_delta"
].
_reset
(
acc_delta
)
imperative/python/megengine/optimizer/adagrad.py
浏览文件 @
1fed5929
...
...
@@ -10,8 +10,8 @@ from typing import Iterable, Union
import
numpy
as
np
from
..
functional
import
sqrt
from
..tensor
import
Parameter
from
..
core.tensor.tensor
import
Tensor
from
..tensor
import
Parameter
,
tensor
from
.optimizer
import
Optimizer
...
...
@@ -61,6 +61,16 @@ class Adagrad(Optimizer):
weight_decay
=
param_group
[
"weight_decay"
]
eps
=
param_group
[
"eps"
]
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr
=
tensor
([
lr
])
_lr_decay
=
tensor
([
lr_decay
])
_weight_decay
=
tensor
([
weight_decay
])
_eps
=
tensor
([
eps
])
c05
=
tensor
([
0.5
])
c1
=
tensor
([
1.0
])
c2
=
tensor
([
2.0
])
for
param
in
param_group
[
"params"
]:
if
param
.
grad
is
None
:
...
...
@@ -68,14 +78,14 @@ class Adagrad(Optimizer):
states
=
self
.
_state
[
param
]
step
=
states
[
"step"
]
step
+=
1.0
step
+=
c1
grad
=
param
.
grad
if
weight_decay
!=
0.0
:
grad
+=
param
*
weight_decay
grad
+=
param
*
_
weight_decay
square_avg
=
states
[
"square_avg"
]
square_avg
+=
grad
**
2
delta
=
grad
/
sqrt
(
square_avg
+
eps
)
clr
=
lr
/
(
1
+
(
step
-
1
)
*
lr_decay
)
square_avg
+=
grad
**
c
2
delta
=
grad
/
(
square_avg
+
_eps
)
**
c05
clr
=
_lr
/
(
c1
+
(
step
-
c1
)
*
_
lr_decay
)
param
-=
clr
*
delta
imperative/python/megengine/optimizer/adam.py
浏览文件 @
1fed5929
...
...
@@ -8,7 +8,8 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Iterable
,
Tuple
,
Union
from
..tensor
import
Parameter
from
..core.tensor.tensor
import
Tensor
from
..tensor
import
Parameter
,
tensor
from
.optimizer
import
Optimizer
...
...
@@ -58,6 +59,15 @@ class Adam(Optimizer):
eps
=
param_group
[
"eps"
]
beta0
,
beta1
=
param_group
[
"betas"
]
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr
=
tensor
([
lr
])
_weight_decay
=
tensor
([
weight_decay
])
_eps
=
tensor
([
eps
])
_beta0
,
_beta1
=
tensor
([
beta0
]),
tensor
([
beta1
])
c1
=
tensor
([
1.0
])
c05
=
tensor
([
0.5
])
for
param
in
param_group
[
"params"
]:
if
param
.
grad
is
None
:
...
...
@@ -65,20 +75,20 @@ class Adam(Optimizer):
grad
=
param
.
grad
if
weight_decay
!=
0.0
:
grad
+=
param
*
weight_decay
grad
+=
param
*
_
weight_decay
states
=
self
.
_state
[
param
]
step
=
states
[
"step"
]
step
+=
1.0
step
+=
c1
exp_avg
=
states
[
"exp_avg"
]
exp_avg_sq
=
states
[
"exp_avg_sq"
]
exp_avg
=
beta0
*
exp_avg
+
grad
*
(
1
-
beta0
)
exp_avg_sq
=
beta1
*
exp_avg_sq
+
(
1
-
beta1
)
*
(
grad
*
grad
)
exp_avg
=
_beta0
*
exp_avg
+
grad
*
(
c1
-
_
beta0
)
exp_avg_sq
=
_beta1
*
exp_avg_sq
+
(
c1
-
_
beta1
)
*
(
grad
*
grad
)
delta
=
(
exp_avg
/
(
1
-
beta0
**
step
))
/
(
(
exp_avg_sq
/
(
1
-
beta1
**
step
))
**
0.5
+
eps
delta
=
(
exp_avg
/
(
c1
-
_
beta0
**
step
))
/
(
(
exp_avg_sq
/
(
c1
-
_beta1
**
step
))
**
c05
+
_
eps
)
param
-=
lr
*
delta
param
-=
_
lr
*
delta
# not inplace change, need to update underlying tensor handler in state
states
[
"exp_avg"
].
_reset
(
exp_avg
)
...
...
imperative/python/megengine/optimizer/optimizer.py
浏览文件 @
1fed5929
...
...
@@ -15,6 +15,7 @@ from typing import Union
import
numpy
as
np
from
..core.tensor.utils
import
set_convert_inputs
from
..tensor
import
Parameter
,
Tensor
from
..utils.deprecation
import
deprecated
...
...
@@ -143,6 +144,9 @@ class Optimizer(metaclass=ABCMeta):
Performs a single optimization step.
"""
# set the globle state `_enable_convert_inputs` to `False` to disable
# the `convert_inputs` for param updates
backup
=
set_convert_inputs
(
False
)
for
group
in
self
.
param_groups
:
if
isinstance
(
group
[
"params"
],
set
):
raise
TypeError
(
...
...
@@ -151,6 +155,8 @@ class Optimizer(metaclass=ABCMeta):
"Please use a list instead."
)
self
.
_updates
(
group
)
# restore the globle state `_enable_convert_inputs`
set_convert_inputs
(
backup
)
return
self
@
deprecated
(
version
=
"1.0"
,
reason
=
"use clear_grad instead"
)
...
...
imperative/python/megengine/optimizer/sgd.py
浏览文件 @
1fed5929
...
...
@@ -8,7 +8,8 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Iterable
,
Union
from
..tensor
import
Parameter
from
..core.tensor.tensor
import
Tensor
from
..tensor
import
Parameter
,
tensor
from
.optimizer
import
Optimizer
...
...
@@ -52,18 +53,24 @@ class SGD(Optimizer):
weight_decay
=
param_group
[
"weight_decay"
]
momentum
=
param_group
[
"momentum"
]
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr
=
tensor
([
lr
])
_weight_decay
=
tensor
([
weight_decay
])
_momentum
=
tensor
([
momentum
])
for
param
in
param_group
[
"params"
]:
if
param
.
grad
is
None
:
continue
grad
=
param
.
grad
if
weight_decay
!=
0.0
:
grad
+=
param
*
weight_decay
grad
+=
param
*
_
weight_decay
if
momentum
:
v
=
self
.
_state
[
param
][
"momentum_buffer"
]
v
=
momentum
*
v
+
grad
param
-=
lr
*
v
v
=
_
momentum
*
v
+
grad
param
-=
_
lr
*
v
self
.
_state
[
param
][
"momentum_buffer"
].
_reset
(
v
)
else
:
param
-=
lr
*
grad
param
-=
_
lr
*
grad
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录