quantization_pass.py 139.0 KB
Newer Older
1
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
W
WangZhen 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
16
import logging
17

W
WangZhen 已提交
18
import numpy as np
19

20 21 22 23
try:
    from tqdm import tqdm
except:
    from .utils import tqdm
24

2
201716010711 已提交
25
import paddle
26

27 28 29 30
from ...fluid.framework import IrGraph, IrNode
from ...framework import _get_paddle_place, core
from ...static import Program, data, program_guard, scope_guard
from ...utils import unique_name
31
from ..log_helper import get_logger
32
from . import utils
33 34 35 36 37
from .quant_config import (
    SUPPORT_ACT_QUANTIZATION_OP_DICT,
    SUPPORT_QUANTIZATION_OP_DICT,
    SUPPORT_WEIGHT_QUANTIZATION_OP_DICT,
)
W
WangZhen 已提交
38

39 40 41 42
_logger = get_logger(
    __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)

43
_fake_quant_op_list = [
44 45 46 47
    'fake_quantize_abs_max',
    'fake_quantize_range_abs_max',
    'fake_quantize_moving_average_abs_max',
    'fake_channel_wise_quantize_abs_max',
48 49 50
]

_fake_dequant_op_list = [
51 52
    'fake_dequantize_max_abs',
    'fake_channel_wise_dequantize_max_abs',
53 54
]

55
_fake_quant_dequant_op_list = [
56 57
    'fake_quantize_dequantize_moving_average_abs_max',
    "fake_channel_wise_quantize_dequantize_abs_max",
58
    "fake_quantize_dequantize_abs_max",
59 60
]

61 62
_conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose']

63
_SCALE_DEFAULT_VALUE = 0.001
64 65


66
def _init_var_node(var_node, value, scope, place):
67 68 69 70 71
    assert isinstance(
        value, np.ndarray
    ), 'The type of value should be numpy array.'
    assert scope is not None, 'The scope cannot be set None.'
    assert place is not None, 'The place cannot be set None.'
72 73 74 75
    tensor = scope.var(var_node.name()).get_tensor()
    tensor.set(value, place)


76 77 78 79 80
def _is_input_all_not_persistable(graph, op_node):
    '''
    Analyse the real inputs of the op node are all not persistable.
    '''
    is_input_all_not_persistable = True
81
    for var_name in utils._get_op_input_var_names(op_node):
82
        in_node = graph._find_node_by_name(op_node.inputs, var_name)
83 84 85
        is_input_all_not_persistable = is_input_all_not_persistable and (
            not in_node.persistable()
        )
86 87 88
    return is_input_all_not_persistable


89 90 91 92 93 94 95 96 97 98 99 100 101 102
def _check_grandchild_op_node(op_node, grandchild_op_name):
    '''
    Check whether the fake_quant node has a grandchild op node named
    grandchild_op_name.
    '''
    for out1_var_node in op_node.outputs:
        for out1_op_node in out1_var_node.outputs:
            for out2_var_node in out1_op_node.outputs:
                for out2_op_node in out2_var_node.outputs:
                    if out2_op_node.name() == grandchild_op_name:
                        return True
    return False


103
class QuantizationTransformPass:
104
    """
105 106
    Quantize the ops that have weights. Add quant and dequant ops for
    the quantized ops's inputs.
107
    """
108

109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
    def __init__(
        self,
        scope=None,
        place=None,
        weight_bits=8,
        activation_bits=8,
        activation_quantize_type='abs_max',
        weight_quantize_type='abs_max',
        window_size=10000,
        moving_rate=0.9,
        skip_pattern=['skip_quant'],
        quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'],
        weight_quantize_func=None,
        act_quantize_func=None,
        weight_preprocess_func=None,
        act_preprocess_func=None,
        optimizer_func=None,
        executor=None,
        is_test=None,
    ):
129
        r"""
130
        Constructor.
131

W
WangZhen 已提交
132
        Args:
133
            scope(static.Scope): When activation use 'range_abs_max' as the quantize
134 135
                type, this pass will create some new parameters. The scope is used to
                initialize these new parameters.
136
            place(static.CPUPlace|static.CUDAPlace|str): place is used to initialize new
137
                parameters described above. If it's string, It can be ``cpu``, and ``gpu:x``,
138
                where ``x`` is the index of the GPUs.
139
            weight_bits(int): quantization bit number for weights,
W
WangZhen 已提交
140
                the bias is not quantized.
141 142
            activation_bits(int): quantization bit number for activation.
            activation_quantize_type(str): quantization type for activation,
143 144 145 146 147
                now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
                If use 'abs_max' mode, the quantization scale will be calculated
                dynamically each step in both training and testing period. If use
                'range_abs_max', a static quantization scale will be calculated
                during training and used in inference.
148
            weight_quantize_type(str): quantization type for weights,
149 150 151
                support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max'
                usually is not used for weight, since weights are fixed once the
                model is well trained.
152 153
            window_size(int): the window size for 'range_abs_max' quantization.
            moving_rate(float): the param for 'moving_average_abs_max' quantization.
154
            skip_pattern(str or str list): The user-defined quantization skip pattern, which
155
                will be presented in the name scope of an op. When the skip pattern is
156 157
                detected in an op's name scope, the corresponding op will not be quantized.
            quantizable_op_type(list[str]): List the type of ops that will be quantized.
158 159
                Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
                QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
            weight_quantize_func(function): Function that defines how to quantize weight.
                Using this can quickly test if user's quantization method works or not.
                In this function, user should both define quantization function and
                dequantization function, that is, the function's input is non-quantized
                weight and function returns dequantized weight. If None, will use
                quantization op defined by 'weight_quantize_type'. Default is None.
            act_quantize_func(function): Function that defines how to quantize activation.
                Using this can quickly test if user's quantization method works or not.
                In this function, user should both define quantization and dequantization
                process, that is, the function's input is non-quantized activation and
                function returns dequantized activation. If None, will use quantization
                op defined by 'activation_quantize_type'. Default is None.
            weight_preprocess_func(function): Function that defines how to preprocess
                weight before quantization. Using this can quickly test if user's preprocess
                method works or not. The function's input is non-quantized weight and
                function returns processed weight to be quantized. If None, the weight will
                be quantized directly. Default is None.
            act_preprocess_func(function): Function that defines how to preprocess
                activation before quantization. Using this can quickly test if user's
                preprocess method works or not. The function's input is non-quantized
                activation and function returns processed activation to be quantized.
                If None, the activation will be quantized directly. Default is None.
            optimizer_func(function): Fuction return a optimizer. When 'is_test' is
                False and user want to use self-defined quantization function and
                preprocess function, this function must be set. Default is None.
            executor(Fluid.Executor): If user want to use self-defined quantization
                function and preprocess function, executor must be set for initialization.
187 188
                Default is None.

189

W
WangZhen 已提交
190 191
        Examples:
        .. code-block:: python
192
            # The original graph will be rewrite.
193 194
            import paddle.static as static
            from paddle.static.quantization \
195
                import QuantizationTransformPass
196 197
            from paddle.fluid.framework import IrGraph
            from paddle.framework import core
198

199 200 201
            graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
            place = paddle.CPUPlace()
            transform_pass = QuantizationTransformPass(static.global_scope(),
202
            place)
203
            transform_pass.apply(graph)
W
WangZhen 已提交
204
        """
205
        self._scope = scope
206
        self._place = _get_paddle_place(place)
207 208
        self._weight_bits = weight_bits
        self._activation_bits = activation_bits
209
        self._skip_pattern = skip_pattern
210 211 212 213 214 215
        self._weight_quantize_func = weight_quantize_func
        self._act_quantize_func = act_quantize_func
        self._weight_preprocess_func = weight_preprocess_func
        self._act_preprocess_func = act_preprocess_func
        self._optimizer = optimizer_func
        self._exe = executor
216
        quant_type = [
217 218 219 220
            'abs_max',
            'channel_wise_abs_max',
            'range_abs_max',
            'moving_average_abs_max',
221
        ]
222 223 224
        assert (
            activation_quantize_type != 'channel_wise_abs_max'
        ), "The activation quantization type does not support 'channel_wise_abs_max'."
W
WangZhen 已提交
225 226
        if activation_quantize_type not in quant_type:
            raise ValueError(
227
                "Unknown activation_quantize_type : '%s'. It can only be "
228 229 230
                "'abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
                % (str(activation_quantize_type))
            )
W
WangZhen 已提交
231 232
        if weight_quantize_type not in quant_type:
            raise ValueError(
233
                "Unknown weight_quantize_type: '%s'. It can only be "
234
                "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' "
235 236
                "or 'moving_average_abs_max'." % (str(weight_quantize_type))
            )
W
WangZhen 已提交
237

238 239 240
        self._activation_quantize_type = activation_quantize_type
        self._weight_quantize_type = weight_quantize_type
        self._window_size = window_size
241
        self._moving_rate = moving_rate
W
WangZhen 已提交
242

243 244
        self._quantizable_ops = quantizable_op_type
        for op in self._quantizable_ops:
245
            assert op in list(SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()), (
246
                op + " is not supported for quantization."
247
            )
248 249
        self._quantizable_grad_ops = [
            '%s_grad' % (op) for op in self._quantizable_ops
W
WangZhen 已提交
250
        ]
251
        self._is_test = is_test
252
        self._global_step = None
W
WangZhen 已提交
253

254 255 256
        self.create_var_map = {}
        self.create_op_map = {}

257
    def apply(self, graph):
258 259 260 261 262 263 264
        """
        Quantize the graph for training process. According to weight and
        activation quantization type, the graph will be added some fake
        quantize operators and fake dequantize operators.

        Args:
            graph(IrGraph): the applied graph.
265 266
        Returns:
            None
267
        """
268 269 270
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
271 272
        if self._is_test is None:
            self._is_test = graph.is_test()
W
WangZhen 已提交
273 274
        # marked the variable which has been dequantized.
        dequantized_vars = collections.OrderedDict()
275
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
276
        processed_vars = []
W
WangZhen 已提交
277

278
        def _quant_preprocess(op_node):
279 280
            user_skipped = False
            if isinstance(self._skip_pattern, list):
281 282 283 284
                user_skipped = op_node.op().has_attr("op_namescope") and any(
                    pattern in op_node.op().attr("op_namescope")
                    for pattern in self._skip_pattern
                )
285
            elif isinstance(self._skip_pattern, str):
286 287 288 289 290 291 292
                user_skipped = (
                    op_node.op().has_attr("op_namescope")
                    and op_node.op()
                    .attr("op_namescope")
                    .find(self._skip_pattern)
                    != -1
                )
293

294
            if user_skipped:
295
                op_node.op()._set_attr("skip_quant", True)
296
                op_node.op()._set_attr("with_quant_attr", True)
297

W
WangZhen 已提交
298
        def _transform_forward(graph, op):
299
            op.op()._set_attr("quantization_type", "qat_with_weight")
300
            op.op()._set_attr("with_quant_attr", True)
301 302
            inputs = op.inputs
            for var_node in inputs:
303 304
                if var_node.name() not in op.input_arg_names():
                    continue
W
WangZhen 已提交
305 306 307
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
                else:
308 309 310
                    name = var_node.name()
                    if name in processed_vars:
                        continue
311 312 313
                    is_weight = (
                        True if var_node.name() in persistable_vars else False
                    )
314 315

                    # if var node is weight and weight_preprocess_func is not None,
316
                    # will insert weight preprocess func
317
                    # to preorocess weight before quantization
318 319
                    # if var node is activation and act_preprocess_func is not None,
                    # will insert activation preprocess func
320 321 322
                    # to preorocess activation before quantization
                    if is_weight and self._weight_preprocess_func is not None:
                        var_node = self._insert_func(
323 324 325 326 327 328 329 330
                            graph, self._weight_preprocess_func, var_node, op
                        )
                    elif (
                        not is_weight and self._act_preprocess_func is not None
                    ):
                        var_node = self._insert_func(
                            graph, self._act_preprocess_func, var_node, op
                        )
331 332 333 334 335 336 337

                    # if var node is weight and weight_quantize_func is not None,
                    # will insert weight quantize func to quantize and dequantize weight
                    # if var node is activation and act_quantize_func is not None,
                    # will insert act quantize func to quantize and dequantize activation
                    if is_weight and self._weight_quantize_func is not None:
                        target_out_node = self._insert_func(
338 339
                            graph, self._weight_quantize_func, var_node, op
                        )
340 341 342 343
                        processed_vars.append(name)
                        continue
                    elif not is_weight and self._act_quantize_func is not None:
                        target_out_node = self._insert_func(
344 345
                            graph, self._act_quantize_func, var_node, op
                        )
346 347 348
                        processed_vars.append(name)
                        continue

349 350 351
                    quant_bits = (
                        self._weight_bits
                        if var_node.name() in persistable_vars
352
                        else self._activation_bits
353 354 355 356
                    )
                    quant_type = (
                        self._weight_quantize_type
                        if is_weight
357
                        else self._activation_quantize_type
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
                    )
                    if (
                        quant_type == 'channel_wise_abs_max'
                    ):  # Weight quantization
                        quant_axis = (
                            1
                            if op.name() in utils._channelwise_quant_axis1_ops
                            else 0
                        )
                        (
                            quant_var_node,
                            scale_var_node,
                        ) = self._insert_channel_quant_op(
                            graph, var_node, name, quant_bits, quant_axis
                        )
373
                        dequant_var_node = self._insert_channel_dequant_op(
374 375 376 377 378 379
                            graph,
                            quant_var_node,
                            [scale_var_node],
                            [quant_bits],
                            quant_axis,
                        )
380 381
                    else:
                        quant_var_node, scale_var_node = self._insert_quant_op(
382 383
                            graph, var_node, name, quant_bits, quant_type
                        )
384
                        dequant_var_node = self._insert_dequant_op(
385 386
                            graph, quant_var_node, scale_var_node, quant_bits
                        )
387
                    dequantized_vars[name] = dequant_var_node
388
                graph.update_input_link(var_node, dequant_var_node, op)
W
WangZhen 已提交
389 390 391

        def _transform_backward(graph, op):
            for var_node in op.inputs:
392 393
                if var_node.name() not in op.input_arg_names():
                    continue
W
WangZhen 已提交
394 395
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
396
                    graph.update_input_link(var_node, dequant_var_node, op)
W
WangZhen 已提交
397

X
XGZhang 已提交
398 399 400 401 402 403 404 405 406 407
        def _has_weight(op):
            has_weight = False
            for var_node in op.inputs:
                if var_node.name() not in op.input_arg_names():
                    continue
                name = var_node.name()
                if var_node.name() in persistable_vars:
                    has_weight = True
            return has_weight

408
        if not self._is_test:
W
WangZhen 已提交
409
            self._create_global_step(graph)
410
        ops = graph.all_op_nodes()
411 412 413
        # Do the preproccess of quantization, such as skipping some ops
        # for not being quantized.
        for op in ops:
414 415 416 417
            if (
                op.name() in self._quantizable_ops
                or op.name() in self._quantizable_grad_ops
            ):
418
                _quant_preprocess(op)
419 420
        # Insert mapping table to solve the problem in saving inference model.
        graph.out_node_mapping_table = dict()
W
WangZhen 已提交
421 422
        # The process of _transform_forward and _transform_backward is needed in two for loops.
        # The loop for transforming the forward graph:
423 424 425 426 427
        with tqdm(
            total=len(ops),
            bar_format='Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}',
            ncols=80,
        ) as t:
428 429 430 431 432
            for op in ops:
                if op.name() in self._quantizable_ops:
                    if not self._is_skip_quant(graph, op) and _has_weight(op):
                        _transform_forward(graph, op)
                t.update()
W
WangZhen 已提交
433 434
        # The loop for renaming the inputs of backward op.
        for op in ops:
X
XGZhang 已提交
435
            if op.name() in self._quantizable_grad_ops and _has_weight(op):
W
WangZhen 已提交
436
                _transform_backward(graph, op)
Z
Zhen Wang 已提交
437
        graph.resolve_hazard()
438
        return graph
