quantization_pass.py 137.3 KB
Newer Older
1
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
W
WangZhen 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
16

W
WangZhen 已提交
17
import numpy as np
18

19 20 21 22
try:
    from tqdm import tqdm
except:
    from .utils import tqdm
23

2
201716010711 已提交
24
import paddle
25

26 27 28 29 30
from ...fluid.framework import IrGraph, IrNode
from ...framework import _get_paddle_place, core
from ...static import Program, data, program_guard, scope_guard
from ...utils import unique_name
from . import utils
31 32 33 34 35
from .quant_config import (
    SUPPORT_ACT_QUANTIZATION_OP_DICT,
    SUPPORT_QUANTIZATION_OP_DICT,
    SUPPORT_WEIGHT_QUANTIZATION_OP_DICT,
)
W
WangZhen 已提交
36

37
_fake_quant_op_list = [
38 39 40 41
    'fake_quantize_abs_max',
    'fake_quantize_range_abs_max',
    'fake_quantize_moving_average_abs_max',
    'fake_channel_wise_quantize_abs_max',
42 43 44
]

_fake_dequant_op_list = [
45 46
    'fake_dequantize_max_abs',
    'fake_channel_wise_dequantize_max_abs',
47 48
]

49
_fake_quant_dequant_op_list = [
50 51
    'fake_quantize_dequantize_moving_average_abs_max',
    "fake_channel_wise_quantize_dequantize_abs_max",
52
    "fake_quantize_dequantize_abs_max",
53 54
]

55 56
_conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose']

57
_SCALE_DEFAULT_VALUE = 0.001
58 59


60
def _init_var_node(var_node, value, scope, place):
61 62 63 64 65
    assert isinstance(
        value, np.ndarray
    ), 'The type of value should be numpy array.'
    assert scope is not None, 'The scope cannot be set None.'
    assert place is not None, 'The place cannot be set None.'
66 67 68 69
    tensor = scope.var(var_node.name()).get_tensor()
    tensor.set(value, place)


70 71 72 73 74
def _is_input_all_not_persistable(graph, op_node):
    '''
    Analyse the real inputs of the op node are all not persistable.
    '''
    is_input_all_not_persistable = True
75
    for var_name in utils._get_op_input_var_names(op_node):
76
        in_node = graph._find_node_by_name(op_node.inputs, var_name)
77 78 79
        is_input_all_not_persistable = is_input_all_not_persistable and (
            not in_node.persistable()
        )
80 81 82
    return is_input_all_not_persistable


83 84 85 86 87 88 89 90 91 92 93 94 95 96
def _check_grandchild_op_node(op_node, grandchild_op_name):
    '''
    Check whether the fake_quant node has a grandchild op node named
    grandchild_op_name.
    '''
    for out1_var_node in op_node.outputs:
        for out1_op_node in out1_var_node.outputs:
            for out2_var_node in out1_op_node.outputs:
                for out2_op_node in out2_var_node.outputs:
                    if out2_op_node.name() == grandchild_op_name:
                        return True
    return False


97
class QuantizationTransformPass:
98
    """
99 100
    Quantize the ops that have weights. Add quant and dequant ops for
    the quantized ops's inputs.
101
    """
102

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    def __init__(
        self,
        scope=None,
        place=None,
        weight_bits=8,
        activation_bits=8,
        activation_quantize_type='abs_max',
        weight_quantize_type='abs_max',
        window_size=10000,
        moving_rate=0.9,
        skip_pattern=['skip_quant'],
        quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'],
        weight_quantize_func=None,
        act_quantize_func=None,
        weight_preprocess_func=None,
        act_preprocess_func=None,
        optimizer_func=None,
        executor=None,
        is_test=None,
    ):
123
        r"""
124
        Constructor.
125

W
WangZhen 已提交
126
        Args:
127
            scope(static.Scope): When activation use 'range_abs_max' as the quantize
128 129
                type, this pass will create some new parameters. The scope is used to
                initialize these new parameters.
130
            place(static.CPUPlace|static.CUDAPlace|str): place is used to initialize new
131
                parameters described above. If it's string, It can be ``cpu``, and ``gpu:x``,
132
                where ``x`` is the index of the GPUs.
133
            weight_bits(int): quantization bit number for weights,
W
WangZhen 已提交
134
                the bias is not quantized.
135 136
            activation_bits(int): quantization bit number for activation.
            activation_quantize_type(str): quantization type for activation,
137 138 139 140 141
                now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
                If use 'abs_max' mode, the quantization scale will be calculated
                dynamically each step in both training and testing period. If use
                'range_abs_max', a static quantization scale will be calculated
                during training and used in inference.
142
            weight_quantize_type(str): quantization type for weights,
143 144 145
                support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max'
                usually is not used for weight, since weights are fixed once the
                model is well trained.
146 147
            window_size(int): the window size for 'range_abs_max' quantization.
            moving_rate(float): the param for 'moving_average_abs_max' quantization.
148
            skip_pattern(str or str list): The user-defined quantization skip pattern, which
149
                will be presented in the name scope of an op. When the skip pattern is
150 151
                detected in an op's name scope, the corresponding op will not be quantized.
            quantizable_op_type(list[str]): List the type of ops that will be quantized.
152 153
                Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
                QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
            weight_quantize_func(function): Function that defines how to quantize weight.
                Using this can quickly test if user's quantization method works or not.
                In this function, user should both define quantization function and
                dequantization function, that is, the function's input is non-quantized
                weight and function returns dequantized weight. If None, will use
                quantization op defined by 'weight_quantize_type'. Default is None.
            act_quantize_func(function): Function that defines how to quantize activation.
                Using this can quickly test if user's quantization method works or not.
                In this function, user should both define quantization and dequantization
                process, that is, the function's input is non-quantized activation and
                function returns dequantized activation. If None, will use quantization
                op defined by 'activation_quantize_type'. Default is None.
            weight_preprocess_func(function): Function that defines how to preprocess
                weight before quantization. Using this can quickly test if user's preprocess
                method works or not. The function's input is non-quantized weight and
                function returns processed weight to be quantized. If None, the weight will
                be quantized directly. Default is None.
            act_preprocess_func(function): Function that defines how to preprocess
                activation before quantization. Using this can quickly test if user's
                preprocess method works or not. The function's input is non-quantized
                activation and function returns processed activation to be quantized.
                If None, the activation will be quantized directly. Default is None.
            optimizer_func(function): Fuction return a optimizer. When 'is_test' is
                False and user want to use self-defined quantization function and
                preprocess function, this function must be set. Default is None.
            executor(Fluid.Executor): If user want to use self-defined quantization
                function and preprocess function, executor must be set for initialization.
181 182
                Default is None.

183

W
WangZhen 已提交
184 185
        Examples:
        .. code-block:: python
186
            # The original graph will be rewrite.
187 188
            import paddle.static as static
            from paddle.static.quantization \
189
                import QuantizationTransformPass
190 191
            from paddle.fluid.framework import IrGraph
            from paddle.framework import core
192

193 194 195
            graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
            place = paddle.CPUPlace()
            transform_pass = QuantizationTransformPass(static.global_scope(),
196
            place)
197
            transform_pass.apply(graph)
W
WangZhen 已提交
198
        """
199
        self._scope = scope
200
        self._place = _get_paddle_place(place)
201 202
        self._weight_bits = weight_bits
        self._activation_bits = activation_bits
203
        self._skip_pattern = skip_pattern
204 205 206 207 208 209
        self._weight_quantize_func = weight_quantize_func
        self._act_quantize_func = act_quantize_func
        self._weight_preprocess_func = weight_preprocess_func
        self._act_preprocess_func = act_preprocess_func
        self._optimizer = optimizer_func
        self._exe = executor
210
        quant_type = [
211 212 213 214
            'abs_max',
            'channel_wise_abs_max',
            'range_abs_max',
            'moving_average_abs_max',
215
        ]
216 217 218
        assert (
            activation_quantize_type != 'channel_wise_abs_max'
        ), "The activation quantization type does not support 'channel_wise_abs_max'."
W
WangZhen 已提交
219 220
        if activation_quantize_type not in quant_type:
            raise ValueError(
221
                "Unknown activation_quantize_type : '%s'. It can only be "
222 223 224
                "'abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
                % (str(activation_quantize_type))
            )
W
WangZhen 已提交
225 226
        if weight_quantize_type not in quant_type:
            raise ValueError(
227
                "Unknown weight_quantize_type: '%s'. It can only be "
228
                "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' "
229 230
                "or 'moving_average_abs_max'." % (str(weight_quantize_type))
            )
W
WangZhen 已提交
231

232 233 234
        self._activation_quantize_type = activation_quantize_type
        self._weight_quantize_type = weight_quantize_type
        self._window_size = window_size
235
        self._moving_rate = moving_rate
W
WangZhen 已提交
236

237 238
        self._quantizable_ops = quantizable_op_type
        for op in self._quantizable_ops:
239
            assert op in list(SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()), (
240
                op + " is not supported for quantization."
241
            )
242 243
        self._quantizable_grad_ops = [
            '%s_grad' % (op) for op in self._quantizable_ops
W
WangZhen 已提交
244
        ]
245
        self._is_test = is_test
246
        self._global_step = None
W
WangZhen 已提交
247

248 249 250
        self.create_var_map = {}
        self.create_op_map = {}

251
    def apply(self, graph):
252 253 254 255 256 257 258
        """
        Quantize the graph for training process. According to weight and
        activation quantization type, the graph will be added some fake
        quantize operators and fake dequantize operators.

        Args:
            graph(IrGraph): the applied graph.
259 260
        Returns:
            None
261
        """
262 263 264
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
265 266
        if self._is_test is None:
            self._is_test = graph.is_test()
W
WangZhen 已提交
267 268
        # marked the variable which has been dequantized.
        dequantized_vars = collections.OrderedDict()
269
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
270
        processed_vars = []
W
WangZhen 已提交
271

272
        def _quant_preprocess(op_node):
273 274
            user_skipped = False
            if isinstance(self._skip_pattern, list):
275 276 277 278
                user_skipped = op_node.op().has_attr("op_namescope") and any(
                    pattern in op_node.op().attr("op_namescope")
                    for pattern in self._skip_pattern
                )
279
            elif isinstance(self._skip_pattern, str):
280 281 282 283 284 285 286
                user_skipped = (
                    op_node.op().has_attr("op_namescope")
                    and op_node.op()
                    .attr("op_namescope")
                    .find(self._skip_pattern)
                    != -1
                )
287

288
            if user_skipped:
289
                op_node.op()._set_attr("skip_quant", True)
290
                op_node.op()._set_attr("with_quant_attr", True)
291

W
WangZhen 已提交
292
        def _transform_forward(graph, op):
293
            op.op()._set_attr("quantization_type", "qat_with_weight")
294
            op.op()._set_attr("with_quant_attr", True)
295 296
            inputs = op.inputs
            for var_node in inputs:
297 298
                if var_node.name() not in op.input_arg_names():
                    continue
W
WangZhen 已提交
299 300 301
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
                else:
302 303 304
                    name = var_node.name()
                    if name in processed_vars:
                        continue
305 306 307
                    is_weight = (
                        True if var_node.name() in persistable_vars else False
                    )
308 309

                    # if var node is weight and weight_preprocess_func is not None,
310
                    # will insert weight preprocess func
311
                    # to preorocess weight before quantization
312 313
                    # if var node is activation and act_preprocess_func is not None,
                    # will insert activation preprocess func
314 315 316
                    # to preorocess activation before quantization
                    if is_weight and self._weight_preprocess_func is not None:
                        var_node = self._insert_func(
317 318 319 320 321 322 323 324
                            graph, self._weight_preprocess_func, var_node, op
                        )
                    elif (
                        not is_weight and self._act_preprocess_func is not None
                    ):
                        var_node = self._insert_func(
                            graph, self._act_preprocess_func, var_node, op
                        )
325 326 327 328 329 330 331

                    # if var node is weight and weight_quantize_func is not None,
                    # will insert weight quantize func to quantize and dequantize weight
                    # if var node is activation and act_quantize_func is not None,
                    # will insert act quantize func to quantize and dequantize activation
                    if is_weight and self._weight_quantize_func is not None:
                        target_out_node = self._insert_func(
332 333
                            graph, self._weight_quantize_func, var_node, op
                        )
334 335 336 337
                        processed_vars.append(name)
                        continue
                    elif not is_weight and self._act_quantize_func is not None:
                        target_out_node = self._insert_func(
338 339
                            graph, self._act_quantize_func, var_node, op
                        )
340 341 342
                        processed_vars.append(name)
                        continue

343 344 345
                    quant_bits = (
                        self._weight_bits
                        if var_node.name() in persistable_vars
346
                        else self._activation_bits
347 348 349 350
                    )
                    quant_type = (
                        self._weight_quantize_type
                        if is_weight
351
                        else self._activation_quantize_type
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
                    )
                    if (
                        quant_type == 'channel_wise_abs_max'
                    ):  # Weight quantization
                        quant_axis = (
                            1
                            if op.name() in utils._channelwise_quant_axis1_ops
                            else 0
                        )
                        (
                            quant_var_node,
                            scale_var_node,
                        ) = self._insert_channel_quant_op(
                            graph, var_node, name, quant_bits, quant_axis
                        )
367
                        dequant_var_node = self._insert_channel_dequant_op(
368 369 370 371 372 373
                            graph,
                            quant_var_node,
                            [scale_var_node],
                            [quant_bits],
                            quant_axis,
                        )
374 375
                    else:
                        quant_var_node, scale_var_node = self._insert_quant_op(
376 377
                            graph, var_node, name, quant_bits, quant_type
                        )
378
                        dequant_var_node = self._insert_dequant_op(
379 380
                            graph, quant_var_node, scale_var_node, quant_bits
                        )
381
                    dequantized_vars[name] = dequant_var_node
382
                graph.update_input_link(var_node, dequant_var_node, op)
W
WangZhen 已提交
383 384 385

        def _transform_backward(graph, op):
            for var_node in op.inputs:
386 387
                if var_node.name() not in op.input_arg_names():
                    continue
W
WangZhen 已提交
388 389
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
390
                    graph.update_input_link(var_node, dequant_var_node, op)
W
WangZhen 已提交
391

X
XGZhang 已提交
392 393 394 395 396 397 398 399 400 401
        def _has_weight(op):
            has_weight = False
            for var_node in op.inputs:
                if var_node.name() not in op.input_arg_names():
                    continue
                name = var_node.name()
                if var_node.name() in persistable_vars:
                    has_weight = True
            return has_weight

402
        if not self._is_test:
W
WangZhen 已提交
403
            self._create_global_step(graph)
404
        ops = graph.all_op_nodes()
405 406 407
        # Do the preproccess of quantization, such as skipping some ops
        # for not being quantized.
        for op in ops:
408 409 410 411
            if (
                op.name() in self._quantizable_ops
                or op.name() in self._quantizable_grad_ops
            ):
412
                _quant_preprocess(op)
413 414
        # Insert mapping table to solve the problem in saving inference model.
        graph.out_node_mapping_table = dict()
W
WangZhen 已提交
415 416
        # The process of _transform_forward and _transform_backward is needed in two for loops.
        # The loop for transforming the forward graph:
417 418 419 420 421
        with tqdm(
            total=len(ops),
            bar_format='Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}',
            ncols=80,
        ) as t:
422 423 424 425 426
            for op in ops:
                if op.name() in self._quantizable_ops:
                    if not self._is_skip_quant(graph, op) and _has_weight(op):
                        _transform_forward(graph, op)
                t.update()
W
WangZhen 已提交
427 428
        # The loop for renaming the inputs of backward op.
        for op in ops:
X
XGZhang 已提交
429
            if op.name() in self._quantizable_grad_ops and _has_weight(op):
W
WangZhen 已提交
430
                _transform_backward(graph, op)
Z
Zhen Wang 已提交
431
        graph.resolve_hazard()
