Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7225b0f0
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
7225b0f0
编写于
6月 10, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/utils): use static infer manager to get value of network.varnode
GitOrigin-RevId: ecc47edab8334e3f41a409020db2a9090db62147
上级
ffe2bb2e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
39 addition
and
17 deletion
+39
-17
imperative/python/megengine/utils/network_node.py
imperative/python/megengine/utils/network_node.py
+3
-8
imperative/python/test/helpers/utils.py
imperative/python/test/helpers/utils.py
+13
-1
imperative/python/test/unit/core/test_tensor_wrapper.py
imperative/python/test/unit/core/test_tensor_wrapper.py
+16
-5
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+7
-3
未找到文件。
imperative/python/megengine/utils/network_node.py
浏览文件 @
7225b0f0
...
...
@@ -18,7 +18,6 @@ from ..core._trace_option import use_symbolic_shape
from
..core._wrap
import
Device
from
..core.ops
import
builtin
from
..core.tensor.array_method
import
ArrayMethodMixin
from
..core.tensor.megbrain_graph
import
OutputNode
from
.comp_graph_tools
import
replace_vars
from
.module_stats
import
(
preprocess_receptive_field
,
...
...
@@ -106,9 +105,7 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
return
id
(
self
)
def
numpy
(
self
):
o
=
OutputNode
(
self
.
var
)
self
.
graph
.
compile
(
o
.
outputs
).
execute
()
return
o
.
get_value
().
numpy
()
return
super
().
numpy
()
def
_reset
(
self
,
other
):
if
not
isinstance
(
other
,
VarNode
):
...
...
@@ -141,15 +138,13 @@ class OpNode(NetworkNode):
@
property
def
id
(
self
):
if
self
.
_opr
is
not
None
:
return
self
.
_opr
.
id
return
id
(
self
)
@
property
def
priority
(
self
):
if
self
.
_opr
is
not
None
:
return
self
.
_opr
.
priority
return
0
return
(
self
.
_opr
.
priority
,
self
.
_opr
.
id
)
return
(
0
,
0
)
@
classmethod
def
load
(
cls
,
opr
):
...
...
imperative/python/test/helpers/utils.py
浏览文件 @
7225b0f0
...
...
@@ -5,6 +5,7 @@ import numpy as np
import
megengine.core.tensor.megbrain_graph
as
G
import
megengine.utils.comp_graph_tools
as
cgtools
from
megengine
import
tensor
from
megengine.core.tensor.megbrain_graph
import
OutputNode
from
megengine.jit
import
trace
from
megengine.utils.network_node
import
VarNode
...
...
@@ -12,8 +13,10 @@ from megengine.utils.network_node import VarNode
def
_default_compare_fn
(
x
,
y
):
if
isinstance
(
x
,
np
.
ndarray
):
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
1e-6
)
el
se
:
el
if
isinstance
(
x
,
tensor
)
:
np
.
testing
.
assert_allclose
(
x
.
numpy
(),
y
,
rtol
=
1e-6
)
else
:
np
.
testing
.
assert_allclose
(
get_var_value
(
x
),
y
,
rtol
=
1e-6
)
def
make_tensor
(
x
,
network
=
None
,
device
=
None
):
...
...
@@ -25,6 +28,15 @@ def make_tensor(x, network=None, device=None):
return
tensor
(
x
,
device
=
device
)
def
get_var_value
(
x
):
try
:
o
=
OutputNode
(
x
.
var
)
o
.
graph
.
compile
(
o
.
outputs
).
execute
()
return
o
.
get_value
().
numpy
()
except
RuntimeError
:
raise
ValueError
(
"value invalid!"
)
def
opr_test
(
cases
,
func
,
...
...
imperative/python/test/unit/core/test_tensor_wrapper.py
浏览文件 @
7225b0f0
...
...
@@ -10,7 +10,7 @@ import copy
import
numpy
as
np
import
pytest
from
utils
import
make_tensor
from
utils
import
get_var_value
,
make_tensor
from
megengine.core.tensor.dtype
import
get_scale
,
get_zero_point
,
qint8
,
quint8
from
megengine.tensor
import
Parameter
,
Tensor
...
...
@@ -55,7 +55,12 @@ def test_matmul(is_varnode):
A
=
make_tensor
(
np
.
random
.
rand
(
5
,
7
).
astype
(
"float32"
),
network
)
B
=
make_tensor
(
np
.
random
.
rand
(
7
,
10
).
astype
(
"float32"
),
network
)
C
=
A
@
B
np
.
testing
.
assert_almost_equal
(
C
.
numpy
(),
A
.
numpy
()
@
B
.
numpy
(),
decimal
=
6
)
if
is_varnode
:
np
.
testing
.
assert_almost_equal
(
get_var_value
(
C
),
get_var_value
(
A
)
@
get_var_value
(
B
),
decimal
=
6
)
else
:
np
.
testing
.
assert_almost_equal
(
C
.
numpy
(),
A
.
numpy
()
@
B
.
numpy
(),
decimal
=
6
)
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
...
...
@@ -116,11 +121,17 @@ def test_set_subtensor(is_varnode):
x
=
make_tensor
([
1
,
2
,
3
],
network
)
x
[:]
=
[
1
,
1
,
1
]
np
.
testing
.
assert_almost_equal
(
x
.
numpy
(),
[
1
,
1
,
1
],
decimal
=
6
)
np
.
testing
.
assert_almost_equal
(
get_var_value
(
x
)
if
is_varnode
else
x
.
numpy
(),
[
1
,
1
,
1
],
decimal
=
6
)
x
[[
0
,
2
]]
=
[
3
,
2
]
np
.
testing
.
assert_almost_equal
(
x
.
numpy
(),
[
3
,
1
,
2
],
decimal
=
6
)
np
.
testing
.
assert_almost_equal
(
get_var_value
(
x
)
if
is_varnode
else
x
.
numpy
(),
[
3
,
1
,
2
],
decimal
=
6
)
x
[
1
:
3
]
=
[
4
,
5
]
np
.
testing
.
assert_almost_equal
(
x
.
numpy
(),
[
3
,
4
,
5
],
decimal
=
6
)
np
.
testing
.
assert_almost_equal
(
get_var_value
(
x
)
if
is_varnode
else
x
.
numpy
(),
[
3
,
4
,
5
],
decimal
=
6
)
def
test_computing_with_numpy_array
():
...
...
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
7225b0f0
...
...
@@ -11,7 +11,7 @@ import platform
import
numpy
as
np
import
pytest
from
utils
import
make_tensor
,
opr_test
from
utils
import
get_var_value
,
make_tensor
,
opr_test
import
megengine.functional
as
F
from
megengine
import
tensor
...
...
@@ -75,8 +75,12 @@ def test_condtake(is_varnode):
xx
=
make_tensor
(
x
,
network
)
yy
=
make_tensor
(
y
,
network
)
val
,
idx
=
F
.
cond_take
(
yy
,
xx
)
np
.
testing
.
assert_equal
(
val
.
numpy
(),
x
[
y
])
np
.
testing
.
assert_equal
(
idx
.
numpy
(),
np
.
where
(
y
.
reshape
(
-
1
))[
0
])
if
is_varnode
:
np
.
testing
.
assert_equal
(
get_var_value
(
val
),
x
[
y
])
np
.
testing
.
assert_equal
(
get_var_value
(
idx
),
np
.
where
(
y
.
reshape
(
-
1
))[
0
])
else
:
np
.
testing
.
assert_equal
(
val
.
numpy
(),
x
[
y
])
np
.
testing
.
assert_equal
(
idx
.
numpy
(),
np
.
where
(
y
.
reshape
(
-
1
))[
0
])
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录