W
WangZhen 已提交
439

W
WangZhen 已提交
440
    def _create_global_step(self, graph):
441 442 443 444
        if (
            self._weight_quantize_type == 'range_abs_max'
            or self._activation_quantize_type == 'range_abs_max'
        ):
445
            counter_name = '@STEP_COUNTER@'
446
            for node in graph.all_var_nodes():
W
WangZhen 已提交
447
                if node.name() == counter_name:
448 449
                    self._global_step = node
            if self._global_step is None:
450
                global_step_in = graph.create_persistable_node(
W
WangZhen 已提交
451 452 453
                    name=counter_name,
                    var_type=core.VarDesc.VarType.LOD_TENSOR,
                    shape=[1],
454 455 456 457 458 459 460 461
                    var_dtype=core.VarDesc.VarType.INT64,
                )
                _init_var_node(
                    global_step_in,
                    np.zeros([1], dtype='int64'),
                    self._scope,
                    self._place,
                )
W
WangZhen 已提交
462
                global_step_out = graph.create_var_node_from_desc(
463 464
                    global_step_in.var()
                )
465
                # The attribute of `op_role` is needed by ParallelExecutor.
W
WangZhen 已提交
466 467
                increment_op = graph.create_op_node(
                    op_type='increment',
468 469
                    attrs={
                        'step': 1.0,
470
                        'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
471
                    },
W
WangZhen 已提交
472
                    inputs={'X': global_step_in},
473 474
                    outputs={'Out': global_step_out},
                )
475 476 477
                graph.link_to(global_step_in, increment_op)
                graph.link_to(increment_op, global_step_out)
                self._global_step = global_step_out
W
WangZhen 已提交
478

479
    def _insert_quant_op(self, graph, var_node, name, quant_bits, quant_type):
W
WangZhen 已提交
480 481 482 483
        """
        Insert fake_quantize_op in the graph.
        """
        if quant_type == 'abs_max':
484 485 486
            return self._insert_quant_abs_max_op(
                graph, var_node, name, quant_bits
            )
W
WangZhen 已提交
487
        elif quant_type == 'range_abs_max':
488 489 490
            return self._insert_quant_range_abs_max_op(
                graph, var_node, name, quant_bits
            )
491
        elif quant_type == 'moving_average_abs_max':
492
            return self._insert_quant_moving_average_abs_max_op(
493 494
                graph, var_node, name, quant_bits
            )
W
WangZhen 已提交
495

496
    def _insert_quant_abs_max_op(self, graph, var_node, name, quant_bits):
W
WangZhen 已提交
497 498 499 500 501 502
        """
        Insert fake_quantize_abs_max op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
503
            name=self._quantized_var_name(name),
504 505
            var_type=var_node.type(),
            shape=var_node.shape(),
506 507
            var_dtype=var_node.dtype(),
        )
508
        scale_name = self._quantized_scale_name(name)
509 510 511 512 513 514
        if var_node.dtype() == core.VarDesc.VarType.FP64:
            data_type = 'float64'
        elif var_node.dtype() == core.VarDesc.VarType.FP32:
            data_type = 'float32'
        else:
            data_type = "float16"
515 516
        try:
            scale_value = np.array(
517 518
                self._scope.find_var(scale_name).get_tensor()
            )
519 520
        except:
            scale_value = np.zeros([1], dtype=data_type)
521
        scale_var_node = graph.create_persistable_node(
522
            name=scale_name,
523
            var_type=var_node.type(),
524
            shape=[1],
525 526
            var_dtype=var_node.dtype(),
        )
527 528
        _init_var_node(scale_var_node, scale_value, self._scope, self._place)

W
WangZhen 已提交
529 530
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_abs_max',
531 532
            attrs={
                'bit_length': quant_bits,
533
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
534
            },
W
WangZhen 已提交
535
            inputs={'X': var_node},
536 537
            outputs={'Out': quant_var_node, 'OutScale': scale_var_node},
        )
538 539 540
        graph.link_to(var_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_var_node)
W
WangZhen 已提交
541 542
        return quant_var_node, scale_var_node

543
    def _insert_quant_range_abs_max_op(self, graph, var_node, name, quant_bits):
W
WangZhen 已提交
544 545 546 547 548 549
        """
        Insert fake_quantize_range_abs_max on the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
550
            name=self._quantized_var_name(name),
551 552
            var_type=var_node.type(),
            shape=var_node.shape(),
553 554
            var_dtype=var_node.dtype(),
        )
W
WangZhen 已提交
555

556
        scale_name = self._quantized_scale_name(name)
557 558 559 560 561 562
        if var_node.dtype() == core.VarDesc.VarType.FP64:
            data_type = 'float64'
        elif var_node.dtype() == core.VarDesc.VarType.FP32:
            data_type = 'float32'
        else:
            data_type = "float16"
563 564
        try:
            scale_value = np.array(
565 566
                self._scope.find_var(scale_name).get_tensor()
            )
567 568
        except:
            scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)
569
        scale_in_node = graph.create_persistable_node(
570
            name=scale_name,
W
WangZhen 已提交
571 572
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
573 574
            var_dtype=var_node.dtype(),
        )
575
        _init_var_node(scale_in_node, scale_value, self._scope, self._place)
W
WangZhen 已提交
576 577 578 579 580

        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        inputs = {'X': var_node, 'InScale': scale_in_node}
        outputs = {'Out': quant_var_node, 'OutScale': scale_out_node}

581
        if not self._is_test:
W
WangZhen 已提交
582
            # The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
583
            scales_node = graph.create_persistable_node(
W
WangZhen 已提交
584 585
                name=unique_name.generate('scales'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
586
                shape=[self._window_size],
587 588
                var_dtype=var_node.dtype(),
            )
589 590 591 592 593 594
            if var_node.dtype() == core.VarDesc.VarType.FP64:
                data_type = 'float64'
            elif var_node.dtype() == core.VarDesc.VarType.FP32:
                data_type = 'float32'
            else:
                data_type = "float16"
595 596 597 598 599 600
            _init_var_node(
                scales_node,
                np.zeros([self._window_size], dtype=data_type),
                self._scope,
                self._place,
            )
601

602
            inputs['Iter'] = self._global_step
W
WangZhen 已提交
603 604
            outputs['OutScales'] = scales_node
        attrs = {
605
            'window_size': self._window_size,
W
WangZhen 已提交
606
            'bit_length': quant_bits,
607
            'is_test': self._is_test,
608
            'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
W
WangZhen 已提交
609 610 611 612 613
        }
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_range_abs_max',
            attrs=attrs,
            inputs=inputs,
614 615
            outputs=outputs,
        )
W
WangZhen 已提交
616

617 618 619 620
        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_in_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_out_node)
W
WangZhen 已提交
621

622 623 624
        if not self._is_test:
            graph.link_to(self._global_step, quant_op_node)
            graph.link_to(quant_op_node, scales_node)
W
WangZhen 已提交
625 626 627

        return quant_var_node, scale_out_node

628 629 630 631
    def _insert_quant_moving_average_abs_max_op(
        self, graph, var_node, name, quant_bits
    ):
        """Insert fake_quantize_moving_average_abs_max"""
632
        quant_var_node = graph.create_var_node(
633
            name=self._quantized_var_name(name),
634 635
            var_type=var_node.type(),
            shape=var_node.shape(),
636 637
            var_dtype=var_node.dtype(),
        )
638
        scale_name = self._quantized_scale_name(name)
639 640 641 642 643 644
        if var_node.dtype() == core.VarDesc.VarType.FP64:
            data_type = 'float64'
        elif var_node.dtype() == core.VarDesc.VarType.FP32:
            data_type = 'float32'
        else:
            data_type = "float16"
645 646
        try:
            scale_value = np.array(
647 648
                self._scope.find_var(scale_name).get_tensor()
            )
649 650
        except:
            scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)
651
        scale_in_node = graph.create_persistable_node(
652
            name=scale_name,
653 654
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
655 656
            var_dtype=var_node.dtype(),
        )
657
        _init_var_node(scale_in_node, scale_value, self._scope, self._place)
658 659 660 661 662 663 664 665 666

        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        ins = {'X': var_node, 'InScale': scale_in_node}
        outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
        if not self._is_test:
            state_in_node = graph.create_persistable_node(
                name=unique_name.generate('state'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
667 668
                shape=[1],
            )
669 670 671 672 673 674
            if var_node.dtype() == core.VarDesc.VarType.FP64:
                data_type = 'float64'
            elif var_node.dtype() == core.VarDesc.VarType.FP32:
                data_type = 'float32'
            else:
                data_type = "float16"
675 676 677 678 679 680
            _init_var_node(
                state_in_node,
                np.ones([1], dtype=data_type),
                self._scope,
                self._place,
            )
681 682 683 684
            accum_in_node = graph.create_persistable_node(
                name=unique_name.generate('accum'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
685 686 687 688 689 690 691 692
                shape=[1],
            )
            _init_var_node(
                accum_in_node,
                np.ones([1], dtype=data_type),
                self._scope,
                self._place,
            )
693
            state_out_node = graph.create_var_node_from_desc(
694 695
                state_in_node.var()
            )
696
            accum_out_node = graph.create_var_node_from_desc(
697 698
                accum_in_node.var()
            )
699 700 701 702 703 704 705 706 707 708

            ins['InState'] = state_in_node
            ins['InAccum'] = accum_in_node
            outs['OutState'] = state_out_node
            outs['OutAccum'] = accum_out_node

        attrs = {
            'bit_length': quant_bits,
            'moving_rate': self._moving_rate,
            'is_test': self._is_test,
709
            'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
710 711 712 713 714 715
        }

        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_moving_average_abs_max',
            attrs=attrs,
            inputs=ins,
716 717
            outputs=outs,
        )
718 719 720 721 722 723 724 725 726 727 728 729 730 731

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_in_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_out_node)

        if not self._is_test:
            graph.link_to(state_in_node, quant_op_node)
            graph.link_to(accum_in_node, quant_op_node)
            graph.link_to(quant_op_node, state_out_node)
            graph.link_to(quant_op_node, accum_out_node)

        return quant_var_node, scale_out_node

732 733 734
    def _insert_channel_quant_op(
        self, graph, var_node, name, quant_bits, quant_axis
    ):
735 736 737 738 739 740
        """
        Insert fake_channel_wise_quantize_abs_max op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
741
            name=self._quantized_var_name(name),
742 743
            var_type=var_node.type(),
            shape=var_node.shape(),
744 745
            var_dtype=var_node.dtype(),
        )
746
        scale_name = self._quantized_scale_name(name)
747 748 749 750 751 752
        if var_node.dtype() == core.VarDesc.VarType.FP64:
            data_type = 'float64'
        elif var_node.dtype() == core.VarDesc.VarType.FP32:
            data_type = 'float32'
        else:
            data_type = "float16"
753 754
        try:
            scale_value = np.array(
755 756
                self._scope.find_var(scale_name).get_tensor()
            )
757
        except:
758 759 760
            scale_value = np.zeros(
                [var_node.shape()[quant_axis]], dtype=data_type
            )
761
        scale_var_node = graph.create_persistable_node(
762
            name=self._quantized_scale_name(name),
763
            var_type=var_node.type(),
764
            shape=[var_node.shape()[quant_axis]],
765 766
            var_dtype=var_node.dtype(),
        )
767
        _init_var_node(scale_var_node, scale_value, self._scope, self._place)
768 769 770 771
        quant_op_node = graph.create_op_node(
            op_type='fake_channel_wise_quantize_abs_max',
            attrs={
                'bit_length': quant_bits,
772
                'quant_axis': quant_axis,
773
                'is_test': self._is_test,
774
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
775 776
            },
            inputs={'X': var_node},
777 778
            outputs={'Out': quant_var_node, 'OutScale': scale_var_node},
        )
779 780 781 782 783
        graph.link_to(var_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_var_node)
        return quant_var_node, scale_var_node

W
WangZhen 已提交
784 785 786 787 788 789 790 791
    def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
        """
        Insert fake_dequantize_op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(var_node.name()),
792 793
            var_type=var_node.type(),
            shape=var_node.shape(),
794 795
            var_dtype=var_node.dtype(),
        )
W
WangZhen 已提交
796 797 798
        max_range = (1 << (quant_bits - 1)) - 1
        dequant_op_node = graph.create_op_node(
            op_type='fake_dequantize_max_abs',
799 800
            attrs={
                'max_range': float(max_range),
801
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
802
            },
803 804 805
            inputs={'X': var_node, 'Scale': scale_var_node},
            outputs={'Out': dequant_var_node},
        )
