提交 855c49ca 编写于 作者: M Megvii Engine Team

feat(mge/module): add linear quantization module

GitOrigin-RevId: d0c96a94112724afd263b56f866986872f48cf3d
上级 90107b6d
......@@ -8,13 +8,13 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from .. import functional as F
from ..core import Parameter
from ..functional import linear
from . import init
from .module import Module
from .module import QATModule
class Linear(Module):
class Linear(QATModule):
r"""Applies a linear transformation to the input. For instance, if input
is x, then output y is:
......@@ -55,5 +55,18 @@ class Linear(Module):
if self.bias is not None:
init.zeros_(self.bias)
def _calc_linear(self, x, weight, bias):
return F.linear(x, weight, bias)
def forward(self, x):
return linear(x, self.weight, self.bias)
return self._calc_linear(x, self.weight, self.bias)
def forward_qat(self, x):
w_qat = self.apply_fakequant_with_observer(
self.weight, self.weight_fake_quant, self.weight_observer
)
return self.apply_fakequant_with_observer(
self._calc_linear(x, w_qat, self.bias),
self.act_fake_quant,
self.act_observer,
)
......@@ -8,4 +8,5 @@
from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise
from .linear import Linear
from .quant_dequant import DequantStub, QuantStub
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import megengine._internal as mgb
from ... import functional as F
from ... import module as Float
from ...core import Parameter
from ...quantization.utils import register_method_to_class
from ..module import Module
class Linear(Module):
r"""Applies a quantized linear transformation to the input. The module
usually convert from QAT module by to_quantized method.
:param dtype: output data type.
"""
def __init__(
self, dtype: np.dtype = None,
):
super().__init__()
self.weight = None
self.bias = None
self.output_dtype = dtype
def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
inp_scale = mgb.dtype.get_scale(inp.dtype)
w_scale = mgb.dtype.get_scale(self.weight.dtype)
bias_dtype = mgb.dtype.qint32(inp_scale * w_scale)
return F.linear(
inp,
self.weight,
None if self.bias is None else self.bias.astype(bias_dtype),
).astype(self.output_dtype)
@register_method_to_class(Float.Linear)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
output_dtype = float_module.act_observer.get_dtype()
qmod = Linear(dtype=output_dtype,)
weight = float_module.weight.astype(float_module.weight_observer.get_dtype())
qmod.weight = Parameter(weight.numpy())
if float_module.bias is not None:
qmod.bias = Parameter(float_module.bias.numpy())
return qmod
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册