quant2_int8_mkldnn_pass.py 30.1 KB
Newer Older
1
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

17 18
from ...fluid.framework import IrGraph
from ...framework import _get_paddle_place, core
19

20 21
OpRole = core.op_proto_and_checker_maker.OpRole

22

23
class Quant2Int8MkldnnPass:
24
    """
W
Wojciech Uss 已提交
25
    Transform a quant model IrGraph into MKL-DNN supported INT8 IrGraph.
26 27
    The pass consists of the following transformations:
        1. gather scale values from fake quantize/dequantize operators,
W
Wojciech Uss 已提交
28
        2. extract FP32 inference model graph from the quant graph, i.e.
29 30 31 32 33 34 35 36
            a.  remove fake quantize/dequantize operators,
            b.  dequantize conv2d and mul's weights,
        3. optimize the FP32 graph using standard FP32 optimization fuses
            (e.g. `conv2d`+`bn` -> `conv2d`),
        4. quantize the optimized FP32 graph using standard INT8v2 quantization
            passes (`cpu_quantize_pass`, `cpu_quantize_squash_pass`).
    """

37 38 39 40 41 42 43 44 45
    def __init__(
        self,
        _ops_to_quantize,
        _op_ids_to_skip=None,
        _scope=None,
        _place=None,
        _core=None,
        _debug=False,
    ):
46
        self._scope = _scope
47
        self._place = _get_paddle_place(_place)
48 49
        self._core = _core
        self._debug = _debug
50
        self._fake_quantize_types = [
51 52 53
            'fake_quantize_moving_average_abs_max',
            'fake_quantize_range_abs_max',
        ]
54
        self._fake_dequantize_types = [
55 56
            'fake_dequantize_max_abs',
            'fake_channel_wise_dequantize_max_abs',
57
        ]
C
cc 已提交
58 59
        self._fake_quantize_dequantize_types = [
            'fake_quantize_dequantize_abs_max',
60
            'fake_quantize_dequantize_moving_average_abs_max',
61
            'fake_channel_wise_quantize_dequantize_abs_max',
C
cc 已提交
62
        ]
63
        self._ops_to_quantize = _ops_to_quantize
64 65 66
        self._op_ids_to_skip = (
            _op_ids_to_skip if _op_ids_to_skip is not None else set([-1])
        )
Z
Zuza 已提交
67
        self._scale_immutable_ops = [
68 69 70 71 72 73 74
            'transpose2',
            'reshape2',
            'pool2d',
            'slice',
            'shape',
            'nearest_interp',
            'nearest_interp_v2',
P
Paulina Gacek 已提交
75
            'split',
Z
Zuza 已提交
76
        ]
77
        self._scale_ops = ['scale']
78 79 80 81
        self._conv_ops = ['conv2d', 'depthwise_conv2d']
        self._pool_ops = ['pool2d']
        self._mul_ops = ['mul']
        self._fc_ops = ['fc']
82
        self._relu_ops = ['relu', 'relu6']
83
        self._matmul_ops = ['matmul', 'matmul_v2']
84
        self._gru_ops = ['fusion_gru', 'multi_gru']
L
lidanqing 已提交
85
        self._lstm_ops = ['fusion_lstm']
86
        self._weight_thresholds = {}
W
Wojciech Uss 已提交
87
        # Collect the Input and Output sclaes from Fake quant models
88 89 90
        self._var_quant_scales = {}
        self._max_range = {}
        self._s8_max = 127
W
Wojciech Uss 已提交
91 92
        self._pass_idx = 0
        self._pass_group = 'int8'
93 94

    def apply(self, graph):
95 96 97
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
98

W
Wojciech Uss 已提交
99
        self._reset_pass_idx_and_group('int8')
100 101
        graph = self._label_skip_quantized_op(graph)
        graph = self._gather_weight_thresholds_from_fake(graph)
102
        graph = self._gather_input_scales_from_fake(graph)
103
        graph = self._gather_output_scales_from_attr(graph)