806 807 808
        graph.link_to(var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
W
WangZhen 已提交
809 810
        return dequant_var_node

811 812 813
    def _insert_channel_dequant_op(
        self, graph, var_node, scale_var_nodes, quant_bits, quant_axis
    ):
814 815 816 817 818 819 820 821 822
        """
        Insert fake_channel_wise_dequantize_max_abs in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
823 824
            var_dtype=var_node.dtype(),
        )
825 826 827 828
        dequant_op_node = graph.create_op_node(
            op_type='fake_channel_wise_dequantize_max_abs',
            attrs={
                'quant_bits': quant_bits,
829
                'quant_axis': quant_axis,
830
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
831
            },
832 833 834
            inputs={'X': var_node, 'Scales': scale_var_nodes},
            outputs={'Out': dequant_var_node},
        )
835 836 837 838 839 840
        graph.link_to(var_node, dequant_op_node)
        for scale_n in scale_var_nodes:
            graph.link_to(scale_n, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
        return dequant_var_node

841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868
    def _create_new_node(self, graph, in_node):
        """
        create a node that same with in_node in graph
        Args:
            graph(IrGraph): create node in graph.
            in_node(IrVarNode): create node that same with in_node.
        Returns:
            created new node
        """
        key = ''
        for inp in in_node.inputs:
            key = key + inp.name()
        key = key + in_node.name()
        for inp in in_node.outputs:
            key = key + inp.name()

        if key in self.create_var_map.keys():
            new_node = self.create_var_map[key]
        elif in_node.is_ctrl_var():
            new_node = graph.create_control_dep_var()
            self.create_var_map[key] = new_node
        else:
            new_node = graph.create_var_node_from_desc(in_node.node.var())
            self.create_var_map[key] = new_node
        return new_node

    def _copy_graph(self, graph, source_graph, op_node):
        """
869
        copy op_node in source_graph to graph. And will run recursively
870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920
        for next ops that link to op_node's outputs.
        Args:
            graph(IrGraph): target graph to copy.
            source_graph(IrGraph): source graph to copy.
            op_node(IrOpNode): op node in source_graph.
        Returns:
            None

        """
        key = ''
        for inp in op_node.inputs:
            key = key + inp.name()
        key = key + op_node.name()
        for inp in op_node.outputs:
            key = key + inp.name()
        has_created = False
        if key in self.create_op_map.keys():
            new_op_node = self.create_op_map[key]
            has_created = True
        else:
            new_op_node = graph.create_op_node_from_desc(op_node.node.op())
            self.create_op_map[key] = new_op_node
        if has_created:
            return
        for in_node in op_node.inputs:
            new_node = self._create_new_node(graph, in_node)
            graph.link_to(new_node, new_op_node)
        for in_node in op_node.outputs:
            new_node = self._create_new_node(graph, in_node)
            graph.link_to(new_op_node, new_node)
        for var_node in op_node.outputs:
            for next_op_node in var_node.outputs:
                self._copy_graph(graph, source_graph, next_op_node)
        return

    def _insert_func(self, graph, func, var_node, op):
        """
        Insert a tmp program that returned by func between var_node and op.

        Args:
            graph(IrGraph): target graph to insert tmp program.
            func(Function): function to define a tmp program
            var_node(IrVarNode): node in target graph.
            op(IrOpNode): op in target graph.
        Returns:
            op's new input that replaces var_node
        """
        tmp_program = Program()
        startup_program = Program()
        with program_guard(tmp_program, startup_program):
            with unique_name.guard(var_node.name() + "_"):
921 922 923 924 925
                in_node = data(
                    var_node.name() + '_tmp_input',
                    shape=var_node.shape(),
                    dtype='float32',
                )
926
                out_node = func(in_node)
927
                graph.out_node_mapping_table[out_node.name] = var_node.name()
928
                # loss shape must be 1 when minimize
2
201716010711 已提交
929
                loss = paddle.mean(out_node)
930
                if not graph._for_test:
931 932 933
                    assert (
                        self._optimizer
                    ), "optimizer_func must be set when graph is test graph"
934 935 936 937 938 939
                    in_node.stop_gradient = False
                    optimizer = self._optimizer()
                    optimizer.minimize(loss)
        with scope_guard(self._scope):
            self._exe.run(startup_program)

940 941 942 943 944 945 946 947 948
        tmp_graph = IrGraph(
            core.Graph(tmp_program.desc), for_test=graph._for_test
        )
        in_node = tmp_graph._find_node_by_name(
            tmp_graph.all_var_nodes(), in_node.name
        )
        out_node = tmp_graph._find_node_by_name(
            tmp_graph.all_var_nodes(), out_node.name
        )
949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966

        in_node_params = []
        in_op_node = []
        # copy tmp graph to graph, after that, we can insert tmp graph's copy to graph.
        for node in tmp_graph.all_var_nodes():
            if node.inputs == [] and node.persistable():
                in_node_params.append(node)
        for node in tmp_graph.all_op_nodes():
            if node.inputs == []:
                in_op_node.append(node)
        for node in in_node.outputs:
            self._copy_graph(graph, tmp_graph, node)
        for node in in_node_params:
            for op_node in node.outputs:
                self._copy_graph(graph, tmp_graph, op_node)
        for node in in_op_node:
            self._copy_graph(graph, tmp_graph, node)

967 968 969 970 971 972
        target_in_node = graph._find_node_by_name(
            graph.all_var_nodes(), in_node.name()
        )
        target_out_node = graph._find_node_by_name(
            graph.all_var_nodes(), out_node.name()
        )
973 974 975 976 977 978 979 980 981
        loss_node = graph._find_node_by_name(graph.all_var_nodes(), loss.name)
        outputs = target_in_node.outputs
        for node in outputs:
            graph.update_input_link(target_in_node, var_node, node)
        graph.update_input_link(var_node, target_out_node, op)

        # update grad
        if not graph._for_test:
            op_out = op.outputs[0]
982 983 984
            op_out_grad = graph._find_node_by_name(
                graph.all_var_nodes(), op_out.name() + "@GRAD"
            )
985 986 987
            # find op's gradient op, such as conv2d_grad
            op_grad = op_out_grad.outputs[0]
            target_out_grad_node = graph._find_node_by_name(
988 989
                graph.all_var_nodes(), target_out_node.name() + "@GRAD"
            )
990
            in_node_grad = graph._find_node_by_name(
991 992
                graph.all_var_nodes(), target_in_node.name() + "@GRAD"
            )
993 994 995 996 997 998 999 1000 1001 1002 1003
            in_node_grad_op = in_node_grad.inputs
            # update op_grad's input
            graph.update_input_link(var_node, target_out_node, op_grad)

            op_grad_out = None
            # find var_node's corresponding grad node
            for node in op_grad.outputs:
                if var_node.name() + "@GRAD" in node.name():
                    op_grad_out = node
            # update op_grad's output
            if op_grad_out is not None:
1004 1005 1006
                graph.update_output_link(
                    op_grad_out, target_out_grad_node, op_grad
                )
1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027
            else:
                graph.link_to(op_grad, target_out_grad_node)

            for node in in_node_grad_op:
                graph.update_input_link(target_in_node, var_node, node)
                if op_grad_out:
                    graph.update_output_link(in_node_grad, op_grad_out, node)
            # remove useless nodes
            mean_grad = target_out_grad_node.inputs[0]
            mean_out_grad = mean_grad.inputs[0]
            fill_constant_node = mean_out_grad.inputs[0]
            graph.safe_remove_nodes(mean_grad)
            graph.safe_remove_nodes(mean_out_grad)
            graph.safe_remove_nodes(fill_constant_node)
            graph.safe_remove_nodes(in_node_grad)

        graph.safe_remove_nodes(loss_node.inputs[0])
        graph.safe_remove_nodes(loss_node)
        graph.safe_remove_nodes(target_in_node)
        return target_out_node

W
WangZhen 已提交
1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041
    def _quantized_var_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.quantized" % (var_name)

    def _dequantized_var_name(self, var_name):
        """
        Return dequantized variable name for the input `var_name`.
        """
        return "%s.dequantized" % (var_name)

    def _quantized_scale_name(self, var_name):
        """
1042
        Return the scale name of quantized variable for the input `var_name`.
W
WangZhen 已提交
1043
        """
H
handiz 已提交
1044
        return "%s@scale" % (var_name)
W
WangZhen 已提交
1045

1046
    def _is_skip_quant(self, graph, op_node):
1047 1048 1049 1050
        """
        Analyse whether the op node skips quantization.
        """
        is_skip = False
1051 1052 1053
        if op_node.op().has_attr("skip_quant") and op_node.op().attr(
            "skip_quant"
        ):
1054 1055 1056
            is_skip = True
        # if the inputs of mul and matmul are not all persistable, use
        # AddQuantDequantPass to quantize them.
1057 1058 1059 1060
        if op_node.name() in [
            "mul",
            "matmul",
        ] and _is_input_all_not_persistable(graph, op_node):
1061
            is_skip = True
1062 1063 1064 1065
        if (
            op_node.op().has_attr("quantization_type")
            and op_node.op().attr("quantization_type") == "qat_without_weight"
        ):
1066
            is_skip = True
1067 1068
        return is_skip

W
WangZhen 已提交
1069

1070
class QuantizationFreezePass:
1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081
    def __init__(
        self,
        scope,
        place,
        bias_correction=False,
        weight_bits=8,
        activation_bits=8,
        round_type='round',
        weight_quantize_type='abs_max',
        quantizable_op_type=None,
    ):
1082 1083
        """
        The freeze pass is used to adjust the quantize operator order, for example:
T
tianshuo78520a 已提交
1084
            1) `activation -> quant -> dequant -> conv2d` will be frozen into
1085
            `activation -> quant -> conv2d -> dequant`
T
tianshuo78520a 已提交
1086 1087
            2) `weight -> quant -> dequant -> conv2d` will be frozen into `weight -> conv2d`,
            and weight will be scaled offline.
1088 1089

        Args:
1090 1091
            scope(static.Scope): scope is used to get the weight tensor values.
            place(static.CPUPlace|static.CUDAPlace|str): place is used to restore the weight tensors.
1092
                If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs.
X
XGZhang 已提交
1093 1094
            bias_correction(bool): whether use bias correction for post-training quantization.
                 https://arxiv.org/abs/1810.05723.
1095 1096
            weight_bits(int): quantization bit number for weights.
            activation_bits(int): quantization bit number for activation.
1097
            round_type(str, optional): The method of converting the quantized weights
1098 1099 1100
                value float->int. Currently supports ['round', 'adaround'] methods.
                Default is `round`, which is rounding nearest to the integer.
                'adaround' is refer to https://arxiv.org/abs/2004.10568.
1101 1102
            weight_quantize_type(str): quantization type for weights, support 'abs_max' and
                'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
1103
                since weights are fixed once the model is well trained.
1104 1105
            quantizable_op_type(list[str]): This input param will be removed latter. The pass
                will process all quantized op, so it is not necessary to set the input param.
1106
        """
1107 1108
        assert scope is not None, 'The scope cannot be set None.'
        assert place is not None, 'The place cannot be set None.'
W
WangZhen 已提交
1109
        self._scope = scope
X
XGZhang 已提交
1110
        self._bias_correction = bias_correction
1111
        self._place = _get_paddle_place(place)
W
WangZhen 已提交
1112 1113
        self._weight_bits = weight_bits
        self._activation_bits = activation_bits
1114
        self._round_type = round_type
W
WangZhen 已提交
1115
        self._weight_quantize_type = weight_quantize_type
1116 1117
        self._fake_quant_op_names = _fake_quant_op_list
        self._fake_dequant_op_names = _fake_dequant_op_list
W
WangZhen 已提交
1118 1119
        self._op_input_rename_map = collections.OrderedDict()
        self._op_output_rename_map = collections.OrderedDict()
1120
        self._quant_var_scale_map = collections.OrderedDict()
C
Chang Xu 已提交
1121
        self._quantized_ops = set()
W
WangZhen 已提交
1122 1123

    def apply(self, graph):
1124 1125 1126 1127 1128
        """
        Adjust quantize/dequantize operators order for the inference process.

        Args:
            graph(IrGraph): the applied graph.
1129 1130
        Returns:
            None
1131
        """
1132
        # Get input scales in fake quant op and process weights
1133 1134
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
        ops = graph.all_op_nodes()
W
WangZhen 已提交
1135 1136 1137
        for op_node in ops:
            op_name = op_node.name()
            if op_name in self._fake_quant_op_names:
1138
                input_arg_name = op_node.input('X')[0]
1139 1140 1141
                if hasattr(graph, 'out_node_mapping_table'):
                    if input_arg_name in graph.out_node_mapping_table.keys():
                        input_arg_name = graph.out_node_mapping_table[
1142 1143
                            input_arg_name
                        ]
1144 1145
                if input_arg_name not in persistable_vars:
                    scale_v = graph._find_node_by_name(
1146 1147
                        op_node.outputs, op_node.output('OutScale')[0]
                    )
1148 1149 1150 1151 1152
                    self._quant_var_scale_map[input_arg_name] = scale_v
                else:
                    # Obtain scale from OutScale var node
                    scale_v = self._load_var(op_node.output('OutScale')[0])
                    assert scale_v.ndim in [
1153 1154
                        1,
                        2,
1155 1156 1157
                    ], "the dim of scale_v should be 1 or 2"
                    if scale_v.ndim == 2:
                        scale_v = scale_v[0]
1158 1159 1160 1161
                    if (
                        scale_v.size == 1
                        and self._weight_quantize_type == 'abs_max'
                    ):
1162
                        scale_v = scale_v[0]
W
WangZhen 已提交
1163
                    else:
1164
                        scale_v = scale_v.tolist()
1165
                    self._quant_var_scale_map[input_arg_name] = scale_v
1166
                    # Quantize weight and restore
1167
                    if self._round_type == 'round':
1168
                        param_v = self._load_var(input_arg_name)
1169
                        if any(
1170 1171 1172
                            _check_grandchild_op_node(op_node, op)
                            for op in utils._channelwise_quant_axis1_ops
                        ):
1173 1174 1175
                            quant_axis = 1
                        else:
                            quant_axis = 0
C
Chang Xu 已提交
1176 1177 1178 1179
                        if input_arg_name not in self._quantized_ops:
                            self._quantized_ops.add(input_arg_name)
                            quantized_param_v = utils.quant_tensor(
                                param_v.copy(),
1180 1181
                                scale_v,
                                quant_axis,
C
Chang Xu 已提交
1182
                                self._weight_bits,
1183
                            )
1184
                            quantized_param_v = np.round(quantized_param_v)
C
Chang Xu 已提交
1185
                            # Weight bias correction
1186
                            if self._bias_correction is True:
C
Chang Xu 已提交
1187 1188 1189 1190 1191 1192 1193 1194 1195 1196
                                quantized_param_v = utils.bias_correction_w(
                                    param_v,
                                    quantized_param_v,
                                    scale_v,
                                    quant_axis,
                                    weight_bits=self._weight_bits,
                                )
                                quantized_param_v = np.round(quantized_param_v)
                            self._restore_var(input_arg_name, quantized_param_v)

1197
                    self._remove_fake_quant_and_dequant_op(graph, op_node)
W
WangZhen 已提交
1198

1199
        # Remove all fake dequant op
1200
        ops = graph.all_op_nodes()
W
WangZhen 已提交
1201 1202 1203 1204 1205
        for op_node in ops:
            op_name = op_node.name()
            if op_name in self._fake_dequant_op_names:
                self._remove_fake_quant_and_dequant_op(graph, op_node)

1206
        # Insert post dequant op
1207
        ops = graph.all_op_nodes()
W
WangZhen 已提交
1208
        for op_node in ops:
1209
            op_node_desc = op_node.op()
1210 1211 1212 1213
            if (
                op_node_desc.has_attr("quantization_type")
                and op_node_desc.attr("quantization_type") == "qat_with_weight"
            ):
1214
                if self._weight_quantize_type == 'channel_wise_abs_max':
1215 1216 1217 1218 1219
                    quant_axis = (
                        1
                        if op_node.name() in utils._channelwise_quant_axis1_ops
                        else 0
                    )
1220
                    self._insert_post_channel_dequant_op(
1221 1222
                        graph, op_node, quant_axis
                    )
1223 1224
                else:
                    self._insert_post_dequant_op(graph, op_node)
W
WangZhen 已提交
1225

1226
        # Rename inputs of the followed ops after inserting dequant_op after fc/conv
W
WangZhen 已提交
1227 1228
        for op_node in ops:
            for var_node in op_node.inputs:
1229 1230 1231
                if var_node.node in self._op_output_rename_map:
                    old_in = var_node
                    new_in = self._op_output_rename_map[var_node.node]
W
WangZhen 已提交
1232 1233 1234 1235
                    graph.update_input_link(old_in, new_in, op_node)

        # remove the unused var node in the graph
        self._remove_unused_var_nodes(graph)
Z
Zhen Wang 已提交
1236
        graph.resolve_hazard()
1237
        return graph
W
WangZhen 已提交
1238 1239

    def _remove_fake_quant_and_dequant_op(self, graph, op_node):
1240 1241
        k = graph._find_node_by_name(op_node.outputs, op_node.output('Out')[0])
        v = graph._find_node_by_name(op_node.inputs, op_node.input('X')[0])
1242 1243
        if v.node not in self._op_input_rename_map:
            self._op_input_rename_map[k.node] = v
W
WangZhen 已提交
1244
        else:
1245
            self._op_input_rename_map[k.node] = self._op_input_rename_map[
1246 1247
                v.node
            ]
W
WangZhen 已提交
1248
        graph.safe_remove_nodes(op_node)
W
WangZhen 已提交
1249

1250
    def _insert_post_channel_dequant_op(self, graph, op_node, quant_axis):
1251 1252 1253
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
        for var_node in op_node.inputs:
            name = var_node.name()
1254 1255 1256 1257 1258
            if name not in op_node.input_arg_names():
                continue
            if var_node.node in self._op_input_rename_map:
                old_in = var_node
                new_in = self._op_input_rename_map[var_node.node]
1259 1260 1261
                new_in.clear_outputs()
                graph.update_input_link(old_in, new_in, op_node)
            original_var_name = self._original_var_name(name)
1262
            scale_v = self._quant_var_scale_map[original_var_name]
1263 1264
            if original_var_name in persistable_vars:
                assert isinstance(
1265 1266 1267 1268
                    scale_v, list
                ), 'The scale of parameter %s is not a list.' % (
                    original_var_name
                )
1269 1270 1271
                channel_scale = np.array(scale_v)
            else:
                assert isinstance(scale_v, IrNode)
1272
                scale_var_node = self._quant_var_scale_map[original_var_name]
1273

1274
        if len(op_node.output_arg_names()) != 1:
1275 1276 1277 1278
            raise ValueError(
                "Only support one output, but op %s has"
                " more than one output." % (op_node.name())
            )
1279

1280
        output_var_node = graph._find_node_by_name(
1281 1282
            op_node.outputs, op_node.output_arg_names()[0]
        )
1283 1284 1285 1286
        weight_scale_node = graph.create_persistable_node(
            name=unique_name.generate('channel_scale'),
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[channel_scale.shape[0]],
1287 1288
            var_dtype=output_var_node.dtype(),
        )
1289 1290 1291 1292 1293 1294 1295

        if output_var_node.dtype() == core.VarDesc.VarType.FP64:
            data_type = 'float64'
        elif output_var_node.dtype() == core.VarDesc.VarType.FP32:
            data_type = 'float32'
        else:
            data_type = "float16"
1296 1297 1298 1299 1300 1301
        _init_var_node(
            weight_scale_node,
            channel_scale.astype(data_type),
            self._scope,
            self._place,
        )
1302 1303 1304 1305
        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(output_var_node.name()),
            var_type=output_var_node.type(),
            shape=output_var_node.shape(),
1306 1307
            var_dtype=output_var_node.dtype(),
        )
X
XGZhang 已提交
1308 1309 1310
        x_num_col_dims = 1
        if op_node.name() in ['matmul', 'matmul_v2', 'mul']:
            x_num_col_dims = len(op_node.outputs[0].shape()) - 1
1311 1312
        if op_node.op().has_attr("x_num_col_dims"):
            x_num_col_dims = op_node.op().attr("x_num_col_dims")
1313 1314 1315 1316
        dequant_op_node = graph.create_op_node(
            op_type='fake_channel_wise_dequantize_max_abs',
            attrs={
                'quant_bits': [self._weight_bits, self._activation_bits],
1317
                'quant_axis': quant_axis,
1318
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
1319
                'x_num_col_dims': x_num_col_dims,
1320 1321 1322
            },
            inputs={
                'X': output_var_node,
1323
                'Scales': [weight_scale_node, scale_var_node],
1324
            },
1325 1326
            outputs={'Out': dequant_var_node},
        )
1327 1328 1329 1330
        graph.link_to(output_var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        graph.link_to(weight_scale_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
1331
        self._op_output_rename_map[output_var_node.node] = dequant_var_node
1332 1333
        return dequant_var_node

W
WangZhen 已提交
1334
    def _insert_post_dequant_op(self, graph, op_node):
1335
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
1336 1337 1338
        max_range = 1
        param_range = (1 << (self._weight_bits - 1)) - 1
        act_range = (1 << (self._activation_bits - 1)) - 1
W
WangZhen 已提交
1339
        for var_node in op_node.inputs:
W
WangZhen 已提交
1340
            name = var_node.name()
1341 1342 1343 1344 1345
            if name not in op_node.input_arg_names():
                continue
            if var_node.node in self._op_input_rename_map:
                old_in = var_node
                new_in = self._op_input_rename_map[var_node.node]
W
WangZhen 已提交
1346
                new_in.clear_outputs()
W
WangZhen 已提交
1347 1348
                graph.update_input_link(old_in, new_in, op_node)
            original_var_name = self._original_var_name(name)
1349
            scale_v = self._quant_var_scale_map[original_var_name]
W
WangZhen 已提交
1350 1351
            if original_var_name in persistable_vars:
                assert self._is_float(
1352 1353 1354 1355
                    scale_v
                ), 'The scale of parameter %s is not a float.' % (
                    original_var_name
                )
X
XGZhang 已提交
1356
                scale_v = 1e-8 if scale_v == 0.0 else scale_v
1357
                max_range *= param_range / scale_v
W
WangZhen 已提交
1358
            else:
1359
                max_range *= act_range
1360
                assert isinstance(scale_v, IrNode)
1361
                scale_var_node = self._quant_var_scale_map[original_var_name]
W
WangZhen 已提交
1362

1363
        if len(op_node.output_arg_names()) != 1:
1364 1365 1366 1367
            raise ValueError(
                "Only support one output, but op %s has"
                " more than one output." % (op_node.name())
            )
W
WangZhen 已提交
1368

1369
        output_var_node = graph._find_node_by_name(
1370 1371
            op_node.outputs, op_node.output_arg_names()[0]
        )
W
WangZhen 已提交
1372 1373
        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(output_var_node.name()),
1374 1375
            var_type=output_var_node.type(),
            shape=output_var_node.shape(),
1376 1377
            var_dtype=output_var_node.dtype(),
        )
W
WangZhen 已提交
1378 1379
        dequant_op_node = graph.create_op_node(
            op_type='fake_dequantize_max_abs',
1380 1381
            attrs={
                'max_range': float(max_range),
1382
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
1383
            },
1384 1385 1386
            inputs={'X': output_var_node, 'Scale': scale_var_node},
            outputs={'Out': dequant_var_node},
        )
W
WangZhen 已提交
1387 1388 1389
        graph.link_to(output_var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
1390
        self._op_output_rename_map[output_var_node.node] = dequant_var_node
W
WangZhen 已提交
1391 1392 1393 1394 1395
        return dequant_var_node

    def _load_var(self, name):
        return np.array(self._scope.find_var(name).get_tensor())

1396 1397 1398
    def _restore_var(self, name, array):
        tensor = self._scope.find_var(name).get_tensor()
        tensor.set(array, self._place)
W
WangZhen 已提交
1399 1400 1401

    def _remove_unused_var_nodes(self, graph):
        all_used_vars = set()
1402
        ops = graph.all_op_nodes()
W
WangZhen 已提交
1403 1404 1405 1406 1407 1408
        for op_node in ops:
            for input_node in op_node.inputs:
                all_used_vars.add(input_node)
            for output_node in op_node.outputs:
                all_used_vars.add(output_node)

1409 1410 1411
        all_used_vars = {n.node for n in all_used_vars}
        all_unused_vars = {
            n
1412 1413 1414 1415
            for n in filter(
                lambda node: node.node not in all_used_vars,
                graph.all_var_nodes(),
            )
1416
        }
W
WangZhen 已提交
1417 1418 1419 1420 1421 1422 1423
        graph.safe_remove_nodes(all_unused_vars)

    def _original_var_name(self, var_name):
        """
        Return the original variable name.
        """
        if var_name.endswith('.quantized.dequantized'):
1424
            return var_name[: -len('.quantized.dequantized')]
W
WangZhen 已提交
1425
        if var_name.endswith('.quantized'):
1426
            return var_name[: -len('.quantized')]
W
WangZhen 已提交
1427
        if var_name.endswith('.dequantized'):
1428
            return var_name[: -len('.dequantized')]
H
handiz 已提交
1429
        if var_name.endswith('@scale'):
1430
            return var_name[: -len('@scale')]
W
WangZhen 已提交
1431 1432 1433 1434 1435 1436 1437 1438 1439
        else:
            return var_name

    def _dequantized_var_name(self, var_name):
        """
        Return dequantized variable name for the input `var_name`.
        """
        return "%s.dequantized" % (var_name)

W
WangZhen 已提交
1440
    def _is_float(self, v):
1441 1442
        return (
            isinstance(v, float)
1443
            or isinstance(v, np.float16)
1444
            or isinstance(v, np.float32)
W
WangZhen 已提交
1445
            or isinstance(v, np.float64)
1446
        )
W
WangZhen 已提交
1447

1448

1449
class ConvertToInt8Pass:
1450
    def __init__(self, scope, place, quantizable_op_type=None):
1451 1452 1453 1454
        """
        Convert the weights into int8_t type.

        Args:
