quantization_pass.py 52.2 KB
Newer Older
W
WangZhen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
W
WangZhen 已提交
16
import numpy as np
W
WangZhen 已提交
17
from ..... import compat as cpt
W
WangZhen 已提交
18
from .... import core
19
from ....framework import IrGraph
20
from ....framework import IrNode
W
WangZhen 已提交
21 22
from .... import unique_name

23 24
__all__ = [
    'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass',
25 26
    'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass',
    'AddQuantDequantPass'
27
]
W
WangZhen 已提交
28

W
WangZhen 已提交
29

30 31 32 33 34 35 36 37 38 39 40
def _init_var_node(var_node, value, scope, place):
    assert isinstance(value,
                      np.ndarray), 'The type of value should be numpy array.'
    assert scope is not None, \
    'The scope cannot be set None.'
    assert place is not None, \
    'The place cannot be set None.'
    tensor = scope.var(var_node.name()).get_tensor()
    tensor.set(value, place)


41
class QuantizationTransformPass(object):
W
WangZhen 已提交
42
    def __init__(self,
43
                 scope=None,
44
                 place=None,
W
WangZhen 已提交
45 46 47 48
                 weight_bits=8,
                 activation_bits=8,
                 activation_quantize_type='abs_max',
                 weight_quantize_type='abs_max',
49 50
                 window_size=10000,
                 moving_rate=0.9):
W
WangZhen 已提交
51
        """
52
        Convert and rewrite the IrGraph according to weight and
W
WangZhen 已提交
53
        activation quantization type.
54

W
WangZhen 已提交
55
        Args:
56 57 58
            scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
            type, this pass will create some new parameters. The scope is used to
            initialize these new parameters.
59
            place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new
60
            parameters described above.
W
WangZhen 已提交
61 62 63 64
            weight_bits (int): quantization bit number for weights,
                the bias is not quantized.
            activation_bits (int): quantization bit number for activation.
            activation_quantize_type (str): quantization type for activation,
65 66 67 68 69
                now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
                If use 'abs_max' mode, the quantization scale will be calculated
                dynamically each step in both training and testing period. If use
                'range_abs_max', a static quantization scale will be calculated
                during training and used in inference.
W
WangZhen 已提交
70
            weight_quantize_type (str): quantization type for weights,
71 72 73
                support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max'
                usually is not used for weight, since weights are fixed once the
                model is well trained.
W
WangZhen 已提交
74
            window_size (int): the window size for 'range_abs_max' quantization.
75

W
WangZhen 已提交
76 77
        Examples:
        .. code-block:: python
78 79 80 81
            # The original graph will be rewrite.
            import paddle.fluid as fluid
            from paddle.fluid.contrib.slim.quantization \
                import QuantizationTransformPass
82
            from paddle.fluid.contrib.slim.graph import IrGraph
83 84
            from paddle.fluid import core

85
            graph = IrGraph(core.Graph(program.desc), for_test=False)
86
            place = fluid.CPUPlace()
87
            transform_pass = QuantizationTransformPass(fluid.global_scope(),
88
            place)
89
            transform_pass.apply(graph)
W
WangZhen 已提交
90
        """
91
        self._scope = scope
92
        self._place = place
93 94
        self._weight_bits = weight_bits
        self._activation_bits = activation_bits
W
WangZhen 已提交
95

96 97 98 99 100
        quant_type = [
            'abs_max', 'channel_wise_abs_max', 'range_abs_max',
            'moving_average_abs_max'
        ]
        assert activation_quantize_type != 'channel_wise_abs_max', "The activation quantization type does not support 'channel_wise_abs_max'."
W
WangZhen 已提交
101 102
        if activation_quantize_type not in quant_type:
            raise ValueError(
103 104 105
                "Unknown activation_quantize_type : '%s'. It can only be "
                "'abs_max' or 'range_abs_max' or 'moving_average_abs_max'." %
                (str(activation_quantize_type)))
W
WangZhen 已提交
106 107
        if weight_quantize_type not in quant_type:
            raise ValueError(
108 109 110
                "Unknown weight_quantize_type: '%s'. It can only be "
                "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
                % (str(weight_quantize_type)))
W
WangZhen 已提交
111

112 113 114
        self._activation_quantize_type = activation_quantize_type
        self._weight_quantize_type = weight_quantize_type
        self._window_size = window_size
115
        self._moving_rate = moving_rate
W
WangZhen 已提交
116

117
        self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
118
        self._conv_ops = ['conv2d', 'depthwise_conv2d']
119 120
        self._quantizable_grad_ops = [
            '%s_grad' % (op) for op in self._quantizable_ops
W
WangZhen 已提交
121
        ]
122 123
        self._is_test = None
        self._global_step = None
W
WangZhen 已提交
124

125
    def apply(self, graph):
126 127 128 129 130 131 132 133
        """
        Quantize the graph for training process. According to weight and
        activation quantization type, the graph will be added some fake
        quantize operators and fake dequantize operators.

        Args:
            graph(IrGraph): the applied graph.
        """
W
WangZhen 已提交
134
        assert isinstance(graph,
135 136
                          IrGraph), 'graph must be the instance of IrGraph.'
        self._is_test = graph.is_test()
W
WangZhen 已提交
137 138
        # marked the variable which has been dequantized.
        dequantized_vars = collections.OrderedDict()
139
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
W
WangZhen 已提交
140 141 142

        def _transform_forward(graph, op):
            for var_node in op.inputs:
143 144
                if var_node.name() not in op.input_arg_names():
                    continue
W
WangZhen 已提交
145 146 147
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
                else:
W
WangZhen 已提交
148
                    quant_bits = self._weight_bits if var_node.name() in persistable_vars \
