quantization_pass.py 62.0 KB
Newer Older
W
WangZhen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
W
WangZhen 已提交
16
import numpy as np
W
WangZhen 已提交
17
from ..... import compat as cpt
W
WangZhen 已提交
18
from .... import core
19
from ....framework import IrGraph
20
from ....framework import IrNode
W
WangZhen 已提交
21 22
from .... import unique_name

23 24
__all__ = [
    'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass',
25 26
    'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass',
    'AddQuantDequantPass'
27
]
W
WangZhen 已提交
28

29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
_fake_quant_op_list = [
    'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
    'fake_quantize_moving_average_abs_max', 'fake_channel_wise_quantize_abs_max'
]

_fake_dequant_op_list = [
    'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
]

_out_scale_op_list = [
    "mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d",
    "batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul",
    "dropout", "split", "prelu", "conv2d_transpose", "leaky_relu"
]

44 45 46 47 48
# list op real input and output names, to avoid processing input such as AxisTensor.
_op_real_in_out_name = {
    "conv2d": [["Input", "Filter"], ["Output"]],
    "depthwise_conv2d": [["Input"], ["Output"]],
    "mul": [["X", "Y"], ["Out"]],
49
    "matmul": [["X", "Y"], ["Out"]],
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    "pool2d": [["X"], ["Out"]],
    "elementwise_add": [["X", "Y"], ["Out"]],
    "concat": [["X"], ["Out"]],
    "softmax": [["X"], ["Out"]],
    "argmax": [["X"], ["Out"]],
    "transpose": [["X"], ["Out"]],
    "equal": [["X", "Y"], ["Out"]],
    "gather": [["X"], ["Out"]],
    "greater_equal": [["X", "Y"], ["Out"]],
    "greater_than": [["X", "Y"], ["Out"]],
    "less_equal": [["X", "Y"], ["Out"]],
    "less_than": [["X", "Y"], ["Out"]],
    "mean": [["X"], ["Out"]],
    "not_equal": [["X", "Y"], ["Out"]],
    "reshape": [["X"], ["Out"]],
    "reshape2": [["X"], ["Out"]],
    "bilinear_interp": [["X"], ["Out"]],
    "nearest_interp": [["X"], ["Out"]],
    "trilinear_interp": [["X"], ["Out"]],
    "slice": [["Input"], ["Out"]],
    "squeeze": [["X"], ["Out"]],
    "elementwise_sub": [["X", "Y"], ["Out"]],
    "relu": [["X"], ["Out"]],
    "relu6": [["X"], ["Out"]],
    "leaky_relu": [["X"], ["Out"]],
    "tanh": [["X"], ["Out"]],
    "swish": [["X"], ["Out"]],
}

W
WangZhen 已提交
79

80 81 82 83
def _init_var_node(var_node, value, scope, place):
    assert isinstance(value,
                      np.ndarray), 'The type of value should be numpy array.'
    assert scope is not None, \
84
        'The scope cannot be set None.'
85
    assert place is not None, \
86
        'The place cannot be set None.'
87 88 89 90
    tensor = scope.var(var_node.name()).get_tensor()
    tensor.set(value, place)


91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
def _is_input_all_not_persistable(graph, op_node):
    '''
    Analyse the real inputs of the op node are all not persistable.
    '''
    is_input_all_not_persistable = True
    op_node_name = op_node.name()
    input_name_list = _op_real_in_out_name[op_node_name][0]
    for input_name in input_name_list:
        for arg_name in op_node.input(input_name):
            in_node = graph._find_node_by_name(op_node.inputs, arg_name)
            is_input_all_not_persistable = (is_input_all_not_persistable and \
                (not in_node.persistable()))
    return is_input_all_not_persistable


106
class QuantizationTransformPass(object):
107 108 109
    _supported_quantizable_op_type = [
        'conv2d', 'depthwise_conv2d', 'mul', 'matmul'
    ]
110