432
        return graph
W
WangZhen 已提交
433

W
WangZhen 已提交
434
    def _create_global_step(self, graph):
435 436 437 438
        if (
            self._weight_quantize_type == 'range_abs_max'
            or self._activation_quantize_type == 'range_abs_max'
        ):
439
            counter_name = '@STEP_COUNTER@'
440
            for node in graph.all_var_nodes():
W
WangZhen 已提交
441
                if node.name() == counter_name:
442 443
                    self._global_step = node
            if self._global_step is None:
444
                global_step_in = graph.create_persistable_node(
W
WangZhen 已提交
445 446 447
                    name=counter_name,
                    var_type=core.VarDesc.VarType.LOD_TENSOR,
                    shape=[1],
448 449 450 451 452 453 454 455
                    var_dtype=core.VarDesc.VarType.INT64,
                )
                _init_var_node(
                    global_step_in,
                    np.zeros([1], dtype='int64'),
                    self._scope,
                    self._place,
                )
W
WangZhen 已提交
456
                global_step_out = graph.create_var_node_from_desc(
457 458
                    global_step_in.var()
                )
459
                # The attribute of `op_role` is needed by ParallelExecutor.
W
WangZhen 已提交
460 461
                increment_op = graph.create_op_node(
                    op_type='increment',
462 463
                    attrs={
                        'step': 1.0,
464
                        'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
465
                    },
W
WangZhen 已提交
466
                    inputs={'X': global_step_in},
467 468
                    outputs={'Out': global_step_out},
                )
469 470 471
                graph.link_to(global_step_in, increment_op)
                graph.link_to(increment_op, global_step_out)
                self._global_step = global_step_out
W
WangZhen 已提交
472

473
    def _insert_quant_op(self, graph, var_node, name, quant_bits, quant_type):
W
WangZhen 已提交
474 475 476 477
        """
        Insert fake_quantize_op in the graph.
        """
        if quant_type == 'abs_max':
478 479 480
            return self._insert_quant_abs_max_op(
                graph, var_node, name, quant_bits
            )
W
WangZhen 已提交
481
        elif quant_type == 'range_abs_max':
482 483 484
            return self._insert_quant_range_abs_max_op(
                graph, var_node, name, quant_bits
            )
485
        elif quant_type == 'moving_average_abs_max':
486
            return self._insert_quant_moving_average_abs_max_op(
487 488
                graph, var_node, name, quant_bits
            )
W
WangZhen 已提交
489

490
    def _insert_quant_abs_max_op(self, graph, var_node, name, quant_bits):
W
WangZhen 已提交
491 492 493 494 495 496
        """
        Insert fake_quantize_abs_max op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
497
            name=self._quantized_var_name(name),
498 499
            var_type=var_node.type(),
            shape=var_node.shape(),
500 501
            var_dtype=var_node.dtype(),
        )
502
        scale_name = self._quantized_scale_name(name)
503 504 505 506 507 508
        if var_node.dtype() == core.VarDesc.VarType.FP64:
            data_type = 'float64'
        elif var_node.dtype() == core.VarDesc.VarType.FP32:
            data_type = 'float32'
        else:
            data_type = "float16"
509 510
        try:
            scale_value = np.array(
511 512
                self._scope.find_var(scale_name).get_tensor()
            )
513 514
        except:
            scale_value = np.zeros([1], dtype=data_type)
515
        scale_var_node = graph.create_persistable_node(
516
            name=scale_name,
517
            var_type=var_node.type(),
518
            shape=[1],
519 520
            var_dtype=var_node.dtype(),
        )
521 522
        _init_var_node(scale_var_node, scale_value, self._scope, self._place)

W
WangZhen 已提交
523 524
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_abs_max',
525 526
            attrs={
                'bit_length': quant_bits,
527
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
528
            },
W
WangZhen 已提交
529
            inputs={'X': var_node},
530 531
            outputs={'Out': quant_var_node, 'OutScale': scale_var_node},
        )
532 533 534
        graph.link_to(var_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_var_node)
W
WangZhen 已提交
535 536
        return quant_var_node, scale_var_node

537
    def _insert_quant_range_abs_max_op(self, graph, var_node, name, quant_bits):
W
WangZhen 已提交
538 539 540 541 542 543
        """
        Insert fake_quantize_range_abs_max on the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
544
            name=self._quantized_var_name(name),
545 546
            var_type=var_node.type(),
            shape=var_node.shape(),
547 548
            var_dtype=var_node.dtype(),
        )
W
WangZhen 已提交
549

550
        scale_name = self._quantized_scale_name(name)
551 552 553 554 555 556
        if var_node.dtype() == core.VarDesc.VarType.FP64:
            data_type = 'float64'
        elif var_node.dtype() == core.VarDesc.VarType.FP32:
            data_type = 'float32'
        else:
            data_type = "float16"
557 558
        try:
            scale_value = np.array(
559 560
                self._scope.find_var(scale_name).get_tensor()
            )
561 562
        except:
            scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)
563
        scale_in_node = graph.create_persistable_node(
564
            name=scale_name,
W
WangZhen 已提交
565 566
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
567 568
            var_dtype=var_node.dtype(),
        )
569
        _init_var_node(scale_in_node, scale_value, self._scope, self._place)
W
WangZhen 已提交
570 571 572 573 574

        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        inputs = {'X': var_node, 'InScale': scale_in_node}
        outputs = {'Out': quant_var_node, 'OutScale': scale_out_node}

575
        if not self._is_test:
W
WangZhen 已提交
576
            # The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
577
            scales_node = graph.create_persistable_node(
W
WangZhen 已提交
578 579
                name=unique_name.generate('scales'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
580
                shape=[self._window_size],
581 582
                var_dtype=var_node.dtype(),
            )
583 584 585 586 587 588
            if var_node.dtype() == core.VarDesc.VarType.FP64:
                data_type = 'float64'
            elif var_node.dtype() == core.VarDesc.VarType.FP32:
                data_type = 'float32'
            else:
                data_type = "float16"
589 590 591 592 593 594
            _init_var_node(
                scales_node,
                np.zeros([self._window_size], dtype=data_type),
                self._scope,
                self._place,
            )
595

596
            inputs['Iter'] = self._global_step
W
WangZhen 已提交
597 598
            outputs['OutScales'] = scales_node
        attrs = {
599
            'window_size': self._window_size,
W
WangZhen 已提交
600
            'bit_length': quant_bits,
601
            'is_test': self._is_test,
602
            'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
W
WangZhen 已提交
603 604 605 606 607
        }
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_range_abs_max',
            attrs=attrs,
            inputs=inputs,
608 609
            outputs=outputs,
        )
W
WangZhen 已提交
610

611 612 613 614
        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_in_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_out_node)
W
WangZhen 已提交
615

616 617 618
        if not self._is_test:
            graph.link_to(self._global_step, quant_op_node)
            graph.link_to(quant_op_node, scales_node)
W
WangZhen 已提交
619 620 621

        return quant_var_node, scale_out_node

622 623 624 625
    def _insert_quant_moving_average_abs_max_op(
        self, graph, var_node, name, quant_bits
    ):
        """Insert fake_quantize_moving_average_abs_max"""
626
        quant_var_node = graph.create_var_node(
627
            name=self._quantized_var_name(name),
628 629
            var_type=var_node.type(),
            shape=var_node.shape(),
630 631
            var_dtype=var_node.dtype(),
        )
632
        scale_name = self._quantized_scale_name(name)
633 634 635 636 637 638
        if var_node.dtype() == core.VarDesc.VarType.FP64:
            data_type = 'float64'
        elif var_node.dtype() == core.VarDesc.VarType.FP32:
            data_type = 'float32'
        else:
            data_type = "float16"
639 640
        try:
            scale_value = np.array(
641 642
                self._scope.find_var(scale_name).get_tensor()
            )
643 644
        except:
            scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)
645
        scale_in_node = graph.create_persistable_node(
646
            name=scale_name,
647 648
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
649 650
            var_dtype=var_node.dtype(),
        )
651
        _init_var_node(scale_in_node, scale_value, self._scope, self._place)
652 653 654 655 656 657 658 659 660

        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        ins = {'X': var_node, 'InScale': scale_in_node}
        outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
        if not self._is_test:
            state_in_node = graph.create_persistable_node(
                name=unique_name.generate('state'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
661 662
                shape=[1],
            )
663 664 665 666 667 668
            if var_node.dtype() == core.VarDesc.VarType.FP64:
                data_type = 'float64'
            elif var_node.dtype() == core.VarDesc.VarType.FP32:
                data_type = 'float32'
            else:
                data_type = "float16"
669 670 671 672 673 674
            _init_var_node(
                state_in_node,
                np.ones([1], dtype=data_type),
                self._scope,
                self._place,
            )
675 676 677 678
            accum_in_node = graph.create_persistable_node(
                name=unique_name.generate('accum'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
679 680 681 682 683 684 685 686
                shape=[1],
            )
            _init_var_node(
                accum_in_node,
                np.ones([1], dtype=data_type),
                self._scope,
                self._place,
            )
687
            state_out_node = graph.create_var_node_from_desc(
688 689
                state_in_node.var()
            )
690
            accum_out_node = graph.create_var_node_from_desc(
691 692
                accum_in_node.var()
            )
693 694 695 696 697 698 699 700 701 702

            ins['InState'] = state_in_node
            ins['InAccum'] = accum_in_node
            outs['OutState'] = state_out_node
            outs['OutAccum'] = accum_out_node

        attrs = {
            'bit_length': quant_bits,
            'moving_rate': self._moving_rate,
            'is_test': self._is_test,
703
            'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
704 705 706 707 708 709
        }

        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_moving_average_abs_max',
            attrs=attrs,
            inputs=ins,
710 711
            outputs=outs,
        )
712 713 714 715 716 717 718 719 720 721 722 723 724 725

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_in_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_out_node)

        if not self._is_test:
            graph.link_to(state_in_node, quant_op_node)
            graph.link_to(accum_in_node, quant_op_node)
            graph.link_to(quant_op_node, state_out_node)
            graph.link_to(quant_op_node, accum_out_node)

        return quant_var_node, scale_out_node

726 727 728
    def _insert_channel_quant_op(
        self, graph, var_node, name, quant_bits, quant_axis
    ):
729 730 731 732 733 734
        """
        Insert fake_channel_wise_quantize_abs_max op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
735
            name=self._quantized_var_name(name),
736 737
            var_type=var_node.type(),
            shape=var_node.shape(),
738 739
            var_dtype=var_node.dtype(),
        )
740
        scale_name = self._quantized_scale_name(name)
741 742 743 744 745 746
        if var_node.dtype() == core.VarDesc.VarType.FP64:
            data_type = 'float64'
        elif var_node.dtype() == core.VarDesc.VarType.FP32:
            data_type = 'float32'
        else:
            data_type = "float16"
747 748
        try:
            scale_value = np.array(
749 750
                self._scope.find_var(scale_name).get_tensor()
            )
751
        except:
752 753 754
            scale_value = np.zeros(
                [var_node.shape()[quant_axis]], dtype=data_type
            )
755
        scale_var_node = graph.create_persistable_node(
756
            name=self._quantized_scale_name(name),
757
            var_type=var_node.type(),
758
            shape=[var_node.shape()[quant_axis]],
759 760
            var_dtype=var_node.dtype(),
        )
761
        _init_var_node(scale_var_node, scale_value, self._scope, self._place)
762 763 764 765
        quant_op_node = graph.create_op_node(
            op_type='fake_channel_wise_quantize_abs_max',
            attrs={
                'bit_length': quant_bits,
766
                'quant_axis': quant_axis,
767
                'is_test': self._is_test,
768
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
769 770
            },
            inputs={'X': var_node},
771 772
            outputs={'Out': quant_var_node, 'OutScale': scale_var_node},
        )
773 774 775 776 777
        graph.link_to(var_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_var_node)
        return quant_var_node, scale_var_node

W
WangZhen 已提交
778 779 780 781 782 783 784 785
    def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
        """
        Insert fake_dequantize_op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(var_node.name()),
786 787
            var_type=var_node.type(),
            shape=var_node.shape(),
788 789
            var_dtype=var_node.dtype(),
        )
W
WangZhen 已提交
790 791 792
        max_range = (1 << (quant_bits - 1)) - 1
        dequant_op_node = graph.create_op_node(
            op_type='fake_dequantize_max_abs',
793 794
            attrs={
                'max_range': float(max_range),
795
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
796
            },
797 798 799
            inputs={'X': var_node, 'Scale': scale_var_node},
            outputs={'Out': dequant_var_node},
        )
