diff --git a/imperative/python/megengine/quantization/observer.py b/imperative/python/megengine/quantization/observer.py index dbeaf821b0b744cfbae347ae4f61a912c0a3fa71..381884ae9885a7c448444114653fd90ae2f530ee 100644 --- a/imperative/python/megengine/quantization/observer.py +++ b/imperative/python/megengine/quantization/observer.py @@ -26,9 +26,10 @@ logger = get_logger(__name__) class Observer(Module, QParamsModuleMixin): r""" - A base class for Observer Module. + A base class for Observer Module. Used to record input tensor's statistics for + quantization. - :param dtype: a string indicating to collect scale and zero_point of which dtype. + :param dtype: a string indicating which dtype to collect scale and zero_point of. """ def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs): @@ -72,6 +73,14 @@ class Observer(Module, QParamsModuleMixin): class MinMaxObserver(Observer): + r""" + A Observer Module records input tensor's running min and max values to calc scale. + + :param mode: set quantization mode. + :param eps: a initial maximum value to avoid division by zero problem. + :param dtype: a string indicating which dtype to collect scale and zero_point of. + """ + def __init__( self, mode: QuantMode = QuantMode.SYMMERTIC, @@ -119,6 +128,14 @@ class MinMaxObserver(Observer): class SyncMinMaxObserver(MinMaxObserver): + r""" + A distributed version of :class:`~.MinMaxObserver`. + + :param mode: set quantization mode. + :param eps: a initial maximum value to avoid division by zero problem. + :param dtype: a string indicating which dtype to collect scale and zero_point of. + """ + def forward(self, x_orig): if self.enable: x = x_orig.detach() @@ -134,6 +151,15 @@ class SyncMinMaxObserver(MinMaxObserver): class ExponentialMovingAverageObserver(MinMaxObserver): + r""" + A :class:`~.MinMaxObserver` with momentum support for min/max updating. + + :param momentum: momentum ratio for min/max updating. + :param mode: set quantization mode. + :param eps: a initial maximum value to avoid division by zero problem. + :param dtype: a string indicating which dtype to collect scale and zero_point of. + """ + def __init__( self, momentum: float = 0.9, @@ -170,6 +196,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver): class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): + r""" + A distributed version of :class:`~.ExponentialMovingAverageObserver`. + + :param momentum: momentum ratio for min/max updating. + :param mode: set quantization mode. + :param eps: a initial maximum value to avoid division by zero problem. + :param dtype: a string indicating which dtype to collect scale and zero_point of. + """ + def forward(self, x_orig): if self.enabled: x = x_orig.detach() @@ -192,6 +227,17 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): class HistogramObserver(MinMaxObserver): + r""" + A :class:`~.MinMaxObserver` using running histogram of tensor values + for min/max updating. Usually used for calibration quantization. + + :param bins: number of bins to use for the histogram. + :param upsample_rate: which ratio to interpolate histograms in. + :param mode: set quantization mode. + :param eps: a initial maximum value to avoid division by zero problem. + :param dtype: a string indicating which dtype to collect scale and zero_point of. + """ + def __init__( self, bins: int = 2048,