qat.py 13.8 KB
Newer Older
1
# Copyright (c) 2021  PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging

import paddle
from ...common import get_logger

_logger = get_logger(__name__, level=logging.INFO)

WEIGHT_QUANTIZATION_TYPES = [
    'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'
]

ACTIVATION_QUANTIZATION_TYPES = [
    'abs_max', 'range_abs_max', 'moving_average_abs_max'
]

BUILT_IN_PREPROCESS_TYPES = ['PACT']

VALID_DTYPES = ['int8']

__all__ = ['QAT']

_quant_config_default = {
    # weight preprocess type, default is None and no preprocessing is performed. 
    'weight_preprocess_type': None,
    # activation preprocess type, default is None and no preprocessing is performed.
    'activation_preprocess_type': None,
    # weight quantize type, default is 'channel_wise_abs_max'
    'weight_quantize_type': 'channel_wise_abs_max',
    # activation quantize type, default is 'moving_average_abs_max'
    'activation_quantize_type': 'moving_average_abs_max',
    # weight quantize bit num, default is 8
    'weight_bits': 8,
    # activation quantize bit num, default is 8
    'activation_bits': 8,
    # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
    'dtype': 'int8',
    # window size for 'range_abs_max' quantization. default is 10000
    'window_size': 10000,
    # The decay coefficient of moving average, default is 0.9
    'moving_rate': 0.9,
    # for dygraph quantization, layers of type in quantizable_layer_type will be quantized
    'quantizable_layer_type': ['Conv2D', 'Linear'],
58 59
    # whether fuse conv and bn before QAT
    'fuse_conv_bn': False,
60 61
    # Whether to export the quantized model with format of ONNX. Default is True.
    'onnx_format': True,
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
}


def _parse_configs(user_config):
    """
    check if user's configs are valid.
    Args:
        user_config(dict): user's config.
    Return:
        configs(dict): final configs will be used.
    """

    configs = copy.deepcopy(_quant_config_default)
    configs.update(user_config)

    # check if configs is valid
    weight_types = WEIGHT_QUANTIZATION_TYPES
    activation_types = WEIGHT_QUANTIZATION_TYPES

    assert configs['weight_preprocess_type'] in BUILT_IN_PREPROCESS_TYPES or configs['weight_preprocess_type'] is None, \
        "Unknown weight_preprocess_type: {}. only supports {} ".format(configs['weight_preprocess_type'],
                BUILT_IN_PREPROCESS_TYPES)

    assert configs['activation_preprocess_type'] in BUILT_IN_PREPROCESS_TYPES or configs['activation_preprocess_type'] is None, \
        "Unknown activation_preprocess_type: {}. only supports {}".format(configs['activation_preprocess_type'],
                BUILT_IN_PREPROCESS_TYPES)

    assert configs['weight_quantize_type'] in WEIGHT_QUANTIZATION_TYPES, \
        "Unknown weight_quantize_type: {}. only supports {} ".format(configs['weight_quantize_type'],
                WEIGHT_QUANTIZATION_TYPES)

    assert configs['activation_quantize_type'] in ACTIVATION_QUANTIZATION_TYPES, \
        "Unknown activation_quantize_type: {}. only supports {}".format(configs['activation_quantize_type'],
                ACTIVATION_QUANTIZATION_TYPES)

    assert isinstance(configs['weight_bits'], int), \
        "weight_bits must be int value."

    assert (configs['weight_bits'] >= 1 and configs['weight_bits'] <= 16), \
        "weight_bits should be between 1 and 16."

    assert isinstance(configs['activation_bits'], int), \
        "activation_bits must be int value."

    assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \
        "activation_bits should be between 1 and 16."

    assert isinstance(configs['dtype'], str), \
        "dtype must be a str."

    assert (configs['dtype'] in VALID_DTYPES), \
        "dtype can only be " + " ".join(VALID_DTYPES)

    assert isinstance(configs['window_size'], int), \
        "window_size must be int value, window size for 'range_abs_max' quantization, default is 10000."

    assert isinstance(configs['moving_rate'], float), \
        "moving_rate must be float value, The decay coefficient of moving average, default is 0.9."

    assert isinstance(configs['quantizable_layer_type'], list), \
        "quantizable_layer_type must be a list"

    return configs


