post_training_quantization.py 65.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
15 16
import os
import re
17 18
import logging
import numpy as np
19
import shutil
20
from inspect import isgeneratorfunction
21 22 23
from .... import io
from .... import core
from .... import framework
24
from .... import unique_name
25
from ....executor import global_scope, Executor
26 27
from ....framework import IrGraph
from ....log_helper import get_logger
28
from .quantization_pass import QuantizationTransformPass, QuantizationTransformPassV2, QuantizationFreezePass, QuantWeightPass, AddQuantDequantPass, AddQuantDequantPassV2
29
from .cal_kl_threshold import cal_kl_threshold
30
from .adaround import run_adaround
31
from . import utils
32

33
__all__ = ['PostTrainingQuantization', 'WeightQuantization']
34 35 36 37 38

_logger = get_logger(
    __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')


39 40 41 42 43 44 45 46
def _all_persistable_var_names(program):
    persistable_var_names = []
    for var in program.list_vars():
        if var.persistable:
            persistable_var_names.append(var.name)
    return persistable_var_names


47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
def _remove_unused_var_nodes(graph):
    all_used_vars = set()
    ops = graph.all_op_nodes()
    for op_node in ops:
        for input_node in op_node.inputs:
            all_used_vars.add(input_node)
        for output_node in op_node.outputs:
            all_used_vars.add(output_node)

    all_used_vars = {n.node for n in all_used_vars}
    all_unused_vars = {
        n
        for n in filter(lambda node: node.node not in all_used_vars,
                        graph.all_var_nodes())
    }
    graph.safe_remove_nodes(all_unused_vars)
    return graph


def _remove_ctrl_vars(graph):
    remove_ctr_vars = set()
    for node in graph.all_var_nodes():
        if node.is_ctrl_var():
            remove_ctr_vars.add(node)
    graph.safe_remove_nodes(remove_ctr_vars)
    return graph


def _apply_pass(scope,
                graph,
                pass_name,
                attrs=None,
                attr_values=None,
                debug=False):
    ir_pass = core.get_pass(pass_name)
    cpp_graph = graph.graph
    if not cpp_graph.has('__param_scope__'):
        cpp_graph.set_not_owned('__param_scope__', scope)
    if attrs:
        assert attr_values and len(attrs) == len(
            attr_values), "Different number of pass attributes and their values."
        for attr, value in zip(attrs, attr_values):
            ir_pass.set(attr, value)
    ir_pass.apply(cpp_graph)
    if debug:
        graph.draw('.', 'qat_fp32_{}'.format(pass_name), graph.all_op_nodes())
    _remove_unused_var_nodes(graph)
    return graph


97
class PostTrainingQuantization(object):
98 99 100 101 102 103
    """
    Utilizing post training quantization methon to quantize the FP32 model,
    and it uses calibrate data to get the quantization information for all 
    quantized variables.
    """

104
    def __init__(self,
105 106 107
                 executor=None,
                 scope=None,
                 model_dir=None,
108 109
                 model_filename=None,
                 params_filename=None,
110
                 batch_generator=None,
111
                 sample_generator=None,
112
                 data_loader=None,
113 114 115
                 batch_size=10,
                 batch_nums=None,
                 algo="KL",
X
XGZhang 已提交
116
                 hist_percent=0.99999,
117
                 quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
118 119
                 round_type='round',
                 learning_rate=0.001,
120
                 is_full_quantize=False,
X
XGZhang 已提交
121
                 bias_correction=False,
122
                 activation_bits=8,
123 124 125
                 weight_bits=8,
                 activation_quantize_type='range_abs_max',
                 weight_quantize_type='channel_wise_abs_max',
126
                 onnx_format=False,
127
                 optimize_model=False,
128
                 is_use_cache_file=False,
129
                 cache_dir=None):
130
        '''
131
        Constructor.
132 133

        Args:
134
            executor(fluid.Executor): The executor to load, run and save the
135
                quantized model.
136 137
            scope(fluid.Scope, optional): The scope of the program, use it to load 
                and save variables. If scope=None, get scope by global_scope(). 
138 139 140 141 142 143 144 145 146
            model_dir(str): The path of the fp32 model that will be quantized, 
                and the model and params files are under the path.
            model_filename(str, optional): The name of file to load the inference 
                program. If it is None, the default filename '__model__' will 
                be used. Default is 'None'.
            params_filename(str, optional): The name of file to load all parameters.
                When all parameters were saved in a single binary file, set it 
                as the real filename. If parameters were saved in separate files, 
                set it as 'None'. Default is 'None'.
147 148 149 150 151 152 153 154
            batch_generator(Python Generator): The batch generator provides 
                calibrate data for DataLoader, and it returns a batch every
                time. Note that, sample_generator and batch_generator, only one
                should be set. Beisdes, batch_generator supports lod tensor.
            sample_generator(Python Generator): The sample generator provides
                calibrate data for DataLoader, and it only returns a sample every
                time. Note that, sample_generator and batch_generator, only one
                should be set. Beisdes, sample_generator dose not support lod tensor.
155 156 157
            data_loader(Python Generator, Paddle.io.DataLoader, optional): The
                Generator or Dataloader provides calibrate data, and it could
                return a batch every time.
158 159 160 161
            batch_size(int, optional): The batch size of DataLoader. Default is 10.
            batch_nums(int, optional): If batch_nums is not None, the number of 
                calibrate data is batch_size*batch_nums. If batch_nums is None, use 
                all data provided by sample_generator as calibrate data.
162 163 164 165
            algo(str, optional): If algo='KL', use KL-divergenc method to
                get the KL threshold for quantized activations and get the abs_max
                value for quantized weights. If algo='abs_max', get the abs max 
                value for activations and weights. If algo= 'min_max', get the min 
X
XGZhang 已提交
166 167 168 169 170 171 172
                and max value for quantized activations and weights. If algo='avg',
                get the average value among the max values for activations. If 
                algo= 'hist', get the value of 'hist_percent' quantile as the threshold.
                If algo='mse', get the value which makes the quantization mse loss 
                minimal. Default is KL.
            hist_percent(float, optional): The threshold of algo 'hist' for activations.
                Default is 0.99999.
173 174
            quantizable_op_type(list[str], optional): List the type of ops 
                that will be quantized. Default is ["conv2d", "depthwise_conv2d", 
175
                "mul"].
176 177 178 179
            round_type(str, optional): The method of converting the quantized weights
                value float->int. Currently supports ['round', 'adaround'] methods.
                Default is `round`, which is rounding nearest to the nearest whole number.
            learning_rate(float, optional): The learning rate of adaround method.
180
            is_full_quantized(bool, optional): If set is_full_quantized as True, 
181
                apply quantization to all supported quantizable op type. If set
182 183
                is_full_quantized as False, only apply quantization to the op type 
                according to the input quantizable_op_type.
X
XGZhang 已提交
184 185
            bias_correction(bool, optional): If set as True, use the bias correction
                method of https://arxiv.org/abs/1810.05723. Default is False.
186
            activation_bits(int): quantization bit number for activation.
187 188 189 190 191 192 193 194 195 196 197 198
            weight_bits(int, optional): quantization bit number for weights.
            activation_quantize_type(str): quantization type for activation,
                now support 'range_abs_max', 'moving_average_abs_max' and 'abs_max'.
                This param only specifies the fake ops in saving quantized model.
                If it is 'range_abs_max' or 'moving_average_abs_max', we save the scale
                obtained by post training quantization in fake ops. Note that, if it
                is 'abs_max', the scale will not be saved in fake ops.
            weight_quantize_type(str): quantization type for weights,
                support 'abs_max' and 'channel_wise_abs_max'. This param only specifies
                the fake ops in saving quantized model, and we save the scale obtained
                by post training quantization in fake ops. Compared to 'abs_max',
                the model accuracy is usually higher when it is 'channel_wise_abs_max'.
199 200
            onnx_format(bool): Whether to export the quantized model with format of ONNX.
                Default is False.
201 202 203 204 205 206 207 208
            optimize_model(bool, optional): If set optimize_model as True, it applies
                some passes to the model before quantization, and it supports
                `conv2d/depthwise_conv2d + bn` pass so far. Some targets require the
                weights are quantized by tensor-wise method, which means the weights
                scale for all channel are the same. However, if fuse
                `conv2d/depthwise_conv2d + bn`, the weights scale for all channel will
                be different. In address this problem, fuse the pattern before
                quantization. Default False.
209 210
            is_use_cache_file(bool, optional): This param is deprecated.
            cache_dir(str, optional): This param is deprecated.
211 212 213
        Returns:
            None

214 215 216 217 218 219
        Examples:
        .. code-block:: python
            import paddle.fluid as fluid
            from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
            
            exe = fluid.Executor(fluid.CPUPlace())
220 221 222 223 224 225 226 227 228
            model_dir = path/to/fp32_model_params
            # set model_filename as None when the filename is __model__, 
            # otherwise set it as the real filename
            model_filename = None 
            # set params_filename as None when all parameters were saved in 
            # separate files, otherwise set it as the real filename
            params_filename = None
            save_model_path = path/to/save_model_path
            # prepare the sample generator according to the model, and the 
229
            # sample generator must return a sample every time. The reference
230 231 232
            # document: https://www.paddlepaddle.org.cn/documentation/docs/zh
            # /user_guides/howto/prepare_data/use_py_reader.html
            sample_generator = your_sample_generator
233 234 235
            batch_size = 10
            batch_nums = 10
            algo = "KL"
236
            quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
237 238
            ptq = PostTrainingQuantization(
                        executor=exe,
239 240 241 242
                        sample_generator=sample_generator,
                        model_dir=model_dir,
                        model_filename=model_filename,
                        params_filename=params_filename,
243 244 245 246 247 248 249
                        batch_size=batch_size,
                        batch_nums=batch_nums,
                        algo=algo,
                        quantizable_op_type=quantizable_op_type)
            ptq.quantize()
            ptq.save_quantized_model(save_model_path)
        '''
250

251 252 253 254
        self._support_activation_quantize_type = [
            'range_abs_max', 'moving_average_abs_max', 'abs_max'
        ]
        self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
X
XGZhang 已提交
255
        self._support_algo_type = [
256
            'KL', 'hist', 'avg', 'mse', 'emd', 'abs_max', 'min_max'
X
XGZhang 已提交
257
        ]
258 259 260
        assert round_type in ['adaround', 'round']
        self._round_type = round_type
        self._learning_rate = learning_rate
261
        self._dynamic_quantize_op_type = ['lstm']
262
        self._support_quantize_op_type = \
263 264
            list(set(utils._weight_supported_quantizable_op_type +
                utils._act_supported_quantizable_op_type +
265
                self._dynamic_quantize_op_type))
266 267

        # Check inputs
268 269
        assert executor is not None, "The executor cannot be None."
        assert model_dir is not None, "The model_dir cannot be None."
270
        assert any([gen is not None] for gen in [sample_generator,
271 272 273 274 275
            batch_generator, data_loader]), "The sample_generator, batch_generator " \
            "and data_loader cannot be None in the same time."
        if data_loader is not None:
            assert isinstance(data_loader, (io.DataLoader, type(isgeneratorfunction))), \
                "data_loader only accepts `paddle.io.DataLoader` or Generator instance."
276 277
        assert batch_size > 0, "The batch_size should be greater than 0."
        assert algo in self._support_algo_type, \
X
XGZhang 已提交
278
            "The algo should be KL, hist, mse, avg, abs_max or min_max."
279 280 281 282 283 284 285 286
        assert activation_quantize_type in self._support_activation_quantize_type, \
            "The activation_quantize_type ({}) should in ({}).".format(
            activation_quantize_type, self._support_activation_quantize_type)
        assert weight_quantize_type in self._support_weight_quantize_type, \
            "The weight_quantize_type ({}) shoud in ({}).".format(
            weight_quantize_type, self._support_weight_quantize_type)

        # Save input params
X
XGZhang 已提交
287
        self._bias_correction = bias_correction
288
        self._executor = executor
289
        self._scope = global_scope() if scope == None else scope
290 291 292
        self._model_dir = model_dir
        self._model_filename = model_filename
        self._params_filename = params_filename
293
        self._sample_generator = sample_generator
294
        self._batch_generator = batch_generator
295 296 297
        self._batch_size = batch_size
        self._batch_nums = batch_nums
        self._algo = algo
X
XGZhang 已提交
298
        self._hist_percent = hist_percent
299 300 301 302
        self._activation_bits = activation_bits
        self._weight_bits = weight_bits
        self._activation_quantize_type = activation_quantize_type
        self._weight_quantize_type = weight_quantize_type
303
        self._onnx_format = onnx_format
304
        self._is_full_quantize = is_full_quantize
305
        if is_full_quantize:
306
            self._quantizable_op_type = self._support_quantize_op_type
307 308 309
        else:
            self._quantizable_op_type = quantizable_op_type
            for op_type in self._quantizable_op_type:
310
                assert op_type in self._support_quantize_op_type, \
311
                    op_type + " is not supported for quantization."
312
        self._optimize_model = optimize_model
313

314
        # Define variables
315 316 317 318
        self._place = self._executor.place
        self._program = None
        self._feed_list = None
        self._fetch_list = None
319
        self._data_loader = data_loader
320

321
        self._out_scale_op_list = utils._out_scale_op_list
322 323
        self._quantized_weight_var_name = set()
        self._quantized_act_var_name = set()
324
        self._weight_op_pairs = {}
X
XGZhang 已提交
325
        # The vars for alog = KL or hist
326 327
        self._sampling_act_abs_min_max = {}
        self._sampling_act_histogram = {}
328
        self._sampling_data = {}
X
XGZhang 已提交
329
        self._quantized_var_threshold = {}
330 331
        self._histogram_bins = 2048
        # The vars for algo = min_max
332 333
        self._quantized_var_min = {}
        self._quantized_var_max = {}
X
XGZhang 已提交
334 335 336
        # The vars for algo = avg
        self._quantized_var_avg = {}
        # The best loss of algo = mse
337
        self._best_calibration_loss = {}
X
XGZhang 已提交
338 339
        # The threshold for algo = abs_max, mse or avg
        self._quantized_threshold = {}
340 341 342

    def quantize(self):
        '''
343 344 345
        Load the FP32 model, and use the calibrate data to calculate the forward-stage.
        Based on the sample data, we can get the quantization information, and obtain
        the final quantized model.
346 347 348 349

        Args:
            None
        Returns:
350 351
            the program of quantized model.
        '''
352
        self._load_model_data()
353
        self._collect_target_varnames()
354
        self._set_activation_persistable()
355

X
XGZhang 已提交
356
        if self._algo in ["KL", "hist"]:
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
            _logger.info("Preparation stage ...")
            batch_id = 0
            for data in self._data_loader():
                self._executor.run(program=self._program,
                                   feed=data,
                                   fetch_list=self._fetch_list,
                                   return_numpy=False,
                                   scope=self._scope)
                self._collect_activation_abs_min_max()
                if batch_id % 5 == 0:
                    _logger.info("Run batch: " + str(batch_id))
                batch_id += 1
                if self._batch_nums and batch_id >= self._batch_nums:
                    break
            _logger.info("Finish preparation stage, all batch:" + str(batch_id))
            self._init_sampling_act_histogram()

        _logger.info("Sampling stage ...")
375 376 377 378
        batch_id = 0
        for data in self._data_loader():
            self._executor.run(program=self._program,
                               feed=data,
379
                               fetch_list=self._fetch_list,
380 381
                               return_numpy=False,
                               scope=self._scope)
382
            self._sampling()
383
            if batch_id % 5 == 0:
384
                _logger.info("Run batch: " + str(batch_id))
385 386 387
            batch_id += 1
            if self._batch_nums and batch_id >= self._batch_nums:
                break
388
        _logger.info("Finish sampling stage, all batch: " + str(batch_id))
389

X
XGZhang 已提交
390 391 392 393 394 395
        if self._algo == 'avg':
            for var_name in self._quantized_act_var_name:
                self._quantized_threshold[var_name] = \
                np.array(self._quantized_var_avg[var_name]).mean()
        if self._algo in ["KL", "hist"]:
            self._calculate_kl_hist_threshold()
396 397 398 399 400 401 402

        if self._round_type == 'adaround':
            self._adaround_apply()

        self._reset_activation_persistable()

        if self._algo is 'min_max':
403
            self._save_input_threhold()
404 405 406 407 408 409
        else:
            self._update_program()

        # save out_threshold for quantized ops.
        if not self._onnx_format:
            self._save_output_threshold()
410

411 412 413 414
        if any(op_type in self._quantizable_op_type
               for op_type in self._dynamic_quantize_op_type):
            self._collect_dynamic_quantize_op_threshold(
                self._dynamic_quantize_op_type)
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431

        # Move sub blocks persistable var to global block
        global_block = self._program.global_block()
        for _op in global_block.ops:
            if _op.type == "while":
                _block_id = _op.attr("sub_block").id
                _block = self._program.block(_block_id)
                persistables = []
                for _name, _var in _block.vars.items():
                    if _var.persistable:
                        global_block._clone_variable(_var)
                        persistables.append(_name)
                for _name in persistables:
                    _block._remove_var(_name)
                persistables.extend(_op.input('X'))
                _op.desc.set_input("X", persistables)

432 433
        return self._program

434
    def _adaround_apply(self):
435
        assert self._algo != "min_max", "The algo should not be min_max."
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
        if self._algo in ["KL", "hist"]:
            scale_dict = self._quantized_var_threshold
        else:
            scale_dict = self._quantized_threshold
        run_adaround(
            self._data_loader,
            self._program,
            self._fetch_list,
            self._executor,
            self._scope,
            self._place,
            self._quantized_op_pairs,
            self._weight_op_pairs,
            scale_dict,
            num_iterations=self._batch_nums,
            lr=self._learning_rate)

453 454 455 456
    def save_quantized_model(self,
                             save_model_path,
                             model_filename=None,
                             params_filename=None):
457 458 459 460
        '''
        Save the quantized model to the disk.

        Args:
461 462 463 464 465 466 467
            save_model_path(str): The path to save the quantized model.
            model_filename(str, optional): If the model_filename is None,
                save the model to '__model__'. Otherwise, save the model
                to the specified filename. Default: None.
            params_filename(str, optional): If the params_filename is None,
                save params to separted files. Otherwise, save all params
                to the specified filename.
468
        Returns:
469 470
            None
        '''
471
        clip_extra = True if self._onnx_format else False
472 473
        io.save_inference_model(
            dirname=save_model_path,
474 475
            model_filename=model_filename,
            params_filename=params_filename,
476 477 478
            feeded_var_names=self._feed_list,
            target_vars=self._fetch_list,
            executor=self._executor,
479 480
            main_program=self._program,
            clip_extra=clip_extra)
481
        _logger.info("The quantized model is saved in " + save_model_path)
482

483
    def _load_model_data(self):
484
        '''
485
        Load model and set data loader.
486
        '''
487
        _logger.info("Load model and set data loader ...")
488
        [self._program, self._feed_list, self._fetch_list] = \
489 490 491 492
            io.load_inference_model(dirname=self._model_dir,
                                    executor=self._executor,
                                    model_filename=self._model_filename,
                                    params_filename=self._params_filename)
493 494 495 496

        if self._optimize_model:
            self._optimize_fp32_model()

497 498
        feed_vars = [framework._get_var(str(var_name), self._program) \
            for var_name in self._feed_list]
499 500 501

        if self._data_loader is not None:
            return
502 503
        self._data_loader = io.DataLoader.from_generator(
            feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
504 505 506 507 508 509 510 511 512 513
        if self._sample_generator is not None:
            self._data_loader.set_sample_generator(
                self._sample_generator,
                batch_size=self._batch_size,
                drop_last=True,
                places=self._place)
        elif self._batch_generator is not None:
            self._data_loader.set_batch_generator(
                self._batch_generator, places=self._place)

514 515 516 517 518 519 520 521
    def _optimize_fp32_model(self):
        '''
        Fuse the `conv2d/depthwise_conv2d + bn` in FP32 model.
        '''
        _logger.info("Optimize FP32 model ...")
        graph = IrGraph(core.Graph(self._program.desc), for_test=True)
        graph = _remove_ctrl_vars(graph)
        graph = _apply_pass(self._scope, graph, 'conv_bn_fuse_pass')
522 523
        graph = _apply_pass(self._scope, graph, 'depthwise_conv_bn_fuse_pass')
        graph = _apply_pass(self._scope, graph, 'conv_transpose_bn_fuse_pass')
524 525 526 527
        graph = _apply_pass(self._scope, graph, 'conv_eltwiseadd_bn_fuse_pass')
        graph = _apply_pass(self._scope, graph,
                            'depthwise_conv_eltwiseadd_bn_fuse_pass')

528 529
        self._program = graph.to_program()

530
    def _collect_target_varnames(self):
531 532 533 534
        '''
        Collect the variable names for sampling, and set activation
        variables to be persistable.
        '''
535
        # TODO(juncaipeng), consider the name_scope of skip_quant
536
        _logger.info("Collect quantized variable names ...")
537
        self._quantized_op_pairs = {}
538

539
        def collect_var_name(var_name_list, persistable_var_names, op_type):
540 541 542
            for var_name in var_name_list:
                if var_name in persistable_var_names:
                    self._quantized_weight_var_name.add(var_name)
543
                    self._weight_op_pairs[var_name] = op_type
544 545 546
                else:
                    self._quantized_act_var_name.add(var_name)

547
        persistable_var_names = _all_persistable_var_names(self._program)
548 549 550 551 552 553 554 555 556 557
        for block_id in range(len(self._program.blocks)):
            for op in self._program.blocks[block_id].ops:
                op_type = op.type
                if self._is_full_quantize and \
                    op_type not in self._quantizable_op_type:
                    _logger.warning(op_type +
                                    " is not supported for quantization.")
                # For quantized ops, sample inputs and outputs
                if op_type in self._quantizable_op_type:
                    collect_var_name(
558 559
                        utils._get_op_input_var_names(op),
                        persistable_var_names, op_type)
560
                    collect_var_name(
561 562
                        utils._get_op_output_var_names(op),
                        persistable_var_names, op_type)
563
                    # collect quanted op output var name
564 565
                    for out_var_name in utils._get_op_output_var_names(op):
                        for in_var_name in utils._get_op_input_var_names(op):
566 567 568
                            if in_var_name in persistable_var_names:
                                self._quantized_op_pairs[
                                    in_var_name] = out_var_name
569 570 571
                # For other op, only sample output scale
                elif op_type in self._out_scale_op_list:
                    collect_var_name(
572 573
                        utils._get_op_output_var_names(op),
                        persistable_var_names, op_type)
574 575 576 577 578 579

    def _set_activation_persistable(self):
        '''
        Set activation variables to be persistable, so can obtain 
        the tensor data in sample_data
        '''
580 581 582 583
        for var in self._program.list_vars():
            if var.name in self._quantized_act_var_name:
                var.persistable = True

584 585 586 587
    def _reset_activation_persistable(self):
        '''
        Reset activations to be not persistable.
        '''
588
        to_erase = []
589 590 591
        for var in self._program.list_vars():
            if var.name in self._quantized_act_var_name:
                var.persistable = False
592 593
                to_erase.append(var.name)
        self._scope.erase(to_erase)
594

595
    def _sampling(self):
596
        '''
597
        Sample the min/max, abs_max or histogram in every iterations.
598 599
        '''
        if self._algo == "abs_max":
600
            self._sample_abs_max()
X
XGZhang 已提交
601 602
        elif self._algo == "avg":
            self._sample_avg()
603
        elif self._algo == "min_max":
604
            self._sample_min_max()
X
XGZhang 已提交
605 606
        elif self._algo == "mse":
            self._sample_mse()
607 608
        elif self._algo == "emd":
            self._sample_emd()
X
XGZhang 已提交
609
        elif self._algo in ["KL", "hist"]:
610
            self._sample_histogram()
611

X
XGZhang 已提交
612 613 614
    def _sample_mse(self):
        if self._quantized_threshold == {}:
            for var_name in self._quantized_weight_var_name:
615
                var_tensor = utils.load_variable_data(self._scope, var_name)
X
XGZhang 已提交
616 617 618 619 620
                if self._weight_quantize_type == "abs_max":
                    abs_max_value = float(np.max(np.abs(var_tensor)))
                elif self._weight_quantize_type == "channel_wise_abs_max":
                    abs_max_value = []
                    if self._weight_op_pairs[
621
                            var_name] in utils._channelwise_quant_axis1_ops:
X
XGZhang 已提交
622 623 624 625 626 627 628 629 630 631
                        for i in range(var_tensor.shape[1]):
                            abs_max_value.append(
                                float(np.max(np.abs(var_tensor[:, i]))))
                    else:
                        for i in range(var_tensor.shape[0]):
                            abs_max_value.append(
                                float(np.max(np.abs(var_tensor[i]))))
                self._quantized_threshold[var_name] = abs_max_value
        _logger.info("MSE searching stage ...")
        for var_name in self._quantized_act_var_name:
632
            var_tensor = utils.load_variable_data(self._scope, var_name)
X
XGZhang 已提交
633 634
            var_tensor = var_tensor.flatten()
            abs_max_value = float(np.max(np.abs(var_tensor)))
X
XGZhang 已提交
635
            abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
X
XGZhang 已提交
636
            s = 0.3
637 638
            if var_name not in self._best_calibration_loss:
                self._best_calibration_loss[var_name] = float('inf')
X
XGZhang 已提交
639 640 641 642 643 644 645 646
            while s <= 1.0:
                scale = s * abs_max_value
                s += 0.02
                bins = 2**(self._activation_bits - 1) - 1
                quant_dequant_var = np.round(
                    np.clip(var_tensor, 0.0, scale) / scale *
                    bins) / bins * scale
                mse_loss = ((var_tensor - quant_dequant_var)**2).mean()
647 648 649 650 651 652 653
                if mse_loss <= self._best_calibration_loss[var_name]:
                    self._best_calibration_loss[var_name] = mse_loss
                    self._quantized_threshold[var_name] = scale

    def _sample_emd(self):
        if self._quantized_threshold == {}:
            for var_name in self._quantized_weight_var_name:
654
                var_tensor = utils.load_variable_data(self._scope, var_name)
655 656 657 658 659
                if self._weight_quantize_type == "abs_max":
                    abs_max_value = float(np.max(np.abs(var_tensor)))
                elif self._weight_quantize_type == "channel_wise_abs_max":
                    abs_max_value = []
                    if self._weight_op_pairs[
660
                            var_name] in utils._channelwise_quant_axis1_ops:
661 662 663 664 665 666 667 668 669 670
                        for i in range(var_tensor.shape[1]):
                            abs_max_value.append(
                                float(np.max(np.abs(var_tensor[:, i]))))
                    else:
                        for i in range(var_tensor.shape[0]):
                            abs_max_value.append(
                                float(np.max(np.abs(var_tensor[i]))))
                self._quantized_threshold[var_name] = abs_max_value
        _logger.info("EMD searching stage ...")
        for var_name in self._quantized_act_var_name:
671
            var_tensor = utils.load_variable_data(self._scope, var_name)
672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689
            var_tensor = var_tensor.flatten()
            abs_max_value = float(np.max(np.abs(var_tensor)))
            abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
            s = 0.3
            if var_name not in self._best_calibration_loss:
                self._best_calibration_loss[var_name] = float('inf')
            while s <= 1.0:
                scale = s * abs_max_value
                s += 0.02
                bins = 2**(self._activation_bits - 1) - 1
                quant_dequant_var = np.round(
                    np.clip(var_tensor, 0.0, scale) / scale *
                    bins) / bins * scale
                emd_loss = np.abs(
                    np.mean(var_tensor) - np.mean(quant_dequant_var)) + np.abs(
                        np.std(var_tensor) - np.std(quant_dequant_var))
                if emd_loss <= self._best_calibration_loss[var_name]:
                    self._best_calibration_loss[var_name] = emd_loss
X
XGZhang 已提交
690 691 692 693 694
                    self._quantized_threshold[var_name] = scale

    def _sample_avg(self):
        if self._quantized_threshold == {}:
            for var_name in self._quantized_weight_var_name:
695
                var_tensor = utils.load_variable_data(self._scope, var_name)
X
XGZhang 已提交
696 697 698 699 700
                if self._weight_quantize_type == "abs_max":
                    abs_max_value = float(np.max(np.abs(var_tensor)))
                elif self._weight_quantize_type == "channel_wise_abs_max":
                    abs_max_value = []
                    if self._weight_op_pairs[
701
                            var_name] in utils._channelwise_quant_axis1_ops:
X
XGZhang 已提交
702 703 704 705 706 707 708 709 710 711
                        for i in range(var_tensor.shape[1]):
                            abs_max_value.append(
                                float(np.max(np.abs(var_tensor[:, i]))))
                    else:
                        for i in range(var_tensor.shape[0]):
                            abs_max_value.append(
                                float(np.max(np.abs(var_tensor[i]))))
                self._quantized_threshold[var_name] = abs_max_value

        for var_name in self._quantized_act_var_name:
712
            var_tensor = utils.load_variable_data(self._scope, var_name)
X
XGZhang 已提交
713 714 715 716 717 718 719 720
            abs_max_value = float(np.max(np.abs(var_tensor)))
            if (var_name not in self._quantized_var_avg):
                self._quantized_var_avg[var_name] = []
            abs_avg_value = float(np.mean(np.max(  \
            np.abs(var_tensor.reshape(var_tensor.shape[0], -1)), axis=(1))))
            self._quantized_var_avg[var_name].append(abs_avg_value)
            continue

721
    def _sample_abs_max(self):
X
XGZhang 已提交
722
        if self._quantized_threshold == {}:
723
            for var_name in self._quantized_weight_var_name:
724
                var_tensor = utils.load_variable_data(self._scope, var_name)
725 726 727 728
                if self._weight_quantize_type == "abs_max":
                    abs_max_value = float(np.max(np.abs(var_tensor)))
                elif self._weight_quantize_type == "channel_wise_abs_max":
                    abs_max_value = []
729
                    if self._weight_op_pairs[
730
                            var_name] in utils._channelwise_quant_axis1_ops:
731 732 733 734 735 736 737
                        for i in range(var_tensor.shape[1]):
                            abs_max_value.append(
                                float(np.max(np.abs(var_tensor[:, i]))))
                    else:
                        for i in range(var_tensor.shape[0]):
                            abs_max_value.append(
                                float(np.max(np.abs(var_tensor[i]))))
X
XGZhang 已提交
738
                self._quantized_threshold[var_name] = abs_max_value
739 740

        for var_name in self._quantized_act_var_name:
741
            var_tensor = utils.load_variable_data(self._scope, var_name)
742
            abs_max_value = float(np.max(np.abs(var_tensor)))
X
XGZhang 已提交
743 744 745
            if (var_name not in self._quantized_threshold) or \
                (abs_max_value > self._quantized_threshold[var_name]):
                self._quantized_threshold[var_name] = abs_max_value
746

747
    def _sample_min_max(self):
748 749
        if self._quantized_var_min == {} and self._quantized_var_max == {}:
            for var_name in self._quantized_weight_var_name:
750
                var_tensor = utils.load_variable_data(self._scope, var_name)
751 752 753 754 755 756
                if self._weight_quantize_type == "abs_max":
                    min_value = float(np.min(var_tensor))
                    max_value = float(np.max(var_tensor))
                elif self._weight_quantize_type == "channel_wise_abs_max":
                    min_value = []
                    max_value = []
757
                    if self._weight_op_pairs[
758
                            var_name] in utils._channelwise_quant_axis1_ops:
759 760 761 762 763 764 765 766 767 768 769
                        for i in range(var_tensor.shape[1]):
                            min_value.append(float(np.min(var_tensor[:, i])))
                            max_value.append(float(np.max(var_tensor[:, i])))
                    else:
                        for i in range(var_tensor.shape[0]):
                            min_value.append(float(np.min(var_tensor[i])))
                            max_value.append(float(np.max(var_tensor[i])))
                self._quantized_var_min[var_name] = min_value
                self._quantized_var_max[var_name] = max_value

        for var_name in self._quantized_act_var_name:
770
            var_tensor = utils.load_variable_data(self._scope, var_name)
771 772 773 774 775 776 777 778
            min_value = float(np.min(var_tensor))
            max_value = float(np.max(var_tensor))
            if (var_name not in self._quantized_var_min) or \
                (min_value < self._quantized_var_min[var_name]):
                self._quantized_var_min[var_name] = min_value
            if (var_name not in self._quantized_var_max) or \
                (max_value > self._quantized_var_max[var_name]):
                self._quantized_var_max[var_name] = max_value
779

780 781
    def _sample_histogram(self):
        for var_name in self._quantized_act_var_name:
782
            var_tensor = utils.load_variable_data(self._scope, var_name)
783 784 785 786 787
            var_tensor_abs = np.abs(var_tensor)
            bins = self._sampling_act_histogram[var_name][1]
            hist, _ = np.histogram(var_tensor_abs, bins=bins)
            self._sampling_act_histogram[var_name][0] += hist

788 789 790 791 792 793
    def _save_input_threhold(self):
        '''
        Save input threshold to the quantized op.
        '''
        assert self._algo == "min_max", \
            "The algo should be min_max to save input threshold."
794 795 796
        for block_id in range(len(self._program.blocks)):
            for op in self._program.blocks[block_id].ops:
                if op.type in self._quantizable_op_type:
797
                    for var_name in utils._get_op_input_var_names(op):
798 799 800 801 802 803 804
                        assert var_name in self._quantized_var_min
                        assert var_name in self._quantized_var_max
                        op._set_attr(var_name + ".min",
                                     self._quantized_var_min[var_name])
                        op._set_attr(var_name + ".max",
                                     self._quantized_var_max[var_name])
                        op._set_attr("with_quant_attr", True)
805

806
    def _collect_activation_abs_min_max(self):
807
        '''
808 809
        Collect the abs_min and abs_max for all activation. When algo = KL,
        get the min and max value, and then calculate the threshold.
810
        '''
811
        for var_name in self._quantized_act_var_name:
812
            var_tensor = utils.load_variable_data(self._scope, var_name)
813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835
            var_tensor = np.abs(var_tensor)
            min_value = float(np.min(var_tensor))
            max_value = float(np.max(var_tensor))
            if var_name not in self._sampling_act_abs_min_max:
                self._sampling_act_abs_min_max[
                    var_name] = [min_value, max_value]
            else:
                if min_value < self._sampling_act_abs_min_max[var_name][0]:
                    self._sampling_act_abs_min_max[var_name][0] = min_value
                if max_value > self._sampling_act_abs_min_max[var_name][1]:
                    self._sampling_act_abs_min_max[var_name][1] = max_value

    def _init_sampling_act_histogram(self):
        '''
        Based on the min/max value, init the sampling_act_histogram.
        '''
        for var_name in self._quantized_act_var_name:
            if var_name not in self._sampling_act_histogram:
                min_val = self._sampling_act_abs_min_max[var_name][0]
                max_val = self._sampling_act_abs_min_max[var_name][1]
                hist, hist_edeges = np.histogram(
                    [], bins=self._histogram_bins, range=(min_val, max_val))
                self._sampling_act_histogram[var_name] = [hist, hist_edeges]
836

X
XGZhang 已提交
837
    def _calculate_kl_hist_threshold(self):
838
        '''
X
XGZhang 已提交
839
        Calculate the KL or hist threshold of quantized variables.
840
        '''
X
XGZhang 已提交
841 842
        _logger.info("Calculate {} threshold ...".format(self._algo))
        assert self._algo in ["KL", "hist"], "The algo should be KL or hist."
843 844

        # Abs_max threshold for weights
845
        for var_name in self._quantized_weight_var_name:
846
            weight_data = utils.load_variable_data(self._scope, var_name)
847
            if self._weight_quantize_type == "abs_max":
848
                weight_threshold = float(np.max(np.abs(weight_data)))
849 850
            elif self._weight_quantize_type == "channel_wise_abs_max":
                weight_threshold = []
851
                if self._weight_op_pairs[
852
                        var_name] in utils._channelwise_quant_axis1_ops:
853 854 855 856 857 858 859
                    for i in range(weight_data.shape[1]):
                        weight_threshold.append(
                            float(np.max(np.abs(weight_data[:, i]))))
                else:
                    for i in range(weight_data.shape[0]):
                        weight_threshold.append(
                            float(np.max(np.abs(weight_data[i]))))
X
XGZhang 已提交
860
            self._quantized_var_threshold[var_name] = weight_threshold
861

862 863
        for var_name in self._quantized_act_var_name:
            hist, hist_edeges = self._sampling_act_histogram[var_name]
X
XGZhang 已提交
864
            if self._algo == "KL":
865
                bin_width = hist_edeges[1] - hist_edeges[0]
X
XGZhang 已提交
866
                self._quantized_var_threshold[var_name] = \
867
                    cal_kl_threshold(hist, bin_width, self._activation_bits)
X
XGZhang 已提交
868 869 870
            elif self._algo == "hist":
                self._quantized_var_threshold[var_name] = \
                    self._get_hist_scaling_factor(hist, hist_edeges)
871 872 873

    def _update_program(self):
        '''
874 875
        Use QuantizationTransformPass and AddQuantDequantPass to insert 
        fake_quantize, fake_dequantize and fake_quant_dequant op. 
X
XGZhang 已提交
876
        Besides, save all threshold to the scale var node.
877
        '''
878
        _logger.info("Update the program ...")
879 880
        graph = IrGraph(core.Graph(self._program.desc), for_test=True)

881
        # use QuantizationTransformPass to insert fake_quant/fake_dequantize op
882
        major_quantizable_op_types = []
883
        for op_type in utils._weight_supported_quantizable_op_type:
884
            if op_type in self._quantizable_op_type:
885
                major_quantizable_op_types.append(op_type)
886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903
        if not self._onnx_format:
            transform_pass = QuantizationTransformPass(
                scope=self._scope,
                place=self._place,
                weight_bits=self._weight_bits,
                activation_bits=self._activation_bits,
                activation_quantize_type=self._activation_quantize_type,
                weight_quantize_type=self._weight_quantize_type,
                quantizable_op_type=major_quantizable_op_types)
        else:
            transform_pass = QuantizationTransformPassV2(
                scope=self._scope,
                place=self._place,
                weight_bits=self._weight_bits,
                activation_bits=self._activation_bits,
                activation_quantize_type=self._activation_quantize_type,
                weight_quantize_type=self._weight_quantize_type,
                quantizable_op_type=major_quantizable_op_types)
904 905 906 907 908 909

        for sub_graph in graph.all_sub_graphs():
            # Insert fake_quant/fake_dequantize op must in test graph, so
            # set per graph's _for_test is True.
            sub_graph._for_test = True
            transform_pass.apply(sub_graph)
910 911

        # use AddQuantDequantPass to insert fake_quant_dequant op
912
        minor_quantizable_op_types = []
913
        for op_type in utils._act_supported_quantizable_op_type:
914
            if op_type in self._quantizable_op_type:
915
                minor_quantizable_op_types.append(op_type)
916 917 918 919 920 921 922 923 924 925 926
        if not self._onnx_format:
            add_quant_dequant_pass = AddQuantDequantPass(
                scope=self._scope,
                place=self._place,
                quantizable_op_type=minor_quantizable_op_types)
        else:
            add_quant_dequant_pass = AddQuantDequantPassV2(
                scope=self._scope,
                place=self._place,
                quantizable_op_type=minor_quantizable_op_types,
                is_full_quantized=self._is_full_quantize)
927 928 929 930

        for sub_graph in graph.all_sub_graphs():
            sub_graph._for_test = True
            add_quant_dequant_pass.apply(sub_graph)
931

X
XGZhang 已提交
932 933 934
        # save threshold to scale var node
        if self._algo in ["KL", "hist"]:
            scale_dict = self._quantized_var_threshold
935
        else:
X
XGZhang 已提交
936
            scale_dict = self._quantized_threshold
937
        for key, val in scale_dict.items():
938
            utils.set_variable_data(
939 940 941 942
                self._scope,
                self._place,
                key + ".scale",
                np.array(
943
                    [val], dtype=np.float32))
944
            utils.set_variable_data(
945 946 947 948
                self._scope,
                self._place,
                key + ".quant_dequant.scale",
                np.array(
949 950
                    [val], dtype=np.float32))

951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970
        if not self._onnx_format:
            # apply QuantizationFreezePass, and obtain the final quant model
            freeze_pass = QuantizationFreezePass(
                scope=self._scope,
                place=self._place,
                bias_correction=self._bias_correction,
                weight_bits=self._weight_bits,
                round_type=self._round_type,
                activation_bits=self._activation_bits,
                weight_quantize_type=self._weight_quantize_type,
                quantizable_op_type=major_quantizable_op_types)

            for sub_graph in graph.all_sub_graphs():
                sub_graph._for_test = True
                freeze_pass.apply(sub_graph)
        else:
            quant_weight_pass = QuantWeightPass(self._scope, self._place)
            for sub_graph in graph.all_sub_graphs():
                sub_graph._for_test = True
                quant_weight_pass.apply(sub_graph)
971

972 973
        self._program = graph.to_program()

974
    def _save_output_threshold(self):
975
        '''
976
        Save output threshold to the quantized op.
977
        '''
978 979 980 981 982 983 984

        def save_info(op_node, out_var_name, threshold_map, out_info_name,
                      quantized_type):
            assert out_var_name in threshold_map, \
                "The output ({}) of {} node does not have threshold.".format(
                out_var_name, op_node.type)
            op_node._set_attr(out_info_name, threshold_map[var_name])
985
            op_node._set_attr("with_quant_attr", True)
986 987 988 989
            if op_node.type in self._quantizable_op_type:
                op._set_attr("quantization_type", quantized_type)

        def analysis_and_save_info(op_node, out_var_name):
990
            argname_index = utils._get_output_name_index(op_node, out_var_name)
991 992
            assert argname_index is not None, \
                out_var_name + " is not the output of the op"
993
            if self._algo == "KL":
994
                # For compatibility, we save output threshold by two methods.
X
XGZhang 已提交
995 996
                save_info(op_node, out_var_name, self._quantized_var_threshold,
                          "out_threshold", "post_kl")
997
                save_info(
X
XGZhang 已提交
998
                    op_node, out_var_name, self._quantized_var_threshold,
999 1000
                    argname_index[0] + str(argname_index[1]) + "_threshold",
                    "post_kl")
X
XGZhang 已提交
1001 1002 1003 1004
            elif self._algo == "hist":
                # For compatibility, we save output threshold by two methods.
                save_info(op_node, out_var_name, self._quantized_var_threshold,
                          "out_threshold", "post_hist")
1005
                save_info(
X
XGZhang 已提交
1006
                    op_node, out_var_name, self._quantized_var_threshold,
1007
                    argname_index[0] + str(argname_index[1]) + "_threshold",
X
XGZhang 已提交
1008 1009
                    "post_hist")

1010
            elif self._algo in ["avg", "abs_max", "mse", "emd"]:
X
XGZhang 已提交
1011 1012 1013 1014 1015 1016
                save_info(op_node, out_var_name, self._quantized_threshold,
                          "out_threshold", "post_" + str(self._algo))
                save_info(
                    op_node, out_var_name, self._quantized_threshold,
                    argname_index[0] + str(argname_index[1]) + "_threshold",
                    "post_" + str(self._algo))
1017 1018 1019 1020 1021 1022
            elif self._algo == "min_max":
                save_info(op_node, out_var_name, self._quantized_var_min,
                          "out_min", "post_min_max")
                save_info(op_node, out_var_name, self._quantized_var_max,
                          "out_max", "post_min_max")

1023 1024 1025 1026
        for block_id in range(len(self._program.blocks)):
            for op in self._program.blocks[block_id].ops:
                if op.type in (
                        self._quantizable_op_type + self._out_scale_op_list):
1027
                    out_var_names = utils._get_op_output_var_names(op)
1028 1029
                    for var_name in out_var_names:
                        analysis_and_save_info(op, var_name)
1030

1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049
    def _collect_dynamic_quantize_op_threshold(self, target_ops_type):
        """
        Collect and save the weight threshold for dynamic quantize ops,
        such as lstm and gru.
        Args:
            target_ops_type(list): the op type of target ops
        Returns:
            None
        """

        target_ops = []
        for index in range(self._program.num_blocks):
            for op in self._program.block(index).ops:
                if op.type in target_ops_type:
                    target_ops.append(op)

        quantization_type = str("post_" + self._algo).lower()
        persistable_var_names = _all_persistable_var_names(self._program)
        for op in target_ops:
1050
            for var_name in utils._get_op_input_var_names(op):
1051
                if var_name in persistable_var_names:
1052
                    var_data = utils.load_variable_data(self._scope, var_name)
1053
                    threshold = float(np.max(np.abs(var_data)))
1054
                    argname, index = utils._get_input_name_index(op, var_name)
1055 1056 1057
                    op._set_attr(argname + str(index) + "_threshold", threshold)
                    op._set_attr("quantization_type", quantization_type)
                    op._set_attr("bit_length", self._weight_bits)
1058
                    op._set_attr("with_quant_attr", True)
1059

X
XGZhang 已提交
1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075
    def _get_hist_scaling_factor(self, hist, hist_edges):
        '''
        Using the hist method to get the scaling factor.
        '''
        threshold_rate = self._hist_percent
        hist = hist / float(sum(hist))
        hist_sum = 0
        hist_index = 0
        for i in range(len(hist)):
            hist_sum += hist[i]
            if hist_sum >= threshold_rate:
                hist_index = i + 1
                break
        bin_width = hist_edges[1] - hist_edges[0]
        return (hist_index - 0.5) * bin_width

1076 1077 1078

class WeightQuantization(object):
    _supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
1079
    _supported_weight_quantize_type = ['channel_wise_abs_max', 'abs_max']
1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105

    def __init__(self, model_dir, model_filename=None, params_filename=None):
        '''
        This class quantizes the weight of some ops to reduce the size of model
        or improve the perforemace.

        Args:
            model_dir(str): The path of the fp32 model that will be quantized,
                and the model and params files are under the path.
            model_filename(str, optional): The name of file to load the inference
                program. If it is None, the default filename '__model__' will
                be used. Default is 'None'.
            params_filename(str, optional): The name of file to load all parameters.
                When all parameters were saved in a single binary file, set it
                as the real filename. If parameters were saved in separate files,
                set it as 'None'. Default is 'None'.
        '''
        self._model_dir = model_dir
        self._model_filename = model_filename
        self._params_filename = params_filename

    def quantize_weight_to_int(self,
                               save_model_dir,
                               save_model_filename=None,
                               save_params_filename=None,
                               quantizable_op_type=["conv2d", "mul"],
1106
                               weight_bits=8,
1107 1108
                               weight_quantize_type="channel_wise_abs_max",
                               generate_test_model=False,
1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127
                               threshold_rate=0.0):
        '''
        In order to reduce the size of model, this api quantizes the weight
        of some ops from float32 to int8/16. In the inference stage, the 
        quantized weight will be dequantized to float32 again.
        
        Args:
            save_model_dir(str): The path to save the quantized model.
            save_model_filename(str, optional): The name of file to 
                save the inference program. If it is None, the default 
                filename '__model__' will be used. Default is 'None'.
            save_params_filename(str, optional): The name of file to 
                save all parameters. If it is None, parameters were 
                saved in separate files. If it is not None, all 
                parameters were saved in a single binary file.
            quantizable_op_type(list[str], optional): The list of ops 
                that will be quantized, and the quantized ops should be
                contained in ["conv2d", "depthwise_conv2d", "mul"]. 
                Default is ["conv2d","mul"].
1128 1129
            weight_bits(int, optional): The bits for the quantized weight, 
                and it should be 8 or 16. Default is 8.
1130 1131 1132 1133 1134 1135 1136
            weight_quantize_type(str, optional): quantization type for weights,
                support 'channel_wise_abs_max' and 'abs_max'. Set it as
                'channel_wise_abs_max', the accuracy performs better.
            generate_test_model(bool, optional): If set generate_test_model 
                as True, it saves a fake quantized model, in which the weights 
                are quantized and dequantized. We can use PaddlePaddle to load 
                the fake quantized model and test the accuracy on GPU or CPU.
1137 1138 1139 1140 1141 1142 1143 1144 1145
            threshold_rate(float, optional): This api uses abs_max methd to 
                quantize the weight from float32 to int8/16, and the abs max 
                value is important for quantization diff. When the abs_max 
                value is far away from the center of the numerical distribution, 
                we can set threshold_rate between 1e-6 and 1e-8, so the abs max 
                value will be optimized. Default is 0.0.
        '''
        for op_type in quantizable_op_type:
            assert op_type in self._supported_quantizable_op_type, \
1146
                "Input error:" + op_type + \
1147
                " is not supported for weight quantization."
1148
        assert weight_bits in [8, 16], \
1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166
            "Input error: weight_bits should be 8 or 16."
        assert weight_quantize_type in self._supported_weight_quantize_type, \
            "Input error: weight_quantize_type should in {}".format(
                self._supported_weight_quantize_type)

        quantized_model_dir = os.path.join(save_model_dir, "quantized_model")
        self._quantize_weight_to_int(quantized_model_dir, save_model_filename,
                                     save_params_filename, quantizable_op_type,
                                     weight_bits, weight_quantize_type, False,
                                     threshold_rate)

        if generate_test_model:
            test_model_dir = os.path.join(save_model_dir, "test_model")
            self._quantize_weight_to_int(
                test_model_dir, save_model_filename, save_params_filename,
                quantizable_op_type, weight_bits, weight_quantize_type, True,
                threshold_rate)

1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242
    def convert_weight_to_fp16(self, save_model_dir):
        """
        Convert all presistable vars from fp32 to fp16.
        Note that, this api only changes the data type of variables in
        __params__ file, and the __model__ file remains unchanged. 

        Args:
            save_model_dir(str): The path to save the fp16 model.
        """

        # Load model
        place = core.CPUPlace()
        exe = Executor(place)
        scope = global_scope()
        [infer_program, feed_list, fetch_list] = \
            io.load_inference_model(dirname=self._model_dir,
                                    executor=exe,
                                    model_filename=self._model_filename,
                                    params_filename=self._params_filename)

        # Clone and save fp16 weights
        save_program = framework.Program()
        save_block = save_program.global_block()
        save_var_map = {}

        for var in infer_program.list_vars():
            if (var.type == core.VarDesc.VarType.RAW) or \
                (not var.persistable) or (var.name in ['feed', 'fetch']) \
                or (var.dtype != core.VarDesc.VarType.FP32):
                continue

            #new_var = _clone_var_to_block_(var, save_block)
            new_var = save_block._clone_variable(var)
            if self._params_filename is not None:
                save_var_map[new_var.name] = new_var
            else:
                save_file_path = os.path.join(
                    os.path.normpath(save_model_dir), new_var.name)
                save_block.append_op(
                    type='save',
                    inputs={'X': [new_var]},
                    outputs={},
                    attrs={
                        'file_path': os.path.normpath(save_file_path),
                        'save_as_fp16': True
                    })

        if self._params_filename is not None:
            save_var_list = []
            for name in sorted(save_var_map.keys()):
                save_var_list.append(save_var_map[name])

            saved_params_var = save_block.create_var(
                type=core.VarDesc.VarType.RAW,
                name=unique_name.generate("saved_params"))
            saved_params_var.desc.set_persistable(True)

            save_path = os.path.join(
                os.path.normpath(save_model_dir), self._params_filename)
            save_block.append_op(
                type='save_combine',
                inputs={'X': save_var_list},
                outputs={'Y': saved_params_var},
                attrs={'file_path': save_path,
                       'save_as_fp16': True})

        save_program._sync_with_cpp()
        exe.run(save_program)

        # Copy model
        model_filename = "__model__" if self._model_filename is None \
                    else self._model_filename
        src_model = os.path.join(self._model_dir, model_filename)
        dest_model = os.path.join(save_model_dir, model_filename)
        shutil.copyfile(src_model, dest_model)

1243 1244 1245 1246 1247 1248 1249 1250
    def _quantize_weight_to_int(self, save_model_dir, save_model_filename,
                                save_params_filename, quantizable_op_type,
                                weight_bits, weight_quantize_type, for_test,
                                threshold_rate):
        """
        Generate quantized model or fake quantized model.
        """
        # Load model
1251 1252 1253 1254 1255 1256 1257 1258 1259
        place = core.CPUPlace()
        exe = Executor(place)
        scope = global_scope()
        [program, feed_list, fetch_list] = \
            io.load_inference_model(dirname=self._model_dir,
                                    executor=exe,
                                    model_filename=self._model_filename,
                                    params_filename=self._params_filename)

1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278
        quantized_ops = []
        for index in range(program.num_blocks):
            block = program.block(index)
            for op in block.ops:
                if op.type in quantizable_op_type:
                    quantized_ops.append(op)

        # Quantize weights
        persistable_var_names = _all_persistable_var_names(program)
        for op in quantized_ops:
            for var_name in op.input_arg_names:
                if var_name in persistable_var_names:
                    if weight_quantize_type == "abs_max":
                        self._weight_abs_max_quantization(
                            scope, place, weight_bits, threshold_rate, op,
                            var_name, for_test)
                    elif weight_quantize_type == "channel_wise_abs_max":
                        self._weight_channel_wise_abs_max_quantization(
                            scope, place, weight_bits, op, var_name, for_test)
1279 1280 1281 1282 1283 1284 1285 1286 1287 1288

        io.save_inference_model(
            dirname=save_model_dir,
            feeded_var_names=feed_list,
            target_vars=fetch_list,
            executor=exe,
            main_program=program,
            model_filename=save_model_filename,
            params_filename=save_params_filename)

1289 1290 1291 1292 1293 1294 1295 1296 1297
    def _weight_abs_max_quantization(self, scope, place, weight_bits,
                                     threshold_rate, op, var_name, for_test):
        '''
        Use abs_max method to quantize weight.
        '''
        quantize_range = (1 << (weight_bits - 1)) - 1
        save_weight_dtype = np.int8 if weight_bits == 8 else np.int16

        # Get quantized scale and weight data
1298
        weight_data = utils.load_variable_data(scope, var_name)
1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311
        if abs(threshold_rate) < 1e-10:
            threshold_value = np.max(np.abs(weight_data))
        else:
            threshold_value = self._calculate_threshold(\
                weight_data, threshold_rate)
            weight_data[weight_data > threshold_value] = threshold_value
            weight_data[weight_data < -threshold_value] = -threshold_value
        scale = threshold_value / quantize_range
        quantized_weight_data = \
            np.around(weight_data / scale).astype(save_weight_dtype)

        # Set weight data
        if not for_test:
1312 1313
            utils.set_variable_data(scope, place, var_name,
                                    quantized_weight_data)
1314 1315 1316
        else:
            dequantized_weight_data = \
                (quantized_weight_data * scale).astype(np.float32)
1317 1318
            utils.set_variable_data(scope, place, var_name,
                                    dequantized_weight_data)
1319 1320 1321 1322 1323

        # Save info
        op._set_attr('quantization_type', 'post_weight_abs_max')
        op._set_attr('quantize_weight_bits', weight_bits)
        op._set_attr(var_name + "_quant_scale", [scale])  # Save as list
1324
        op._set_attr("with_quant_attr", True)
1325 1326 1327 1328 1329 1330 1331 1332 1333 1334

    def _weight_channel_wise_abs_max_quantization(
            self, scope, place, weight_bits, op, var_name, for_test):
        ''' 
        Use channel_wise_abs_max method to quantize weight.
        '''
        quantize_range = (1 << (weight_bits - 1)) - 1
        save_weight_dtype = np.int8 if weight_bits == 8 else np.int16

        # Get quantized scale and weight data
1335
        weight_data = utils.load_variable_data(scope, var_name)
1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348
        if op.type == "mul":
            scales, quantized_weight_data = \
                self._mul_channel_wise_quantization(weight_data,
                    quantize_range, save_weight_dtype)
        elif op.type in ["conv2d", "depthwise_conv2d"]:
            scales, quantized_weight_data = \
                self._conv_channel_wise_quantization(weight_data,
                    quantize_range, save_weight_dtype)
        else:
            _logger.error(op.type + " is not supported by weight quantization")

        # Set weight data
        if not for_test:
1349 1350
            utils.set_variable_data(scope, place, var_name,
                                    quantized_weight_data)
1351 1352 1353 1354 1355 1356 1357 1358 1359 1360
        else:
            if op.type == "mul":
                dequantized_weight_data = \
                    self._mul_channel_wise_dequantization(quantized_weight_data, scales)
            elif op.type in ["conv2d", "depthwise_conv2d"]:
                dequantized_weight_data = \
                    self._conv_channel_wise_dequantization(quantized_weight_data, scales)
            else:
                _logger.error(op.type +
                              " is not supported by weight quantization")
1361 1362
            utils.set_variable_data(scope, place, var_name,
                                    dequantized_weight_data)
1363 1364 1365 1366 1367

        # Save info
        op._set_attr('quantization_type', 'post_weight_channel_wise_abs_max')
        op._set_attr('quantize_weight_bits', weight_bits)
        op._set_attr(var_name + "_quant_scale", scales)
1368
        op._set_attr("with_quant_attr", True)
1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425

    def _conv_channel_wise_quantization(self, weight_data, quantize_range,
                                        save_weight_dtype):
        '''
        Get channel wise scale for the weights of conv2d and depthwise_conv2d,
        and quantize the weights.
        '''
        scales = []
        quantized_weight_data = np.zeros_like(
            weight_data, dtype=save_weight_dtype)
        channel_num = weight_data.shape[0]
        for i in range(channel_num):
            scale = np.max(np.abs(weight_data[i])) / quantize_range
            scales.append(scale)
            quantized_weight_data[i] = \
                np.around(weight_data[i] / scale).astype(save_weight_dtype)
        return scales, quantized_weight_data

    def _conv_channel_wise_dequantization(self, quantized_weight_data, scales):
        '''
        For conv2d and depthwise_conv2d, dequantize the weights to fp32.
        '''
        dequantized_weight_data = np.zeros_like(
            quantized_weight_data, dtype=np.float32)
        for i in range(len(scales)):
            dequantized_weight_data[i] = \
                (quantized_weight_data[i] * scales[i]).astype(np.float32)
        return dequantized_weight_data

    def _mul_channel_wise_quantization(self, weight_data, quantize_range,
                                       save_weight_dtype):
        '''
        Get channel wise scale for the weights of conv2d and depthwise_conv2d,
        and quantize the weights.
        '''
        scales = []
        quantized_weight_data = np.zeros_like(
            weight_data, dtype=save_weight_dtype)
        channel_num = weight_data.shape[-1]
        for i in range(channel_num):
            scale = np.max(np.abs(weight_data[:, i])) / quantize_range
            scales.append(scale)
            quantized_weight_data[:, i] = \
                np.around(weight_data[:, i] / scale).astype(save_weight_dtype)
        return scales, quantized_weight_data

    def _mul_channel_wise_dequantization(self, quantized_weight_data, scales):
        '''
        For mul, dequantize the weights to fp32.
        '''
        dequantized_weight_data = np.zeros_like(
            quantized_weight_data, dtype=np.float32)
        for i in range(len(scales)):
            dequantized_weight_data[:, i] = \
                (quantized_weight_data[:, i] * scales[i]).astype(np.float32)
        return dequantized_weight_data

1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439
    def _calculate_threshold(self, input, threshold_rate, histogram_bins=5000):
        input_abs = np.abs(input)
        hist, hist_edeges = np.histogram(
            input_abs, bins=histogram_bins, range=(0, np.max(input_abs)))
        hist = hist / float(sum(hist))
        hist_sum = 0
        hist_index = 0
        for i in range(len(hist)):
            hist_sum += hist[i]
            if hist_sum >= 1.0 - threshold_rate:
                hist_index = i + 1
                break
        bin_width = hist_edeges[1] - hist_edeges[0]
        return hist_index * bin_width