Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d6db4fea
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d6db4fea
编写于
7月 01, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/module): set no_cache=true when loading state dict
GitOrigin-RevId: 83281a3d4756bb257a991454ccc4c3b477a21b4c
上级
fea1bba2
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
40 addition
and
1 deletion
+40
-1
imperative/python/megengine/module/module.py
imperative/python/megengine/module/module.py
+5
-1
imperative/python/test/integration/test_save_load.py
imperative/python/test/integration/test_save_load.py
+32
-0
imperative/src/impl/ops/elemwise.cpp
imperative/src/impl/ops/elemwise.cpp
+3
-0
未找到文件。
imperative/python/megengine/module/module.py
浏览文件 @
d6db4fea
...
...
@@ -600,7 +600,11 @@ class Module(metaclass=ABCMeta):
k
,
var_shape
,
to_be_load_shape
)
)
var
.
_reset
(
type
(
var
)(
to_be_load
,
dtype
=
to_be_load
.
dtype
,
device
=
var
.
device
))
var
.
_reset
(
type
(
var
)(
to_be_load
,
dtype
=
to_be_load
.
dtype
,
device
=
var
.
device
,
no_cache
=
True
)
)
loaded
.
append
(
k
)
return
set
(
loaded
),
set
(
skipped
)
...
...
imperative/python/test/integration/test_save_load.py
浏览文件 @
d6db4fea
...
...
@@ -11,6 +11,7 @@ import numpy as np
import
megengine
as
mge
import
megengine.autodiff
as
ad
import
megengine.module
as
M
import
megengine.optimizer
as
optimizer
from
megengine
import
Parameter
,
tensor
from
megengine.module
import
Module
...
...
@@ -26,6 +27,37 @@ class Simple(Module):
return
x
class
Net
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
fc
=
M
.
Linear
(
1
,
1
)
def
forward
(
self
,
images
):
x
=
self
.
fc
(
images
)
loss
=
x
.
mean
()
*
10000
return
loss
def
test_load_state_dict_no_cache
(
monkeypatch
):
with
monkeypatch
.
context
()
as
mk
:
mk
.
setenv
(
"MEGENGINE_INPLACE_UPDATE"
,
"1"
)
net
=
Net
()
optim
=
optimizer
.
SGD
(
net
.
parameters
(),
lr
=
0.1
)
gm
=
ad
.
GradManager
().
attach
(
net
.
parameters
())
state
=
{
"fc.weight"
:
np
.
array
([[
0
]],
dtype
=
np
.
float32
),
"fc.bias"
:
np
.
array
([
0.0
],
dtype
=
np
.
float32
),
}
net
.
load_state_dict
(
state
)
images
=
mge
.
tensor
([[
0
]],
dtype
=
np
.
float32
)
with
gm
:
loss
=
net
(
images
)
gm
.
backward
(
loss
)
optim
.
step
()
optim
.
clear_grad
()
def
test_save_load
():
net
=
Simple
()
...
...
imperative/src/impl/ops/elemwise.cpp
浏览文件 @
d6db4fea
...
...
@@ -224,6 +224,9 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node(
SmallVector
<
TensorPtr
>
apply_inplace_add_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
){
mgb_assert
(
inputs
[
0
]
->
blob
().
unique
()
&&
inputs
[
0
]
->
blob
()
->
storage
().
unique
(),
"This inplace modification may change the elements of other tensors. "
"Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs correctly."
);
auto
dest
=
inputs
[
0
],
delta
=
inputs
[
1
],
alpha
=
inputs
[
2
],
beta
=
inputs
[
3
];
auto
tensor_to_scalar
=
[](
const
TensorPtr
&
tensor
)
->
float
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录