149 150
                    else self._activation_bits
                    quant_type = self._weight_quantize_type if var_node.name() \
W
WangZhen 已提交
151
                        in persistable_vars else self._activation_quantize_type
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
                    if quant_type == 'channel_wise_abs_max':
                        assert var_node.name(
                        ) in persistable_vars, "'channel_wise_abs_max' can only be applied on weights."
                        if op.name() in self._conv_ops:
                            quant_var_node, scale_var_node = self._insert_channel_quant_op(
                                graph, var_node, quant_bits)
                            dequant_var_node = self._insert_channel_dequant_op(
                                graph, quant_var_node, [scale_var_node],
                                [quant_bits])
                        else:
                            quant_var_node, scale_var_node = self._insert_quant_op(
                                graph, var_node, quant_bits, 'abs_max')
                            dequant_var_node = self._insert_dequant_op(
                                graph, quant_var_node, scale_var_node,
                                quant_bits)
                    else:
                        quant_var_node, scale_var_node = self._insert_quant_op(
                            graph, var_node, quant_bits, quant_type)
                        dequant_var_node = self._insert_dequant_op(
                            graph, quant_var_node, scale_var_node, quant_bits)
W
WangZhen 已提交
172
                    dequantized_vars[var_node.name()] = dequant_var_node
173
                graph.update_input_link(var_node, dequant_var_node, op)
W
WangZhen 已提交
174 175 176 177

        def _transform_backward(graph, op):
            no_dequanted_input_vars = True
            for var_node in op.inputs:
178 179
                if var_node.name() not in op.input_arg_names():
                    continue
W
WangZhen 已提交
180 181
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
182
                    graph.update_input_link(var_node, dequant_var_node, op)
W
WangZhen 已提交
183 184 185 186
                    no_dequanted_input_vars = False
            if no_dequanted_input_vars:
                raise ValueError("There is no dequanted inputs for op %s." %
                                 (op.name()))
W
WangZhen 已提交
187

188
        if not self._is_test:
W
WangZhen 已提交
189
            self._create_global_step(graph)
190
        ops = graph.all_op_nodes()
W
WangZhen 已提交
191 192
        # The process of _transform_forward and _transform_backward is needed in two for loops.
        # The loop for transforming the forward graph:
W
WangZhen 已提交
193
        for op in ops:
194
            if op.name() in self._quantizable_ops:
W
WangZhen 已提交
195
                _transform_forward(graph, op)
W
WangZhen 已提交
196 197
        # The loop for renaming the inputs of backward op.
        for op in ops:
198
            if op.name() in self._quantizable_grad_ops:
W
WangZhen 已提交
199
                _transform_backward(graph, op)
Z
Zhen Wang 已提交
200
        graph.resolve_hazard()
201
        return graph
W
WangZhen 已提交
202

W
WangZhen 已提交
203
    def _create_global_step(self, graph):
204 205
        if self._weight_quantize_type == 'range_abs_max' or \
                self._activation_quantize_type == 'range_abs_max':
W
WangZhen 已提交
206
            counter_name = cpt.to_text('@STEP_COUNTER@')
207
            for node in graph.all_var_nodes():
W
WangZhen 已提交
208
                if node.name() == counter_name:
209 210
                    self._global_step = node
            if self._global_step is None:
211
                global_step_in = graph.create_persistable_node(
W
WangZhen 已提交
212 213 214 215
                    name=counter_name,
                    var_type=core.VarDesc.VarType.LOD_TENSOR,
                    shape=[1],
                    var_dtype=core.VarDesc.VarType.INT64)
216 217 218 219 220 221
                _init_var_node(
                    global_step_in,
                    np.zeros(
                        [1], dtype='int64'),
                    self._scope,
                    self._place)
W
WangZhen 已提交
222 223
                global_step_out = graph.create_var_node_from_desc(
                    global_step_in.var())
224
                # The attribute of `op_role` is needed by ParallelExecutor.
W
WangZhen 已提交
225 226
                increment_op = graph.create_op_node(
                    op_type='increment',
227 228 229 230 231
                    attrs={
                        'step': 1.0,
                        'op_role':
                        core.op_proto_and_checker_maker.OpRole.Forward
                    },
W
WangZhen 已提交
232 233
                    inputs={'X': global_step_in},
                    outputs={'Out': global_step_out})
234 235 236
                graph.link_to(global_step_in, increment_op)
                graph.link_to(increment_op, global_step_out)
                self._global_step = global_step_out
W
WangZhen 已提交
237

W
WangZhen 已提交
238 239 240 241 242 243 244
    def _insert_quant_op(self, graph, var_node, quant_bits, quant_type):
        """
        Insert fake_quantize_op in the graph.
        """
        if quant_type == 'abs_max':
            return self._insert_quant_abs_max_op(graph, var_node, quant_bits)
        elif quant_type == 'range_abs_max':
W
WangZhen 已提交
245 246
            return self._insert_quant_range_abs_max_op(graph, var_node,
                                                       quant_bits)
247 248 249
        elif quant_type == 'moving_average_abs_max':
            return self._insert_quant_moving_average_abs_max_op(graph, var_node,
                                                                quant_bits)
W
WangZhen 已提交
250 251 252 253 254 255 256 257 258

    def _insert_quant_abs_max_op(self, graph, var_node, quant_bits):
        """
        Insert fake_quantize_abs_max op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
259 260 261
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
W
WangZhen 已提交
262 263
        scale_var_node = graph.create_var_node(
            name=self._quantized_scale_name(var_node.name()),
264
            var_type=var_node.type(),
265
            shape=[1],
266
            var_dtype=var_node.dtype())
W
WangZhen 已提交
267 268
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_abs_max',
269 270 271 272
            attrs={
                'bit_length': quant_bits,
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
            },
W
WangZhen 已提交
273 274 275
            inputs={'X': var_node},
            outputs={'Out': quant_var_node,
                     'OutScale': scale_var_node})
276 277 278
        graph.link_to(var_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_var_node)
W
WangZhen 已提交
279 280 281 282 283 284 285 286 287 288
        return quant_var_node, scale_var_node

    def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
        """
        Insert fake_quantize_range_abs_max on the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