1455 1456
            scope(static.Scope): scope is used to get the weight tensor values.
            place(static.CPUPlace|static.CUDAPlace|str): place is used to restore the
1457 1458
                8bits weight tensors. If it's string, It can be ``cpu``, and ``gpu:x``,
                where ``x`` is the index of the GPUs.
1459 1460
            quantizable_op_type(list[str]): This input param will be removed latter. The pass
                will process all quantized op, so it is not necessary to set the input param.
1461
        """
1462 1463
        assert scope is not None, 'The scope cannot be set None.'
        assert place is not None, 'The place cannot be set None.'
1464
        self._scope = scope
1465
        self._place = _get_paddle_place(place)
1466 1467

    def apply(self, graph):
1468
        """
T
tianshuo78520a 已提交
1469 1470
        Convert weights' type of the graph. After that, the data type of the
        graph weights is int8_t.
1471 1472 1473

        Args:
            graph(IrGraph): the applied graph.
1474 1475
        Returns:
            None
1476
        """
1477 1478
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
        ops = graph.all_op_nodes()
1479 1480
        input_map = {}
        for op_node in ops:
1481 1482 1483 1484
            if (
                op_node.op().has_attr("quantization_type")
                and op_node.op().attr("quantization_type") == "qat_with_weight"
            ):
1485 1486 1487 1488
                for var_node in op_node.inputs:
                    name = var_node.name()
                    if name in persistable_vars:
                        if name not in input_map:
1489
                            int8_var_node = self._convert_to_int8(
1490 1491
                                graph, var_node
                            )
1492
                            input_map[name] = int8_var_node
1493 1494 1495
                        graph.update_input_link(
                            var_node, input_map[name], op_node
                        )
1496 1497 1498

        # remove the unused var node in the graph
        self._remove_unused_var_nodes(graph)
Z
Zhen Wang 已提交
1499
        graph.resolve_hazard()
1500 1501 1502 1503
        return graph

    def _convert_to_int8(self, graph, var_node):
        int8_var_node_name = var_node.name() + ".int8"
1504
        int8_var_node = graph.create_persistable_node(
1505
            name=int8_var_node_name,
1506 1507
            var_type=var_node.type(),
            shape=var_node.shape(),
1508 1509
            var_dtype=core.VarDesc.VarType.INT8,
        )
1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523
        array = self._load_var(var_node.name())
        self._scope.var(int8_var_node_name)
        self._store_var(int8_var_node_name, array, np.int8)
        return int8_var_node

    def _load_var(self, name):
        return np.array(self._scope.find_var(name).get_tensor())

    def _store_var(self, name, array, dtype):
        tensor = self._scope.find_var(name).get_tensor()
        tensor.set(array.astype(dtype), self._place)

    def _remove_unused_var_nodes(self, graph):
        all_used_vars = set()
1524
        ops = graph.all_op_nodes()
1525 1526 1527 1528 1529 1530
        for op_node in ops:
            for input_node in op_node.inputs:
                all_used_vars.add(input_node)
            for output_node in op_node.outputs:
                all_used_vars.add(output_node)

1531 1532 1533
        all_used_vars = {n.node for n in all_used_vars}
        all_unused_vars = {
            n
1534 1535 1536 1537
            for n in filter(
                lambda node: node.node not in all_used_vars,
                graph.all_var_nodes(),
            )
1538
        }
1539 1540 1541
        graph.safe_remove_nodes(all_unused_vars)


1542
class TransformForMobilePass:
1543
    def __init__(self):
1544
        """
T
tianshuo78520a 已提交
1545
        This pass is used to convert the frozen graph for paddle-mobile execution.
1546
        """
1547 1548
        self._fake_quant_op_names = _fake_quant_op_list
        self._fake_dequant_op_names = _fake_dequant_op_list
1549 1550

    def apply(self, graph):
1551 1552 1553 1554 1555 1556 1557
        """
        Because paddle-mobile use `quantize` an `dequantize` as the names of
        quantize operator and dequantize operator, the `apply` function just
        realize this logic.

        Args:
            graph(IrGraph): the graph will be transformed.
1558 1559
        Returns:
            None
1560
        """
1561
        ops = graph.all_op_nodes()
1562 1563 1564
        for op_node in ops:
            name = op_node.name()
            if name in self._fake_quant_op_names:
1565
                op_node.set_type('quantize')
1566 1567 1568 1569 1570 1571 1572
                quant_node = graph.create_op_node_from_desc(op_node.op())
                for input_node in op_node.inputs:
                    graph.link_to(input_node, quant_node)
                for output_node in op_node.outputs:
                    graph.link_to(quant_node, output_node)
                graph.safe_remove_nodes(op_node)
            if name in self._fake_dequant_op_names:
1573
                op_node.set_type('dequantize')
1574 1575 1576 1577 1578 1579
                dequant_node = graph.create_op_node_from_desc(op_node.op())
                for input_node in op_node.inputs:
                    graph.link_to(input_node, dequant_node)
                for output_node in op_node.outputs:
                    graph.link_to(dequant_node, output_node)
                graph.safe_remove_nodes(op_node)
Z
Zhen Wang 已提交
1580
        graph.resolve_hazard()
1581
        return graph
1582 1583


1584
class OutScaleForTrainingPass:
1585 1586 1587 1588 1589 1590 1591 1592
    def __init__(
        self,
        scope=None,
        place=None,
        moving_rate=0.9,
        is_test=None,
        scale_dict=None,
    ):
1593 1594 1595 1596 1597
        """
        This pass is used for calculating output scales of some operators.
        These output scales may be used by tensorRT or some other inference engines.

        Args:
1598 1599
            scope(static.Scope): The scope is used to initialize these new parameters.
            place(static.CPUPlace|static.CUDAPlace|str): The place is used to initialize new parameters.
1600 1601
                If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the
                index of the GPUs.
1602 1603 1604
            moving_rate(float): The decay coefficient of moving average. The default value is 0.9.
        """
        self._scope = scope
1605
        self._place = _get_paddle_place(place)
1606
        self._moving_rate = moving_rate
1607
        self._is_test = is_test
1608
        self._teller_set = list(SUPPORT_QUANTIZATION_OP_DICT.keys())
1609
        self._scale_dict = scale_dict
1610 1611 1612 1613 1614 1615 1616 1617 1618

    def apply(self, graph):
        """
        Insert the `moving_average_abs_max_scale` op in order to calculate output scales
        of operators in the teller_set.

        Args:
            graph(IrGraph): the target graph.
        """
1619 1620 1621
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
1622 1623
        if self._is_test is None:
            self._is_test = graph.is_test()
1624 1625 1626 1627
        target_ops = []
        for op in graph.all_op_nodes():
            if op.name() in self._teller_set:
                target_ops.append(op)
1628 1629 1630 1631 1632
        with tqdm(
            total=len(target_ops),
            bar_format='Adding OutScale op:|{bar}| {n_fmt}/{total_fmt}',
            ncols=80,
        ) as t:
1633 1634
            for op in target_ops:
                for output_var_name in utils._get_op_output_var_names(op):
1635 1636 1637 1638 1639 1640
                    in_node = graph._find_node_by_name(
                        op.outputs, output_var_name
                    )
                    if in_node.dtype() not in [
                        core.VarDesc.VarType.FP64,
                        core.VarDesc.VarType.FP32,
1641
                        core.VarDesc.VarType.FP16,
1642
                    ]:
1643
                        continue
1644

1645 1646 1647 1648 1649 1650 1651
                    if in_node.dtype() == core.VarDesc.VarType.FP64:
                        data_type = 'float64'
                    elif in_node.dtype() == core.VarDesc.VarType.FP32:
                        data_type = 'float32'
                    else:
                        data_type = "float16"

1652
                    try:
1653
                        graph._find_node_by_name(
1654
                            graph.all_var_nodes(),
1655 1656
                            self._scale_name(in_node.name()),
                        )
1657
                        continue
1658 1659 1660 1661 1662
                    except:
                        scale_node = graph.create_persistable_node(
                            name=self._scale_name(in_node.name()),
                            var_type=core.VarDesc.VarType.LOD_TENSOR,
                            shape=[1],
1663 1664
                            var_dtype=in_node.dtype(),
                        )
1665 1666 1667
                        if self._scale_dict is not None:
                            try:
                                scale_value = np.array(
1668 1669
                                    [self._scale_dict[in_node.name()]]
                                )
1670 1671 1672 1673
                            except:
                                scale_value = np.ones([1], dtype=data_type)
                        else:
                            scale_value = np.ones([1], dtype=data_type)
1674 1675 1676
                    _init_var_node(
                        scale_node, scale_value, self._scope, self._place
                    )
1677

1678 1679 1680 1681 1682 1683 1684
                    ins = {'X': in_node}
                    outs = {'OutScale': scale_node}
                    if not self._is_test:
                        state_in_node = graph.create_persistable_node(
                            name=unique_name.generate('scale_state@'),
                            var_type=core.VarDesc.VarType.LOD_TENSOR,
                            var_dtype=in_node.dtype(),
1685 1686 1687 1688 1689 1690 1691 1692
                            shape=[1],
                        )
                        _init_var_node(
                            state_in_node,
                            np.ones([1], dtype=data_type),
                            self._scope,
                            self._place,
                        )
1693 1694 1695 1696
                        accum_in_node = graph.create_persistable_node(
                            name=unique_name.generate('scale_accum@'),
                            var_type=core.VarDesc.VarType.LOD_TENSOR,
                            var_dtype=in_node.dtype(),
1697 1698 1699 1700 1701 1702 1703 1704
                            shape=[1],
                        )
                        _init_var_node(
                            accum_in_node,
                            np.ones([1], dtype=data_type),
                            self._scope,
                            self._place,
                        )
1705
                        state_out_node = graph.create_var_node_from_desc(
1706 1707
                            state_in_node.var()
                        )
1708
                        accum_out_node = graph.create_var_node_from_desc(
1709 1710
                            accum_in_node.var()
                        )
1711 1712 1713 1714 1715 1716 1717 1718 1719

                        ins['InState'] = state_in_node
                        ins['InAccum'] = accum_in_node
                        outs['OutState'] = state_out_node
                        outs['OutAccum'] = accum_out_node

                    attrs = {
                        'moving_rate': self._moving_rate,
                        'is_test': self._is_test,
1720
                        'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
1721 1722 1723 1724 1725
                    }
                    scale_op_node = graph.create_op_node(
                        op_type='moving_average_abs_max_scale',
                        attrs=attrs,
                        inputs=ins,
1726 1727
                        outputs=outs,
                    )
C
ceci3 已提交
1728 1729 1730 1731 1732

                    next_op_node = None
                    if len(in_node.outputs) > 0:
                        next_op_node = in_node.outputs[0]

1733 1734
                    graph.link_to(in_node, scale_op_node)
                    graph.link_to(scale_op_node, scale_node)
C
ceci3 已提交
1735 1736 1737
                    if next_op_node:
                        graph.link_to(scale_node, next_op_node)

1738 1739 1740 1741 1742 1743
                    if not self._is_test:
                        graph.link_to(state_in_node, scale_op_node)
                        graph.link_to(accum_in_node, scale_op_node)
                        graph.link_to(scale_op_node, state_out_node)
                        graph.link_to(scale_op_node, accum_out_node)
                t.update()
1744 1745 1746 1747 1748 1749
        return graph

    def _scale_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
1750
        return "%s@scale" % (var_name)
1751 1752


1753
class OutScaleForInferencePass:
1754 1755 1756 1757 1758 1759
    def __init__(self, scope=None):
        """
        This pass is used for setting output scales of some operators.
        These output scales may be used by tensorRT or some other inference engines.

        Args:
