Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a0b3a3c0
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a0b3a3c0
编写于
10月 28, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative): add TracedModule checker
GitOrigin-RevId: 12de7b278e28b7a3e37eb129c7f73c6660e8f300
上级
19993070
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
289 addition
and
22 deletion
+289
-22
imperative/python/megengine/traced_module/__init__.py
imperative/python/megengine/traced_module/__init__.py
+3
-0
imperative/python/megengine/traced_module/checker.py
imperative/python/megengine/traced_module/checker.py
+142
-0
imperative/python/megengine/traced_module/expr.py
imperative/python/megengine/traced_module/expr.py
+9
-0
imperative/python/megengine/traced_module/module_tracer.py
imperative/python/megengine/traced_module/module_tracer.py
+4
-0
imperative/python/megengine/traced_module/node.py
imperative/python/megengine/traced_module/node.py
+17
-0
imperative/python/megengine/traced_module/pytree.py
imperative/python/megengine/traced_module/pytree.py
+5
-1
imperative/python/megengine/traced_module/tm_config.py
imperative/python/megengine/traced_module/tm_config.py
+55
-0
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+48
-18
imperative/python/megengine/traced_module/utils.py
imperative/python/megengine/traced_module/utils.py
+2
-3
imperative/python/test/unit/traced_module/test_qat_module.py
imperative/python/test/unit/traced_module/test_qat_module.py
+4
-0
未找到文件。
imperative/python/megengine/traced_module/__init__.py
浏览文件 @
a0b3a3c0
...
...
@@ -10,6 +10,7 @@ from ..core._imperative_rt.core2 import set_cpp_apply_module_trace
from
.
import
compat
from
._passes
import
optimize
from
.pytree
import
register_supported_type
from
.tm_config
import
disable_default_checker
,
enable_expr_checker
from
.traced_module
import
(
TracedModule
,
_register_all_builtin_module
,
...
...
@@ -29,4 +30,6 @@ __all__ = [
"wrap"
,
"TracedModule"
,
"optimize"
,
"enable_expr_checker"
,
"disable_default_checker"
,
]
imperative/python/megengine/traced_module/checker.py
0 → 100644
浏览文件 @
a0b3a3c0
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
traceback
from
typing
import
Sequence
import
numpy
as
np
from
..core._imperative_rt.core2
import
apply
from
..core._imperative_rt.ops
import
ROIAlign
,
ROIPooling
from
..core.ops.builtin
import
Copy
from
..core.tensor.utils
import
isscalar
,
setscalar
from
..tensor
import
Tensor
from
.tm_config
import
_exclude_from_trace
class
TracedModuleChecker
:
def
__init__
(
self
,
tracer
):
self
.
_active_node2values
=
[]
self
.
tracer
=
tracer
self
.
node_without_tensor_info
=
{}
def
push_scope
(
self
):
self
.
_active_node2values
.
append
({})
def
pop_scope
(
self
):
self
.
_active_node2values
.
pop
()
def
current_node2values
(
self
):
return
self
.
_active_node2values
[
-
1
]
def
reset_checker
(
self
):
self
.
_active_node2values
=
[]
def
check_node_not_in_scope
(
self
):
if
self
.
node_without_tensor_info
:
for
node
,
info
in
self
.
node_without_tensor_info
.
items
():
for
expr
in
info
[
0
].
_exprs
:
if
node
in
expr
.
inputs
or
node
in
expr
.
outputs
:
traceback
.
print_list
(
info
[
1
])
raise
ValueError
(
"node({}) not in the graph:
\n
{}"
.
format
(
node
,
info
[
0
])
)
return
True
else
:
return
False
def
check_net_outputs
(
self
,
tm_res
,
gt_res
):
if
isinstance
(
tm_res
,
Tensor
):
np
.
testing
.
assert_allclose
(
tm_res
.
numpy
(),
gt_res
.
numpy
())
elif
isinstance
(
tm_res
,
Sequence
):
for
i
,
j
in
zip
(
tm_res
,
gt_res
):
np
.
testing
.
assert_allclose
(
i
.
numpy
(),
j
.
numpy
())
else
:
for
k
in
tm_res
.
__dict__
.
keys
():
np
.
testing
.
assert_allclose
(
getattr
(
tm_res
,
k
).
numpy
(),
getattr
(
gt_res
,
k
).
numpy
()
)
def
record_nodemixin
(
self
,
node
,
value
):
self
.
current_node2values
()[
node
]
=
value
def
record_node2value
(
self
,
node
,
value
):
with
_exclude_from_trace
():
self
.
current_node2values
()[
node
]
=
apply
(
Copy
(
comp_node
=
value
.
device
),
value
)[
0
]
if
isscalar
(
value
):
setscalar
(
self
.
current_node2values
()[
node
])
def
check_apply_special_cases
(
self
,
opdef
,
num_outputs
):
indexs
=
list
(
range
(
num_outputs
))
if
isinstance
(
opdef
,
ROIAlign
)
and
opdef
.
mode
==
ROIAlign
.
Mode
.
AVERAGE
:
indexs
.
pop
(
-
1
)
if
isinstance
(
opdef
,
ROIPooling
)
and
opdef
.
mode
==
ROIPooling
.
Mode
.
AVERAGE
:
indexs
.
pop
(
-
1
)
return
indexs
def
check_expr_results
(
self
,
expr_outputs
,
gt_outputs
,
indexs
=
None
):
expr_outputs
=
(
(
expr_outputs
,)
if
not
isinstance
(
expr_outputs
,
Sequence
)
else
expr_outputs
)
gt_outputs
=
(
(
gt_outputs
,)
if
not
isinstance
(
gt_outputs
,
Sequence
)
else
gt_outputs
)
if
indexs
is
not
None
:
for
i
in
indexs
:
np
.
testing
.
assert_allclose
(
expr_outputs
[
i
].
numpy
(),
gt_outputs
[
i
].
numpy
()
)
else
:
np
.
testing
.
assert_allclose
(
expr_outputs
,
gt_outputs
)
def
get_node2value
(
self
,
inputs
,
start_idx
=
0
):
inp_values
=
[]
has_node_not_in_scope
=
False
for
i
in
range
(
start_idx
,
len
(
inputs
)):
try
:
inp_values
.
append
(
self
.
current_node2values
()[
inputs
[
i
]])
except
:
has_node_not_in_scope
=
True
self
.
node_without_tensor_info
[
inputs
[
i
]]
=
[
self
.
tracer
.
current_scope
(),
traceback
.
extract_stack
(),
]
return
inp_values
,
has_node_not_in_scope
def
check_expr_interpret
(
self
,
expr
,
gt_outputs
):
ori_in
,
has_node_not_in_scope
=
self
.
get_node2value
(
expr
.
inputs
)
if
not
has_node_not_in_scope
:
expr_res
=
expr
.
interpret
(
*
ori_in
)
try
:
self
.
check_expr_results
(
expr_res
,
gt_outputs
)
except
:
raise
ValueError
(
"Error occurred when checking expr: {}"
.
format
(
expr
))
def
check_apply
(
self
,
expr
,
gt_outputs
,
opdef
):
ori_in
,
has_node_not_in_scope
=
self
.
get_node2value
(
expr
.
inputs
)
if
not
has_node_not_in_scope
:
expr_res
=
expr
.
interpret
(
*
ori_in
)
indexs
=
self
.
check_apply_special_cases
(
opdef
,
len
(
gt_outputs
))
try
:
self
.
check_expr_results
(
expr_res
,
gt_outputs
,
indexs
=
indexs
)
except
:
raise
ValueError
(
"Error occurred when checking expr: {}"
.
format
(
expr
))
def
check_builtin_module
(
self
,
module
,
expr
,
gt_outputs
):
ori_in
,
has_node_not_in_scope
=
self
.
get_node2value
(
expr
.
inputs
,
start_idx
=
1
)
if
not
has_node_not_in_scope
:
ori_in
.
insert
(
0
,
module
)
expr_res
=
expr
.
interpret
(
*
ori_in
)
try
:
self
.
check_expr_results
(
expr_res
,
gt_outputs
)
except
:
raise
ValueError
(
"{}, Error occurred when checking expr: {}"
.
format
(
expr
)
)
imperative/python/megengine/traced_module/expr.py
浏览文件 @
a0b3a3c0
...
...
@@ -32,6 +32,7 @@ from .module_tracer import active_module_tracer, module_tracer
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.pytree
import
ArgsIndex
,
TreeDef
,
_is_const_leaf
,
_is_leaf
,
tree_flatten
from
.serialization
import
_ModuleState
from
.tm_config
import
_exclude_from_trace
,
_get_expr_checker
from
.utils
import
_check_builtin_module_attr
,
_check_obj_attr
,
_convert_kwargs_to_args
...
...
@@ -611,6 +612,8 @@ class Apply(Expr):
inp_nodes
=
[
NodeMixin
.
get
(
inputs
[
0
])]
for
i
in
inputs
[
1
:]:
node
=
Constant
.
make
(
i
)
if
_get_expr_checker
():
active_module_tracer
().
checker
.
record_node2value
(
node
,
Tensor
(
i
))
inp_nodes
.
append
(
node
)
apply_node
=
cls
.
make
(
opdef
)
for
n
in
inp_nodes
:
...
...
@@ -624,11 +627,17 @@ class Apply(Expr):
unset_module_tracing
()
outputs
=
apply
(
opdef
,
*
inputs
)
outputs
=
list
(
map
(
Tensor
,
outputs
))
set_module_tracing
()
apply_node
.
add_outputs
(
outputs
)
for
n
,
v
in
zip
(
apply_node
.
outputs
,
outputs
):
NodeMixin
.
wrap_safe
(
v
,
n
)
if
_get_expr_checker
():
with
_exclude_from_trace
():
active_module_tracer
().
checker
.
check_apply
(
apply_node
,
outputs
,
opdef
)
return
list
(
outputs
)
...
...
imperative/python/megengine/traced_module/module_tracer.py
浏览文件 @
a0b3a3c0
...
...
@@ -12,6 +12,7 @@ from .. import functional as F
from
..core.tensor.array_method
import
ArrayMethodMixin
from
..module
import
Module
from
..module.qat
import
QATModule
from
.checker
import
TracedModuleChecker
_active_module_tracer
=
None
...
...
@@ -128,6 +129,7 @@ class module_tracer:
def
__init__
(
self
,
wrap_fn
):
self
.
_active_scopes
=
[]
self
.
checker
=
TracedModuleChecker
(
self
)
self
.
patcher
=
Patcher
(
wrap_fn
)
@
classmethod
...
...
@@ -142,9 +144,11 @@ class module_tracer:
def
push_scope
(
self
,
scope
):
self
.
_active_scopes
.
append
(
scope
)
self
.
checker
.
push_scope
()
def
pop_scope
(
self
):
self
.
_active_scopes
.
pop
()
self
.
checker
.
pop_scope
()
def
current_scope
(
self
):
if
self
.
_active_scopes
:
...
...
imperative/python/megengine/traced_module/node.py
浏览文件 @
a0b3a3c0
...
...
@@ -18,6 +18,8 @@ from ..core._imperative_rt.core2 import Tensor as RawTensor
from
..module
import
Module
from
..quantization.utils
import
QParams
from
..tensor
import
Tensor
from
.module_tracer
import
active_module_tracer
from
.tm_config
import
_get_expr_checker
from
.utils
import
_check_obj_attr
logger
=
get_logger
(
__name__
)
...
...
@@ -343,6 +345,11 @@ class NodeMixin(abc.ABC):
if
isinstance
(
value
,
NodeMixin
):
value
.
_record_wrapped_nodes
(
node
)
setattr
(
value
,
"_NodeMixin__node"
,
node
)
if
_get_expr_checker
():
if
isinstance
(
value
,
RawTensor
):
active_module_tracer
().
checker
.
record_node2value
(
node
,
value
)
if
isinstance
(
value
,
NodeMixin
):
active_module_tracer
().
checker
.
record_nodemixin
(
node
,
value
)
else
:
assert
callable
(
node
)
n
=
node
()
...
...
@@ -352,6 +359,11 @@ class NodeMixin(abc.ABC):
if
isinstance
(
value
,
NodeMixin
):
value
.
_record_wrapped_nodes
(
n
)
setattr
(
value
,
"_NodeMixin__node"
,
n
)
if
_get_expr_checker
():
if
isinstance
(
value
,
RawTensor
):
active_module_tracer
().
checker
.
record_node2value
(
n
,
value
)
if
isinstance
(
value
,
NodeMixin
):
active_module_tracer
().
checker
.
record_nodemixin
(
n
,
value
)
@
classmethod
def
wrap_safe
(
cls
,
value
,
node
):
...
...
@@ -359,6 +371,11 @@ class NodeMixin(abc.ABC):
if
isinstance
(
value
,
RawTensor
):
cls
.
_record_tensornode_property
(
node
,
value
)
setattr
(
value
,
"_NodeMixin__node"
,
node
)
if
_get_expr_checker
():
if
isinstance
(
value
,
RawTensor
):
active_module_tracer
().
checker
.
record_node2value
(
node
,
value
)
if
isinstance
(
value
,
NodeMixin
):
active_module_tracer
().
checker
.
record_nodemixin
(
node
,
value
)
if
isinstance
(
value
,
NodeMixin
):
value
.
_record_wrapped_nodes
(
node
)
...
...
imperative/python/megengine/traced_module/pytree.py
浏览文件 @
a0b3a3c0
...
...
@@ -212,7 +212,11 @@ def tree_flatten(
to reconstruct the pytree.
"""
if
type
(
values
)
not
in
SUPPORTED_TYPE
:
assert
is_leaf
(
values
),
values
assert
is_leaf
(
values
),
'doesn
\'
t support {} type, MUST use "register_supported_type" method to register self-defined type'
.
format
(
values
)
node
=
LeafDef
(
leaf_type
(
values
))
if
is_const_leaf
(
values
):
node
.
const_val
=
values
...
...
imperative/python/megengine/traced_module/tm_config.py
0 → 100644
浏览文件 @
a0b3a3c0
import
contextlib
from
..core._imperative_rt.core2
import
(
is_tracing_module
,
set_module_tracing
,
unset_module_tracing
,
)
_enable_expr_checker
=
False
_enable_default_checker
=
True
def
_get_expr_checker
():
return
_enable_expr_checker
def
_get_default_checker
():
return
_enable_default_checker
def
enable_expr_checker
():
r
"""Call this function to check the result of each expr during tracing."""
global
_enable_expr_checker
_enable_expr_checker
=
True
_enable_default_checker
=
False
def
disable_default_checker
():
r
"""Call this function to disable checking the final output of the model after tracing."""
global
_enable_default_checker
_enable_default_checker
=
False
_enable_graph_surgery_mode
=
False
def
_graph_surgery_mode
():
return
_enable_graph_surgery_mode
def
_set_graph_surgery_mode
(
mode
:
bool
):
global
_enable_graph_surgery_mode
pre_mode
=
_enable_graph_surgery_mode
_enable_graph_surgery_mode
=
mode
return
pre_mode
@
contextlib
.
contextmanager
def
_exclude_from_trace
():
is_tracing
=
is_tracing_module
()
if
is_tracing
:
unset_module_tracing
()
yield
if
is_tracing
:
set_module_tracing
()
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
a0b3a3c0
...
...
@@ -36,11 +36,14 @@ from .. import get_logger
from
..
import
module
as
M
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
(
apply
,
is_tracing_module
,
set_module_tracing
,
unset_module_tracing
,
)
from
..core._trace_option
import
set_symbolic_shape
from
..core.ops.builtin
import
Copy
from
..core.tensor.utils
import
isscalar
,
setscalar
from
..module
import
Module
from
..module
import
external
as
MExternal
from
..module.qat
import
QATModule
...
...
@@ -98,6 +101,13 @@ from .serialization import (
load_call_tensor_method_expr
,
load_functional
,
)
from
.tm_config
import
(
_exclude_from_trace
,
_get_default_checker
,
_get_expr_checker
,
_graph_surgery_mode
,
_set_graph_surgery_mode
,
)
from
.utils
import
(
_check_builtin_module_attr
,
_check_obj_attr
,
...
...
@@ -117,26 +127,14 @@ def _is_builtin_name(name: str) -> bool:
def
_is_leaf
(
node
):
assert
isinstance
(
node
,
RawTensor
),
"doesn't support {} in return values"
.
format
(
assert
isinstance
(
node
,
RawTensor
),
'doesn
\'
t support {} in return values, MUST use Tensor or use "register_supported_type" method to register self-defined type'
.
format
(
type
(
node
)
)
return
isinstance
(
node
,
RawTensor
)
_enable_graph_surgery_mode
=
False
def
_graph_surgery_mode
():
return
_enable_graph_surgery_mode
def
_set_graph_surgery_mode
(
mode
:
bool
):
global
_enable_graph_surgery_mode
pre_mode
=
_enable_graph_surgery_mode
_enable_graph_surgery_mode
=
mode
return
pre_mode
def
_node_to_tensor
(
*
args
,
**
kwargs
):
tensors
=
[]
nodes
,
tree_def
=
tree_flatten
((
args
,
kwargs
))
...
...
@@ -1295,7 +1293,12 @@ def _wrapped_function(orig_func):
return
orig_func
(
*
args
,
**
kwargs
)
if
isinstance
(
args
[
1
],
RawTensor
):
node
=
NodeMixin
.
get
(
inputs
[
1
])
inputs
[
1
]
=
copy
.
copy
(
inputs
[
1
])
is_scalar
=
isscalar
(
inputs
[
1
])
inputs
[
1
]
=
apply
(
Copy
(
comp_node
=
inputs
[
1
].
device
),
Tensor
(
inputs
[
1
])
)[
0
]
if
is_scalar
:
setscalar
(
inputs
[
1
])
# copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor,
# which will cause they have same _NodeMixin__node in tracing.
NodeMixin
.
wrap_safe
(
inputs
[
1
],
node
)
...
...
@@ -1319,6 +1322,13 @@ def _wrapped_function(orig_func):
else
:
outputs
=
None
call_node
.
add_outputs
(
outputs
)
if
_get_expr_checker
():
with
_exclude_from_trace
():
active_module_tracer
().
checker
.
check_expr_interpret
(
call_node
,
outputs
)
set_module_tracing
()
return
rst
return
orig_func
(
*
args
,
**
kwargs
)
...
...
@@ -1500,6 +1510,12 @@ class TracedModuleBuilder(NodeMixin):
unset_module_tracing
()
rst
=
self
.
_mod
(
*
args
,
**
kwargs
)
outputs
,
out_def
=
tree_flatten
(
rst
,
is_leaf
=
_is_leaf
)
if
_get_expr_checker
():
with
_exclude_from_trace
():
tmp
=
self
.
build
()
active_module_tracer
().
checker
.
check_builtin_module
(
tmp
,
callnode
,
outputs
)
set_module_tracing
()
if
self
.
_is_builtin
:
self
.
_body
=
None
...
...
@@ -1674,7 +1690,9 @@ class TracedModuleBuilder(NodeMixin):
if
not
isinstance
(
mod_attr
,
(
List
,
Dict
,
QATModule
)):
assert
mod_attr
is
wrapped
.
_mod
else
:
assert
mod_attr
is
wrapped
assert
(
mod_attr
is
wrapped
),
"TracedModule do not support modify attributes, please check your code."
if
isinstance
(
wrapped
,
(
NodeMixin
,
RawTensor
)):
NodeMixin
.
wrap
(
...
...
@@ -2469,11 +2487,23 @@ def trace_module(
qualname
=
"{}.[{}]"
.
format
(
net_name
,
"arg_{}"
.
format
(
_
)),
),
)
builder
(
*
args
,
**
kwargs
)
rst
=
builder
(
*
copy
.
deepcopy
(
args
),
**
copy
.
deepcopy
(
kwargs
)
)
active_module_tracer
().
pop_scope
()
traced_mod
=
builder
.
build
()
traced_mod
.
argspec
=
forward_argspec
traced_mod
.
graph
.
_reset_ids
()
has_expr_not_check
=
False
if
_get_expr_checker
():
has_expr_not_check
=
(
active_module_tracer
().
checker
.
check_node_not_in_scope
()
)
if
_get_default_checker
()
or
has_expr_not_check
:
with
_exclude_from_trace
():
tm_res
=
traced_mod
(
*
args
,
**
kwargs
)
tm_res
,
_
=
tree_flatten
(
tm_res
,
is_leaf
=
_is_leaf
)
rst
,
_
=
tree_flatten
(
rst
,
is_leaf
=
_is_leaf
)
active_module_tracer
().
checker
.
check_net_outputs
(
tm_res
,
rst
)
return
traced_mod
finally
:
set_symbolic_shape
(
use_sym_shape
)
...
...
imperative/python/megengine/traced_module/utils.py
浏览文件 @
a0b3a3c0
...
...
@@ -5,16 +5,15 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
import
copy
import
inspect
from
collections.abc
import
MutableMapping
,
MutableSequence
from
inspect
import
FullArgSpec
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Type
,
Union
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Union
from
..
import
get_logger
from
..module
import
Module
from
..tensor
import
Parameter
,
Tensor
from
..tensor
import
Tensor
logger
=
get_logger
(
__name__
)
...
...
imperative/python/test/unit/traced_module/test_qat_module.py
浏览文件 @
a0b3a3c0
...
...
@@ -109,6 +109,7 @@ def build_observered_net(net: M.Module, observer_cls):
)
Q
.
enable_observer
(
qat_net
)
inp
=
Tensor
(
np
.
random
.
random
(
size
=
(
5
,
3
,
32
,
32
)))
qat_net
.
eval
()
qat_net
(
inp
)
Q
.
disable_observer
(
qat_net
)
return
qat_net
...
...
@@ -116,6 +117,7 @@ def build_observered_net(net: M.Module, observer_cls):
def
build_fakequanted_net
(
net
:
QATModule
,
fakequant_cls
):
qat_net
=
Q
.
reset_qconfig
(
net
,
get_lsq_config
(
fakequant_cls
))
qat_net
.
eval
()
return
qat_net
...
...
@@ -162,6 +164,7 @@ def test_load_param():
def
_check_module
(
build_func
:
Callable
):
net
=
build_func
()
net
.
eval
()
buffer
=
io
.
BytesIO
()
mge
.
save
(
net
.
state_dict
(),
buffer
)
buffer
.
seek
(
0
)
...
...
@@ -185,6 +188,7 @@ def test_load_param():
def
test_qualname
():
def
_check_qualname
(
net
):
inp
=
Tensor
(
np
.
random
.
random
(
size
=
(
5
,
3
,
32
,
32
)))
net
.
eval
()
traced_net
=
trace_module
(
net
,
inp
)
base_qualname
=
traced_net
.
graph
.
qualname
for
node
in
traced_net
.
graph
.
nodes
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录