class PACT(paddle.nn.Layer):
    def __init__(self):
        super(PACT, self).__init__()
        alpha_attr = paddle.ParamAttr(
            name=self.full_name() + ".pact",
132
            initializer=paddle.nn.initializer.Constant(value=100),
B
Bai Yifan 已提交
133
            learning_rate=1000.0)
134 135 136 137 138 139 140 141 142 143 144 145 146

        self.alpha = self.create_parameter(
            shape=[1], attr=alpha_attr, dtype='float32')

    def forward(self, x):
        out_left = paddle.nn.functional.relu(x - self.alpha)
        out_right = paddle.nn.functional.relu(-self.alpha - x)
        x = x - out_left + out_right
        return x


class QAT(object):
    """
147 148
    Quant Aware Training(QAT): Add the fake quant logic for given quantizable layers,
    namely add the quant_dequant computational logic both for activation inputs and weight inputs.
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
    """

    def __init__(self,
                 config=None,
                 weight_preprocess=None,
                 act_preprocess=None,
                 weight_quantize=None,
                 act_quantize=None):
        """
        Args:
            model(nn.Layer)
            config(dict, optional): configs for quantization. if None, will use default config. 
                    Default: None.
            weight_quantize(class, optional): Defines how to quantize weight. Using this
                    can quickly test if user's quantization method works or not. In this method, user should
                    both define quantization function and dequantization function, that is, the function's input
                    is non-quantized weight and function returns dequantized weight. If None, will use
                    quantization op defined by 'weight_quantize_type'.
                    Default is None.
            act_quantize(class, optional): Defines how to quantize activation. Using this
                    can quickly test if user's quantization method works or not. In this function, user should
                    both define quantization and dequantization process, that is, the function's input
                    is non-quantized activation and function returns dequantized activation. If None, will use 
                    quantization op defined by 'activation_quantize_type'.
                    Default is None.
            weight_preprocess(class, optional): Defines how to preprocess weight before quantization. Using this
                    can quickly test if user's preprocess method works or not. The function's input
                    is non-quantized weight and function returns processed weight to be quantized. If None, will
                    use preprocess method defined by 'weight_preprocess_type'.
                    Default is None.
            act_preprocess(class, optional): Defines how to preprocess activation before quantization. Using this
                    can quickly test if user's preprocess method works or not. The function's input
                    is non-quantized activation and function returns processed activation to be quantized. If None,
                    will use preprocess method defined by 'activation_preprocess_type'.
                    Default is None.
        """
        if config is None:
            config = _quant_config_default
        else:
            assert isinstance(config, dict), "config must be dict"
            config = _parse_configs(config)
        self.config = config

        self.weight_preprocess = PACT if self.config[
            'weight_preprocess_type'] == 'PACT' else None
        self.act_preprocess = PACT if self.config[
            'activation_preprocess_type'] == 'PACT' else None

197 198 199 200
        self.weight_preprocess = weight_preprocess if weight_preprocess is not None \
            else self.weight_preprocess
        self.act_preprocess = act_preprocess if act_preprocess is not None \
            else self.act_preprocess
201 202 203
        self.weight_quantize = weight_quantize
        self.act_quantize = act_quantize

204 205
        # TODO: remove try-except when the version is stable
        try:
206
            self.imperative_qat = paddle.quantization.ImperativeQuantAware(
207 208 209 210 211 212 213 214 215 216 217 218
                weight_bits=self.config['weight_bits'],
                activation_bits=self.config['activation_bits'],
                weight_quantize_type=self.config['weight_quantize_type'],
                activation_quantize_type=self.config[
                    'activation_quantize_type'],
                moving_rate=self.config['moving_rate'],
                quantizable_layer_type=self.config['quantizable_layer_type'],
                fuse_conv_bn=self.config[
                    'fuse_conv_bn'],  # support Paddle > 2.3
                weight_preprocess_layer=self.weight_preprocess,
                act_preprocess_layer=self.act_preprocess,
                weight_quantize_layer=self.weight_quantize,
219 220 221
                act_quantize_layer=self.act_quantize,
                onnx_format=self.config['onnx_format'],  # support Paddle >= 2.4
            )
