Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
829f0907
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
829f0907
编写于
11月 09, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/traced_module): fix insert qat module
GitOrigin-RevId: 35849bc1a26b10fbbba4a6ef72593e82c10a2b6d
上级
8b764934
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
32 addition
and
6 deletion
+32
-6
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+4
-3
imperative/python/test/unit/traced_module/test_modification.py
...ative/python/test/unit/traced_module/test_modification.py
+26
-0
imperative/python/test/unit/traced_module/test_qat_module.py
imperative/python/test/unit/traced_module/test_qat_module.py
+2
-3
未找到文件。
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
829f0907
...
...
@@ -281,11 +281,8 @@ class _InsertExprs:
def
__exit__
(
self
,
ty
,
va
,
tr
):
if
va
is
not
None
:
return
False
set_symbolic_shape
(
self
.
use_sym_shape
)
active_module_tracer
().
patcher
.
__exit__
(
ty
,
va
,
tr
)
_set_convert_node_flag
(
False
)
set_active_module_tracer
(
None
)
unset_module_tracing
()
while
self
.
_tensor_method_patch
:
pf
=
self
.
_tensor_method_patch
.
pop
()
...
...
@@ -298,6 +295,10 @@ class _InsertExprs:
v
=
v
.
build
()
setattr
(
module
,
k
,
v
)
set_symbolic_shape
(
self
.
use_sym_shape
)
set_active_module_tracer
(
None
)
unset_module_tracing
()
extra_inp_nodes
=
set
(
self
.
global_scope
.
inputs
)
max_inp_expr_idx
=
-
1
for
node
in
extra_inp_nodes
:
...
...
imperative/python/test/unit/traced_module/test_modification.py
浏览文件 @
829f0907
...
...
@@ -13,6 +13,7 @@ import numpy as np
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.module.qat
as
qat
from
megengine.module.identity
import
Identity
from
megengine.traced_module
import
trace_module
from
megengine.traced_module.expr
import
CallFunction
,
CallMethod
,
Expr
,
GetAttr
,
Input
...
...
@@ -199,6 +200,31 @@ def test_insert_module():
assert
n
.
value
is
None
def
test_insert_qat_module
():
class
concat
(
qat
.
Concat
):
pass
traced_module
,
x
,
expect
=
_init_block
()
graph
=
traced_module
.
graph
self
=
graph
.
inputs
[
0
]
out
=
graph
.
outputs
[
0
]
setattr
(
traced_module
,
"cat_0"
,
qat
.
Concat
())
setattr
(
traced_module
,
"cat_1"
,
concat
())
with
graph
.
insert_exprs
():
x_0
=
self
.
cat_0
([
out
,
out
])
x_1
=
self
.
cat_1
([
out
,
x_0
])
graph
.
replace_node
({
out
:
x_1
})
graph
.
compile
()
x
=
F
.
copy
(
x
)
np
.
testing
.
assert_allclose
(
F
.
concat
([
expect
,
expect
,
expect
]),
traced_module
(
x
),
atol
=
1e-6
)
assert
not
hasattr
(
traced_module
.
cat_0
,
"graph"
)
assert
traced_module
.
cat_1
.
graph
is
not
None
def
test_add_input_and_output
():
traced_module
,
x
,
y
=
_init_module
()
...
...
imperative/python/test/unit/traced_module/test_qat_module.py
浏览文件 @
829f0907
...
...
@@ -108,7 +108,6 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams):
def
build_observered_net
(
net
:
M
.
Module
,
observer_cls
):
qat_net
=
Q
.
quantize_qat
(
net
,
qconfig
=
get_observer_config
(
observer_cls
))
Q
.
enable_observer
(
qat_net
)
for
_
in
range
(
5
):
inp
=
Tensor
(
np
.
random
.
random
(
size
=
(
5
,
3
,
32
,
32
)))
qat_net
(
inp
)
Q
.
disable_observer
(
qat_net
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录