800 801 802
        graph.link_to(var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
W
WangZhen 已提交
803 804
        return dequant_var_node

805 806 807
    def _insert_channel_dequant_op(
        self, graph, var_node, scale_var_nodes, quant_bits, quant_axis
    ):
808 809 810 811 812 813 814 815 816
        """
        Insert fake_channel_wise_dequantize_max_abs in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
817 818
            var_dtype=var_node.dtype(),
        )
819 820 821 822
        dequant_op_node = graph.create_op_node(
            op_type='fake_channel_wise_dequantize_max_abs',
            attrs={
                'quant_bits': quant_bits,
823
                'quant_axis': quant_axis,
824
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
825
            },
826 827 828
            inputs={'X': var_node, 'Scales': scale_var_nodes},
            outputs={'Out': dequant_var_node},
        )
829 830 831 832 833 834
        graph.link_to(var_node, dequant_op_node)
        for scale_n in scale_var_nodes:
            graph.link_to(scale_n, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
        return dequant_var_node

835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862
    def _create_new_node(self, graph, in_node):
        """
        create a node that same with in_node in graph
        Args:
            graph(IrGraph): create node in graph.
            in_node(IrVarNode): create node that same with in_node.
        Returns:
            created new node
        """
        key = ''
        for inp in in_node.inputs:
            key = key + inp.name()
        key = key + in_node.name()
        for inp in in_node.outputs:
            key = key + inp.name()

        if key in self.create_var_map.keys():
            new_node = self.create_var_map[key]
        elif in_node.is_ctrl_var():
            new_node = graph.create_control_dep_var()
            self.create_var_map[key] = new_node
        else:
            new_node = graph.create_var_node_from_desc(in_node.node.var())
            self.create_var_map[key] = new_node
        return new_node

    def _copy_graph(self, graph, source_graph, op_node):
        """
863
        copy op_node in source_graph to graph. And will run recursively
864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914
        for next ops that link to op_node's outputs.
        Args:
            graph(IrGraph): target graph to copy.
            source_graph(IrGraph): source graph to copy.
            op_node(IrOpNode): op node in source_graph.
        Returns:
            None

        """
        key = ''
        for inp in op_node.inputs:
            key = key + inp.name()
        key = key + op_node.name()
        for inp in op_node.outputs:
            key = key + inp.name()
        has_created = False
        if key in self.create_op_map.keys():
            new_op_node = self.create_op_map[key]
            has_created = True
        else:
            new_op_node = graph.create_op_node_from_desc(op_node.node.op())
            self.create_op_map[key] = new_op_node
        if has_created:
            return
        for in_node in op_node.inputs:
            new_node = self._create_new_node(graph, in_node)
            graph.link_to(new_node, new_op_node)
        for in_node in op_node.outputs:
            new_node = self._create_new_node(graph, in_node)
            graph.link_to(new_op_node, new_node)
        for var_node in op_node.outputs:
            for next_op_node in var_node.outputs:
                self._copy_graph(graph, source_graph, next_op_node)
        return

    def _insert_func(self, graph, func, var_node, op):
        """
        Insert a tmp program that returned by func between var_node and op.

        Args:
            graph(IrGraph): target graph to insert tmp program.
            func(Function): function to define a tmp program
            var_node(IrVarNode): node in target graph.
            op(IrOpNode): op in target graph.
        Returns:
            op's new input that replaces var_node
        """
        tmp_program = Program()
        startup_program = Program()
        with program_guard(tmp_program, startup_program):
            with unique_name.guard(var_node.name() + "_"):
915 916 917 918 919
                in_node = data(
                    var_node.name() + '_tmp_input',
                    shape=var_node.shape(),
                    dtype='float32',
                )
920
                out_node = func(in_node)
921
                graph.out_node_mapping_table[out_node.name] = var_node.name()
922
                # loss shape must be 1 when minimize
2
201716010711 已提交
923
                loss = paddle.mean(out_node)
924
                if not graph._for_test:
925 926 927
                    assert (
                        self._optimizer
                    ), "optimizer_func must be set when graph is test graph"
928 929 930 931 932 933
                    in_node.stop_gradient = False
                    optimizer = self._optimizer()
                    optimizer.minimize(loss)
        with scope_guard(self._scope):
            self._exe.run(startup_program)

934 935 936 937 938 939 940 941 942
        tmp_graph = IrGraph(
            core.Graph(tmp_program.desc), for_test=graph._for_test
        )
        in_node = tmp_graph._find_node_by_name(
            tmp_graph.all_var_nodes(), in_node.name
        )
        out_node = tmp_graph._find_node_by_name(
            tmp_graph.all_var_nodes(), out_node.name
        )
943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960

        in_node_params = []
        in_op_node = []
        # copy tmp graph to graph, after that, we can insert tmp graph's copy to graph.
        for node in tmp_graph.all_var_nodes():
            if node.inputs == [] and node.persistable():
                in_node_params.append(node)
        for node in tmp_graph.all_op_nodes():
            if node.inputs == []:
                in_op_node.append(node)
        for node in in_node.outputs:
            self._copy_graph(graph, tmp_graph, node)
        for node in in_node_params:
            for op_node in node.outputs:
                self._copy_graph(graph, tmp_graph, op_node)
        for node in in_op_node:
            self._copy_graph(graph, tmp_graph, node)

961 962 963 964 965 966
        target_in_node = graph._find_node_by_name(
            graph.all_var_nodes(), in_node.name()
        )
        target_out_node = graph._find_node_by_name(
            graph.all_var_nodes(), out_node.name()
        )
967 968 969 970 971 972 973 974 975
        loss_node = graph._find_node_by_name(graph.all_var_nodes(), loss.name)
        outputs = target_in_node.outputs
        for node in outputs:
            graph.update_input_link(target_in_node, var_node, node)
        graph.update_input_link(var_node, target_out_node, op)

        # update grad
        if not graph._for_test:
            op_out = op.outputs[0]
976 977 978
            op_out_grad = graph._find_node_by_name(
                graph.all_var_nodes(), op_out.name() + "@GRAD"
            )
979 980 981
            # find op's gradient op, such as conv2d_grad
            op_grad = op_out_grad.outputs[0]
            target_out_grad_node = graph._find_node_by_name(
982 983
                graph.all_var_nodes(), target_out_node.name() + "@GRAD"
            )
984
            in_node_grad = graph._find_node_by_name(
985 986
                graph.all_var_nodes(), target_in_node.name() + "@GRAD"
            )
987 988 989 990 991 992 993 994 995 996 997
            in_node_grad_op = in_node_grad.inputs
            # update op_grad's input
            graph.update_input_link(var_node, target_out_node, op_grad)

            op_grad_out = None
            # find var_node's corresponding grad node
            for node in op_grad.outputs:
                if var_node.name() + "@GRAD" in node.name():
                    op_grad_out = node
            # update op_grad's output
            if op_grad_out is not None:
998 999 1000
                graph.update_output_link(
                    op_grad_out, target_out_grad_node, op_grad
                )
1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021
            else:
                graph.link_to(op_grad, target_out_grad_node)

            for node in in_node_grad_op:
                graph.update_input_link(target_in_node, var_node, node)
                if op_grad_out:
                    graph.update_output_link(in_node_grad, op_grad_out, node)
            # remove useless nodes
            mean_grad = target_out_grad_node.inputs[0]
            mean_out_grad = mean_grad.inputs[0]
            fill_constant_node = mean_out_grad.inputs[0]
            graph.safe_remove_nodes(mean_grad)
            graph.safe_remove_nodes(mean_out_grad)
            graph.safe_remove_nodes(fill_constant_node)
            graph.safe_remove_nodes(in_node_grad)

        graph.safe_remove_nodes(loss_node.inputs[0])
        graph.safe_remove_nodes(loss_node)
        graph.safe_remove_nodes(target_in_node)
        return target_out_node

W
WangZhen 已提交
1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035
    def _quantized_var_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.quantized" % (var_name)

    def _dequantized_var_name(self, var_name):
        """
        Return dequantized variable name for the input `var_name`.
        """
        return "%s.dequantized" % (var_name)

    def _quantized_scale_name(self, var_name):
        """
1036
        Return the scale name of quantized variable for the input `var_name`.
W
WangZhen 已提交
1037
        """
H
handiz 已提交
1038
        return "%s@scale" % (var_name)
W
WangZhen 已提交
1039

1040
    def _is_skip_quant(self, graph, op_node):
1041 1042 1043 1044
        """
        Analyse whether the op node skips quantization.
        """
        is_skip = False
1045 1046 1047
        if op_node.op().has_attr("skip_quant") and op_node.op().attr(
            "skip_quant"
        ):
1048 1049 1050
            is_skip = True
        # if the inputs of mul and matmul are not all persistable, use
        # AddQuantDequantPass to quantize them.
1051 1052 1053 1054
        if op_node.name() in [
            "mul",
            "matmul",
        ] and _is_input_all_not_persistable(graph, op_node):
1055
            is_skip = True
1056 1057 1058 1059
        if (
            op_node.op().has_attr("quantization_type")
            and op_node.op().attr("quantization_type") == "qat_without_weight"
        ):
1060
            is_skip = True
1061 1062
        return is_skip

W
WangZhen 已提交
1063

1064
class QuantizationFreezePass:
1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075
    def __init__(
        self,
        scope,
        place,
        bias_correction=False,
        weight_bits=8,
        activation_bits=8,
        round_type='round',
        weight_quantize_type='abs_max',
        quantizable_op_type=None,
    ):
1076 1077
        """
        The freeze pass is used to adjust the quantize operator order, for example:
T
tianshuo78520a 已提交
1078
            1) `activation -> quant -> dequant -> conv2d` will be frozen into
1079
            `activation -> quant -> conv2d -> dequant`
T
tianshuo78520a 已提交
1080 1081
            2) `weight -> quant -> dequant -> conv2d` will be frozen into `weight -> conv2d`,
            and weight will be scaled offline.
1082 1083

        Args:
1084 1085
            scope(static.Scope): scope is used to get the weight tensor values.
            place(static.CPUPlace|static.CUDAPlace|str): place is used to restore the weight tensors.
1086
                If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs.
X
XGZhang 已提交
1087 1088
            bias_correction(bool): whether use bias correction for post-training quantization.
                 https://arxiv.org/abs/1810.05723.
1089 1090
            weight_bits(int): quantization bit number for weights.
            activation_bits(int): quantization bit number for activation.
1091
            round_type(str, optional): The method of converting the quantized weights
1092 1093 1094
                value float->int. Currently supports ['round', 'adaround'] methods.
                Default is `round`, which is rounding nearest to the integer.
                'adaround' is refer to https://arxiv.org/abs/2004.10568.
1095 1096
            weight_quantize_type(str): quantization type for weights, support 'abs_max' and
                'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
1097
                since weights are fixed once the model is well trained.
1098 1099
            quantizable_op_type(list[str]): This input param will be removed latter. The pass
                will process all quantized op, so it is not necessary to set the input param.
1100
        """
1101 1102
        assert scope is not None, 'The scope cannot be set None.'
        assert place is not None, 'The place cannot be set None.'
W
WangZhen 已提交
1103
        self._scope = scope
X
XGZhang 已提交
1104
        self._bias_correction = bias_correction
1105
        self._place = _get_paddle_place(place)
W
WangZhen 已提交
1106 1107
        self._weight_bits = weight_bits
        self._activation_bits = activation_bits
1108
        self._round_type = round_type
W
WangZhen 已提交
1109
        self._weight_quantize_type = weight_quantize_type
1110 1111
        self._fake_quant_op_names = _fake_quant_op_list
        self._fake_dequant_op_names = _fake_dequant_op_list
W
WangZhen 已提交
1112 1113
        self._op_input_rename_map = collections.OrderedDict()
        self._op_output_rename_map = collections.OrderedDict()
1114
        self._quant_var_scale_map = collections.OrderedDict()
C
Chang Xu 已提交
1115
        self._quantized_ops = set()
W
WangZhen 已提交
1116 1117

    def apply(self, graph):
1118 1119 1120 1121 1122
        """
        Adjust quantize/dequantize operators order for the inference process.

        Args:
            graph(IrGraph): the applied graph.
1123 1124
        Returns:
            None
