Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f3863810
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f3863810
编写于
7月 01, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative): fix inplace operation of optim
GitOrigin-RevId: 2aaa71eb66c1096d117ed70d2cadae3f85e32ab6
上级
9330929f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
43 addition
and
1 deletion
+43
-1
imperative/python/megengine/optimizer/optimizer.py
imperative/python/megengine/optimizer/optimizer.py
+1
-1
imperative/python/test/unit/amp/test_convert_format.py
imperative/python/test/unit/amp/test_convert_format.py
+42
-0
未找到文件。
imperative/python/megengine/optimizer/optimizer.py
浏览文件 @
f3863810
...
@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta):
...
@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is "
"optimizer can only optimize Parameters, but one of the params is "
+
str
(
type
(
param
))
+
str
(
type
(
param
))
)
)
param
[...]
=
Tensor
(
param
.
numpy
()
,
no_cache
=
True
)
param
[...]
=
Tensor
(
param
,
no_cache
=
True
)
for
name
,
default
in
self
.
_defaults
.
items
():
for
name
,
default
in
self
.
_defaults
.
items
():
if
default
is
required
and
name
not
in
param_group
:
if
default
is
required
and
name
not
in
param_group
:
...
...
imperative/python/test/unit/amp/test_convert_format.py
浏览文件 @
f3863810
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
megengine
as
mge
import
megengine.autodiff
as
autodiff
import
megengine.functional
as
F
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.module
as
M
import
megengine.optimizer
as
optim
from
megengine
import
Parameter
,
Tensor
,
amp
from
megengine
import
Parameter
,
Tensor
,
amp
from
megengine.core._config
import
set_auto_format_convert
from
megengine.core._config
import
set_auto_format_convert
from
megengine.core._trace_option
import
use_symbolic_shape
from
megengine.core._trace_option
import
use_symbolic_shape
...
@@ -57,3 +60,42 @@ def test_convert_module(is_inplace):
...
@@ -57,3 +60,42 @@ def test_convert_module(is_inplace):
)
)
else
:
else
:
assert
param
.
shape
==
expected_shape
[
name
],
name
assert
param
.
shape
==
expected_shape
[
name
],
name
class
Module
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
M
.
Conv2d
(
3
,
16
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
bn
=
M
.
BatchNorm2d
(
16
)
def
forward
(
self
,
x
):
out
=
F
.
relu
(
self
.
bn
(
self
.
conv
(
x
)))
return
out
def
test_format_remained
():
m
=
Module
()
m
=
amp
.
convert_module_format
(
m
)
gm
=
autodiff
.
GradManager
().
attach
(
m
.
parameters
())
opt
=
optim
.
SGD
(
m
.
parameters
(),
lr
=
0.01
)
scaler
=
amp
.
GradScaler
()
image
=
mge
.
tensor
(
np
.
random
.
normal
(
size
=
(
1
,
3
,
224
,
224
)),
dtype
=
"float32"
)
label
=
mge
.
tensor
(
np
.
ones
((
1
,
224
,
224
)),
dtype
=
"int32"
)
image
=
amp
.
convert_tensor_format
(
image
)
@
amp
.
autocast
(
enabled
=
True
)
def
train_step
(
image
):
with
gm
:
logits
=
m
(
image
)
loss
=
F
.
nn
.
cross_entropy
(
logits
,
label
)
scaler
.
backward
(
gm
,
loss
)
opt
.
step
().
clear_grad
()
return
logits
for
_
in
range
(
5
):
res
=
train_step
(
image
)
assert
res
.
format
==
"nhwc"
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录