1760
            scope(static.Scope): The scope is used to initialize these new parameters.
1761 1762
        """
        self._scope = scope
1763
        self._teller_set = list(SUPPORT_QUANTIZATION_OP_DICT.keys())
1764 1765 1766 1767 1768 1769 1770 1771 1772

    def apply(self, graph):
        """
        Get output scales from the scope and set these scales in op_descs
        of operators in the teller_set.

        Args:
            graph(IrGraph): the target graph.
        """
1773 1774 1775
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
1776 1777 1778
        op_nodes = graph.all_op_nodes()
        for op_node in op_nodes:
            if op_node.name() in self._teller_set:
1779
                var_names = utils._get_op_output_var_names(op_node)
1780
                for var_name in var_names:
1781 1782 1783
                    in_node = graph._find_node_by_name(
                        op_node.outputs, var_name
                    )
C
ceci3 已提交
1784 1785 1786 1787 1788
                    if (in_node.node.var() is None) or (
                        in_node.dtype()
                        not in [
                            core.VarDesc.VarType.FP64,
                            core.VarDesc.VarType.FP32,
1789
                            core.VarDesc.VarType.FP16,
C
ceci3 已提交
1790 1791
                        ]
                    ):
1792 1793
                        continue

1794
                    scale_name = self._scale_name(var_name)
1795
                    scale_var = self._scope.find_var(scale_name)
1796 1797 1798 1799 1800
                    assert (
                        scale_var is not None
                    ), "Can not find {} variable in the scope".format(
                        scale_name
                    )
1801 1802 1803 1804
                    scale_value = np.array(scale_var.get_tensor())[0]

                    # For compatibility, we save output threshold by two methods.
                    op_node.op()._set_attr("out_threshold", float(scale_value))
1805

1806
                    argname_index = utils._get_output_name_index(
1807 1808 1809
                        op_node, var_name
                    )
                    assert argname_index is not None, (
1810
                        var_name + " is not the output of the op"
1811 1812 1813 1814 1815
                    )
                    op_node.op()._set_attr(
                        argname_index[0] + str(argname_index[1]) + "_threshold",
                        float(scale_value),
                    )
1816
                    op_node.op()._set_attr("with_quant_attr", True)
1817 1818 1819 1820 1821 1822 1823
        graph.resolve_hazard()
        return graph

    def _scale_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
1824
        return "%s@scale" % (var_name)
1825 1826


1827
class AddQuantDequantPass:
1828
    """
1829
    Quantize the ops that do not have weights, and add quant_dequant op for the
1830 1831
    quantized ops's inputs.
    """
1832

1833 1834 1835
    # To be compatible with PaddleSlim, not remove _activation_type for now
    _activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]

1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846
    def __init__(
        self,
        scope=None,
        place=None,
        moving_rate=0.9,
        quant_bits=8,
        skip_pattern=["skip_quant"],
        quantizable_op_type=["elementwise_add", "pool2d"],
        is_test=None,
        scale_dict=None,
    ):
1847
        """
1848
        Constructor.
1849 1850

        Args:
1851 1852
            scope(static.Scope): The scope is used to initialize these new parameters.
            place(static.CPUPlace|static.CUDAPlace|str): place is used to initialize new
1853 1854
                parameters described above. If ``place`` is string, it can be It can be ``cpu``
                or ``gpu:x``, where ``x`` is the index of the GPUs.
1855
            moving_rate(float, optional): the param for 'quant_dequant_moving_average_abs_max'
1856 1857 1858 1859 1860 1861
                quantization. Default is 0.9.
            quant_bits(int, optional): quantization bit number for activation. Default is 8.
            skip_pattern(str, optional): The user-defined quantization skip pattern, which
                will be presented in the name scope of an op. When the skip pattern is
                detected in an op's name scope, the corresponding op will not be quantized.
                Default is 'skip_quant'.
1862 1863
            quantizable_op_type(list[str], optional): List the type of ops that will be
                quantized. Default is ["elementwise_add", "pool2d"].
1864 1865
        """
        self._scope = scope
1866
        self._place = _get_paddle_place(place)
1867 1868
        self._moving_rate = moving_rate
        self._quant_bits = quant_bits
1869
        self._is_test = is_test
1870
        self._skip_pattern = skip_pattern
1871
        self._scale_dict = scale_dict
1872

1873 1874 1875 1876 1877
        self._quantizable_op_type = quantizable_op_type
        for op_type in self._quantizable_op_type:
            assert op_type in list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()), (
                op_type + " is not supported for quantization."
            )
1878 1879 1880 1881
        self._quantizable_grad_op_type = [
            '%s_grad' % (op) for op in self._quantizable_op_type
        ]

1882 1883
        assert self._scope is not None, "scope must not be None."
        assert self._place is not None, "place must not be None."
1884 1885 1886

    def apply(self, graph):
        """
1887 1888
        Add quant_dequant before some ops, such as the 'elementwise_add' and
        'pool2d' op.
1889

1890 1891
        Args:
            graph(IrGraph): the target graph.
1892 1893
        Returns:
            None
1894
        """
1895 1896 1897
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
1898 1899
        if self._is_test is None:
            self._is_test = graph.is_test()
1900 1901
        dequantized_vars_map = collections.OrderedDict()

1902 1903
        # Forward stage, insert quant_dequant op
        all_op_nodes = graph.all_op_nodes()
1904 1905 1906 1907 1908
        with tqdm(
            total=len(all_op_nodes),
            bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
            ncols=80,
        ) as t:
1909 1910 1911 1912
            for op_node in all_op_nodes:
                if op_node.name() in self._quantizable_op_type:
                    is_skip = False
                    if isinstance(self._skip_pattern, list):
1913 1914 1915 1916
                        is_skip = op_node.op().has_attr("op_namescope") and any(
                            pattern in op_node.op().attr("op_namescope")
                            for pattern in self._skip_pattern
                        )
1917
                    elif isinstance(self._skip_pattern, str):
1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934
                        is_skip = (
                            op_node.op().has_attr("op_namescope")
                            and op_node.op()
                            .attr("op_namescope")
                            .find(self._skip_pattern)
                            != -1
                        )
                    is_quantized = (
                        op_node.op().has_attr("quantization_type")
                        and op_node.op().attr("quantization_type")
                        == "qat_with_weight"
                    )
                    if (
                        is_skip
                        or is_quantized
                        or (not _is_input_all_not_persistable(graph, op_node))
                    ):
1935
                        continue
1936

1937 1938 1939
                    op_node.op()._set_attr(
                        "quantization_type", "qat_without_weight"
                    )
1940 1941 1942
                    op_node.op()._set_attr("activation_bits", self._quant_bits)
                    op_node.op()._set_attr("with_quant_attr", True)
                    arg_names = utils._get_op_input_var_names(op_node)
1943 1944 1945 1946 1947 1948 1949 1950 1951
                    # If already quanted, skip it.
                    skip_quant = False
                    for arg_name in arg_names:
                        if "quantized.dequantized" in arg_name:
                            skip_quant = True
                            break
                    if skip_quant:
                        continue

1952 1953
                    for arg_name in arg_names:
                        in_node = graph._find_node_by_name(
1954 1955
                            op_node.inputs, arg_name
                        )
1956 1957 1958
                        if arg_name in dequantized_vars_map:
                            quant_var_node = dequantized_vars_map[arg_name]
                        else:
1959 1960 1961 1962 1963 1964
                            (
                                quant_var_node,
                                _,
                            ) = self._inser_quant_dequant_moving_average_abs_max_op(
                                graph, in_node, self._quant_bits
                            )
1965
                            dequantized_vars_map[arg_name] = quant_var_node
1966 1967 1968
                        graph.update_input_link(
                            in_node, quant_var_node, op_node
                        )
1969
                t.update()
1970

1971 1972
        # Backward stage, update input link
        for op_node in all_op_nodes:
1973
            if op_node.name() in self._quantizable_grad_op_type:
1974 1975
                for input_name in op_node.input_arg_names():
                    if input_name in dequantized_vars_map:
1976
                        in_node = graph._find_node_by_name(
1977 1978
                            op_node.inputs, input_name
                        )
1979
                        dequant_var_node = dequantized_vars_map[input_name]
1980 1981 1982
                        graph.update_input_link(
                            in_node, dequant_var_node, op_node
                        )
1983

1984 1985 1986
        graph.resolve_hazard()
        return graph

1987 1988 1989 1990 1991 1992 1993 1994 1995 1996
    def _inser_quant_dequant_moving_average_abs_max_op(
        self, graph, var_node, quant_bits
    ):
        """Insert fake_quantize_dequantize_moving_average_abs_max op."""
        quant_var_node = graph.create_var_node(
            name="{}.quant_dequant".format(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype(),
        )
1997
        scale_name = "{}.quant_dequant@scale".format(var_node.name())
1998 1999 2000 2001 2002 2003
        if var_node.dtype() == core.VarDesc.VarType.FP64:
            data_type = 'float64'
        elif var_node.dtype() == core.VarDesc.VarType.FP32:
            data_type = 'float32'
        else:
            data_type = "float16"
2004
        try:
2005 2006 2007 2008 2009 2010 2011
            if (
                self._scale_dict is not None
                and var_node.name() in self._scale_dict.keys()
            ):
                scale_value = np.array(
                    [self._scale_dict[var_node.name()]], dtype=data_type
                )
2012 2013 2014
            else:
                scale_value = np.array(
                    self._scope.find_var(scale_name).get_tensor(),
2015 2016
                    dtype=data_type,
                )
2017 2018 2019
        except:
            scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)

2020
        scale_in_node = graph.create_persistable_node(
H
handiz 已提交
2021
            name="{}.quant_dequant@scale".format(var_node.name()),
2022 2023
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
2024 2025
            var_dtype=var_node.dtype(),
        )
2026

2027
        _init_var_node(scale_in_node, scale_value, self._scope, self._place)
2028 2029 2030 2031 2032 2033 2034 2035
        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        ins = {'X': var_node, 'InScale': scale_in_node}
        outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
        if not self._is_test:
            state_in_node = graph.create_persistable_node(
                name=unique_name.generate('quant_dequant.state'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
2036 2037
                shape=[1],
            )
2038 2039 2040 2041 2042 2043
            if var_node.dtype() == core.VarDesc.VarType.FP64:
                data_type = 'float64'
            elif var_node.dtype() == core.VarDesc.VarType.FP32:
                data_type = 'float32'
            else:
                data_type = "float16"
2044 2045 2046 2047 2048 2049
            _init_var_node(
                state_in_node,
                np.ones([1], dtype=data_type),
                self._scope,
                self._place,
            )
2050 2051 2052 2053
            accum_in_node = graph.create_persistable_node(
                name=unique_name.generate('quant_dequant.accum'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
2054 2055 2056 2057 2058 2059 2060 2061
                shape=[1],
            )
            _init_var_node(
                accum_in_node,
                np.ones([1], dtype=data_type),
                self._scope,
                self._place,
            )
2062
            state_out_node = graph.create_var_node_from_desc(
2063 2064
                state_in_node.var()
            )
2065
            accum_out_node = graph.create_var_node_from_desc(
2066 2067
                accum_in_node.var()
            )
2068 2069 2070 2071 2072 2073 2074 2075 2076 2077

            ins['InState'] = state_in_node
            ins['InAccum'] = accum_in_node
            outs['OutState'] = state_out_node
            outs['OutAccum'] = accum_out_node

        attrs = {
            'bit_length': quant_bits,
            'moving_rate': self._moving_rate,
            'is_test': self._is_test,
2078
            'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
2079 2080 2081 2082 2083 2084
        }

        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_dequantize_moving_average_abs_max',
            attrs=attrs,
            inputs=ins,
2085 2086
            outputs=outs,
        )
2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_in_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_out_node)

        if not self._is_test:
            graph.link_to(state_in_node, quant_op_node)
            graph.link_to(accum_in_node, quant_op_node)
            graph.link_to(quant_op_node, state_out_node)
            graph.link_to(quant_op_node, accum_out_node)

        return quant_var_node, scale_out_node
2100 2101


2102
class InsertQuantizeLinear:
2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114
    """
    Insert quantize_linear and dequantize_linear op before ops.

    Args:
        place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to restore the weight tensors.
            If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs.
        scope(paddle.Scope): scope is used to get the weight tensor values.
        quant_bits(int, optional): quantization bit number for weight. Default is 8.
        quant_axis(int, optional): quantization dimension of channels. When it is greater than or
            equal to 0, it will quantization with per channel, else quantization with per layer.
            Default is -1.
        channel_wise(bool, optional): Whether quantization with per channel or not. Default is False.
2115
        moving_rate(float): the rate for 'moving average' method.
2116
        is_test(bool, optional): Whether quantization with training or not. Default is True.
2117
        scale_dict(dict, optional): calibration ranges of tensors output.
