observer.py 19.4 KB
Newer Older
1 2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
4 5 6 7 8 9
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from abc import abstractmethod
10
from copy import deepcopy
11
from typing import Union
12 13 14 15

import numpy as np

from .. import functional as F
16
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
17 18
from ..distributed import WORLD, get_rank, is_distributed
from ..functional.distributed import all_reduce_max, all_reduce_min
19
from ..logger import get_logger
20
from ..module import Module
M
Megvii Engine Team 已提交
21
from ..tensor import Tensor
22
from .utils import QParams, QParamsModuleMixin, QuantMode, create_qparams
23

24
logger = get_logger(__name__)
25

26 27

class Observer(Module, QParamsModuleMixin):
28
    r"""
29 30
    A base class for Observer Module. Used to record input tensor's statistics for
    quantization.
31

32
    :param dtype: a string indicating which dtype to collect scale and zero_point of.
33 34
    """

35
    def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
36
        super().__init__()
37 38 39 40 41 42
        if isinstance(dtype, str):
            if not dtype in _builtin_quant_dtypes:
                raise ValueError(
                    "unknown dtype: {}, only support {}".format(
                        dtype, _builtin_quant_dtypes.keys()
                    )
43
                )
44 45 46 47 48 49 50
            dtype = _builtin_quant_dtypes[dtype]
        if "narrow_range" in kwargs:
            del kwargs["narrow_range"]
            logger.warning(
                "FakeQuantize currently has no narrow_range param "
                "so it is ignored here",
                exc_info=DeprecationWarning,
51 52
            )
        self.dtype = dtype
53 54
        self.qmin = dtype.qmin
        self.qmax = dtype.qmax
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        self.enabled = True

    def enable(self):
        self.enabled = True

    def disable(self):
        self.enabled = False

    def train(self, mode: bool = True, recursive: bool = True) -> None:
        super().train(mode, recursive)
        if mode:
            self.enable()
        else:
            self.disable()

    @abstractmethod
    def forward(self, x):
        pass


class MinMaxObserver(Observer):
76 77 78 79 80 81 82 83
    r"""
    A Observer Module records input tensor's running min and max values to calc scale.

    :param mode: set quantization mode.
    :param eps: a initial maximum value to avoid division by zero problem.
    :param dtype: a string indicating which dtype to collect scale and zero_point of.
    """

84 85
    def __init__(
        self,
86 87 88
        mode: QuantMode = QuantMode.SYMMERTIC,
        eps: float = 0.00001,
        dtype: Union[str, QuantDtypeMeta] = "qint8",
89
        **kwargs
90
    ):
91
        super().__init__(dtype, **kwargs)
92
        self.mode = mode
M
Megvii Engine Team 已提交
93 94
        self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
        self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
95 96 97 98 99 100 101 102
        self.scale_limit = eps

    def _calculate_qparams(self, inp_min_val, inp_max_val):
        min_val = F.minimum(0.0, inp_min_val)
        max_val = F.maximum(0.0, inp_max_val)
        if self.mode == QuantMode.SYMMERTIC:
            symmetric_max_vals = F.maximum(-min_val, max_val)
            # use maximun to avoid scale too small at the begin
103
            scale = F.maximum(
104 105
                symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
            )
106
            zero_point = None
107 108
        else:
            # use maximun to avoid scale too small at the begin