289 290 291
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
W
WangZhen 已提交
292

293
        scale_in_node = graph.create_persistable_node(
W
WangZhen 已提交
294 295 296
            name=self._quantized_scale_name(var_node.name()),
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
297
            var_dtype=var_node.dtype())
298 299
        data_type = 'float64' if var_node.dtype(
        ) == core.VarDesc.VarType.FP64 else 'float32'
300 301 302 303 304 305
        _init_var_node(
            scale_in_node,
            np.array(
                [0.001], dtype=data_type),
            self._scope,
            self._place)
W
WangZhen 已提交
306 307 308 309 310

        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        inputs = {'X': var_node, 'InScale': scale_in_node}
        outputs = {'Out': quant_var_node, 'OutScale': scale_out_node}

311
        if not self._is_test:
W
WangZhen 已提交
312
            # The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
313
            scales_node = graph.create_persistable_node(
W
WangZhen 已提交
314 315
                name=unique_name.generate('scales'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
316
                shape=[self._window_size],
317
                var_dtype=var_node.dtype())
318 319
            data_type = 'float64' if var_node.dtype(
            ) == core.VarDesc.VarType.FP64 else 'float32'
320 321 322 323 324 325 326
            _init_var_node(
                scales_node,
                np.zeros(
                    [self._window_size], dtype=data_type),
                self._scope,
                self._place)

327
            inputs['Iter'] = self._global_step
W
WangZhen 已提交
328 329
            outputs['OutScales'] = scales_node
        attrs = {
330
            'window_size': self._window_size,
W
WangZhen 已提交
331
            'bit_length': quant_bits,
332 333
            'is_test': self._is_test,
            'op_role': core.op_proto_and_checker_maker.OpRole.Forward
W
WangZhen 已提交
334 335 336 337 338 339 340
        }
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_range_abs_max',
            attrs=attrs,
            inputs=inputs,
            outputs=outputs)

341 342 343 344
        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_in_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_out_node)
W
WangZhen 已提交
345

346 347 348
        if not self._is_test:
            graph.link_to(self._global_step, quant_op_node)
            graph.link_to(quant_op_node, scales_node)
W
WangZhen 已提交
349 350 351

        return quant_var_node, scale_out_node

352 353 354 355 356 357 358 359 360 361 362 363 364 365
    def _insert_quant_moving_average_abs_max_op(self, graph, var_node,
                                                quant_bits):
        """Insert fake_quantize_moving_average_abs_max
        """
        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
        scale_in_node = graph.create_persistable_node(
            name=self._quantized_scale_name(var_node.name()),
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
            var_dtype=var_node.dtype())
366 367
        data_type = 'float64' if var_node.dtype(
        ) == core.VarDesc.VarType.FP64 else 'float32'
368 369 370 371 372 373
        _init_var_node(
            scale_in_node,
            np.array(
                [0.001], dtype=data_type),
            self._scope,
            self._place)
374 375 376 377 378 379 380 381 382 383

        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        ins = {'X': var_node, 'InScale': scale_in_node}
        outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
        if not self._is_test:
            state_in_node = graph.create_persistable_node(
                name=unique_name.generate('state'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
                shape=[1])
384 385
            data_type = 'float64' if var_node.dtype(
            ) == core.VarDesc.VarType.FP64 else 'float32'
386
            _init_var_node(
387
                state_in_node,
388 389 390 391
                np.ones(
                    [1], dtype=data_type),
                self._scope,
                self._place)
392 393 394 395 396
            accum_in_node = graph.create_persistable_node(
                name=unique_name.generate('accum'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
                shape=[1])
397 398 399 400 401 402
            _init_var_node(
                accum_in_node,
                np.ones(
                    [1], dtype=data_type),
                self._scope,
                self._place)
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
            state_out_node = graph.create_var_node_from_desc(state_in_node.var(
            ))
            accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
            ))

            ins['InState'] = state_in_node
            ins['InAccum'] = accum_in_node
            outs['OutState'] = state_out_node
            outs['OutAccum'] = accum_out_node

        attrs = {
            'bit_length': quant_bits,
            'moving_rate': self._moving_rate,
            'is_test': self._is_test,
            'op_role': core.op_proto_and_checker_maker.OpRole.Forward
        }

        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_moving_average_abs_max',
            attrs=attrs,
            inputs=ins,
            outputs=outs)

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_in_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_out_node)

        if not self._is_test:
            graph.link_to(state_in_node, quant_op_node)
            graph.link_to(accum_in_node, quant_op_node)
            graph.link_to(quant_op_node, state_out_node)
            graph.link_to(quant_op_node, accum_out_node)

        return quant_var_node, scale_out_node

439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468
    def _insert_channel_quant_op(self, graph, var_node, quant_bits):
        """
        Insert fake_channel_wise_quantize_abs_max op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
        scale_var_node = graph.create_var_node(
            name=self._quantized_scale_name(var_node.name()),
            var_type=var_node.type(),
            shape=[var_node.shape()[0]],
            var_dtype=var_node.dtype())
        quant_op_node = graph.create_op_node(
            op_type='fake_channel_wise_quantize_abs_max',
            attrs={
                'bit_length': quant_bits,
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
            },
            inputs={'X': var_node},
            outputs={'Out': quant_var_node,
                     'OutScale': scale_var_node})
        graph.link_to(var_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_var_node)
        return quant_var_node, scale_var_node

W
WangZhen 已提交
469 470 471 472 473 474 475 476
    def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
        """
        Insert fake_dequantize_op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(var_node.name()),