1125
        """
1126
        # Get input scales in fake quant op and process weights
1127 1128
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
        ops = graph.all_op_nodes()
W
WangZhen 已提交
1129 1130 1131
        for op_node in ops:
            op_name = op_node.name()
            if op_name in self._fake_quant_op_names:
1132
                input_arg_name = op_node.input('X')[0]
1133 1134 1135
                if hasattr(graph, 'out_node_mapping_table'):
                    if input_arg_name in graph.out_node_mapping_table.keys():
                        input_arg_name = graph.out_node_mapping_table[
1136 1137
                            input_arg_name
                        ]
1138 1139
                if input_arg_name not in persistable_vars:
                    scale_v = graph._find_node_by_name(
1140 1141
                        op_node.outputs, op_node.output('OutScale')[0]
                    )
1142 1143 1144 1145 1146
                    self._quant_var_scale_map[input_arg_name] = scale_v
                else:
                    # Obtain scale from OutScale var node
                    scale_v = self._load_var(op_node.output('OutScale')[0])
                    assert scale_v.ndim in [
1147 1148
                        1,
                        2,
1149 1150 1151
                    ], "the dim of scale_v should be 1 or 2"
                    if scale_v.ndim == 2:
                        scale_v = scale_v[0]
1152 1153 1154 1155
                    if (
                        scale_v.size == 1
                        and self._weight_quantize_type == 'abs_max'
                    ):
1156
                        scale_v = scale_v[0]
W
WangZhen 已提交
1157
                    else:
1158
                        scale_v = scale_v.tolist()
1159
                    self._quant_var_scale_map[input_arg_name] = scale_v
1160
                    # Quantize weight and restore
1161
                    if self._round_type == 'round':
1162
                        param_v = self._load_var(input_arg_name)
1163
                        if any(
1164 1165 1166
                            _check_grandchild_op_node(op_node, op)
                            for op in utils._channelwise_quant_axis1_ops
                        ):
1167 1168 1169
                            quant_axis = 1
                        else:
                            quant_axis = 0
C
Chang Xu 已提交
1170 1171 1172 1173
                        if input_arg_name not in self._quantized_ops:
                            self._quantized_ops.add(input_arg_name)
                            quantized_param_v = utils.quant_tensor(
                                param_v.copy(),
1174 1175
                                scale_v,
                                quant_axis,
C
Chang Xu 已提交
1176
                                self._weight_bits,
1177
                            )
1178
                            quantized_param_v = np.round(quantized_param_v)
C
Chang Xu 已提交
1179
                            # Weight bias correction
1180
                            if self._bias_correction is True:
C
Chang Xu 已提交
1181 1182 1183 1184 1185 1186 1187 1188 1189 1190
                                quantized_param_v = utils.bias_correction_w(
                                    param_v,
                                    quantized_param_v,
                                    scale_v,
                                    quant_axis,
                                    weight_bits=self._weight_bits,
                                )
                                quantized_param_v = np.round(quantized_param_v)
                            self._restore_var(input_arg_name, quantized_param_v)

1191
                    self._remove_fake_quant_and_dequant_op(graph, op_node)
W
WangZhen 已提交
1192

1193
        # Remove all fake dequant op
1194
        ops = graph.all_op_nodes()
W
WangZhen 已提交
1195 1196 1197 1198 1199
        for op_node in ops:
            op_name = op_node.name()
            if op_name in self._fake_dequant_op_names:
                self._remove_fake_quant_and_dequant_op(graph, op_node)

1200
        # Insert post dequant op
1201
        ops = graph.all_op_nodes()
W
WangZhen 已提交
1202
        for op_node in ops:
1203
            op_node_desc = op_node.op()
1204 1205 1206 1207
            if (
                op_node_desc.has_attr("quantization_type")
                and op_node_desc.attr("quantization_type") == "qat_with_weight"
            ):
1208
                if self._weight_quantize_type == 'channel_wise_abs_max':
1209 1210 1211 1212 1213
                    quant_axis = (
                        1
                        if op_node.name() in utils._channelwise_quant_axis1_ops
                        else 0
                    )
1214
                    self._insert_post_channel_dequant_op(
1215 1216
                        graph, op_node, quant_axis
                    )
1217 1218
                else:
                    self._insert_post_dequant_op(graph, op_node)
W
WangZhen 已提交
1219

1220
        # Rename inputs of the followed ops after inserting dequant_op after fc/conv
W
WangZhen 已提交
1221 1222
        for op_node in ops:
            for var_node in op_node.inputs:
1223 1224 1225
                if var_node.node in self._op_output_rename_map:
                    old_in = var_node
                    new_in = self._op_output_rename_map[var_node.node]
W
WangZhen 已提交
1226 1227 1228 1229
                    graph.update_input_link(old_in, new_in, op_node)

        # remove the unused var node in the graph
        self._remove_unused_var_nodes(graph)
Z
Zhen Wang 已提交
1230
        graph.resolve_hazard()
1231
        return graph
W
WangZhen 已提交
1232 1233

    def _remove_fake_quant_and_dequant_op(self, graph, op_node):
1234 1235
        k = graph._find_node_by_name(op_node.outputs, op_node.output('Out')[0])
        v = graph._find_node_by_name(op_node.inputs, op_node.input('X')[0])
1236 1237
        if v.node not in self._op_input_rename_map:
            self._op_input_rename_map[k.node] = v
W
WangZhen 已提交
1238
        else:
1239
            self._op_input_rename_map[k.node] = self._op_input_rename_map[
1240 1241
                v.node
            ]
W
WangZhen 已提交
1242
        graph.safe_remove_nodes(op_node)
W
WangZhen 已提交
1243

1244
    def _insert_post_channel_dequant_op(self, graph, op_node, quant_axis):
1245 1246 1247
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
        for var_node in op_node.inputs:
            name = var_node.name()
1248 1249 1250 1251 1252
            if name not in op_node.input_arg_names():
                continue
            if var_node.node in self._op_input_rename_map:
                old_in = var_node
                new_in = self._op_input_rename_map[var_node.node]
1253 1254 1255
                new_in.clear_outputs()
                graph.update_input_link(old_in, new_in, op_node)
            original_var_name = self._original_var_name(name)
1256
            scale_v = self._quant_var_scale_map[original_var_name]
1257 1258
            if original_var_name in persistable_vars:
                assert isinstance(
1259 1260 1261 1262
                    scale_v, list
                ), 'The scale of parameter %s is not a list.' % (
                    original_var_name
                )
1263 1264 1265
                channel_scale = np.array(scale_v)
            else:
                assert isinstance(scale_v, IrNode)
1266
                scale_var_node = self._quant_var_scale_map[original_var_name]
1267

1268
        if len(op_node.output_arg_names()) != 1:
1269 1270 1271 1272
            raise ValueError(
                "Only support one output, but op %s has"
                " more than one output." % (op_node.name())
            )
1273

1274
        output_var_node = graph._find_node_by_name(
1275 1276
            op_node.outputs, op_node.output_arg_names()[0]
        )
1277 1278 1279 1280
        weight_scale_node = graph.create_persistable_node(
            name=unique_name.generate('channel_scale'),
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[channel_scale.shape[0]],
1281 1282
            var_dtype=output_var_node.dtype(),
        )
1283 1284 1285 1286 1287 1288 1289

        if output_var_node.dtype() == core.VarDesc.VarType.FP64:
            data_type = 'float64'
        elif output_var_node.dtype() == core.VarDesc.VarType.FP32:
            data_type = 'float32'
        else:
            data_type = "float16"
1290 1291 1292 1293 1294 1295
        _init_var_node(
            weight_scale_node,
            channel_scale.astype(data_type),
            self._scope,
            self._place,
        )
1296 1297 1298 1299
        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(output_var_node.name()),
            var_type=output_var_node.type(),
            shape=output_var_node.shape(),
1300 1301
            var_dtype=output_var_node.dtype(),
        )
X
XGZhang 已提交
1302 1303 1304
        x_num_col_dims = 1
        if op_node.name() in ['matmul', 'matmul_v2', 'mul']:
            x_num_col_dims = len(op_node.outputs[0].shape()) - 1
1305 1306
        if op_node.op().has_attr("x_num_col_dims"):
            x_num_col_dims = op_node.op().attr("x_num_col_dims")
1307 1308 1309 1310
        dequant_op_node = graph.create_op_node(
            op_type='fake_channel_wise_dequantize_max_abs',
            attrs={
                'quant_bits': [self._weight_bits, self._activation_bits],
1311
                'quant_axis': quant_axis,
1312
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
1313
                'x_num_col_dims': x_num_col_dims,
1314 1315 1316
            },
            inputs={
                'X': output_var_node,
1317
                'Scales': [weight_scale_node, scale_var_node],
1318
            },
1319 1320
            outputs={'Out': dequant_var_node},
        )
1321 1322 1323 1324
        graph.link_to(output_var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        graph.link_to(weight_scale_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
1325
        self._op_output_rename_map[output_var_node.node] = dequant_var_node
1326 1327
        return dequant_var_node

W
WangZhen 已提交
1328
    def _insert_post_dequant_op(self, graph, op_node):
1329
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
1330 1331 1332
        max_range = 1
        param_range = (1 << (self._weight_bits - 1)) - 1
        act_range = (1 << (self._activation_bits - 1)) - 1
W
WangZhen 已提交
1333
        for var_node in op_node.inputs:
W
WangZhen 已提交
1334
            name = var_node.name()
1335 1336 1337 1338 1339
            if name not in op_node.input_arg_names():
                continue
            if var_node.node in self._op_input_rename_map:
                old_in = var_node
                new_in = self._op_input_rename_map[var_node.node]
W
WangZhen 已提交
1340
                new_in.clear_outputs()
W
WangZhen 已提交
1341 1342
                graph.update_input_link(old_in, new_in, op_node)
            original_var_name = self._original_var_name(name)
1343
            scale_v = self._quant_var_scale_map[original_var_name]
W
WangZhen 已提交
1344 1345
            if original_var_name in persistable_vars:
                assert self._is_float(
1346 1347 1348 1349
                    scale_v
                ), 'The scale of parameter %s is not a float.' % (
                    original_var_name
                )
X
XGZhang 已提交
1350
                scale_v = 1e-8 if scale_v == 0.0 else scale_v
1351
                max_range *= param_range / scale_v
W
WangZhen 已提交
1352
            else:
1353
                max_range *= act_range
1354
                assert isinstance(scale_v, IrNode)
1355
                scale_var_node = self._quant_var_scale_map[original_var_name]
W
WangZhen 已提交
1356

1357
        if len(op_node.output_arg_names()) != 1:
1358 1359 1360 1361
            raise ValueError(
                "Only support one output, but op %s has"
                " more than one output." % (op_node.name())
            )
W
WangZhen 已提交
1362

1363
        output_var_node = graph._find_node_by_name(
1364 1365
            op_node.outputs, op_node.output_arg_names()[0]
        )
W
WangZhen 已提交
1366 1367
        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(output_var_node.name()),
1368 1369
            var_type=output_var_node.type(),
            shape=output_var_node.shape(),
1370 1371
            var_dtype=output_var_node.dtype(),
        )
W
WangZhen 已提交
1372 1373
        dequant_op_node = graph.create_op_node(
            op_type='fake_dequantize_max_abs',
1374 1375
            attrs={
                'max_range': float(max_range),
1376
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
1377
            },
1378 1379 1380
            inputs={'X': output_var_node, 'Scale': scale_var_node},
            outputs={'Out': dequant_var_node},
        )
W
WangZhen 已提交
1381 1382 1383
        graph.link_to(output_var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
1384
        self._op_output_rename_map[output_var_node.node] = dequant_var_node
W
WangZhen 已提交
1385 1386 1387 1388 1389
        return dequant_var_node

    def _load_var(self, name):
        return np.array(self._scope.find_var(name).get_tensor())

1390 1391 1392
    def _restore_var(self, name, array):
        tensor = self._scope.find_var(name).get_tensor()
        tensor.set(array, self._place)
W
WangZhen 已提交
1393 1394 1395

    def _remove_unused_var_nodes(self, graph):
        all_used_vars = set()
1396
        ops = graph.all_op_nodes()
W
WangZhen 已提交
1397 1398 1399 1400 1401 1402
        for op_node in ops:
            for input_node in op_node.inputs:
                all_used_vars.add(input_node)
            for output_node in op_node.outputs:
                all_used_vars.add(output_node)

1403 1404 1405
        all_used_vars = {n.node for n in all_used_vars}
        all_unused_vars = {
            n
1406 1407 1408 1409
            for n in filter(
                lambda node: node.node not in all_used_vars,
                graph.all_var_nodes(),
            )
1410
        }
W
WangZhen 已提交
1411 1412 1413 1414 1415 1416 1417
        graph.safe_remove_nodes(all_unused_vars)

    def _original_var_name(self, var_name):
        """
        Return the original variable name.
        """
        if var_name.endswith('.quantized.dequantized'):
1418
            return var_name[: -len('.quantized.dequantized')]
W
WangZhen 已提交
1419
        if var_name.endswith('.quantized'):
1420
            return var_name[: -len('.quantized')]
W
WangZhen 已提交
1421
        if var_name.endswith('.dequantized'):
1422
            return var_name[: -len('.dequantized')]
H
handiz 已提交
1423
        if var_name.endswith('@scale'):
1424
            return var_name[: -len('@scale')]
W
WangZhen 已提交
1425 1426 1427 1428 1429 1430 1431 1432 1433
        else:
            return var_name

    def _dequantized_var_name(self, var_name):
        """
        Return dequantized variable name for the input `var_name`.
        """
        return "%s.dequantized" % (var_name)

W
WangZhen 已提交
1434
    def _is_float(self, v):
1435 1436
        return (
            isinstance(v, float)
1437
            or isinstance(v, np.float16)
1438
            or isinstance(v, np.float32)
W
WangZhen 已提交
1439
            or isinstance(v, np.float64)
1440
        )
W
WangZhen 已提交
1441

1442

1443
class ConvertToInt8Pass:
1444
    def __init__(self, scope, place, quantizable_op_type=None):
1445 1446 1447 1448
        """
        Convert the weights into int8_t type.

        Args:
1449 1450
            scope(static.Scope): scope is used to get the weight tensor values.
            place(static.CPUPlace|static.CUDAPlace|str): place is used to restore the
1451 1452
                8bits weight tensors. If it's string, It can be ``cpu``, and ``gpu:x``,
                where ``x`` is the index of the GPUs.
1453 1454
            quantizable_op_type(list[str]): This input param will be removed latter. The pass
                will process all quantized op, so it is not necessary to set the input param.
1455
        """
1456 1457
        assert scope is not None, 'The scope cannot be set None.'
        assert place is not None, 'The place cannot be set None.'
1458
        self._scope = scope
1459
        self._place = _get_paddle_place(place)
1460 1461

    def apply(self, graph):
1462
        """
T
tianshuo78520a 已提交
1463 1464
        Convert weights' type of the graph. After that, the data type of the
        graph weights is int8_t.
1465 1466 1467

        Args:
            graph(IrGraph): the applied graph.
1468 1469
        Returns:
            None
1470
        """
1471 1472
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
        ops = graph.all_op_nodes()
1473 1474
        input_map = {}
        for op_node in ops:
1475 1476 1477 1478
            if (
                op_node.op().has_attr("quantization_type")
                and op_node.op().attr("quantization_type") == "qat_with_weight"
            ):
1479 1480 1481 1482
                for var_node in op_node.inputs:
                    name = var_node.name()
                    if name in persistable_vars:
                        if name not in input_map:
1483
                            int8_var_node = self._convert_to_int8(
1484 1485
                                graph, var_node
                            )
1486
                            input_map[name] = int8_var_node
1487 1488 1489
                        graph.update_input_link(
                            var_node, input_map[name], op_node
                        )
1490 1491 1492

        # remove the unused var node in the graph
        self._remove_unused_var_nodes(graph)
Z
Zhen Wang 已提交
1493
        graph.resolve_hazard()
1494 1495 1496 1497
        return graph

    def _convert_to_int8(self, graph, var_node):
        int8_var_node_name = var_node.name() + ".int8"
1498
        int8_var_node = graph.create_persistable_node(
1499
            name=int8_var_node_name,
1500 1501
            var_type=var_node.type(),
            shape=var_node.shape(),
1502 1503
            var_dtype=core.VarDesc.VarType.INT8,
        )
1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517
        array = self._load_var(var_node.name())
        self._scope.var(int8_var_node_name)
        self._store_var(int8_var_node_name, array, np.int8)
        return int8_var_node

    def _load_var(self, name):
        return np.array(self._scope.find_var(name).get_tensor())

    def _store_var(self, name, array, dtype):
        tensor = self._scope.find_var(name).get_tensor()
        tensor.set(array.astype(dtype), self._place)

    def _remove_unused_var_nodes(self, graph):
        all_used_vars = set()
1518
        ops = graph.all_op_nodes()
1519 1520 1521 1522 1523 1524
        for op_node in ops:
            for input_node in op_node.inputs:
                all_used_vars.add(input_node)
            for output_node in op_node.outputs:
                all_used_vars.add(output_node)

1525 1526 1527
        all_used_vars = {n.node for n in all_used_vars}
        all_unused_vars = {
            n
1528 1529 1530 1531
            for n in filter(
                lambda node: node.node not in all_used_vars,
                graph.all_var_nodes(),
            )
1532
        }
1533 1534 1535
        graph.safe_remove_nodes(all_unused_vars)


1536
class TransformForMobilePass:
1537
    def __init__(self):
1538
        """
T
tianshuo78520a 已提交
1539
        This pass is used to convert the frozen graph for paddle-mobile execution.
1540
        """
1541 1542
        self._fake_quant_op_names = _fake_quant_op_list
        self._fake_dequant_op_names = _fake_dequant_op_list
1543 1544

    def apply(self, graph):
1545 1546 1547 1548 1549 1550 1551
        """
        Because paddle-mobile use `quantize` an `dequantize` as the names of
        quantize operator and dequantize operator, the `apply` function just
        realize this logic.

        Args:
            graph(IrGraph): the graph will be transformed.
1552 1553
        Returns:
            None
1554
        """
1555
        ops = graph.all_op_nodes()
1556 1557 1558
        for op_node in ops:
            name = op_node.name()
            if name in self._fake_quant_op_names:
1559
                op_node.set_type('quantize')
1560 1561 1562 1563 1564 1565 1566
                quant_node = graph.create_op_node_from_desc(op_node.op())
                for input_node in op_node.inputs:
                    graph.link_to(input_node, quant_node)
                for output_node in op_node.outputs:
                    graph.link_to(quant_node, output_node)
                graph.safe_remove_nodes(op_node)
            if name in self._fake_dequant_op_names:
1567
                op_node.set_type('dequantize')
1568 1569 1570 1571 1572 1573
                dequant_node = graph.create_op_node_from_desc(op_node.op())
                for input_node in op_node.inputs:
                    graph.link_to(input_node, dequant_node)
                for output_node in op_node.outputs:
                    graph.link_to(dequant_node, output_node)
                graph.safe_remove_nodes(op_node)
Z
Zhen Wang 已提交
1574
        graph.resolve_hazard()
1575
        return graph
1576 1577


1578
class OutScaleForTrainingPass:
1579 1580 1581 1582 1583 1584 1585 1586
    def __init__(
        self,
        scope=None,
        place=None,
        moving_rate=0.9,
        is_test=None,
        scale_dict=None,
    ):
1587 1588 1589 1590 1591
        """
        This pass is used for calculating output scales of some operators.
        These output scales may be used by tensorRT or some other inference engines.

        Args:
1592 1593
            scope(static.Scope): The scope is used to initialize these new parameters.
            place(static.CPUPlace|static.CUDAPlace|str): The place is used to initialize new parameters.
1594 1595
                If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the
                index of the GPUs.
1596 1597 1598
            moving_rate(float): The decay coefficient of moving average. The default value is 0.9.
        """
        self._scope = scope
1599
        self._place = _get_paddle_place(place)
1600
        self._moving_rate = moving_rate
1601
        self._is_test = is_test
1602
        self._teller_set = list(SUPPORT_QUANTIZATION_OP_DICT.keys())
1603
        self._scale_dict = scale_dict
1604 1605 1606 1607 1608 1609 1610 1611 1612

    def apply(self, graph):
        """
        Insert the `moving_average_abs_max_scale` op in order to calculate output scales
        of operators in the teller_set.

        Args:
            graph(IrGraph): the target graph.
        """
1613 1614 1615
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
1616 1617
        if self._is_test is None:
            self._is_test = graph.is_test()
1618 1619 1620 1621
        target_ops = []
        for op in graph.all_op_nodes():
            if op.name() in self._teller_set:
                target_ops.append(op)
1622 1623 1624 1625 1626
        with tqdm(
            total=len(target_ops),
            bar_format='Adding OutScale op:|{bar}| {n_fmt}/{total_fmt}',
            ncols=80,
        ) as t:
1627 1628
            for op in target_ops:
                for output_var_name in utils._get_op_output_var_names(op):
1629 1630 1631 1632 1633 1634
                    in_node = graph._find_node_by_name(
                        op.outputs, output_var_name
                    )
                    if in_node.dtype() not in [
                        core.VarDesc.VarType.FP64,
                        core.VarDesc.VarType.FP32,
1635
                        core.VarDesc.VarType.FP16,
1636
                    ]:
1637
                        continue
1638

1639 1640 1641 1642 1643 1644 1645
                    if in_node.dtype() == core.VarDesc.VarType.FP64:
                        data_type = 'float64'
                    elif in_node.dtype() == core.VarDesc.VarType.FP32:
                        data_type = 'float32'
                    else:
                        data_type = "float16"

1646
                    try:
1647
                        graph._find_node_by_name(
1648
                            graph.all_var_nodes(),
1649 1650
                            self._scale_name(in_node.name()),
                        )
1651
                        continue
1652 1653 1654 1655 1656
                    except:
                        scale_node = graph.create_persistable_node(
                            name=self._scale_name(in_node.name()),
                            var_type=core.VarDesc.VarType.LOD_TENSOR,
                            shape=[1],
1657 1658
                            var_dtype=in_node.dtype(),
                        )
1659 1660 1661
                        if self._scale_dict is not None:
                            try:
                                scale_value = np.array(
1662 1663
                                    [self._scale_dict[in_node.name()]]
                                )
1664 1665 1666 1667
                            except:
                                scale_value = np.ones([1], dtype=data_type)
                        else:
                            scale_value = np.ones([1], dtype=data_type)
1668 1669 1670
                    _init_var_node(
                        scale_node, scale_value, self._scope, self._place
                    )
1671

1672 1673 1674 1675 1676 1677 1678
                    ins = {'X': in_node}
                    outs = {'OutScale': scale_node}
                    if not self._is_test:
                        state_in_node = graph.create_persistable_node(
                            name=unique_name.generate('scale_state@'),
                            var_type=core.VarDesc.VarType.LOD_TENSOR,
                            var_dtype=in_node.dtype(),
1679 1680 1681 1682 1683 1684 1685 1686
                            shape=[1],
                        )
                        _init_var_node(
                            state_in_node,
                            np.ones([1], dtype=data_type),
                            self._scope,
                            self._place,
                        )
1687 1688 1689 1690
                        accum_in_node = graph.create_persistable_node(
                            name=unique_name.generate('scale_accum@'),
                            var_type=core.VarDesc.VarType.LOD_TENSOR,
                            var_dtype=in_node.dtype(),
1691 1692 1693 1694 1695 1696 1697 1698
                            shape=[1],
                        )
                        _init_var_node(
                            accum_in_node,
                            np.ones([1], dtype=data_type),
                            self._scope,
                            self._place,
                        )
1699
                        state_out_node = graph.create_var_node_from_desc(
1700 1701
                            state_in_node.var()
                        )
1702
                        accum_out_node = graph.create_var_node_from_desc(
1703 1704
                            accum_in_node.var()
                        )
1705 1706 1707 1708 1709 1710 1711 1712 1713

                        ins['InState'] = state_in_node
                        ins['InAccum'] = accum_in_node
                        outs['OutState'] = state_out_node
                        outs['OutAccum'] = accum_out_node

                    attrs = {
                        'moving_rate': self._moving_rate,
                        'is_test': self._is_test,
1714
                        'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
1715 1716 1717 1718 1719
                    }
                    scale_op_node = graph.create_op_node(
                        op_type='moving_average_abs_max_scale',
                        attrs=attrs,
                        inputs=ins,
1720 1721
                        outputs=outs,
                    )
C
ceci3 已提交
1722 1723 1724 1725 1726

                    next_op_node = None
                    if len(in_node.outputs) > 0:
                        next_op_node = in_node.outputs[0]

1727 1728
                    graph.link_to(in_node, scale_op_node)
                    graph.link_to(scale_op_node, scale_node)
C
ceci3 已提交
1729 1730 1731
                    if next_op_node:
                        graph.link_to(scale_node, next_op_node)

1732 1733 1734 1735 1736 1737
                    if not self._is_test:
                        graph.link_to(state_in_node, scale_op_node)
                        graph.link_to(accum_in_node, scale_op_node)
                        graph.link_to(scale_op_node, state_out_node)
                        graph.link_to(scale_op_node, accum_out_node)
                t.update()
1738 1739 1740 1741 1742 1743
        return graph

    def _scale_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
1744
        return "%s@scale" % (var_name)
1745 1746


1747
class OutScaleForInferencePass:
1748 1749 1750 1751 1752 1753
    def __init__(self, scope=None):
        """
        This pass is used for setting output scales of some operators.
        These output scales may be used by tensorRT or some other inference engines.

        Args:
1754
            scope(static.Scope): The scope is used to initialize these new parameters.
1755 1756
        """
        self._scope = scope
1757
        self._teller_set = list(SUPPORT_QUANTIZATION_OP_DICT.keys())
1758 1759 1760 1761 1762 1763 1764 1765 1766

    def apply(self, graph):
        """
        Get output scales from the scope and set these scales in op_descs
        of operators in the teller_set.

        Args:
            graph(IrGraph): the target graph.
        """
1767 1768 1769
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
1770 1771 1772
        op_nodes = graph.all_op_nodes()
        for op_node in op_nodes:
            if op_node.name() in self._teller_set:
1773
                var_names = utils._get_op_output_var_names(op_node)
1774
                for var_name in var_names:
1775 1776 1777
                    in_node = graph._find_node_by_name(
                        op_node.outputs, var_name
                    )
C
ceci3 已提交
1778 1779 1780 1781 1782
                    if (in_node.node.var() is None) or (
                        in_node.dtype()
                        not in [
                            core.VarDesc.VarType.FP64,
                            core.VarDesc.VarType.FP32,
1783
                            core.VarDesc.VarType.FP16,
C
ceci3 已提交
1784 1785
                        ]
                    ):
1786 1787
                        continue

1788
                    scale_name = self._scale_name(var_name)
1789
                    scale_var = self._scope.find_var(scale_name)
1790 1791 1792 1793 1794
                    assert (
                        scale_var is not None
                    ), "Can not find {} variable in the scope".format(
                        scale_name
                    )
1795 1796 1797 1798
                    scale_value = np.array(scale_var.get_tensor())[0]

                    # For compatibility, we save output threshold by two methods.
                    op_node.op()._set_attr("out_threshold", float(scale_value))
1799

1800
                    argname_index = utils._get_output_name_index(
1801 1802 1803
                        op_node, var_name
                    )
                    assert argname_index is not None, (
1804
                        var_name + " is not the output of the op"
1805 1806 1807 1808 1809
                    )
                    op_node.op()._set_attr(
                        argname_index[0] + str(argname_index[1]) + "_threshold",
                        float(scale_value),
                    )
1810
                    op_node.op()._set_attr("with_quant_attr", True)
1811 1812 1813 1814 1815 1816 1817
        graph.resolve_hazard()
        return graph

    def _scale_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
1818
        return "%s@scale" % (var_name)
1819 1820


1821
class AddQuantDequantPass:
1822
    """
1823
    Quantize the ops that do not have weights, and add quant_dequant op for the
1824 1825
    quantized ops's inputs.
    """
1826

1827 1828 1829
    # To be compatible with PaddleSlim, not remove _activation_type for now
    _activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]

1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840
    def __init__(
        self,
        scope=None,
        place=None,
        moving_rate=0.9,
        quant_bits=8,
        skip_pattern=["skip_quant"],
        quantizable_op_type=["elementwise_add", "pool2d"],
        is_test=None,
        scale_dict=None,
    ):
1841
        """
1842
        Constructor.
1843 1844

        Args:
1845 1846
            scope(static.Scope): The scope is used to initialize these new parameters.
            place(static.CPUPlace|static.CUDAPlace|str): place is used to initialize new
1847 1848
                parameters described above. If ``place`` is string, it can be It can be ``cpu``
                or ``gpu:x``, where ``x`` is the index of the GPUs.
1849
            moving_rate(float, optional): the param for 'quant_dequant_moving_average_abs_max'
1850 1851 1852 1853 1854 1855
                quantization. Default is 0.9.
            quant_bits(int, optional): quantization bit number for activation. Default is 8.
            skip_pattern(str, optional): The user-defined quantization skip pattern, which
                will be presented in the name scope of an op. When the skip pattern is
                detected in an op's name scope, the corresponding op will not be quantized.
                Default is 'skip_quant'.
1856 1857
            quantizable_op_type(list[str], optional): List the type of ops that will be
                quantized. Default is ["elementwise_add", "pool2d"].
1858 1859
        """
        self._scope = scope
1860
        self._place = _get_paddle_place(place)
1861 1862
        self._moving_rate = moving_rate
        self._quant_bits = quant_bits
1863
        self._is_test = is_test
1864
        self._skip_pattern = skip_pattern
1865
        self._scale_dict = scale_dict
1866

1867 1868 1869 1870 1871
        self._quantizable_op_type = quantizable_op_type
        for op_type in self._quantizable_op_type:
            assert op_type in list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()), (
                op_type + " is not supported for quantization."
            )
1872 1873 1874 1875
        self._quantizable_grad_op_type = [
            '%s_grad' % (op) for op in self._quantizable_op_type
        ]

1876 1877
        assert self._scope is not None, "scope must not be None."
        assert self._place is not None, "place must not be None."
1878 1879 1880

    def apply(self, graph):
        """
1881 1882
        Add quant_dequant before some ops, such as the 'elementwise_add' and
        'pool2d' op.
1883

1884 1885
        Args:
            graph(IrGraph): the target graph.
1886 1887
        Returns:
            None
1888
        """
1889 1890 1891
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
1892 1893
        if self._is_test is None:
            self._is_test = graph.is_test()
1894 1895
        dequantized_vars_map = collections.OrderedDict()

1896 1897
        # Forward stage, insert quant_dequant op
        all_op_nodes = graph.all_op_nodes()
1898 1899 1900 1901 1902
        with tqdm(
            total=len(all_op_nodes),
            bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
            ncols=80,
        ) as t:
1903 1904 1905 1906
            for op_node in all_op_nodes:
                if op_node.name() in self._quantizable_op_type:
                    is_skip = False
                    if isinstance(self._skip_pattern, list):
1907 1908 1909 1910
                        is_skip = op_node.op().has_attr("op_namescope") and any(
                            pattern in op_node.op().attr("op_namescope")
                            for pattern in self._skip_pattern
                        )
1911
                    elif isinstance(self._skip_pattern, str):
1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928
                        is_skip = (
                            op_node.op().has_attr("op_namescope")
                            and op_node.op()
                            .attr("op_namescope")
                            .find(self._skip_pattern)
                            != -1
                        )
                    is_quantized = (
                        op_node.op().has_attr("quantization_type")
                        and op_node.op().attr("quantization_type")
                        == "qat_with_weight"
                    )
                    if (
                        is_skip
                        or is_quantized
                        or (not _is_input_all_not_persistable(graph, op_node))
                    ):
1929
                        continue
1930

1931 1932 1933
                    op_node.op()._set_attr(
                        "quantization_type", "qat_without_weight"
                    )
1934 1935 1936
                    op_node.op()._set_attr("activation_bits", self._quant_bits)
                    op_node.op()._set_attr("with_quant_attr", True)
                    arg_names = utils._get_op_input_var_names(op_node)
1937 1938 1939 1940 1941 1942 1943 1944 1945
                    # If already quanted, skip it.
                    skip_quant = False
                    for arg_name in arg_names:
                        if "quantized.dequantized" in arg_name:
                            skip_quant = True
                            break
                    if skip_quant:
                        continue

1946 1947
                    for arg_name in arg_names:
                        in_node = graph._find_node_by_name(
1948 1949
                            op_node.inputs, arg_name
                        )
1950 1951 1952
                        if arg_name in dequantized_vars_map:
                            quant_var_node = dequantized_vars_map[arg_name]
                        else:
1953 1954 1955 1956 1957 1958
                            (
                                quant_var_node,
                                _,
                            ) = self._inser_quant_dequant_moving_average_abs_max_op(
                                graph, in_node, self._quant_bits
                            )
1959
                            dequantized_vars_map[arg_name] = quant_var_node
1960 1961 1962
                        graph.update_input_link(
                            in_node, quant_var_node, op_node
                        )
1963
                t.update()
1964

1965 1966
        # Backward stage, update input link
        for op_node in all_op_nodes:
1967
            if op_node.name() in self._quantizable_grad_op_type:
1968 1969
                for input_name in op_node.input_arg_names():
                    if input_name in dequantized_vars_map:
1970
                        in_node = graph._find_node_by_name(
1971 1972
                            op_node.inputs, input_name
                        )
1973
                        dequant_var_node = dequantized_vars_map[input_name]
1974 1975 1976
                        graph.update_input_link(
                            in_node, dequant_var_node, op_node
                        )
1977

1978 1979 1980
        graph.resolve_hazard()
        return graph

1981 1982 1983 1984 1985 1986 1987 1988 1989 1990
    def _inser_quant_dequant_moving_average_abs_max_op(
        self, graph, var_node, quant_bits
    ):
        """Insert fake_quantize_dequantize_moving_average_abs_max op."""
        quant_var_node = graph.create_var_node(
            name="{}.quant_dequant".format(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype(),
        )
1991
        scale_name = "{}.quant_dequant@scale".format(var_node.name())
1992 1993 1994 1995 1996 1997
        if var_node.dtype() == core.VarDesc.VarType.FP64:
            data_type = 'float64'
        elif var_node.dtype() == core.VarDesc.VarType.FP32:
            data_type = 'float32'
        else:
            data_type = "float16"
1998
        try:
1999 2000 2001 2002 2003 2004 2005
            if (
                self._scale_dict is not None
                and var_node.name() in self._scale_dict.keys()
            ):
                scale_value = np.array(
                    [self._scale_dict[var_node.name()]], dtype=data_type
                )
2006 2007 2008
            else:
                scale_value = np.array(
                    self._scope.find_var(scale_name).get_tensor(),
2009 2010
                    dtype=data_type,
                )
2011 2012 2013
        except:
            scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)

2014
        scale_in_node = graph.create_persistable_node(
H
handiz 已提交
2015
            name="{}.quant_dequant@scale".format(var_node.name()),
2016 2017
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
2018 2019
            var_dtype=var_node.dtype(),
        )
2020

2021
        _init_var_node(scale_in_node, scale_value, self._scope, self._place)
2022 2023 2024 2025 2026 2027 2028 2029
        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        ins = {'X': var_node, 'InScale': scale_in_node}
        outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
        if not self._is_test:
            state_in_node = graph.create_persistable_node(
                name=unique_name.generate('quant_dequant.state'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
2030 2031
                shape=[1],
            )
2032 2033 2034 2035 2036 2037
            if var_node.dtype() == core.VarDesc.VarType.FP64:
                data_type = 'float64'
            elif var_node.dtype() == core.VarDesc.VarType.FP32:
                data_type = 'float32'
            else:
                data_type = "float16"
2038 2039 2040 2041 2042 2043
            _init_var_node(
                state_in_node,
                np.ones([1], dtype=data_type),
                self._scope,
                self._place,
            )
2044 2045 2046 2047
            accum_in_node = graph.create_persistable_node(
                name=unique_name.generate('quant_dequant.accum'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
2048 2049 2050 2051 2052 2053 2054 2055
                shape=[1],
            )
            _init_var_node(
                accum_in_node,
                np.ones([1], dtype=data_type),
                self._scope,
                self._place,
            )
2056
            state_out_node = graph.create_var_node_from_desc(
2057 2058
                state_in_node.var()
            )
2059
            accum_out_node = graph.create_var_node_from_desc(
2060 2061
                accum_in_node.var()
            )
2062 2063 2064 2065 2066 2067 2068 2069 2070 2071

            ins['InState'] = state_in_node
            ins['InAccum'] = accum_in_node
            outs['OutState'] = state_out_node
            outs['OutAccum'] = accum_out_node

        attrs = {
            'bit_length': quant_bits,
            'moving_rate': self._moving_rate,
            'is_test': self._is_test,
2072
            'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
2073 2074 2075 2076 2077 2078
        }

        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_dequantize_moving_average_abs_max',
            attrs=attrs,
            inputs=ins,
2079 2080
            outputs=outs,
        )
2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_in_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_out_node)

        if not self._is_test:
            graph.link_to(state_in_node, quant_op_node)
            graph.link_to(accum_in_node, quant_op_node)
            graph.link_to(quant_op_node, state_out_node)
            graph.link_to(quant_op_node, accum_out_node)

        return quant_var_node, scale_out_node
2094 2095


2096
class InsertQuantizeLinear:
2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108
    """
    Insert quantize_linear and dequantize_linear op before ops.

    Args:
        place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to restore the weight tensors.
            If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs.
        scope(paddle.Scope): scope is used to get the weight tensor values.
        quant_bits(int, optional): quantization bit number for weight. Default is 8.
        quant_axis(int, optional): quantization dimension of channels. When it is greater than or
            equal to 0, it will quantization with per channel, else quantization with per layer.
            Default is -1.
        channel_wise(bool, optional): Whether quantization with per channel or not. Default is False.
2109
        moving_rate(float): the rate for 'moving average' method.
2110
        is_test(bool, optional): Whether quantization with training or not. Default is True.
2111
        scale_dict(dict, optional): calibration ranges of tensors output.