2118 2119
    """

2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130
    def __init__(
        self,
        place,
        scope,
        quant_bits=8,
        quant_axis=-1,
        channel_wise=False,
        moving_rate=0.9,
        is_test=True,
        scale_dict=None,
    ):
2131 2132 2133 2134 2135 2136
        self._place = place
        self._scope = scope
        self.quant_bits = quant_bits
        self.quant_axis = quant_axis
        self.channel_wise = channel_wise
        self._is_test = is_test
2137
        self._moving_rate = moving_rate
2138
        self._scale_dict = scale_dict
2139

2140 2141 2142
    def insert_quant_op(
        self, graph, var_node, var_name=None, scale_var_node=None
    ):
2143
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())
2144 2145 2146 2147 2148
        var_name = var_node.name() if not var_name else var_name
        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_name),
            var_type=var_node.type(),
            shape=var_node.shape(),
2149 2150
            var_dtype=var_node.dtype(),
        )
2151
        if not scale_var_node:
2152 2153 2154 2155 2156 2157
            if var_node.dtype() == core.VarDesc.VarType.FP64:
                data_type = 'float64'
            elif var_node.dtype() == core.VarDesc.VarType.FP32:
                data_type = 'float32'
            else:
                data_type = "float16"
2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171
            scale_name = self._quantized_scale_name(var_name)
            if self.channel_wise:
                scale_var_shape = var_node.shape()[self.quant_axis]
                scale_var_type = core.VarDesc.VarType.LOD_TENSOR
                init_scale_value = (
                    np.ones(scale_var_shape, dtype=data_type)
                    * _SCALE_DEFAULT_VALUE
                )
            else:
                scale_var_shape = 1
                scale_var_type = var_node.type()
                init_scale_value = np.array(
                    [_SCALE_DEFAULT_VALUE], dtype=data_type
                )
2172

2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187
            if (
                self._scale_dict is not None
                and var_node.name() in self._scale_dict.keys()
            ):
                init_scale_value = np.array(
                    [self._scale_dict[var_node.name()]], dtype=data_type
                )
            scale_var_node = graph.create_persistable_node(
                name=scale_name,
                var_type=scale_var_type,
                shape=[scale_var_shape],
                var_dtype=var_node.dtype(),
            )
            _init_var_node(
                scale_var_node, init_scale_value, self._scope, self._place
2188
            )
2189 2190 2191 2192 2193 2194 2195

        zero_point_node = None
        if zero_point_node is None:
            zero_point_node = graph.create_persistable_node(
                name=self._zero_point_name(quant_var_node.name()),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                shape=scale_var_node.shape(),
2196 2197 2198 2199 2200 2201 2202 2203
                var_dtype=core.VarDesc.VarType.INT32,
            )
            _init_var_node(
                zero_point_node,
                np.zeros(scale_var_node.shape(), dtype="int32"),
                self._scope,
                self._place,
            )
2204 2205 2206 2207 2208

        inputs = {"X": var_node, "Scale": scale_var_node}
        if zero_point_node is not None:
            inputs["ZeroPoint"] = zero_point_node

2209
        attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
2210
        attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
2211 2212
        outputs = {"Y": quant_var_node}
        if not self._is_test:
2213
            scale_out_node = graph.create_var_node_from_desc(
2214 2215
                scale_var_node.var()
            )
2216 2217 2218 2219
            state_in_node = graph.create_persistable_node(
                name=unique_name.generate('state'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
2220 2221
                shape=[1],
            )
2222 2223 2224 2225 2226 2227
            if var_node.dtype() == core.VarDesc.VarType.FP64:
                data_type = 'float64'
            elif var_node.dtype() == core.VarDesc.VarType.FP32:
                data_type = 'float32'
            else:
                data_type = "float16"
2228 2229 2230 2231 2232 2233
            _init_var_node(
                state_in_node,
                np.ones([1], dtype=data_type),
                self._scope,
                self._place,
            )
2234 2235 2236 2237
            accum_in_node = graph.create_persistable_node(
                name=unique_name.generate('accum'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
2238 2239 2240 2241 2242 2243 2244 2245
                shape=[1],
            )
            _init_var_node(
                accum_in_node,
                np.ones([1], dtype=data_type),
                self._scope,
                self._place,
            )
2246
            state_out_node = graph.create_var_node_from_desc(
2247 2248
                state_in_node.var()
            )
2249
            accum_out_node = graph.create_var_node_from_desc(
2250 2251
                accum_in_node.var()
            )
2252

2253
            outputs["OutScale"] = scale_out_node
2254 2255 2256 2257 2258 2259
            inputs['InState'] = state_in_node
            inputs['InAccum'] = accum_in_node
            outputs['OutState'] = state_out_node
            outputs['OutAccum'] = accum_out_node
            attrs["is_test"] = self._is_test
            attrs['moving_rate'] = self._moving_rate
2260

2261 2262 2263 2264 2265 2266
        quant_op_node = graph.create_op_node(
            op_type="quantize_linear",
            attrs=attrs,
            inputs=inputs,
            outputs=outputs,
        )
2267 2268 2269 2270 2271 2272 2273

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_var_node, quant_op_node)
        if zero_point_node is not None:
            graph.link_to(zero_point_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        if not self._is_test:
2274 2275 2276 2277
            graph.link_to(state_in_node, quant_op_node)
            graph.link_to(accum_in_node, quant_op_node)
            graph.link_to(quant_op_node, state_out_node)
            graph.link_to(quant_op_node, accum_out_node)
2278 2279 2280 2281 2282 2283 2284 2285 2286 2287
            graph.link_to(quant_op_node, scale_out_node)
        return quant_var_node, scale_var_node

    def insert_dequant_op(self, graph, var_node, scale_var_node):
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
2288 2289
            var_dtype=var_node.dtype(),
        )
2290 2291 2292 2293 2294 2295 2296

        zero_point_node = None
        if zero_point_node is None:
            zero_point_node = graph.create_persistable_node(
                name=self._zero_point_name(dequant_var_node.name()),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                shape=scale_var_node.shape(),
2297 2298 2299 2300 2301 2302 2303 2304
                var_dtype=core.VarDesc.VarType.INT32,
            )
            _init_var_node(
                zero_point_node,
                np.zeros(scale_var_node.shape(), dtype="int32"),
                self._scope,
                self._place,
            )
2305 2306 2307 2308 2309 2310

        inputs = {"X": var_node, "Scale": scale_var_node}
        if zero_point_node is not None:
            inputs["ZeroPoint"] = zero_point_node

        attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
2311
        attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
2312

2313 2314 2315 2316 2317 2318
        quant_op_node = graph.create_op_node(
            op_type="dequantize_linear",
            attrs=attrs,
            inputs=inputs,
            outputs={"Y": dequant_var_node},
        )
2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_var_node, quant_op_node)
        if zero_point_node is not None:
            graph.link_to(zero_point_node, quant_op_node)
        graph.link_to(quant_op_node, dequant_var_node)
        return dequant_var_node

    def _quantized_var_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.quantized" % (var_name)

    def _dequantized_var_name(self, var_name):
        """
        Return dequantized variable name for the input `var_name`.
        """
        return "%s.dequantized" % (var_name)

    def _quantized_scale_name(self, var_name):
        """
        Return the scale name of quantized variable for the input `var_name`.
        """
H
handiz 已提交
2343
        return "%s@scale" % (var_name)
2344 2345 2346 2347 2348 2349 2350 2351

    def _zero_point_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
        return "%s@zero_point" % (var_name)


2352
class QuantizationTransformPassV2(QuantizationTransformPass):
2353 2354
    """
    Quantize the ops that have weights. Add quant and dequant ops for
2355
    the quantized ops's inputs. It is used in the new format of quantization.
2356 2357
    """

2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377
    def __init__(
        self,
        scope=None,
        place=None,
        weight_bits=8,
        activation_bits=8,
        activation_quantize_type='abs_max',
        weight_quantize_type='abs_max',
        window_size=10000,
        moving_rate=0.9,
        skip_pattern=['skip_quant'],
        quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'],
        weight_quantize_func=None,
        act_quantize_func=None,
        weight_preprocess_func=None,
        act_preprocess_func=None,
        optimizer_func=None,
        executor=None,
        is_test=None,
    ):
2378 2379 2380 2381 2382 2383 2384
        r"""
        Args:
            scope(paddle.Scope): When activation use 'range_abs_max' as the quantize
                type, this pass will create some new parameters. The scope is used to
                initialize these new parameters.
            place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to initialize new
                parameters described above. If it's string, It can be ``cpu``, and ``gpu:x``,
2385
                where ``x`` is the index of the GPUs.
2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402
            weight_bits(int): quantization bit number for weights,
                the bias is not quantized.
            activation_bits(int): quantization bit number for activation.
            activation_quantize_type(str): quantization type for activation,
                now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
                If use 'abs_max' mode, the quantization scale will be calculated
                dynamically each step in both training and testing period. If use
                'range_abs_max', a static quantization scale will be calculated
                during training and used in inference.
            weight_quantize_type(str): quantization type for weights,
                support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max'
                usually is not used for weight, since weights are fixed once the
                model is well trained.
            window_size(int): the window size for 'range_abs_max' quantization.
            moving_rate(float): the param for 'moving_average_abs_max' quantization.
            skip_pattern(str or str list): The user-defined quantization skip pattern, which
                will be presented in the name scope of an op. When the skip pattern is
2403 2404
                detected in an op's name scope, the corresponding op will not be quantized.
            quantizable_op_type(list[str]): List the type of ops that will be quantized.
2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439
                Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
                QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
            weight_quantize_func(function): Function that defines how to quantize weight.
                Using this can quickly test if user's quantization method works or not.
                In this function, user should both define quantization function and
                dequantization function, that is, the function's input is non-quantized
                weight and function returns dequantized weight. If None, will use
                quantization op defined by 'weight_quantize_type'. Default is None.
            act_quantize_func(function): Function that defines how to quantize activation.
                Using this can quickly test if user's quantization method works or not.
                In this function, user should both define quantization and dequantization
                process, that is, the function's input is non-quantized activation and
                function returns dequantized activation. If None, will use quantization
                op defined by 'activation_quantize_type'. Default is None.
            weight_preprocess_func(function): Function that defines how to preprocess
                weight before quantization. Using this can quickly test if user's preprocess
                method works or not. The function's input is non-quantized weight and
                function returns processed weight to be quantized. If None, the weight will
                be quantized directly. Default is None.
            act_preprocess_func(function): Function that defines how to preprocess
                activation before quantization. Using this can quickly test if user's
                preprocess method works or not. The function's input is non-quantized
                activation and function returns processed activation to be quantized.
                If None, the activation will be quantized directly. Default is None.
            optimizer_func(function): Fuction return a optimizer. When 'is_test' is
                False and user want to use self-defined quantization function and
                preprocess function, this function must be set. Default is None.
            executor(paddle.Executor): If user want to use self-defined quantization
                function and preprocess function, executor must be set for initialization.
                Default is None.

        Examples:
        .. code-block:: python
            # The original graph will be rewrite.
            import paddle
2440
            from paddle.static.quantization \
2441
                import QuantizationTransformPassV2
2442 2443
            from paddle.fluid.framework import IrGraph
            from paddle.framework import core
2444

2445
            graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462
            place = paddle.CPUPlace()
            scope = paddle.static.global_scope()
            transform_pass = QuantizationTransformPassV2(scope, place)
            transform_pass.apply(graph)
        """
        self._scope = scope
        self._place = _get_paddle_place(place)
        self._weight_bits = weight_bits
        self._activation_bits = activation_bits
        self._skip_pattern = skip_pattern
        self._weight_quantize_func = weight_quantize_func
        self._act_quantize_func = act_quantize_func
        self._weight_preprocess_func = weight_preprocess_func
        self._act_preprocess_func = act_preprocess_func
        self._optimizer = optimizer_func
        self._exe = executor
        quant_type = [
2463 2464 2465 2466
            'abs_max',
            'channel_wise_abs_max',
            'range_abs_max',
            'moving_average_abs_max',
2467
        ]
2468 2469 2470
        assert (
            activation_quantize_type != 'channel_wise_abs_max'
        ), "The activation quantization type does not support 'channel_wise_abs_max'."
2471 2472 2473
        if activation_quantize_type not in quant_type:
            raise ValueError(
                "Unknown activation_quantize_type : '%s'. It can only be "
2474 2475 2476
                "'abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
                % (str(activation_quantize_type))
            )
2477 2478 2479 2480
        if weight_quantize_type not in quant_type:
            raise ValueError(
                "Unknown weight_quantize_type: '%s'. It can only be "
                "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' "
2481 2482
                "or 'moving_average_abs_max'." % (str(weight_quantize_type))
            )
2483 2484 2485 2486 2487 2488 2489 2490

        self._activation_quantize_type = activation_quantize_type
        self._weight_quantize_type = weight_quantize_type
        self._window_size = window_size
        self._moving_rate = moving_rate

        self._quantizable_ops = quantizable_op_type
        for op in self._quantizable_ops:
2491
            assert op in list(SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()), (
2492
                op + " is not supported for quantization."
2493
            )
2494 2495 2496
        self._quantizable_grad_ops = [
            '%s_grad' % (op) for op in self._quantizable_ops
        ]
2497
        self._is_test = is_test
2498 2499 2500 2501 2502 2503 2504 2505
        self._global_step = None

        self.create_var_map = {}
        self.create_op_map = {}

    def _quant_preprocess(self, op_node):
        user_skipped = False
        if isinstance(self._skip_pattern, list):
2506 2507 2508 2509
            user_skipped = op_node.op().has_attr("op_namescope") and any(
                pattern in op_node.op().attr("op_namescope")
                for pattern in self._skip_pattern
            )
2510
        elif isinstance(self._skip_pattern, str):
2511 2512 2513 2514 2515
            user_skipped = (
                op_node.op().has_attr("op_namescope")
                and op_node.op().attr("op_namescope").find(self._skip_pattern)
                != -1
            )
2516 2517 2518 2519 2520 2521 2522

        if user_skipped:
            op_node.op()._set_attr("skip_quant", True)
            op_node.op()._set_attr("with_quant_attr", True)

    def _transform_forward(self, graph, op):
        op.op()._set_attr("quantization_type", "qat_with_weight")
2523
        weight_scale_node = None
2524 2525 2526 2527 2528 2529 2530 2531 2532 2533
        inputs = op.inputs
        for var_node in inputs:
            if var_node.name() not in op.input_arg_names():
                continue
            if var_node.name() in self.dequantized_vars:
                dequant_var_node = self.dequantized_vars[var_node.name()]
            else:
                name = var_node.name()
                if name in self.processed_vars:
                    continue
2534 2535 2536
                is_weight = (
                    True if var_node.name() in self.persistable_vars else False
                )
2537 2538

                # if var node is weight and weight_preprocess_func is not None,
2539
                # will insert weight preprocess func
2540
                # to preorocess weight before quantization
2541 2542
                # if var node is activation and act_preprocess_func is not None,
                # will insert activation preprocess func
2543 2544
                # to preorocess activation before quantization
                if is_weight and self._weight_preprocess_func is not None:
2545 2546 2547
                    var_node = self._insert_func(
                        graph, self._weight_preprocess_func, var_node, op
                    )
2548
                elif not is_weight and self._act_preprocess_func is not None:
2549 2550 2551
                    var_node = self._insert_func(
                        graph, self._act_preprocess_func, var_node, op
                    )
2552 2553 2554 2555 2556 2557 2558

                # if var node is weight and weight_quantize_func is not None,
                # will insert weight quantize func to quantize and dequantize weight
                # if var node is activation and act_quantize_func is not None,
                # will insert act quantize func to quantize and dequantize activation
                if is_weight and self._weight_quantize_func is not None:
                    target_out_node = self._insert_func(
2559 2560
                        graph, self._weight_quantize_func, var_node, op
                    )
2561
                    self.processed_vars.append(name)
2562 2563
                    continue
                elif not is_weight and self._act_quantize_func is not None:
2564 2565 2566
                    target_out_node = self._insert_func(
                        graph, self._act_quantize_func, var_node, op
                    )
2567
                    self.processed_vars.append(name)
2568 2569
                    continue

2570 2571 2572
                quant_bits = (
                    self._weight_bits
                    if var_node.name() in self.persistable_vars
2573
                    else self._activation_bits
2574 2575 2576 2577
                )
                quant_type = (
                    self._weight_quantize_type
                    if is_weight
2578
                    else self._activation_quantize_type
2579
                )
2580 2581 2582 2583
                quant_axis = -1
                channel_wise = False
                if quant_type == 'channel_wise_abs_max':  # Weight quantization
                    channel_wise = True
2584 2585 2586 2587 2588
                    quant_axis = (
                        1
                        if op.name() in utils._channelwise_quant_axis1_ops
                        else 0
                    )
2589 2590 2591 2592 2593 2594
                insert_quant_pass = InsertQuantizeLinear(
                    self._place,
                    self._scope,
                    quant_bits=quant_bits,
                    quant_axis=quant_axis,
                    channel_wise=channel_wise,
2595
                    moving_rate=self._moving_rate,
2596 2597 2598 2599 2600 2601 2602 2603
                    is_test=self._is_test,
                )
                (
                    quant_var_node,
                    scale_var_node,
                ) = insert_quant_pass.insert_quant_op(
                    graph, var_node, var_name=name
                )
2604
                dequant_var_node = insert_quant_pass.insert_dequant_op(
2605 2606
                    graph, quant_var_node, scale_var_node
                )
2607 2608

                self.dequantized_vars[name] = dequant_var_node
2609 2610
                if is_weight:
                    weight_scale_node = scale_var_node
2611
            graph.update_input_link(var_node, dequant_var_node, op)
2612
        return weight_scale_node
2613 2614 2615 2616 2617 2618 2619 2620 2621 2622 2623 2624 2625 2626 2627 2628 2629 2630

    def _transform_backward(self, graph, op):
        for var_node in op.inputs:
            if var_node.name() not in op.input_arg_names():
                continue
            if var_node.name() in self.dequantized_vars:
                dequant_var_node = self.dequantized_vars[var_node.name()]
                graph.update_input_link(var_node, dequant_var_node, op)

    def _has_weight(self, op):
        has_weight = False
        for var_node in op.inputs:
            if var_node.name() not in op.input_arg_names():
                continue
            if var_node.name() in self.persistable_vars:
                has_weight = True
        return has_weight

2631 2632 2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691
    def _quant_conv1d(self, graph, op):
        # conv1d in inference is a combination of unsqueeze2 + conv2d
        if ("conv2d" not in op.name()) or (
            "unsqueeze2" not in op.input("Filter")[0]
        ):
            return
        conv_weight_var_name = op.input("Filter")[0]
        # unsqueeze2 and conv2d will share weight scale
        weight_scale_node = None
        # quant unsqueeze2
        for _op in graph.all_op_nodes():
            var_names = utils._get_op_output_var_names(_op)
            if conv_weight_var_name in var_names and self._has_weight(_op):
                weight_scale_node = self._transform_forward(graph, _op)
        # insert qdq before conv2d
        for var_node in op.inputs:
            quant_bits = (
                self._weight_bits
                if var_node.name() == conv_weight_var_name
                else self._activation_bits
            )
            quant_type = (
                self._weight_quantize_type
                if var_node.name() == conv_weight_var_name
                else self._activation_quantize_type
            )
            quant_axis = -1
            channel_wise = False
            if quant_type == 'channel_wise_abs_max':
                channel_wise = True
                quant_axis = (
                    1 if op.name() in utils._channelwise_quant_axis1_ops else 0
                )
            insert_quant_pass = InsertQuantizeLinear(
                self._place,
                self._scope,
                quant_bits=quant_bits,
                quant_axis=quant_axis,
                channel_wise=channel_wise,
                moving_rate=self._moving_rate,
                is_test=self._is_test,
            )
            scale_var_node = (
                weight_scale_node
                if var_node.name() == conv_weight_var_name
                else None
            )
            (
                quant_var_node,
                scale_var_node,
            ) = insert_quant_pass.insert_quant_op(
                graph,
                var_node,
                var_name=var_node.name(),
                scale_var_node=scale_var_node,
            )
            dequant_var_node = insert_quant_pass.insert_dequant_op(
                graph, quant_var_node, scale_var_node
            )
            graph.update_input_link(var_node, dequant_var_node, op)

2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702
    def apply(self, graph):
        """
        Quantize the graph for training process. According to weight and
        activation quantization type, the graph will be added some fake
        quantize operators and fake dequantize operators.

        Args:
            graph(IrGraph): the applied graph.
        Returns:
            None
        """
2703 2704 2705
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
2706 2707
        if self._is_test is None:
            self._is_test = graph.is_test()
2708 2709 2710 2711
        # marked the variable which has been dequantized.
        self.dequantized_vars = collections.OrderedDict()
        self.persistable_vars = []
        self.processed_vars = []
2712 2713 2714 2715 2716 2717 2718 2719 2720

        self.persistable_vars = [
            p.name() for p in graph.all_persistable_nodes()
        ]

        ops = graph.all_op_nodes()
        # Do the preproccess of quantization, such as skipping some ops
        # for not being quantized.
        for op in ops:
2721 2722 2723 2724
            if (
                op.name() in self._quantizable_ops
                or op.name() in self._quantizable_grad_ops
            ):
2725 2726 2727 2728 2729
                self._quant_preprocess(op)
        # Insert mapping table to solve the problem in saving inference model.
        graph.out_node_mapping_table = dict()
        # The process of _transform_forward and _transform_backward is needed in two for loops.
        # The loop for transforming the forward graph:
2730 2731 2732 2733 2734
        with tqdm(
            total=len(ops),
            bar_format='Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}',
            ncols=80,
        ) as t:
2735 2736
            for op in ops:
                if op.name() in self._quantizable_ops:
2737 2738 2739
                    if not self._is_skip_quant(graph, op) and self._has_weight(
                        op
                    ):
2740
                        self._transform_forward(graph, op)
2741 2742 2743
                    else:  # op is not persistable
                        # support conv1d quantization
                        self._quant_conv1d(graph, op)
2744
                t.update()
2745 2746 2747 2748 2749 2750 2751
        # The loop for renaming the inputs of backward op.
        for op in ops:
            if op.name() in self._quantizable_grad_ops and self._has_weight(op):
                self._transform_backward(graph, op)
        return graph


2752
class AddQuantDequantPassV2:
2753 2754
    """
    Quantize the ops that do not have weights, and add quant_linear and dequant_linear
