quantization_pass.py 13.7 KB
Newer Older
W
WangZhen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
from .... import core
17
from ....framework import IrGraph
18 19
from ....framework import Program
from ....framework import Variable
W
WangZhen 已提交
20 21 22
from ....initializer import Constant
from .... import unique_name

23
__all__ = ['QuantizationTransformPass']
W
WangZhen 已提交
24

W
WangZhen 已提交
25

26
class QuantizationTransformPass(object):
W
WangZhen 已提交
27
    def __init__(self,
28 29
                 scope=None,
                 program_exe=None,
W
WangZhen 已提交
30 31 32 33 34 35
                 weight_bits=8,
                 activation_bits=8,
                 activation_quantize_type='abs_max',
                 weight_quantize_type='abs_max',
                 window_size=10000):
        """
36
        Convert and rewrite the IrGraph according to weight and
W
WangZhen 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
        activation quantization type.
        Args:
            weight_bits (int): quantization bit number for weights,
                the bias is not quantized.
            activation_bits (int): quantization bit number for activation.
            activation_quantize_type (str): quantization type for activation,
                now support 'abs_max', 'range_abs_max'. If use 'abs_max' mode,
                the quantization scale will be calculated dynamically each step
                in both training and testing period. If use 'range_abs_max',
                a static quantization scale will be calculated during training
                and used in inference.
            weight_quantize_type (str): quantization type for weights,
                support 'abs_max'. The 'range_abs_max' usually is not used for
                weight, since weights are fixed once the model is well trained.
            window_size (int): the window size for 'range_abs_max' quantization.
        Examples:
        .. code-block:: python
54 55 56 57
            # The original graph will be rewrite.
            import paddle.fluid as fluid
            from paddle.fluid.contrib.slim.quantization \
                import QuantizationTransformPass
58
            from paddle.fluid.contrib.slim.graph import IrGraph
59 60
            from paddle.fluid import core

61
            graph = IrGraph(core.Graph(program.desc), for_test=False)
62 63 64 65
            exe = fluid.Executor(fluid.CPUPlace())
            transform_pass = QuantizationTransformPass(fluid.global_scope(),
            exe)
            transform_pass.apply(graph)
W
WangZhen 已提交
66
        """
67 68 69 70
        self._scope = scope
        self._program_exe = program_exe
        self._weight_bits = weight_bits
        self._activation_bits = activation_bits
W
WangZhen 已提交
71 72 73 74 75 76 77 78 79 80 81

        quant_type = ['abs_max', 'range_abs_max']
        if activation_quantize_type not in quant_type:
            raise ValueError(
                "Unknown activation_quantize_type : '%s'. It can only be ",
                "'abs_max' or 'range_abs_max'.", str(activation_quantize_type))
        if weight_quantize_type not in quant_type:
            raise ValueError(
                "Unknown weight_quantize_type: '%s'. It can only be ",
                "'abs_max' or 'range_abs_max'.", str(weight_quantize_type))

82 83 84
        self._activation_quantize_type = activation_quantize_type
        self._weight_quantize_type = weight_quantize_type
        self._window_size = window_size
W
WangZhen 已提交
85

86 87 88 89
        self._need_initialized = collections.OrderedDict()
        self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
        self._quantizable_grad_ops = [
            '%s_grad' % (op) for op in self._quantizable_ops
W
WangZhen 已提交
90
        ]
91
        self._fake_quant_op_types = [
W
WangZhen 已提交
92 93
            'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
        ]
94 95 96
        self._fake_dequant_op_types = ['fake_dequantize_max_abs']
        self._is_test = None
        self._global_step = None
W
WangZhen 已提交
97

98
    def apply(self, graph):
W
WangZhen 已提交
99
        assert isinstance(graph,
100 101 102
                          IrGraph), 'graph must be the instance of IrGraph.'
        self._need_initialized.clear()
        self._is_test = graph.is_test()
W
WangZhen 已提交
103 104 105 106 107 108 109 110 111
        # marked the variable which has been dequantized.
        dequantized_vars = collections.OrderedDict()
        params = [p.name() for p in graph.all_parameters()]

        def _transform_forward(graph, op):
            for var_node in op.inputs:
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
                else:
