Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8e5bf948
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
8e5bf948
编写于
5月 12, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/utils): fix bug of VarNode inplace operations
GitOrigin-RevId: fa9eec7079671a117809c3da8ae7338e12f345f0
上级
13b15fb0
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
94 addition
and
30 deletion
+94
-30
imperative/python/megengine/utils/network_node.py
imperative/python/megengine/utils/network_node.py
+9
-13
imperative/python/test/unit/core/test_tensor_wrapper.py
imperative/python/test/unit/core/test_tensor_wrapper.py
+85
-17
未找到文件。
imperative/python/megengine/utils/network_node.py
浏览文件 @
8e5bf948
...
...
@@ -6,10 +6,9 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
abc
import
json
import
sys
from
typing
import
Callable
,
Sequence
from
typing
import
Sequence
import
numpy
as
np
...
...
@@ -19,10 +18,7 @@ from ..core._trace_option import use_symbolic_shape
from
..core._wrap
import
Device
from
..core.ops
import
builtin
from
..core.tensor.array_method
import
ArrayMethodMixin
from
..core.tensor.indexing
import
getitem
as
_getitem
from
..core.tensor.indexing
import
setitem
as
_setitem
from
..core.tensor.megbrain_graph
import
InputNode
,
OutputNode
from
..tensor
import
Tensor
from
..core.tensor.megbrain_graph
import
OutputNode
from
.comp_graph_tools
import
replace_vars
from
.module_stats
import
(
preprocess_receptive_field
,
...
...
@@ -110,18 +106,18 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
self
.
graph
.
compile
(
o
.
outputs
).
execute
()
return
o
.
get_value
().
numpy
()
def
_
_getitem__
(
self
,
index
):
return
_getitem
(
self
,
index
)
def
__setitem__
(
self
,
index
,
value
):
if
index
is
not
Ellipsis
:
value
=
_setitem
(
self
,
index
,
value
)
def
_
reset
(
self
,
other
):
if
not
isinstance
(
other
,
VarNode
):
assert
self
.
graph
,
"VarNode _reset must have graph"
node
=
ImmutableTensor
(
other
,
graph
=
self
.
graph
)
node
.
compile
(
self
.
graph
)
other
=
node
.
outputs
[
0
]
if
self
.
owner
is
not
None
:
idx
=
self
.
owner
.
outputs
.
index
(
self
)
self
.
owner
.
outputs
[
idx
]
=
VarNode
(
self
.
var
,
owner_opr
=
self
.
owner
,
name
=
self
.
var
.
name
)
self
.
var
=
value
.
var
self
.
var
=
other
.
var
self
.
owner
=
None
def
set_owner_opr
(
self
,
owner_opr
):
...
...
imperative/python/test/unit/core/test_tensor_wrapper.py
浏览文件 @
8e5bf948
...
...
@@ -9,38 +9,81 @@
import
copy
import
numpy
as
np
import
pytest
from
utils
import
make_tensor
from
megengine.core.tensor.dtype
import
get_scale
,
get_zero_point
,
qint8
,
quint8
from
megengine.tensor
import
Tensor
from
megengine.utils.network
import
Network
def
test_basic
():
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_basic
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
else
:
network
=
None
x_np
=
np
.
random
.
rand
(
10
).
astype
(
"float32"
)
x
=
Tensor
(
x_np
)
x
=
make_tensor
(
x_np
,
network
)
y
=
x
*
x
y_np
=
y
.
numpy
()
np
.
testing
.
assert_almost_equal
(
y_np
,
x_np
*
x_np
)
def
test_literal_arith
():
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_literal_arith
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
else
:
network
=
None
x_np
=
np
.
random
.
rand
(
10
).
astype
(
"float32"
)
x
=
Tensor
(
x_np
)
x
=
make_tensor
(
x_np
,
network
)
y
=
x
*
2
y_np
=
y
.
numpy
()
np
.
testing
.
assert_almost_equal
(
y_np
,
x_np
*
2
)
def
test_matmul
():
A
=
Tensor
(
np
.
random
.
rand
(
5
,
7
).
astype
(
"float32"
))
B
=
Tensor
(
np
.
random
.
rand
(
7
,
10
).
astype
(
"float32"
))
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_matmul
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
else
:
network
=
None
A
=
make_tensor
(
np
.
random
.
rand
(
5
,
7
).
astype
(
"float32"
),
network
)
B
=
make_tensor
(
np
.
random
.
rand
(
7
,
10
).
astype
(
"float32"
),
network
)
C
=
A
@
B
np
.
testing
.
assert_almost_equal
(
C
.
numpy
(),
A
.
numpy
()
@
B
.
numpy
(),
decimal
=
6
)
def
test_reduce
():
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_inplace_add
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
else
:
network
=
None
x_np
=
np
.
random
.
rand
(
10
).
astype
(
"float32"
)
y_np
=
np
.
random
.
rand
(
10
).
astype
(
"float32"
)
x
=
make_tensor
(
x_np
,
network
)
y
=
make_tensor
(
y_np
,
network
)
y
+=
x
out_np
=
y
.
numpy
()
np
.
testing
.
assert_almost_equal
(
out_np
,
x_np
+
y_np
)
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_reduce
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
else
:
network
=
None
def
test_x
(
x_np
):
for
m
in
[
"sum"
,
"prod"
,
"min"
,
"max"
,
"mean"
]:
x
=
Tensor
(
x_np
)
x
=
make_tensor
(
x_np
,
network
)
y
=
getattr
(
x
,
m
)(
axis
=-
1
,
keepdims
=
True
)
np
.
testing
.
assert_almost_equal
(
y
.
numpy
(),
getattr
(
x_np
,
m
)(
-
1
),
decimal
=
6
)
...
...
@@ -50,16 +93,28 @@ def test_reduce():
test_x
(
np
.
array
([
True
,
False
,
True
]))
def
test_set_value
():
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_set_value
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
else
:
network
=
None
v0
=
np
.
random
.
random
((
2
,
3
)).
astype
(
np
.
float32
)
param
=
Tensor
(
v0
)
param
=
make_tensor
(
v0
,
network
)
v1
=
np
.
random
.
random
((
2
,
3
)).
astype
(
np
.
float32
)
param
[...]
=
v1
np
.
testing
.
assert_allclose
(
param
.
numpy
(),
v1
,
atol
=
5e-6
)
def
test_set_subtensor
():
x
=
Tensor
([
1
,
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_set_subtensor
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
else
:
network
=
None
x
=
make_tensor
([
1
,
2
,
3
],
network
)
x
[:]
=
[
1
,
1
,
1
]
np
.
testing
.
assert_almost_equal
(
x
.
numpy
(),
[
1
,
1
,
1
],
decimal
=
6
)
x
[[
0
,
2
]]
=
[
3
,
2
]
...
...
@@ -78,14 +133,27 @@ def test_computing_with_numpy_array():
np
.
testing
.
assert_equal
(
np
.
equal
(
xx
,
xx
).
numpy
(),
np
.
equal
(
x
,
x
))
def
test_transpose
():
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_transpose
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
else
:
network
=
None
x
=
np
.
random
.
rand
(
2
,
5
).
astype
(
"float32"
)
xx
=
Tensor
(
x
)
xx
=
make_tensor
(
x
,
network
)
np
.
testing
.
assert_almost_equal
(
xx
.
T
.
numpy
(),
x
.
T
)
def
test_as_type
():
x
=
Tensor
([
1
,
2
,
3
],
dtype
=
np
.
float32
)
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_as_type
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
else
:
network
=
None
x_np
=
np
.
array
([
1
,
2
,
3
],
dtype
=
np
.
float32
)
x
=
make_tensor
(
x_np
,
network
)
y
=
x
.
astype
(
qint8
(
0.1
))
np
.
testing
.
assert_almost_equal
(
get_scale
(
y
.
dtype
),
0.1
)
z
=
y
.
astype
(
qint8
(
0.2
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录