109
            scale = F.maximum(
110
                (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
111 112
            )
            # caculate zero_point
113
            zero_point = self.qmin - F.round((min_val / scale))
114

115
        return create_qparams(self.mode, self.dtype, scale=scale, zero_point=zero_point)
116 117 118 119 120 121 122 123 124

    def get_qparams(self):
        return self._calculate_qparams(self.min_val, self.max_val)

    def forward(self, x_orig):
        if self.enabled:
            # stop gradient
            x = x_orig.detach()
            # find max and min
125 126
            self.min_val[...] = F.minimum(self.min_val, x.min())
            self.max_val[...] = F.maximum(self.max_val, x.max())
127 128 129
        return x_orig


130
class SyncMinMaxObserver(MinMaxObserver):
131 132 133 134 135 136 137 138
    r"""
    A distributed version of :class:`~.MinMaxObserver`.

    :param mode: set quantization mode.
    :param eps: a initial maximum value to avoid division by zero problem.
    :param dtype: a string indicating which dtype to collect scale and zero_point of.
    """

139 140 141 142 143 144 145 146 147
    def forward(self, x_orig):
        if self.enable:
            x = x_orig.detach()
            if is_distributed():
                min_x = all_reduce_min(x.min(), WORLD)
                max_x = all_reduce_max(x.max(), WORLD)
            else:
                min_x = x.min()
                max_x = x.max()
148 149
            self.min_val[...] = F.minimum(self.min_val, min_x)
            self.max_val[...] = F.maximum(self.max_val, max_x)
150 151 152
        return x_orig


153
class ExponentialMovingAverageObserver(MinMaxObserver):
154 155 156 157 158 159 160 161 162
    r"""
    A :class:`~.MinMaxObserver` with momentum support for min/max updating.

    :param momentum: momentum ratio for min/max updating.
    :param mode: set quantization mode.
    :param eps: a initial maximum value to avoid division by zero problem.
    :param dtype: a string indicating which dtype to collect scale and zero_point of.
    """

163 164
    def __init__(
        self,
165 166 167 168
        momentum: float = 0.9,
        mode: QuantMode = QuantMode.SYMMERTIC,
        eps: float = 0.00001,
        dtype: Union[str, QuantDtypeMeta] = "qint8",
169
        **kwargs
170
    ):
171
        super().__init__(mode, eps, dtype, **kwargs)
172
        self.momentum = Tensor(momentum, dtype="float32")
173 174
        # used to avoid if-clauses in the first forward which is not supported
        # in trace mode.
M
Megvii Engine Team 已提交
175
        self.runtime_momentum = Tensor(0.0)
176 177

    def set_momentum(self, momentum):
178
        self.momentum = Tensor(momentum, dtype="float32")
179 180 181 182 183 184

    def forward(self, x_orig):
        if self.enabled:
            # stop gradient
            x = x_orig.detach()
            # Exponential Moving Average
185
            self.min_val[...] = (
186 187 188
                self.min_val * self.runtime_momentum
                + (1 - self.runtime_momentum) * x.min()
            )
189
            self.max_val[...] = (
190 191 192
                self.max_val * self.runtime_momentum
                + (1 - self.runtime_momentum) * x.max()
            )
193
            self.runtime_momentum[...] = self.momentum
194 195 196 197

        return x_orig


198
class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
199 200 201 202 203 204 205 206 207
    r"""
    A distributed version of :class:`~.ExponentialMovingAverageObserver`.

    :param momentum: momentum ratio for min/max updating.
    :param mode: set quantization mode.
    :param eps: a initial maximum value to avoid division by zero problem.
    :param dtype: a string indicating which dtype to collect scale and zero_point of.
    """

208 209 210 211 212 213 214 215 216
    def forward(self, x_orig):
        if self.enabled:
            x = x_orig.detach()
            if is_distributed:
                min_x = all_reduce_min(x.min(), WORLD)
                max_x = all_reduce_max(x.max(), WORLD)
            else:
                min_x = x.min()
                max_x = x.max()
217
            self.min_val[...] = (
218 219 220
                self.min_val * self.runtime_momentum
                + (1 - self.runtime_momentum) * min_x
            )
221
            self.max_val[...] = (
222 223 224
                self.max_val * self.runtime_momentum
                + (1 - self.runtime_momentum) * max_x
            )
225
            self.runtime_momentum[...] = self.momentum
226 227 228
        return x_orig


229
class HistogramObserver(MinMaxObserver):
230 231 232 233 234 235 236 237 238 239 240
    r"""
    A :class:`~.MinMaxObserver` using running histogram of tensor values
    for min/max updating. Usually used for calibration quantization.

    :param bins: number of bins to use for the histogram.
    :param upsample_rate: which ratio to interpolate histograms in.
    :param mode: set quantization mode.
    :param eps: a initial maximum value to avoid division by zero problem.
    :param dtype: a string indicating which dtype to collect scale and zero_point of.
    """

241 242
    def __init__(
        self,
243 244 245 246 247
        bins: int = 2048,
        upsample_rate: int = 128,
        mode: QuantMode = QuantMode.SYMMERTIC,
        eps: float = 0.00001,
        dtype: Union[str, QuantDtypeMeta] = "qint8",
248
        **kwargs
249
    ):
250
        super().__init__(mode, eps, dtype, **kwargs)
251 252
        self.bins = bins
        self.upsample_rate = upsample_rate
253 254 255
        self.dst_nbins = (
            _builtin_quant_dtypes[dtype].qmax - _builtin_quant_dtypes[dtype].qmin + 1
        )
256
        self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")
257 258

    def _non_linear_param_search(self):
259 260
        r"""
        Non-linear parameter search.
261 262 263 264
        An approximation for L2 error minimization for selecting min/max.
        By selecting new min/max, we filter out outliers in input distribution.
        """

265 266
        np_min_val = self.min_val.numpy()
        np_max_val = self.max_val.numpy()
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
        np_histogram = self.histogram.numpy()
        assert len(np_histogram) == self.bins, "bins mistmatch"
        bin_width = (np_max_val - np_min_val) / self.bins

        def _get_norm(delta_begin, delta_end, density, norm_type):
            r"""
            Compute the norm of the values uniformaly distributed between
            delta_begin and delta_end.
            norm = density * (integral_{begin, end} x^2)
                 = density * (end^3 - begin^3) / 3
            """
            assert norm_type == "L2", "Only L2 norms are currently supported"
            norm = 0.0
            if norm_type == "L2":
                norm = (
                    delta_end * delta_end * delta_end
                    - delta_begin * delta_begin * delta_begin
                ) / 3
            return density * norm

        def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
            r"""
            Compute the quantization error if we use start_bin to end_bin as the
            min and max to do the quantization.
            """

            norm = 0.0
            dst_bin_width = (
                bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
            )
            if dst_bin_width == 0.0:
                return 0.0
            for src_bin in range(self.bins):
                # distances from the beginning of first dst_bin to the beginning and
                # end of src_bin
                src_bin_begin = (src_bin - next_start_bin) * bin_width
                src_bin_end = src_bin_begin + bin_width

                # which dst_bins the beginning and end of src_bin belong to?
                dst_bin_of_begin = min(
                    self.dst_nbins - 1,
                    max(0.0, math.floor(src_bin_begin / dst_bin_width)),
                )
                dst_bin_of_end = min(
                    self.dst_nbins - 1,
                    max(0.0, math.floor(src_bin_end / dst_bin_width)),
                )
                dst_bin_of_begin_center = (
                    dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
                )

                density = np_histogram[src_bin] / bin_width
                if dst_bin_of_begin == dst_bin_of_end:
                    # if src_bin is entirely within 1 dst_bin
                    delta_begin = src_bin_begin - dst_bin_of_begin_center
                    delta_end = src_bin_end - dst_bin_of_begin_center
                    norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
                else:
                    delta_begin = src_bin_begin - dst_bin_of_begin_center
                    delta_end = dst_bin_width / 2
                    norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)

                    norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
                        -dst_bin_width / 2, dst_bin_width / 2, density, norm_type
                    )

                    dst_bin_of_end_center = (
                        dst_bin_of_end * dst_bin_width + dst_bin_width / 2
                    )

                    delta_begin = -dst_bin_width / 2
                    delta_end = src_bin_end - dst_bin_of_end_center
                    norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
            return norm

        # cumulative sum
        total = sum(np_histogram)
        cSum = np.cumsum(np_histogram, axis=0)

        stepsize = 1e-5  # granularity
        alpha = 0.0  # lower bound
        beta = 1.0  # upper bound
        start_bin = 0
        end_bin = self.bins - 1
        norm_min = float("inf")

        while alpha < beta:
            # Find the next step
            next_alpha = alpha + stepsize
            next_beta = beta - stepsize

            # find the left and right bins between the quantile bounds
            l = start_bin
            r = end_bin
            while l < end_bin and cSum[l] < next_alpha * total:
                l = l + 1
            while r > start_bin and cSum[r] > next_beta * total:
                r = r - 1

            # decide the next move
            next_start_bin = start_bin
            next_end_bin = end_bin
            if (l - start_bin) > (end_bin - r):
                # move the start bin
                next_start_bin = l
                alpha = next_alpha
            else:
                # move the end bin
                next_end_bin = r
                beta = next_beta

            if next_start_bin == start_bin and next_end_bin == end_bin:
                continue

            # calculate the quantization error using next_start_bin and next_end_bin
            norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")

            if norm > norm_min:
                break
            norm_min = norm
            start_bin = next_start_bin
            end_bin = next_end_bin