477 478 479
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
W
WangZhen 已提交
480 481 482
        max_range = (1 << (quant_bits - 1)) - 1
        dequant_op_node = graph.create_op_node(
            op_type='fake_dequantize_max_abs',
483 484 485 486
            attrs={
                'max_range': float(max_range),
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
            },
W
WangZhen 已提交
487 488 489
            inputs={'X': var_node,
                    'Scale': scale_var_node},
            outputs={'Out': dequant_var_node})
490 491 492
        graph.link_to(var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
W
WangZhen 已提交
493 494
        return dequant_var_node

495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521
    def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes,
                                   quant_bits):
        """
        Insert fake_channel_wise_dequantize_max_abs in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
        dequant_op_node = graph.create_op_node(
            op_type='fake_channel_wise_dequantize_max_abs',
            attrs={
                'quant_bits': quant_bits,
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
            },
            inputs={'X': var_node,
                    'Scales': scale_var_nodes},
            outputs={'Out': dequant_var_node})
        graph.link_to(var_node, dequant_op_node)
        for scale_n in scale_var_nodes:
            graph.link_to(scale_n, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
        return dequant_var_node

W
WangZhen 已提交
522 523 524 525 526 527 528 529 530 531 532 533 534 535
    def _quantized_var_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.quantized" % (var_name)

    def _dequantized_var_name(self, var_name):
        """
        Return dequantized variable name for the input `var_name`.
        """
        return "%s.dequantized" % (var_name)

    def _quantized_scale_name(self, var_name):
        """
536
        Return the scale name of quantized variable for the input `var_name`.
W
WangZhen 已提交
537 538
        """
        return "%s.scale" % (var_name)
W
WangZhen 已提交
539 540 541


class QuantizationFreezePass(object):
542 543 544 545 546 547 548 549 550 551 552 553
    """
    The freeze pass is used to adjust the quantize operator order, for example:
        1) `activation -> quant -> dequant -> conv2d` will be freezed into
        `activation -> quant -> conv2d -> dequant`
        2) `weight -> quant -> dequant -> conv2d` will be freezed into `weight -> conv2d`,
        and weight will be sacled offline.

    Args:
        scope(fluid.Scope): scope is used to get the weight tensor values.
        place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors.
        weight_bits (int): quantization bit number for weights.
        activation_bits (int): quantization bit number for activation.
554
        weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'.