W
WangZhen 已提交
111
    def __init__(self,
112
                 scope=None,
113
                 place=None,
W
WangZhen 已提交
114 115 116 117
                 weight_bits=8,
                 activation_bits=8,
                 activation_quantize_type='abs_max',
                 weight_quantize_type='abs_max',
118
                 window_size=10000,
119
                 moving_rate=0.9,
120
                 skip_pattern=['skip_quant'],
121
                 quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
W
WangZhen 已提交
122
        """
123
        Convert and rewrite the IrGraph according to weight and
W
WangZhen 已提交
124
        activation quantization type.
125

W
WangZhen 已提交
126
        Args:
127
            scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
128 129
                type, this pass will create some new parameters. The scope is used to
                initialize these new parameters.
130
            place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new
131
                parameters described above.
132
            weight_bits(int): quantization bit number for weights,
W
WangZhen 已提交
133
                the bias is not quantized.
134 135
            activation_bits(int): quantization bit number for activation.
            activation_quantize_type(str): quantization type for activation,
136 137 138 139 140
                now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
                If use 'abs_max' mode, the quantization scale will be calculated
                dynamically each step in both training and testing period. If use
                'range_abs_max', a static quantization scale will be calculated
                during training and used in inference.
141
            weight_quantize_type(str): quantization type for weights,
142 143 144
                support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max'
                usually is not used for weight, since weights are fixed once the
                model is well trained.
145 146
            window_size(int): the window size for 'range_abs_max' quantization.
            moving_rate(float): the param for 'moving_average_abs_max' quantization.
147
            skip_pattern(str or str list): The user-defined quantization skip pattern, which
148
                will be presented in the name scope of an op. When the skip pattern is
149
                detected in an op's name scope, the corresponding op will not be quantized. 
150
            quantizable_op_type(list[str]): List the type of ops that will be quantized. 
151 152
                Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
                QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
153

W
WangZhen 已提交
154 155
        Examples:
        .. code-block:: python
156 157 158 159
            # The original graph will be rewrite.
            import paddle.fluid as fluid
            from paddle.fluid.contrib.slim.quantization \
                import QuantizationTransformPass
160
            from paddle.fluid.contrib.slim.graph import IrGraph
161 162
            from paddle.fluid import core

163
            graph = IrGraph(core.Graph(program.desc), for_test=False)
164
            place = fluid.CPUPlace()
165
            transform_pass = QuantizationTransformPass(fluid.global_scope(),
166
            place)
167
            transform_pass.apply(graph)
W
WangZhen 已提交
168
        """
169
        self._scope = scope
170
        self._place = place
171 172
        self._weight_bits = weight_bits
        self._activation_bits = activation_bits
173
        self._skip_pattern = skip_pattern
W
WangZhen 已提交
174

175 176 177 178
        quant_type = [
            'abs_max', 'channel_wise_abs_max', 'range_abs_max',
            'moving_average_abs_max'
        ]
179 180
        assert activation_quantize_type != 'channel_wise_abs_max', \
            "The activation quantization type does not support 'channel_wise_abs_max'."
W
WangZhen 已提交
181 182
        if activation_quantize_type not in quant_type:
            raise ValueError(
183 184 185
                "Unknown activation_quantize_type : '%s'. It can only be "
                "'abs_max' or 'range_abs_max' or 'moving_average_abs_max'." %
                (str(activation_quantize_type)))
W
WangZhen 已提交
186 187
        if weight_quantize_type not in quant_type:
            raise ValueError(
188 189 190
                "Unknown weight_quantize_type: '%s'. It can only be "
                "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
                % (str(weight_quantize_type)))
W
WangZhen 已提交
191

192 193 194
        self._activation_quantize_type = activation_quantize_type
        self._weight_quantize_type = weight_quantize_type
        self._window_size = window_size
195
        self._moving_rate = moving_rate
W
WangZhen 已提交
196

197 198
        self._quantizable_ops = quantizable_op_type
        for op in self._quantizable_ops:
199
            assert op in QuantizationTransformPass._supported_quantizable_op_type, \
200
                op + " is not supported for quantization."
201
        self._conv_ops = ['conv2d', 'depthwise_conv2d']
202 203
        self._quantizable_grad_ops = [
            '%s_grad' % (op) for op in self._quantizable_ops
W
WangZhen 已提交
204
        ]
205 206
        self._is_test = None
        self._global_step = None
W
WangZhen 已提交
207

208
    def apply(self, graph):
209 210 211 212 213 214 215
        """
        Quantize the graph for training process. According to weight and
        activation quantization type, the graph will be added some fake
        quantize operators and fake dequantize operators.

        Args:
            graph(IrGraph): the applied graph.
216 217
        Returns:
            None
218
        """
W
WangZhen 已提交
219
        assert isinstance(graph,
220 221
                          IrGraph), 'graph must be the instance of IrGraph.'
        self._is_test = graph.is_test()
W
WangZhen 已提交
222 223
        # marked the variable which has been dequantized.
        dequantized_vars = collections.OrderedDict()
224
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
W
WangZhen 已提交
225

226
        def _quant_preprocess(op_node):
227 228 229 230 231 232 233
            user_skipped = False
            if isinstance(self._skip_pattern, list):
                user_skipped = op_node.op().has_attr("op_namescope") and \
                               any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
            elif isinstance(self._skip_pattern, str):
                user_skipped = op_node.op().has_attr("op_namescope") and \
                               op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
234

235
            if user_skipped:
236 237
                op_node.op()._set_attr("skip_quant", True)

W
WangZhen 已提交
238 239
        def _transform_forward(graph, op):
            for var_node in op.inputs:
240 241
                if var_node.name() not in op.input_arg_names():
                    continue
W
WangZhen 已提交
242 243 244
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
                else:
W
WangZhen 已提交
245
                    quant_bits = self._weight_bits if var_node.name() in persistable_vars \
246
                        else self._activation_bits
247
                    quant_type = self._weight_quantize_type if var_node.name() \
W
WangZhen 已提交
248
                        in persistable_vars else self._activation_quantize_type
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
                    if quant_type == 'channel_wise_abs_max':
                        assert var_node.name(
                        ) in persistable_vars, "'channel_wise_abs_max' can only be applied on weights."
                        if op.name() in self._conv_ops:
                            quant_var_node, scale_var_node = self._insert_channel_quant_op(
                                graph, var_node, quant_bits)
                            dequant_var_node = self._insert_channel_dequant_op(
                                graph, quant_var_node, [scale_var_node],
                                [quant_bits])
                        else:
                            quant_var_node, scale_var_node = self._insert_quant_op(
                                graph, var_node, quant_bits, 'abs_max')
                            dequant_var_node = self._insert_dequant_op(
                                graph, quant_var_node, scale_var_node,
                                quant_bits)
                    else:
                        quant_var_node, scale_var_node = self._insert_quant_op(
                            graph, var_node, quant_bits, quant_type)
                        dequant_var_node = self._insert_dequant_op(
                            graph, quant_var_node, scale_var_node, quant_bits)
W
WangZhen 已提交
269
                    dequantized_vars[var_node.name()] = dequant_var_node
270
                graph.update_input_link(var_node, dequant_var_node, op)
W
WangZhen 已提交
271 272 273

        def _transform_backward(graph, op):
            for var_node in op.inputs:
274 275
                if var_node.name() not in op.input_arg_names():
                    continue
W
WangZhen 已提交
276 277
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
278
                    graph.update_input_link(var_node, dequant_var_node, op)
W
WangZhen 已提交
279

280
        if not self._is_test:
W
WangZhen 已提交
281
            self._create_global_step(graph)
282
        ops = graph.all_op_nodes()
283 284 285 286 287 288
        # Do the preproccess of quantization, such as skipping some ops
        # for not being quantized.
        for op in ops:
            if op.name() in self._quantizable_ops or \
                    op.name() in self._quantizable_grad_ops:
                _quant_preprocess(op)
W
WangZhen 已提交
289 290
        # The process of _transform_forward and _transform_backward is needed in two for loops.
        # The loop for transforming the forward graph:
W
WangZhen 已提交
291
        for op in ops:
292
            if op.name() in self._quantizable_ops:
293 294
                if not QuantizationTransformPass._is_skip_quant(graph, op):
                    _transform_forward(graph, op)
W
WangZhen 已提交
295 296
        # The loop for renaming the inputs of backward op.
        for op in ops:
297
            if op.name() in self._quantizable_grad_ops:
W
WangZhen 已提交
298
                _transform_backward(graph, op)
Z
Zhen Wang 已提交
299
        graph.resolve_hazard()
300
        return graph
W
WangZhen 已提交
301

W
WangZhen 已提交
302
    def _create_global_step(self, graph):
303 304
        if self._weight_quantize_type == 'range_abs_max' or \
                self._activation_quantize_type == 'range_abs_max':
W
WangZhen 已提交
305
            counter_name = cpt.to_text('@STEP_COUNTER@')
306
            for node in graph.all_var_nodes():
W
WangZhen 已提交
307
                if node.name() == counter_name:
308 309
                    self._global_step = node
            if self._global_step is None:
310
                global_step_in = graph.create_persistable_node(
W
WangZhen 已提交
311 312 313 314
                    name=counter_name,
                    var_type=core.VarDesc.VarType.LOD_TENSOR,
                    shape=[1],
                    var_dtype=core.VarDesc.VarType.INT64)
315 316 317 318 319 320
                _init_var_node(
                    global_step_in,
                    np.zeros(
                        [1], dtype='int64'),
                    self._scope,
                    self._place)
W
WangZhen 已提交
321 322
                global_step_out = graph.create_var_node_from_desc(
                    global_step_in.var())
323
                # The attribute of `op_role` is needed by ParallelExecutor.
W
WangZhen 已提交
324 325
                increment_op = graph.create_op_node(
                    op_type='increment',
326 327 328 329 330
                    attrs={
                        'step': 1.0,
                        'op_role':
                        core.op_proto_and_checker_maker.OpRole.Forward
                    },
W
WangZhen 已提交
331 332
                    inputs={'X': global_step_in},
                    outputs={'Out': global_step_out})
333 334 335
                graph.link_to(global_step_in, increment_op)
                graph.link_to(increment_op, global_step_out)
                self._global_step = global_step_out
W
WangZhen 已提交
336

W
WangZhen 已提交
337 338 339 340 341 342 343
    def _insert_quant_op(self, graph, var_node, quant_bits, quant_type):
        """
        Insert fake_quantize_op in the graph.
        """
        if quant_type == 'abs_max':
            return self._insert_quant_abs_max_op(graph, var_node, quant_bits)
        elif quant_type == 'range_abs_max':
W
WangZhen 已提交
344 345
            return self._insert_quant_range_abs_max_op(graph, var_node,
                                                       quant_bits)
346 347 348
        elif quant_type == 'moving_average_abs_max':
            return self._insert_quant_moving_average_abs_max_op(graph, var_node,
                                                                quant_bits)
W
WangZhen 已提交
349 350 351 352 353 354 355 356 357

    def _insert_quant_abs_max_op(self, graph, var_node, quant_bits):
        """
        Insert fake_quantize_abs_max op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
358 359 360
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
W
WangZhen 已提交
361 362
        scale_var_node = graph.create_var_node(
            name=self._quantized_scale_name(var_node.name()),
363
            var_type=var_node.type(),
364
            shape=[1],
365
            var_dtype=var_node.dtype())
W
WangZhen 已提交
366 367
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_abs_max',
368 369 370 371
            attrs={
                'bit_length': quant_bits,
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
            },
W
WangZhen 已提交
372 373 374
            inputs={'X': var_node},
            outputs={'Out': quant_var_node,
                     'OutScale': scale_var_node})
375 376 377
        graph.link_to(var_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_var_node)
W
WangZhen 已提交
378 379 380 381 382 383 384 385 386 387
        return quant_var_node, scale_var_node

    def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
        """
        Insert fake_quantize_range_abs_max on the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
388 389 390
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
W
WangZhen 已提交
391

392
        scale_in_node = graph.create_persistable_node(
W
WangZhen 已提交
393 394 395
            name=self._quantized_scale_name(var_node.name()),
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
396
            var_dtype=var_node.dtype())
397 398
        data_type = 'float64' if var_node.dtype(
        ) == core.VarDesc.VarType.FP64 else 'float32'
399 400 401 402 403 404
        _init_var_node(
            scale_in_node,
            np.array(
                [0.001], dtype=data_type),
            self._scope,
            self._place)
W
WangZhen 已提交
405 406 407 408 409

        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        inputs = {'X': var_node, 'InScale': scale_in_node}
        outputs = {'Out': quant_var_node, 'OutScale': scale_out_node}

410
        if not self._is_test:
W
WangZhen 已提交
411
            # The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
412
            scales_node = graph.create_persistable_node(
W
WangZhen 已提交
413 414
                name=unique_name.generate('scales'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
415
                shape=[self._window_size],
416
                var_dtype=var_node.dtype())
417 418
            data_type = 'float64' if var_node.dtype(
            ) == core.VarDesc.VarType.FP64 else 'float32'
419 420 421 422 423 424 425
            _init_var_node(
                scales_node,
                np.zeros(
                    [self._window_size], dtype=data_type),
                self._scope,
                self._place)

426
            inputs['Iter'] = self._global_step
W
WangZhen 已提交
427 428
            outputs['OutScales'] = scales_node
        attrs = {
429
            'window_size': self._window_size,
W
WangZhen 已提交
430
            'bit_length': quant_bits,
431 432
            'is_test': self._is_test,
            'op_role': core.op_proto_and_checker_maker.OpRole.Forward
W
WangZhen 已提交
433 434 435 436 437 438 439
        }
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_range_abs_max',
            attrs=attrs,
            inputs=inputs,
            outputs=outputs)

440 441 442 443
        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_in_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_out_node)