390 391
        new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32)
        new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32)
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
        return new_min, new_max

    def get_qparams(self):
        new_min, new_max = self._non_linear_param_search()
        return self._calculate_qparams(new_min, new_max)

    def _combine_histograms(
        self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins
    ):
        # First up-sample the histogram with new data by a factor of L
        # This creates an approximate probability density thats piecwise constant
        upsampled_histogram = new_hist.repeat(upsample_rate)
        # Now insert the upsampled histogram into the output
        # histogram, which is initialized with zeros.
        # The offset at which the histogram is introduced is determined
        # by the start index as the output histogram can cover a wider range
        histogram_with_output_range = np.zeros((Nbins * downsample_rate))
        histogram_with_output_range[
            start_idx : Nbins * upsample_rate + start_idx
        ] = upsampled_histogram
        # Compute integral histogram, double precision is needed to ensure
        # that there are no overflows
        integral_histogram = np.cumsum(histogram_with_output_range, 0)[
            downsample_rate - 1 :: downsample_rate
        ]
        # Finally perform interpolation
        shifted_integral_histogram = np.zeros((Nbins))
        shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
        interpolated_histogram = (
            integral_histogram - shifted_integral_histogram
        ) / upsample_rate
        orig_hist = orig_hist + interpolated_histogram
        return orig_hist

    def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
        # We ensure that:
        # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
        # This allows us to have a common grid of resolution s, where we can align
        # the input histogram
        # start_idx maps min_val to the histogram bin index.