112 113 114 115
                    quant_bits = self._weight_bits if var_node.name() in params \
                    else self._activation_bits
                    quant_type = self._weight_quantize_type if var_node.name() \
                        in params else self._activation_quantize_type
W
WangZhen 已提交
116 117 118 119 120
                    quant_var_node, scale_var_node = self._insert_quant_op(
                        graph, var_node, quant_bits, quant_type)
                    dequant_var_node = self._insert_dequant_op(
                        graph, quant_var_node, scale_var_node, quant_bits)
                    dequantized_vars[var_node.name()] = dequant_var_node
121
                graph.update_input_link(var_node, dequant_var_node, op)
W
WangZhen 已提交
122 123 124 125 126 127

        def _transform_backward(graph, op):
            no_dequanted_input_vars = True
            for var_node in op.inputs:
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
128
                    graph.update_input_link(var_node, dequant_var_node, op)
W
WangZhen 已提交
129 130 131 132
                    no_dequanted_input_vars = False
            if no_dequanted_input_vars:
                raise ValueError("There is no dequanted inputs for op %s." %
                                 (op.name()))
W
WangZhen 已提交
133

134
        if not self._is_test:
W
WangZhen 已提交
135 136
            self._create_global_step(graph)
        ops = graph.all_ops()
W
WangZhen 已提交
137 138
        # The process of _transform_forward and _transform_backward is needed in two for loops.
        # The loop for transforming the forward graph:
W
WangZhen 已提交
139
        for op in ops:
140
            if op.name() in self._quantizable_ops:
W
WangZhen 已提交
141
                _transform_forward(graph, op)
W
WangZhen 已提交
142 143
        # The loop for renaming the inputs of backward op.
        for op in ops:
144
            if op.name() in self._quantizable_grad_ops:
W
WangZhen 已提交
145
                _transform_backward(graph, op)
W
WangZhen 已提交
146

147 148
        if len(self._need_initialized) > 0:
            assert self._scope is not None, \
149
            'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
150
            assert self._program_exe is not None, \
151 152
            'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.'
            init_program = Program()
153 154 155
            for var_desc, initializer in self._need_initialized.iteritems():
                var = Variable(init_program.global_block())
                var._set_desc(var_desc)
156
                initializer(var, init_program.global_block())
157
            self._program_exe.run(program=init_program, scope=self._scope)
158 159

        return graph
W
WangZhen 已提交
160

W
WangZhen 已提交
161
    def _create_global_step(self, graph):
162 163
        if self._weight_quantize_type == 'range_abs_max' or \
                self._activation_quantize_type == 'range_abs_max':
W
WangZhen 已提交
164 165 166
            counter_name = '@STEP_COUNTER@'
            for node in graph.all_vars():
                if node.name() == counter_name:
167 168
                    self._global_step = node
            if self._global_step is None:
W
WangZhen 已提交
169 170 171 172 173
                global_step_in = graph.create_param_node(
                    name=counter_name,
                    var_type=core.VarDesc.VarType.LOD_TENSOR,
                    shape=[1],
                    var_dtype=core.VarDesc.VarType.INT64)
174
                self._need_initialized[global_step_in.var()] = \
W
WangZhen 已提交
175 176 177 178 179 180 181 182
                    Constant(value=0, force_cpu=True)
                global_step_out = graph.create_var_node_from_desc(
                    global_step_in.var())
                increment_op = graph.create_op_node(
                    op_type='increment',
                    attrs={'step': 1.0},
                    inputs={'X': global_step_in},
                    outputs={'Out': global_step_out})
183 184 185
                graph.link_to(global_step_in, increment_op)
                graph.link_to(increment_op, global_step_out)
                self._global_step = global_step_out
W
WangZhen 已提交
186

W
WangZhen 已提交
187 188 189 190 191 192 193
    def _insert_quant_op(self, graph, var_node, quant_bits, quant_type):
        """
        Insert fake_quantize_op in the graph.
        """
        if quant_type == 'abs_max':
            return self._insert_quant_abs_max_op(graph, var_node, quant_bits)
        elif quant_type == 'range_abs_max':