W
WangZhen 已提交
444

445 446 447
        if not self._is_test:
            graph.link_to(self._global_step, quant_op_node)
            graph.link_to(quant_op_node, scales_node)
W
WangZhen 已提交
448 449 450

        return quant_var_node, scale_out_node

451 452 453 454 455 456 457 458 459 460 461 462 463 464
    def _insert_quant_moving_average_abs_max_op(self, graph, var_node,
                                                quant_bits):
        """Insert fake_quantize_moving_average_abs_max
        """
        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
        scale_in_node = graph.create_persistable_node(
            name=self._quantized_scale_name(var_node.name()),
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
            var_dtype=var_node.dtype())
465 466
        data_type = 'float64' if var_node.dtype(
        ) == core.VarDesc.VarType.FP64 else 'float32'
467 468 469 470 471 472
        _init_var_node(
            scale_in_node,
            np.array(
                [0.001], dtype=data_type),
            self._scope,
            self._place)
473 474 475 476 477 478 479 480 481 482

        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        ins = {'X': var_node, 'InScale': scale_in_node}
        outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
        if not self._is_test:
            state_in_node = graph.create_persistable_node(
                name=unique_name.generate('state'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
                shape=[1])
483 484
            data_type = 'float64' if var_node.dtype(
            ) == core.VarDesc.VarType.FP64 else 'float32'
485
            _init_var_node(
486
                state_in_node,
487 488 489 490
                np.ones(
                    [1], dtype=data_type),
                self._scope,
                self._place)
491 492 493 494 495
            accum_in_node = graph.create_persistable_node(
                name=unique_name.generate('accum'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
                shape=[1])
496 497 498 499 500 501
            _init_var_node(
                accum_in_node,
                np.ones(
                    [1], dtype=data_type),
                self._scope,
                self._place)
502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
            state_out_node = graph.create_var_node_from_desc(state_in_node.var(
            ))
            accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
            ))

            ins['InState'] = state_in_node
            ins['InAccum'] = accum_in_node
            outs['OutState'] = state_out_node
            outs['OutAccum'] = accum_out_node

        attrs = {
            'bit_length': quant_bits,
            'moving_rate': self._moving_rate,
            'is_test': self._is_test,
            'op_role': core.op_proto_and_checker_maker.OpRole.Forward
        }

        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_moving_average_abs_max',
            attrs=attrs,
            inputs=ins,
            outputs=outs)

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_in_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_out_node)

        if not self._is_test:
            graph.link_to(state_in_node, quant_op_node)
            graph.link_to(accum_in_node, quant_op_node)
            graph.link_to(quant_op_node, state_out_node)
            graph.link_to(quant_op_node, accum_out_node)

        return quant_var_node, scale_out_node

