From 992a90bbaddb52e73d0f9d1fe2e22ddf92c213c9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 16 Mar 2021 18:59:23 +0800 Subject: [PATCH] docs(mge/quantization): add docstring for Observer GitOrigin-RevId: 043be3886dc05205426fd60c9af3bc6172c70ce9 --- .../python/megengine/quantization/observer.py | 50 ++++++++++++++++++- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/quantization/observer.py b/imperative/python/megengine/quantization/observer.py index dbeaf821..381884ae 100644 --- a/imperative/python/megengine/quantization/observer.py +++ b/imperative/python/megengine/quantization/observer.py @@ -26,9 +26,10 @@ logger = get_logger(__name__) class Observer(Module, QParamsModuleMixin): r""" - A base class for Observer Module. + A base class for Observer Module. Used to record input tensor's statistics for + quantization. - :param dtype: a string indicating to collect scale and zero_point of which dtype. + :param dtype: a string indicating which dtype to collect scale and zero_point of. """ def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs): @@ -72,6 +73,14 @@ class Observer(Module, QParamsModuleMixin): class MinMaxObserver(Observer): + r""" + A Observer Module records input tensor's running min and max values to calc scale. + + :param mode: set quantization mode. + :param eps: a initial maximum value to avoid division by zero problem. + :param dtype: a string indicating which dtype to collect scale and zero_point of. + """ + def __init__( self, mode: QuantMode = QuantMode.SYMMERTIC, @@ -119,6 +128,14 @@ class MinMaxObserver(Observer): class SyncMinMaxObserver(MinMaxObserver): + r""" + A distributed version of :class:`~.MinMaxObserver`. + + :param mode: set quantization mode. + :param eps: a initial maximum value to avoid division by zero problem. + :param dtype: a string indicating which dtype to collect scale and zero_point of. + """ + def forward(self, x_orig): if self.enable: x = x_orig.detach() @@ -134,6 +151,15 @@ class SyncMinMaxObserver(MinMaxObserver): class ExponentialMovingAverageObserver(MinMaxObserver): + r""" + A :class:`~.MinMaxObserver` with momentum support for min/max updating. + + :param momentum: momentum ratio for min/max updating. + :param mode: set quantization mode. + :param eps: a initial maximum value to avoid division by zero problem. + :param dtype: a string indicating which dtype to collect scale and zero_point of. + """ + def __init__( self, momentum: float = 0.9, @@ -170,6 +196,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver): class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): + r""" + A distributed version of :class:`~.ExponentialMovingAverageObserver`. + + :param momentum: momentum ratio for min/max updating. + :param mode: set quantization mode. + :param eps: a initial maximum value to avoid division by zero problem. + :param dtype: a string indicating which dtype to collect scale and zero_point of. + """ + def forward(self, x_orig): if self.enabled: x = x_orig.detach() @@ -192,6 +227,17 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): class HistogramObserver(MinMaxObserver): + r""" + A :class:`~.MinMaxObserver` using running histogram of tensor values + for min/max updating. Usually used for calibration quantization. + + :param bins: number of bins to use for the histogram. + :param upsample_rate: which ratio to interpolate histograms in. + :param mode: set quantization mode. + :param eps: a initial maximum value to avoid division by zero problem. + :param dtype: a string indicating which dtype to collect scale and zero_point of. + """ + def __init__( self, bins: int = 2048, -- GitLab