ptq_quantizer.py 8.3 KB
Newer Older
1
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import math
17

18 19 20 21
import numpy as np

import paddle

22
from ...static.quantization.cal_kl_threshold import cal_kl_threshold
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
from . import utils


def abs_max_value(tensor):
    return float(paddle.max(paddle.abs(tensor)).numpy())


def merge_max_value(old, new):
    """
    Merge the max element one by one in two lists.
    """
    assert isinstance(old, list) and isinstance(new, list)
    if old != []:
        assert len(old) == len(new)
        for i in range(len(old)):
            assert type(old[i]) == type(new[i])
            if isinstance(old[i], list):
                new[i] = merge_max_value(old[i], new[i])
            else:
                new[i] = old[i] if new[i] < old[i] else new[i]
    return new


46 47 48 49
def combine_abs_max_and_hist(
    tensor, origin_max, origin_hist, bins, upsample_bins
):
    """ """
50 51 52 53 54 55

    new_max = abs_max_value(tensor)

    if new_max == 0.0:
        return origin_max, origin_hist
    elif origin_max == 0.0:
56 57 58
        new_hist, _ = np.histogram(
            paddle.abs(tensor).numpy(), range=(0, new_max), bins=bins
        )
59 60 61
        new_hist = new_hist.astype(np.float32)
        return new_max, new_hist
    elif new_max <= origin_max:
62 63 64
        new_hist, _ = np.histogram(
            paddle.abs(tensor).numpy(), range=(0, origin_max), bins=bins
        )
65 66 67 68
        new_hist = new_hist.astype(np.float32)
        new_hist += origin_hist
        return origin_max, new_hist
    else:
69
        # bin_width = origin_max / (bins * upsample_bins)
70 71 72 73 74 75 76
        #           = new_max / (bins * downsample_bins)
        bin_width = origin_max / (bins * upsample_bins)
        downsampe_bins = int(math.ceil(new_max / (bins * bin_width)))
        new_max = bins * bin_width * downsampe_bins

        upsampled_hist = np.repeat(origin_hist, upsample_bins)
        expanded_hist = np.zeros((bins * downsampe_bins), dtype=np.float32)
77 78 79 80
        expanded_hist[0 : bins * upsample_bins] = upsampled_hist
        cumsumed_hist = np.cumsum(expanded_hist, dtype=np.float64)[
            downsampe_bins - 1 :: downsampe_bins
        ]
81 82 83 84 85
        shift_cumsumed_hist = np.zeros((bins), dtype=np.float64)
        shift_cumsumed_hist[1:] = cumsumed_hist[0:-1]
        sampled_hist = (cumsumed_hist - shift_cumsumed_hist) / upsample_bins
        sampled_hist = sampled_hist.astype(np.float32)

86 87 88
        new_hist, _ = np.histogram(
            paddle.abs(tensor).numpy(), range=(0, new_max), bins=bins
        )
89 90 91 92 93 94
        new_hist = new_hist.astype(np.float32)
        new_hist += sampled_hist

        return new_max, new_hist


95
class BaseQuantizer(metaclass=abc.ABCMeta):
96 97 98 99 100
    """
    Base quantizer for activation and weight.
    """

    def __init__(self, quant_bits=8):
101
        super().__init__()
102 103 104 105 106
        assert isinstance(quant_bits, int)
        assert quant_bits > 0 and quant_bits <= 16

        self.quant_bits = quant_bits

107
        self.abs_max_vals = []
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
        self.thresholds = []

    @abc.abstractmethod
    def sample_data(self, layer, tensors):
        pass

    @abc.abstractmethod
    def cal_thresholds(self):
        pass


class AbsmaxQuantizer(BaseQuantizer):
    """
    Per-tensor abs max quantizer.
    """

    def __init__(self, quant_bits=8):
125
        super().__init__(quant_bits)
126 127 128 129 130

    def sample_data(self, layer, tensors):
        assert isinstance(tensors, tuple)

        abs_max_vals = [abs_max_value(t) for t in tensors]
131
        self.abs_max_vals = merge_max_value(self.abs_max_vals, abs_max_vals)
132 133

    def cal_thresholds(self):
134
        self.thresholds = self.abs_max_vals
135 136 137 138 139 140 141 142


class PerChannelAbsmaxQuantizer(BaseQuantizer):
    """
    Per channel abs max quantizer.
    """

    def __init__(self, quant_bits=8):