538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
    def _insert_channel_quant_op(self, graph, var_node, quant_bits):
        """
        Insert fake_channel_wise_quantize_abs_max op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
        scale_var_node = graph.create_var_node(
            name=self._quantized_scale_name(var_node.name()),
            var_type=var_node.type(),
            shape=[var_node.shape()[0]],
            var_dtype=var_node.dtype())
        quant_op_node = graph.create_op_node(
            op_type='fake_channel_wise_quantize_abs_max',
            attrs={
                'bit_length': quant_bits,
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
            },
            inputs={'X': var_node},
            outputs={'Out': quant_var_node,
                     'OutScale': scale_var_node})
        graph.link_to(var_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_var_node)
        return quant_var_node, scale_var_node

W
WangZhen 已提交
568 569 570 571 572 573 574 575
    def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
        """
        Insert fake_dequantize_op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(var_node.name()),
576 577 578
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
W
WangZhen 已提交
579 580 581
        max_range = (1 << (quant_bits - 1)) - 1
        dequant_op_node = graph.create_op_node(
            op_type='fake_dequantize_max_abs',
582 583 584 585
            attrs={
                'max_range': float(max_range),
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
            },
W
WangZhen 已提交
586 587 588
            inputs={'X': var_node,
                    'Scale': scale_var_node},
            outputs={'Out': dequant_var_node})
589 590 591
        graph.link_to(var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
W
WangZhen 已提交
592 593
        return dequant_var_node

594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620
    def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes,
                                   quant_bits):
        """
        Insert fake_channel_wise_dequantize_max_abs in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
        dequant_op_node = graph.create_op_node(
            op_type='fake_channel_wise_dequantize_max_abs',
            attrs={
                'quant_bits': quant_bits,
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
            },
            inputs={'X': var_node,
                    'Scales': scale_var_nodes},
            outputs={'Out': dequant_var_node})
        graph.link_to(var_node, dequant_op_node)
        for scale_n in scale_var_nodes:
            graph.link_to(scale_n, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
        return dequant_var_node

W
WangZhen 已提交
621 622 623 624 625 626 627 628 629 630 631 632 633 634
    def _quantized_var_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.quantized" % (var_name)

    def _dequantized_var_name(self, var_name):
        """
        Return dequantized variable name for the input `var_name`.
        """
        return "%s.dequantized" % (var_name)

    def _quantized_scale_name(self, var_name):
        """
635
        Return the scale name of quantized variable for the input `var_name`.
W
WangZhen 已提交
636 637
        """
        return "%s.scale" % (var_name)
W
WangZhen 已提交
638

639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654
    @staticmethod
    def _is_skip_quant(graph, op_node):
        """
        Analyse whether the op node skips quantization.
        """
        is_skip = False
        if op_node.op().has_attr("skip_quant") and \
            op_node.op().attr("skip_quant"):
            is_skip = True
        # if the inputs of mul and matmul are not all persistable, use
        # AddQuantDequantPass to quantize them.
        if op_node.name() in ["mul", "matmul"] and \
            _is_input_all_not_persistable(graph, op_node):
            is_skip = True
        return is_skip

W
WangZhen 已提交
655 656

class QuantizationFreezePass(object):
657 658
    _supported_quantizable_op_type = \
        QuantizationTransformPass._supported_quantizable_op_type
659

W
WangZhen 已提交
660 661 662 663 664
    def __init__(self,
                 scope,
                 place,
                 weight_bits=8,
                 activation_bits=8,
665 666
                 weight_quantize_type='abs_max',
                 quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
        """
        The freeze pass is used to adjust the quantize operator order, for example:
            1) `activation -> quant -> dequant -> conv2d` will be freezed into
            `activation -> quant -> conv2d -> dequant`
            2) `weight -> quant -> dequant -> conv2d` will be freezed into `weight -> conv2d`,
            and weight will be sacled offline.

        Args:
            scope(fluid.Scope): scope is used to get the weight tensor values.
            place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors.
            weight_bits(int): quantization bit number for weights.
            activation_bits(int): quantization bit number for activation.
            weight_quantize_type(str): quantization type for weights, support 'abs_max' and 
                'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, 
                since weights are fixed once the model is well trained.
            quantizable_op_type(list[str]): List the type of ops that will be quantized. 
                Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
                QuantizationTransformPass and ConvertToInt8Pass must be the same as this.
        """
W
WangZhen 已提交
686 687 688 689 690 691 692 693 694
        assert scope is not None, \
            'The scope cannot be set None.'
        assert place is not None, \
            'The place cannot be set None.'
        self._scope = scope
        self._place = place
        self._weight_bits = weight_bits
        self._activation_bits = activation_bits
        self._weight_quantize_type = weight_quantize_type
695 696
        self._quantizable_ops = quantizable_op_type
        for op in self._quantizable_ops:
697
            assert op in QuantizationFreezePass._supported_quantizable_op_type, \
698
                op + " is not supported for quantization."
699
        self._conv_ops = ['conv2d', 'depthwise_conv2d']
700 701
        self._fake_quant_op_names = _fake_quant_op_list
        self._fake_dequant_op_names = _fake_dequant_op_list
W
WangZhen 已提交
702 703 704 705 706
        self._op_input_rename_map = collections.OrderedDict()
        self._op_output_rename_map = collections.OrderedDict()
        self._var_scale_map = collections.OrderedDict()

    def apply(self, graph):
707 708 709 710 711
        """
        Adjust quantize/dequantize operators order for the inference process.

        Args:
            graph(IrGraph): the applied graph.
712 713
        Returns:
            None