555 556 557 558
        The 'range_abs_max' usually is not used for weight, since weights are fixed once the
        model is well trained.
    """

W
WangZhen 已提交
559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574
    def __init__(self,
                 scope,
                 place,
                 weight_bits=8,
                 activation_bits=8,
                 weight_quantize_type='abs_max'):
        assert scope is not None, \
            'The scope cannot be set None.'
        assert place is not None, \
            'The place cannot be set None.'
        self._scope = scope
        self._place = place
        self._weight_bits = weight_bits
        self._activation_bits = activation_bits
        self._weight_quantize_type = weight_quantize_type
        self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
575
        self._conv_ops = ['conv2d', 'depthwise_conv2d']
W
WangZhen 已提交
576
        self._fake_quant_op_names = [
577
            'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
578 579 580 581 582
            'fake_quantize_moving_average_abs_max',
            'fake_channel_wise_quantize_abs_max'
        ]
        self._fake_dequant_op_names = [
            'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
W
WangZhen 已提交
583 584 585 586 587 588
        ]
        self._op_input_rename_map = collections.OrderedDict()
        self._op_output_rename_map = collections.OrderedDict()
        self._var_scale_map = collections.OrderedDict()

    def apply(self, graph):
589 590 591 592 593 594
        """
        Adjust quantize/dequantize operators order for the inference process.

        Args:
            graph(IrGraph): the applied graph.
        """
595 596
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
        ops = graph.all_op_nodes()
W
WangZhen 已提交
597 598 599
        for op_node in ops:
            op_name = op_node.name()
            if op_name in self._fake_quant_op_names:
600
                input_arg_name = op_node.input('X')[0]
W
WangZhen 已提交
601 602 603 604
                if input_arg_name in persistable_vars:
                    if self._weight_quantize_type == 'abs_max':
                        param = self._load_var(input_arg_name)
                        scale_v = np.max(np.abs(param))
605 606 607 608 609 610 611 612
                    elif self._weight_quantize_type == 'channel_wise_abs_max':
                        param = self._load_var(input_arg_name)
                        if len(param.shape) == 4:  # conv2d or depthwise_conv2d
                            scale_v = []
                            for i in range(param.shape[0]):
                                scale_v.append(np.max(np.abs(param[i])))
                        else:
                            scale_v = np.max(np.abs(param))
W
WangZhen 已提交
613
                    else:
614 615
                        scale_v = self._load_var(
                            op_node.output('OutScale')[0])[0]
W
WangZhen 已提交
616 617 618 619 620
                    self._var_scale_map[input_arg_name] = scale_v
                    self._remove_fake_quant_and_dequant_op(graph, op_node)
                    # quantize weight and restore
                    param_v = self._load_var(input_arg_name)
                    quantized_param_v = self._quant(param_v, scale_v,
W
WangZhen 已提交
621
                                                    self._weight_bits)
W
WangZhen 已提交
622
                    self._restore_var(input_arg_name, quantized_param_v)
623
                else:
624 625
                    scale_v = graph._find_node_by_name(
                        op_node.outputs, op_node.output('OutScale')[0])
626
                    self._var_scale_map[input_arg_name] = scale_v
W
WangZhen 已提交
627

628
        ops = graph.all_op_nodes()
W
WangZhen 已提交
629 630 631 632 633
        for op_node in ops:
            op_name = op_node.name()
            if op_name in self._fake_dequant_op_names:
                self._remove_fake_quant_and_dequant_op(graph, op_node)

634
        ops = graph.all_op_nodes()
W
WangZhen 已提交
635 636 637
        for op_node in ops:
            op_name = op_node.name()
            if op_name in self._quantizable_ops:
638 639 640 641
                if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
                    self._insert_post_channel_dequant_op(graph, op_node)
                else:
                    self._insert_post_dequant_op(graph, op_node)
W
WangZhen 已提交
642 643 644 645

        for op_node in ops:
            # insert dequant_op after fc/conv, need to rename inputs of the followed ops
            for var_node in op_node.inputs:
646 647 648
                if var_node.node in self._op_output_rename_map:
                    old_in = var_node
                    new_in = self._op_output_rename_map[var_node.node]
W
WangZhen 已提交
649 650 651 652
                    graph.update_input_link(old_in, new_in, op_node)

        # remove the unused var node in the graph
        self._remove_unused_var_nodes(graph)
Z
Zhen Wang 已提交
653
        graph.resolve_hazard()
654
        return graph
W
WangZhen 已提交
655 656

    def _remove_fake_quant_and_dequant_op(self, graph, op_node):
657 658
        k = graph._find_node_by_name(op_node.outputs, op_node.output('Out')[0])
        v = graph._find_node_by_name(op_node.inputs, op_node.input('X')[0])
659 660
        if v.node not in self._op_input_rename_map:
            self._op_input_rename_map[k.node] = v
W
WangZhen 已提交
661
        else:
662 663
            self._op_input_rename_map[k.node] = self._op_input_rename_map[
                v.node]
W
WangZhen 已提交
664
        graph.safe_remove_nodes(op_node)
W
WangZhen 已提交
665

666 667 668 669
    def _insert_post_channel_dequant_op(self, graph, op_node):
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
        for var_node in op_node.inputs:
            name = var_node.name()
670 671 672 673 674
            if name not in op_node.input_arg_names():
                continue
            if var_node.node in self._op_input_rename_map:
                old_in = var_node
                new_in = self._op_input_rename_map[var_node.node]
675 676 677 678 679 680 681 682 683 684 685 686 687 688
                new_in.clear_outputs()
                graph.update_input_link(old_in, new_in, op_node)
            original_var_name = self._original_var_name(name)
            scale_v = self._var_scale_map[original_var_name]
            if original_var_name in persistable_vars:
                assert isinstance(
                    scale_v,
                    list), 'The scale of parameter %s is not a list.' % (
                        original_var_name)
                channel_scale = np.array(scale_v)
            else:
                assert isinstance(scale_v, IrNode)
                scale_var_node = self._var_scale_map[original_var_name]

689
        if len(op_node.output_arg_names()) != 1:
690 691 692
            raise ValueError("Only support one output, but op %s has"
                             " more than one output." % (op_node.name()))

693 694
        output_var_node = graph._find_node_by_name(
            op_node.outputs, op_node.output_arg_names()[0])
695 696 697 698 699
        weight_scale_node = graph.create_persistable_node(
            name=unique_name.generate('channel_scale'),
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[channel_scale.shape[0]],
            var_dtype=output_var_node.dtype())
700 701
        data_type = 'float64' if output_var_node.dtype(
        ) == core.VarDesc.VarType.FP64 else 'float32'
702 703 704
        _init_var_node(weight_scale_node,
                       channel_scale.astype(data_type), self._scope,
                       self._place)
705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724
        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(output_var_node.name()),
            var_type=output_var_node.type(),
            shape=output_var_node.shape(),
            var_dtype=output_var_node.dtype())
        dequant_op_node = graph.create_op_node(
            op_type='fake_channel_wise_dequantize_max_abs',
            attrs={
                'quant_bits': [self._weight_bits, self._activation_bits],
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
            },
            inputs={
                'X': output_var_node,
                'Scales': [weight_scale_node, scale_var_node]
            },
            outputs={'Out': dequant_var_node})
        graph.link_to(output_var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        graph.link_to(weight_scale_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
725
        self._op_output_rename_map[output_var_node.node] = dequant_var_node
726 727
        return dequant_var_node

W
WangZhen 已提交
728
    def _insert_post_dequant_op(self, graph, op_node):
729
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
W
WangZhen 已提交
730
        for var_node in op_node.inputs:
W
WangZhen 已提交
731
            name = var_node.name()
732 733 734 735 736
            if name not in op_node.input_arg_names():
                continue
            if var_node.node in self._op_input_rename_map:
                old_in = var_node
                new_in = self._op_input_rename_map[var_node.node]
W
WangZhen 已提交
737
                new_in.clear_outputs()
W
WangZhen 已提交
738 739
                graph.update_input_link(old_in, new_in, op_node)
            original_var_name = self._original_var_name(name)
W
WangZhen 已提交
740
            scale_v = self._var_scale_map[original_var_name]
W
WangZhen 已提交
741 742 743 744 745 746 747 748
            if original_var_name in persistable_vars:
                param_range = (1 << (self._weight_bits - 1)) - 1
                act_range = (1 << (self._activation_bits - 1)) - 1
                assert self._is_float(
                    scale_v), 'The scale of parameter %s is not a float.' % (
                        original_var_name)
                max_range = param_range * act_range / scale_v
            else:
749
                assert isinstance(scale_v, IrNode)
W
WangZhen 已提交
750 751
                scale_var_node = self._var_scale_map[original_var_name]

752
        if len(op_node.output_arg_names()) != 1:
W
WangZhen 已提交
753 754 755
            raise ValueError("Only support one output, but op %s has"
                             " more than one output." % (op_node.name()))

756 757
        output_var_node = graph._find_node_by_name(
            op_node.outputs, op_node.output_arg_names()[0])
W
WangZhen 已提交
758 759
        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(output_var_node.name()),
760 761 762
            var_type=output_var_node.type(),
            shape=output_var_node.shape(),
            var_dtype=output_var_node.dtype())
W
WangZhen 已提交
763 764
        dequant_op_node = graph.create_op_node(
            op_type='fake_dequantize_max_abs',
765 766 767 768
            attrs={
                'max_range': float(max_range),
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
            },
W
WangZhen 已提交
769 770 771 772 773 774
            inputs={'X': output_var_node,
                    'Scale': scale_var_node},
            outputs={'Out': dequant_var_node})
        graph.link_to(output_var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
775
        self._op_output_rename_map[output_var_node.node] = dequant_var_node
W
WangZhen 已提交
776 777 778 779 780
        return dequant_var_node

    def _load_var(self, name):
        return np.array(self._scope.find_var(name).get_tensor())

781 782 783
    def _restore_var(self, name, array):
        tensor = self._scope.find_var(name).get_tensor()
        tensor.set(array, self._place)
W
WangZhen 已提交
784 785 786

    def _remove_unused_var_nodes(self, graph):
        all_used_vars = set()
787
        ops = graph.all_op_nodes()
W
WangZhen 已提交
788 789 790 791 792 793
        for op_node in ops:
            for input_node in op_node.inputs:
                all_used_vars.add(input_node)
            for output_node in op_node.outputs:
                all_used_vars.add(output_node)

794 795 796 797 798 799
        all_used_vars = {n.node for n in all_used_vars}
        all_unused_vars = {
            n
            for n in filter(lambda node: node.node not in all_used_vars,
                            graph.all_var_nodes())
        }
W
WangZhen 已提交
800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822
        graph.safe_remove_nodes(all_unused_vars)

    def _original_var_name(self, var_name):
        """
        Return the original variable name.
        """
        if var_name.endswith('.quantized.dequantized'):
            return var_name[:-len('.quantized.dequantized')]
        if var_name.endswith('.quantized'):
            return var_name[:-len('.quantized')]
        if var_name.endswith('.dequantized'):
            return var_name[:-len('.dequantized')]
        if var_name.endswith('.scale'):
            return var_name[:-len('.scale')]
        else:
            return var_name

    def _dequantized_var_name(self, var_name):
        """
        Return dequantized variable name for the input `var_name`.
        """
        return "%s.dequantized" % (var_name)

W
WangZhen 已提交
823
    def _is_float(self, v):
W
WangZhen 已提交
824 825 826
        return isinstance(v, float) or isinstance(v, np.float32) \
            or isinstance(v, np.float64)

W
WangZhen 已提交
827
    def _quant(self, x, scale, num_bits):
828 829 830 831 832 833
        if isinstance(scale, list):
            for i, s in enumerate(scale):
                x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1))
            return x
        else:
            return np.round(x / scale * ((1 << (num_bits - 1)) - 1))
834 835 836


class ConvertToInt8Pass(object):
837 838 839 840 841 842 843 844 845
    """
    Convert the weights into int8_t type.

    Args:
        scope(fluid.Scope): scope is used to get the weight tensor values.
        place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
        8bits weight tensors.
    """

846 847 848 849 850 851 852 853 854 855
    def __init__(self, scope, place):
        assert scope is not None, \
            'The scope cannot be set None.'
        assert place is not None, \
            'The place cannot be set None.'
        self._scope = scope
        self._place = place
        self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']

    def apply(self, graph):
856 857 858 859 860 861 862
        """
        Convert weights' tpye of the graph. After that, the data type of the
        graph weigths is int8_t.

        Args:
            graph(IrGraph): the applied graph.
        """
863 864
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
        ops = graph.all_op_nodes()
865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880
        input_map = {}
        for op_node in ops:
            op_name = op_node.name()
            if op_name in self._quantizable_ops:
                for var_node in op_node.inputs:
                    name = var_node.name()
                    if name in persistable_vars:
                        if name not in input_map:
                            int8_var_node = self._convert_to_int8(graph,
                                                                  var_node)
                            input_map[name] = int8_var_node
                        graph.update_input_link(var_node, input_map[name],
                                                op_node)

        # remove the unused var node in the graph
        self._remove_unused_var_nodes(graph)
Z
Zhen Wang 已提交
881
        graph.resolve_hazard()
882 883 884 885
        return graph

    def _convert_to_int8(self, graph, var_node):
        int8_var_node_name = var_node.name() + ".int8"
886
        int8_var_node = graph.create_persistable_node(
887
            name=cpt.to_text(int8_var_node_name),
888 889
            var_type=var_node.type(),
            shape=var_node.shape(),
890 891 892 893 894 895 896 897 898 899 900 901 902 903 904
            var_dtype=core.VarDesc.VarType.INT8)
        array = self._load_var(var_node.name())
        self._scope.var(int8_var_node_name)
        self._store_var(int8_var_node_name, array, np.int8)
        return int8_var_node

    def _load_var(self, name):
        return np.array(self._scope.find_var(name).get_tensor())

    def _store_var(self, name, array, dtype):
        tensor = self._scope.find_var(name).get_tensor()
        tensor.set(array.astype(dtype), self._place)

    def _remove_unused_var_nodes(self, graph):
        all_used_vars = set()
905
        ops = graph.all_op_nodes()
906 907 908 909 910 911
        for op_node in ops:
            for input_node in op_node.inputs:
                all_used_vars.add(input_node)
            for output_node in op_node.outputs:
                all_used_vars.add(output_node)

912 913 914 915 916 917
        all_used_vars = {n.node for n in all_used_vars}
        all_unused_vars = {
            n
            for n in filter(lambda node: node.node not in all_used_vars,
                            graph.all_var_nodes())
        }
918 919 920 921
        graph.safe_remove_nodes(all_unused_vars)


class TransformForMobilePass(object):
922 923 924 925
    """
    This pass is used to convert the freezed graph for paddle-mobile execution.
    """

926 927
    def __init__(self):
        self._fake_quant_op_names = [
928 929 930 931 932 933
            'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
            'fake_quantize_moving_average_abs_max',
            'fake_channel_wise_quantize_abs_max'
        ]
        self._fake_dequant_op_names = [
            'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
934 935 936
        ]

    def apply(self, graph):
937 938 939 940 941 942 943 944
        """
        Because paddle-mobile use `quantize` an `dequantize` as the names of
        quantize operator and dequantize operator, the `apply` function just
        realize this logic.

        Args:
            graph(IrGraph): the graph will be transformed.
        """
945
        ops = graph.all_op_nodes()
946 947 948
        for op_node in ops:
            name = op_node.name()
            if name in self._fake_quant_op_names:
949
                op_node.set_type('quantize')
950 951 952 953 954 955 956
                quant_node = graph.create_op_node_from_desc(op_node.op())
                for input_node in op_node.inputs:
                    graph.link_to(input_node, quant_node)
                for output_node in op_node.outputs:
                    graph.link_to(quant_node, output_node)
                graph.safe_remove_nodes(op_node)
            if name in self._fake_dequant_op_names:
957
                op_node.set_type('dequantize')
958 959 960 961 962 963
                dequant_node = graph.create_op_node_from_desc(op_node.op())
                for input_node in op_node.inputs:
                    graph.link_to(input_node, dequant_node)
                for output_node in op_node.outputs:
                    graph.link_to(dequant_node, output_node)
                graph.safe_remove_nodes(op_node)
Z
Zhen Wang 已提交
964
        graph.resolve_hazard()
965
        return graph
966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997


class ScaleForTrainingPass(object):
    def __init__(self, scope=None, place=None, moving_rate=0.9):
        """
        This pass is used for calculating output scales of some operators.
        These output scales may be used by tensorRT or some other inference engines.

        Args:
            scope(fluid.Scope): The scope is used to initialize these new parameters.
            place(fluid.CPUPlace|fluid.CUDAPlace): The place is used to initialize new parameters.
            moving_rate(float): The decay coefficient of moving average. The default value is 0.9.
        """
        self._scope = scope
        self._place = place
        self._moving_rate = moving_rate
        self._is_test = None
        self._teller_set = [
            "mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
            "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
            "elementwise_add", "elementwise_mul", "dropout", "split", "prelu",
            "conv2d_transpose", "leaky_relu"
        ]

    def apply(self, graph):
        """
        Insert the `moving_average_abs_max_scale` op in order to calculate output scales
        of operators in the teller_set.

        Args:
            graph(IrGraph): the target graph.
        """
998 999
        assert isinstance(graph,
                          IrGraph), 'graph must be the instance of IrGraph.'
1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104
        self._is_test = graph.is_test()
        ops = graph.all_op_nodes()
        for op_node in ops:
            name = op_node.name()
            if name in self._teller_set:
                if len(op_node.output_arg_names()) != 1:
                    continue
                in_node = graph._find_node_by_name(
                    op_node.outputs, op_node.output_arg_names()[0])
                out_node = graph.create_var_node_from_desc(in_node.var())
                scale_node = graph.create_persistable_node(
                    name=self._scale_name(in_node.name()),
                    var_type=core.VarDesc.VarType.LOD_TENSOR,
                    shape=[1],
                    var_dtype=in_node.dtype())
                ins = {'X': in_node}
                outs = {'Out': out_node, 'OutScale': scale_node}
                if not self._is_test:
                    state_in_node = graph.create_persistable_node(
                        name=unique_name.generate('scale_state@'),
                        var_type=core.VarDesc.VarType.LOD_TENSOR,
                        var_dtype=in_node.dtype(),
                        shape=[1])
                    data_type = 'float64' if in_node.dtype(
                    ) == core.VarDesc.VarType.FP64 else 'float32'
                    _init_var_node(
                        state_in_node,
                        np.ones(
                            [1], dtype=data_type),
                        self._scope,
                        self._place)
                    accum_in_node = graph.create_persistable_node(
                        name=unique_name.generate('scale_accum@'),
                        var_type=core.VarDesc.VarType.LOD_TENSOR,
                        var_dtype=in_node.dtype(),
                        shape=[1])
                    _init_var_node(
                        accum_in_node,
                        np.ones(
                            [1], dtype=data_type),
                        self._scope,
                        self._place)
                    state_out_node = graph.create_var_node_from_desc(
                        state_in_node.var())
                    accum_out_node = graph.create_var_node_from_desc(
                        accum_in_node.var())

                    ins['InState'] = state_in_node
                    ins['InAccum'] = accum_in_node
                    outs['OutState'] = state_out_node
                    outs['OutAccum'] = accum_out_node

                attrs = {
                    'moving_rate': self._moving_rate,
                    'is_test': self._is_test,
                    'op_role': core.op_proto_and_checker_maker.OpRole.Forward
                }
                scale_op_node = graph.create_op_node(
                    op_type='moving_average_abs_max_scale',
                    attrs=attrs,
                    inputs=ins,
                    outputs=outs)
                graph.link_to(in_node, scale_op_node)
                graph.link_to(scale_op_node, out_node)
                graph.link_to(scale_op_node, scale_node)
                if not self._is_test:
                    graph.link_to(state_in_node, scale_op_node)
                    graph.link_to(accum_in_node, scale_op_node)
                    graph.link_to(scale_op_node, state_out_node)
                    graph.link_to(scale_op_node, accum_out_node)
        graph.resolve_hazard()
        return graph

    def _scale_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
        return "%s@scale" % (var_name)


class ScaleForInferencePass(object):
    def __init__(self, scope=None):
        """
        This pass is used for setting output scales of some operators.
        These output scales may be used by tensorRT or some other inference engines.

        Args:
            scope(fluid.Scope): The scope is used to initialize these new parameters.
        """
        self._scope = scope
        self._teller_set = [
            "mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
            "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
            "elementwise_add", "elementwise_mul", "dropout", "split", "prelu",
            "conv2d_transpose", "leaky_relu"
        ]

    def apply(self, graph):
        """
        Get output scales from the scope and set these scales in op_descs
        of operators in the teller_set.

        Args:
            graph(IrGraph): the target graph.
        """
1105 1106
        assert isinstance(graph,
                          IrGraph), 'graph must be the instance of IrGraph.'
1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124
        ops = graph.all_op_nodes()
        for op_node in ops:
            name = op_node.name()
            if name in self._teller_set:
                if len(op_node.output_arg_names()) != 1:
                    continue
                scale_name = self._scale_name(op_node.output_arg_names()[0])
                scale_v = np.array(
                    self._scope.find_var(scale_name).get_tensor())[0]
                op_node.op()._set_attr("out_scale", float(scale_v))
        graph.resolve_hazard()
        return graph

    def _scale_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
        return "%s@scale" % (var_name)
1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258


class AddQuantDequantPass(object):
    def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8):
        """
        This pass is used to add quant_dequant op for some ops, such as the
        `elementwise_add` op.
        """
        self._scope = scope
        self._place = place
        self._moving_rate = moving_rate
        self._quant_bits = quant_bits
        self._is_test = None
        self._target_ops = ["elementwise_add", "pool2d"]

    def apply(self, graph):
        """
        Add quant_dequant before some ops, such as the `elementwise_add` op. This
        is required by TensorRT.
        Args:
            graph(IrGraph): the target graph.
        """
        assert isinstance(graph,
                          IrGraph), 'graph must be the instance of IrGraph.'
        self._is_test = graph.is_test()
        ops = graph.all_op_nodes()
        for op_node in ops:
            name = op_node.name()
            if name in self._target_ops:
                in_nodes_all_not_persistable = True
                for input_name in op_node.input_arg_names():
                    in_node = graph._find_node_by_name(op_node.inputs,
                                                       input_name)
                    in_nodes_all_not_persistable = (
                        in_nodes_all_not_persistable and
                        not in_node.persistable())
                if not in_nodes_all_not_persistable:
                    continue
                input_names = op_node.input_arg_names()
                for input_name in input_names:
                    in_node = graph._find_node_by_name(op_node.inputs,
                                                       input_name)
                    quant_var_node, scale_var_node = self._inser_quant_dequant_moving_average_abs_max_op(
                        graph, in_node, self._quant_bits)
                    graph.update_input_link(in_node, quant_var_node, op_node)
        graph.resolve_hazard()
        return graph

    def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
                                                       quant_bits):
        """Insert fake_quantize_dequantize_moving_average_abs_max op.
        """
        quant_var_node = graph.create_var_node(
            name="{}.quant_dequant".format(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
        scale_in_node = graph.create_persistable_node(
            name="{}.quant_dequant.scale".format(var_node.name()),
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
            var_dtype=var_node.dtype())
        data_type = 'float64' if var_node.dtype(
        ) == core.VarDesc.VarType.FP64 else 'float32'
        _init_var_node(
            scale_in_node,
            np.array(
                [0.001], dtype=data_type),
            self._scope,
            self._place)

        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        ins = {'X': var_node, 'InScale': scale_in_node}
        outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
        if not self._is_test:
            state_in_node = graph.create_persistable_node(
                name=unique_name.generate('quant_dequant.state'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
                shape=[1])
            data_type = 'float64' if var_node.dtype(
            ) == core.VarDesc.VarType.FP64 else 'float32'
            _init_var_node(
                state_in_node,
                np.ones(
                    [1], dtype=data_type),
                self._scope,
                self._place)
            accum_in_node = graph.create_persistable_node(
                name=unique_name.generate('quant_dequant.accum'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
                shape=[1])
            _init_var_node(
                accum_in_node,
                np.ones(
                    [1], dtype=data_type),
                self._scope,
                self._place)
            state_out_node = graph.create_var_node_from_desc(state_in_node.var(
            ))
            accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
            ))

            ins['InState'] = state_in_node
            ins['InAccum'] = accum_in_node
            outs['OutState'] = state_out_node
            outs['OutAccum'] = accum_out_node

        attrs = {
            'bit_length': quant_bits,
            'moving_rate': self._moving_rate,
            'is_test': self._is_test,
            'op_role': core.op_proto_and_checker_maker.OpRole.Forward
        }

        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_dequantize_moving_average_abs_max',
            attrs=attrs,
            inputs=ins,
            outputs=outs)

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_in_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_out_node)

        if not self._is_test:
            graph.link_to(state_in_node, quant_op_node)
            graph.link_to(accum_in_node, quant_op_node)
            graph.link_to(quant_op_node, state_out_node)
            graph.link_to(quant_op_node, accum_out_node)

        return quant_var_node, scale_out_node