quantization_pass.py 15.0 KB
Newer Older
W
WangZhen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import numpy as np
from .... import core
18 19
from ....framework import Program
from ....framework import Variable
W
WangZhen 已提交
20 21 22 23
from ....initializer import Constant
from .... import unique_name
from ..graph import PyGraph

24
__all__ = ['QuantizationTransformPass']
W
WangZhen 已提交
25

W
WangZhen 已提交
26

27
class QuantizationTransformPass(object):
W
WangZhen 已提交
28
    def __init__(self,
29 30
                 scope=None,
                 program_exe=None,
W
WangZhen 已提交
31 32 33 34 35 36
                 weight_bits=8,
                 activation_bits=8,
                 activation_quantize_type='abs_max',
                 weight_quantize_type='abs_max',
                 window_size=10000):
        """
37
        Convert and rewrite the PyGraph according to weight and
W
WangZhen 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
        activation quantization type.
        Args:
            weight_bits (int): quantization bit number for weights,
                the bias is not quantized.
            activation_bits (int): quantization bit number for activation.
            activation_quantize_type (str): quantization type for activation,
                now support 'abs_max', 'range_abs_max'. If use 'abs_max' mode,
                the quantization scale will be calculated dynamically each step
                in both training and testing period. If use 'range_abs_max',
                a static quantization scale will be calculated during training
                and used in inference.
            weight_quantize_type (str): quantization type for weights,
                support 'abs_max'. The 'range_abs_max' usually is not used for
                weight, since weights are fixed once the model is well trained.
            window_size (int): the window size for 'range_abs_max' quantization.
        Examples:
        .. code-block:: python
55 56 57 58 59 60 61 62 63 64 65 66
            # The original graph will be rewrite.
            import paddle.fluid as fluid
            from paddle.fluid.contrib.slim.quantization \
                import QuantizationTransformPass
            from paddle.fluid.contrib.slim.graph import PyGraph
            from paddle.fluid import core

            graph = PyGraph(core.Graph(program.desc), for_test=False)
            exe = fluid.Executor(fluid.CPUPlace())
            transform_pass = QuantizationTransformPass(fluid.global_scope(),
            exe)
            transform_pass.apply(graph)
W
WangZhen 已提交
67
        """
68 69
        self.scope = scope
        self.program_exe = program_exe
W
WangZhen 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
        self.weight_bits = weight_bits
        self.activation_bits = activation_bits

        quant_type = ['abs_max', 'range_abs_max']
        if activation_quantize_type not in quant_type:
            raise ValueError(
                "Unknown activation_quantize_type : '%s'. It can only be ",
                "'abs_max' or 'range_abs_max'.", str(activation_quantize_type))
        if weight_quantize_type not in quant_type:
            raise ValueError(
                "Unknown weight_quantize_type: '%s'. It can only be ",
                "'abs_max' or 'range_abs_max'.", str(weight_quantize_type))

        self.activation_quantize_type = activation_quantize_type
        self.weight_quantize_type = weight_quantize_type
        self.window_size = window_size

87
        self.need_initialized = collections.OrderedDict()
W
WangZhen 已提交
88 89 90 91 92 93 94 95 96 97 98
        self.quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
        self.quantizable_grad_ops = [
            '%s_grad' % (op) for op in self.quantizable_ops
        ]
        self.fake_quant_op_types = [
            'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
        ]
        self.fake_dequant_op_types = ['fake_dequantize_max_abs']
        self.is_test = None
        self.global_step = None

99
    def apply(self, graph):
W
WangZhen 已提交
100 101
        assert isinstance(graph,
                          PyGraph), 'graph must be the instance of PyGraph.'
102 103
        self.need_initialized.clear()
        self.is_test = graph.is_test()
W
WangZhen 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
        # marked the variable which has been dequantized.
        dequantized_vars = collections.OrderedDict()
        params = [p.name() for p in graph.all_parameters()]

        def _transform_forward(graph, op):
            for var_node in op.inputs:
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
                else:
                    quant_bits = self.weight_bits if var_node.name() in params \
                    else self.activation_bits
                    quant_type = self.weight_quantize_type if var_node.name() \
                        in params else self.activation_quantize_type
                    quant_var_node, scale_var_node = self._insert_quant_op(
                        graph, var_node, quant_bits, quant_type)
                    dequant_var_node = self._insert_dequant_op(
                        graph, quant_var_node, scale_var_node, quant_bits)
                    dequantized_vars[var_node.name()] = dequant_var_node
                self._update_input(var_node, dequant_var_node, op)