714
        """
715 716
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
        ops = graph.all_op_nodes()
W
WangZhen 已提交
717 718 719
        for op_node in ops:
            op_name = op_node.name()
            if op_name in self._fake_quant_op_names:
720
                input_arg_name = op_node.input('X')[0]
W
WangZhen 已提交
721 722 723 724
                if input_arg_name in persistable_vars:
                    if self._weight_quantize_type == 'abs_max':
                        param = self._load_var(input_arg_name)
                        scale_v = np.max(np.abs(param))
725 726 727 728 729 730 731 732
                    elif self._weight_quantize_type == 'channel_wise_abs_max':
                        param = self._load_var(input_arg_name)
                        if len(param.shape) == 4:  # conv2d or depthwise_conv2d
                            scale_v = []
                            for i in range(param.shape[0]):
                                scale_v.append(np.max(np.abs(param[i])))
                        else:
                            scale_v = np.max(np.abs(param))
W
WangZhen 已提交
733
                    else:
734 735
                        scale_v = self._load_var(
                            op_node.output('OutScale')[0])[0]
W
WangZhen 已提交
736 737 738 739 740
                    self._var_scale_map[input_arg_name] = scale_v
                    self._remove_fake_quant_and_dequant_op(graph, op_node)
                    # quantize weight and restore
                    param_v = self._load_var(input_arg_name)
                    quantized_param_v = self._quant(param_v, scale_v,
W
WangZhen 已提交
741
                                                    self._weight_bits)
W
WangZhen 已提交
742
                    self._restore_var(input_arg_name, quantized_param_v)
743
                else:
744 745
                    scale_v = graph._find_node_by_name(
                        op_node.outputs, op_node.output('OutScale')[0])
746
                    self._var_scale_map[input_arg_name] = scale_v
W
WangZhen 已提交
747

748
        ops = graph.all_op_nodes()
W
WangZhen 已提交
749 750 751 752 753
        for op_node in ops:
            op_name = op_node.name()
            if op_name in self._fake_dequant_op_names:
                self._remove_fake_quant_and_dequant_op(graph, op_node)

754
        ops = graph.all_op_nodes()
W
WangZhen 已提交
755 756 757
        for op_node in ops:
            op_name = op_node.name()
            if op_name in self._quantizable_ops:
758 759 760 761 762 763 764 765 766 767 768
                # only process the node that is quantized by QuantizationTransformPass
                is_op_node_quantized = False
                for var_node in op_node.inputs:
                    var_name = var_node.name()
                    if var_name.endswith('.dequantized'):
                        is_op_node_quantized = True
                if is_op_node_quantized:
                    if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
                        self._insert_post_channel_dequant_op(graph, op_node)
                    else:
                        self._insert_post_dequant_op(graph, op_node)
W
WangZhen 已提交
769 770 771 772

        for op_node in ops:
            # insert dequant_op after fc/conv, need to rename inputs of the followed ops
            for var_node in op_node.inputs:
773 774 775
                if var_node.node in self._op_output_rename_map:
                    old_in = var_node
                    new_in = self._op_output_rename_map[var_node.node]
W
WangZhen 已提交
776 777 778 779
                    graph.update_input_link(old_in, new_in, op_node)

        # remove the unused var node in the graph
        self._remove_unused_var_nodes(graph)
Z
Zhen Wang 已提交
780
        graph.resolve_hazard()
781
        return graph
W
WangZhen 已提交
782 783

    def _remove_fake_quant_and_dequant_op(self, graph, op_node):
784 785
        k = graph._find_node_by_name(op_node.outputs, op_node.output('Out')[0])
        v = graph._find_node_by_name(op_node.inputs, op_node.input('X')[0])
786 787
        if v.node not in self._op_input_rename_map:
            self._op_input_rename_map[k.node] = v
W
WangZhen 已提交
788
        else:
789 790
            self._op_input_rename_map[k.node] = self._op_input_rename_map[
                v.node]
W
WangZhen 已提交
791
        graph.safe_remove_nodes(op_node)
W
WangZhen 已提交
792

793 794 795 796
    def _insert_post_channel_dequant_op(self, graph, op_node):
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
        for var_node in op_node.inputs:
            name = var_node.name()
797 798 799 800 801
            if name not in op_node.input_arg_names():
                continue
            if var_node.node in self._op_input_rename_map:
                old_in = var_node
                new_in = self._op_input_rename_map[var_node.node]
802 803 804 805 806 807 808 809 810 811 812 813 814 815
                new_in.clear_outputs()
                graph.update_input_link(old_in, new_in, op_node)
            original_var_name = self._original_var_name(name)
            scale_v = self._var_scale_map[original_var_name]
            if original_var_name in persistable_vars:
                assert isinstance(
                    scale_v,
                    list), 'The scale of parameter %s is not a list.' % (
                        original_var_name)
                channel_scale = np.array(scale_v)
            else:
                assert isinstance(scale_v, IrNode)
                scale_var_node = self._var_scale_map[original_var_name]

816
        if len(op_node.output_arg_names()) != 1:
817 818 819
            raise ValueError("Only support one output, but op %s has"
                             " more than one output." % (op_node.name()))

820 821
        output_var_node = graph._find_node_by_name(
            op_node.outputs, op_node.output_arg_names()[0])
822 823 824 825 826
        weight_scale_node = graph.create_persistable_node(
            name=unique_name.generate('channel_scale'),
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[channel_scale.shape[0]],
            var_dtype=output_var_node.dtype())
827 828
        data_type = 'float64' if output_var_node.dtype(
        ) == core.VarDesc.VarType.FP64 else 'float32'
829 830 831
        _init_var_node(weight_scale_node,
                       channel_scale.astype(data_type), self._scope,
                       self._place)
832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851
        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(output_var_node.name()),
            var_type=output_var_node.type(),
            shape=output_var_node.shape(),
            var_dtype=output_var_node.dtype())
        dequant_op_node = graph.create_op_node(
            op_type='fake_channel_wise_dequantize_max_abs',
            attrs={
                'quant_bits': [self._weight_bits, self._activation_bits],
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
            },
            inputs={
                'X': output_var_node,
                'Scales': [weight_scale_node, scale_var_node]
            },
            outputs={'Out': dequant_var_node})
        graph.link_to(output_var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        graph.link_to(weight_scale_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
852
        self._op_output_rename_map[output_var_node.node] = dequant_var_node
853 854
        return dequant_var_node

W
WangZhen 已提交
855
    def _insert_post_dequant_op(self, graph, op_node):
856
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
857 858 859
        max_range = 1
        param_range = (1 << (self._weight_bits - 1)) - 1
        act_range = (1 << (self._activation_bits - 1)) - 1
W
WangZhen 已提交
860
        for var_node in op_node.inputs:
W
WangZhen 已提交
861
            name = var_node.name()
862 863 864 865 866
            if name not in op_node.input_arg_names():
                continue
            if var_node.node in self._op_input_rename_map:
                old_in = var_node
                new_in = self._op_input_rename_map[var_node.node]
W
WangZhen 已提交
867
                new_in.clear_outputs()
W
WangZhen 已提交
868 869
                graph.update_input_link(old_in, new_in, op_node)
            original_var_name = self._original_var_name(name)
W
WangZhen 已提交
870
            scale_v = self._var_scale_map[original_var_name]
W
WangZhen 已提交
871 872 873 874
            if original_var_name in persistable_vars:
                assert self._is_float(
                    scale_v), 'The scale of parameter %s is not a float.' % (
                        original_var_name)
875
                max_range *= param_range / scale_v
W
WangZhen 已提交
876
            else:
877
                max_range *= act_range
878
                assert isinstance(scale_v, IrNode)
W
WangZhen 已提交
879 880
                scale_var_node = self._var_scale_map[original_var_name]

881
        if len(op_node.output_arg_names()) != 1:
W
WangZhen 已提交
882 883 884
            raise ValueError("Only support one output, but op %s has"
                             " more than one output." % (op_node.name()))

885 886
        output_var_node = graph._find_node_by_name(
            op_node.outputs, op_node.output_arg_names()[0])
W
WangZhen 已提交
887 888
        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(output_var_node.name()),
889 890 891
            var_type=output_var_node.type(),
            shape=output_var_node.shape(),
            var_dtype=output_var_node.dtype())
W
WangZhen 已提交
892 893
        dequant_op_node = graph.create_op_node(
            op_type='fake_dequantize_max_abs',
894 895 896 897
            attrs={
                'max_range': float(max_range),
                'op_role': core.op_proto_and_checker_maker.OpRole.Forward
            },
W
WangZhen 已提交
898 899 900 901 902 903
            inputs={'X': output_var_node,
                    'Scale': scale_var_node},
            outputs={'Out': dequant_var_node})
        graph.link_to(output_var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
904
        self._op_output_rename_map[output_var_node.node] = dequant_var_node
W
WangZhen 已提交
905 906 907 908 909
        return dequant_var_node

    def _load_var(self, name):
        return np.array(self._scope.find_var(name).get_tensor())

910 911 912
    def _restore_var(self, name, array):
        tensor = self._scope.find_var(name).get_tensor()
        tensor.set(array, self._place)
W
WangZhen 已提交
913 914 915

    def _remove_unused_var_nodes(self, graph):
        all_used_vars = set()
916
        ops = graph.all_op_nodes()
W
WangZhen 已提交
917 918 919 920 921 922
        for op_node in ops:
            for input_node in op_node.inputs:
                all_used_vars.add(input_node)
            for output_node in op_node.outputs:
                all_used_vars.add(output_node)

923 924 925 926 927 928
        all_used_vars = {n.node for n in all_used_vars}
        all_unused_vars = {
            n
            for n in filter(lambda node: node.node not in all_used_vars,
                            graph.all_var_nodes())
        }
W
WangZhen 已提交
929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951
        graph.safe_remove_nodes(all_unused_vars)

    def _original_var_name(self, var_name):
        """
        Return the original variable name.
        """
        if var_name.endswith('.quantized.dequantized'):
            return var_name[:-len('.quantized.dequantized')]
        if var_name.endswith('.quantized'):
            return var_name[:-len('.quantized')]
        if var_name.endswith('.dequantized'):
            return var_name[:-len('.dequantized')]
        if var_name.endswith('.scale'):
            return var_name[:-len('.scale')]
        else:
            return var_name

    def _dequantized_var_name(self, var_name):
        """
        Return dequantized variable name for the input `var_name`.
        """
        return "%s.dequantized" % (var_name)

W
WangZhen 已提交
952
    def _is_float(self, v):
W
WangZhen 已提交
953 954 955
        return isinstance(v, float) or isinstance(v, np.float32) \
            or isinstance(v, np.float64)

W
WangZhen 已提交
956
    def _quant(self, x, scale, num_bits):
957 958 959 960 961 962
        if isinstance(scale, list):
            for i, s in enumerate(scale):
                x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1))
            return x
        else:
            return np.round(x / scale * ((1 << (num_bits - 1)) - 1))
963 964 965


class ConvertToInt8Pass(object):
966 967
    _supported_quantizable_op_type = \
        QuantizationTransformPass._supported_quantizable_op_type
968

969 970 971 972
    def __init__(self,
                 scope,
                 place,
                 quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
973 974 975 976 977 978 979 980 981 982 983
        """
        Convert the weights into int8_t type.

        Args:
            scope(fluid.Scope): scope is used to get the weight tensor values.
            place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
                8bits weight tensors.
            quantizable_op_type(list[str]): List the type of ops that will be quantized. 
                Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
                QuantizationTransformPass and QuantizationFreezePass must be the same as this.
        """
984 985 986 987 988 989
        assert scope is not None, \
            'The scope cannot be set None.'
        assert place is not None, \
            'The place cannot be set None.'
        self._scope = scope
        self._place = place
990 991
        self._quantizable_ops = quantizable_op_type
        for op in self._quantizable_ops:
992
            assert op in ConvertToInt8Pass._supported_quantizable_op_type, \
993
                op + " is not supported for quantization."
994 995

    def apply(self, graph):
996 997 998 999 1000 1001
        """
        Convert weights' tpye of the graph. After that, the data type of the
        graph weigths is int8_t.

        Args:
            graph(IrGraph): the applied graph.