2112 2113
    """

2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124
    def __init__(
        self,
        place,
        scope,
        quant_bits=8,
        quant_axis=-1,
        channel_wise=False,
        moving_rate=0.9,
        is_test=True,
        scale_dict=None,
    ):
2125 2126 2127 2128 2129 2130
        self._place = place
        self._scope = scope
        self.quant_bits = quant_bits
        self.quant_axis = quant_axis
        self.channel_wise = channel_wise
        self._is_test = is_test
2131
        self._moving_rate = moving_rate
2132
        self._scale_dict = scale_dict
2133

2134 2135 2136
    def insert_quant_op(
        self, graph, var_node, var_name=None, scale_var_node=None
    ):
2137
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())
2138 2139 2140 2141 2142
        var_name = var_node.name() if not var_name else var_name
        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_name),
            var_type=var_node.type(),
            shape=var_node.shape(),
2143 2144
            var_dtype=var_node.dtype(),
        )
2145
        if not scale_var_node:
2146 2147 2148 2149 2150 2151
            if var_node.dtype() == core.VarDesc.VarType.FP64:
                data_type = 'float64'
            elif var_node.dtype() == core.VarDesc.VarType.FP32:
                data_type = 'float32'
            else:
                data_type = "float16"
2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165
            scale_name = self._quantized_scale_name(var_name)
            if self.channel_wise:
                scale_var_shape = var_node.shape()[self.quant_axis]
                scale_var_type = core.VarDesc.VarType.LOD_TENSOR
                init_scale_value = (
                    np.ones(scale_var_shape, dtype=data_type)
                    * _SCALE_DEFAULT_VALUE
                )
            else:
                scale_var_shape = 1
                scale_var_type = var_node.type()
                init_scale_value = np.array(
                    [_SCALE_DEFAULT_VALUE], dtype=data_type
                )
2166

2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181
            if (
                self._scale_dict is not None
                and var_node.name() in self._scale_dict.keys()
            ):
                init_scale_value = np.array(
                    [self._scale_dict[var_node.name()]], dtype=data_type
                )
            scale_var_node = graph.create_persistable_node(
                name=scale_name,
                var_type=scale_var_type,
                shape=[scale_var_shape],
                var_dtype=var_node.dtype(),
            )
            _init_var_node(
                scale_var_node, init_scale_value, self._scope, self._place
2182
            )
2183 2184 2185 2186 2187 2188 2189

        zero_point_node = None
        if zero_point_node is None:
            zero_point_node = graph.create_persistable_node(
                name=self._zero_point_name(quant_var_node.name()),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                shape=scale_var_node.shape(),
2190 2191 2192 2193 2194 2195 2196 2197
                var_dtype=core.VarDesc.VarType.INT32,
            )
            _init_var_node(
                zero_point_node,
                np.zeros(scale_var_node.shape(), dtype="int32"),
                self._scope,
                self._place,
            )
2198 2199 2200 2201 2202

        inputs = {"X": var_node, "Scale": scale_var_node}
        if zero_point_node is not None:
            inputs["ZeroPoint"] = zero_point_node

2203
        attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
2204
        attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
2205 2206
        outputs = {"Y": quant_var_node}
        if not self._is_test:
2207
            scale_out_node = graph.create_var_node_from_desc(
2208 2209
                scale_var_node.var()
            )
2210 2211 2212 2213
            state_in_node = graph.create_persistable_node(
                name=unique_name.generate('state'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
2214 2215
                shape=[1],
            )
2216 2217 2218 2219 2220 2221
            if var_node.dtype() == core.VarDesc.VarType.FP64:
                data_type = 'float64'
            elif var_node.dtype() == core.VarDesc.VarType.FP32:
                data_type = 'float32'
            else:
                data_type = "float16"
2222 2223 2224 2225 2226 2227
            _init_var_node(
                state_in_node,
                np.ones([1], dtype=data_type),
                self._scope,
                self._place,
            )
2228 2229 2230 2231
            accum_in_node = graph.create_persistable_node(
                name=unique_name.generate('accum'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
2232 2233 2234 2235 2236 2237 2238 2239
                shape=[1],
            )
            _init_var_node(
                accum_in_node,
                np.ones([1], dtype=data_type),
                self._scope,
                self._place,
            )
2240
            state_out_node = graph.create_var_node_from_desc(
2241 2242
                state_in_node.var()
            )
2243
            accum_out_node = graph.create_var_node_from_desc(
2244 2245
                accum_in_node.var()
            )
2246

2247
            outputs["OutScale"] = scale_out_node
2248 2249 2250 2251 2252 2253
            inputs['InState'] = state_in_node
            inputs['InAccum'] = accum_in_node
            outputs['OutState'] = state_out_node
            outputs['OutAccum'] = accum_out_node
            attrs["is_test"] = self._is_test
            attrs['moving_rate'] = self._moving_rate
2254

2255 2256 2257 2258 2259 2260
        quant_op_node = graph.create_op_node(
            op_type="quantize_linear",
            attrs=attrs,
            inputs=inputs,
            outputs=outputs,
        )
2261 2262 2263 2264 2265 2266 2267

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_var_node, quant_op_node)
        if zero_point_node is not None:
            graph.link_to(zero_point_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        if not self._is_test:
2268 2269 2270 2271
            graph.link_to(state_in_node, quant_op_node)
            graph.link_to(accum_in_node, quant_op_node)
            graph.link_to(quant_op_node, state_out_node)
            graph.link_to(quant_op_node, accum_out_node)
2272 2273 2274 2275 2276 2277 2278 2279 2280 2281
            graph.link_to(quant_op_node, scale_out_node)
        return quant_var_node, scale_var_node

    def insert_dequant_op(self, graph, var_node, scale_var_node):
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
2282 2283
            var_dtype=var_node.dtype(),
        )
2284 2285 2286 2287 2288 2289 2290

        zero_point_node = None
        if zero_point_node is None:
            zero_point_node = graph.create_persistable_node(
                name=self._zero_point_name(dequant_var_node.name()),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                shape=scale_var_node.shape(),
2291 2292 2293 2294 2295 2296 2297 2298
                var_dtype=core.VarDesc.VarType.INT32,
            )
            _init_var_node(
                zero_point_node,
                np.zeros(scale_var_node.shape(), dtype="int32"),
                self._scope,
                self._place,
            )
2299 2300 2301 2302 2303 2304

        inputs = {"X": var_node, "Scale": scale_var_node}
        if zero_point_node is not None:
            inputs["ZeroPoint"] = zero_point_node

        attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
2305
        attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
2306

2307 2308 2309 2310 2311 2312
        quant_op_node = graph.create_op_node(
            op_type="dequantize_linear",
            attrs=attrs,
            inputs=inputs,
            outputs={"Y": dequant_var_node},
        )
2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_var_node, quant_op_node)
        if zero_point_node is not None:
            graph.link_to(zero_point_node, quant_op_node)
        graph.link_to(quant_op_node, dequant_var_node)
        return dequant_var_node

    def _quantized_var_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.quantized" % (var_name)

    def _dequantized_var_name(self, var_name):
        """
        Return dequantized variable name for the input `var_name`.
        """
        return "%s.dequantized" % (var_name)

    def _quantized_scale_name(self, var_name):
        """
        Return the scale name of quantized variable for the input `var_name`.
        """
H
handiz 已提交
2337
        return "%s@scale" % (var_name)
2338 2339 2340 2341 2342 2343 2344 2345

    def _zero_point_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
        return "%s@zero_point" % (var_name)


2346
class QuantizationTransformPassV2(QuantizationTransformPass):
2347 2348
    """
    Quantize the ops that have weights. Add quant and dequant ops for
2349
    the quantized ops's inputs. It is used in the new format of quantization.
2350 2351
    """

2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371
    def __init__(
        self,
        scope=None,
        place=None,
        weight_bits=8,
        activation_bits=8,
        activation_quantize_type='abs_max',
        weight_quantize_type='abs_max',
        window_size=10000,
        moving_rate=0.9,
        skip_pattern=['skip_quant'],
        quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'],
        weight_quantize_func=None,
        act_quantize_func=None,
        weight_preprocess_func=None,
        act_preprocess_func=None,
        optimizer_func=None,
        executor=None,
        is_test=None,
    ):
2372 2373 2374 2375 2376 2377 2378
        r"""
        Args:
            scope(paddle.Scope): When activation use 'range_abs_max' as the quantize
                type, this pass will create some new parameters. The scope is used to
                initialize these new parameters.
            place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to initialize new
                parameters described above. If it's string, It can be ``cpu``, and ``gpu:x``,
2379
                where ``x`` is the index of the GPUs.
2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396
            weight_bits(int): quantization bit number for weights,
                the bias is not quantized.
            activation_bits(int): quantization bit number for activation.
            activation_quantize_type(str): quantization type for activation,
                now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
                If use 'abs_max' mode, the quantization scale will be calculated
                dynamically each step in both training and testing period. If use
                'range_abs_max', a static quantization scale will be calculated
                during training and used in inference.
            weight_quantize_type(str): quantization type for weights,
                support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max'
                usually is not used for weight, since weights are fixed once the
                model is well trained.
            window_size(int): the window size for 'range_abs_max' quantization.
            moving_rate(float): the param for 'moving_average_abs_max' quantization.
            skip_pattern(str or str list): The user-defined quantization skip pattern, which
                will be presented in the name scope of an op. When the skip pattern is
2397 2398
                detected in an op's name scope, the corresponding op will not be quantized.
            quantizable_op_type(list[str]): List the type of ops that will be quantized.
2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433
                Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
                QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
            weight_quantize_func(function): Function that defines how to quantize weight.
                Using this can quickly test if user's quantization method works or not.
                In this function, user should both define quantization function and
                dequantization function, that is, the function's input is non-quantized
                weight and function returns dequantized weight. If None, will use
                quantization op defined by 'weight_quantize_type'. Default is None.
            act_quantize_func(function): Function that defines how to quantize activation.
                Using this can quickly test if user's quantization method works or not.
                In this function, user should both define quantization and dequantization
                process, that is, the function's input is non-quantized activation and
                function returns dequantized activation. If None, will use quantization
                op defined by 'activation_quantize_type'. Default is None.
            weight_preprocess_func(function): Function that defines how to preprocess
                weight before quantization. Using this can quickly test if user's preprocess
                method works or not. The function's input is non-quantized weight and
                function returns processed weight to be quantized. If None, the weight will
                be quantized directly. Default is None.
            act_preprocess_func(function): Function that defines how to preprocess
                activation before quantization. Using this can quickly test if user's
                preprocess method works or not. The function's input is non-quantized
                activation and function returns processed activation to be quantized.
                If None, the activation will be quantized directly. Default is None.
            optimizer_func(function): Fuction return a optimizer. When 'is_test' is
                False and user want to use self-defined quantization function and
                preprocess function, this function must be set. Default is None.
            executor(paddle.Executor): If user want to use self-defined quantization
                function and preprocess function, executor must be set for initialization.
                Default is None.

        Examples:
        .. code-block:: python
            # The original graph will be rewrite.
            import paddle
2434
            from paddle.static.quantization \
2435
                import QuantizationTransformPassV2
2436 2437
            from paddle.fluid.framework import IrGraph
            from paddle.framework import core
2438