W
WangZhen 已提交
194 195
            return self._insert_quant_range_abs_max_op(graph, var_node,
                                                       quant_bits)
W
WangZhen 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218

    def _insert_quant_abs_max_op(self, graph, var_node, quant_bits):
        """
        Insert fake_quantize_abs_max op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
            var_type=var_node.var().type(),
            shape=var_node.var().shape(),
            var_dtype=var_node.var().dtype())
        scale_var_node = graph.create_var_node(
            name=self._quantized_scale_name(var_node.name()),
            var_type=var_node.var().type(),
            shape=var_node.var().shape(),
            var_dtype=var_node.var().dtype())
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_abs_max',
            attrs={'bit_length': quant_bits},
            inputs={'X': var_node},
            outputs={'Out': quant_var_node,
                     'OutScale': scale_var_node})
219 220 221
        graph.link_to(var_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_var_node)
W
WangZhen 已提交
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
        return quant_var_node, scale_var_node

    def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
        """
        Insert fake_quantize_range_abs_max on the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
            var_type=var_node.var().type(),
            shape=var_node.var().shape(),
            var_dtype=var_node.var().dtype())

        scale_in_node = graph.create_param_node(
            name=self._quantized_scale_name(var_node.name()),
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
            var_dtype=var_node.var().dtype())
241
        self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
W
WangZhen 已提交
242 243 244 245 246

        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        inputs = {'X': var_node, 'InScale': scale_in_node}
        outputs = {'Out': quant_var_node, 'OutScale': scale_out_node}

247
        if not self._is_test:
W
WangZhen 已提交
248 249 250 251
            # The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
            scales_node = graph.create_param_node(
                name=unique_name.generate('scales'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
252
                shape=[self._window_size],
W
WangZhen 已提交
253
                var_dtype=var_node.var().dtype())
254 255
            self._need_initialized[scales_node.var()] = Constant(value=0)
            inputs['Iter'] = self._global_step
W
WangZhen 已提交
256 257
            outputs['OutScales'] = scales_node
        attrs = {
258
            'window_size': self._window_size,
W
WangZhen 已提交
259
            'bit_length': quant_bits,
260
            'is_test': self._is_test
W
WangZhen 已提交
261 262 263 264 265 266 267
        }
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_range_abs_max',
            attrs=attrs,
            inputs=inputs,
            outputs=outputs)

268 269 270 271
        graph.link_to(var_node, quant_op_node)
        graph.link_to(scale_in_node, quant_op_node)
        graph.link_to(quant_op_node, quant_var_node)
        graph.link_to(quant_op_node, scale_out_node)
W
WangZhen 已提交
272

273 274 275
        if not self._is_test:
            graph.link_to(self._global_step, quant_op_node)
            graph.link_to(quant_op_node, scales_node)
W
WangZhen 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296

        return quant_var_node, scale_out_node

    def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
        """
        Insert fake_dequantize_op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(var_node.name()),
            var_type=var_node.var().type(),
            shape=var_node.var().shape(),
            var_dtype=var_node.var().dtype())
        max_range = (1 << (quant_bits - 1)) - 1
        dequant_op_node = graph.create_op_node(
            op_type='fake_dequantize_max_abs',
            attrs={'max_range': float(max_range)},
            inputs={'X': var_node,
                    'Scale': scale_var_node},
            outputs={'Out': dequant_var_node})
297 298 299
        graph.link_to(var_node, dequant_op_node)
        graph.link_to(scale_var_node, dequant_op_node)
        graph.link_to(dequant_op_node, dequant_var_node)
W
WangZhen 已提交
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
        return dequant_var_node

    def _quantized_var_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.quantized" % (var_name)

    def _dequantized_var_name(self, var_name):
        """
        Return dequantized variable name for the input `var_name`.
        """
        return "%s.dequantized" % (var_name)

    def _quantized_scale_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.scale" % (var_name)