post_training_quantization.py 81.9 KB
Newer Older
1
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import logging
16
import os
17
import shutil
18

19
import numpy as np
20

21 22 23 24
try:
    from tqdm import tqdm
except:
    from .utils import tqdm
25

26
from inspect import isgeneratorfunction
27 28 29 30 31 32 33 34 35 36 37

from paddle.fluid.framework import IrGraph, _get_var

from ... import io, static
from ...fluid import reader
from ...framework import core
from ...utils import unique_name
from ..log_helper import get_logger
from . import utils
from .adaround import run_adaround
from .cal_kl_threshold import cal_kl_threshold
38 39 40 41 42 43 44
from .quant_config import (
    SUPPORT_QUANTIZATION_OP_DICT,
    ARMCPUQuantizer,
    BaseQuantizer,
    MKLDNNQuantizer,
    TensorRTQuantizer,
)
45
from .quantization_pass import (
46
    AddQuantDequantForInferencePass,
47 48 49
    AddQuantDequantPass,
    AddQuantDequantPassV2,
    QuantizationFreezePass,
50 51 52 53
    QuantizationTransformPass,
    QuantizationTransformPassV2,
    QuantWeightPass,
)
54

55 56 57
_logger = get_logger(
    __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
58 59


60 61 62 63 64 65 66 67
def _all_persistable_var_names(program):
    persistable_var_names = []
    for var in program.list_vars():
        if var.persistable:
            persistable_var_names.append(var.name)
    return persistable_var_names


68 69 70 71 72 73 74 75 76 77 78 79
def _remove_unused_var_nodes(graph):
    all_used_vars = set()
    ops = graph.all_op_nodes()
    for op_node in ops:
        for input_node in op_node.inputs:
            all_used_vars.add(input_node)
        for output_node in op_node.outputs:
            all_used_vars.add(output_node)

    all_used_vars = {n.node for n in all_used_vars}
    all_unused_vars = {
        n
80 81 82
        for n in filter(
            lambda node: node.node not in all_used_vars, graph.all_var_nodes()
        )
83 84 85 86 87 88 89 90 91 92 93 94 95 96
    }
    graph.safe_remove_nodes(all_unused_vars)
    return graph


def _remove_ctrl_vars(graph):
    remove_ctr_vars = set()
    for node in graph.all_var_nodes():
        if node.is_ctrl_var():
            remove_ctr_vars.add(node)
    graph.safe_remove_nodes(remove_ctr_vars)
    return graph


97 98 99
def _apply_pass(
    scope, graph, pass_name, attrs=None, attr_values=None, debug=False
):
100 101 102 103 104 105
    ir_pass = core.get_pass(pass_name)
    cpp_graph = graph.graph
    if not cpp_graph.has('__param_scope__'):
        cpp_graph.set_not_owned('__param_scope__', scope)
    if attrs:
        assert attr_values and len(attrs) == len(
106 107
            attr_values
        ), "Different number of pass attributes and their values."
108 109 110 111 112 113 114 115 116
        for attr, value in zip(attrs, attr_values):
            ir_pass.set(attr, value)
    ir_pass.apply(cpp_graph)
    if debug:
        graph.draw('.', 'qat_fp32_{}'.format(pass_name), graph.all_op_nodes())
    _remove_unused_var_nodes(graph)
    return graph


117
class PostTrainingQuantization:
118 119
    """
    Utilizing post training quantization methon to quantize the FP32 model,
120
    and it uses calibrate data to get the quantization information for all
121 122 123
    quantized variables.
    """

124 125 126 127 128 129 130 131 132 133 134 135 136 137
    def __init__(
        self,
        executor,
        model_dir,
        scope=None,
        model_filename=None,
        params_filename=None,
        batch_generator=None,
        sample_generator=None,
        data_loader=None,
        batch_size=10,
        batch_nums=None,
        algo="KL",
        hist_percent=0.99999,
138
        quantizable_op_type=[],
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
        round_type='round',
        learning_rate=0.001,
        is_full_quantize=False,
        bias_correction=False,
        activation_bits=8,
        weight_bits=8,
        activation_quantize_type='range_abs_max',
        weight_quantize_type='channel_wise_abs_max',
        onnx_format=False,
        freeze_model=True,
        optimize_model=False,
        is_use_cache_file=False,
        skip_tensor_list=None,
        same_scale_tensor_list=None,
        cache_dir=None,
        scale_dict=None,
        return_graph=False,
156
        deploy_backend=None,
157
    ):
158
        '''
159
        Constructor.
160 161

        Args:
162
            executor(static.Executor): The executor to load, run and save the
163
                quantized model.
164 165
            scope(static.Scope, optional): The scope of the program, use it to load
                and save variables. If scope=None, get scope by static.global_scope().
166
            model_dir(str): The path of the fp32 model that will be quantized,
167
                and the model and params files are under the path.
168 169
            model_filename(str, optional): The name of file to load the inference
                program. If it is None, the default filename '__model__' will
170 171
                be used. Default is 'None'.
            params_filename(str, optional): The name of file to load all parameters.
172 173
                When all parameters were saved in a single binary file, set it
                as the real filename. If parameters were saved in separate files,
174
                set it as 'None'. Default is 'None'.
175
            batch_generator(Python Generator): The batch generator provides
176 177 178 179 180 181 182
                calibrate data for DataLoader, and it returns a batch every
                time. Note that, sample_generator and batch_generator, only one
                should be set. Beisdes, batch_generator supports lod tensor.
            sample_generator(Python Generator): The sample generator provides
                calibrate data for DataLoader, and it only returns a sample every
                time. Note that, sample_generator and batch_generator, only one
                should be set. Beisdes, sample_generator dose not support lod tensor.
183 184 185
            data_loader(Python Generator, Paddle.io.DataLoader, optional): The
                Generator or Dataloader provides calibrate data, and it could
                return a batch every time.
186
            batch_size(int, optional): The batch size of DataLoader. Default is 10.
187 188
            batch_nums(int, optional): If batch_nums is not None, the number of
                calibrate data is batch_size*batch_nums. If batch_nums is None, use
189
                all data provided by sample_generator as calibrate data.
190 191
            algo(str, optional): If algo='KL', use KL-divergenc method to
                get the KL threshold for quantized activations and get the abs_max
192 193
                value for quantized weights. If algo='abs_max', get the abs max
                value for activations and weights. If algo= 'min_max', get the min
X
XGZhang 已提交
194
                and max value for quantized activations and weights. If algo='avg',
195
                get the average value among the max values for activations. If
X
XGZhang 已提交
196
                algo= 'hist', get the value of 'hist_percent' quantile as the threshold.
197
                If algo='mse', get the value which makes the quantization mse loss
X
XGZhang 已提交
198 199 200
                minimal. Default is KL.
            hist_percent(float, optional): The threshold of algo 'hist' for activations.
                Default is 0.99999.
201
            quantizable_op_type(list[str], optional): List the type of ops
202 203 204
                that will be quantized. Default is []. If quantizable_op_type is [],
                it will use the default quantization op type of the qunat config in
                the current deploy_backend.
205
            round_type(str, optional): The method of converting the quantized weights
206
                value float->int. Currently supports ['round', 'adaround'] methods.
207 208
                Default is `round`, which is rounding nearest to the integer.
                'adaround' is refer to https://arxiv.org/abs/2004.10568.
209
            learning_rate(float, optional): The learning rate of adaround method.
210
            is_full_quantized(bool, optional): If set is_full_quantized as True,
211
                apply quantization to all supported quantizable op type. If set
212 213
                is_full_quantized as False, it will apply quantization to the op type
                according to the input quantizable_op_type or quant config of deploy_backend.
X
XGZhang 已提交
214 215
            bias_correction(bool, optional): If set as True, use the bias correction
                method of https://arxiv.org/abs/1810.05723. Default is False.
216
            activation_bits(int): quantization bit number for activation.
217 218 219 220 221 222 223 224 225 226 227 228
            weight_bits(int, optional): quantization bit number for weights.
            activation_quantize_type(str): quantization type for activation,
                now support 'range_abs_max', 'moving_average_abs_max' and 'abs_max'.
                This param only specifies the fake ops in saving quantized model.
                If it is 'range_abs_max' or 'moving_average_abs_max', we save the scale
                obtained by post training quantization in fake ops. Note that, if it
                is 'abs_max', the scale will not be saved in fake ops.
            weight_quantize_type(str): quantization type for weights,
                support 'abs_max' and 'channel_wise_abs_max'. This param only specifies
                the fake ops in saving quantized model, and we save the scale obtained
                by post training quantization in fake ops. Compared to 'abs_max',
                the model accuracy is usually higher when it is 'channel_wise_abs_max'.
229 230
            onnx_format(bool): Whether to export the quantized model with format of ONNX.
                Default is False.
231
            freeze_model(bool): Whether to convert quantized and trained ``program`` to final
232 233
                quantized ``program``. Default: True.
            skip_tensor_list(list): List of skip quant tensor name. Default: None.
234 235
            same_scale_tensor_list(list(list)): The list of tensor keep same scale in the outermost
                list, the final scale about every list is the max of the scale in the list
236
                of tensor. Default: None.
237 238 239 240 241 242 243 244
            optimize_model(bool, optional): If set optimize_model as True, it applies
                some passes to the model before quantization, and it supports
                `conv2d/depthwise_conv2d + bn` pass so far. Some targets require the
                weights are quantized by tensor-wise method, which means the weights
                scale for all channel are the same. However, if fuse
                `conv2d/depthwise_conv2d + bn`, the weights scale for all channel will
                be different. In address this problem, fuse the pattern before
                quantization. Default False.
245 246
            is_use_cache_file(bool, optional): This param is deprecated.
            cache_dir(str, optional): This param is deprecated.
247 248 249
            deploy_backend(str, optional): Deploy backend, it can be None, `TensorRT`,
                `MKLDNN`, `ARM`. And it will extend the new backend. Default is None,
                which means to use the default general quantization configuration.
250 251 252
        Returns:
            None

253 254
        Examples:
        .. code-block:: python
255 256
            import paddle.static as static
            from paddle.static.quantization import PostTrainingQuantization
257

258
            exe = static.Executor(paddle.CPUPlace())
259
            model_dir = path/to/fp32_model_params
260
            # set model_filename as None when the filename is __model__,
261
            # otherwise set it as the real filename
262 263
            model_filename = None
            # set params_filename as None when all parameters were saved in
264 265 266
            # separate files, otherwise set it as the real filename
            params_filename = None
            save_model_path = path/to/save_model_path
267
            # prepare the sample generator according to the model, and the
268
            # sample generator must return a sample every time. The reference
269 270 271
            # document: https://www.paddlepaddle.org.cn/documentation/docs/zh
            # /user_guides/howto/prepare_data/use_py_reader.html
            sample_generator = your_sample_generator
272 273 274
            batch_size = 10
            batch_nums = 10
            algo = "KL"
275
            quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
276 277
            ptq = PostTrainingQuantization(
                        executor=exe,
278 279 280 281
                        sample_generator=sample_generator,
                        model_dir=model_dir,
                        model_filename=model_filename,
                        params_filename=params_filename,
282 283 284 285 286 287 288
                        batch_size=batch_size,
                        batch_nums=batch_nums,
                        algo=algo,
                        quantizable_op_type=quantizable_op_type)
            ptq.quantize()
            ptq.save_quantized_model(save_model_path)
        '''
289

290
        self._support_activation_quantize_type = [
291 292 293
            'range_abs_max',
            'moving_average_abs_max',
            'abs_max',
294 295
        ]
        self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
X
XGZhang 已提交
296
        self._support_algo_type = [
297 298 299 300 301 302 303 304
            'KL',
            'hist',
            'avg',
            'mse',
            'emd',
            'abs_max',
            'min_max',
            'ptf',
X
XGZhang 已提交
305
        ]
306
        assert round_type in ['adaround', 'round']
307 308
        self._round_type = round_type
        self._learning_rate = learning_rate
309
        self._dynamic_quantize_op_type = ['lstm']
310 311

        # Check inputs
312
        assert executor is not None, "The executor cannot be None."
313 314 315 316 317
        assert any(
            [gen is not None]
            for gen in [sample_generator, batch_generator, data_loader]
        ), (
            "The sample_generator, batch_generator "
318
            "and data_loader cannot be None in the same time."
319
        )
320
        if data_loader is not None:
321 322 323 324 325 326 327 328
            assert isinstance(
                data_loader,
                (
                    io.DataLoader,
                    type(isgeneratorfunction),
                    reader.GeneratorLoader,
                ),
            ), "data_loader only accepts `paddle.io.DataLoader` or Generator instance."
329
        assert batch_size > 0, "The batch_size should be greater than 0."
330 331 332 333 334 335 336 337 338 339 340 341 342
        assert (
            algo in self._support_algo_type
        ), "The algo should be KL, hist, mse, avg, abs_max, min_max or ptf."
        assert (
            activation_quantize_type in self._support_activation_quantize_type
        ), "The activation_quantize_type ({}) should in ({}).".format(
            activation_quantize_type, self._support_activation_quantize_type
        )
        assert (
            weight_quantize_type in self._support_weight_quantize_type
        ), "The weight_quantize_type ({}) shoud in ({}).".format(
            weight_quantize_type, self._support_weight_quantize_type
        )
343 344

        # Save input params
X
XGZhang 已提交
345
        self._bias_correction = bias_correction
346
        self._executor = executor
347
        self._scope = static.global_scope() if scope is None else scope
348 349 350
        self._model_dir = model_dir
        self._model_filename = model_filename
        self._params_filename = params_filename
351
        self._sample_generator = sample_generator
352
        self._batch_generator = batch_generator
353 354 355
        self._batch_size = batch_size
        self._batch_nums = batch_nums
        self._algo = algo
X
XGZhang 已提交
356
        self._hist_percent = hist_percent
357 358 359 360
        self._activation_bits = activation_bits
        self._weight_bits = weight_bits
        self._activation_quantize_type = activation_quantize_type
        self._weight_quantize_type = weight_quantize_type
361
        self._onnx_format = onnx_format
G
Guanghua Yu 已提交
362
        self._clip_extra = True if self._onnx_format else False
363
        self._skip_tensor_list = skip_tensor_list
364
        self._optimize_model = optimize_model
365

366
        # Define variables
367 368 369 370
        self._place = self._executor.place
        self._program = None
        self._feed_list = None
        self._fetch_list = None
371
        self._data_loader = data_loader
372

373 374
        self._quantized_weight_var_name = set()
        self._quantized_act_var_name = set()
375
        self._weight_op_pairs = {}
X
XGZhang 已提交
376
        # The vars for alog = KL or hist
377 378
        self._sampling_act_abs_min_max = {}
        self._sampling_act_histogram = {}
379
        self._sampling_data = {}
X
XGZhang 已提交
380
        self._quantized_var_threshold = {}
381 382
        self._histogram_bins = 2048
        # The vars for algo = min_max
383 384
        self._quantized_var_min = {}
        self._quantized_var_max = {}
X
XGZhang 已提交
385 386 387
        # The vars for algo = avg
        self._quantized_var_avg = {}
        # The best loss of algo = mse
388
        self._best_calibration_loss = {}
X
XGZhang 已提交
389 390
        # The threshold for algo = abs_max, mse or avg
        self._quantized_threshold = {}
391 392 393
        # If the tensor is zero-size during any calibration step,
        # it will be stored in self._zero_size_var_names
        self._zero_size_var_names = set()
394 395 396 397
        self._same_scale_tensor_list = same_scale_tensor_list
        self._freeze_model = freeze_model
        self._scale_dict = scale_dict
        self._return_graph = return_graph
398 399 400
        self.FLAG = False
        if self._program is not None:
            self.FLAG = True
401

402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
        self._is_full_quantize = is_full_quantize
        if is_full_quantize:
            quantizable_op_type = list(SUPPORT_QUANTIZATION_OP_DICT.keys())
        elif quantizable_op_type:
            for op_type in quantizable_op_type:
                assert op_type in list(SUPPORT_QUANTIZATION_OP_DICT.keys()), (
                    op_type + " is not supported for quantization."
                )
        assert (
            activation_bits == weight_bits
        ), "activation_bits and weight_bits must be the same, other cases are not supported."
        support_deploy_backend = [None, "tensorrt", "mkldnn", "arm"]
        if not deploy_backend:
            self.quant_config = BaseQuantizer(
                quantizable_op_type=quantizable_op_type,
                quant_bits=weight_bits,
            )
        elif deploy_backend.lower() == "tensorrt":
            self.quant_config = TensorRTQuantizer(
                quantizable_op_type=quantizable_op_type,
                quant_bits=weight_bits,
            )
        elif deploy_backend.lower() == "mkldnn":
            self.quant_config = MKLDNNQuantizer(
                quantizable_op_type=quantizable_op_type,
                quant_bits=weight_bits,
            )
        elif deploy_backend.lower() == "arm":
            self.quant_config = ARMCPUQuantizer(
                quantizable_op_type=quantizable_op_type,
                quant_bits=weight_bits,
            )
        else:
            assert "Deploy Backend {} not support, please choose one of {}.".format(
                deploy_backend, support_deploy_backend
            )

439 440
    def quantize(self):
        '''
441 442 443
        Load the FP32 model, and use the calibrate data to calculate the forward-stage.
        Based on the sample data, we can get the quantization information, and obtain
        the final quantized model.
444 445 446 447

        Args:
            None
        Returns:
448 449
            the program of quantized model.
        '''
450
        self._load_model_data()
451
        self._collect_target_varnames()
452
        self._set_activation_persistable()
453

X
XGZhang 已提交
454
        if self._algo in ["KL", "hist"]:
455
            batch_id = 0
456
            with tqdm(
457 458 459 460
                total=self._batch_nums,
                bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
                ncols=80,
            ) as t:
461
                for data in self._data_loader():
462 463 464 465 466 467 468
                    self._executor.run(
                        program=self._program,
                        feed=data,
                        fetch_list=self._fetch_list,
                        return_numpy=False,
                        scope=self._scope,
                    )
469 470 471 472 473 474 475 476
                    self._collect_activation_abs_min_max()
                    batch_id += 1
                    t.update()
                    if self._batch_nums and batch_id >= self._batch_nums:
                        break
            self._init_sampling_act_histogram()

        batch_id = 0
477 478 479 480 481
        with tqdm(
            total=self._batch_nums,
            bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
            ncols=80,
        ) as t:
482
            for data in self._data_loader():
483 484 485 486 487 488 489
                self._executor.run(
                    program=self._program,
                    feed=data,
                    fetch_list=self._fetch_list,
                    return_numpy=False,
                    scope=self._scope,
                )
490
                self._sampling()
491
                batch_id += 1
492
                t.update()
493 494
                if self._batch_nums and batch_id >= self._batch_nums:
                    break
495

X
XGZhang 已提交
496 497
        if self._algo == 'avg':
            for var_name in self._quantized_act_var_name:
498 499
                if var_name not in self._quantized_var_avg:
                    continue
500 501 502
                self._quantized_threshold[var_name] = np.array(
                    self._quantized_var_avg[var_name]
                ).mean()
503

X
XGZhang 已提交
504 505
        if self._algo in ["KL", "hist"]:
            self._calculate_kl_hist_threshold()
506

507
        if self._round_type == 'adaround':
508 509 510 511
            self._adaround_apply()

        self._reset_activation_persistable()

512
        if self._algo == 'min_max':
513
            self._save_input_threhold()
514 515 516 517
        else:
            self._update_program()

        # save out_threshold for quantized ops.
518 519
        if not self.FLAG:
            self._save_output_threshold()
520

521
        if any(
522
            op_type in self.quant_config.activation_quant_operation_types
523 524
            for op_type in self._dynamic_quantize_op_type
        ):
525
            self._collect_dynamic_quantize_op_threshold(
526 527
                self._dynamic_quantize_op_type
            )
528

529
        utils.move_persistable_var_to_global_block(self._program)
530

531 532 533 534 535
        if not self._return_graph:
            return self._program
        else:
            main_graph = IrGraph(core.Graph(self._program.desc), for_test=True)
            return main_graph
536

537
    def _adaround_apply(self):
538
        assert self._algo != "min_max", "The algo should not be min_max."
539 540 541 542
        if self._algo in ["KL", "hist"]:
            scale_dict = self._quantized_var_threshold
        else:
            scale_dict = self._quantized_threshold
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560
        run_adaround(
            self._data_loader,
            self._program,
            self._fetch_list,
            self._executor,
            self._scope,
            self._place,
            self._quantized_op_pairs,
            self._weight_op_pairs,
            scale_dict,
            num_iterations=self._batch_nums,
            bias_correction=self._bias_correction,
            lr=self._learning_rate,
        )

    def save_quantized_model(
        self, save_model_path, model_filename=None, params_filename=None
    ):
561 562 563 564
        '''
        Save the quantized model to the disk.

        Args:
565 566
            save_model_path(str): The path to save the quantized model.
            model_filename(str, optional): If the model_filename is None,
567 568
                save the model to 'model.pdmodel' and 'model.pdiparams'. Otherwise, save the model to 'model_name.pdmodel' and
                'model_name.pdiparams". Default: None.
569
        Returns:
570 571
            None
        '''
572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587
        model_name = None
        if model_filename is None:
            model_name = "model"
        elif model_filename.endswith(".pdmodel"):
            model_name = model_filename.rsplit(".", 1)[0]
        else:
            model_name = model_filename

        path_prefix = os.path.join(save_model_path, model_name)
        feed_vars = [
            self._program.global_block().var(name) for name in self._feed_list
        ]
        static.save_inference_model(
            path_prefix,
            feed_vars,
            self._fetch_list,
588
            executor=self._executor,
589
            program=self._program,
590 591
            clip_extra=self._clip_extra,
        )
592
        _logger.info("The quantized model is saved in " + save_model_path)
593

594
    def _load_model_data(self):
595
        '''
596
        Load model and set data loader.
597
        '''
598 599
        if self._program is None:
            _logger.info("Load model and set data loader ...")
600 601 602 603
            [
                self._program,
                self._feed_list,
                self._fetch_list,
604 605
            ] = static.load_inference_model(
                self._model_dir,
606 607 608 609
                executor=self._executor,
                model_filename=self._model_filename,
                params_filename=self._params_filename,
            )
610 611 612 613

        if self._optimize_model:
            self._optimize_fp32_model()

614
        feed_vars = [
615
            _get_var(str(var_name), self._program)
616 617
            for var_name in self._feed_list
        ]
618 619

        if self._data_loader is not None:
620 621 622
            self._batch_nums = (
                self._batch_nums if self._batch_nums else len(self._data_loader)
            )
623
            return
624 625 626
        self._data_loader = io.DataLoader.from_generator(
            feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True
        )
627
        if self._sample_generator is not None:
628 629 630 631 632 633
            self._data_loader.set_sample_generator(
                self._sample_generator,
                batch_size=self._batch_size,
                drop_last=True,
                places=self._place,
            )
634
        elif self._batch_generator is not None:
635 636 637 638 639 640 641 642
            self._data_loader.set_batch_generator(
                self._batch_generator, places=self._place
            )
        self._batch_nums = (
            self._batch_nums
            if self._batch_nums
            else len(list(self._data_loader))
        )
643

644 645 646 647 648 649 650 651
    def _optimize_fp32_model(self):
        '''
        Fuse the `conv2d/depthwise_conv2d + bn` in FP32 model.
        '''
        _logger.info("Optimize FP32 model ...")
        graph = IrGraph(core.Graph(self._program.desc), for_test=True)
        graph = _remove_ctrl_vars(graph)
        graph = _apply_pass(self._scope, graph, 'conv_bn_fuse_pass')
652 653
        graph = _apply_pass(self._scope, graph, 'depthwise_conv_bn_fuse_pass')
        graph = _apply_pass(self._scope, graph, 'conv_transpose_bn_fuse_pass')
654
        graph = _apply_pass(self._scope, graph, 'conv_eltwiseadd_bn_fuse_pass')
655 656 657
        graph = _apply_pass(
            self._scope, graph, 'depthwise_conv_eltwiseadd_bn_fuse_pass'
        )
658

659 660
        self._program = graph.to_program()

661
    def _collect_target_varnames(self):
662 663 664 665
        '''
        Collect the variable names for sampling, and set activation
        variables to be persistable.
        '''
666
        # TODO(juncaipeng), consider the name_scope of skip_quant
667
        _logger.info("Collect quantized variable names ...")
668
        self._quantized_op_pairs = {}
669

670
        def collect_var_name(var_name_list, persistable_var_names, op_type):
671 672 673
            for var_name in var_name_list:
                if var_name in persistable_var_names:
                    self._quantized_weight_var_name.add(var_name)
674
                    self._weight_op_pairs[var_name] = op_type
675 676 677
                else:
                    self._quantized_act_var_name.add(var_name)

678
        persistable_var_names = _all_persistable_var_names(self._program)
679 680
        for block_id in range(len(self._program.blocks)):
            for op in self._program.blocks[block_id].ops:
681 682 683 684 685 686
                # skip quant form self._skip_tensor_list
                if self._skip_tensor_list is not None:
                    for inp_name in utils._get_op_input_var_names(op):
                        if inp_name in self._skip_tensor_list:
                            op._set_attr("op_namescope", "skip_quant")

687
                op_type = op.type
688 689
                if self._is_full_quantize and op_type not in list(
                    SUPPORT_QUANTIZATION_OP_DICT.keys()
690 691 692 693
                ):
                    _logger.warning(
                        op_type + " is not supported for quantization."
                    )
694 695 696 697
                is_conv1d_quant = (op_type == "unsqueeze2") and (
                    utils._get_op_input_var_names(op)[0]
                    in persistable_var_names
                )
698
                # For quantized ops, sample inputs and outputs
699 700 701 702 703 704
                if (
                    op_type in self.quant_config.weight_quant_operation_types
                    or op_type
                    in self.quant_config.activation_quant_operation_types
                    or is_conv1d_quant
                ):
705 706 707 708 709 710 711 712 713 714
                    collect_var_name(
                        utils._get_op_input_var_names(op),
                        persistable_var_names,
                        op_type,
                    )
                    collect_var_name(
                        utils._get_op_output_var_names(op),
                        persistable_var_names,
                        op_type,
                    )
715
                    # collect quanted op output var name
716 717
                    for out_var_name in utils._get_op_output_var_names(op):
                        for in_var_name in utils._get_op_input_var_names(op):
718 719
                            if in_var_name in persistable_var_names:
                                self._quantized_op_pairs[
720 721
                                    in_var_name
                                ] = out_var_name
722
                # For other op, only sample output scale
723
                elif op_type in self.quant_config.observer_operation_types:
724 725 726 727 728
                    collect_var_name(
                        utils._get_op_output_var_names(op),
                        persistable_var_names,
                        op_type,
                    )
729 730 731

    def _set_activation_persistable(self):
        '''
732
        Set activation variables to be persistable, so can obtain
733 734
        the tensor data in sample_data
        '''
735 736 737 738
        for var in self._program.list_vars():
            if var.name in self._quantized_act_var_name:
                var.persistable = True

739 740 741 742 743 744 745
    def _reset_activation_persistable(self):
        '''
        Reset activations to be not persistable.
        '''
        for var in self._program.list_vars():
            if var.name in self._quantized_act_var_name:
                var.persistable = False
C
ceci3 已提交
746
                self._scope.find_var(var.name).get_tensor()._clear()
747

748
    def _sampling(self):
749
        '''
750
        Sample the min/max, abs_max or histogram in every iterations.
751 752
        '''
        if self._algo == "abs_max":
753
            self._sample_abs_max()
X
XGZhang 已提交
754 755
        elif self._algo == "avg":
            self._sample_avg()
756
        elif self._algo == "min_max":
757
            self._sample_min_max()
X
XGZhang 已提交
758 759
        elif self._algo == "mse":
            self._sample_mse()
760 761
        elif self._algo == "emd":
            self._sample_emd()
H
handiz 已提交
762 763
        elif self._algo == "ptf":
            self._sample_ptf()
X
XGZhang 已提交
764
        elif self._algo in ["KL", "hist"]:
765
            self._sample_histogram()
766

X
XGZhang 已提交
767 768 769
    def _sample_mse(self):
        if self._quantized_threshold == {}:
            for var_name in self._quantized_weight_var_name:
770
                var_tensor = utils.load_variable_data(self._scope, var_name)
X
XGZhang 已提交
771 772 773 774
                if self._weight_quantize_type == "abs_max":
                    abs_max_value = float(np.max(np.abs(var_tensor)))
                elif self._weight_quantize_type == "channel_wise_abs_max":
                    abs_max_value = []
775 776 777 778
                    if (
                        self._weight_op_pairs[var_name]
                        in utils._channelwise_quant_axis1_ops
                    ):
X
XGZhang 已提交
779 780
                        for i in range(var_tensor.shape[1]):
                            abs_max_value.append(
781 782
                                float(np.max(np.abs(var_tensor[:, i])))
                            )
X
XGZhang 已提交
783 784 785
                    else:
                        for i in range(var_tensor.shape[0]):
                            abs_max_value.append(
786 787
                                float(np.max(np.abs(var_tensor[i])))
                            )
X
XGZhang 已提交
788 789 790
                self._quantized_threshold[var_name] = abs_max_value
        _logger.info("MSE searching stage ...")
        for var_name in self._quantized_act_var_name:
791
            var_tensor = utils.load_variable_data(self._scope, var_name)
792
            if var_tensor.size == 0:
793 794
                self._zero_size_var_names.add(var_name)
                continue
X
XGZhang 已提交
795 796
            var_tensor = var_tensor.flatten()
            abs_max_value = float(np.max(np.abs(var_tensor)))
X
XGZhang 已提交
797
            abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
X
XGZhang 已提交
798
            s = 0.3
799 800
            if var_name not in self._best_calibration_loss:
                self._best_calibration_loss[var_name] = float('inf')
X
XGZhang 已提交
801 802 803
            while s <= 1.0:
                scale = s * abs_max_value
                s += 0.02
804
                bins = 2 ** (self._activation_bits - 1) - 1
805
                if self._onnx_format:
806 807 808
                    quant_var = np.clip(
                        np.round(var_tensor / scale * bins), -bins - 1, bins
                    )
809 810
                    quant_dequant_var = quant_var / bins * scale
                else:
811 812 813 814 815 816
                    quant_dequant_var = (
                        np.round(np.clip(var_tensor, 0.0, scale) / scale * bins)
                        / bins
                        * scale
                    )
                mse_loss = ((var_tensor - quant_dequant_var) ** 2).mean()
817 818 819 820 821 822 823
                if mse_loss <= self._best_calibration_loss[var_name]:
                    self._best_calibration_loss[var_name] = mse_loss
                    self._quantized_threshold[var_name] = scale

    def _sample_emd(self):
        if self._quantized_threshold == {}:
            for var_name in self._quantized_weight_var_name:
824
                var_tensor = utils.load_variable_data(self._scope, var_name)
825 826 827 828
                if self._weight_quantize_type == "abs_max":
                    abs_max_value = float(np.max(np.abs(var_tensor)))
                elif self._weight_quantize_type == "channel_wise_abs_max":
                    abs_max_value = []
829 830 831 832
                    if (
                        self._weight_op_pairs[var_name]
                        in utils._channelwise_quant_axis1_ops
                    ):
833 834
                        for i in range(var_tensor.shape[1]):
                            abs_max_value.append(
835 836
                                float(np.max(np.abs(var_tensor[:, i])))
                            )
837 838 839
                    else:
                        for i in range(var_tensor.shape[0]):
                            abs_max_value.append(
840 841
                                float(np.max(np.abs(var_tensor[i])))
                            )
842 843 844
                self._quantized_threshold[var_name] = abs_max_value
        _logger.info("EMD searching stage ...")
        for var_name in self._quantized_act_var_name:
845
            var_tensor = utils.load_variable_data(self._scope, var_name)
846
            if var_tensor.size == 0:
847 848
                self._zero_size_var_names.add(var_name)
                continue
849 850 851 852 853 854 855 856 857
            var_tensor = var_tensor.flatten()
            abs_max_value = float(np.max(np.abs(var_tensor)))
            abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
            s = 0.3
            if var_name not in self._best_calibration_loss:
                self._best_calibration_loss[var_name] = float('inf')
            while s <= 1.0:
                scale = s * abs_max_value
                s += 0.02
858
                bins = 2 ** (self._activation_bits - 1) - 1
859
                if self._onnx_format:
860 861 862
                    quant_var = np.clip(
                        np.round(var_tensor / scale * bins), -bins - 1, bins
                    )
863 864
                    quant_dequant_var = quant_var / bins * scale
                else:
865 866 867 868 869
                    quant_dequant_var = (
                        np.round(np.clip(var_tensor, 0.0, scale) / scale * bins)
                        / bins
                        * scale
                    )
870
                emd_loss = np.abs(
871 872
                    np.mean(var_tensor) - np.mean(quant_dequant_var)
                ) + np.abs(np.std(var_tensor) - np.std(quant_dequant_var))
873 874
                if emd_loss <= self._best_calibration_loss[var_name]:
                    self._best_calibration_loss[var_name] = emd_loss
X
XGZhang 已提交
875 876 877 878 879
                    self._quantized_threshold[var_name] = scale

    def _sample_avg(self):
        if self._quantized_threshold == {}:
            for var_name in self._quantized_weight_var_name:
880
                var_tensor = utils.load_variable_data(self._scope, var_name)
X
XGZhang 已提交
881 882 883 884
                if self._weight_quantize_type == "abs_max":
                    abs_max_value = float(np.max(np.abs(var_tensor)))
                elif self._weight_quantize_type == "channel_wise_abs_max":
                    abs_max_value = []
885 886 887 888
                    if (
                        self._weight_op_pairs[var_name]
                        in utils._channelwise_quant_axis1_ops
                    ):
X
XGZhang 已提交
889 890
                        for i in range(var_tensor.shape[1]):
                            abs_max_value.append(
891 892
                                float(np.max(np.abs(var_tensor[:, i])))
                            )
X
XGZhang 已提交
893 894 895
                    else:
                        for i in range(var_tensor.shape[0]):
                            abs_max_value.append(
896 897
                                float(np.max(np.abs(var_tensor[i])))
                            )
X
XGZhang 已提交
898 899 900
                self._quantized_threshold[var_name] = abs_max_value

        for var_name in self._quantized_act_var_name:
901
            var_tensor = utils.load_variable_data(self._scope, var_name)
902
            if var_tensor.size == 0:
903 904
                self._zero_size_var_names.add(var_name)
                continue
X
XGZhang 已提交
905
            abs_max_value = float(np.max(np.abs(var_tensor)))
906
            if var_name not in self._quantized_var_avg:
X
XGZhang 已提交
907
                self._quantized_var_avg[var_name] = []
908 909 910 911 912 913 914 915
            abs_avg_value = float(
                np.mean(
                    np.max(
                        np.abs(var_tensor.reshape(var_tensor.shape[0], -1)),
                        axis=(1),
                    )
                )
            )
X
XGZhang 已提交
916 917
            self._quantized_var_avg[var_name].append(abs_avg_value)

918
    def _sample_abs_max(self):
X
XGZhang 已提交
919
        if self._quantized_threshold == {}:
920
            for var_name in self._quantized_weight_var_name:
921
                var_tensor = utils.load_variable_data(self._scope, var_name)
922 923 924 925
                if self._weight_quantize_type == "abs_max":
                    abs_max_value = float(np.max(np.abs(var_tensor)))
                elif self._weight_quantize_type == "channel_wise_abs_max":
                    abs_max_value = []
926 927 928 929
                    if (
                        self._weight_op_pairs[var_name]
                        in utils._channelwise_quant_axis1_ops
                    ):
930 931
                        for i in range(var_tensor.shape[1]):
                            abs_max_value.append(
932 933
                                float(np.max(np.abs(var_tensor[:, i])))
                            )
934 935 936
                    else:
                        for i in range(var_tensor.shape[0]):
                            abs_max_value.append(
937 938
                                float(np.max(np.abs(var_tensor[i])))
                            )
X
XGZhang 已提交
939
                self._quantized_threshold[var_name] = abs_max_value
940 941

        for var_name in self._quantized_act_var_name:
942
            var_tensor = utils.load_variable_data(self._scope, var_name)
943
            if var_tensor.size == 0:
944 945
                self._zero_size_var_names.add(var_name)
                continue
946
            abs_max_value = float(np.max(np.abs(var_tensor)))
947 948 949
            if (var_name not in self._quantized_threshold) or (
                abs_max_value > self._quantized_threshold[var_name]
            ):
X
XGZhang 已提交
950
                self._quantized_threshold[var_name] = abs_max_value
951

952
    def _sample_min_max(self):
953 954
        if self._quantized_var_min == {} and self._quantized_var_max == {}:
            for var_name in self._quantized_weight_var_name:
955
                var_tensor = utils.load_variable_data(self._scope, var_name)
956 957 958 959 960 961
                if self._weight_quantize_type == "abs_max":
                    min_value = float(np.min(var_tensor))
                    max_value = float(np.max(var_tensor))
                elif self._weight_quantize_type == "channel_wise_abs_max":
                    min_value = []
                    max_value = []
962 963 964 965
                    if (
                        self._weight_op_pairs[var_name]
                        in utils._channelwise_quant_axis1_ops
                    ):
966 967 968 969 970 971 972 973 974 975 976
                        for i in range(var_tensor.shape[1]):
                            min_value.append(float(np.min(var_tensor[:, i])))
                            max_value.append(float(np.max(var_tensor[:, i])))
                    else:
                        for i in range(var_tensor.shape[0]):
                            min_value.append(float(np.min(var_tensor[i])))
                            max_value.append(float(np.max(var_tensor[i])))
                self._quantized_var_min[var_name] = min_value
                self._quantized_var_max[var_name] = max_value

        for var_name in self._quantized_act_var_name:
977
            var_tensor = utils.load_variable_data(self._scope, var_name)
978
            if var_tensor.size == 0:
979 980
                self._zero_size_var_names.add(var_name)
                continue
981 982
            min_value = float(np.min(var_tensor))
            max_value = float(np.max(var_tensor))
983 984 985
            if (var_name not in self._quantized_var_min) or (
                min_value < self._quantized_var_min[var_name]
            ):
986
                self._quantized_var_min[var_name] = min_value
987 988 989
            if (var_name not in self._quantized_var_max) or (
                max_value > self._quantized_var_max[var_name]
            ):
990
                self._quantized_var_max[var_name] = max_value
991

992 993
    def _sample_histogram(self):
        for var_name in self._quantized_act_var_name:
994
            var_tensor = utils.load_variable_data(self._scope, var_name)
995
            if (var_tensor.size == 0) or (
996 997 998 999
                var_name not in self._sampling_act_histogram
            ):
                self._zero_size_var_names.add(var_name)
                continue
1000 1001 1002 1003 1004
            var_tensor_abs = np.abs(var_tensor)
            bins = self._sampling_act_histogram[var_name][1]
            hist, _ = np.histogram(var_tensor_abs, bins=bins)
            self._sampling_act_histogram[var_name][0] += hist

H
handiz 已提交
1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016
    def _sample_ptf(self):
        """
        The following code are modified from:
        https://github.com/megvii-research/FQ-ViT/
        """
        if self._quantized_threshold == {}:
            for var_name in self._quantized_weight_var_name:
                var_tensor = utils.load_variable_data(self._scope, var_name)
                if self._weight_quantize_type == "abs_max":
                    abs_max_value = float(np.max(np.abs(var_tensor)))
                elif self._weight_quantize_type == "channel_wise_abs_max":
                    abs_max_value = []
1017 1018 1019 1020
                    if (
                        self._weight_op_pairs[var_name]
                        in utils._channelwise_quant_axis1_ops
                    ):
H
handiz 已提交
1021 1022
                        for i in range(var_tensor.shape[1]):
                            abs_max_value.append(
1023 1024
                                float(np.max(np.abs(var_tensor[:, i])))
                            )
H
handiz 已提交
1025 1026 1027
                    else:
                        for i in range(var_tensor.shape[0]):
                            abs_max_value.append(
1028 1029
                                float(np.max(np.abs(var_tensor[i])))
                            )
H
handiz 已提交
1030 1031 1032 1033
                self._quantized_threshold[var_name] = abs_max_value

        for var_name in self._quantized_act_var_name:
            var_tensor = utils.load_variable_data(self._scope, var_name)
1034
            if var_tensor.size == 0:
1035 1036
                self._zero_size_var_names.add(var_name)
                continue
H
handiz 已提交
1037
            abs_max_value = float(np.max(np.abs(var_tensor)))
1038
            q_max = 2 ** (self._activation_bits - 1) - 1
H
handiz 已提交
1039 1040 1041 1042
            scale8 = abs_max_value / q_max
            scale4 = scale8 / 2
            scale2 = scale4 / 2
            scale1 = scale2 / 2
1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054
            quant_dequant_var_scale1 = (
                np.clip(np.round(var_tensor / scale1), 0, q_max) * scale1
            )
            quant_dequant_var_scale2 = (
                np.clip(np.round(var_tensor / scale2), 0, q_max) * scale2
            )
            quant_dequant_var_scale4 = (
                np.clip(np.round(var_tensor / scale4), 0, q_max) * scale4
            )
            quant_dequant_var_scale8 = (
                np.clip(np.round(var_tensor / scale8), 0, q_max) * scale8
            )
1055 1056 1057 1058
            score1 = utils.l2_loss(var_tensor, quant_dequant_var_scale1)
            score2 = utils.l2_loss(var_tensor, quant_dequant_var_scale2)
            score4 = utils.l2_loss(var_tensor, quant_dequant_var_scale4)
            score8 = utils.l2_loss(var_tensor, quant_dequant_var_scale8)
H
handiz 已提交
1059
            score = [score1, score2, score4, score8]
1060
            mask = 2 ** score.index(min(score))
H
handiz 已提交
1061 1062 1063 1064
            scale = scale1 * mask
            threshold = q_max * scale
            self._quantized_threshold[var_name] = threshold

1065 1066 1067 1068
    def _save_input_threhold(self):
        '''
        Save input threshold to the quantized op.
        '''
1069 1070 1071
        assert (
            self._algo == "min_max"
        ), "The algo should be min_max to save input threshold."
1072 1073
        for block_id in range(len(self._program.blocks)):
            for op in self._program.blocks[block_id].ops:
1074 1075 1076 1077 1078
                if (
                    op.type in self.quant_config.weight_quant_operation_types
                    or op.type
                    in self.quant_config.activation_quant_operation_types
                ):
1079
                    for var_name in utils._get_op_input_var_names(op):
1080 1081
                        assert var_name in self._quantized_var_min
                        assert var_name in self._quantized_var_max
1082 1083 1084 1085 1086 1087
                        op._set_attr(
                            var_name + ".min", self._quantized_var_min[var_name]
                        )
                        op._set_attr(
                            var_name + ".max", self._quantized_var_max[var_name]
                        )
1088
                        op._set_attr("with_quant_attr", True)
1089

1090
    def _collect_activation_abs_min_max(self):
1091
        '''
1092 1093
        Collect the abs_min and abs_max for all activation. When algo = KL,
        get the min and max value, and then calculate the threshold.
1094
        '''
1095
        for var_name in self._quantized_act_var_name:
1096
            var_tensor = utils.load_variable_data(self._scope, var_name)
1097
            if var_tensor.size == 0:
1098 1099
                self._zero_size_var_names.add(var_name)
                continue
1100 1101 1102 1103
            var_tensor = np.abs(var_tensor)
            min_value = float(np.min(var_tensor))
            max_value = float(np.max(var_tensor))
            if var_name not in self._sampling_act_abs_min_max:
1104
                self._sampling_act_abs_min_max[var_name] = [
1105 1106
                    min_value,
                    max_value,
1107
                ]
1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118
            else:
                if min_value < self._sampling_act_abs_min_max[var_name][0]:
                    self._sampling_act_abs_min_max[var_name][0] = min_value
                if max_value > self._sampling_act_abs_min_max[var_name][1]:
                    self._sampling_act_abs_min_max[var_name][1] = max_value

    def _init_sampling_act_histogram(self):
        '''
        Based on the min/max value, init the sampling_act_histogram.
        '''
        for var_name in self._quantized_act_var_name:
1119 1120 1121 1122
            if (var_name in self._zero_size_var_names) and (
                var_name not in self._sampling_act_abs_min_max
            ):
                continue
1123 1124 1125
            if var_name not in self._sampling_act_histogram:
                min_val = self._sampling_act_abs_min_max[var_name][0]
                max_val = self._sampling_act_abs_min_max[var_name][1]
1126 1127 1128
                hist, hist_edeges = np.histogram(
                    [], bins=self._histogram_bins, range=(min_val, max_val)
                )
1129
                self._sampling_act_histogram[var_name] = [hist, hist_edeges]
1130

X
XGZhang 已提交
1131
    def _calculate_kl_hist_threshold(self):
1132
        '''
X
XGZhang 已提交
1133
        Calculate the KL or hist threshold of quantized variables.
1134
        '''
X
XGZhang 已提交
1135 1136
        _logger.info("Calculate {} threshold ...".format(self._algo))
        assert self._algo in ["KL", "hist"], "The algo should be KL or hist."
1137 1138

        # Abs_max threshold for weights
1139
        for var_name in self._quantized_weight_var_name:
1140
            weight_data = utils.load_variable_data(self._scope, var_name)
1141
            if self._weight_quantize_type == "abs_max":
1142
                weight_threshold = float(np.max(np.abs(weight_data)))
1143 1144
            elif self._weight_quantize_type == "channel_wise_abs_max":
                weight_threshold = []
1145 1146 1147 1148
                if (
                    self._weight_op_pairs[var_name]
                    in utils._channelwise_quant_axis1_ops
                ):
1149 1150
                    for i in range(weight_data.shape[1]):
                        weight_threshold.append(
1151 1152
                            float(np.max(np.abs(weight_data[:, i])))
                        )
1153 1154 1155
                else:
                    for i in range(weight_data.shape[0]):
                        weight_threshold.append(
1156 1157
                            float(np.max(np.abs(weight_data[i])))
                        )
X
XGZhang 已提交
1158
            self._quantized_var_threshold[var_name] = weight_threshold
1159

1160
        for var_name in self._quantized_act_var_name:
1161 1162 1163 1164
            if (var_name in self._zero_size_var_names) and (
                var_name not in self._sampling_act_histogram
            ):
                continue
1165
            hist, hist_edeges = self._sampling_act_histogram[var_name]
X
XGZhang 已提交
1166
            if self._algo == "KL":
1167
                bin_width = hist_edeges[1] - hist_edeges[0]
1168 1169 1170
                self._quantized_var_threshold[var_name] = cal_kl_threshold(
                    hist, bin_width, self._activation_bits
                )
X
XGZhang 已提交
1171
            elif self._algo == "hist":
1172 1173 1174
                self._quantized_var_threshold[
                    var_name
                ] = self._get_hist_scaling_factor(hist, hist_edeges)
1175 1176 1177

    def _update_program(self):
        '''
1178 1179
        Use QuantizationTransformPass and AddQuantDequantPass to insert
        fake_quantize, fake_dequantize and fake_quant_dequant op.
X
XGZhang 已提交
1180
        Besides, save all threshold to the scale var node.
1181
        '''
1182
        _logger.info("Update the program ...")
1183 1184
        graph = IrGraph(core.Graph(self._program.desc), for_test=True)

1185
        # use QuantizationTransformPass to insert fake_quant/fake_dequantize op
1186 1187 1188 1189 1190 1191 1192 1193
        if not self._onnx_format:
            transform_pass = QuantizationTransformPass(
                scope=self._scope,
                place=self._place,
                weight_bits=self._weight_bits,
                activation_bits=self._activation_bits,
                activation_quantize_type=self._activation_quantize_type,
                weight_quantize_type=self._weight_quantize_type,
1194
                quantizable_op_type=self.quant_config.weight_quant_operation_types,
1195
            )
1196 1197 1198 1199 1200 1201 1202 1203
        else:
            transform_pass = QuantizationTransformPassV2(
                scope=self._scope,
                place=self._place,
                weight_bits=self._weight_bits,
                activation_bits=self._activation_bits,
                activation_quantize_type=self._activation_quantize_type,
                weight_quantize_type=self._weight_quantize_type,
1204
                quantizable_op_type=self.quant_config.weight_quant_operation_types,
1205
            )
1206 1207 1208 1209 1210 1211

        for sub_graph in graph.all_sub_graphs():
            # Insert fake_quant/fake_dequantize op must in test graph, so
            # set per graph's _for_test is True.
            sub_graph._for_test = True
            transform_pass.apply(sub_graph)
1212 1213

        # use AddQuantDequantPass to insert fake_quant_dequant op
1214 1215 1216 1217
        if not self._onnx_format:
            add_quant_dequant_pass = AddQuantDequantPass(
                scope=self._scope,
                place=self._place,
1218
                quantizable_op_type=self.quant_config.activation_quant_operation_types,
1219
            )
1220 1221 1222 1223
        else:
            add_quant_dequant_pass = AddQuantDequantPassV2(
                scope=self._scope,
                place=self._place,
1224
                quantizable_op_type=self.quant_config.activation_quant_operation_types,
1225
            )
1226 1227 1228 1229

        for sub_graph in graph.all_sub_graphs():
            sub_graph._for_test = True
            add_quant_dequant_pass.apply(sub_graph)
1230

X
XGZhang 已提交
1231
        # save threshold to scale var node
1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243
        if self._scale_dict is None:
            if self._algo in ["KL", "hist"]:
                scale_dict = self._quantized_var_threshold
            else:
                scale_dict = self._quantized_threshold

            if self._same_scale_tensor_list is not None:
                for tensor_list in self._same_scale_tensor_list:
                    max_scale = None
                    for tensor_name in tensor_list:
                        if '#' in tensor_name:
                            real_tensor_name, opera, scalar = tensor_name.split(
1244 1245
                                '#'
                            )
1246 1247
                            if real_tensor_name not in scale_dict.keys():
                                continue
1248 1249
                            if opera == '*':
                                scale_dict[real_tensor_name] = float(
1250 1251
                                    scale_dict[real_tensor_name]
                                ) * float(scalar)
1252 1253
                            elif opera == '/':
                                scale_dict[real_tensor_name] = float(
1254 1255 1256 1257 1258 1259 1260 1261 1262
                                    scale_dict[real_tensor_name]
                                ) / float(scalar)
                            max_scale = (
                                scale_dict[real_tensor_name]
                                if max_scale is None
                                else max(
                                    max_scale, scale_dict[real_tensor_name]
                                )
                            )
1263
                        else:
1264 1265
                            if tensor_name not in scale_dict.keys():
                                continue
1266 1267 1268 1269 1270
                            max_scale = (
                                scale_dict[tensor_name]
                                if max_scale is None
                                else max(max_scale, scale_dict[tensor_name])
                            )
1271 1272 1273 1274

                    for tensor_name in tensor_list:
                        if '#' in tensor_name:
                            real_tensor_name, opera, scalar = tensor_name.split(
1275 1276
                                '#'
                            )
1277 1278
                            if real_tensor_name not in scale_dict.keys():
                                continue
1279 1280
                            if opera == '*':
                                scale_dict[
1281 1282
                                    real_tensor_name
                                ] = max_scale / float(scalar)
1283 1284
                            elif opera == '/':
                                scale_dict[
1285 1286
                                    real_tensor_name
                                ] = max_scale * float(scalar)
1287
                        else:
1288 1289
                            if tensor_name not in scale_dict.keys():
                                continue
1290 1291 1292 1293
                            scale_dict[tensor_name] = max_scale
            self._scale_dict = scale_dict

        for key, val in self._scale_dict.items():
1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305
            utils.set_variable_data(
                self._scope,
                self._place,
                key + "@scale",
                np.array([val], dtype=np.float32),
            )
            utils.set_variable_data(
                self._scope,
                self._place,
                key + ".quant_dequant@scale",
                np.array([val], dtype=np.float32),
            )
1306

1307 1308
        if not self._onnx_format:
            # apply QuantizationFreezePass, and obtain the final quant model
1309 1310 1311 1312 1313 1314 1315 1316 1317
            if self._freeze_model:
                freeze_pass = QuantizationFreezePass(
                    scope=self._scope,
                    place=self._place,
                    bias_correction=self._bias_correction,
                    weight_bits=self._weight_bits,
                    round_type=self._round_type,
                    activation_bits=self._activation_bits,
                    weight_quantize_type=self._weight_quantize_type,
1318
                    quantizable_op_type=self.quant_config.weight_quant_operation_types,
1319
                )
1320 1321 1322 1323

                for sub_graph in graph.all_sub_graphs():
                    sub_graph._for_test = True
                    freeze_pass.apply(sub_graph)
1324 1325 1326 1327 1328
        else:
            quant_weight_pass = QuantWeightPass(self._scope, self._place)
            for sub_graph in graph.all_sub_graphs():
                sub_graph._for_test = True
                quant_weight_pass.apply(sub_graph)
1329

1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345
            infer_pass_quant_op_types = (
                self.quant_config.weight_quant_operation_types
                + self.quant_config.activation_quant_operation_types
                + self.quant_config.observer_operation_types
            )
            out_scale_infer_pass = AddQuantDequantForInferencePass(
                scope=self._scope,
                place=self._place,
                quant_bits=self._activation_bits,
                quantizable_op_type=infer_pass_quant_op_types,
                calibration_range_dict=self._scale_dict,
            )
            for sub_graph in graph.all_sub_graphs():
                sub_graph._for_test = True
                out_scale_infer_pass.apply(sub_graph)

1346 1347
        self._program = graph.to_program()

1348
    def _save_output_threshold(self):
1349
        '''
1350
        Save output threshold to the quantized op.
1351
        '''
1352
        self._calibration_scales = {}
1353

1354
        def save_info(
1355 1356 1357 1358 1359 1360
            op_node,
            out_var_name,
            threshold_map,
            out_info_name,
            argname_index,
            quantized_type,
1361
        ):
1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376
            if (out_var_name in self._zero_size_var_names) and (
                out_var_name not in threshold_map
            ):
                _logger.warning(
                    "{} is zero-size tensor and unable to calibrate, so skip quant it.".format(
                        out_var_name
                    )
                )
                return
            else:
                assert (
                    out_var_name in threshold_map
                ), "The output ({}) of {} node does not have threshold.".format(
                    out_var_name, op_node.type
                )
1377 1378
            if self._onnx_format:
                # For easy extension, every var_node set a dict to save parameters of quant.
1379 1380 1381
                self._calibration_scales[out_var_name] = {}
                self._calibration_scales[out_var_name]['scale'] = threshold_map[
                    out_var_name
1382
                ]
1383
            else:
1384 1385 1386 1387 1388
                op_node._set_attr(out_info_name, threshold_map[out_var_name])
                op_node._set_attr(
                    argname_index[0] + str(argname_index[1]) + "_threshold",
                    threshold_map[out_var_name],
                )
1389
                op_node._set_attr("with_quant_attr", True)
1390 1391 1392 1393 1394 1395
                if (
                    op_node.type
                    in self.quant_config.weight_quant_operation_types
                    or op_node.type
                    in self.quant_config.activation_quant_operation_types
                ):
1396
                    op._set_attr("quantization_type", quantized_type)
1397 1398

        def analysis_and_save_info(op_node, out_var_name):
1399
            argname_index = utils._get_output_name_index(op_node, out_var_name)
1400
            assert argname_index is not None, (
1401
                out_var_name + " is not the output of the op"
1402
            )
1403
            if self._algo in ["KL", "hist"]:
X
XGZhang 已提交
1404
                # For compatibility, we save output threshold by two methods.
1405
                save_info(
1406 1407 1408 1409
                    op_node,
                    out_var_name,
                    self._quantized_var_threshold,
                    "out_threshold",
1410 1411
                    argname_index,
                    "post_" + str(self._algo).lower(),
1412
                )
H
handiz 已提交
1413
            elif self._algo in ["avg", "abs_max", "mse", "emd", "ptf"]:
X
XGZhang 已提交
1414
                save_info(
1415 1416 1417 1418
                    op_node,
                    out_var_name,
                    self._quantized_threshold,
                    "out_threshold",
1419
                    argname_index,
1420 1421
                    "post_" + str(self._algo),
                )
1422
            elif self._algo == "min_max":
1423 1424 1425 1426 1427
                save_info(
                    op_node,
                    out_var_name,
                    self._quantized_var_min,
                    "out_min",
1428
                    argname_index,
1429 1430 1431 1432 1433 1434 1435
                    "post_min_max",
                )
                save_info(
                    op_node,
                    out_var_name,
                    self._quantized_var_max,
                    "out_max",
1436
                    argname_index,
1437 1438
                    "post_min_max",
                )
1439

1440 1441
        for block_id in range(len(self._program.blocks)):
            for op in self._program.blocks[block_id].ops:
1442
                if op.type in (
1443 1444 1445
                    self.quant_config.weight_quant_operation_types
                    + self.quant_config.activation_quant_operation_types
                    + self.quant_config.observer_operation_types
1446
                ):
1447
                    out_var_names = utils._get_op_output_var_names(op)
1448 1449
                    for var_name in out_var_names:
                        analysis_and_save_info(op, var_name)
1450

1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469
    def _collect_dynamic_quantize_op_threshold(self, target_ops_type):
        """
        Collect and save the weight threshold for dynamic quantize ops,
        such as lstm and gru.
        Args:
            target_ops_type(list): the op type of target ops
        Returns:
            None
        """

        target_ops = []
        for index in range(self._program.num_blocks):
            for op in self._program.block(index).ops:
                if op.type in target_ops_type:
                    target_ops.append(op)

        quantization_type = str("post_" + self._algo).lower()
        persistable_var_names = _all_persistable_var_names(self._program)
        for op in target_ops:
1470
            for var_name in utils._get_op_input_var_names(op):
1471
                if var_name in persistable_var_names:
1472
                    var_data = utils.load_variable_data(self._scope, var_name)
1473
                    threshold = float(np.max(np.abs(var_data)))
1474
                    argname, index = utils._get_input_name_index(op, var_name)
1475 1476 1477
                    op._set_attr(argname + str(index) + "_threshold", threshold)
                    op._set_attr("quantization_type", quantization_type)
                    op._set_attr("bit_length", self._weight_bits)
1478
                    op._set_attr("with_quant_attr", True)
1479

X
XGZhang 已提交
1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495
    def _get_hist_scaling_factor(self, hist, hist_edges):
        '''
        Using the hist method to get the scaling factor.
        '''
        threshold_rate = self._hist_percent
        hist = hist / float(sum(hist))
        hist_sum = 0
        hist_index = 0
        for i in range(len(hist)):
            hist_sum += hist[i]
            if hist_sum >= threshold_rate:
                hist_index = i + 1
                break
        bin_width = hist_edges[1] - hist_edges[0]
        return (hist_index - 0.5) * bin_width

1496

1497
class PostTrainingQuantizationProgram(PostTrainingQuantization):
1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562
    def __init__(
        self,
        executor,
        program,
        feed_list=None,
        fetch_list=None,
        scope=None,
        batch_generator=None,
        sample_generator=None,
        data_loader=None,
        batch_size=10,
        batch_nums=None,
        algo="KL",
        hist_percent=0.99999,
        quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
        round_type='round',
        learning_rate=0.001,
        is_full_quantize=False,
        bias_correction=False,
        activation_bits=8,
        weight_bits=8,
        activation_quantize_type='range_abs_max',
        weight_quantize_type='channel_wise_abs_max',
        onnx_format=False,
        freeze_model=True,
        optimize_model=False,
        is_use_cache_file=False,
        skip_tensor_list=None,
        same_scale_tensor_list=None,
        cache_dir=None,
        scale_dict=None,
        return_graph=True,
    ):
        super().__init__(
            executor,
            scope,
            None,
            None,
            None,
            batch_generator,
            sample_generator,
            data_loader,
            batch_size,
            batch_nums,
            algo,
            hist_percent,
            quantizable_op_type,
            round_type,
            learning_rate,
            is_full_quantize,
            bias_correction,
            activation_bits,
            weight_bits,
            activation_quantize_type,
            weight_quantize_type,
            onnx_format,
            freeze_model,
            optimize_model,
            is_use_cache_file,
            skip_tensor_list,
            same_scale_tensor_list,
            cache_dir,
            scale_dict,
            return_graph,
        )
1563
        self.FLAG = False
1564
        self._program = program
1565 1566
        if self._program is not None:
            self.FLAG = True
1567 1568
        assert feed_list is not None, "Feed list should not be None."
        assert fetch_list is not None, "Fetch list should not be None."
1569 1570 1571 1572
        self._feed_list = feed_list
        self._fetch_list = fetch_list


1573
class WeightQuantization:
1574
    _supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
1575
    _supported_weight_quantize_type = ['channel_wise_abs_max', 'abs_max']
1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596

    def __init__(self, model_dir, model_filename=None, params_filename=None):
        '''
        This class quantizes the weight of some ops to reduce the size of model
        or improve the perforemace.

        Args:
            model_dir(str): The path of the fp32 model that will be quantized,
                and the model and params files are under the path.
            model_filename(str, optional): The name of file to load the inference
                program. If it is None, the default filename '__model__' will
                be used. Default is 'None'.
            params_filename(str, optional): The name of file to load all parameters.
                When all parameters were saved in a single binary file, set it
                as the real filename. If parameters were saved in separate files,
                set it as 'None'. Default is 'None'.
        '''
        self._model_dir = model_dir
        self._model_filename = model_filename
        self._params_filename = params_filename

1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607
    def quantize_weight_to_int(
        self,
        save_model_dir,
        save_model_filename=None,
        save_params_filename=None,
        quantizable_op_type=["conv2d", "mul"],
        weight_bits=8,
        weight_quantize_type="channel_wise_abs_max",
        generate_test_model=False,
        threshold_rate=0.0,
    ):
1608 1609
        '''
        In order to reduce the size of model, this api quantizes the weight
1610
        of some ops from float32 to int8/16. In the inference stage, the
1611
        quantized weight will be dequantized to float32 again.
1612

1613 1614
        Args:
            save_model_dir(str): The path to save the quantized model.
1615 1616
            save_model_filename(str, optional): The name of file to
                save the inference program. If it is None, the default
1617
                filename '__model__' will be used. Default is 'None'.
1618 1619 1620
            save_params_filename(str, optional): The name of file to
                save all parameters. If it is None, parameters were
                saved in separate files. If it is not None, all
1621
                parameters were saved in a single binary file.
1622
            quantizable_op_type(list[str], optional): The list of ops
1623
                that will be quantized, and the quantized ops should be
1624
                contained in ["conv2d", "depthwise_conv2d", "mul"].
1625
                Default is ["conv2d","mul"].
1626
            weight_bits(int, optional): The bits for the quantized weight,
1627
                and it should be 8 or 16. Default is 8.
1628 1629 1630
            weight_quantize_type(str, optional): quantization type for weights,
                support 'channel_wise_abs_max' and 'abs_max'. Set it as
                'channel_wise_abs_max', the accuracy performs better.
1631 1632 1633
            generate_test_model(bool, optional): If set generate_test_model
                as True, it saves a fake quantized model, in which the weights
                are quantized and dequantized. We can use PaddlePaddle to load
1634
                the fake quantized model and test the accuracy on GPU or CPU.
1635 1636 1637 1638 1639
            threshold_rate(float, optional): This api uses abs_max methd to
                quantize the weight from float32 to int8/16, and the abs max
                value is important for quantization diff. When the abs_max
                value is far away from the center of the numerical distribution,
                we can set threshold_rate between 1e-6 and 1e-8, so the abs max
1640 1641 1642
                value will be optimized. Default is 0.0.
        '''
        for op_type in quantizable_op_type:
1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656
            assert op_type in self._supported_quantizable_op_type, (
                "Input error:"
                + op_type
                + " is not supported for weight quantization."
            )
        assert weight_bits in [
            8,
            16,
        ], "Input error: weight_bits should be 8 or 16."
        assert (
            weight_quantize_type in self._supported_weight_quantize_type
        ), "Input error: weight_quantize_type should in {}".format(
            self._supported_weight_quantize_type
        )
1657 1658

        quantized_model_dir = os.path.join(save_model_dir, "quantized_model")
1659 1660 1661 1662 1663 1664 1665 1666 1667 1668
        self._quantize_weight_to_int(
            quantized_model_dir,
            save_model_filename,
            save_params_filename,
            quantizable_op_type,
            weight_bits,
            weight_quantize_type,
            False,
            threshold_rate,
        )
1669 1670 1671

        if generate_test_model:
            test_model_dir = os.path.join(save_model_dir, "test_model")
1672 1673 1674 1675 1676 1677 1678 1679 1680 1681
            self._quantize_weight_to_int(
                test_model_dir,
                save_model_filename,
                save_params_filename,
                quantizable_op_type,
                weight_bits,
                weight_quantize_type,
                True,
                threshold_rate,
            )
1682

1683 1684 1685 1686
    def convert_weight_to_fp16(self, save_model_dir):
        """
        Convert all presistable vars from fp32 to fp16.
        Note that, this api only changes the data type of variables in
1687
        __params__ file, and the __model__ file remains unchanged.
1688 1689 1690 1691 1692 1693 1694

        Args:
            save_model_dir(str): The path to save the fp16 model.
        """

        # Load model
        place = core.CPUPlace()
1695 1696 1697 1698
        exe = static.Executor(place)
        scope = static.global_scope()
        [infer_program, feed_list, fetch_list] = static.load_inference_model(
            self._model_dir,
1699 1700 1701 1702
            executor=exe,
            model_filename=self._model_filename,
            params_filename=self._params_filename,
        )
1703 1704

        # Clone and save fp16 weights
1705
        save_program = static.Program()
1706 1707 1708 1709
        save_block = save_program.global_block()
        save_var_map = {}

        for var in infer_program.list_vars():
1710 1711 1712 1713 1714 1715
            if (
                (var.type == core.VarDesc.VarType.RAW)
                or (not var.persistable)
                or (var.name in ['feed', 'fetch'])
                or (var.dtype != core.VarDesc.VarType.FP32)
            ):
1716 1717
                continue

1718
            # new_var = _clone_var_to_block_(var, save_block)
1719 1720 1721 1722
            new_var = save_block._clone_variable(var)
            if self._params_filename is not None:
                save_var_map[new_var.name] = new_var
            else:
1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734
                save_file_path = os.path.join(
                    os.path.normpath(save_model_dir), new_var.name
                )
                save_block.append_op(
                    type='save',
                    inputs={'X': [new_var]},
                    outputs={},
                    attrs={
                        'file_path': os.path.normpath(save_file_path),
                        'save_as_fp16': True,
                    },
                )
1735 1736 1737 1738 1739 1740 1741 1742

        if self._params_filename is not None:
            save_var_list = []
            for name in sorted(save_var_map.keys()):
                save_var_list.append(save_var_map[name])

            saved_params_var = save_block.create_var(
                type=core.VarDesc.VarType.RAW,
1743 1744
                name=unique_name.generate("saved_params"),
            )
1745 1746
            saved_params_var.desc.set_persistable(True)

1747 1748 1749 1750 1751 1752 1753 1754 1755
            save_path = os.path.join(
                os.path.normpath(save_model_dir), self._params_filename
            )
            save_block.append_op(
                type='save_combine',
                inputs={'X': save_var_list},
                outputs={'Y': saved_params_var},
                attrs={'file_path': save_path, 'save_as_fp16': True},
            )
1756 1757 1758 1759 1760

        save_program._sync_with_cpp()
        exe.run(save_program)

        # Copy model
1761 1762 1763 1764 1765
        model_filename = (
            "__model__"
            if self._model_filename is None
            else self._model_filename
        )
1766 1767 1768 1769
        src_model = os.path.join(self._model_dir, model_filename)
        dest_model = os.path.join(save_model_dir, model_filename)
        shutil.copyfile(src_model, dest_model)

1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780
    def _quantize_weight_to_int(
        self,
        save_model_dir,
        save_model_filename,
        save_params_filename,
        quantizable_op_type,
        weight_bits,
        weight_quantize_type,
        for_test,
        threshold_rate,
    ):
1781 1782 1783 1784
        """
        Generate quantized model or fake quantized model.
        """
        # Load model
1785
        place = core.CPUPlace()
1786 1787 1788 1789
        exe = static.Executor(place)
        scope = static.global_scope()
        [program, feed_list, fetch_list] = static.load_inference_model(
            self._model_dir,
1790 1791 1792 1793
            executor=exe,
            model_filename=self._model_filename,
            params_filename=self._params_filename,
        )
1794

1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808
        quantized_ops = []
        for index in range(program.num_blocks):
            block = program.block(index)
            for op in block.ops:
                if op.type in quantizable_op_type:
                    quantized_ops.append(op)

        # Quantize weights
        persistable_var_names = _all_persistable_var_names(program)
        for op in quantized_ops:
            for var_name in op.input_arg_names:
                if var_name in persistable_var_names:
                    if weight_quantize_type == "abs_max":
                        self._weight_abs_max_quantization(
1809 1810 1811 1812 1813 1814 1815 1816
                            scope,
                            place,
                            weight_bits,
                            threshold_rate,
                            op,
                            var_name,
                            for_test,
                        )
1817 1818
                    elif weight_quantize_type == "channel_wise_abs_max":
                        self._weight_channel_wise_abs_max_quantization(
1819 1820
                            scope, place, weight_bits, op, var_name, for_test
                        )
1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834
        model_name = None
        if save_model_filename is None:
            model_name = "model"
        elif save_model_filename.endswith(".pdmodel"):
            model_name = save_model_filename.rsplit(".", 1)[0]
        else:
            model_name = save_model_filename

        path_prefix = os.path.join(save_model_dir, model_name)
        feed_vars = [program.global_block().var(name) for name in feed_list]
        static.save_inference_model(
            path_prefix,
            feed_vars,
            fetch_list,
1835
            executor=exe,
1836
            program=program,
1837 1838 1839 1840 1841
        )

    def _weight_abs_max_quantization(
        self, scope, place, weight_bits, threshold_rate, op, var_name, for_test
    ):
1842 1843 1844 1845 1846 1847 1848
        '''
        Use abs_max method to quantize weight.
        '''
        quantize_range = (1 << (weight_bits - 1)) - 1
        save_weight_dtype = np.int8 if weight_bits == 8 else np.int16

        # Get quantized scale and weight data
1849
        weight_data = utils.load_variable_data(scope, var_name)
1850 1851 1852
        if abs(threshold_rate) < 1e-10:
            threshold_value = np.max(np.abs(weight_data))
        else:
1853 1854 1855
            threshold_value = self._calculate_threshold(
                weight_data, threshold_rate
            )
1856 1857 1858
            weight_data[weight_data > threshold_value] = threshold_value
            weight_data[weight_data < -threshold_value] = -threshold_value
        scale = threshold_value / quantize_range
1859 1860 1861
        quantized_weight_data = np.around(weight_data / scale).astype(
            save_weight_dtype
        )
1862 1863 1864

        # Set weight data
        if not for_test:
1865 1866 1867
            utils.set_variable_data(
                scope, place, var_name, quantized_weight_data
            )
1868
        else:
1869 1870 1871 1872 1873 1874
            dequantized_weight_data = (quantized_weight_data * scale).astype(
                np.float32
            )
            utils.set_variable_data(
                scope, place, var_name, dequantized_weight_data
            )
1875 1876 1877 1878 1879

        # Save info
        op._set_attr('quantization_type', 'post_weight_abs_max')
        op._set_attr('quantize_weight_bits', weight_bits)
        op._set_attr(var_name + "_quant_scale", [scale])  # Save as list
1880
        op._set_attr("with_quant_attr", True)
1881

1882 1883 1884
    def _weight_channel_wise_abs_max_quantization(
        self, scope, place, weight_bits, op, var_name, for_test
    ):
1885
        '''
1886 1887 1888 1889 1890 1891
        Use channel_wise_abs_max method to quantize weight.
        '''
        quantize_range = (1 << (weight_bits - 1)) - 1
        save_weight_dtype = np.int8 if weight_bits == 8 else np.int16

        # Get quantized scale and weight data
1892
        weight_data = utils.load_variable_data(scope, var_name)
1893
        if op.type == "mul":
1894 1895 1896
            scales, quantized_weight_data = self._mul_channel_wise_quantization(
                weight_data, quantize_range, save_weight_dtype
            )
1897
        elif op.type in ["conv2d", "depthwise_conv2d"]:
1898 1899 1900 1901 1902 1903
            (
                scales,
                quantized_weight_data,
            ) = self._conv_channel_wise_quantization(
                weight_data, quantize_range, save_weight_dtype
            )
1904 1905 1906 1907 1908
        else:
            _logger.error(op.type + " is not supported by weight quantization")

        # Set weight data
        if not for_test:
1909 1910 1911
            utils.set_variable_data(
                scope, place, var_name, quantized_weight_data
            )
1912 1913
        else:
            if op.type == "mul":
1914 1915 1916
                dequantized_weight_data = self._mul_channel_wise_dequantization(
                    quantized_weight_data, scales
                )
1917
            elif op.type in ["conv2d", "depthwise_conv2d"]:
1918 1919 1920 1921 1922
                dequantized_weight_data = (
                    self._conv_channel_wise_dequantization(
                        quantized_weight_data, scales
                    )
                )
1923
            else:
1924 1925 1926 1927 1928 1929
                _logger.error(
                    op.type + " is not supported by weight quantization"
                )
            utils.set_variable_data(
                scope, place, var_name, dequantized_weight_data
            )
1930 1931 1932 1933 1934

        # Save info
        op._set_attr('quantization_type', 'post_weight_channel_wise_abs_max')
        op._set_attr('quantize_weight_bits', weight_bits)
        op._set_attr(var_name + "_quant_scale", scales)
1935
        op._set_attr("with_quant_attr", True)
1936

1937 1938 1939
    def _conv_channel_wise_quantization(
        self, weight_data, quantize_range, save_weight_dtype
    ):
1940 1941 1942 1943 1944
        '''
        Get channel wise scale for the weights of conv2d and depthwise_conv2d,
        and quantize the weights.
        '''
        scales = []
1945 1946 1947
        quantized_weight_data = np.zeros_like(
            weight_data, dtype=save_weight_dtype
        )
1948 1949 1950 1951
        channel_num = weight_data.shape[0]
        for i in range(channel_num):
            scale = np.max(np.abs(weight_data[i])) / quantize_range
            scales.append(scale)
1952 1953 1954
            quantized_weight_data[i] = np.around(weight_data[i] / scale).astype(
                save_weight_dtype
            )
1955 1956 1957 1958 1959 1960
        return scales, quantized_weight_data

    def _conv_channel_wise_dequantization(self, quantized_weight_data, scales):
        '''
        For conv2d and depthwise_conv2d, dequantize the weights to fp32.
        '''
1961 1962 1963
        dequantized_weight_data = np.zeros_like(
            quantized_weight_data, dtype=np.float32
        )
1964
        for i in range(len(scales)):
1965 1966 1967
            dequantized_weight_data[i] = (
                quantized_weight_data[i] * scales[i]
            ).astype(np.float32)
1968 1969
        return dequantized_weight_data

1970 1971 1972
    def _mul_channel_wise_quantization(
        self, weight_data, quantize_range, save_weight_dtype
    ):
1973 1974 1975 1976 1977
        '''
        Get channel wise scale for the weights of conv2d and depthwise_conv2d,
        and quantize the weights.
        '''
        scales = []
1978 1979 1980
        quantized_weight_data = np.zeros_like(
            weight_data, dtype=save_weight_dtype
        )
1981 1982 1983 1984
        channel_num = weight_data.shape[-1]
        for i in range(channel_num):
            scale = np.max(np.abs(weight_data[:, i])) / quantize_range
            scales.append(scale)
1985 1986 1987
            quantized_weight_data[:, i] = np.around(
                weight_data[:, i] / scale
            ).astype(save_weight_dtype)
1988 1989 1990 1991 1992 1993
        return scales, quantized_weight_data

    def _mul_channel_wise_dequantization(self, quantized_weight_data, scales):
        '''
        For mul, dequantize the weights to fp32.
        '''
1994 1995 1996
        dequantized_weight_data = np.zeros_like(
            quantized_weight_data, dtype=np.float32
        )
1997
        for i in range(len(scales)):
1998 1999 2000
            dequantized_weight_data[:, i] = (
                quantized_weight_data[:, i] * scales[i]
            ).astype(np.float32)
2001 2002
        return dequantized_weight_data

2003 2004
    def _calculate_threshold(self, input, threshold_rate, histogram_bins=5000):
        input_abs = np.abs(input)
2005 2006 2007
        hist, hist_edeges = np.histogram(
            input_abs, bins=histogram_bins, range=(0, np.max(input_abs))
        )
2008 2009 2010 2011 2012 2013 2014 2015 2016 2017
        hist = hist / float(sum(hist))
        hist_sum = 0
        hist_index = 0
        for i in range(len(hist)):
            hist_sum += hist[i]
            if hist_sum >= 1.0 - threshold_rate:
                hist_index = i + 1
                break
        bin_width = hist_edeges[1] - hist_edeges[0]
        return hist_index * bin_width