104 105 106 107
        graph = self._remove_fake_ops(graph)
        graph = self._dequantize_weights(graph)
        graph = self._optimize_fp32_graph(graph)
        graph = self._compute_weight_scales(graph)
108 109
        # This function causes nondeterministic quantization behavior
        # graph = self._update_relu_output_scales(graph)
110
        graph = self._propagate_scales(graph)
111
        graph = self._quantize_fp32_graph(graph)
112
        graph = self._cleanup(graph)
113 114
        return graph

W
Wojciech Uss 已提交
115
    def prepare_and_optimize_fp32(self, graph):
116 117 118
        assert isinstance(
            graph, IrGraph
        ), 'graph must be the instance of IrGraph.'
119

W
Wojciech Uss 已提交
120
        self._reset_pass_idx_and_group('fp32')
121
        graph = self._optimize_fp32_graph(graph)
122
        graph = self._cleanup(graph)
123 124
        return graph

W
Wojciech Uss 已提交
125 126 127 128
    def _reset_pass_idx_and_group(self, group):
        self._pass_idx = 0
        self._pass_group = group

129 130 131 132 133
    def _convert_scale2tensor(self, scale):
        tensor = core.LoDTensor()
        tensor.set(scale, core.CPUPlace())
        return tensor

134 135 136 137 138
    def _is_quantizing_all_ops(self):
        return len(self._ops_to_quantize) == 0

    def _is_any_of_op_types_in_graph(self, op_types, graph):
        return any(op.name() in op_types for op in graph.all_op_nodes())
139

140
    def _is_any_of_op_types_quantized(self, op_types, graph):
141 142 143 144
        return self._is_any_of_op_types_in_graph(op_types, graph) and (
            self._is_quantizing_all_ops()
            or any(op_type in self._ops_to_quantize for op_type in op_types)
        )
145 146 147 148 149 150

    def _is_conv_quantized(self, graph):
        return self._is_any_of_op_types_quantized(self._conv_ops, graph)

    def _is_fc_quantized(self, graph):
        return self._is_any_of_op_types_quantized(self._fc_ops, graph)
151

152 153 154 155 156 157 158 159 160 161 162 163
    def _label_skip_quantized_op(self, graph):
        """
        For some ops(conv2d, depthwise_conv2d, mul, matml), find and label
        the skip quantized ops. cpu_quantize_placement_pass will use the
        label to identify it.
        For static models, the skip quantized ops have `skip_quant` attr.
        Therefore, it only needs to find and label the skip quantized ops for
        dygraph models, in which the quantized ops don't have `quantization_type`
        attr.
        """
        target_ops = self._conv_ops + self._mul_ops + self._matmul_ops
        for op_node in graph.all_op_nodes():
164 165 166
            if op_node.name() in target_ops and not op_node.op().has_attr(
                "quantization_type"
            ):
167 168 169
                is_quantized_op = True
                for var_node in op_node.inputs:
                    for front_op_node in var_node.inputs:
170
                        if "quantize" not in front_op_node.name():
171 172 173 174 175
                            is_quantized_op = False
                if not is_quantized_op:
                    op_node.op()._set_attr("skip_quant", True)
        return graph

176 177 178 179 180 181 182
    def _add_scale_for_vars(self, var_names, use_unsigned_int, lod_tensor):
        """
        Save quantization scales for variables. Do not overwrite.
        """
        scales = self._var_quant_scales
        for var_name in var_names:
            if var_name not in scales:
183 184
                scales[var_name] = (use_unsigned_int, lod_tensor)

185
    def _gather_input_scales_from_fake(self, graph):
C
cc 已提交
186 187 188 189
        # fake_quantize_dequantize_abs_max doesn't have scale value
        fake_ops = ['fake_quantize_dequantize_moving_average_abs_max']
        fake_ops.extend(self._fake_quantize_types)

190
        for op in graph.all_op_nodes():
C
cc 已提交
191
            if op.name() in fake_ops:
192
                bit_length = op.op().attr("bit_length")
