utils.py 9.1 KB
Newer Older
1
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import sys
16

17
import numpy as np
18 19

from ...fluid.framework import IrNode, Operator
20
from .quant_config import SUPPORT_QUANTIZATION_OP_DICT
21

22
_channelwise_quant_axis1_ops = [
23 24 25 26
    'conv2d_transpose',
    'mul',
    'matmul',
    'matmul_v2',
27 28
]

29 30 31 32 33 34 35 36 37

def _get_op_input_var_names(op):
    """
    Get the input var names of the op.
    Args:
        op(IrNode, Operator): the input op.
    Returns:
        input_var_names or None.
    """
38 39 40
    assert isinstance(
        op, (IrNode, Operator)
    ), "The input op should be IrNode or Operator."
41
    var_names = []
42
    op_name = op.name() if isinstance(op, IrNode) else op.type
43
    if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
44 45
        return []

46
    name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][0]
47 48 49 50 51 52 53 54 55 56 57
    for name in name_list:
        var_name = op.input(name)
        if isinstance(var_name, list):
            var_names.extend(var_name)
        else:
            var_names.append(var_name)
    return var_names


def _get_op_output_var_names(op):
    """ """
58 59 60
    assert isinstance(
        op, (IrNode, Operator)
    ), "The input op should be IrNode or Operator."
61
    var_names = []
62
    op_name = op.name() if isinstance(op, IrNode) else op.type
63
    if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
64 65
        return []

66
    name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][1]
67 68 69 70 71 72 73 74 75 76 77
    for name in name_list:
        var_name = op.output(name)
        if isinstance(var_name, list):
            var_names.extend(var_name)
        else:
            var_names.append(var_name)
    return var_names


def _get_input_name_index(op, input_var_name):
    """Get the input name and index of the var_name in the op"""
78 79 80 81
    assert isinstance(
        op, (IrNode, Operator)
    ), "The input op should be IrNode or Operator."
    op_name = op.name() if isinstance(op, IrNode) else op.type
82
    if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
83 84 85
        return None

    res = None
86
    for argname in SUPPORT_QUANTIZATION_OP_DICT[op_name][0]:
87 88 89 90 91 92 93 94 95
        var_names = op.input(argname)
        for index, name in enumerate(var_names):
            if name == input_var_name:
                res = (argname, index)
    return res


def _get_output_name_index(op, output_var_name):
    """Get the output name and index of the var_name in the op"""
96 97 98 99
    assert isinstance(
        op, (IrNode, Operator)
    ), "The input op should be IrNode or Operator."
    op_name = op.name() if isinstance(op, IrNode) else op.type
100
    if op_name not in SUPPORT_QUANTIZATION_OP_DICT:
101 102
        return None

103
    name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][1]
104 105 106 107 108 109 110 111
    res = None
    for name in name_list:
        var_name = op.output(name)
        for index, val in enumerate(var_name):
            if val == output_var_name:
                res = (name, index)
    return res

112 113 114 115 116 117

def load_variable_data(scope, var_name):
    '''
    Load variable value from scope
    '''
    var_node = scope.find_var(var_name)
118
    assert var_node is not None, "Cannot find " + var_name + " in scope."
119 120 121 122 123 124 125
    return np.array(var_node.get_tensor())


def set_variable_data(scope, place, var_name, np_value):
    '''
    Set the value of var node by name, if the node exits,
    '''
126 127 128
    assert isinstance(
        np_value, np.ndarray
    ), 'The type of value should be numpy array.'
129
    var_node = scope.find_var(var_name)
130
    if var_node is not None:
131 132 133 134
        tensor = var_node.get_tensor()
        tensor.set(np_value, place)


135 136 137 138 139 140
def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False):
    # symmetry quant
    def _clip(x, scale):
        x[x > scale] = scale
        x[x < -scale] = -scale
        return x
141 142

    bnt = (1 << (weight_bits - 1)) - 1
143 144
    if isinstance(scale, list) and len(scale) == 1:
        scale = scale[0]
145
    if isinstance(scale, list):
146
        assert quant_axis in [-1, 0, 1], 'quant_axis should be 0 or 1 for now.'
147 148 149 150
        for i, s in enumerate(scale):
            if s == 0.0:
                s = 1e-8
            if quant_axis == 0:
151 152 153 154 155 156
                if onnx_format:
                    x[i] = np.round(x[i] / s * bnt)
                    x[i] = np.clip(x[i], -bnt - 1, bnt)
                else:
                    x[i] = _clip(x[i], s)
                    x[i] = x[i] / s * bnt
157
            else:
158 159 160 161 162 163
                if onnx_format:
                    x[:, i] = np.round(x[:, i] / s * bnt)
                    x[:, i] = np.clip(x[:, i], -bnt - 1, bnt)
                else:
                    x[:, i] = _clip(x[:, i], s)
                    x[:, i] = x[:, i] / s * bnt