W
WangZhen 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136
                op.op()._rename_input(var_node.name(), dequant_var_node.name())

        def _transform_backward(graph, op):
            no_dequanted_input_vars = True
            for var_node in op.inputs:
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
                    self._update_input(var_node, dequant_var_node, op)
                    op.op()._rename_input(var_node.name(),
                                          dequant_var_node.name())
                    no_dequanted_input_vars = False
            if no_dequanted_input_vars:
                raise ValueError("There is no dequanted inputs for op %s." %
                                 (op.name()))
W
WangZhen 已提交
137 138 139 140

        if not self.is_test:
            self._create_global_step(graph)
        ops = graph.all_ops()
W
WangZhen 已提交
141 142
        # The process of _transform_forward and _transform_backward is needed in two for loops.
        # The loop for transforming the forward graph:
W
WangZhen 已提交
143 144 145
        for op in ops:
            if op.name() in self.quantizable_ops:
                _transform_forward(graph, op)
W
WangZhen 已提交
146 147
        # The loop for renaming the inputs of backward op.
        for op in ops:
W
WangZhen 已提交
148 149
            if op.name() in self.quantizable_grad_ops:
                _transform_backward(graph, op)
W
WangZhen 已提交
150

151 152 153 154 155 156 157 158 159 160 161 162 163
        if len(self.need_initialized) > 0:
            assert self.scope is not None, \
            'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
            assert self.program_exe is not None, \
            'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.'
            init_program = Program()
            for var_desc, initializer in self.need_initialized.iteritems():
                var = Variable.construct_from_desc(init_program.global_block(),
                                                   var_desc)
                initializer(var, init_program.global_block())
            self.program_exe.run(program=init_program, scope=self.scope)

        return graph
W
WangZhen 已提交
164

W
WangZhen 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177
    def _create_global_step(self, graph):
        if self.weight_quantize_type == 'range_abs_max' or \
                self.activation_quantize_type == 'range_abs_max':
            counter_name = '@STEP_COUNTER@'
            for node in graph.all_vars():
                if node.name() == counter_name:
                    self.global_step = node
            if self.global_step is None:
                global_step_in = graph.create_param_node(
                    name=counter_name,
                    var_type=core.VarDesc.VarType.LOD_TENSOR,
                    shape=[1],
                    var_dtype=core.VarDesc.VarType.INT64)
178
                self.need_initialized[global_step_in.var()] = \
W
WangZhen 已提交
179 180 181 182 183 184 185 186 187 188 189 190
                    Constant(value=0, force_cpu=True)
                global_step_out = graph.create_var_node_from_desc(
                    global_step_in.var())
                increment_op = graph.create_op_node(
                    op_type='increment',
                    attrs={'step': 1.0},
                    inputs={'X': global_step_in},
                    outputs={'Out': global_step_out})
                self._link_to(global_step_in, increment_op)
                self._link_to(increment_op, global_step_out)
                self.global_step = global_step_out

W
WangZhen 已提交
191 192 193 194 195 196 197
    def _insert_quant_op(self, graph, var_node, quant_bits, quant_type):
        """
        Insert fake_quantize_op in the graph.
        """
        if quant_type == 'abs_max':
            return self._insert_quant_abs_max_op(graph, var_node, quant_bits)
        elif quant_type == 'range_abs_max':
W
WangZhen 已提交
198 199
            return self._insert_quant_range_abs_max_op(graph, var_node,
                                                       quant_bits)