193 194 195 196 197
                assert (
                    bit_length == 8
                ), 'Unsupported number quantization bits ({}). Only 8 is supported now.'.format(
                    bit_length
                )
198 199 200

                input_name = op.input("X")[0]
                scale_name = op.input("InScale")[0]
201
                output_name = op.output("Out")[0]
W
Wojciech Uss 已提交
202
                # Gather new weight scales after folding batchnorm in convolution
203
                scale = np.array(
204 205
                    1.0 / self._load_param(self._scope, scale_name)[0]
                ).astype(np.float64)
W
Wojciech Uss 已提交
206
                scale[scale == np.Inf] = 0.0
207 208
                lod_tensor = self._convert_scale2tensor(scale)
                use_unsigned_int = False
209 210 211
                self._add_scale_for_vars(
                    [input_name, output_name], use_unsigned_int, lod_tensor
                )
212 213

        return graph
214

215
    def _gather_weight_thresholds_from_fake(self, graph):
216
        for op in graph.all_op_nodes():
217 218
            if op.name() in self._fake_dequantize_types:
                input_name = op.input("X")[0]
219 220
                if op.op().has_attr("max_range"):
                    _max_range = np.array(op.op().attr("max_range")).astype(
221 222
                        np.float64
                    )
223
                    self._weight_thresholds[input_name] = np.array(
224 225
                        self._s8_max * self._s8_max / _max_range
                    ).astype(np.float64)
226 227
                else:
                    scale_name = op.input("Scales")[0]
228
                    self._weight_thresholds[input_name] = np.array(
229 230
                        self._load_param(self._scope, scale_name)
                    ).astype(np.float64)
231 232 233 234 235 236 237

        return graph

    def _gather_output_scales_from_attr(self, graph):
        for op in graph.all_op_nodes():
            if op.op().has_attr("out_threshold"):
                attr_scale = op.op().attr("out_threshold")
238 239
                if attr_scale == 0.0:
                    continue
240
                scale = np.array(1.0 / attr_scale).astype(np.float64)
W
Wojciech Uss 已提交
241
                scale[scale == np.Inf] = 0.0
242 243 244 245
                scale_lod_tensor = self._convert_scale2tensor(scale)
                use_unsigned_int = False
                for output_name in op.op().outputs():
                    for out_var_name in op.op().output(output_name):
246 247 248
                        self._add_scale_for_vars(
                            [out_var_name], use_unsigned_int, scale_lod_tensor
                        )
249

250 251
        return graph

252 253 254 255 256 257 258 259 260 261 262
    def _propagate_scales(self, graph):
        def _update_scale_op_in_scale(op, input, output):
            unsigned, tensor = self._var_quant_scales[output]
            scale = np.array(tensor) * op.op().attr("scale")
            new_tensor = self._convert_scale2tensor(scale.astype(np.float64))
            self._var_quant_scales[input] = (unsigned, new_tensor)

        def _update_scales(graph):
            waiting_for_scale = set()
            for op in graph.all_op_nodes():
                if op.name() in self._scale_immutable_ops:
263
                    if op.name() == 'slice' or op.name() == 'shape':
Z
Zuza 已提交
264 265 266
                        input_name = op.input("Input")[0]
                    else:
                        input_name = op.input("X")[0]
267 268 269
                    output_name = op.output("Out")[0]
                    tensor_names = [input_name, output_name]

270 271 272 273
                    if all(
                        name not in self._var_quant_scales
                        for name in tensor_names
                    ):
274 275
                        waiting_for_scale.update(tensor_names)
                        continue
276
                    elif input_name in self._var_quant_scales:
277
                        self._var_quant_scales[
278 279
                            output_name
                        ] = self._var_quant_scales[input_name]
280
                    elif output_name in self._var_quant_scales:
281
                        self._var_quant_scales[
282 283
                            input_name
                        ] = self._var_quant_scales[output_name]
P
Paulina Gacek 已提交
284

