提交 f31752d5 编写于 作者: M Megvii Engine Team

feat(mge/module): add __repr__ method for qat and quantized module

GitOrigin-RevId: 0d78cbab9a93801c62f0c5d079ec20ccc705ff17
上级 9451a961
......@@ -39,6 +39,9 @@ class QATModule(Module):
self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize
def __repr__(self):
return "QAT." + super().__repr__()
def set_qconfig(self, qconfig: QConfig):
r"""
Set quantization related configs with ``qconfig``, including
......
......@@ -22,6 +22,9 @@ class QuantizedModule(Module):
raise ValueError("quantized module only support inference.")
return super().__call__(*inputs, **kwargs)
def __repr__(self):
return "Quantized." + super().__repr__()
@classmethod
@abstractmethod
def from_qat_module(cls, qat_module: QATModule):
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial
import numpy as np
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import platform
import numpy as np
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import megengine.module as M
from megengine.quantization import quantize, quantize_qat
def test_repr():
class Net(M.Module):
def __init__(self):
super().__init__()
self.conv_bn = M.ConvBnRelu2d(3, 3, 3)
self.linear = M.Linear(3, 3)
def forward(self, x):
return x
net = Net()
ground_truth = (
"Net(\n"
" (conv_bn): ConvBnRelu2d(\n"
" (conv): Conv2d(3, 3, kernel_size=(3, 3))\n"
" (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
" )\n"
" (linear): Linear(in_features=3, out_features=3, bias=True)\n"
")"
)
assert net.__repr__() == ground_truth
quantize_qat(net)
ground_truth = (
"Net(\n"
" (conv_bn): QAT.ConvBnRelu2d(\n"
" (conv): Conv2d(3, 3, kernel_size=(3, 3))\n"
" (bn): BatchNorm2d(3, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
" (act_observer): ExponentialMovingAverageObserver()\n"
" (act_fake_quant): FakeQuantize()\n"
" (weight_observer): MinMaxObserver()\n"
" (weight_fake_quant): FakeQuantize()\n"
" )\n"
" (linear): QAT.Linear(\n"
" in_features=3, out_features=3, bias=True\n"
" (act_observer): ExponentialMovingAverageObserver()\n"
" (act_fake_quant): FakeQuantize()\n"
" (weight_observer): MinMaxObserver()\n"
" (weight_fake_quant): FakeQuantize()\n"
" )\n"
")"
)
assert net.__repr__() == ground_truth
quantize(net)
ground_truth = (
"Net(\n"
" (conv_bn): Quantized.ConvBnRelu2d(3, 3, kernel_size=(3, 3))\n"
" (linear): Quantized.Linear()\n"
")"
)
assert net.__repr__() == ground_truth
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册