提交 9c90ce8c 编写于 作者: M Megvii Engine Team

fix(mge/quantization): fix `quantize` and `quantize_qat`'s `set_expand_structure` arguments

GitOrigin-RevId: c61633095d7371728be14fd260c1a4de7f3bbd92
上级 c8697a70
...@@ -80,7 +80,7 @@ def quantize(module: Module, inplace: bool = True, mapping: dict = None): ...@@ -80,7 +80,7 @@ def quantize(module: Module, inplace: bool = True, mapping: dict = None):
module._flatten(with_key=True, with_parent=True, predicate=is_qat) module._flatten(with_key=True, with_parent=True, predicate=is_qat)
): ):
new_mod = convert_dict[type(submodule)].from_qat_module(submodule) new_mod = convert_dict[type(submodule)].from_qat_module(submodule)
set_expand_structure(parent, key, new_mod) set_expand_structure(module, key, new_mod)
return module return module
...@@ -123,7 +123,7 @@ def quantize_qat( ...@@ -123,7 +123,7 @@ def quantize_qat(
continue continue
new_mod = convert_dict[type(submodule)].from_float_module(submodule) new_mod = convert_dict[type(submodule)].from_float_module(submodule)
set_expand_structure(parent, key, new_mod) set_expand_structure(module, key, new_mod)
propagate_qconfig(module, qconfig) propagate_qconfig(module, qconfig)
return module return module
......
...@@ -37,9 +37,10 @@ class FloatNet(Float.Module): ...@@ -37,9 +37,10 @@ class FloatNet(Float.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.quant = Float.QuantStub() self.quant = Float.QuantStub()
self.linear = Float.Linear(3, 3) self.linear = Float.Sequential(Float.Linear(3, 3), Float.Linear(3, 3))
self.dequant = Float.DequantStub() self.dequant = Float.DequantStub()
self.linear.bias[...] = Parameter(np.random.rand(3)) self.linear[0].bias[...] = Parameter(np.random.rand(3))
self.linear[1].bias[...] = Parameter(np.random.rand(3))
def forward(self, x): def forward(self, x):
x = self.quant(x) x = self.quant(x)
...@@ -52,9 +53,10 @@ class QATNet(Float.Module): ...@@ -52,9 +53,10 @@ class QATNet(Float.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.quant = QAT.QuantStub() self.quant = QAT.QuantStub()
self.linear = QAT.Linear(3, 3) self.linear = Float.Sequential(QAT.Linear(3, 3), QAT.Linear(3, 3))
self.dequant = QAT.DequantStub() self.dequant = QAT.DequantStub()
self.linear.bias[...] = Parameter(np.random.rand(3)) self.linear[0].bias[...] = Parameter(np.random.rand(3))
self.linear[1].bias[...] = Parameter(np.random.rand(3))
def forward(self, x): def forward(self, x):
x = self.quant(x) x = self.quant(x)
...@@ -72,10 +74,14 @@ def test_propagate_qconfig(): ...@@ -72,10 +74,14 @@ def test_propagate_qconfig():
net.quant.weight_fake_quant is None, net.quant.weight_fake_quant is None,
isinstance(net.quant.act_observer, MinMaxObserver), isinstance(net.quant.act_observer, MinMaxObserver),
isinstance(net.quant.act_fake_quant, FakeQuantize), isinstance(net.quant.act_fake_quant, FakeQuantize),
isinstance(net.linear.weight_observer, MinMaxObserver), isinstance(net.linear[0].weight_observer, MinMaxObserver),
isinstance(net.linear.weight_fake_quant, FakeQuantize), isinstance(net.linear[0].weight_fake_quant, FakeQuantize),
isinstance(net.linear.act_observer, MinMaxObserver), isinstance(net.linear[0].act_observer, MinMaxObserver),
isinstance(net.linear.act_fake_quant, FakeQuantize), isinstance(net.linear[0].act_fake_quant, FakeQuantize),
isinstance(net.linear[1].weight_observer, MinMaxObserver),
isinstance(net.linear[1].weight_fake_quant, FakeQuantize),
isinstance(net.linear[1].act_observer, MinMaxObserver),
isinstance(net.linear[1].act_fake_quant, FakeQuantize),
net.dequant.weight_observer is None, net.dequant.weight_observer is None,
net.dequant.weight_fake_quant is None, net.dequant.weight_fake_quant is None,
net.dequant.act_observer is None, net.dequant.act_observer is None,
...@@ -91,10 +97,14 @@ def init_qat_net(): ...@@ -91,10 +97,14 @@ def init_qat_net():
max_val = np.random.randint(1, 127, size=(3,)) max_val = np.random.randint(1, 127, size=(3,))
net.quant.act_observer.min_val[...] = Parameter(min_val[0]) net.quant.act_observer.min_val[...] = Parameter(min_val[0])
net.quant.act_observer.max_val[...] = Parameter(max_val[0]) net.quant.act_observer.max_val[...] = Parameter(max_val[0])
net.linear.weight_observer.min_val[...] = Parameter(min_val[1]) net.linear[0].weight_observer.min_val[...] = Parameter(min_val[1])
net.linear.weight_observer.max_val[...] = Parameter(max_val[1]) net.linear[0].weight_observer.max_val[...] = Parameter(max_val[1])
net.linear.act_observer.min_val[...] = Parameter(min_val[2]) net.linear[0].act_observer.min_val[...] = Parameter(min_val[2])
net.linear.act_observer.max_val[...] = Parameter(max_val[2]) net.linear[0].act_observer.max_val[...] = Parameter(max_val[2])
net.linear[1].weight_observer.min_val[...] = Parameter(min_val[1])
net.linear[1].weight_observer.max_val[...] = Parameter(max_val[1])
net.linear[1].act_observer.min_val[...] = Parameter(min_val[2])
net.linear[1].act_observer.max_val[...] = Parameter(max_val[2])
return net return net
...@@ -102,11 +112,20 @@ def test_reset_qconfig(): ...@@ -102,11 +112,20 @@ def test_reset_qconfig():
qat_net = init_qat_net() qat_net = init_qat_net()
new_qat_net = reset_qconfig(qat_net, passive_qconfig) new_qat_net = reset_qconfig(qat_net, passive_qconfig)
assert ( assert (
new_qat_net.linear.get_weight_qparams() == qat_net.linear.get_weight_qparams() new_qat_net.linear[0].get_weight_qparams()
== qat_net.linear[0].get_weight_qparams()
) )
assert ( assert (
new_qat_net.linear.get_activation_qparams() new_qat_net.linear[0].get_activation_qparams()
== qat_net.linear.get_activation_qparams() == qat_net.linear[0].get_activation_qparams()
)
assert (
new_qat_net.linear[1].get_weight_qparams()
== qat_net.linear[1].get_weight_qparams()
)
assert (
new_qat_net.linear[1].get_activation_qparams()
== qat_net.linear[1].get_activation_qparams()
) )
...@@ -114,24 +133,32 @@ def test_enable_and_disable_observer(): ...@@ -114,24 +133,32 @@ def test_enable_and_disable_observer():
net = init_qat_net() net = init_qat_net()
enable_observer(net) enable_observer(net)
assert net.quant.act_observer.enabled is True assert net.quant.act_observer.enabled is True
assert net.linear.weight_observer.enabled is True assert net.linear[0].weight_observer.enabled is True
assert net.linear.act_observer.enabled is True assert net.linear[0].act_observer.enabled is True
assert net.linear[1].weight_observer.enabled is True
assert net.linear[1].act_observer.enabled is True
disable_observer(net) disable_observer(net)
assert net.quant.act_observer.enabled is False assert net.quant.act_observer.enabled is False
assert net.linear.weight_observer.enabled is False assert net.linear[0].weight_observer.enabled is False
assert net.linear.act_observer.enabled is False assert net.linear[0].weight_observer.enabled is False
assert net.linear[1].act_observer.enabled is False
assert net.linear[1].act_observer.enabled is False
def test_enable_and_disable_fake_quant(): def test_enable_and_disable_fake_quant():
net = init_qat_net() net = init_qat_net()
disable_fake_quant(net) disable_fake_quant(net)
assert net.quant.act_fake_quant.enabled is False assert net.quant.act_fake_quant.enabled is False
assert net.linear.weight_fake_quant.enabled is False assert net.linear[0].weight_fake_quant.enabled is False
assert net.linear.act_fake_quant.enabled is False assert net.linear[0].act_fake_quant.enabled is False
assert net.linear[1].weight_fake_quant.enabled is False
assert net.linear[1].act_fake_quant.enabled is False
enable_fake_quant(net) enable_fake_quant(net)
assert net.quant.act_fake_quant.enabled is True assert net.quant.act_fake_quant.enabled is True
assert net.linear.weight_fake_quant.enabled is True assert net.linear[0].weight_fake_quant.enabled is True
assert net.linear.act_fake_quant.enabled is True assert net.linear[0].act_fake_quant.enabled is True
assert net.linear[1].weight_fake_quant.enabled is True
assert net.linear[1].act_fake_quant.enabled is True
def init_observer(module, data): def init_observer(module, data):
...@@ -165,7 +192,8 @@ def test_quantize_qat(): ...@@ -165,7 +192,8 @@ def test_quantize_qat():
net = FloatNet() net = FloatNet()
qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig) qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig)
assert isinstance(qat_net.quant, QAT.QuantStub) assert isinstance(qat_net.quant, QAT.QuantStub)
assert isinstance(qat_net.linear, QAT.Linear) assert isinstance(qat_net.linear[0], QAT.Linear)
assert isinstance(qat_net.linear[1], QAT.Linear)
assert isinstance(qat_net.dequant, QAT.DequantStub) assert isinstance(qat_net.dequant, QAT.DequantStub)
...@@ -173,7 +201,8 @@ def test_quantize(): ...@@ -173,7 +201,8 @@ def test_quantize():
qat_net = init_qat_net() qat_net = init_qat_net()
q_net = quantize(qat_net, inplace=False) q_net = quantize(qat_net, inplace=False)
assert isinstance(q_net.quant, Q.QuantStub) assert isinstance(q_net.quant, Q.QuantStub)
assert isinstance(q_net.linear, Q.Linear) assert isinstance(q_net.linear[0], Q.Linear)
assert isinstance(q_net.linear[1], Q.Linear)
assert isinstance(q_net.dequant, Q.DequantStub) assert isinstance(q_net.dequant, Q.DequantStub)
...@@ -183,8 +212,10 @@ def test_apply_easy_quant(): ...@@ -183,8 +212,10 @@ def test_apply_easy_quant():
eq_net = reset_qconfig(qat_net, passive_qconfig, inplace=False) eq_net = reset_qconfig(qat_net, passive_qconfig, inplace=False)
apply_easy_quant(eq_net, data, 0.9, 1.1, 10) apply_easy_quant(eq_net, data, 0.9, 1.1, 10)
assert isinstance(eq_net.quant.act_observer, PassiveObserver) assert isinstance(eq_net.quant.act_observer, PassiveObserver)
assert isinstance(eq_net.linear.weight_observer, PassiveObserver) assert isinstance(eq_net.linear[0].weight_observer, PassiveObserver)
assert isinstance(eq_net.linear.act_observer, PassiveObserver) assert isinstance(eq_net.linear[0].act_observer, PassiveObserver)
assert isinstance(eq_net.linear[1].weight_observer, PassiveObserver)
assert isinstance(eq_net.linear[1].act_observer, PassiveObserver)
assert eq_net.dequant.act_observer is None assert eq_net.dequant.act_observer is None
...@@ -192,8 +223,10 @@ def test_apply_tqt(): ...@@ -192,8 +223,10 @@ def test_apply_tqt():
qat_net = init_qat_net() qat_net = init_qat_net()
tqt_net = reset_qconfig(qat_net, tqt_qconfig, inplace=False) tqt_net = reset_qconfig(qat_net, tqt_qconfig, inplace=False)
assert isinstance(tqt_net.quant.act_fake_quant, TQT) assert isinstance(tqt_net.quant.act_fake_quant, TQT)
assert isinstance(tqt_net.linear.weight_fake_quant, TQT) assert isinstance(tqt_net.linear[0].weight_fake_quant, TQT)
assert isinstance(tqt_net.linear.act_fake_quant, TQT) assert isinstance(tqt_net.linear[0].act_fake_quant, TQT)
assert isinstance(tqt_net.linear[1].weight_fake_quant, TQT)
assert isinstance(tqt_net.linear[1].act_fake_quant, TQT)
assert tqt_net.dequant.act_fake_quant is None assert tqt_net.dequant.act_fake_quant is None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册