W
WangZhen 已提交
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244

    def _insert_quant_abs_max_op(self, graph, var_node, quant_bits):
        """
        Insert fake_quantize_abs_max op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
            var_type=var_node.var().type(),
            shape=var_node.var().shape(),
            var_dtype=var_node.var().dtype())
        scale_var_node = graph.create_var_node(
            name=self._quantized_scale_name(var_node.name()),
            var_type=var_node.var().type(),
            shape=var_node.var().shape(),
            var_dtype=var_node.var().dtype())
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_abs_max',
            attrs={'bit_length': quant_bits},
            inputs={'X': var_node},
            outputs={'Out': quant_var_node,
                     'OutScale': scale_var_node})
        self._link_to(var_node, quant_op_node)
        self._link_to(quant_op_node, quant_var_node)
        self._link_to(quant_op_node, scale_var_node)
        return quant_var_node, scale_var_node

    def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
        """
        Insert fake_quantize_range_abs_max on the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
            var_type=var_node.var().type(),
            shape=var_node.var().shape(),
            var_dtype=var_node.var().dtype())

        scale_in_node = graph.create_param_node(
            name=self._quantized_scale_name(var_node.name()),
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
            var_dtype=var_node.var().dtype())
245
        self.need_initialized[scale_in_node.var()] = Constant(value=0.001)
W
WangZhen 已提交
246 247 248 249 250 251 252 253 254 255 256 257

        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        inputs = {'X': var_node, 'InScale': scale_in_node}
        outputs = {'Out': quant_var_node, 'OutScale': scale_out_node}

        if not self.is_test:
            # The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
            scales_node = graph.create_param_node(
                name=unique_name.generate('scales'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                shape=[self.window_size],
                var_dtype=var_node.var().dtype())
258
            self.need_initialized[scales_node.var()] = Constant(value=0)
W
WangZhen 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
            inputs['Iter'] = self.global_step
            outputs['OutScales'] = scales_node
        attrs = {
            'window_size': self.window_size,
            'bit_length': quant_bits,
            'is_test': self.is_test
        }
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_range_abs_max',
            attrs=attrs,
            inputs=inputs,
            outputs=outputs)

        self._link_to(var_node, quant_op_node)
        self._link_to(scale_in_node, quant_op_node)
        self._link_to(quant_op_node, quant_var_node)
        self._link_to(quant_op_node, scale_out_node)

        if not self.is_test:
            self._link_to(self.global_step, quant_op_node)
            self._link_to(quant_op_node, scales_node)

        return quant_var_node, scale_out_node

    def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
        """
        Insert fake_dequantize_op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(var_node.name()),
            var_type=var_node.var().type(),
            shape=var_node.var().shape(),
            var_dtype=var_node.var().dtype())
        max_range = (1 << (quant_bits - 1)) - 1
        dequant_op_node = graph.create_op_node(
            op_type='fake_dequantize_max_abs',
            attrs={'max_range': float(max_range)},
            inputs={'X': var_node,
                    'Scale': scale_var_node},
            outputs={'Out': dequant_var_node})
        self._link_to(var_node, dequant_op_node)
        self._link_to(scale_var_node, dequant_op_node)
        self._link_to(dequant_op_node, dequant_var_node)
        return dequant_var_node

    def _update_input(self, old_input_node, new_input_node, op_node):
W
WangZhen 已提交
307 308 309 310
        old_input_node.outputs_remove(op_node)
        op_node.inputs_remove(old_input_node)
        new_input_node.outputs_append(op_node)
        op_node.inputs_append(new_input_node)
W
WangZhen 已提交
311

W
WangZhen 已提交
312 313 314
    def _link_to(self, node_in, node_out):
        node_in.outputs_append(node_out)
        node_out.inputs_append(node_in)
W
WangZhen 已提交
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354

    def _quantized_var_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.quantized" % (var_name)

    def _dequantized_var_name(self, var_name):
        """
        Return dequantized variable name for the input `var_name`.
        """
        return "%s.dequantized" % (var_name)

    def _quantized_scale_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.scale" % (var_name)

    def _original_var_name(self, var_name):
        """
        Return the original variable name.
        """
        if var_name.endswith('.quantized.dequantized'):
            return var_name[:-len('.quantized.dequantized')]
        if var_name.endswith('.quantized'):
            return var_name[:-len('.quantized')]
        if var_name.endswith('.dequantized'):
            return var_name[:-len('.dequantized')]
        if var_name.endswith('.scale'):
            return var_name[:-len('.scale')]
        else:
            return var_name

    def _is_float(self, v):
        return isinstance(v, float) or isinstance(v, np.float32)

    def _quant(self, x, scale, num_bits):
        y = np.round(x / scale * ((1 << (num_bits - 1)) - 1))
        return y