Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
dedecf69
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
dedecf69
编写于
6月 23, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative/utils): fix logical error of replace var
GitOrigin-RevId: 614302552cbeaa66cbc977ee81e5492b6023c1c4
上级
ea70d99b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
96 addition
and
19 deletion
+96
-19
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+1
-2
imperative/python/megengine/utils/network.py
imperative/python/megengine/utils/network.py
+32
-13
imperative/python/megengine/utils/network_node.py
imperative/python/megengine/utils/network_node.py
+17
-4
imperative/python/test/unit/utils/test_network.py
imperative/python/test/unit/utils/test_network.py
+46
-0
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
dedecf69
...
...
@@ -519,8 +519,7 @@ def _unwrap(x):
return
type
(
x
)(
map
(
_unwrap
,
x
))
if
isinstance
(
x
,
VarNode
):
return
x
.
_node
else
:
return
x
return
x
def
apply_normal_varnode
(
op
:
OpDef
,
*
args
:
VarNode
):
...
...
imperative/python/megengine/utils/network.py
浏览文件 @
dedecf69
...
...
@@ -12,14 +12,16 @@ import itertools
import
pickle
import
re
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
List
,
Sequence
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Sequence
from
..core
import
_imperative_rt
from
..core._imperative_rt
import
ComputingGraph
,
SerializationMetadata
from
..core._trace_option
import
set_symbolic_shape
as
_set_symbolic_shape
from
..core.tensor
import
megbrain_graph
as
G
from
..logger
import
get_logger
from
.comp_graph_tools
import
get_dep_vars
,
get_opr_type
,
get_oprs_seq
from
.network_node
import
(
ConstOpBase
,
Host2DeviceCopy
,
ImmutableTensor
,
NetworkNode
,
...
...
@@ -37,8 +39,10 @@ class Network:
self
.
_orig_inputs
=
[]
self
.
output_vars
=
[]
# output var of graph
self
.
_orig_outputs
=
[]
self
.
all_oprs_map
=
OrderedDict
()
self
.
all_vars_map
=
OrderedDict
()
self
.
all_oprs_map
=
OrderedDict
()
# _imperative_rt.graph.VarNode.id: VarNode
self
.
all_vars_map
=
(
OrderedDict
()
)
# _imperative_rt.graph.OperatorNode.id: OpNode
self
.
graph
=
ComputingGraph
()
self
.
_metadata
=
None
...
...
@@ -101,7 +105,7 @@ class Network:
self
.
all_oprs_map
=
{}
self
.
all_vars_map
=
{}
for
opr
in
self
.
all_oprs
:
if
isinstance
(
opr
,
(
ImmutableTensor
,
Host2DeviceCopy
)):
if
isinstance
(
opr
,
(
ConstOpBase
,
Host2DeviceCopy
)):
opr
.
compile
(
self
.
graph
)
else
:
opr
.
compile
()
...
...
@@ -295,6 +299,9 @@ class Network:
def
add_dep_oprs
(
self
,
*
vars
):
if
len
(
vars
)
==
0
:
vars
=
self
.
output_vars
assert
all
(
isinstance
(
var
,
VarNode
)
for
var
in
vars
),
"Only support add VarNode"
q
=
list
(
vars
)
while
len
(
q
)
>
0
:
cur
=
q
.
pop
(
0
)
...
...
@@ -368,11 +375,14 @@ class Network:
for
var
in
self
.
all_vars
:
if
var
in
repl_dict
:
repl_var
=
repl_dict
[
var
]
owner
=
repl_var
.
owner
idx
=
owner
.
outputs
.
index
(
repl_var
)
owner
.
outputs
[
idx
]
=
var
var
.
__dict__
.
update
(
repl_var
.
__dict__
)
var
.
var
=
repl_var
.
var
if
repl_var
is
var
:
continue
for
opnode
in
var
.
users
:
assert
var
in
opnode
.
inputs
opnode
.
inputs
=
[
repl_var
if
var
is
i
else
i
for
i
in
opnode
.
inputs
]
if
opnode
not
in
repl_var
.
users
:
repl_var
.
users
.
append
(
opnode
)
var
.
users
.
clear
()
self
.
_compile
()
def
replace_oprs
(
self
,
repl_dict
:
Dict
[
OpNode
,
OpNode
]):
...
...
@@ -473,14 +483,20 @@ class Network:
def
all_oprs_dict
(
self
):
return
self
.
opr_filter
.
as_dict
()
# used for loading and building graph
def
_add_opr
(
self
,
opr
):
def
_add_opr
(
self
,
opr
)
->
Optional
[
OpNode
]:
"""
Used for loading and building graph.
"""
assert
isinstance
(
opr
,
_imperative_rt
.
graph
.
OperatorNode
)
# TODO: use megbrain C++ RTTI to replace type string
if
opr
.
id
not
in
self
.
all_oprs_map
:
opnode
=
str_to_mge_class
(
get_opr_type
(
opr
)).
load
(
opr
)
self
.
all_oprs_map
[
opr
.
id
]
=
opnode
for
var
in
opr
.
inputs
:
opnode
.
add_inp_var
(
self
.
_get_var
(
var
))
varnode
=
self
.
_get_var
(
var
)
opnode
.
add_inp_var
(
varnode
)
varnode
.
users
.
append
(
opnode
)
for
var
in
opr
.
outputs
:
opnode
.
add_out_var
(
self
.
_get_var
(
var
))
return
opnode
...
...
@@ -503,7 +519,10 @@ class Network:
return
None
def
_get_var
(
self
,
x
):
# auto convert to VarNode of Network
"""
Convert :class:`~._imperative_rt.graph.VarNode` to :class:`~.VarNode`.
"""
assert
isinstance
(
x
,
_imperative_rt
.
graph
.
VarNode
)
if
x
.
id
not
in
self
.
all_vars_map
or
self
.
all_vars_map
[
x
.
id
].
var
!=
x
:
self
.
all_vars_map
[
x
.
id
]
=
VarNode
.
load
(
x
,
self
.
_get_opr
(
x
.
owner
))
return
self
.
all_vars_map
[
x
.
id
]
...
...
imperative/python/megengine/utils/network_node.py
浏览文件 @
dedecf69
...
...
@@ -37,6 +37,7 @@ class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
class
VarNode
(
NetworkNode
,
SymbolVar
,
ArrayMethodMixin
,
metaclass
=
VarNodeMeta
):
def
__init__
(
self
,
var
=
None
,
*
,
owner_opr
=
None
,
name
=
None
):
SymbolVar
.
__init__
(
self
,
var
)
self
.
users
=
[]
# List[OpNode]
self
.
owner
=
owner_opr
self
.
name
=
name
self
.
id
=
id
(
self
)
...
...
@@ -214,6 +215,7 @@ class Host2DeviceCopy(OpNode):
def
compile
(
self
,
graph
):
if
(
self
.
_opr
is
None
or
self
.
_opr
.
graph
!=
graph
or
self
.
_opr
.
outputs
[
0
].
comp_node
!=
self
.
device
or
self
.
_opr
.
outputs
[
0
].
shape
!=
self
.
shape
or
self
.
_opr
.
outputs
[
0
].
dtype
!=
self
.
dtype
...
...
@@ -226,10 +228,11 @@ class Host2DeviceCopy(OpNode):
assert
self
.
outputs
[
0
].
owner
is
self
class
ImmutableTensor
(
OpNode
):
type
=
"
ImmutableTensor
"
class
ConstOpBase
(
OpNode
):
type
=
"
ConstOpBase
"
def
__init__
(
self
,
data
=
None
,
name
=
None
,
device
=
None
,
graph
=
None
):
assert
type
(
self
)
is
not
ConstOpBase
,
"ConstOpBase cannot be instantiated"
super
().
__init__
()
self
.
name
=
name
self
.
outputs
=
[]
...
...
@@ -254,7 +257,7 @@ class ImmutableTensor(OpNode):
return
self
.
_opr
.
outputs
[
0
].
dtype
if
self
.
_opr
else
None
def
numpy
(
self
):
return
self
.
_opr
.
outputs
[
0
].
value
if
self
.
_opr
else
None
return
self
.
outputs
[
0
].
numpy
()
def
set_value
(
self
,
data
,
device
=
None
):
assert
self
.
graph
is
not
None
...
...
@@ -266,7 +269,7 @@ class ImmutableTensor(OpNode):
data
=
data
.
astype
(
np
.
float32
)
elif
data
.
dtype
==
np
.
int64
:
data
=
data
.
astype
(
np
.
int32
)
varnode
=
rt
.
make_const
(
self
.
graph
,
data
,
cn
,
data
.
dtype
,
self
.
name
)
varnode
=
type
(
self
).
rt_fun
(
self
.
graph
,
data
,
cn
,
data
.
dtype
,
self
.
name
)
if
len
(
self
.
outputs
)
==
0
:
self
.
outputs
.
append
(
VarNode
(
owner_opr
=
self
,
name
=
self
.
name
))
self
.
outputs
[
0
].
var
=
varnode
...
...
@@ -291,6 +294,16 @@ class ImmutableTensor(OpNode):
self
.
outputs
[
0
].
var
.
name
=
self
.
name
class
ImmutableTensor
(
ConstOpBase
):
type
=
"ImmutableTensor"
rt_fun
=
rt
.
make_const
class
SharedDeviceTensor
(
ConstOpBase
):
type
=
"SharedDeviceTensor"
rt_fun
=
rt
.
make_shared
class
ReadOnlyOpNode
(
OpNode
):
@
classmethod
def
load
(
cls
,
opr
):
...
...
imperative/python/test/unit/utils/test_network.py
浏览文件 @
dedecf69
...
...
@@ -130,6 +130,52 @@ def test_replace_opr():
np
.
testing
.
assert_equal
(
out
[
"o"
],
[
0
,
0
])
def
test_splice_network
():
x
=
F
.
ones
((
2
,))
y
=
F
.
ones
((
2
,))
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fun1
(
a
,
b
):
return
(
a
+
b
)
*
2
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fun2
(
a
):
return
a
*
2
-
1
model
=
io
.
BytesIO
()
fun1
(
x
,
y
)
fun2
(
x
)
fun1
.
dump
(
model
,
arg_names
=
[
"net1_i0"
,
"net1_i1"
],
output_names
=
[
"net1_o0"
],
optimize_for_inference
=
False
,
)
model
.
seek
(
0
)
net1
=
Net
.
load
(
model
)
model
.
seek
(
0
)
fun2
.
dump
(
model
,
arg_names
=
[
"net2_i0"
],
output_names
=
[
"net2_o0"
],
optimize_for_inference
=
False
,
)
model
.
seek
(
0
)
net2
=
Net
.
load
(
model
)
net1
.
add_output
(
*
net2
.
output_vars
)
var
=
net1
.
var_filter
.
name
(
"net1_i0"
).
as_unique
()
repl_var
=
net2
.
var_filter
.
name
(
"net2_o0"
).
as_unique
()
net1
.
replace_vars
({
var
:
repl_var
})
assert
"net1_i0"
not
in
[
var
.
name
for
var
in
net1
.
all_vars
]
assert
"net2_i0"
in
[
var
.
name
for
var
in
net1
.
all_vars
]
model
.
seek
(
0
)
net1
.
dump
(
model
,
keep_var_name
=
2
,
optimize_for_inference
=
False
)
model
.
seek
(
0
)
net
=
Net
.
load
(
model
)
assert
"net1_i0"
not
in
[
var
.
name
for
var
in
net
.
all_vars
]
assert
"net2_i0"
in
[
var
.
name
for
var
in
net
.
all_vars
]
def
test_modify_params
():
a
=
Tensor
([
1
,
2
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录