1002 1003
        Returns:
            None
1004
        """
1005 1006
        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
        ops = graph.all_op_nodes()
1007 1008 1009 1010
        input_map = {}
        for op_node in ops:
            op_name = op_node.name()
            if op_name in self._quantizable_ops:
1011
                if QuantizationTransformPass._is_skip_quant(graph, op_node):
1012
                    continue
1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024
                for var_node in op_node.inputs:
                    name = var_node.name()
                    if name in persistable_vars:
                        if name not in input_map:
                            int8_var_node = self._convert_to_int8(graph,
                                                                  var_node)
                            input_map[name] = int8_var_node
                        graph.update_input_link(var_node, input_map[name],
                                                op_node)

        # remove the unused var node in the graph
        self._remove_unused_var_nodes(graph)
Z
Zhen Wang 已提交
1025
        graph.resolve_hazard()
1026 1027 1028 1029
        return graph

    def _convert_to_int8(self, graph, var_node):
        int8_var_node_name = var_node.name() + ".int8"
1030
        int8_var_node = graph.create_persistable_node(
1031
            name=cpt.to_text(int8_var_node_name),
1032 1033
            var_type=var_node.type(),
            shape=var_node.shape(),
1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048
            var_dtype=core.VarDesc.VarType.INT8)
        array = self._load_var(var_node.name())
        self._scope.var(int8_var_node_name)
        self._store_var(int8_var_node_name, array, np.int8)
        return int8_var_node

    def _load_var(self, name):
        return np.array(self._scope.find_var(name).get_tensor())

    def _store_var(self, name, array, dtype):
        tensor = self._scope.find_var(name).get_tensor()
        tensor.set(array.astype(dtype), self._place)

    def _remove_unused_var_nodes(self, graph):
        all_used_vars = set()
1049
        ops = graph.all_op_nodes()
1050 1051 1052 1053 1054 1055
        for op_node in ops:
            for input_node in op_node.inputs:
                all_used_vars.add(input_node)
            for output_node in op_node.outputs:
                all_used_vars.add(output_node)

1056 1057 1058 1059 1060 1061
        all_used_vars = {n.node for n in all_used_vars}
        all_unused_vars = {
            n
            for n in filter(lambda node: node.node not in all_used_vars,
                            graph.all_var_nodes())
        }
1062 1063 1064 1065 1066
        graph.safe_remove_nodes(all_unused_vars)


class TransformForMobilePass(object):
    def __init__(self):
1067 1068 1069
        """
        This pass is used to convert the freezed graph for paddle-mobile execution.
        """
1070 1071
        self._fake_quant_op_names = _fake_quant_op_list
        self._fake_dequant_op_names = _fake_dequant_op_list
1072 1073

    def apply(self, graph):
1074 1075 1076 1077 1078 1079 1080
        """
        Because paddle-mobile use `quantize` an `dequantize` as the names of
        quantize operator and dequantize operator, the `apply` function just
        realize this logic.

        Args:
            graph(IrGraph): the graph will be transformed.
1081 1082
        Returns:
            None
