Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c7e730bc
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c7e730bc
编写于
7月 26, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(traced_module): add some functions of graph modification
GitOrigin-RevId: 09691ebd334072f822226125acb11cebdc218618
上级
f88bd3ae
变更
6
展开全部
隐藏空白更改
内联
并排
Showing
6 changed file
with
612 addition
and
124 deletion
+612
-124
imperative/python/megengine/experimental/traced_module/__init__.py
...e/python/megengine/experimental/traced_module/__init__.py
+2
-0
imperative/python/megengine/experimental/traced_module/expr.py
...ative/python/megengine/experimental/traced_module/expr.py
+68
-13
imperative/python/megengine/experimental/traced_module/node.py
...ative/python/megengine/experimental/traced_module/node.py
+22
-8
imperative/python/megengine/experimental/traced_module/pytree.py
...ive/python/megengine/experimental/traced_module/pytree.py
+23
-0
imperative/python/megengine/experimental/traced_module/traced_module.py
...hon/megengine/experimental/traced_module/traced_module.py
+492
-98
imperative/python/test/unit/traced_module/test_modification.py
...ative/python/test/unit/traced_module/test_modification.py
+5
-5
未找到文件。
imperative/python/megengine/experimental/traced_module/__init__.py
浏览文件 @
c7e730bc
...
@@ -13,6 +13,8 @@ from .traced_module import (
...
@@ -13,6 +13,8 @@ from .traced_module import (
cpp_apply_module_trace
,
cpp_apply_module_trace
,
register_as_builtin
,
register_as_builtin
,
trace_module
,
trace_module
,
wrap
,
wrap_tensors
,
)
)
_register_all_builtin_module
()
_register_all_builtin_module
()
...
...
imperative/python/megengine/experimental/traced_module/expr.py
浏览文件 @
c7e730bc
...
@@ -11,7 +11,7 @@ import builtins
...
@@ -11,7 +11,7 @@ import builtins
import
collections
import
collections
import
copy
import
copy
import
inspect
import
inspect
from
typing
import
Callable
,
List
from
typing
import
Callable
,
Dict
,
List
from
...core._imperative_rt
import
OpDef
from
...core._imperative_rt
import
OpDef
from
...core._imperative_rt.core2
import
Tensor
as
RawTensor
from
...core._imperative_rt.core2
import
Tensor
as
RawTensor
...
@@ -29,10 +29,24 @@ class Expr:
...
@@ -29,10 +29,24 @@ class Expr:
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
"""
"""
__total_id
=
0
inputs
=
None
# type: List[Node]
inputs
=
None
# type: List[Node]
outputs
=
None
# type: List[Node]
outputs
=
None
# type: List[Node]
const_val
=
None
# type: List[Any]
const_val
=
None
# type: List[Any]
arg_def
=
None
# type: TreeDef
arg_def
=
None
# type: TreeDef
out_def
=
None
# type: TreeDef
_top_graph
=
None
# type: weakref.ReferenceType
def
__init__
(
self
)
->
None
:
self
.
_id
=
Expr
.
__total_id
Expr
.
__total_id
+=
1
self
.
_disable_remove
=
False
def
enable_remove
(
self
):
self
.
_disable_remove
=
False
def
disable_remove
(
self
):
self
.
_disable_remove
=
True
def
add_inputs
(
self
,
vals
):
def
add_inputs
(
self
,
vals
):
if
not
isinstance
(
vals
,
collections
.
abc
.
Sequence
):
if
not
isinstance
(
vals
,
collections
.
abc
.
Sequence
):
...
@@ -70,6 +84,22 @@ class Expr:
...
@@ -70,6 +84,22 @@ class Expr:
else
:
else
:
return
inputs
,
{}
return
inputs
,
{}
def
_replace_nodes
(
self
,
repl_dict
:
Dict
[
Node
,
Node
],
nodes
:
List
[
Node
]):
while
repl_dict
:
node
,
repl_node
=
repl_dict
.
popitem
()
assert
type
(
node
)
==
type
(
repl_node
)
assert
node
in
nodes
index
=
nodes
.
index
(
node
)
nodes
[
index
]
=
repl_node
repl_node
.
users
.
append
(
self
)
node
.
users
.
pop
(
self
)
def
replace_inputs
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
self
.
_replace_nodes
(
repl_dict
,
self
.
inputs
)
def
replace_outputs
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
self
.
_replace_nodes
(
repl_dict
,
self
.
outputs
)
@
property
@
property
def
kwargs
(
self
):
def
kwargs
(
self
):
_
,
kwargs
=
self
.
unflatten_args
(
self
.
inputs
)
_
,
kwargs
=
self
.
unflatten_args
(
self
.
inputs
)
...
@@ -80,12 +110,19 @@ class Expr:
...
@@ -80,12 +110,19 @@ class Expr:
args
,
_
=
self
.
unflatten_args
(
self
.
inputs
)
args
,
_
=
self
.
unflatten_args
(
self
.
inputs
)
return
args
return
args
@
property
def
top_graph
(
self
):
if
self
.
_top_graph
:
return
self
.
_top_graph
()
return
None
# expr: None (i.e. fake expression which is used to mark input)
# expr: None (i.e. fake expression which is used to mark input)
class
Input
(
Expr
):
class
Input
(
Expr
):
name
=
None
name
=
None
def
__init__
(
self
,
name
=
None
,
type
=
None
):
def
__init__
(
self
,
name
=
None
,
type
=
None
):
super
().
__init__
()
self
.
inputs
=
[]
self
.
inputs
=
[]
node_cls
=
type
if
type
else
Node
node_cls
=
type
if
type
else
Node
self
.
outputs
=
[
self
.
outputs
=
[
...
@@ -100,7 +137,7 @@ class Input(Expr):
...
@@ -100,7 +137,7 @@ class Input(Expr):
return
expr
.
outputs
[
0
]
return
expr
.
outputs
[
0
]
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"
{} = Input({})"
.
format
(
self
.
outputs
[
0
],
self
.
name
)
return
"
%{}: {} = Input({})"
.
format
(
self
.
_id
,
self
.
outputs
[
0
],
self
.
name
)
# expr: outputs = getattr(inputs[0], self.name)
# expr: outputs = getattr(inputs[0], self.name)
...
@@ -108,6 +145,7 @@ class GetAttr(Expr):
...
@@ -108,6 +145,7 @@ class GetAttr(Expr):
name
=
None
name
=
None
def
__init__
(
self
,
module
,
name
,
type
=
None
):
def
__init__
(
self
,
module
,
name
,
type
=
None
):
super
().
__init__
()
assert
isinstance
(
module
,
ModuleNode
)
assert
isinstance
(
module
,
ModuleNode
)
self
.
inputs
=
[
self
.
inputs
=
[
module
,
module
,
...
@@ -130,14 +168,15 @@ class GetAttr(Expr):
...
@@ -130,14 +168,15 @@ class GetAttr(Expr):
return
(
getattr
(
inputs
[
0
],
self
.
name
),)
return
(
getattr
(
inputs
[
0
],
self
.
name
),)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
'{} = GetAttr({}, "{}")'
.
format
(
return
'
%{}:
{} = GetAttr({}, "{}")'
.
format
(
self
.
outputs
[
0
],
self
.
inputs
[
0
],
self
.
name
self
.
_id
,
self
.
outputs
[
0
],
self
.
inputs
[
0
],
self
.
name
)
)
# expr: outputs = inputs[0].__call__(*inputs[1:])
# expr: outputs = inputs[0].__call__(*inputs[1:])
class
CallMethod
(
Expr
):
class
CallMethod
(
Expr
):
def
__init__
(
self
,
node
,
method
=
"__call__"
):
def
__init__
(
self
,
node
,
method
=
"__call__"
):
super
().
__init__
()
if
isinstance
(
node
,
type
):
if
isinstance
(
node
,
type
):
assert
issubclass
(
node
,
Tensor
)
assert
issubclass
(
node
,
Tensor
)
cls
=
Parameter
if
issubclass
(
node
,
Parameter
)
else
Tensor
cls
=
Parameter
if
issubclass
(
node
,
Parameter
)
else
Tensor
...
@@ -178,6 +217,8 @@ class CallMethod(Expr):
...
@@ -178,6 +217,8 @@ class CallMethod(Expr):
if
inspect
.
ismethod
(
meth
):
if
inspect
.
ismethod
(
meth
):
args
=
args
[
1
:]
args
=
args
[
1
:]
outputs
=
getattr
(
obj
,
self
.
method
)(
*
args
,
**
kwargs
)
outputs
=
getattr
(
obj
,
self
.
method
)(
*
args
,
**
kwargs
)
if
self
.
method
==
"__setitem__"
:
outputs
=
obj
if
outputs
is
None
:
if
outputs
is
None
:
return
outputs
return
outputs
outputs
,
_
=
tree_flatten
(
outputs
,
is_leaf
=
lambda
x
:
isinstance
(
x
,
RawTensor
))
outputs
,
_
=
tree_flatten
(
outputs
,
is_leaf
=
lambda
x
:
isinstance
(
x
,
RawTensor
))
...
@@ -186,8 +227,12 @@ class CallMethod(Expr):
...
@@ -186,8 +227,12 @@ class CallMethod(Expr):
def
__repr__
(
self
):
def
__repr__
(
self
):
args
=
", "
.
join
(
str
(
i
)
for
i
in
self
.
args
[
1
:])
args
=
", "
.
join
(
str
(
i
)
for
i
in
self
.
args
[
1
:])
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
return
"{} = {}.{}({})"
.
format
(
outputs
=
self
.
outputs
", "
.
join
(
str
(
i
)
for
i
in
self
.
outputs
),
if
self
.
out_def
:
outputs
=
self
.
out_def
.
unflatten
(
outputs
)
return
"%{}: {}{}.{}({})"
.
format
(
self
.
_id
,
str
(
outputs
)
+
" = "
if
outputs
else
""
,
self
.
args
[
0
],
self
.
args
[
0
],
self
.
method
,
self
.
method
,
", "
.
join
([
args
,
kwargs
]),
", "
.
join
([
args
,
kwargs
]),
...
@@ -199,6 +244,7 @@ class Apply(Expr):
...
@@ -199,6 +244,7 @@ class Apply(Expr):
opdef
=
None
opdef
=
None
def
__init__
(
self
,
opdef
):
def
__init__
(
self
,
opdef
):
super
().
__init__
()
assert
isinstance
(
opdef
,
OpDef
)
assert
isinstance
(
opdef
,
OpDef
)
self
.
opdef
=
opdef
self
.
opdef
=
opdef
self
.
inputs
=
[]
self
.
inputs
=
[]
...
@@ -213,7 +259,8 @@ class Apply(Expr):
...
@@ -213,7 +259,8 @@ class Apply(Expr):
return
apply
(
self
.
opdef
,
*
inputs
)
return
apply
(
self
.
opdef
,
*
inputs
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"{} = {}({})"
.
format
(
return
"%{}: {} = {}({})"
.
format
(
self
.
_id
,
", "
.
join
(
str
(
i
)
for
i
in
self
.
outputs
),
", "
.
join
(
str
(
i
)
for
i
in
self
.
outputs
),
self
.
opdef
,
self
.
opdef
,
", "
.
join
(
str
(
i
)
for
i
in
self
.
inputs
),
", "
.
join
(
str
(
i
)
for
i
in
self
.
inputs
),
...
@@ -241,6 +288,7 @@ class Apply(Expr):
...
@@ -241,6 +288,7 @@ class Apply(Expr):
class
CallFunction
(
Expr
):
class
CallFunction
(
Expr
):
def
__init__
(
self
,
func
):
def
__init__
(
self
,
func
):
super
().
__init__
()
assert
isinstance
(
func
,
Callable
)
assert
isinstance
(
func
,
Callable
)
self
.
func
=
func
self
.
func
=
func
self
.
const_val
=
[]
self
.
const_val
=
[]
...
@@ -255,16 +303,20 @@ class CallFunction(Expr):
...
@@ -255,16 +303,20 @@ class CallFunction(Expr):
def
interpret
(
self
,
*
inputs
):
def
interpret
(
self
,
*
inputs
):
args
,
kwargs
=
self
.
unflatten_args
(
inputs
)
args
,
kwargs
=
self
.
unflatten_args
(
inputs
)
outputs
=
self
.
func
(
*
args
,
**
kwargs
)
outputs
=
self
.
func
(
*
args
,
**
kwargs
)
outputs
=
(
if
outputs
is
None
:
outputs
if
isinstance
(
outputs
,
collections
.
abc
.
Sequence
)
else
(
outputs
,)
return
outputs
)
outputs
,
_
=
tree_flatten
(
outputs
,
is_leaf
=
lambda
x
:
isinstance
(
x
,
RawTensor
)
)
return
outputs
return
outputs
def
__repr__
(
self
):
def
__repr__
(
self
):
args
=
", "
.
join
(
str
(
i
)
for
i
in
self
.
args
)
args
=
", "
.
join
(
str
(
i
)
for
i
in
self
.
args
)
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
return
"{} = {}({})"
.
format
(
outputs
=
self
.
outputs
", "
.
join
(
str
(
i
)
for
i
in
self
.
outputs
),
if
self
.
out_def
:
outputs
=
self
.
out_def
.
unflatten
(
outputs
)
return
"%{}: {}{}({})"
.
format
(
self
.
_id
,
str
(
outputs
)
+
" = "
if
outputs
else
""
,
self
.
func
.
__module__
+
"."
+
self
.
func
.
__name__
,
self
.
func
.
__module__
+
"."
+
self
.
func
.
__name__
,
", "
.
join
([
args
,
kwargs
]),
", "
.
join
([
args
,
kwargs
]),
)
)
...
@@ -277,6 +329,7 @@ class Constant(Expr):
...
@@ -277,6 +329,7 @@ class Constant(Expr):
_constant_cache
=
{}
_constant_cache
=
{}
def
__init__
(
self
,
c
):
def
__init__
(
self
,
c
):
super
().
__init__
()
assert
isinstance
(
c
,
(
RawTensor
,
Module
))
assert
isinstance
(
c
,
(
RawTensor
,
Module
))
if
isinstance
(
c
,
Module
):
if
isinstance
(
c
,
Module
):
assert
module_tracer
.
is_builtin
(
c
)
assert
module_tracer
.
is_builtin
(
c
)
...
@@ -299,7 +352,9 @@ class Constant(Expr):
...
@@ -299,7 +352,9 @@ class Constant(Expr):
return
(
self
.
value
,)
return
(
self
.
value
,)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"{} = Constant({})"
.
format
(
self
.
outputs
[
0
],
type
(
self
.
value
))
return
"%{}: {} = Constant({})"
.
format
(
self
.
_id
,
self
.
outputs
[
0
],
type
(
self
.
value
)
)
def
__getstate__
(
self
):
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
=
self
.
__dict__
.
copy
()
...
...
imperative/python/megengine/experimental/traced_module/node.py
浏览文件 @
c7e730bc
...
@@ -30,6 +30,7 @@ class Node:
...
@@ -30,6 +30,7 @@ class Node:
__total_id
=
0
__total_id
=
0
_id
=
None
_id
=
None
_name
=
None
_name
=
None
_top_graph
=
None
# type: weakref.ReferenceType
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
):
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
):
self
.
expr
=
expr
self
.
expr
=
expr
...
@@ -48,6 +49,12 @@ class Node:
...
@@ -48,6 +49,12 @@ class Node:
else
:
else
:
return
"%{}"
.
format
(
self
.
_name
)
return
"%{}"
.
format
(
self
.
_name
)
@
property
def
top_graph
(
self
):
if
self
.
_top_graph
:
return
self
.
_top_graph
()
return
None
class
ModuleNode
(
Node
):
class
ModuleNode
(
Node
):
"""
"""
...
@@ -64,21 +71,28 @@ class ModuleNode(Node):
...
@@ -64,21 +71,28 @@ class ModuleNode(Node):
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
):
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
):
super
().
__init__
(
expr
,
name
)
super
().
__init__
(
expr
,
name
)
self
.
actual_mnode
=
[]
def
__repr__
(
self
):
def
__repr__
(
self
):
if
self
.
_name
is
None
:
if
self
.
_name
is
None
:
return
"%{}({})"
.
format
(
self
.
_id
,
self
.
module_type
.
__name__
)
return
"%{}
_
({})"
.
format
(
self
.
_id
,
self
.
module_type
.
__name__
)
else
:
else
:
return
"%{}
({})"
.
format
(
self
.
_name
,
self
.
module_type
.
__name__
)
return
"%{}
_{}({})"
.
format
(
self
.
_id
,
self
.
_name
,
self
.
module_type
.
__name__
)
def
__getstate__
(
self
):
def
__getstate__
(
self
):
d
=
self
.
__dict__
return
{
d
.
pop
(
"_owner"
,
None
)
"expr"
:
self
.
expr
,
return
d
"users"
:
self
.
users
,
"_id"
:
self
.
_id
,
"_name"
:
self
.
_name
,
"module_type"
:
self
.
module_type
,
}
@
property
@
property
def
owner
(
self
):
def
owner
(
self
):
return
self
.
_owner
()
if
self
.
_owner
:
return
self
.
_owner
()
return
None
class
TensorNode
(
Node
):
class
TensorNode
(
Node
):
...
@@ -91,9 +105,9 @@ class TensorNode(Node):
...
@@ -91,9 +105,9 @@ class TensorNode(Node):
def
__repr__
(
self
):
def
__repr__
(
self
):
if
self
.
_name
is
None
:
if
self
.
_name
is
None
:
return
"%{}(Tensor)"
.
format
(
self
.
_id
)
return
"%{}
_
(Tensor)"
.
format
(
self
.
_id
)
else
:
else
:
return
"%{}
(Tensor)"
.
format
(
self
.
_name
)
return
"%{}
_{}(Tensor)"
.
format
(
self
.
_id
,
self
.
_name
)
class
NodeMixin
(
abc
.
ABC
):
class
NodeMixin
(
abc
.
ABC
):
...
...
imperative/python/megengine/experimental/traced_module/pytree.py
浏览文件 @
c7e730bc
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
import
collections
from
collections
import
OrderedDict
from
typing
import
Callable
,
NamedTuple
from
typing
import
Callable
,
NamedTuple
import
numpy
as
np
import
numpy
as
np
...
@@ -34,9 +35,25 @@ def _dict_unflatten(inps, aux_data):
...
@@ -34,9 +35,25 @@ def _dict_unflatten(inps, aux_data):
return
dict
(
zip
(
aux_data
,
inps
))
return
dict
(
zip
(
aux_data
,
inps
))
def
_ordereddict_flatten
(
inp
):
aux_data
=
[]
results
=
[]
for
key
,
value
in
inp
.
items
():
results
.
append
(
value
)
aux_data
.
append
(
key
)
return
results
,
tuple
(
aux_data
)
def
_ordereddict_unflatten
(
inps
,
aux_data
):
return
OrderedDict
(
zip
(
aux_data
,
inps
))
register_supported_type
(
list
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
list
(
x
))
register_supported_type
(
list
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
list
(
x
))
register_supported_type
(
tuple
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
tuple
(
x
))
register_supported_type
(
tuple
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
tuple
(
x
))
register_supported_type
(
dict
,
_dict_flatten
,
_dict_unflatten
)
register_supported_type
(
dict
,
_dict_flatten
,
_dict_unflatten
)
register_supported_type
(
collections
.
OrderedDict
,
_ordereddict_flatten
,
_ordereddict_unflatten
)
register_supported_type
(
register_supported_type
(
slice
,
slice
,
lambda
x
:
([
x
.
start
,
x
.
stop
,
x
.
step
],
None
),
lambda
x
:
([
x
.
start
,
x
.
stop
,
x
.
step
],
None
),
...
@@ -99,6 +116,12 @@ class TreeDef:
...
@@ -99,6 +116,12 @@ class TreeDef:
)
)
)
)
def
__lt__
(
self
,
other
):
return
self
.
__hash__
()
<
other
.
__hash__
()
def
__gt__
(
self
,
other
):
return
self
.
__hash__
()
>
other
.
__hash__
()
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
return
(
self
.
type
==
other
.
type
self
.
type
==
other
.
type
...
...
imperative/python/megengine/experimental/traced_module/traced_module.py
浏览文件 @
c7e730bc
此差异已折叠。
点击以展开。
imperative/python/test/unit/traced_module/test_modification.py
浏览文件 @
c7e730bc
...
@@ -57,16 +57,16 @@ def _init_module():
...
@@ -57,16 +57,16 @@ def _init_module():
def
test_search
():
def
test_search
():
traced_module
,
*
_
=
_init_block
()
traced_module
,
*
_
=
_init_block
()
graph
=
traced_module
.
graph
graph
=
traced_module
.
graph
relu_expr
=
graph
.
get_
call_function
(
F
.
relu
).
as_unique
()
relu_expr
=
graph
.
get_
function_by_type
(
F
.
relu
).
as_unique
()
assert
isinstance
(
relu_expr
,
CallFunction
)
and
relu_expr
.
func
==
F
.
relu
assert
isinstance
(
relu_expr
,
CallFunction
)
and
relu_expr
.
func
==
F
.
relu
def
test_insert
():
def
test_insert
():
traced_module
,
x
,
expect
=
_init_block
()
traced_module
,
x
,
expect
=
_init_block
()
graph
=
traced_module
.
graph
graph
=
traced_module
.
graph
relu_node
=
graph
.
get_
call_function
(
F
.
relu
).
as_unique
().
outputs
relu_node
=
graph
.
get_
function_by_type
(
F
.
relu
).
as_unique
().
outputs
neg_node
=
graph
.
insert_
call_function
(
F
.
neg
,
relu_node
)
neg_node
=
graph
.
insert_
function
(
lambda
x
:
F
.
neg
(
x
),
*
relu_node
)
graph
.
replace_node
({
relu_node
[
0
]:
neg_node
[
0
]
})
graph
.
replace_node
({
relu_node
[
0
]:
neg_node
})
graph
.
compile
()
graph
.
compile
()
np
.
testing
.
assert_allclose
(
expect
-
1
,
1
-
traced_module
(
x
),
atol
=
1e-6
)
np
.
testing
.
assert_allclose
(
expect
-
1
,
1
-
traced_module
(
x
),
atol
=
1e-6
)
...
@@ -74,7 +74,7 @@ def test_insert():
...
@@ -74,7 +74,7 @@ def test_insert():
def
test_delete
():
def
test_delete
():
traced_module
,
x
,
expect
=
_init_block
()
traced_module
,
x
,
expect
=
_init_block
()
graph
=
traced_module
.
graph
graph
=
traced_module
.
graph
relu_expr
=
graph
.
get_
call_function
(
F
.
relu
).
as_unique
()
relu_expr
=
graph
.
get_
function_by_type
(
F
.
relu
).
as_unique
()
node
=
relu_expr
.
outputs
node
=
relu_expr
.
outputs
repl_node
=
relu_expr
.
inputs
repl_node
=
relu_expr
.
inputs
graph
.
replace_node
({
node
[
0
]:
repl_node
[
0
]})
graph
.
replace_node
({
node
[
0
]:
repl_node
[
0
]})
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录