285 286 287 288 289 290
                elif op.name() == 'concat':
                    output_name = op.output("Out")[0]
                    if output_name in self._var_quant_scales:
                        input_names = op.input("X")
                        for input_name in input_names:
                            self._var_quant_scales[
291 292
                                input_name
                            ] = self._var_quant_scales[output_name]
293 294 295 296 297
                elif op.name() in self._scale_ops:
                    input_name = op.input("X")[0]
                    output_name = op.output("Out")[0]
                    if output_name in self._var_quant_scales:
                        _update_scale_op_in_scale(op, input_name, output_name)
298 299 300
            return waiting_for_scale

        waiting_for_scale = _update_scales(graph)
301
        waiting_for_scale_prev = set()
302

303 304 305 306
        while (
            len(waiting_for_scale) != 0
            and waiting_for_scale != waiting_for_scale_prev
        ):
307
            waiting_for_scale_prev = waiting_for_scale
308 309 310 311
            waiting_for_scale = _update_scales(graph)

        return graph

312 313 314 315 316 317
    def _load_param(self, scope, param_name):
        return np.array(scope.find_var(param_name).get_tensor())

    def _remove_fake_ops(self, graph):
        for op in graph.all_op_nodes():
            if op.name() in self._fake_quantize_types:
318
                self._remove_fake_quantize(graph, op)
C
cc 已提交
319 320 321
            elif op.name() in self._fake_dequantize_types:
                self._remove_fake_dequantize(graph, op)
            elif op.name() in self._fake_quantize_dequantize_types:
322
                self._remove_fake_dequantize(graph, op)
323

324 325 326 327
        return graph

    def _remove_fake_quantize(self, graph, op):
        fake_quant_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
328 329 330 331 332 333
        fake_quant_in_scale = graph._find_node_by_name(
            op.inputs, op.input("InScale")[0]
        )
        fake_quant_out = graph._find_node_by_name(
            op.outputs, op.output("Out")[0]
        )
334
        fake_quant_out_scale = graph._find_node_by_name(
335 336
            op.outputs, op.output("OutScale")[0]
        )
337 338 339 340 341 342

        next_ops = fake_quant_out.outputs
        for next_op in next_ops:
            self._swap_inputs(next_op, fake_quant_out, fake_quant_in)
            graph.link_to(fake_quant_in, next_op)
        graph.safe_remove_nodes(
343 344
            {op, fake_quant_in_scale, fake_quant_out, fake_quant_out_scale}
        )
345 346 347 348 349

        return graph

    def _remove_fake_dequantize(self, graph, op):
        fake_dequant_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
350 351 352
        fake_dequant_out = graph._find_node_by_name(
            op.outputs, op.output("Out")[0]
        )
353 354 355 356 357 358 359 360 361 362 363 364

        next_ops = fake_dequant_out.outputs
        for next_op in next_ops:
            self._swap_inputs(next_op, fake_dequant_out, fake_dequant_in)
            graph.link_to(fake_dequant_in, next_op)
        graph.safe_remove_nodes({op, fake_dequant_out})

        return graph

    def _swap_inputs(self, op, old_input, new_input):
        for input_name in op.op().input_names():
            if old_input.name() in op.input(input_name):
365 366 367 368 369 370 371
                op.op().set_input(
                    input_name,
                    [
                        new_input.name() if x == old_input.name() else x
                        for x in op.input(input_name)
                    ],
                )
372 373

    def _dequantize_weights(self, graph):
C
cc 已提交
374 375
        def _is_int8_weights(op_node, weight_name):
            weight_var_name = op_node.input(weight_name)[0]
376 377
            if self._scope.find_var(weight_var_name) is None:
                return False
C
cc 已提交
378 379 380
            weight = self._load_param(self._scope, weight_var_name)
            return np.all(np.mod(weight, 1) == 0)

381
        mul_and_matmul_ops = self._mul_ops + self._matmul_ops
382
        for op in graph.all_op_nodes():
C
cc 已提交
383
            if op.name() in self._conv_ops and _is_int8_weights(op, "Filter"):
384
                self._dequantize_op_weights(graph, op, "Filter", "Output")