1083
        """
1084
        ops = graph.all_op_nodes()
1085 1086 1087
        for op_node in ops:
            name = op_node.name()
            if name in self._fake_quant_op_names:
1088
                op_node.set_type('quantize')
1089 1090 1091 1092 1093 1094 1095
                quant_node = graph.create_op_node_from_desc(op_node.op())
                for input_node in op_node.inputs:
                    graph.link_to(input_node, quant_node)
                for output_node in op_node.outputs:
                    graph.link_to(quant_node, output_node)
                graph.safe_remove_nodes(op_node)
            if name in self._fake_dequant_op_names:
1096
                op_node.set_type('dequantize')
1097 1098 1099 1100 1101 1102
                dequant_node = graph.create_op_node_from_desc(op_node.op())
                for input_node in op_node.inputs:
                    graph.link_to(input_node, dequant_node)
                for output_node in op_node.outputs:
                    graph.link_to(dequant_node, output_node)
                graph.safe_remove_nodes(op_node)
Z
Zhen Wang 已提交
1103
        graph.resolve_hazard()
1104
        return graph
1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121


class ScaleForTrainingPass(object):
    def __init__(self, scope=None, place=None, moving_rate=0.9):
        """
        This pass is used for calculating output scales of some operators.
        These output scales may be used by tensorRT or some other inference engines.

        Args:
            scope(fluid.Scope): The scope is used to initialize these new parameters.
            place(fluid.CPUPlace|fluid.CUDAPlace): The place is used to initialize new parameters.
            moving_rate(float): The decay coefficient of moving average. The default value is 0.9.
        """
        self._scope = scope
        self._place = place
        self._moving_rate = moving_rate
        self._is_test = None
1122
        self._teller_set = _out_scale_op_list
1123 1124 1125 1126 1127 1128 1129 1130 1131

    def apply(self, graph):
        """
        Insert the `moving_average_abs_max_scale` op in order to calculate output scales
        of operators in the teller_set.

        Args:
            graph(IrGraph): the target graph.
        """
1132 1133
        assert isinstance(graph,
                          IrGraph), 'graph must be the instance of IrGraph.'
1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223
        self._is_test = graph.is_test()
        ops = graph.all_op_nodes()
        for op_node in ops:
            name = op_node.name()
            if name in self._teller_set:
                if len(op_node.output_arg_names()) != 1:
                    continue
                in_node = graph._find_node_by_name(
                    op_node.outputs, op_node.output_arg_names()[0])
                out_node = graph.create_var_node_from_desc(in_node.var())
                scale_node = graph.create_persistable_node(
                    name=self._scale_name(in_node.name()),
                    var_type=core.VarDesc.VarType.LOD_TENSOR,
                    shape=[1],
                    var_dtype=in_node.dtype())
                ins = {'X': in_node}
                outs = {'Out': out_node, 'OutScale': scale_node}
                if not self._is_test:
                    state_in_node = graph.create_persistable_node(
                        name=unique_name.generate('scale_state@'),
                        var_type=core.VarDesc.VarType.LOD_TENSOR,
                        var_dtype=in_node.dtype(),
                        shape=[1])
                    data_type = 'float64' if in_node.dtype(
                    ) == core.VarDesc.VarType.FP64 else 'float32'
                    _init_var_node(
                        state_in_node,
                        np.ones(
                            [1], dtype=data_type),
                        self._scope,
                        self._place)
                    accum_in_node = graph.create_persistable_node(
                        name=unique_name.generate('scale_accum@'),
                        var_type=core.VarDesc.VarType.LOD_TENSOR,
                        var_dtype=in_node.dtype(),
                        shape=[1])
                    _init_var_node(
                        accum_in_node,
                        np.ones(
                            [1], dtype=data_type),
                        self._scope,
                        self._place)
                    state_out_node = graph.create_var_node_from_desc(
                        state_in_node.var())
                    accum_out_node = graph.create_var_node_from_desc(
                        accum_in_node.var())

                    ins['InState'] = state_in_node
                    ins['InAccum'] = accum_in_node
                    outs['OutState'] = state_out_node
                    outs['OutAccum'] = accum_out_node

                attrs = {
                    'moving_rate': self._moving_rate,
                    'is_test': self._is_test,
                    'op_role': core.op_proto_and_checker_maker.OpRole.Forward
                }
                scale_op_node = graph.create_op_node(
                    op_type='moving_average_abs_max_scale',
                    attrs=attrs,
                    inputs=ins,
                    outputs=outs)
                graph.link_to(in_node, scale_op_node)
                graph.link_to(scale_op_node, out_node)
                graph.link_to(scale_op_node, scale_node)
                if not self._is_test:
                    graph.link_to(state_in_node, scale_op_node)
                    graph.link_to(accum_in_node, scale_op_node)
                    graph.link_to(scale_op_node, state_out_node)
                    graph.link_to(scale_op_node, accum_out_node)
        graph.resolve_hazard()
        return graph

    def _scale_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
        return "%s@scale" % (var_name)


class ScaleForInferencePass(object):
    def __init__(self, scope=None):
        """
        This pass is used for setting output scales of some operators.
        These output scales may be used by tensorRT or some other inference engines.

        Args:
            scope(fluid.Scope): The scope is used to initialize these new parameters.
        """
        self._scope = scope
1224
        self._teller_set = _out_scale_op_list
1225 1226 1227 1228 1229 1230 1231 1232 1233

    def apply(self, graph):
        """
        Get output scales from the scope and set these scales in op_descs
        of operators in the teller_set.

        Args:
            graph(IrGraph): the target graph.
        """
1234 1235
        assert isinstance(graph,
                          IrGraph), 'graph must be the instance of IrGraph.'
1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253
        ops = graph.all_op_nodes()
        for op_node in ops:
            name = op_node.name()
            if name in self._teller_set:
                if len(op_node.output_arg_names()) != 1:
                    continue
                scale_name = self._scale_name(op_node.output_arg_names()[0])
                scale_v = np.array(
                    self._scope.find_var(scale_name).get_tensor())[0]
                op_node.op()._set_attr("out_scale", float(scale_v))
        graph.resolve_hazard()
        return graph

    def _scale_name(self, var_name):
        """
        Return the scale name for the var named `var_name`.
        """
        return "%s@scale" % (var_name)
1254 1255 1256


class AddQuantDequantPass(object):
1257 1258 1259 1260 1261
    _supported_quantizable_op_type = [
        "pool2d", "elementwise_add", "concat", "softmax", "argmax", "transpose",
        "equal", "gather", "greater_equal", "greater_than", "less_equal",
        "less_than", "mean", "not_equal", "reshape", "reshape2",
        "bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
1262
        "squeeze", "elementwise_sub", "mul", "matmul"
1263 1264 1265
    ]
    _activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]

1266 1267 1268 1269 1270
    def __init__(self,
                 scope=None,
                 place=None,
                 moving_rate=0.9,
                 quant_bits=8,
1271
                 skip_pattern=["skip_quant"],
1272
                 quantizable_op_type=["elementwise_add", "pool2d"],
1273
                 is_full_quantized=False):
1274
        """
