quantization_mkldnn_pass.py 9.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from .... import core
from ....framework import IrGraph
from ....framework import IrNode

__all__ = ['TransformForMkldnnPass']


class TransformForMkldnnPass(object):
    """
    Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8
    IrGraph. Following transformations did in this pass:
        1. Convert int8 range weights with float32 data type, which are generated by 
           the QuantizationFreezePass, to float32 range weights with float32 data type
           by using the corresponding scales. This conversion is because MKL-DNN INT8 
           conv2d kernel now only supports float32 weights input, will do weights
           quantization inside the conv2d kernel.
        2. Create the new conv2d op with the converted weights and link its output
           to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32
           _output" as true
        3. Transform fake_quantize_xx op to quantize op
        4. Remove fake_dequantize_abs_max op
    """

    def __init__(self, scope=None, place=None):
        """
        Args:
            scope(fluid.Scope): scope is used to initialize the new parameters.
            place(fluid.CPUPlace): place is used to initialize the new parameters.


        Examples:
        .. code-block:: python
            # The original graph will be rewrite.
            import paddle.fluid as fluid
            from paddle.fluid.contrib.slim.quantization \
                import TransformForMkldnnPass
            from paddle.fluid.framework import IrGraph
            from paddle.fluid import core	
        
            graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
            place = fluid.CPUPlace()
            mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(),
            place)
            mkldnn_pass.apply(graph)
        """

        self._scope = scope
        self._place = place

        self.quantize_type = [
            'fake_quantize_moving_average_abs_max',
            'fake_quantize_range_abs_max'
        ]
        self.dequantize_type = ['fake_dequantize_max_abs']

        self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
        self._conv_ops = ['conv2d', 'depthwise_conv2d']

        self.InScale = {}
        self.max_range = {}
        self.conv_new_output = {}
        self.s8_max = 127
        # Temporary code for keeping the mul op as fake quantization
        #TODO Intel: Remove the following code when mul int8 mkldnn
        # kernel enabled 
        self.mul_input_id = []
        self.mul_output_id = []

    def apply(self, graph):
        """
        Quantize the graph for running MKL-DNN INT8 inference. According 
        to activation quantization type, the graph will transform fake 
        quantize ops to quantize ops and remove the fake dequantize ops.
      
        Args:
            graph(IrGraph): the applied graph.
        """

        assert isinstance(graph,
                          IrGraph), 'graph must be the instance of IrGraph.'
        ops = graph.all_op_nodes()

        persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
        # Collect the InScales and max_range to calculate the new scales for MKL-DNN 
        # INT8 conv2d
        for op_node in ops:
            if op_node.name() in self.dequantize_type:
                input_name = op_node.input("X")[0]
                scale_name = op_node.input("Scale")[0]
                self.InScale[input_name] = self._load_param(self._scope,
                                                            scale_name)[0]
                self.max_range[input_name] = op_node.op().attr("max_range")
                self.conv_new_output[input_name] = op_node.output("Out")[0]
            # Temporary graph transform on keeping the mul op
            # TODO Intel: Remove following code
            elif op_node.name() in ['mul']:
                input_node = graph._find_node_by_name(op_node.inputs,
                                                      op_node.input('X')[0])
                output_node = graph._find_node_by_name(op_node.outputs,
                                                       op_node.output('Out')[0])
                self.mul_input_id.append(input_node.id())
                self.mul_output_id.append(output_node.id())

        for op_node in ops:
            if op_node.name() in self._conv_ops:
                self._transform_to_conv_mkldnn(graph, op_node)
            elif op_node.name() in self.quantize_type:
                self._transform_to_quantize_mkldnn(graph, op_node)
            elif op_node.name() in self.dequantize_type:
                self._remove_fake_dequantize_op(graph, op_node)
            self._remove_unused_var_nodes(graph)
        return graph

    def _transform_to_conv_mkldnn(self, graph, op_node):
        weight_name = op_node.input("Filter")[0]
        output_name = op_node.output("Output")[0]
        # Convert int8 range weights to fp32 range weights 
        weight = self._load_param(self._scope, weight_name)
        w_fp32 = np.divide(
            np.multiply(weight, 127), self.max_range[output_name])
        w_fp32 = w_fp32.reshape(weight.shape)
        self._restore_var(weight_name, w_fp32)
        input_var_node = graph._find_node_by_name(op_node.inputs,
                                                  op_node.input("Input")[0])
        weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)

        # Set fake_dequantize_abs_max's output as new output of conv2d
        output_var_node = graph._find_node_by_name(
            graph.all_var_nodes(), self.conv_new_output[output_name])
        attrs = {
            name: op_node.op().attr(name)
            for name in op_node.op().attr_names()
        }

        conv_op_node = graph.create_op_node(
            op_type='conv2d',
            attrs=attrs,
            inputs={'Input': input_var_node,
                    'Filter': weight_var_node},
            outputs={'Output': output_var_node})

        # Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d
        scale_in = self.s8_max / self.InScale[output_name]
        scale_w = []
        scale_w.append(self.max_range[output_name] / self.s8_max)

        conv_op_node.set_attr("Scale_weights", scale_w)
        conv_op_node.set_attr("Scale_in", scale_in)
        conv_op_node.set_attr("Scale_out", 1.0)
        conv_op_node.set_attr("use_mkldnn", 1)
        conv_op_node.set_attr("force_fp32_output", 1)
        graph.link_to(input_var_node, conv_op_node)
        graph.link_to(weight_var_node, conv_op_node)
        graph.link_to(conv_op_node, output_var_node)
        graph.safe_remove_nodes(op_node)

    def _transform_to_quantize_mkldnn(self, graph, op_node):
        """
        Transform fake_quantize_xx op to quantize mkldnn op in the graph.
        """
        input_var_node = graph._find_node_by_name(op_node.inputs,
                                                  op_node.input("X")[0])
        output_var_node = graph._find_node_by_name(op_node.outputs,
                                                   op_node.output("Out")[0])
        if output_var_node.id() in self.mul_input_id:
            return
        else:
            scale_in = self.s8_max / self._load_param(
                self._scope, op_node.input("InScale")[0])[0]
            quant_op_node = graph.create_op_node(
                op_type='quantize',
                attrs={
                    'data_format': 'MKLDNNLAYOUT',
                    'use_mkldnn': 1,
                    'Scale': scale_in,
                    'is_negative_input': 1
                },
                inputs={'Input': input_var_node},
                outputs={'Output': output_var_node})
            graph.link_to(input_var_node, quant_op_node)
            graph.link_to(quant_op_node, output_var_node)
            graph.safe_remove_nodes(op_node)

    def _remove_fake_dequantize_op(self, graph, op_node):
        input_var_node = graph._find_node_by_name(op_node.inputs,
                                                  op_node.input("X")[0])
        if input_var_node.id() in self.mul_output_id:
            return
        else:
            graph.safe_remove_nodes(op_node)

    def _load_param(self, scope, param_name):
        return np.array(scope.find_var(param_name).get_tensor())

    def _restore_var(self, name, array):
        tensor = self._scope.find_var(name).get_tensor()
        tensor.set(array, self._place)

    def _remove_unused_var_nodes(self, graph):
        all_used_vars = set()
        ops = graph.all_op_nodes()
        for op_node in ops:
            for input_node in op_node.inputs:
                all_used_vars.add(input_node)
            for output_node in op_node.outputs:
                all_used_vars.add(output_node)

        all_used_vars = {n.node for n in all_used_vars}
        all_unused_vars = {
            n
            for n in filter(lambda node: node.node not in all_used_vars,
                            graph.all_var_nodes())
        }
        graph.safe_remove_nodes(all_unused_vars)