2755
    op for the quantized ops's inputs. It is used in the new format of quantization.
2756 2757 2758 2759 2760
    """

    # To be compatible with PaddleSlim, not remove _activation_type for now
    _activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]

2761 2762 2763 2764 2765 2766 2767 2768 2769 2770 2771
    def __init__(
        self,
        scope=None,
        place=None,
        moving_rate=0.9,
        quant_bits=8,
        skip_pattern=["skip_quant"],
        quantizable_op_type=["elementwise_add", "pool2d"],
        is_test=None,
        scale_dict=None,
    ):
2772 2773 2774 2775 2776 2777
        """
        Args:
            scope(paddle.Scope): The scope is used to initialize these new parameters.
            place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to initialize new
                parameters described above. If ``place`` is string, it can be It can be ``cpu``
                or ``gpu:x``, where ``x`` is the index of the GPUs.
2778
            moving_rate(float, optional): the param for 'quant_dequant_moving_average_abs_max'
2779 2780 2781 2782 2783 2784
                quantization. Default is 0.9.
            quant_bits(int, optional): quantization bit number for activation. Default is 8.
            skip_pattern(str, optional): The user-defined quantization skip pattern, which
                will be presented in the name scope of an op. When the skip pattern is
                detected in an op's name scope, the corresponding op will not be quantized.
                Default is 'skip_quant'.
2785 2786
            quantizable_op_type(list[str], optional): List the type of ops that will be
                quantized. Default is ["elementwise_add", "pool2d"].
2787
            scale_dict(dict, optional): calibration ranges of tensors output.
2788

2789 2790 2791 2792
        Examples:
        .. code-block:: python
            # The original graph will be rewrite.
            import paddle
2793
            from paddle.static.quantization \
2794
                import AddQuantDequantPassV2
2795 2796
            from paddle.fluid.framework import IrGraph
            from paddle.framework import core
2797

2798
            graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
2799 2800 2801 2802 2803 2804 2805 2806 2807
            place = paddle.CPUPlace()
            scope = paddle.static.global_scope()
            add_quant_dequant_pass = AddQuantDequantPassV2(scope, place)
            add_quant_dequant_pass.apply(graph)
        """
        self._scope = scope
        self._place = _get_paddle_place(place)
        self._moving_rate = moving_rate
        self._quant_bits = quant_bits
2808
        self._is_test = is_test
2809
        self._skip_pattern = skip_pattern
2810
        self._scale_dict = scale_dict
2811

2812 2813 2814 2815 2816
        self._quantizable_op_type = quantizable_op_type
        for op_type in self._quantizable_op_type:
            assert op_type in list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()), (
                op_type + " is not supported for quantization."
            )
2817 2818 2819 2820
        self._quantizable_grad_op_type = [
            '%s_grad' % (op) for op in self._quantizable_op_type
        ]

2821 2822
        assert self._scope is not None, "scope must not be None."
        assert self._place is not None, "place must not be None."
2823 2824 2825 2826 2827 2828 2829 2830 2831 2832 2833 2834
        self.persistable_vars = []

    def apply(self, graph):
        """
        Add quant_dequant before some ops, such as the 'elementwise_add' and
        'pool2d' op.

        Args:
            graph(IrGraph): the target graph.
        Returns:
            None
        """
2835 2836 2837
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
2838 2839
        if self._is_test is None:
            self._is_test = graph.is_test()
2840 2841 2842 2843 2844 2845 2846 2847
        dequantized_vars_map = collections.OrderedDict()

        self.persistable_vars = [
            p.name() for p in graph.all_persistable_nodes()
        ]

        # Forward stage, insert quant_dequant op
        all_op_nodes = graph.all_op_nodes()
2848 2849 2850 2851 2852
        with tqdm(
            total=len(all_op_nodes),
            bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
            ncols=80,
        ) as t:
2853 2854 2855 2856
            for op_node in all_op_nodes:
                if op_node.name() in self._quantizable_op_type:
                    is_skip = False
                    if isinstance(self._skip_pattern, list):
2857 2858 2859 2860
                        is_skip = op_node.op().has_attr("op_namescope") and any(
                            pattern in op_node.op().attr("op_namescope")
                            for pattern in self._skip_pattern
                        )
2861
                    elif isinstance(self._skip_pattern, str):
2862 2863 2864 2865 2866 2867 2868 2869 2870 2871 2872 2873
                        is_skip = (
                            op_node.op().has_attr("op_namescope")
                            and op_node.op()
                            .attr("op_namescope")
                            .find(self._skip_pattern)
                            != -1
                        )
                    is_quantized = (
                        op_node.op().has_attr("quantization_type")
                        and op_node.op().attr("quantization_type")
                        == "qat_with_weight"
                    )
2874
                    if is_skip or is_quantized:
2875
                        continue
2876 2877

                    arg_names = utils._get_op_input_var_names(op_node)
2878 2879 2880 2881 2882 2883 2884 2885 2886
                    # If already quanted, skip it.
                    skip_quant = False
                    for arg_name in arg_names:
                        if "quantized.dequantized" in arg_name:
                            skip_quant = True
                            break
                    if skip_quant:
                        continue

2887 2888
                    for arg_name in arg_names:
                        in_node = graph._find_node_by_name(
2889 2890
                            op_node.inputs, arg_name
                        )
2891 2892
                        if in_node.persistable():
                            continue
2893 2894 2895 2896 2897 2898 2899 2900 2901 2902 2903 2904 2905

                        if in_node.dtype() not in [
                            paddle.float64,
                            paddle.float32,
                            paddle.float16,
                        ]:
                            _logger.warning(
                                "Since the {} contains an input of type INT, the quantization of this layer is skipped.".format(
                                    op_node.name()
                                )
                            )
                            break

2906 2907 2908 2909 2910 2911 2912 2913 2914
                        if arg_name in dequantized_vars_map:
                            dequant_var_node = dequantized_vars_map[arg_name]
                        else:
                            insert_quant_pass = InsertQuantizeLinear(
                                self._place,
                                self._scope,
                                quant_bits=self._quant_bits,
                                quant_axis=-1,
                                channel_wise=False,
2915
                                moving_rate=self._moving_rate,
2916
                                is_test=self._is_test,
2917 2918 2919 2920 2921 2922 2923 2924 2925 2926 2927 2928 2929
                                scale_dict=self._scale_dict,
                            )
                            (
                                quant_var_node,
                                scale_var_node,
                            ) = insert_quant_pass.insert_quant_op(
                                graph, in_node
                            )
                            dequant_var_node = (
                                insert_quant_pass.insert_dequant_op(
                                    graph, quant_var_node, scale_var_node
                                )
                            )
2930
                            dequantized_vars_map[arg_name] = dequant_var_node
2931 2932 2933
                        graph.update_input_link(
                            in_node, dequant_var_node, op_node
                        )
2934
                t.update()
2935 2936 2937 2938 2939 2940

        # Backward stage, update input link
        for op_node in all_op_nodes:
            if op_node.name() in self._quantizable_grad_op_type:
                for input_name in op_node.input_arg_names():
                    if input_name in dequantized_vars_map:
2941
                        in_node = graph._find_node_by_name(
2942 2943
                            op_node.inputs, input_name
                        )
2944
                        dequant_var_node = dequantized_vars_map[input_name]
2945 2946 2947
                        graph.update_input_link(
                            in_node, dequant_var_node, op_node
                        )
2948 2949 2950 2951

        return graph


2952
class ReplaceFakeQuantDequantPass:
2953 2954 2955 2956
    """
    replace quant-dequant ops with quantize_linear and dequantize_linear ops.
    """

2957
    def __init__(self, scope, place, quant_bits=8):
2958 2959 2960 2961 2962 2963
        r"""
        Args:
            scope(paddle.Scope): The scope is used to initialize these new parameters.
            place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to initialize new
                parameters described above. If ``place`` is string, it can be It can be ``cpu``
                or ``gpu:x``, where ``x`` is the index of the GPUs.
2964
            quant_bits(int, optional): quantization bit number for activation. Default is 8.
2965

2966 2967 2968 2969
        Examples:
        .. code-block:: python
            # The original graph will be rewrite.
            import paddle
2970
            from paddle.static.quantization \
2971
                import ReplaceFakeQuantDequantPass
2972 2973
            from paddle.fluid.framework import IrGraph
            from paddle.framework import core
2974

2975
            graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