432 433
        np_min_val = self.min_val.numpy()
        np_max_val = self.max_val.numpy()
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449

        hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
        downsample_rate = int(
            np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width))
        )
        e = downsample_rate * (self.bins * hist_bin_width) - (
            combined_max - combined_min
        )
        combined_max = combined_max + e / 2
        combined_min = combined_min - e / 2
        start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width))

        return combined_min, combined_max, downsample_rate, start_idx

    def sideeffect_forward(self, x_orig):
        x = x_orig.numpy()
450 451
        min_val = self.min_val.numpy()
        max_val = self.max_val.numpy()
452 453 454 455 456 457 458 459 460 461 462
        histogram = self.histogram.numpy()
        new_min = x.min()
        new_max = x.max()
        if histogram[0] == -1:
            new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
        else:
            new_min = min(new_min, min_val)
            new_max = max(new_max, max_val)
            # combine the existing histogram and new histogram into 1 histogram
            # We do this by first upsampling the histogram to a dense grid
            # and then downsampling the histogram efficiently
463
            (new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max(
464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
                new_min, new_max, self.upsample_rate
            )

            new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
            new_histogram = new_histogram.astype(np.float64)
            if new_min == min_val and new_max == max_val:
                new_histogram += histogram
            else:
                new_histogram = self._combine_histograms(
                    new_histogram,
                    histogram,
                    self.upsample_rate,
                    downsample_rate,
                    start_idx,
                    self.bins,
                )

481 482 483
        self.histogram = Tensor(new_histogram, dtype="float32")
        self.min_val = Tensor(new_min, dtype="float32")
        self.max_val = Tensor(new_max, dtype="float32")
484 485 486 487

    def forward(self, x_orig):
        self.sideeffect_forward(x_orig)
        return x_orig
488 489 490 491


class PassiveObserver(Observer):
    r"""
492
    An Observer that supports setting :attr:`scale` directly.
493 494
    """

495 496 497
    def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
        super().__init__(dtype, **kwargs)
        self.qparams = None
498
        self.orig_scale = None
499 500 501

    @property
    def scale(self):
502
        return self.qparams.scale
503 504

    @scale.setter
505 506 507
    def scale(self, value: np.ndarray):
        assert np.all(value > 0)
        self.qparams.scale[...] = Tensor(value)
508 509

    def get_qparams(self):
510
        return self.qparams
511

512 513 514 515 516 517
    def set_qparams(self, qparams: QParams):
        """
        :param qparams: used to set initial scale.
        """
        self.qparams = deepcopy(qparams)
        if qparams.scale is None:
518
            raise AssertionError("Can not get an initialized scale")
519 520 521 522 523 524 525 526 527
        if qparams.dtype_meta is None:
            qparams.dtype_meta = self.dtype
        else:
            assert (
                qparams.dtype_meta is self.dtype
            ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
                qparams.dtype_meta, self.dtype
            )
        self.orig_scale = qparams.scale.numpy()
528

529 530
    def forward(self, x):
        r"""
531
        Just return input because :attr:`qparams` is set by :func:`~.apply_easy_quant`.
532 533
        """
        return x