Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
65432d3b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
65432d3b
编写于
3月 25, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/module): fix torch subgraph under jit.trace with symbolic=False
GitOrigin-RevId: a208ba79d964baf78bdd9d10264dcb9166bb8506
上级
862de28a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
70 addition
and
0 deletion
+70
-0
python_module/megengine/module/pytorch/pytorch.py
python_module/megengine/module/pytorch/pytorch.py
+2
-0
python_module/test/unit/module/test_pytorch.py
python_module/test/unit/module/test_pytorch.py
+68
-0
未找到文件。
python_module/megengine/module/pytorch/pytorch.py
浏览文件 @
65432d3b
...
...
@@ -305,6 +305,8 @@ class PyTorchSubgraphImplOpr(mgb.craniotome.CraniotomeBase):
ret
.
__dict__
[
"_last_forward_inputs"
]
=
d0
.
pop
(
"_last_forward_inputs"
)
ret
.
__dict__
[
"_last_forward_outputs"
]
=
d0
.
pop
(
"_last_forward_outputs"
)
ret
.
__dict__
[
"_last_forward_params"
]
=
d0
.
pop
(
"_last_forward_params"
)
ret
.
__dict__
[
"_func"
]
=
d0
.
pop
(
"_func"
)
d0
.
pop
(
"_grad_opr"
)
later_copy
=
self
.
_grad_opr
in
_copy_dict
...
...
python_module/test/unit/module/test_pytorch.py
浏览文件 @
65432d3b
...
...
@@ -13,8 +13,11 @@ from helpers import randomTorch
import
megengine
as
mge
import
megengine._internal
as
mgb
import
megengine.functional
import
megengine.optimizer
as
optimizer
from
megengine
import
get_default_device
,
set_default_device
from
megengine.core
import
Parameter
,
tensor
from
megengine.jit
import
trace
from
megengine.module
import
Module
as
MGEModule
from
megengine.module.pytorch
import
PyTorchModule
from
megengine.test
import
assertTensorClose
...
...
@@ -72,3 +75,68 @@ def test_pytorch_backward():
return
mge
.
functional
.
grad
(
mge_e
,
mge_a
,
use_virtual_grad
=
False
)
assertTensorClose
(
get_pytorch_backward
().
numpy
(),
get_mge_backward
().
numpy
())
def
test_pytorch_mixed
():
init_param
=
(
np
.
array
([
2.0
],
dtype
=
np
.
float32
),
np
.
array
([
3.0
],
dtype
=
np
.
float32
))
lr
=
1.0
class
Mixed
(
MGEModule
):
class
SubModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
multiplier
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
init_param
[
0
]))
def
forward
(
self
,
inp
):
return
inp
*
self
.
multiplier
def
__init__
(
self
):
super
().
__init__
()
self
.
torch_module
=
PyTorchModule
(
self
.
SubModule
())
a
=
list
(
self
.
SubModule
().
named_parameters
(
recurse
=
True
))
a
=
list
(
self
.
SubModule
().
parameters
())
self
.
multiplier
=
Parameter
(
np
.
array
(
init_param
[
1
]),
dtype
=
np
.
float32
)
def
forward
(
self
,
inp
):
return
self
.
torch_module
(
inp
)
*
self
.
multiplier
def
run
(
step
,
enable_trace
,
use_symbolic
):
def
train_func
(
data
,
net
=
None
,
opt
=
None
):
pred
=
net
(
data
)
opt
.
backward
(
pred
)
return
pred
if
enable_trace
:
train_func
=
trace
(
train_func
,
symbolic
=
use_symbolic
)
net
=
Mixed
()
data
=
tensor
()
opt
=
optimizer
.
SGD
(
net
.
parameters
(),
lr
=
lr
)
saved_param
=
init_param
for
i
in
range
(
step
):
opt
.
zero_grad
()
data
.
set_value
([
i
+
1.0
])
output
=
train_func
(
data
,
net
=
net
,
opt
=
opt
)
opt
.
step
()
expect_param
=
(
saved_param
[
0
]
-
lr
*
saved_param
[
1
]
*
data
.
numpy
(),
saved_param
[
1
]
-
lr
*
saved_param
[
0
]
*
data
.
numpy
(),
)
assertTensorClose
(
output
.
numpy
(),
saved_param
[
0
]
*
saved_param
[
1
]
*
data
.
numpy
()
)
torch_param
=
net
.
torch_module
.
_torch_params
[
0
].
detach
().
cpu
()
assertTensorClose
(
torch_param
.
numpy
(),
expect_param
[
0
])
assertTensorClose
(
net
.
multiplier
.
numpy
(),
expect_param
[
1
])
saved_param
=
expect_param
run
(
1
,
False
,
False
)
run
(
1
,
True
,
True
)
run
(
1
,
True
,
False
)
run
(
2
,
False
,
False
)
run
(
2
,
True
,
True
)
run
(
2
,
True
,
False
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录