385
            elif op.name() in mul_and_matmul_ops and _is_int8_weights(op, "Y"):
386
                self._dequantize_op_weights(graph, op, "Y", "Out")
387

388 389
        return graph

390 391 392
    def _dequantize_op_weights(self, graph, op_node, weight_name, output_name):
        weight_var_name = op_node.input(weight_name)[0]
        output_var_name = op_node.output(output_name)[0]
393
        # Convert int8 range weights to fp32 range weights
394
        scales = self._weight_thresholds[output_var_name]
395
        weight = self._load_param(self._scope, weight_var_name)
396
        if scales.size == 1 or scales.size == weight.shape[0]:
C
cc 已提交
397
            w_fp32 = np.multiply(np.divide(weight, self._s8_max).T, scales.T).T
398
        elif len(weight.shape) > 1 and scales.size == weight.shape[1]:
C
cc 已提交
399
            w_fp32 = np.multiply(np.divide(weight, self._s8_max), scales)
400 401
        else:
            raise ValueError(
402 403 404 405
                "The size of weight scales vector ({}) does not match the dimensions ({}) of the weights tensor {}.".format(
                    scales.size, weight.shape, weight_var_name
                )
            )
406 407
        w_fp32 = w_fp32.reshape(weight.shape).astype(np.float32)
        self._restore_var(weight_var_name, w_fp32)
408 409 410 411 412

    def _restore_var(self, name, array):
        tensor = self._scope.find_var(name).get_tensor()
        tensor.set(array, self._place)

413 414
    def _update_activations(self, graph):
        for op in graph.all_op_nodes():
415 416 417
            if op.name() in self._conv_ops and not op.op().has_attr(
                "fuse_activation"
            ):
418 419 420 421 422 423
                activation = ""
                if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"):
                    activation = "relu"
                op.set_attr("fuse_activation", activation)
        return graph

424 425 426 427 428 429 430 431
    def _remove_ctrl_vars(self, graph):
        remove_ctr_vars = set()
        for node in graph.all_var_nodes():
            if node.is_ctrl_var():
                remove_ctr_vars.add(node)
        graph.safe_remove_nodes(remove_ctr_vars)
        return graph

432
    def _optimize_fp32_graph(self, graph):
433
        graph = self._update_activations(graph)
434
        graph = self._remove_ctrl_vars(graph)
435 436 437
        graph = self._apply_pass(
            graph, 'mkldnn_placement_pass', ['mkldnn_enabled_op_types'], [set()]
        )
438 439
        # remove dropout ops
        graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass')
440
        graph = self._apply_pass(graph, 'layer_norm_fuse_pass')
W
Wojciech Uss 已提交
441 442 443 444 445 446
        graph = self._apply_pass(graph, 'attention_lstm_fuse_pass')
        graph = self._apply_pass(graph, 'seqconv_eltadd_relu_fuse_pass')
        graph = self._apply_pass(graph, 'fc_lstm_fuse_pass')
        graph = self._apply_pass(graph, 'mul_lstm_fuse_pass')
        graph = self._apply_pass(graph, 'fc_gru_fuse_pass')
        graph = self._apply_pass(graph, 'mul_gru_fuse_pass')
447 448
        graph = self._apply_pass(graph, 'multi_gru_fuse_pass')
        graph = self._apply_pass(graph, 'multi_gru_seq_fuse_pass')
W
Wojciech Uss 已提交
449
        graph = self._apply_pass(graph, 'seq_concat_fc_fuse_pass')
450 451 452
        graph = self._apply_pass(graph, 'gpu_cpu_squeeze2_matmul_fuse_pass')
        graph = self._apply_pass(graph, 'gpu_cpu_reshape2_matmul_fuse_pass')
        graph = self._apply_pass(graph, 'gpu_cpu_flatten2_matmul_fuse_pass')
453
        graph = self._apply_pass(graph, 'matmul_v2_scale_fuse_pass')
W
Wojciech Uss 已提交
454 455
        graph = self._apply_pass(graph, 'squared_mat_sub_fuse_pass')
        graph = self._apply_pass(graph, 'is_test_pass')