143
        super().__init__(quant_bits)
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161

    def sample_data(self, layer, tensors):
        assert isinstance(layer, paddle.nn.Layer)
        assert isinstance(tensors, tuple)

        abs_max_vals_list = []
        for idx, tensor in enumerate(tensors):
            if isinstance(layer, tuple(utils.spec_channel_axis_layers)):
                abs_max_vals = [
                    abs_max_value(tensor[:, i]) for i in range(tensor.shape[1])
                ]
                abs_max_vals_list.append(abs_max_vals)
            else:
                abs_max_vals = [
                    abs_max_value(tensor[i]) for i in range(tensor.shape[0])
                ]
                abs_max_vals_list.append(abs_max_vals)

162 163 164
        self.abs_max_vals = merge_max_value(
            self.abs_max_vals, abs_max_vals_list
        )
165 166

    def cal_thresholds(self):
167
        self.thresholds = self.abs_max_vals
168 169


170
class BaseHistQuantizer(BaseQuantizer, metaclass=abc.ABCMeta):
171
    """ """
172 173

    def __init__(self, quant_bits=8, bins=1024, upsample_bins=64):
174
        super().__init__(quant_bits)
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
        self.bins = bins
        self.upsample_bins = upsample_bins

        self.hists = []

    def sample_data(self, layer, tensors):
        assert isinstance(tensors, tuple)

        if self.abs_max_vals == []:
            abs_max_vals = [abs_max_value(t) for t in tensors]
            self.abs_max_vals = abs_max_vals

            for idx, tensor in enumerate(tensors):
                if abs_max_vals[idx] == 0.0:
                    self.hists.append(None)
                else:
191 192 193 194 195
                    hist, _ = np.histogram(
                        paddle.abs(tensor).numpy(),
                        range=(0.0, abs_max_vals[idx]),
                        bins=self.bins,
                    )
196 197 198 199 200 201 202 203
                    hist = hist.astype(np.float32)
                    self.hists.append(hist)
        else:
            assert len(self.abs_max_vals) == len(tensors)
            assert len(self.hists) == len(tensors)

            for idx, tensor in enumerate(tensors):
                new_abs_max, new_hist = combine_abs_max_and_hist(
204 205 206 207 208 209
                    tensor,
                    self.abs_max_vals[idx],
                    self.hists[idx],
                    self.bins,
                    self.upsample_bins,
                )
210 211 212 213 214 215 216 217 218
                self.abs_max_vals[idx] = new_abs_max
                self.hists[idx] = new_hist

    @abc.abstractmethod
    def cal_thresholds(self):
        pass


class HistQuantizer(BaseHistQuantizer):
219
    """ """
220

221 222 223
    def __init__(
        self, quant_bits=8, bins=1024, upsample_bins=64, hist_percent=0.99999
    ):
224
        super().__init__(quant_bits, bins, upsample_bins)
225 226 227 228 229 230 231 232 233 234 235 236 237 238
        self.hist_percent = hist_percent

    def cal_thresholds(self):
        def _helper(abs_max, hist, percent):
            assert hist.ndim == 1 and percent < 1.0
            hist = hist / np.sum(hist, dtype=np.float64)
            cumsumed_hist = np.cumsum(hist)
            index = np.argwhere(cumsumed_hist >= percent)[0]
            return float((index - 0.5) * (abs_max / hist.shape[0]))

        for idx in range(len(self.hists)):
            if self.hists[idx] is None:
                self.thresholds.append(self.abs_max_vals[idx])
            else:
239 240 241
                threshold = _helper(
                    self.abs_max_vals[idx], self.hists[idx], self.hist_percent
                )
242 243 244 245
                self.thresholds.append(threshold)


class KLQuantizer(BaseHistQuantizer):
246
    """ """
247 248

    def __init__(self, quant_bits=8, bins=1024, upsample_bins=64):
249
        super().__init__(quant_bits, bins, upsample_bins)
250 251 252 253 254 255

    def cal_thresholds(self):
        for idx in range(len(self.hists)):
            if self.hists[idx] is None:
                self.thresholds.append(self.abs_max_vals[idx])
            else:
256 257 258 259
                hist = self.hists[idx]
                abs_max_val = self.abs_max_vals[idx]
                bin_width = abs_max_val / hist.shape[0]
                threshold = cal_kl_threshold(hist, bin_width, self.quant_bits)
260
                self.thresholds.append(threshold)
261 262 263 264


SUPPORT_ACT_QUANTIZERS = [AbsmaxQuantizer, HistQuantizer, KLQuantizer]
SUPPORT_WT_QUANTIZERS = [AbsmaxQuantizer, PerChannelAbsmaxQuantizer]