提交 829f0907 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix insert qat module

GitOrigin-RevId: 35849bc1a26b10fbbba4a6ef72593e82c10a2b6d
上级 8b764934
......@@ -281,11 +281,8 @@ class _InsertExprs:
def __exit__(self, ty, va, tr):
if va is not None:
return False
set_symbolic_shape(self.use_sym_shape)
active_module_tracer().patcher.__exit__(ty, va, tr)
_set_convert_node_flag(False)
set_active_module_tracer(None)
unset_module_tracing()
while self._tensor_method_patch:
pf = self._tensor_method_patch.pop()
......@@ -298,6 +295,10 @@ class _InsertExprs:
v = v.build()
setattr(module, k, v)
set_symbolic_shape(self.use_sym_shape)
set_active_module_tracer(None)
unset_module_tracing()
extra_inp_nodes = set(self.global_scope.inputs)
max_inp_expr_idx = -1
for node in extra_inp_nodes:
......
......@@ -13,6 +13,7 @@ import numpy as np
import megengine.functional as F
import megengine.module as M
import megengine.module.qat as qat
from megengine.module.identity import Identity
from megengine.traced_module import trace_module
from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input
......@@ -199,6 +200,31 @@ def test_insert_module():
assert n.value is None
def test_insert_qat_module():
class concat(qat.Concat):
pass
traced_module, x, expect = _init_block()
graph = traced_module.graph
self = graph.inputs[0]
out = graph.outputs[0]
setattr(traced_module, "cat_0", qat.Concat())
setattr(traced_module, "cat_1", concat())
with graph.insert_exprs():
x_0 = self.cat_0([out, out])
x_1 = self.cat_1([out, x_0])
graph.replace_node({out: x_1})
graph.compile()
x = F.copy(x)
np.testing.assert_allclose(
F.concat([expect, expect, expect]), traced_module(x), atol=1e-6
)
assert not hasattr(traced_module.cat_0, "graph")
assert traced_module.cat_1.graph is not None
def test_add_input_and_output():
traced_module, x, y = _init_module()
......
......@@ -108,7 +108,6 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams):
def build_observered_net(net: M.Module, observer_cls):
qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls))
Q.enable_observer(qat_net)
for _ in range(5):
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
qat_net(inp)
Q.disable_observer(qat_net)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册