456 457
        graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_v2_to_mul_pass')
        graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_v2_to_matmul_pass')
458
        graph = self._apply_pass(graph, 'matmul_scale_fuse_pass')
459
        graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_to_mul_pass')
460
        graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
461 462 463
        graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
        graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
        graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
464
        graph = self._apply_pass(graph, 'conv_affine_channel_mkldnn_fuse_pass')
W
Wojciech Uss 已提交
465
        graph = self._apply_pass(graph, 'conv_transpose_bn_fuse_pass')
466 467 468
        graph = self._apply_pass(
            graph, 'conv_transpose_eltwiseadd_bn_fuse_pass'
        )
469
        graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass')
470
        graph = self._apply_pass(graph, 'conv_transpose_bias_mkldnn_fuse_pass')
471
        graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass')
472
        graph = self._apply_pass(graph, 'conv_activation_mkldnn_fuse_pass')
473 474 475
        graph = self._apply_pass(
            graph, 'fc_fuse_pass', ['use_gpu', 'use_fc_padding'], [False, False]
        )
W
Wojciech Uss 已提交
476
        graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
477
        if self._is_fc_quantized(graph):
478
            # Disabled due to topology-dependent speed-up
479
            graph = self._apply_pass(graph, 'fc_mkldnn_pass')
480
            graph = self._apply_pass(graph, 'fc_act_mkldnn_fuse_pass')
481 482 483 484 485 486
        graph = self._apply_pass(
            graph, 'matmul_transpose_reshape_mkldnn_fuse_pass'
        )
        graph = self._apply_pass(
            graph, 'matmul_elementwise_add_mkldnn_fuse_pass'
        )
487
        graph = self._apply_pass(graph, 'matmul_activation_mkldnn_fuse_pass')
488 489
        graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass')
        graph = self._apply_pass(graph, 'softplus_activation_mkldnn_fuse_pass')
490
        graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
491 492 493
        graph = self._apply_pass(
            graph, 'reshape_transpose_matmul_mkldnn_fuse_pass'
        )
W
Wojciech Uss 已提交
494 495
        # the following pass should be the last one since it will work on all fused ops.
        graph = self._apply_pass(graph, 'runtime_context_cache_pass')
496 497 498 499
        return graph

    def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None):
        ir_pass = core.get_pass(pass_name)
500 501 502
        cpp_graph = graph.graph
        if not cpp_graph.has('__param_scope__'):
            cpp_graph.set_not_owned('__param_scope__', self._scope)
503 504 505 506 507 508
        if attrs:
            assert attr_values and len(attrs) == len(
                attr_values
            ), "Different number of pass attributes and their values."
            for attr, value in zip(attrs, attr_values):
                ir_pass.set(attr, value)
509
        ir_pass.apply(cpp_graph)
510
        if self._debug:
511
            graph.draw(
512 513 514 515
                '.',
                '{}_{}_{}'.format(self._pass_group, self._pass_idx, pass_name),
                graph.all_op_nodes(),
            )
516
        self._remove_unused_var_nodes(graph)
W
Wojciech Uss 已提交
517
        self._pass_idx += 1
518 519
        return graph

520
    def _cleanup(self, graph):
521 522 523 524
        graph = self._remove_unused_var_nodes(graph)
        graph = self._set_op_role_forward(graph)
        return graph

525 526 527 528 529 530 531 532 533 534 535 536
    def _remove_unused_var_nodes(self, graph):
        all_used_vars = set()
        ops = graph.all_op_nodes()
        for op_node in ops:
            for input_node in op_node.inputs:
                all_used_vars.add(input_node)
            for output_node in op_node.outputs:
                all_used_vars.add(output_node)

        all_used_vars = {n.node for n in all_used_vars}
        all_unused_vars = {
            n
537 538 539 540
            for n in filter(
                lambda node: node.node not in all_used_vars,
                graph.all_var_nodes(),
            )
541 542 543 544
        }
        graph.safe_remove_nodes(all_unused_vars)
        return graph