2439
            graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456
            place = paddle.CPUPlace()
            scope = paddle.static.global_scope()
            transform_pass = QuantizationTransformPassV2(scope, place)
            transform_pass.apply(graph)
        """
        self._scope = scope
        self._place = _get_paddle_place(place)
        self._weight_bits = weight_bits
        self._activation_bits = activation_bits
        self._skip_pattern = skip_pattern
        self._weight_quantize_func = weight_quantize_func
        self._act_quantize_func = act_quantize_func
        self._weight_preprocess_func = weight_preprocess_func
        self._act_preprocess_func = act_preprocess_func
        self._optimizer = optimizer_func
        self._exe = executor
        quant_type = [
2457 2458 2459 2460
            'abs_max',
            'channel_wise_abs_max',
            'range_abs_max',
            'moving_average_abs_max',
2461
        ]
2462 2463 2464
        assert (
            activation_quantize_type != 'channel_wise_abs_max'
        ), "The activation quantization type does not support 'channel_wise_abs_max'."
2465 2466 2467
        if activation_quantize_type not in quant_type:
            raise ValueError(
                "Unknown activation_quantize_type : '%s'. It can only be "
2468 2469 2470
                "'abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
                % (str(activation_quantize_type))
            )
2471 2472 2473 2474
        if weight_quantize_type not in quant_type:
            raise ValueError(
                "Unknown weight_quantize_type: '%s'. It can only be "
                "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' "
2475 2476
                "or 'moving_average_abs_max'." % (str(weight_quantize_type))
            )
2477 2478 2479 2480 2481 2482 2483 2484

        self._activation_quantize_type = activation_quantize_type
        self._weight_quantize_type = weight_quantize_type
        self._window_size = window_size
        self._moving_rate = moving_rate

        self._quantizable_ops = quantizable_op_type
        for op in self._quantizable_ops:
2485
            assert op in list(SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()), (
2486
                op + " is not supported for quantization."
2487
            )
2488 2489 2490
        self._quantizable_grad_ops = [
            '%s_grad' % (op) for op in self._quantizable_ops
        ]
2491
        self._is_test = is_test
2492 2493 2494 2495 2496 2497 2498 2499
        self._global_step = None

        self.create_var_map = {}
        self.create_op_map = {}

    def _quant_preprocess(self, op_node):
        user_skipped = False
        if isinstance(self._skip_pattern, list):
2500 2501 2502 2503
            user_skipped = op_node.op().has_attr("op_namescope") and any(
                pattern in op_node.op().attr("op_namescope")
                for pattern in self._skip_pattern
            )
2504
        elif isinstance(self._skip_pattern, str):
2505 2506 2507 2508 2509
            user_skipped = (
                op_node.op().has_attr("op_namescope")
                and op_node.op().attr("op_namescope").find(self._skip_pattern)
                != -1
            )
2510 2511 2512 2513 2514 2515 2516

        if user_skipped:
            op_node.op()._set_attr("skip_quant", True)
            op_node.op()._set_attr("with_quant_attr", True)

    def _transform_forward(self, graph, op):
        op.op()._set_attr("quantization_type", "qat_with_weight")
2517
        weight_scale_node = None
2518 2519 2520 2521 2522 2523 2524 2525 2526 2527
        inputs = op.inputs
        for var_node in inputs:
            if var_node.name() not in op.input_arg_names():
                continue
            if var_node.name() in self.dequantized_vars:
                dequant_var_node = self.dequantized_vars[var_node.name()]
            else:
                name = var_node.name()
                if name in self.processed_vars:
                    continue
2528 2529 2530
                is_weight = (
                    True if var_node.name() in self.persistable_vars else False
                )
2531 2532

                # if var node is weight and weight_preprocess_func is not None,
2533
                # will insert weight preprocess func
2534
                # to preorocess weight before quantization
2535 2536
                # if var node is activation and act_preprocess_func is not None,
                # will insert activation preprocess func
2537 2538
                # to preorocess activation before quantization
                if is_weight and self._weight_preprocess_func is not None:
2539 2540 2541
                    var_node = self._insert_func(
                        graph, self._weight_preprocess_func, var_node, op
                    )
2542
                elif not is_weight and self._act_preprocess_func is not None:
2543 2544 2545
                    var_node = self._insert_func(
                        graph, self._act_preprocess_func, var_node, op
                    )
2546 2547 2548 2549 2550 2551 2552

                # if var node is weight and weight_quantize_func is not None,
                # will insert weight quantize func to quantize and dequantize weight
                # if var node is activation and act_quantize_func is not None,
                # will insert act quantize func to quantize and dequantize activation
                if is_weight and self._weight_quantize_func is not None:
                    target_out_node = self._insert_func(
2553 2554
                        graph, self._weight_quantize_func, var_node, op
                    )
2555
                    self.processed_vars.append(name)
2556 2557
                    continue
                elif not is_weight and self._act_quantize_func is not None:
2558 2559 2560
                    target_out_node = self._insert_func(
                        graph, self._act_quantize_func, var_node, op
                    )
2561
                    self.processed_vars.append(name)
2562 2563
                    continue

2564 2565 2566
                quant_bits = (
                    self._weight_bits
                    if var_node.name() in self.persistable_vars
2567
                    else self._activation_bits
2568 2569 2570 2571
                )
                quant_type = (
                    self._weight_quantize_type
                    if is_weight
2572
                    else self._activation_quantize_type
2573
                )
2574 2575 2576 2577
                quant_axis = -1
                channel_wise = False
                if quant_type == 'channel_wise_abs_max':  # Weight quantization
                    channel_wise = True
2578 2579 2580 2581 2582
                    quant_axis = (
                        1
                        if op.name() in utils._channelwise_quant_axis1_ops
                        else 0
                    )
2583 2584 2585 2586 2587 2588
                insert_quant_pass = InsertQuantizeLinear(
                    self._place,
                    self._scope,
                    quant_bits=quant_bits,
                    quant_axis=quant_axis,
                    channel_wise=channel_wise,
2589
                    moving_rate=self._moving_rate,
2590 2591 2592 2593 2594 2595 2596 2597
                    is_test=self._is_test,
                )
                (
                    quant_var_node,
                    scale_var_node,
                ) = insert_quant_pass.insert_quant_op(
                    graph, var_node, var_name=name
                )
2598
                dequant_var_node = insert_quant_pass.insert_dequant_op(
2599 2600
                    graph, quant_var_node, scale_var_node
                )
2601 2602

                self.dequantized_vars[name] = dequant_var_node
2603 2604
                if is_weight:
                    weight_scale_node = scale_var_node
2605
            graph.update_input_link(var_node, dequant_var_node, op)
2606
        return weight_scale_node
2607 2608 2609 2610 2611 2612 2613 2614 2615 2616 2617 2618 2619 2620 2621 2622 2623 2624

    def _transform_backward(self, graph, op):
        for var_node in op.inputs:
            if var_node.name() not in op.input_arg_names():
                continue
            if var_node.name() in self.dequantized_vars:
                dequant_var_node = self.dequantized_vars[var_node.name()]
                graph.update_input_link(var_node, dequant_var_node, op)

    def _has_weight(self, op):
        has_weight = False
        for var_node in op.inputs:
            if var_node.name() not in op.input_arg_names():
                continue
            if var_node.name() in self.persistable_vars:
                has_weight = True
        return has_weight

2625 2626 2627 2628 2629 2630 2631 2632 2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684 2685
    def _quant_conv1d(self, graph, op):
        # conv1d in inference is a combination of unsqueeze2 + conv2d
        if ("conv2d" not in op.name()) or (
            "unsqueeze2" not in op.input("Filter")[0]
        ):
            return
        conv_weight_var_name = op.input("Filter")[0]
        # unsqueeze2 and conv2d will share weight scale
        weight_scale_node = None
        # quant unsqueeze2
        for _op in graph.all_op_nodes():
            var_names = utils._get_op_output_var_names(_op)
            if conv_weight_var_name in var_names and self._has_weight(_op):
                weight_scale_node = self._transform_forward(graph, _op)
        # insert qdq before conv2d
        for var_node in op.inputs:
            quant_bits = (
                self._weight_bits
                if var_node.name() == conv_weight_var_name
                else self._activation_bits
            )
            quant_type = (
                self._weight_quantize_type
                if var_node.name() == conv_weight_var_name
                else self._activation_quantize_type
            )
            quant_axis = -1
            channel_wise = False
            if quant_type == 'channel_wise_abs_max':
                channel_wise = True
                quant_axis = (
                    1 if op.name() in utils._channelwise_quant_axis1_ops else 0
                )
            insert_quant_pass = InsertQuantizeLinear(
                self._place,
                self._scope,
                quant_bits=quant_bits,
                quant_axis=quant_axis,
                channel_wise=channel_wise,
                moving_rate=self._moving_rate,
                is_test=self._is_test,
            )
            scale_var_node = (
                weight_scale_node
                if var_node.name() == conv_weight_var_name
                else None
            )
            (
                quant_var_node,
                scale_var_node,
            ) = insert_quant_pass.insert_quant_op(
                graph,
                var_node,
                var_name=var_node.name(),
                scale_var_node=scale_var_node,
            )
            dequant_var_node = insert_quant_pass.insert_dequant_op(
                graph, quant_var_node, scale_var_node
            )
            graph.update_input_link(var_node, dequant_var_node, op)

2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696
    def apply(self, graph):
        """
        Quantize the graph for training process. According to weight and
        activation quantization type, the graph will be added some fake
        quantize operators and fake dequantize operators.

        Args:
            graph(IrGraph): the applied graph.
        Returns:
            None
        """
2697 2698 2699
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
2700 2701
        if self._is_test is None:
            self._is_test = graph.is_test()
2702 2703 2704 2705
        # marked the variable which has been dequantized.
        self.dequantized_vars = collections.OrderedDict()
        self.persistable_vars = []
        self.processed_vars = []
2706 2707 2708 2709 2710 2711 2712 2713 2714

        self.persistable_vars = [
            p.name() for p in graph.all_persistable_nodes()
        ]

        ops = graph.all_op_nodes()
        # Do the preproccess of quantization, such as skipping some ops
        # for not being quantized.
        for op in ops:
2715 2716 2717 2718
            if (
                op.name() in self._quantizable_ops
                or op.name() in self._quantizable_grad_ops
            ):
2719 2720 2721 2722 2723
                self._quant_preprocess(op)
        # Insert mapping table to solve the problem in saving inference model.
        graph.out_node_mapping_table = dict()
        # The process of _transform_forward and _transform_backward is needed in two for loops.
        # The loop for transforming the forward graph:
2724 2725 2726 2727 2728
        with tqdm(
            total=len(ops),
            bar_format='Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}',
            ncols=80,
        ) as t:
2729 2730
            for op in ops:
                if op.name() in self._quantizable_ops:
2731 2732 2733
                    if not self._is_skip_quant(graph, op) and self._has_weight(
                        op
                    ):
2734
                        self._transform_forward(graph, op)
2735 2736 2737
                    else:  # op is not persistable
                        # support conv1d quantization
                        self._quant_conv1d(graph, op)
2738
                t.update()
2739 2740 2741 2742 2743 2744 2745
        # The loop for renaming the inputs of backward op.
        for op in ops:
            if op.name() in self._quantizable_grad_ops and self._has_weight(op):
                self._transform_backward(graph, op)
        return graph


2746
class AddQuantDequantPassV2:
2747 2748
    """
    Quantize the ops that do not have weights, and add quant_linear and dequant_linear
2749
    op for the quantized ops's inputs. It is used in the new format of quantization.
2750 2751 2752 2753 2754
    """

    # To be compatible with PaddleSlim, not remove _activation_type for now
    _activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]

2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765
    def __init__(
        self,
        scope=None,
        place=None,
        moving_rate=0.9,
        quant_bits=8,
        skip_pattern=["skip_quant"],
        quantizable_op_type=["elementwise_add", "pool2d"],
        is_test=None,
        scale_dict=None,
    ):
2766 2767 2768 2769 2770 2771
        """
        Args:
            scope(paddle.Scope): The scope is used to initialize these new parameters.
            place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to initialize new
                parameters described above. If ``place`` is string, it can be It can be ``cpu``
                or ``gpu:x``, where ``x`` is the index of the GPUs.
2772
            moving_rate(float, optional): the param for 'quant_dequant_moving_average_abs_max'
2773 2774 2775 2776 2777 2778
                quantization. Default is 0.9.
            quant_bits(int, optional): quantization bit number for activation. Default is 8.
            skip_pattern(str, optional): The user-defined quantization skip pattern, which
                will be presented in the name scope of an op. When the skip pattern is
                detected in an op's name scope, the corresponding op will not be quantized.
                Default is 'skip_quant'.
2779 2780
            quantizable_op_type(list[str], optional): List the type of ops that will be
                quantized. Default is ["elementwise_add", "pool2d"].
2781
            scale_dict(dict, optional): calibration ranges of tensors output.
2782

2783 2784 2785 2786
        Examples:
        .. code-block:: python
            # The original graph will be rewrite.
            import paddle
2787
            from paddle.static.quantization \
2788
                import AddQuantDequantPassV2
2789 2790
            from paddle.fluid.framework import IrGraph
            from paddle.framework import core
2791

2792
            graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
2793 2794 2795 2796 2797 2798 2799 2800 2801
            place = paddle.CPUPlace()
            scope = paddle.static.global_scope()
            add_quant_dequant_pass = AddQuantDequantPassV2(scope, place)
            add_quant_dequant_pass.apply(graph)
        """
        self._scope = scope
        self._place = _get_paddle_place(place)
        self._moving_rate = moving_rate
        self._quant_bits = quant_bits
2802
        self._is_test = is_test
2803
        self._skip_pattern = skip_pattern
2804
        self._scale_dict = scale_dict
2805

2806 2807 2808 2809 2810
        self._quantizable_op_type = quantizable_op_type
        for op_type in self._quantizable_op_type:
            assert op_type in list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()), (
                op_type + " is not supported for quantization."
            )
2811 2812 2813 2814
        self._quantizable_grad_op_type = [
            '%s_grad' % (op) for op in self._quantizable_op_type
        ]

2815 2816
        assert self._scope is not None, "scope must not be None."
        assert self._place is not None, "place must not be None."
2817 2818 2819 2820 2821 2822 2823 2824 2825 2826 2827 2828
        self.persistable_vars = []

    def apply(self, graph):
        """
        Add quant_dequant before some ops, such as the 'elementwise_add' and
        'pool2d' op.

        Args:
            graph(IrGraph): the target graph.
        Returns:
            None
        """
2829 2830 2831
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
2832 2833
        if self._is_test is None:
            self._is_test = graph.is_test()
2834 2835 2836 2837 2838 2839 2840 2841
        dequantized_vars_map = collections.OrderedDict()

        self.persistable_vars = [
            p.name() for p in graph.all_persistable_nodes()
        ]

        # Forward stage, insert quant_dequant op
        all_op_nodes = graph.all_op_nodes()
2842 2843 2844 2845 2846
        with tqdm(
            total=len(all_op_nodes),
            bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
            ncols=80,
        ) as t:
2847 2848 2849 2850
            for op_node in all_op_nodes:
                if op_node.name() in self._quantizable_op_type:
                    is_skip = False
                    if isinstance(self._skip_pattern, list):
2851 2852 2853 2854
                        is_skip = op_node.op().has_attr("op_namescope") and any(
                            pattern in op_node.op().attr("op_namescope")
                            for pattern in self._skip_pattern
                        )
2855
                    elif isinstance(self._skip_pattern, str):
2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866 2867
                        is_skip = (
                            op_node.op().has_attr("op_namescope")
                            and op_node.op()
                            .attr("op_namescope")
                            .find(self._skip_pattern)
                            != -1
                        )
                    is_quantized = (
                        op_node.op().has_attr("quantization_type")
                        and op_node.op().attr("quantization_type")
                        == "qat_with_weight"
                    )
2868
                    if is_skip or is_quantized:
2869
                        continue
2870 2871

                    arg_names = utils._get_op_input_var_names(op_node)
2872 2873 2874 2875 2876 2877 2878 2879 2880
                    # If already quanted, skip it.
                    skip_quant = False
                    for arg_name in arg_names:
                        if "quantized.dequantized" in arg_name:
                            skip_quant = True
                            break
                    if skip_quant:
                        continue

2881 2882
                    for arg_name in arg_names:
                        in_node = graph._find_node_by_name(
2883 2884
                            op_node.inputs, arg_name
                        )
2885 2886 2887 2888 2889 2890 2891 2892 2893 2894 2895
                        if in_node.persistable():
                            continue
                        if arg_name in dequantized_vars_map:
                            dequant_var_node = dequantized_vars_map[arg_name]
                        else:
                            insert_quant_pass = InsertQuantizeLinear(
                                self._place,
                                self._scope,
                                quant_bits=self._quant_bits,
                                quant_axis=-1,
                                channel_wise=False,
2896
                                moving_rate=self._moving_rate,
2897
                                is_test=self._is_test,
2898 2899 2900 2901 2902 2903 2904 2905 2906 2907 2908 2909 2910
                                scale_dict=self._scale_dict,
                            )
                            (
                                quant_var_node,
                                scale_var_node,
                            ) = insert_quant_pass.insert_quant_op(
                                graph, in_node
                            )
                            dequant_var_node = (
                                insert_quant_pass.insert_dequant_op(
                                    graph, quant_var_node, scale_var_node
                                )
                            )
2911
                            dequantized_vars_map[arg_name] = dequant_var_node
2912 2913 2914
                        graph.update_input_link(
                            in_node, dequant_var_node, op_node
                        )
2915
                t.update()
2916 2917 2918 2919 2920 2921

        # Backward stage, update input link
        for op_node in all_op_nodes:
            if op_node.name() in self._quantizable_grad_op_type:
                for input_name in op_node.input_arg_names():
                    if input_name in dequantized_vars_map:
2922
                        in_node = graph._find_node_by_name(
2923 2924
                            op_node.inputs, input_name
                        )
2925
                        dequant_var_node = dequantized_vars_map[input_name]
2926 2927 2928
                        graph.update_input_link(
                            in_node, dequant_var_node, op_node
                        )
2929 2930 2931 2932

        return graph


2933
class ReplaceFakeQuantDequantPass:
2934 2935 2936 2937
    """
    replace quant-dequant ops with quantize_linear and dequantize_linear ops.
    """

2938
    def __init__(self, scope, place, quant_bits=8):
2939 2940 2941 2942 2943 2944
        r"""
        Args:
            scope(paddle.Scope): The scope is used to initialize these new parameters.
            place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to initialize new
                parameters described above. If ``place`` is string, it can be It can be ``cpu``
                or ``gpu:x``, where ``x`` is the index of the GPUs.
2945
            quant_bits(int, optional): quantization bit number for activation. Default is 8.
2946

2947 2948 2949 2950
        Examples:
        .. code-block:: python
            # The original graph will be rewrite.
            import paddle
2951
            from paddle.static.quantization \
2952
                import ReplaceFakeQuantDequantPass
2953 2954
            from paddle.fluid.framework import IrGraph
            from paddle.framework import core
2955

2956
            graph = IrGraph(core.Graph(static.Program().desc), for_test=False)