164 165
    else:
        scale = 1e-8 if scale == 0.0 else scale
166 167 168 169 170 171
        if onnx_format:
            x = np.round(x / scale * bnt)
            x = np.clip(x, -bnt - 1, bnt)
        else:
            x = _clip(x, scale)
            x = x / scale * bnt
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
    return x


def dequant_tensor(x, scale, quant_axis=0, weight_bits=8):
    assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
    bnt = (1 << (weight_bits - 1)) - 1
    if isinstance(scale, list):
        for i, s in enumerate(scale):
            if s == 0.0:
                s = 1e-8
            if quant_axis == 0:
                x[i] = x[i] * s / bnt
            else:
                x[:, i] = x[:, i] * s / bnt
    else:
        scale = 1e-8 if scale == 0.0 else scale
        x = x * scale / bnt
    return x


192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
def bias_correction_w(x, x_quant, scale_v, quant_axis, weight_bits=8):
    '''
    Bias correction for weight
    '''
    eps = 1e-8
    bnt = (1 << (weight_bits - 1)) - 1
    x_dequant = x_quant.copy()
    if isinstance(scale_v, list):
        if quant_axis == 0:
            for i, s in enumerate(scale_v):
                x_dequant[i] = x_dequant[i] * s / bnt
            quant_bias = x - x_dequant
            mean_bias = quant_bias.reshape(quant_bias.shape[0], -1).mean(-1)
            std_orig = x.reshape(x.shape[0], -1).std(-1)
            std_quant = x_dequant.reshape(x_dequant.shape[0], -1).std(-1)
            std_bias = std_orig / (std_quant + eps)
        else:
            for i, s in enumerate(scale_v):
                x_dequant[:, i] = x_quant[:, i] * s / bnt
            quant_bias = x - x_dequant
            mean_bias = np.array(
213 214
                [quant_bias[:, i].mean() for i in range(quant_bias.shape[1])]
            )
215 216
            std_orig = np.array([x[:, i].std() for i in range(x.shape[1])])
            std_quant = np.array(
217 218
                [x_dequant[:, i].std() for i in range(x_dequant.shape[1])]
            )
219 220 221 222 223 224 225 226 227 228
            std_bias = std_orig / (std_quant + eps)
    else:
        x_dequant = x_quant * scale_v / bnt
        mean_bias = (x - x_dequant).mean()
        std_bias = x.std() / (x_dequant.std() + eps)
    if mean_bias.ndim == 1:
        std_bias = np.resize(std_bias, x.shape)
        mean_bias = np.resize(mean_bias, x.shape)

    x_dequant = (mean_bias + x_dequant) * std_bias
229 230 231
    quantized_param_v = quant_tensor(
        x_dequant, scale_v, quant_axis, weight_bits
    )
232 233 234
    return quantized_param_v


235 236 237 238 239 240
def stable_sigmoid(x):
    sig = np.where(x < 0, np.exp(x) / (1 + np.exp(x)), 1 / (1 + np.exp(-x)))
    return sig


def calculate_quant_cos_error(orig_tensor, qdq_tensor):
241 242 243 244
    cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) / (
        np.linalg.norm(orig_tensor.flatten())
        * np.linalg.norm(qdq_tensor.flatten())
    )
245
    return cos_sim
246 247


248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
def move_persistable_var_to_global_block(program):
    # Move sub blocks persistable var to global block
    global_block = program.global_block()
    for _op in global_block.ops:
        if _op.type == "while":
            _block_id = _op.attr("sub_block").id
            _block = program.block(_block_id)
            persistables = []
            for _name, _var in _block.vars.items():
                if _var.persistable:
                    global_block._clone_variable(_var)
                    persistables.append(_name)
            for _name in persistables:
                _block._remove_var(_name)
            persistables.extend(_op.input('X'))
            _op.desc.set_input("X", persistables)


266
def l2_loss(gt, pred):
267
    return ((gt - pred) ** 2).mean()
268 269


270
class tqdm:
271 272 273 274 275 276 277 278 279 280 281
    def __init__(self, total, bar_format='Loading|{bar}', ncols=80):
        self.total = total
        self.bar_format = bar_format
        self.ncols = ncols
        self.n = 0

    def update(self, n=1):
        self.n += n
        a = "=" * round((self.n / self.total) * self.ncols)
        b = " " * (self.ncols - len(a))
        prefix = self.bar_format.split('|')[0]
282 283 284
        sys.stderr.write(
            "\r{}|{}=>{}| {}/{}".format(prefix, a, b, self.n, self.total)
        )
285 286 287 288 289 290 291
        sys.stderr.flush()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stderr.write('\n')