545 546 547 548 549 550
    def _set_op_role_forward(self, graph):
        ops = graph.all_op_nodes()
        for op in ops:
            op.set_attr("op_role", OpRole.Forward)
        return graph

551
    def _compute_weight_scales(self, graph):
552
        def _compute_var_scales(ops, w_name, axis):
553 554 555 556
            for op in graph.all_op_nodes():
                if op.op().type() in ops:
                    weight_var_name = op.input(w_name)[0]
                    weights = np.array(
557 558 559 560 561 562 563 564
                        self._load_param(self._scope, weight_var_name)
                    )
                    scales = 1.0 / np.amax(
                        np.abs(weights.reshape(weights.shape[0], -1)).astype(
                            np.float64
                        ),
                        axis=axis,
                    )
565
                    scales[scales == np.Inf] = 0.0
566

567
                    lod_tensor = self._convert_scale2tensor(scales)
568
                    use_unsigned_int = False
569 570 571 572
                    self._var_quant_scales[weight_var_name] = (
                        use_unsigned_int,
                        lod_tensor,
                    )
573

574 575 576 577
        def _compute_single_gru_weight_scales(wx_var_name, wh_var_name):
            wx = np.array(self._load_param(self._scope, wx_var_name))
            wh = np.array(self._load_param(self._scope, wh_var_name))
            OC = wh.shape[0]
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605
            scale_ur = 1.0 / np.max(
                np.abs(
                    np.concatenate(
                        [
                            wx[:, : 2 * OC],
                            wh.flatten()[: 2 * OC * OC].reshape(OC, 2 * OC),
                        ],
                        axis=0,
                    )
                ),
                axis=0,
            )
            scale_o = 1.0 / np.max(
                np.abs(
                    np.concatenate(
                        [
                            wx[:, 2 * OC :],
                            wh.flatten()[2 * OC * OC :].reshape(OC, OC),
                        ],
                        axis=0,
                    )
                ),
                axis=0,
            )

            gru_weights_scale = np.concatenate([scale_ur, scale_o]).astype(
                'float'
            )
606 607 608

            return self._convert_scale2tensor(gru_weights_scale)

609 610 611
        def _compute_gru_weight_scales(wx_name, wh_name):
            for op in graph.all_op_nodes():
                if op.op().type() in self._gru_ops:
612 613 614
                    assert len(op.input(wx_name)) == len(
                        op.input(wh_name)
                    ), 'Mismatch in number of weights inputs ({} for WeightX vs. {} for WeightH).'.format(
615 616
                        len(op.input(wx_name)), len(op.input(wh_name))
                    )
617 618 619 620
                    for i, wx_var_name in enumerate(op.input(wx_name)):
                        wh_var_name = op.input(wh_name)[i]
                        use_unsigned_int = False
                        lod_tensor = _compute_single_gru_weight_scales(
621 622 623 624 625 626
                            wx_var_name, wh_var_name
                        )
                        self._var_quant_scales[wx_var_name] = (
                            use_unsigned_int,
                            lod_tensor,
                        )
627

L
lidanqing 已提交
628 629 630 631 632
        def _compute_single_lstm_weight_scales(wx_var_name, wh_var_name):
            wx = np.array(self._load_param(self._scope, wx_var_name))
            wh = np.array(self._load_param(self._scope, wh_var_name))

            lstm_weights_scale = 1.0 / np.max(
633 634
                np.abs(np.concatenate([wx[:, :], wh[:, :]], axis=0)), axis=0
            )
L
lidanqing 已提交
635 636 637 638 639 640 641 642 643 644
            lstm_weights_scale = lstm_weights_scale.astype('float')

            return self._convert_scale2tensor(lstm_weights_scale)

        def _compute_lstm_weight_scales(wx_name, wh_name):
            for op in graph.all_op_nodes():
                if op.op().type() in self._lstm_ops:
                    assert len(op.input(wx_name)) == len(
                        op.input(wh_name)
                    ), 'Mismatch in number of weights inputs ({} for WeightX vs. {} for WeightH).'.format(
645 646
                        len(op.input(wx_name)), len(op.input(wh_name))
                    )