2957 2958 2959 2960 2961 2962 2963
            place = paddle.CPUPlace()
            scope = paddle.static.global_scope()
            replace_pass = ReplaceFakeQuantDequantPass(scope, place)
            replace_pass.apply(graph)
        """
        self._place = _get_paddle_place(place)
        self._scope = scope
2964
        self._quant_bits = quant_bits
2965 2966
        assert self._scope is not None, "scope must not be None."
        assert self._place is not None, "place must not be None."
2967 2968

    def apply(self, graph):
2969 2970 2971
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
2972
        fake_quant_dequant_ops = []
2973 2974 2975 2976 2977 2978
        remove_fake_quant_ops = []
        observer_out_node_names = []
        for op in graph.all_op_nodes():
            # collect observer node
            if op.name() == "moving_average_abs_max_scale":
                observer_out_node_names.append(op.output("Out")[0])
2979 2980

        for op in graph.all_op_nodes():
2981 2982 2983 2984
            if (
                op.name() in _fake_quant_dequant_op_list
                or op.name() == "moving_average_abs_max_scale"
            ):
2985 2986 2987 2988 2989 2990 2991 2992 2993 2994 2995 2996 2997
                var_name = op.input("X")[0]
                if var_name in observer_out_node_names:
                    remove_fake_quant_ops.append(op)
                else:
                    fake_quant_dequant_ops.append(op)

        for _op in remove_fake_quant_ops:
            x_node = graph._find_node_by_name(_op.inputs, _op.input("X")[0])
            out_node = graph._find_node_by_name(
                _op.outputs, _op.output("Out")[0]
            )
            for next_op_node in out_node.outputs:
                graph.update_input_link(out_node, x_node, next_op_node)
2998 2999 3000 3001 3002 3003 3004 3005 3006 3007 3008

        for _op in fake_quant_dequant_ops:
            self._replace_op(graph, _op)
            graph.safe_remove_nodes(_op)

        graph.resolve_hazard()
        return graph

    def _replace_op(self, graph, op):
        x_node = graph._find_node_by_name(op.inputs, op.input("X")[0])
        out_node = graph._find_node_by_name(op.outputs, op.output("Out")[0])
3009 3010 3011 3012 3013 3014 3015 3016 3017 3018 3019 3020
        scale_node = graph._find_node_by_name(
            op.outputs, op.output("OutScale")[0]
        )

        quant_axis = (
            op.op().attr("quant_axis") if op.op().has_attr("quant_axis") else -1
        )
        bit_length = (
            op.op().attr("bit_length")
            if op.op().has_attr("bit_length")
            else self._quant_bits
        )
3021 3022 3023 3024 3025 3026 3027 3028

        zero_point_node = None
        quanted_node = x_node
        if zero_point_node is None:
            zero_point_node = graph.create_persistable_node(
                name=self._zero_point_name(quanted_node.name()),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                shape=scale_node.shape(),
3029 3030 3031 3032 3033 3034 3035 3036 3037 3038 3039 3040 3041 3042 3043 3044 3045 3046 3047 3048 3049 3050 3051 3052 3053
                var_dtype=core.VarDesc.VarType.INT32,
            )
            _init_var_node(
                zero_point_node,
                np.zeros(scale_node.shape(), dtype="int32"),
                self._scope,
                self._place,
            )

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(x_node.name()),
            var_type=x_node.type(),
            shape=x_node.shape(),
            var_dtype=x_node.dtype(),
        )
        quant_op_node = graph.create_op_node(
            op_type="quantize_linear",
            attrs={"quant_axis": quant_axis, "bit_length": bit_length},
            inputs={
                "X": x_node,
                "Scale": scale_node,
                "ZeroPoint": zero_point_node,
            },
            outputs={"Y": quant_var_node},
        )
3054 3055 3056 3057 3058
        graph.link_to(x_node, quant_op_node)
        graph.link_to(scale_node, quant_op_node)
        if zero_point_node is not None:
            graph.link_to(zero_point_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
3059 3060 3061 3062 3063 3064 3065 3066 3067 3068
        dequant_op_node = graph.create_op_node(
            op_type="dequantize_linear",
            attrs={"quant_axis": quant_axis, "bit_length": bit_length},
            inputs={
                "X": quant_var_node,
                "Scale": scale_node,
                "ZeroPoint": zero_point_node,
            },
            outputs={"Y": out_node},
        )
3069 3070 3071 3072 3073 3074 3075 3076 3077 3078 3079 3080 3081 3082 3083 3084 3085 3086 3087
        graph.link_to(quant_var_node, dequant_op_node)
        graph.link_to(scale_node, dequant_op_node)
        if zero_point_node is not None:
            graph.link_to(zero_point_node, dequant_op_node)
        graph.link_to(dequant_op_node, out_node)

    def _quantized_var_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.quantized" % (var_name)

    def _zero_point_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
        return "%s@zero_point" % (var_name)


3088
class QuantWeightPass:
3089 3090 3091 3092 3093 3094 3095 3096 3097 3098 3099 3100 3101
    """
    quant weights and remove weights input quantize_linear node. for example:
    `weight -> quant -> dequant -> conv2d` will be frozen into `weight -> dequant -> conv2d`,
    and weight will be scaled offline.

    Args:
        scope(paddle.Scope): scope is used to get the weight tensor values.
        place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to restore the weight tensors.
            If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs.
        bias_correction(bool): whether use bias correction for post-training quantization.
             https://arxiv.org/abs/1810.05723.
        quant_bits(int, optional): quantization bit number for weight. Default is 8.
        save_int_weight(bool, optional): Whether the type saving the weight is int. Default is True.
3102

3103 3104 3105 3106
    Examples:
        .. code-block:: python
            # The original graph will be rewrite.
            import paddle
3107
            from paddle.static.quantization \
3108
                import QuantWeightPass
3109 3110
            from paddle.fluid.framework import IrGraph
            from paddle.framework import core
3111

3112
            graph = IrGraph(core.Graph(paddle.static.Program().desc), for_test=False)
3113 3114 3115 3116 3117 3118
            place = paddle.CPUPlace()
            scope = paddle.static.global_scope()
            quant_weight_pass = QuantWeightPass(scope, place)
            quant_weight_pass.apply(graph)
    """

3119 3120 3121 3122 3123 3124 3125 3126
    def __init__(
        self,
        scope,
        place,
        bias_correction=False,
        quant_bits=8,
        save_int_weight=True,
    ):
3127 3128 3129 3130 3131
        self._place = _get_paddle_place(place)
        self._scope = scope
        self._bias_correction = bias_correction
        self._quant_bits = quant_bits
        self._save_int_weight = save_int_weight
3132 3133
        assert self._scope is not None, "scope must not be None."
        assert self._place is not None, "place must not be None."
C
Chang Xu 已提交
3134
        self._quantized_ops = set()
3135 3136

    def apply(self, graph):
3137 3138 3139
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
3140 3141 3142 3143 3144 3145 3146 3147
        fake_quant_ops_for_weight = []

        fake_quant_ops = [
            op for op in graph.all_op_nodes() if op.name() == "quantize_linear"
        ]
        for _op in fake_quant_ops:
            x_node = graph._find_node_by_name(_op.inputs, _op.input("X")[0])
            if x_node.persistable():
3148 3149 3150
                scale_node = graph._find_node_by_name(
                    _op.inputs, _op.input("Scale")[0]
                )
3151
                zero_point_node = graph._find_node_by_name(
3152 3153 3154 3155 3156
                    _op.inputs, _op.input("ZeroPoint")[0]
                )
                out_node = graph._find_node_by_name(
                    _op.outputs, _op.output("Y")[0]
                )
3157 3158

                scale_v = self._load_var(scale_node.name())
3159 3160 3161 3162
                assert scale_v.ndim in [
                    1,
                    2,
                ], "the dim of scale_v should be 1 or 2"
3163 3164 3165 3166 3167 3168 3169 3170 3171
                if scale_v.ndim == 2:
                    scale_v = scale_v[0]
                if scale_v.size == 1 and _op.name() == 'abs_max':
                    scale_v = scale_v[0]
                else:
                    scale_v = scale_v.tolist()
                param_v = self._load_var(x_node.name())
                quant_axis = _op.op().attr("quant_axis")
                bits_length = _op.op().attr("bit_length")
C
Chang Xu 已提交
3172 3173 3174 3175
                if x_node.name() not in self._quantized_ops:
                    self._quantized_ops.add(x_node.name())
                    quantized_param_v = utils.quant_tensor(
                        param_v.copy(),
3176 3177
                        scale_v,
                        quant_axis,
C
Chang Xu 已提交
3178 3179
                        bits_length,
                        onnx_format=True,
3180
                    )
3181
                    if self._bias_correction is True:
C
Chang Xu 已提交
3182 3183 3184 3185 3186 3187 3188 3189 3190 3191 3192 3193 3194 3195 3196
                        quantized_param_v = utils.bias_correction_w(
                            param_v,
                            quantized_param_v,
                            scale_v,
                            quant_axis,
                            weight_bits=bits_length,
                        )
                    if self._save_int_weight:
                        # cast weight type to int
                        if self._quant_bits == 8:
                            save_weight_dtype = np.int8
                        quantized_param_v = quantized_param_v.astype(
                            save_weight_dtype
                        )
                    self._restore_var(x_node.name(), quantized_param_v)
3197 3198 3199 3200 3201 3202 3203 3204 3205 3206 3207 3208 3209 3210 3211 3212 3213 3214

                for next_op_node in out_node.outputs:
                    graph.update_input_link(out_node, x_node, next_op_node)
                graph.safe_remove_nodes(out_node)
        self._remove_unused_var_nodes(graph)

    def _remove_unused_var_nodes(self, graph):
        all_used_vars = set()
        ops = graph.all_op_nodes()
        for op_node in ops:
            for input_node in op_node.inputs:
                all_used_vars.add(input_node)
            for output_node in op_node.outputs:
                all_used_vars.add(output_node)

        all_used_vars = {n.node for n in all_used_vars}
        all_unused_vars = {
            n
3215 3216 3217 3218
            for n in filter(
                lambda node: node.node not in all_used_vars,
                graph.all_var_nodes(),
            )
3219 3220 3221 3222 3223 3224 3225 3226 3227
        }
        graph.safe_remove_nodes(all_unused_vars)

    def _load_var(self, name):
        return np.array(self._scope.find_var(name).get_tensor())

    def _restore_var(self, name, array):
        tensor = self._scope.find_var(name).get_tensor()
        tensor.set(array, self._place)
3228 3229


3230
class AddQuantDequantForInferencePass:
3231 3232 3233 3234
    """
    When export quant model, it will traverse to find the output of each op, and then insert the quant/dequant op after it.
    """

3235 3236 3237 3238 3239 3240 3241 3242 3243
    def __init__(
        self,
        scope,
        place,
        quant_bits=8,
        quantizable_op_type=[],
        calibration_range_dict=None,
        only_observer=True,
    ):
3244 3245
        """
        Args:
3246
            scope(static.Scope): The scope is used to initialize these new parameters.
3247 3248 3249 3250 3251 3252 3253
            place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to restore the weight tensors.
                If it's string, it can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs.
            quant_bits(int, optional): quantization bit number for weight. Default is 8.
        """
        self._scope = scope
        self._place = place
        self._quant_bits = quant_bits
3254 3255 3256 3257 3258 3259 3260
        self._only_observer = only_observer
        self._teller_set = (
            quantizable_op_type
            if quantizable_op_type
            else list(SUPPORT_QUANTIZATION_OP_DICT.keys())
        )
        self._calibration_range_dict = calibration_range_dict
3261 3262 3263 3264 3265 3266

    def apply(self, graph):
        """
        Args:
            graph(IrGraph): the target graph.
        """
3267 3268 3269
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
3270 3271 3272 3273 3274 3275
        dequant_node_map = {}
        dequantized_vars_map = collections.OrderedDict()
        for op_node in graph.all_op_nodes():
            if op_node.name() in self._teller_set:
                var_names = utils._get_op_output_var_names(op_node)
                for var_name in var_names:
3276 3277 3278 3279 3280 3281
                    out_node = graph._find_node_by_name(
                        op_node.outputs, var_name
                    )
                    if out_node.dtype() not in [
                        core.VarDesc.VarType.FP64,
                        core.VarDesc.VarType.FP32,
3282
                        core.VarDesc.VarType.FP16,
3283
                    ]:
3284 3285 3286 3287 3288
                        continue
                    if var_name in dequantized_vars_map:
                        dequant_var_node = dequantized_vars_map[var_name]
                    else:
                        dequant_var_node = self._insert_quant_dequant_op(
3289 3290
                            graph, out_node
                        )
3291 3292 3293 3294 3295 3296 3297 3298 3299 3300 3301 3302
                        dequantized_vars_map[var_name] = dequant_var_node
                    dequant_node_map[var_name] = dequant_var_node

        # remove unuse node and link act quant/dequant linear to op node
        for op_node in graph.all_op_nodes():
            if op_node.name() == 'moving_average_abs_max_scale':
                graph.safe_remove_nodes(op_node)
            else:
                var_names = utils._get_op_input_var_names(op_node)
                for var_name in var_names:
                    if var_name in dequant_node_map:
                        in_node = graph._find_node_by_name(
3303 3304 3305 3306 3307
                            op_node.inputs, var_name
                        )
                        graph.update_input_link(
                            in_node, dequant_node_map[var_name], op_node
                        )
3308 3309 3310 3311 3312 3313 3314 3315 3316 3317 3318 3319 3320 3321 3322 3323 3324

        return graph

    def _scale_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
        return "%s@scale" % (var_name)

    def _insert_quant_dequant_op(self, graph, var_node):
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())
        var_name = var_node.name()
        quant_axis = -1
        quant_var_node = graph.create_var_node(
            name="{}.quantized".format(var_name),
            var_type=var_node.type(),
            shape=var_node.shape(),
3325 3326
            var_dtype=var_node.dtype(),
        )
3327 3328 3329 3330 3331 3332 3333 3334 3335 3336 3337 3338 3339 3340 3341 3342 3343 3344 3345 3346 3347 3348 3349 3350 3351
        if not self._calibration_range_dict:
            scale_var_node = graph._find_node_by_name(
                graph.all_persistable_nodes(), self._scale_name(var_name)
            )
        elif var_name in self._calibration_range_dict:
            scale_value = self._calibration_range_dict[var_name]
            scale_var_node = graph.create_persistable_node(
                name=self._scale_name(var_name),
                var_type=var_node.type(),
                shape=[1],
                var_dtype=var_node.dtype(),
            )
            data_type = (
                'float64'
                if var_node.dtype() == core.VarDesc.VarType.FP64
                else 'float32'
            )
            _init_var_node(
                scale_var_node,
                np.array(scale_value, dtype=data_type),
                self._scope,
                self._place,
            )
        else:
            return None
3352 3353 3354
        try:
            zero_point_node = graph._find_node_by_name(
                graph.all_persistable_nodes(),
3355 3356
                "{}@zero_point".format(quant_var_node.name()),
            )
3357 3358 3359 3360 3361
        except:
            zero_point_node = graph.create_persistable_node(
                name="{}@zero_point".format(quant_var_node.name()),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                shape=scale_var_node.shape(),
3362 3363 3364 3365 3366 3367 3368 3369
                var_dtype=core.VarDesc.VarType.INT32,
            )
            _init_var_node(
                zero_point_node,
                np.zeros(scale_var_node.shape(), dtype="int32"),
                self._scope,
                self._place,
            )
3370 3371 3372 3373 3374

        inputs = {"X": var_node, "Scale": scale_var_node}
        if zero_point_node is not None:
            inputs["ZeroPoint"] = zero_point_node

3375 3376 3377 3378 3379
        attrs = {
            "quant_axis": quant_axis,
            "bit_length": self._quant_bits,
            "only_observer": self._only_observer,
        }
3380 3381 3382
        attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
        outputs = {"Y": quant_var_node}

3383 3384 3385 3386 3387 3388
        quant_op_node = graph.create_op_node(
            op_type="quantize_linear",
            attrs=attrs,
            inputs=inputs,
            outputs=outputs,
        )
3389 3390 3391 3392 3393 3394 3395 3396 3397 3398 3399 3400

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_var_node, quant_op_node)
        if zero_point_node is not None:
            graph.link_to(zero_point_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)

        # add dequant_linear node
        dequant_var_node = graph.create_var_node(
            name="{}.dequantized".format(quant_var_node.name()),
            var_type=quant_var_node.type(),
            shape=quant_var_node.shape(),
3401 3402
            var_dtype=quant_var_node.dtype(),
        )
3403 3404 3405 3406 3407

        inputs = {"X": quant_var_node, "Scale": scale_var_node}
        if zero_point_node is not None:
            inputs["ZeroPoint"] = zero_point_node

3408 3409 3410 3411 3412
        attrs = {
            "quant_axis": -1,
            "bit_length": self._quant_bits,
            "only_observer": self._only_observer,
        }
3413 3414
        attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward

3415 3416 3417 3418 3419 3420
        dequant_op_node = graph.create_op_node(
            op_type="dequantize_linear",
            attrs=attrs,
            inputs=inputs,
            outputs={"Y": dequant_var_node},
        )
3421 3422 3423 3424 3425 3426 3427

        graph.link_to(quant_var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        if zero_point_node is not None:
            graph.link_to(zero_point_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
        return dequant_var_node