提交 7a8a2830 编写于 作者: M Megvii Engine Team

feat(quant): support nnie quant

GitOrigin-RevId: 8ca3f828bdd55e9bfcec072831aa161e8b961fa8
上级 d49a2971
......@@ -49,6 +49,8 @@ class QATModule(Module):
def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
):
if observer is None:
return target
oup = observer(target)
if fake_quant is None:
return oup
......@@ -76,7 +78,7 @@ class QATModule(Module):
r"""
Get weight's quantization dtype as the method from ``qconfig``.
"""
if hasattr(self.act_fake_quant, "get_dtype"):
if hasattr(self.weight_fake_quant, "get_dtype"):
return self.weight_fake_quant.get_dtype()
else:
return self.weight_observer.get_dtype()
......
......@@ -5,7 +5,9 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .fake_quant import FakeQuantize
from .internal_fake_quant import *
from .observer import HistogramObserver, Observer, ObserverMode
from .qconfig import (
QConfig,
......
......@@ -19,6 +19,15 @@ from .observer import ObserverMode, Round
class _FakeQuantize(Module):
r"""
A Basic Fake Quant module.
:param dtype: A string indicating the target quantization type of input.
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation.
:param enable: Whether do ``normal_forward`` or ``fake_quant_forward``.
"""
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
super().__init__()
if not dtype in _metadata_dict.keys():
......@@ -92,9 +101,9 @@ class TQT_Function(Function):
class TQT(_FakeQuantize):
"""
r"""
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
"""
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
......@@ -119,11 +128,6 @@ class TQT(_FakeQuantize):
class FakeQuantize(_FakeQuantize):
r"""
A module to do quant and dequant according to observer's scale and zero_point.
:param dtype: A string indicating the target quantization type of input.
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation.
:param enable: Whether do ``normal_forward`` or ``fake_quant_forward``.
"""
def fake_quant_forward(self, inp, q_dict):
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
import math
from functools import partial
import numpy as np
from .. import functional as F
from ..core import Function
from .fake_quant import _FakeQuantize
from .observer import MinMaxObserver
from .qconfig import QConfig
......@@ -13,6 +13,7 @@ import megengine as mge
import megengine._internal as mgb
from megengine.core import tensor
from megengine.quantization.fake_quant import TQT_Function
from megengine.quantization.internal_fake_quant import *
from megengine.test import assertTensorClose
......@@ -75,3 +76,5 @@ def test_TQT():
a.set_value(a_np)
b.set_value(b_np)
check_inp(a, b, b, a_np, b_np, b_np)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册