L
lidanqing 已提交
647 648 649 650
                    for i, wx_var_name in enumerate(op.input(wx_name)):
                        wh_var_name = op.input(wh_name)[i]
                        use_unsigned_int = False
                        lod_tensor = _compute_single_lstm_weight_scales(
651 652 653 654 655 656
                            wx_var_name, wh_var_name
                        )
                        self._var_quant_scales[wx_var_name] = (
                            use_unsigned_int,
                            lod_tensor,
                        )
L
lidanqing 已提交
657

658 659
        _compute_var_scales(self._conv_ops, "Filter", axis=1)
        _compute_var_scales(self._fc_ops, "W", axis=0)
660
        _compute_var_scales(self._gru_ops, "WeightH", axis=0)
L
lidanqing 已提交
661
        _compute_var_scales(self._lstm_ops, "WeightH", axis=0)
662
        _compute_gru_weight_scales("WeightX", "WeightH")
L
lidanqing 已提交
663
        _compute_lstm_weight_scales("WeightX", "WeightH")
664 665
        return graph

666
    def _update_relu_output_scales(self, graph):
667
        def _set_unsigned_scale(graph, ops, op_out_name, predicate):
668 669 670 671 672 673 674 675
            '''
            Sets the type of an output scale of a passed op type(s) to 'unsigned int8' if the
            predicate applied on op passes. Typically, the predicate checks if op's
            activation is set to relu.
            '''
            for op in graph.all_op_nodes():
                if op.name() in ops:
                    out_name = op.output(op_out_name)[0]
676
                    if out_name in self._var_quant_scales and predicate(
677 678
                        op.op()
                    ):
679 680 681 682 683 684 685
                        is_unsigned, tensor = self._var_quant_scales[out_name]
                        if is_unsigned is False:
                            # If the variable is signed, it means that the scales for this var
                            # were computed for signed data, so the scale must be multiplied by 2
                            # to fill the entire range of uint8
                            scale = np.array(tensor) * 2
                            tensor = self._convert_scale2tensor(
686 687
                                scale.astype(np.float64)
                            )
688 689 690
                        self._var_quant_scales[out_name] = (True, tensor)
            return graph

691 692 693
        def conv_predicate(op):
            return op.attr("fuse_activation") in self._relu_ops

694 695 696
        graph = _set_unsigned_scale(
            graph, self._conv_ops, "Output", conv_predicate
        )
697

698 699 700
        def fc_predicate(op):
            return op.attr("activation_type") in self._relu_ops

701
        graph = _set_unsigned_scale(graph, self._fc_ops, "Out", fc_predicate)
702

703 704 705
        graph = _set_unsigned_scale(
            graph, self._relu_ops, 'Out', lambda op: True
        )
706

707
        return graph
708

709 710
    def _get_data_layout(self, graph):
        return 'NHWC' if self._is_conv_quantized(graph) else 'NCHW'
711

712
    def _quantize_fp32_graph(self, graph):
713
        graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
714
        graph = self._apply_pass(
715 716 717 718 719 720 721 722 723 724 725 726 727 728
            graph, 'reshape_transpose_matmul_mkldnn_fuse_pass'
        )
        graph = self._apply_pass(
            graph,
            'cpu_quantize_placement_pass',
            ['quantize_enabled_op_types'],
            [self._ops_to_quantize],
        )
        graph = self._apply_pass(
            graph,
            'cpu_quantize_pass',
            ['quant_var_scales', 'data_layout'],
            [self._var_quant_scales, self._get_data_layout(graph)],
        )
729
        graph = self._apply_pass(graph, 'cpu_quantize_squash_pass')
730
        graph = self._apply_pass(graph, 'int8_scale_calculation_mkldnn_pass')
731
        graph = self._apply_pass(graph, 'params_quantization_mkldnn_pass')
732
        return graph