Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
69728969
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
411
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
69728969
编写于
5月 06, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/utils): fix network multiple outputs issue
GitOrigin-RevId: d22e639cd3cec9dfe09d1452b9c3c352862be911
上级
f36e99d3
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
176 addition
and
27 deletion
+176
-27
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+1
-1
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+3
-3
imperative/python/megengine/utils/network.py
imperative/python/megengine/utils/network.py
+33
-4
imperative/python/megengine/utils/network_node.py
imperative/python/megengine/utils/network_node.py
+29
-6
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+30
-3
imperative/python/test/unit/utils/test_network.py
imperative/python/test/unit/utils/test_network.py
+80
-10
未找到文件。
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
69728969
...
...
@@ -140,7 +140,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
else
:
if
ndim
!=
0
and
ndim
!=
1
:
raise
ValueError
(
"ndim != 1 or 0, get : %d"
%
ndim
)
if
not
isinstance
(
x
,
Tensor
):
if
not
isinstance
(
x
,
(
Tensor
,
SymbolVar
)
):
(
x
,)
=
Const
(
x
,
dtype
=
dtype
,
device
=
device
)(
*
reference
)
return
x
...
...
imperative/python/megengine/functional/tensor.py
浏览文件 @
69728969
...
...
@@ -334,7 +334,7 @@ def split(inp, nsplits_or_sections, axis=0):
x = tensor(np.random.random((10, 20)), dtype=np.float32)
y = F.split(x, 3)
z = F.split(x, [6, 17], axis=1)
print([i.numpy().shape for i in y])
print([i.numpy().shape for i in z])
...
...
@@ -686,9 +686,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor:
[1. 4.] [0 3]
"""
if
not
isinstance
(
x
,
Tensor
):
if
not
isinstance
(
x
,
(
Tensor
,
SymbolVar
)
):
raise
TypeError
(
"input must be a tensor"
)
if
not
isinstance
(
mask
,
Tensor
):
if
not
isinstance
(
mask
,
(
Tensor
,
SymbolVar
)
):
raise
TypeError
(
"mask must be a tensor"
)
if
mask
.
dtype
!=
np
.
bool_
:
raise
ValueError
(
"mask must be bool"
)
...
...
imperative/python/megengine/utils/network.py
浏览文件 @
69728969
...
...
@@ -17,6 +17,7 @@ import numpy as np
from
..core._imperative_rt
import
ComputingGraph
from
..core._imperative_rt.core2
import
SymbolVar
from
..core._trace_option
import
set_symbolic_shape
as
_set_symbolic_shape
from
..core.tensor
import
megbrain_graph
as
G
from
..logger
import
get_logger
from
.comp_graph_tools
import
get_dep_vars
,
get_opr_type
,
get_oprs_seq
...
...
@@ -182,8 +183,13 @@ class Network:
"""
def
_set_var_name
(
var
):
graph_var
=
G
.
VarNode
(
var
.
var
)
graph_var
.
name
=
var
.
name
return
graph_var
self
.
_compile
()
out
=
[
G
.
VarNode
(
var
.
var
)
for
var
in
self
.
output_vars
]
out
=
list
(
map
(
_set_var_name
,
self
.
output_vars
))
if
kwargs
.
pop
(
"arg_names"
,
False
):
logger
.
warning
(
...
...
@@ -231,15 +237,20 @@ class Network:
if
not
all
([
var
.
owner
for
var
in
vars
]):
self
.
add_dep_oprs
(
*
vars
)
for
var
in
vars
:
if
var
not
in
self
.
output_vars
:
# use method 'is' instead of 'in' to avoid
# compare VarNode use elemwise equal
if
not
any
(
var
is
_
for
_
in
self
.
output_vars
):
self
.
output_vars
.
append
(
var
)
def
remove_output
(
self
,
*
vars
:
VarNode
):
"""Removes vars from the network output node list.
"""
for
var
in
vars
:
if
var
in
self
.
output_vars
:
self
.
output_vars
.
remove
(
var
)
# use list pop instead of remove to avoid
# compare VarNode use elemwise equal
for
idx
,
out_var
in
enumerate
(
self
.
output_vars
):
if
var
is
out_var
:
self
.
output_vars
.
pop
(
idx
)
def
add_dep_oprs
(
self
,
*
vars
):
if
len
(
vars
)
==
0
:
...
...
@@ -434,6 +445,15 @@ class Network:
opnode
.
add_out_var
(
self
.
_get_var
(
var
))
return
opnode
else
:
# overwrite the opnode 'new' output VarNode with
# original one when output number larger than 1,
# or will cause dependence issue in _compiler step.
if
len
(
opr
.
outputs
)
>
1
:
opnode
=
self
.
all_oprs_map
[
opr
.
id
]
for
idx
,
output
in
enumerate
(
opnode
.
outputs
):
if
output
.
var
.
id
in
self
.
all_vars_map
:
opnode
.
outputs
[
idx
]
=
self
.
all_vars_map
[
output
.
var
.
id
]
return
None
def
_get_opr
(
self
,
x
):
...
...
@@ -449,6 +469,15 @@ class Network:
return
self
.
all_vars_map
[
x
.
id
]
def
set_symbolic_shape
(
option
:
bool
):
"""
Set the VarNode use symbolic shape or not, return the last status.
Please set to True and must recover after dump if want to change the input batch size.
:param option: True for enable symbolic shape.
"""
return
_set_symbolic_shape
(
option
)
def
as_varnode
(
obj
):
"""convert a :class:`.VarNode` compatible object to :class:`.VarNode`.
...
...
imperative/python/megengine/utils/network_node.py
浏览文件 @
69728969
...
...
@@ -14,7 +14,8 @@ from typing import Callable, Sequence
import
numpy
as
np
from
..core
import
_imperative_rt
as
rt
from
..core._imperative_rt.core2
import
SymbolVar
from
..core._imperative_rt.core2
import
SymbolVar
,
apply
from
..core._trace_option
import
use_symbolic_shape
from
..core._wrap
import
Device
from
..core.ops
import
builtin
from
..core.tensor.array_method
import
ArrayMethodMixin
...
...
@@ -53,15 +54,41 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
obj
.
owner
=
owner_opr
return
obj
def
_get_var_shape
(
self
,
axis
=
None
):
opdef
=
(
builtin
.
GetVarShape
()
if
axis
is
None
else
builtin
.
GetVarShape
(
axis
=
axis
)
)
return
apply
(
opdef
,
self
)[
0
]
@
property
def
partial_shape
(
self
):
"""Return the tuple type inferred shape of VarNode
"""
return
tuple
(
self
.
_get_var_shape
().
numpy
())
def
shapeof
(
self
,
axis
):
"""Return the symbolic shape of axis
"""
return
self
.
_get_var_shape
(
axis
=
axis
)
if
self
.
var
else
None
@
property
def
_tuple_shape
(
self
):
return
self
.
partial_shape
@
property
def
shape
(
self
):
"""Return the symbolic shape if using set_symbolic_shape(True)
else inferred shape
"""
rst
=
None
if
self
.
var
:
try
:
rst
=
self
.
var
.
shape
except
:
rst
=
None
return
rst
if
not
use_symbolic_shape
():
return
rst
return
self
.
_get_var_shape
()
if
self
.
var
else
None
@
property
def
dtype
(
self
):
...
...
@@ -78,10 +105,6 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
def
__hash__
(
self
):
return
id
(
self
)
@
property
def
_tuple_shape
(
self
):
return
self
.
var
.
shape
def
numpy
(
self
):
o
=
OutputNode
(
self
.
var
)
self
.
graph
.
compile
(
o
.
outputs
).
execute
()
...
...
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
69728969
...
...
@@ -19,7 +19,7 @@ from megengine.core._trace_option import use_symbolic_shape
from
megengine.core.tensor
import
megbrain_graph
as
G
from
megengine.core.tensor.utils
import
astensor1d
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.utils.network
import
Network
from
megengine.utils.network
import
Network
,
set_symbolic_shape
from
megengine.utils.network_node
import
VarNode
...
...
@@ -62,6 +62,22 @@ def test_concat(is_varnode):
opr_test
(
cases
,
run
,
ref_fn
=
lambda
x
,
y
:
np
.
concatenate
([
x
,
y
]),
network
=
network
)
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_condtake
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
else
:
network
=
None
x
=
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]).
astype
(
"float32"
)
y
=
np
.
array
([[
True
,
False
,
True
],
[
False
,
True
,
True
]])
xx
=
make_tensor
(
x
,
network
)
yy
=
make_tensor
(
y
,
network
)
val
,
idx
=
F
.
cond_take
(
yy
,
xx
)
np
.
testing
.
assert_equal
(
val
.
numpy
(),
x
[
y
])
np
.
testing
.
assert_equal
(
idx
.
numpy
(),
np
.
where
(
y
.
reshape
(
-
1
))[
0
])
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_concat_device
(
is_varnode
):
if
is_varnode
:
...
...
@@ -102,6 +118,7 @@ def test_stack(is_varnode):
def
test_split
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
saved_symbolic_shape
=
set_symbolic_shape
(
False
)
else
:
network
=
None
...
...
@@ -134,6 +151,9 @@ def test_split(is_varnode):
except
ValueError
as
e
:
assert
str
(
e
)
==
"Invalid nsplits_or_secions: [3, 3, 5]"
if
is_varnode
:
set_symbolic_shape
(
saved_symbolic_shape
)
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_reshape
(
is_varnode
):
...
...
@@ -161,6 +181,7 @@ def test_reshape(is_varnode):
def
test_reshape_shape_inference
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
saved_symbolic_shape
=
set_symbolic_shape
(
False
)
else
:
network
=
None
...
...
@@ -192,12 +213,15 @@ def test_reshape_shape_inference(is_varnode):
{
"input"
:
[
x_shape_unknown
,
tshp_known_unspec
],
"output"
:
[(
2
,
2
),]},
]
opr_test
(
cases
,
func
,
compare_fn
=
check_shape
,
test_trace
=
True
,
network
=
network
)
if
is_varnode
:
set_symbolic_shape
(
saved_symbolic_shape
)
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_squeeze
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
saved_symbolic_shape
=
set_symbolic_shape
(
False
)
else
:
network
=
None
...
...
@@ -209,6 +233,9 @@ def test_squeeze(is_varnode):
yy
=
F
.
squeeze
(
xx
,
axis
)
np
.
testing
.
assert_equal
(
y
,
yy
.
numpy
())
if
is_varnode
:
set_symbolic_shape
(
saved_symbolic_shape
)
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_expand_dims
(
is_varnode
):
...
...
@@ -358,7 +385,7 @@ def test_flatten(is_varnode):
data1
=
np
.
random
.
random
(
data1_shape
).
astype
(
np
.
float32
)
def
compare_fn
(
x
,
y
):
assert
x
.
shape
[
0
]
==
y
assert
x
.
_tuple_
shape
[
0
]
==
y
output0
=
(
2
*
3
*
4
*
5
,)
output1
=
(
4
*
5
*
6
*
7
,)
...
...
@@ -420,7 +447,7 @@ def test_broadcast(is_varnode):
data3
=
np
.
random
.
random
(
input3_shape
).
astype
(
np
.
float32
)
def
compare_fn
(
x
,
y
):
assert
x
.
shape
[
0
]
==
y
assert
x
.
_tuple_
shape
[
0
]
==
y
cases
=
[
{
"input"
:
[
data1
,
output1_shape
],
"output"
:
output1_shape
},
...
...
imperative/python/test/unit/utils/test_network.py
浏览文件 @
69728969
...
...
@@ -10,7 +10,7 @@ from megengine.jit.tracing import trace
from
megengine.tensor
import
Tensor
from
megengine.utils.comp_graph_tools
import
GraphInference
from
megengine.utils.network
import
Network
as
Net
from
megengine.utils.network
import
as_oprnode
from
megengine.utils.network
import
as_oprnode
,
set_symbolic_shape
from
megengine.utils.network_node
import
Host2DeviceCopy
,
VarNode
...
...
@@ -181,19 +181,22 @@ def test_add_input():
np
.
testing
.
assert_equal
(
out
[
"o1"
],
((
a
+
b
)
*
2
+
a
).
numpy
())
def
test_add_output
():
def
test_add_
remove_
output
():
a
=
Tensor
([
1.0
,
2.0
])
b
=
Tensor
([
3.0
,
4.0
])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
a
,
b
):
return
(
a
+
b
)
*
2
return
(
a
+
b
)
*
2
,
(
a
-
b
)
fwd
(
a
,
b
)
orig_model
=
io
.
BytesIO
()
fwd
.
dump
(
orig_model
,
arg_names
=
[
"a"
,
"b"
],
output_names
=
"o"
,
optimize_for_inference
=
False
orig_model
,
arg_names
=
[
"a"
,
"b"
],
output_names
=
[
"o1"
,
"o2"
],
optimize_for_inference
=
False
,
)
orig_model
.
seek
(
0
)
...
...
@@ -201,11 +204,13 @@ def test_add_output():
var_a
=
net
.
var_filter
.
name
(
"a"
).
as_unique
()
var_b
=
net
.
var_filter
.
name
(
"b"
).
as_unique
()
y
=
F
.
add
(
var_a
,
var_b
)
y
=
F
.
sigmoid
(
y
)
y
1
=
(
var_a
+
var_b
)
*
3
y
2
=
F
.
sigmoid
(
var_a
+
var_b
)
y
.
name
=
"o1"
net
.
add_output
(
y
)
net
.
remove_output
(
*
net
.
output_vars
)
y1
.
name
=
"new_o1"
y2
.
name
=
"new_o2"
net
.
add_output
(
y1
,
y2
)
modified_model
=
io
.
BytesIO
()
net
.
dump
(
modified_model
)
...
...
@@ -214,8 +219,8 @@ def test_add_output():
g
=
GraphInference
(
modified_model
)
out
=
g
.
run
(
a
.
numpy
(),
b
.
numpy
())
np
.
testing
.
assert_equal
(
out
[
"
o"
],
((
a
+
b
)
*
2
).
numpy
())
np
.
testing
.
assert_equal
(
out
[
"
o1
"
],
(
F
.
sigmoid
((
a
+
b
))).
numpy
())
np
.
testing
.
assert_equal
(
out
[
"
new_o1"
],
((
a
+
b
)
*
3
).
numpy
())
np
.
testing
.
assert_equal
(
out
[
"
new_o2
"
],
(
F
.
sigmoid
((
a
+
b
))).
numpy
())
def
test_query
():
...
...
@@ -343,3 +348,68 @@ def test_modify_opr_name():
net1
=
Net
.
load
(
modified_model
)
assert
net1
.
data_providers_filter
.
as_unique
().
name
==
"net1.net.a"
def
test_dump_cond_take
():
a
=
Tensor
([
1.0
,
2.0
])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
a
):
return
F
.
cond_take
(
a
>
1
,
a
)
fwd
(
a
)
orig_model
=
io
.
BytesIO
()
fwd
.
dump
(
orig_model
,
arg_names
=
[
"a"
],
output_names
=
[
"o1"
,
"o2"
],
optimize_for_inference
=
False
,
)
orig_model
.
seek
(
0
)
net
=
Net
.
load
(
orig_model
)
var_a
=
net
.
input_vars
[
0
]
val
,
idx
=
F
.
cond_take
(
var_a
>
1
,
var_a
)
net
.
remove_output
(
*
net
.
output_vars
)
val
.
name
=
"value"
idx
.
name
=
"index"
net
.
add_output
(
val
,
idx
)
modified_model
=
io
.
BytesIO
()
net
.
dump
(
modified_model
)
modified_model
.
seek
(
0
)
g
=
GraphInference
(
modified_model
)
out
=
g
.
run
(
a
.
numpy
())
data
=
a
.
numpy
()
mask
=
a
.
numpy
()
>
1
np
.
testing
.
assert_equal
(
out
[
"index"
],
np
.
where
(
mask
.
reshape
(
-
1
))[
0
])
np
.
testing
.
assert_equal
(
out
[
"value"
],
data
[
mask
])
def
test_set_symbolic_shape
():
a
=
Tensor
([
1.0
,
2.0
])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
a
):
return
F
.
relu
(
a
*
2
)
fwd
(
a
)
orig_model
=
io
.
BytesIO
()
fwd
.
dump
(
orig_model
,
arg_names
=
[
"a"
],
output_names
=
[
"o"
],
optimize_for_inference
=
False
,
)
orig_model
.
seek
(
0
)
net
=
Net
.
load
(
orig_model
)
var_a
=
net
.
input_vars
[
0
]
saved_symbolic_shape
=
set_symbolic_shape
(
True
)
assert
isinstance
(
var_a
.
shape
,
VarNode
)
set_symbolic_shape
(
False
)
assert
var_a
.
shape
==
var_a
.
partial_shape
set_symbolic_shape
(
saved_symbolic_shape
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录