1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290
        This pass add quant_dequant op for some ops, of which all the inputs must be 
        not persistable.
        The input scales can be obtained from the quant_dequant op.

        Args:
            scope(fluid.Scope): The scope is used to initialize these new parameters.
            place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new
                parameters described above.
            moving_rate(float, optional): the param for 'quant_dequant_moving_average_abs_max' 
                quantization. Default is 0.9.
            quant_bits(int, optional): quantization bit number for activation. Default is 8.
            skip_pattern(str, optional): The user-defined quantization skip pattern, which
                will be presented in the name scope of an op. When the skip pattern is
                detected in an op's name scope, the corresponding op will not be quantized.
                Default is 'skip_quant'.
            quantizable_op_type(list[str], optional): List the type of ops that will be 
1291
                quantized. Default is ["elementwise_add", "pool2d"]. 
1292 1293 1294 1295
            is_full_quantized(bool, optional): If set is_full_quantized as True, apply 
                quantization to all supported quantizable op type. If set is_full_quantized
                as False, only apply quantization to the op type according to the input 
                quantizable_op_type.
1296 1297 1298 1299 1300 1301
        """
        self._scope = scope
        self._place = place
        self._moving_rate = moving_rate
        self._quant_bits = quant_bits
        self._is_test = None
1302
        self._skip_pattern = skip_pattern
1303 1304 1305 1306 1307 1308 1309 1310 1311 1312

        if is_full_quantized:
            self._quantizable_op_type = \
                AddQuantDequantPass._supported_quantizable_op_type
        else:
            self._quantizable_op_type = quantizable_op_type
            for op_type in quantizable_op_type:
                assert op_type in AddQuantDequantPass._supported_quantizable_op_type + \
                    AddQuantDequantPass._activation_type, \
                    op_type + " is not supported for quantization."
1313 1314 1315 1316
        self._quantizable_grad_op_type = [
            '%s_grad' % (op) for op in self._quantizable_op_type
        ]

1317 1318
        assert self._scope != None, "scope must not be None."
        assert self._place != None, "place must not be None."
1319 1320 1321

    def apply(self, graph):
        """
1322 1323
        Add quant_dequant before some ops, such as the 'elementwise_add' and
        'pool2d' op.
1324

1325 1326
        Args:
            graph(IrGraph): the target graph.
1327 1328
        Returns:
            None
1329 1330 1331 1332
        """
        assert isinstance(graph,
                          IrGraph), 'graph must be the instance of IrGraph.'
        self._is_test = graph.is_test()
1333 1334
        dequantized_vars_map = collections.OrderedDict()

1335 1336 1337
        # Forward stage, insert quant_dequant op
        all_op_nodes = graph.all_op_nodes()
        for op_node in all_op_nodes:
1338
            if op_node.name() in self._quantizable_op_type:
1339
                is_skip = False
1340
                if isinstance(self._skip_pattern, list):
1341
                    is_skip = op_node.op().has_attr("op_namescope") and \
1342 1343
                                   any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
                elif isinstance(self._skip_pattern, str):
1344
                    is_skip = op_node.op().has_attr("op_namescope") and \
1345 1346
                                   op_node.op().attr("op_namescope").find(self._skip_pattern) != -1

1347 1348 1349 1350 1351
                is_op_node_quantized = False
                for var_node in op_node.inputs:
                    var_name = var_node.name()
                    if var_name.endswith('.dequantized'):
                        is_op_node_quantized = True
1352

1353 1354
                if is_skip or is_op_node_quantized or \
                    (not _is_input_all_not_persistable(graph, op_node)):
1355
                    continue
1356

1357
                input_name_list = _op_real_in_out_name[op_node.name()][0]
1358
                arg_names = []
1359
                for input_name in input_name_list:
1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370
                    arg_names.extend(op_node.input(input_name))
                for arg_name in arg_names:
                    in_node = graph._find_node_by_name(op_node.inputs, arg_name)
                    if arg_name in dequantized_vars_map:
                        quant_var_node = dequantized_vars_map[arg_name]
                    else:
                        quant_var_node, _ = \
                            self._inser_quant_dequant_moving_average_abs_max_op(
                            graph, in_node, self._quant_bits)
                        dequantized_vars_map[arg_name] = quant_var_node
                    graph.update_input_link(in_node, quant_var_node, op_node)
1371

1372 1373
        # Backward stage, update input link
        for op_node in all_op_nodes:
1374
            if op_node.name() in self._quantizable_grad_op_type:
1375 1376 1377 1378 1379 1380 1381 1382
                for input_name in op_node.input_arg_names():
                    if input_name in dequantized_vars_map:
                        in_node = graph._find_node_by_name(op_node.inputs,
                                                           input_name)
                        dequant_var_node = dequantized_vars_map[input_name]
                        graph.update_input_link(in_node, dequant_var_node,
                                                op_node)

1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471
        graph.resolve_hazard()
        return graph

    def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
                                                       quant_bits):
        """Insert fake_quantize_dequantize_moving_average_abs_max op.
        """
        quant_var_node = graph.create_var_node(
            name="{}.quant_dequant".format(var_node.name()),
            var_type=var_node.type(),
            shape=var_node.shape(),
            var_dtype=var_node.dtype())
        scale_in_node = graph.create_persistable_node(
            name="{}.quant_dequant.scale".format(var_node.name()),
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
            var_dtype=var_node.dtype())
        data_type = 'float64' if var_node.dtype(
        ) == core.VarDesc.VarType.FP64 else 'float32'
        _init_var_node(
            scale_in_node,
            np.array(
                [0.001], dtype=data_type),
            self._scope,
            self._place)

        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        ins = {'X': var_node, 'InScale': scale_in_node}
        outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
        if not self._is_test:
            state_in_node = graph.create_persistable_node(
                name=unique_name.generate('quant_dequant.state'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
                shape=[1])
            data_type = 'float64' if var_node.dtype(
            ) == core.VarDesc.VarType.FP64 else 'float32'
            _init_var_node(
                state_in_node,
                np.ones(
                    [1], dtype=data_type),
                self._scope,
                self._place)
            accum_in_node = graph.create_persistable_node(
                name=unique_name.generate('quant_dequant.accum'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                var_dtype=var_node.dtype(),
                shape=[1])
            _init_var_node(
                accum_in_node,
                np.ones(
                    [1], dtype=data_type),
                self._scope,
                self._place)
            state_out_node = graph.create_var_node_from_desc(state_in_node.var(
            ))
            accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
            ))

            ins['InState'] = state_in_node
            ins['InAccum'] = accum_in_node
            outs['OutState'] = state_out_node
            outs['OutAccum'] = accum_out_node

        attrs = {
            'bit_length': quant_bits,
            'moving_rate': self._moving_rate,
            'is_test': self._is_test,
            'op_role': core.op_proto_and_checker_maker.OpRole.Forward
        }

        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_dequantize_moving_average_abs_max',
            attrs=attrs,
            inputs=ins,
            outputs=outs)

        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_in_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_out_node)

        if not self._is_test:
            graph.link_to(state_in_node, quant_op_node)
            graph.link_to(accum_in_node, quant_op_node)
            graph.link_to(quant_op_node, state_out_node)
            graph.link_to(quant_op_node, accum_out_node)

        return quant_var_node, scale_out_node