2976 2977 2978 2979 2980 2981 2982
            place = paddle.CPUPlace()
            scope = paddle.static.global_scope()
            replace_pass = ReplaceFakeQuantDequantPass(scope, place)
            replace_pass.apply(graph)
        """
        self._place = _get_paddle_place(place)
        self._scope = scope
2983
        self._quant_bits = quant_bits
2984 2985
        assert self._scope is not None, "scope must not be None."
        assert self._place is not None, "place must not be None."
2986 2987

    def apply(self, graph):
2988 2989 2990
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
2991
        fake_quant_dequant_ops = []
2992 2993 2994 2995 2996 2997
        remove_fake_quant_ops = []
        observer_out_node_names = []
        for op in graph.all_op_nodes():
            # collect observer node
            if op.name() == "moving_average_abs_max_scale":
                observer_out_node_names.append(op.output("Out")[0])
2998 2999

        for op in graph.all_op_nodes():
3000 3001 3002 3003
            if (
                op.name() in _fake_quant_dequant_op_list
                or op.name() == "moving_average_abs_max_scale"
            ):
3004 3005 3006 3007 3008 3009 3010 3011 3012 3013 3014 3015 3016
                var_name = op.input("X")[0]
                if var_name in observer_out_node_names:
                    remove_fake_quant_ops.append(op)
                else:
                    fake_quant_dequant_ops.append(op)

        for _op in remove_fake_quant_ops:
            x_node = graph._find_node_by_name(_op.inputs, _op.input("X")[0])
            out_node = graph._find_node_by_name(
                _op.outputs, _op.output("Out")[0]
            )
            for next_op_node in out_node.outputs:
                graph.update_input_link(out_node, x_node, next_op_node)
3017 3018 3019 3020 3021 3022 3023 3024 3025 3026 3027

        for _op in fake_quant_dequant_ops:
            self._replace_op(graph, _op)
            graph.safe_remove_nodes(_op)

        graph.resolve_hazard()
        return graph

    def _replace_op(self, graph, op):
        x_node = graph._find_node_by_name(op.inputs, op.input("X")[0])
        out_node = graph._find_node_by_name(op.outputs, op.output("Out")[0])
3028 3029 3030 3031 3032 3033 3034 3035 3036 3037 3038 3039
        scale_node = graph._find_node_by_name(
            op.outputs, op.output("OutScale")[0]
        )

        quant_axis = (
            op.op().attr("quant_axis") if op.op().has_attr("quant_axis") else -1
        )
        bit_length = (
            op.op().attr("bit_length")
            if op.op().has_attr("bit_length")
            else self._quant_bits
        )
3040 3041 3042 3043 3044 3045 3046 3047

        zero_point_node = None
        quanted_node = x_node
        if zero_point_node is None:
            zero_point_node = graph.create_persistable_node(
                name=self._zero_point_name(quanted_node.name()),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                shape=scale_node.shape(),
3048 3049 3050 3051 3052 3053 3054 3055 3056 3057 3058 3059 3060 3061 3062 3063 3064 3065 3066 3067 3068 3069 3070 3071 3072
                var_dtype=core.VarDesc.VarType.INT32,
            )
            _init_var_node(
                zero_point_node,
                np.zeros(scale_node.shape(), dtype="int32"),
                self._scope,
                self._place,
            )

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(x_node.name()),
            var_type=x_node.type(),
            shape=x_node.shape(),
            var_dtype=x_node.dtype(),
        )
        quant_op_node = graph.create_op_node(
            op_type="quantize_linear",
            attrs={"quant_axis": quant_axis, "bit_length": bit_length},
            inputs={
                "X": x_node,
                "Scale": scale_node,
                "ZeroPoint": zero_point_node,
            },
            outputs={"Y": quant_var_node},
        )
3073 3074 3075 3076 3077
        graph.link_to(x_node, quant_op_node)
        graph.link_to(scale_node, quant_op_node)
        if zero_point_node is not None:
            graph.link_to(zero_point_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
3078 3079 3080 3081 3082 3083 3084 3085 3086 3087
        dequant_op_node = graph.create_op_node(
            op_type="dequantize_linear",
            attrs={"quant_axis": quant_axis, "bit_length": bit_length},
            inputs={
                "X": quant_var_node,
                "Scale": scale_node,
                "ZeroPoint": zero_point_node,
            },
            outputs={"Y": out_node},
        )
3088 3089 3090 3091 3092 3093 3094 3095 3096 3097 3098 3099 3100 3101 3102 3103 3104 3105 3106
        graph.link_to(quant_var_node, dequant_op_node)
        graph.link_to(scale_node, dequant_op_node)
        if zero_point_node is not None:
            graph.link_to(zero_point_node, dequant_op_node)
        graph.link_to(dequant_op_node, out_node)

    def _quantized_var_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.quantized" % (var_name)

    def _zero_point_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
        return "%s@zero_point" % (var_name)


3107
class QuantWeightPass:
3108 3109 3110 3111 3112 3113 3114 3115 3116 3117 3118 3119 3120
    """
    quant weights and remove weights input quantize_linear node. for example:
    `weight -> quant -> dequant -> conv2d` will be frozen into `weight -> dequant -> conv2d`,
    and weight will be scaled offline.

    Args:
        scope(paddle.Scope): scope is used to get the weight tensor values.
        place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to restore the weight tensors.
            If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs.
        bias_correction(bool): whether use bias correction for post-training quantization.
             https://arxiv.org/abs/1810.05723.
        quant_bits(int, optional): quantization bit number for weight. Default is 8.
        save_int_weight(bool, optional): Whether the type saving the weight is int. Default is True.
3121

3122 3123 3124 3125
    Examples:
        .. code-block:: python
            # The original graph will be rewrite.
            import paddle
3126
            from paddle.static.quantization \
3127
                import QuantWeightPass
3128 3129
            from paddle.fluid.framework import IrGraph
            from paddle.framework import core
3130

3131
            graph = IrGraph(core.Graph(paddle.static.Program().desc), for_test=False)
3132 3133 3134 3135 3136 3137
            place = paddle.CPUPlace()
            scope = paddle.static.global_scope()
            quant_weight_pass = QuantWeightPass(scope, place)
            quant_weight_pass.apply(graph)
    """

3138 3139 3140 3141 3142 3143 3144 3145
    def __init__(
        self,
        scope,
        place,
        bias_correction=False,
        quant_bits=8,
        save_int_weight=True,
    ):
3146 3147 3148 3149 3150
        self._place = _get_paddle_place(place)
        self._scope = scope
        self._bias_correction = bias_correction
        self._quant_bits = quant_bits
        self._save_int_weight = save_int_weight
3151 3152
        assert self._scope is not None, "scope must not be None."
        assert self._place is not None, "place must not be None."
3153
        self._quantized_ops = {}
3154 3155

    def apply(self, graph):
3156 3157 3158
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
3159 3160 3161 3162 3163 3164 3165 3166
        fake_quant_ops_for_weight = []

        fake_quant_ops = [
            op for op in graph.all_op_nodes() if op.name() == "quantize_linear"
        ]
        for _op in fake_quant_ops:
            x_node = graph._find_node_by_name(_op.inputs, _op.input("X")[0])
            if x_node.persistable():
3167 3168 3169
                scale_node = graph._find_node_by_name(
                    _op.inputs, _op.input("Scale")[0]
                )
3170
                zero_point_node = graph._find_node_by_name(
3171 3172 3173 3174 3175
                    _op.inputs, _op.input("ZeroPoint")[0]
                )
                out_node = graph._find_node_by_name(
                    _op.outputs, _op.output("Y")[0]
                )
3176 3177

                scale_v = self._load_var(scale_node.name())
3178 3179 3180 3181
                assert scale_v.ndim in [
                    1,
                    2,
                ], "the dim of scale_v should be 1 or 2"
3182 3183 3184 3185 3186 3187 3188 3189 3190
                if scale_v.ndim == 2:
                    scale_v = scale_v[0]
                if scale_v.size == 1 and _op.name() == 'abs_max':
                    scale_v = scale_v[0]
                else:
                    scale_v = scale_v.tolist()
                param_v = self._load_var(x_node.name())
                quant_axis = _op.op().attr("quant_axis")
                bits_length = _op.op().attr("bit_length")
C
Chang Xu 已提交
3191 3192 3193
                if x_node.name() not in self._quantized_ops:
                    quantized_param_v = utils.quant_tensor(
                        param_v.copy(),
3194 3195
                        scale_v,
                        quant_axis,
C
Chang Xu 已提交
3196 3197
                        bits_length,
                        onnx_format=True,
3198
                    )
3199
                    if self._bias_correction is True:
C
Chang Xu 已提交
3200 3201 3202 3203 3204 3205 3206 3207 3208 3209 3210 3211 3212 3213
                        quantized_param_v = utils.bias_correction_w(
                            param_v,
                            quantized_param_v,
                            scale_v,
                            quant_axis,
                            weight_bits=bits_length,
                        )
                    if self._save_int_weight:
                        # cast weight type to int
                        if self._quant_bits == 8:
                            save_weight_dtype = np.int8
                        quantized_param_v = quantized_param_v.astype(
                            save_weight_dtype
                        )
3214 3215 3216 3217 3218 3219 3220 3221 3222 3223 3224 3225
                    quant_weight_node = graph.create_persistable_node(
                        name=self._quantized_var_name(x_node.name()),
                        var_type=core.VarDesc.VarType.LOD_TENSOR,
                        shape=x_node.shape(),
                        var_dtype=core.VarDesc.VarType.INT8,
                    )
                    _init_var_node(
                        quant_weight_node,
                        quantized_param_v,
                        self._scope,
                        self._place,
                    )
3226
                    self._quantized_ops[x_node.name()] = quant_weight_node
3227 3228

                for next_op_node in out_node.outputs:
3229
                    graph.update_input_link(
3230 3231 3232
                        out_node,
                        self._quantized_ops[x_node.name()],
                        next_op_node,
3233 3234
                    )
                graph.safe_remove_nodes(_op)
3235 3236 3237 3238 3239 3240 3241 3242 3243 3244 3245 3246 3247 3248
        self._remove_unused_var_nodes(graph)

    def _remove_unused_var_nodes(self, graph):
        all_used_vars = set()
        ops = graph.all_op_nodes()
        for op_node in ops:
            for input_node in op_node.inputs:
                all_used_vars.add(input_node)
            for output_node in op_node.outputs:
                all_used_vars.add(output_node)

        all_used_vars = {n.node for n in all_used_vars}
        all_unused_vars = {
            n
3249 3250 3251 3252
            for n in filter(
                lambda node: node.node not in all_used_vars,
                graph.all_var_nodes(),
            )
3253 3254 3255 3256 3257 3258
        }
        graph.safe_remove_nodes(all_unused_vars)

    def _load_var(self, name):
        return np.array(self._scope.find_var(name).get_tensor())

3259 3260 3261 3262 3263
    def _quantized_var_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.quantized" % (var_name)
3264 3265


3266
class AddQuantDequantForInferencePass:
3267 3268 3269 3270
    """
    When export quant model, it will traverse to find the output of each op, and then insert the quant/dequant op after it.
    """

3271 3272 3273 3274 3275 3276 3277 3278 3279
    def __init__(
        self,
        scope,
        place,
        quant_bits=8,
        quantizable_op_type=[],
        calibration_range_dict=None,
        only_observer=True,
    ):
3280 3281
        """
        Args:
3282
            scope(static.Scope): The scope is used to initialize these new parameters.
3283 3284 3285 3286 3287 3288 3289
            place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to restore the weight tensors.
                If it's string, it can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs.
            quant_bits(int, optional): quantization bit number for weight. Default is 8.
        """
        self._scope = scope
        self._place = place
        self._quant_bits = quant_bits
3290 3291 3292 3293 3294 3295 3296
        self._only_observer = only_observer
        self._teller_set = (
            quantizable_op_type
            if quantizable_op_type
            else list(SUPPORT_QUANTIZATION_OP_DICT.keys())
        )
        self._calibration_range_dict = calibration_range_dict
3297 3298 3299 3300 3301 3302

    def apply(self, graph):
        """
        Args:
            graph(IrGraph): the target graph.
        """
3303 3304 3305
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
3306 3307 3308 3309 3310 3311
        dequant_node_map = {}
        dequantized_vars_map = collections.OrderedDict()
        for op_node in graph.all_op_nodes():
            if op_node.name() in self._teller_set:
                var_names = utils._get_op_output_var_names(op_node)
                for var_name in var_names:
3312 3313 3314 3315
                    out_node = graph._find_node_by_name(
                        op_node.outputs, var_name
                    )
                    if out_node.dtype() not in [
3316 3317 3318
                        paddle.float64,
                        paddle.float32,
                        paddle.float16,
3319
                    ]:
3320 3321 3322 3323 3324
                        continue
                    if var_name in dequantized_vars_map:
                        dequant_var_node = dequantized_vars_map[var_name]
                    else:
                        dequant_var_node = self._insert_quant_dequant_op(
3325 3326
                            graph, out_node
                        )
3327 3328 3329 3330 3331 3332 3333 3334 3335 3336
                        dequantized_vars_map[var_name] = dequant_var_node
                    dequant_node_map[var_name] = dequant_var_node

        # remove unuse node and link act quant/dequant linear to op node
        for op_node in graph.all_op_nodes():
            if op_node.name() == 'moving_average_abs_max_scale':
                graph.safe_remove_nodes(op_node)
            else:
                var_names = utils._get_op_input_var_names(op_node)
                for var_name in var_names:
3337 3338 3339 3340
                    if (
                        var_name in dequant_node_map
                        and dequant_node_map[var_name]
                    ):
3341
                        in_node = graph._find_node_by_name(
3342 3343 3344 3345 3346
                            op_node.inputs, var_name
                        )
                        graph.update_input_link(
                            in_node, dequant_node_map[var_name], op_node
                        )
3347 3348 3349 3350 3351 3352 3353 3354 3355 3356 3357 3358 3359 3360 3361 3362 3363

        return graph

    def _scale_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
        return "%s@scale" % (var_name)

    def _insert_quant_dequant_op(self, graph, var_node):
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())
        var_name = var_node.name()
        quant_axis = -1
        quant_var_node = graph.create_var_node(
            name="{}.quantized".format(var_name),
            var_type=var_node.type(),
            shape=var_node.shape(),
3364 3365
            var_dtype=var_node.dtype(),
        )
3366 3367 3368 3369 3370 3371 3372 3373 3374 3375 3376 3377 3378 3379 3380 3381

        try:
            scale_var_node = graph._find_node_by_name(
                graph.all_persistable_nodes(), self._scale_name(var_name)
            )
        except:
            if (
                self._calibration_range_dict
                and var_name in self._calibration_range_dict
            ):
                scale_value = self._calibration_range_dict[var_name]
                scale_var_node = graph.create_persistable_node(
                    name=self._scale_name(var_name),
                    var_type=var_node.type(),
                    shape=[1],
                    var_dtype=var_node.dtype(),
3382
                )
3383 3384 3385 3386 3387 3388 3389 3390 3391 3392 3393 3394
                data_type = (
                    'float64'
                    if var_node.dtype() == core.VarDesc.VarType.FP64
                    else 'float32'
                )
                _init_var_node(
                    scale_var_node,
                    np.array(scale_value, dtype=data_type),
                    self._scope,
                    self._place,
                )
            else:
3395 3396 3397 3398 3399 3400
                _logger.warning(
                    "Cannot find the target node {} in scope, so skip adding quant node.".format(
                        var_name
                    )
                )
                return None
3401 3402 3403
        try:
            zero_point_node = graph._find_node_by_name(
                graph.all_persistable_nodes(),
3404 3405
                "{}@zero_point".format(quant_var_node.name()),
            )
3406 3407 3408 3409 3410
        except:
            zero_point_node = graph.create_persistable_node(
                name="{}@zero_point".format(quant_var_node.name()),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                shape=scale_var_node.shape(),
3411 3412 3413 3414 3415 3416 3417 3418
                var_dtype=core.VarDesc.VarType.INT32,
            )
            _init_var_node(
                zero_point_node,
                np.zeros(scale_var_node.shape(), dtype="int32"),
                self._scope,
                self._place,
            )
3419 3420 3421 3422 3423

        inputs = {"X": var_node, "Scale": scale_var_node}
        if zero_point_node is not None:
            inputs["ZeroPoint"] = zero_point_node

3424 3425 3426 3427 3428
        attrs = {
            "quant_axis": quant_axis,
            "bit_length": self._quant_bits,
            "only_observer": self._only_observer,
        }
3429 3430 3431
        attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
        outputs = {"Y": quant_var_node}

3432 3433 3434 3435 3436 3437
        quant_op_node = graph.create_op_node(
            op_type="quantize_linear",
            attrs=attrs,
            inputs=inputs,
            outputs=outputs,
        )
3438 3439 3440 3441 3442 3443 3444 3445 3446 3447 3448 3449

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_var_node, quant_op_node)
        if zero_point_node is not None:
            graph.link_to(zero_point_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)

        # add dequant_linear node
        dequant_var_node = graph.create_var_node(
            name="{}.dequantized".format(quant_var_node.name()),
            var_type=quant_var_node.type(),
            shape=quant_var_node.shape(),
3450 3451
            var_dtype=quant_var_node.dtype(),
        )
3452 3453 3454 3455 3456

        inputs = {"X": quant_var_node, "Scale": scale_var_node}
        if zero_point_node is not None:
            inputs["ZeroPoint"] = zero_point_node

3457 3458 3459 3460 3461
        attrs = {
            "quant_axis": -1,
            "bit_length": self._quant_bits,
            "only_observer": self._only_observer,
        }
3462 3463
        attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward

3464 3465 3466 3467 3468 3469
        dequant_op_node = graph.create_op_node(
            op_type="dequantize_linear",
            attrs=attrs,
            inputs=inputs,
            outputs={"Y": dequant_var_node},
        )
3470 3471 3472 3473 3474 3475 3476

        graph.link_to(quant_var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        if zero_point_node is not None:
            graph.link_to(zero_point_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
        return dequant_var_node