Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b3b14fdf
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b3b14fdf
编写于
5月 13, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/jit): fix add_update semantic
GitOrigin-RevId: f541ac7c6d2dcef2f31c9d623ec92ef3a567f4db
上级
2bbce2f9
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
53 addition
and
5 deletion
+53
-5
python_module/megengine/core/tensor.py
python_module/megengine/core/tensor.py
+33
-4
python_module/megengine/functional/graph.py
python_module/megengine/functional/graph.py
+4
-1
python_module/megengine/jit/__init__.py
python_module/megengine/jit/__init__.py
+2
-0
python_module/test/unit/jit/test_jit.py
python_module/test/unit/jit/test_jit.py
+14
-0
未找到文件。
python_module/megengine/core/tensor.py
浏览文件 @
b3b14fdf
...
...
@@ -9,6 +9,7 @@
import
collections
import
functools
import
itertools
import
weakref
from
typing
import
Union
import
numpy
as
np
...
...
@@ -100,6 +101,14 @@ class MGBIndexWrapper:
)(
wrap_idx
(
idx
))
class
Guard
:
def
__init__
(
self
,
deleter
):
self
.
deleter
=
deleter
def
__del__
(
self
):
self
.
deleter
()
class
Tensor
:
r
"""The main data container in MegEngine.
Use :func:`~.tensor` to create a Tensor with existed data.
...
...
@@ -111,6 +120,7 @@ class Tensor:
self
.
_reset
(
val
,
requires_grad
=
requires_grad
)
def
_reset
(
self
,
val
=
None
,
*
,
requires_grad
=
None
):
self
.
__sym_override
=
None
if
val
is
None
:
self
.
__val
=
None
self
.
__sym
=
None
...
...
@@ -154,17 +164,20 @@ class Tensor:
return
self
.
numpy
().
item
()
def
_attach
(
self
,
comp_graph
,
*
,
volatile
=
True
):
sym
=
self
.
__sym_override
or
self
.
__sym
if
sym
:
if
sym
.
owner_graph
!=
comp_graph
:
raise
RuntimeError
(
"internal error"
)
return
sym
if
self
.
__val
:
return
self
.
__val
.
symvar
(
comp_graph
,
volatile
=
volatile
)
if
self
.
__sym
:
if
self
.
__sym
.
owner_graph
!=
comp_graph
:
raise
RuntimeError
(
"internal error"
)
return
self
.
__sym
else
:
raise
ValueError
(
"uninitialized"
)
@
property
def
_symvar
(
self
):
if
self
.
__sym_override
:
return
self
.
__sym_override
if
self
.
__sym
:
assert
not
self
.
__val
return
self
.
__sym
...
...
@@ -174,10 +187,26 @@ class Tensor:
return
self
.
_attach
(
get_default_graph
())
def
__mgb_symvar__
(
self
,
comp_graph
=
None
,
**
_
):
if
self
.
__sym_override
:
return
self
.
__sym_override
if
self
.
__val
and
comp_graph
:
return
self
.
_attach
(
comp_graph
)
return
self
.
_symvar
# read by mgb.opr
def
_override_symvar_during_trace
(
self
,
trace
,
symvar
):
assert
self
.
__val
and
not
self
.
__sym
assert
trace
is
type
(
trace
).
_active_instance
deleters
=
trace
.
_user_cache
.
setdefault
(
Tensor
,
set
())
self_ref
=
weakref
.
ref
(
self
)
def
restore
():
self
=
self_ref
()
if
self
is
not
None
:
self
.
__sym_override
=
None
deleters
.
add
(
Guard
(
restore
))
self
.
__sym_override
=
symvar
@
property
def
dtype
(
self
):
r
"""Return the data type of the tensor.
...
...
python_module/megengine/functional/graph.py
浏览文件 @
b3b14fdf
...
...
@@ -13,7 +13,7 @@ import megengine._internal as mgb
from
..core.graph
import
get_default_graph
from
..core.tensor
import
Tensor
,
wrap_io_tensor
from
..jit
import
barrier
,
mark_impure
from
..jit
import
barrier
,
mark_impure
,
trace
@
wrap_io_tensor
...
...
@@ -112,6 +112,9 @@ def add_update(
)
mark_impure
(
u
)
if
trace
.
_active_instance
:
dest
.
_override_symvar_during_trace
(
trace
.
_active_instance
,
u
)
return
Tensor
(
u
)
...
...
python_module/megengine/jit/__init__.py
浏览文件 @
b3b14fdf
...
...
@@ -367,10 +367,12 @@ class trace:
raise
RuntimeError
(
"nested trace is unsupported"
)
self
.
_status
=
self
.
_STARTED
type
(
self
).
_active_instance
=
self
self
.
_user_cache
=
{}
try
:
yield
finally
:
self
.
_status
=
self
.
_FINISHED
self
.
_user_cache
=
None
type
(
self
).
_active_instance
=
None
def
_run_wrapped
(
self
):
...
...
python_module/test/unit/jit/test_jit.py
浏览文件 @
b3b14fdf
...
...
@@ -16,6 +16,7 @@ import pytest
import
megengine
as
mge
import
megengine._internal
as
mgb
import
megengine.module
as
M
from
megengine
import
functional
as
F
from
megengine
import
jit
,
tensor
from
megengine.core.tensor
import
Tensor
from
megengine.jit
import
SublinearMemoryConfig
...
...
@@ -57,6 +58,19 @@ def test_symbolic():
f
.
trace
(
0
)
def
test_add_update_semantic
():
for
symbolic
in
[
False
,
True
]:
x
=
tensor
(
0
)
@
jit
.
trace
(
symbolic
=
symbolic
)
def
f
():
F
.
add_update
(
x
,
1
)
return
x
+
1
np
.
testing
.
assert_equal
(
f
().
numpy
(),
[
2
])
np
.
testing
.
assert_equal
(
f
().
numpy
(),
[
3
])
def
test_dump
():
@
jit
.
trace
(
symbolic
=
True
)
def
f
(
x
,
y
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录