Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
31031804
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
31031804
编写于
8月 12, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/optimizer): only disable convert inputs in build-in optimizers
GitOrigin-RevId: 1a48fe318dc29c80c8f244923d64de79d9abd6b8
上级
fd24dc8e
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
11 addition
and
3 deletion
+11
-3
imperative/python/megengine/optimizer/adadelta.py
imperative/python/megengine/optimizer/adadelta.py
+1
-0
imperative/python/megengine/optimizer/adagrad.py
imperative/python/megengine/optimizer/adagrad.py
+1
-0
imperative/python/megengine/optimizer/adam.py
imperative/python/megengine/optimizer/adam.py
+1
-0
imperative/python/megengine/optimizer/adamw.py
imperative/python/megengine/optimizer/adamw.py
+1
-0
imperative/python/megengine/optimizer/optimizer.py
imperative/python/megengine/optimizer/optimizer.py
+6
-3
imperative/python/megengine/optimizer/sgd.py
imperative/python/megengine/optimizer/sgd.py
+1
-0
未找到文件。
imperative/python/megengine/optimizer/adadelta.py
浏览文件 @
31031804
...
...
@@ -48,6 +48,7 @@ class Adadelta(Optimizer):
defaults
=
dict
(
lr
=
lr
,
rho
=
rho
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
().
__init__
(
params
,
defaults
)
self
.
_disable_type_convert
=
True
def
_create_state
(
self
,
param_group
):
for
param
in
param_group
[
"params"
]:
...
...
imperative/python/megengine/optimizer/adagrad.py
浏览文件 @
31031804
...
...
@@ -48,6 +48,7 @@ class Adagrad(Optimizer):
defaults
=
dict
(
lr
=
lr
,
lr_decay
=
lr_decay
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
().
__init__
(
params
,
defaults
)
self
.
_disable_type_convert
=
True
def
_create_state
(
self
,
param_group
):
for
param
in
param_group
[
"params"
]:
...
...
imperative/python/megengine/optimizer/adam.py
浏览文件 @
31031804
...
...
@@ -47,6 +47,7 @@ class Adam(Optimizer):
defaults
=
dict
(
lr
=
lr
,
weight_decay
=
weight_decay
,
betas
=
betas
,
eps
=
eps
)
super
().
__init__
(
params
,
defaults
)
self
.
_disable_type_convert
=
True
def
_create_state
(
self
,
param_group
):
for
param
in
param_group
[
"params"
]:
...
...
imperative/python/megengine/optimizer/adamw.py
浏览文件 @
31031804
...
...
@@ -47,6 +47,7 @@ class AdamW(Optimizer):
defaults
=
dict
(
lr
=
lr
,
weight_decay
=
weight_decay
,
betas
=
betas
,
eps
=
eps
)
super
().
__init__
(
params
,
defaults
)
self
.
_disable_type_convert
=
True
def
_create_state
(
self
,
param_group
):
for
param
in
param_group
[
"params"
]:
...
...
imperative/python/megengine/optimizer/optimizer.py
浏览文件 @
31031804
...
...
@@ -42,6 +42,7 @@ class Optimizer(metaclass=ABCMeta):
):
self
.
_state
=
dict
()
self
.
_defaults
=
defaults
self
.
_disable_type_convert
=
False
if
isinstance
(
params
,
(
Parameter
,
dict
)):
params
=
[
params
]
...
...
@@ -149,6 +150,7 @@ class Optimizer(metaclass=ABCMeta):
# set the globle state `_enable_convert_inputs` to `False` to disable
# the `convert_inputs` for param updates
set_option
(
"record_computing_path"
,
0
)
if
self
.
_disable_type_convert
:
backup
=
set_convert_inputs
(
False
)
for
group
in
self
.
param_groups
:
if
isinstance
(
group
[
"params"
],
set
):
...
...
@@ -160,6 +162,7 @@ class Optimizer(metaclass=ABCMeta):
push_scope
(
"step"
)
self
.
_updates
(
group
)
pop_scope
(
"step"
)
if
self
.
_disable_type_convert
:
# restore the globle state `_enable_convert_inputs`
set_convert_inputs
(
backup
)
set_option
(
"record_computing_path"
,
1
)
...
...
imperative/python/megengine/optimizer/sgd.py
浏览文件 @
31031804
...
...
@@ -43,6 +43,7 @@ class SGD(Optimizer):
defaults
=
dict
(
lr
=
lr
,
momentum
=
momentum
,
weight_decay
=
weight_decay
)
super
().
__init__
(
params
,
defaults
)
self
.
_disable_type_convert
=
True
def
_create_state
(
self
,
param_group
):
if
param_group
[
"momentum"
]
!=
0.0
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录