222
        except:
223
            self.imperative_qat = paddle.quantization.ImperativeQuantAware(
224 225 226 227 228 229 230 231 232 233 234
                weight_bits=self.config['weight_bits'],
                activation_bits=self.config['activation_bits'],
                weight_quantize_type=self.config['weight_quantize_type'],
                activation_quantize_type=self.config[
                    'activation_quantize_type'],
                moving_rate=self.config['moving_rate'],
                quantizable_layer_type=self.config['quantizable_layer_type'],
                weight_preprocess_layer=self.weight_preprocess,
                act_preprocess_layer=self.act_preprocess,
                weight_quantize_layer=self.weight_quantize,
                act_quantize_layer=self.act_quantize)
235

236
    def quantize(self, model, inplace=True):
237 238 239 240 241 242 243 244 245 246 247 248
        """
        Quantize the input model.

        Args:
            model(paddle.nn.Layer): The model to be quantized.
            inplace(bool): Whether apply quantization to the input model.
                           Default: False.
        Returns:
            quantized_model(paddle.nn.Layer): The quantized model.
        """
        assert isinstance(model, paddle.nn.Layer), \
            "The model must be the instance of paddle.nn.Layer."
C
Chang Xu 已提交
249 250
        if self.weight_preprocess is not None or self.act_preprocess is not None:
            self._model = copy.deepcopy(model)
251 252

        if inplace:
253 254
            quantize_model = self.imperative_qat.quantize(model)
            quant_model = quantize_model if quantize_model is not None else model
255 256
        else:
            quant_model = copy.deepcopy(model)
257 258 259
            quantize_model = self.imperative_qat.quantize(quant_model)
            if quantize_model is not None:
                quant_model = quantize_model
260 261

        return quant_model
262

263
    def save_quantized_model(self, model, path, input_spec=None):
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
        """
        Save the quantized inference model.

        Args:
            model (Layer): The model to be saved.
            path (str): The path prefix to save model. The format is 
                ``dirname/file_prefix`` or ``file_prefix``.
            input_spec (list[InputSpec|Tensor], optional): Describes the input
                of the saved model's forward method, which can be described by
                InputSpec or example Tensor. If None, all input variables of 
                the original Layer's forward method would be the inputs of
                the saved model. Default: None.

        Returns:
            None
        """
280
        if self.weight_preprocess is not None or self.act_preprocess is not None:
281
            training = model.training
282
            model = self._remove_preprocess(model)
283 284 285 286
            if training:
                model.train()
            else:
                model.eval()
287 288

        self.imperative_qat.save_quantized_model(
289
            layer=model, path=path, input_spec=input_spec)
290 291 292

    def _remove_preprocess(self, model):
        state_dict = model.state_dict()
293
        try:
294
            self.imperative_qat = paddle.quantization.ImperativeQuantAware(
295 296 297 298 299 300 301 302 303 304
                weight_bits=self.config['weight_bits'],
                activation_bits=self.config['activation_bits'],
                weight_quantize_type=self.config['weight_quantize_type'],
                activation_quantize_type=self.config[
                    'activation_quantize_type'],
                moving_rate=self.config['moving_rate'],
                quantizable_layer_type=self.config['quantizable_layer_type'],
                onnx_format=self.config['onnx_format'],  # support Paddle >= 2.4
            )
        except:
305
            self.imperative_qat = paddle.quantization.ImperativeQuantAware(
306 307 308 309 310 311 312
                weight_bits=self.config['weight_bits'],
                activation_bits=self.config['activation_bits'],
                weight_quantize_type=self.config['weight_quantize_type'],
                activation_quantize_type=self.config[
                    'activation_quantize_type'],
                moving_rate=self.config['moving_rate'],
                quantizable_layer_type=self.config['quantizable_layer_type'])
313

W
whs 已提交
314 315 316 317 318 319 320
        paddle.disable_static()
        if hasattr(model, "_layers"):
            model = model._layers
        model = self._model
        self.imperative_qat.quantize(model)
        model.set_state_dict(state_dict)
        paddle.enable_static()
321

322
        return model