quantization_performer.py 14.0 KB
Newer Older
W
WangZhen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import numpy as np
from .... import core
from ....initializer import Constant
from .... import unique_name
from ..graph import PyGraph

W
WangZhen 已提交
22 23
__all__ = ['QuantizationPerformer']

W
WangZhen 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112

class QuantizationPerformer(object):
    def __init__(self,
                 weight_bits=8,
                 activation_bits=8,
                 activation_quantize_type='abs_max',
                 weight_quantize_type='abs_max',
                 window_size=10000):
        """
        Convert and rewrite the IRGraph according to weight and
        activation quantization type.
        Args:
            weight_bits (int): quantization bit number for weights,
                the bias is not quantized.
            activation_bits (int): quantization bit number for activation.
            activation_quantize_type (str): quantization type for activation,
                now support 'abs_max', 'range_abs_max'. If use 'abs_max' mode,
                the quantization scale will be calculated dynamically each step
                in both training and testing period. If use 'range_abs_max',
                a static quantization scale will be calculated during training
                and used in inference.
            weight_quantize_type (str): quantization type for weights,
                support 'abs_max'. The 'range_abs_max' usually is not used for
                weight, since weights are fixed once the model is well trained.
            window_size (int): the window size for 'range_abs_max' quantization.
        Examples:
        .. code-block:: python
            # the original graph will be rewrite, if you don't want to
            # change it, please clone at first.
            # graph = graph.clone()
            from paddle.fluid.contrib.slim import *
            from paddle.fluid.contrib.quantize import *
            graph = IRGraph(program)
            performer = QuantizationPerformer()
            performer.quantize_transform(graph)
        """
        self.weight_bits = weight_bits
        self.activation_bits = activation_bits

        quant_type = ['abs_max', 'range_abs_max']
        if activation_quantize_type not in quant_type:
            raise ValueError(
                "Unknown activation_quantize_type : '%s'. It can only be ",
                "'abs_max' or 'range_abs_max'.", str(activation_quantize_type))
        if weight_quantize_type not in quant_type:
            raise ValueError(
                "Unknown weight_quantize_type: '%s'. It can only be ",
                "'abs_max' or 'range_abs_max'.", str(weight_quantize_type))

        self.activation_quantize_type = activation_quantize_type
        self.weight_quantize_type = weight_quantize_type
        self.window_size = window_size

        self.need_inited_outer = collections.OrderedDict()
        self.quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
        self.quantizable_grad_ops = [
            '%s_grad' % (op) for op in self.quantizable_ops
        ]
        self.fake_quant_op_types = [
            'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
        ]
        self.fake_dequant_op_types = ['fake_dequantize_max_abs']
        self.is_test = None
        self.global_step = None

    def quantize_transform(self, graph, is_test):
        self.need_inited_outer.clear()
        self.is_test = is_test
        assert isinstance(graph,
                          PyGraph), 'graph must be the instance of PyGraph.'
        # marked the variable which has been dequantized.
        dequantized_vars = collections.OrderedDict()
        params = [p.name() for p in graph.all_parameters()]

        def _transform_forward(graph, op):
            for var_node in op.inputs:
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
                else:
                    quant_bits = self.weight_bits if var_node.name() in params \
                    else self.activation_bits
                    quant_type = self.weight_quantize_type if var_node.name() \
                        in params else self.activation_quantize_type
                    quant_var_node, scale_var_node = self._insert_quant_op(
                        graph, var_node, quant_bits, quant_type)
                    dequant_var_node = self._insert_dequant_op(
                        graph, quant_var_node, scale_var_node, quant_bits)
                    dequantized_vars[var_node.name()] = dequant_var_node
                self._update_input(var_node, dequant_var_node, op)
W
WangZhen 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126
                op.op()._rename_input(var_node.name(), dequant_var_node.name())

        def _transform_backward(graph, op):
            no_dequanted_input_vars = True
            for var_node in op.inputs:
                if var_node.name() in dequantized_vars:
                    dequant_var_node = dequantized_vars[var_node.name()]
                    self._update_input(var_node, dequant_var_node, op)
                    op.op()._rename_input(var_node.name(),
                                          dequant_var_node.name())
                    no_dequanted_input_vars = False
            if no_dequanted_input_vars:
                raise ValueError("There is no dequanted inputs for op %s." %
                                 (op.name()))
W
WangZhen 已提交
127 128 129 130

        if not self.is_test:
            self._create_global_step(graph)
        ops = graph.all_ops()
W
WangZhen 已提交
131 132
        # The process of _transform_forward and _transform_backward is needed in two for loops.
        # The loop for transforming the forward graph:
W
WangZhen 已提交
133 134 135
        for op in ops:
            if op.name() in self.quantizable_ops:
                _transform_forward(graph, op)
W
WangZhen 已提交
136 137
        # The loop for renaming the inputs of backward op.
        for op in ops:
W
WangZhen 已提交
138 139
            if op.name() in self.quantizable_grad_ops:
                _transform_backward(graph, op)
W
WangZhen 已提交
140

W
WangZhen 已提交
141 142
        return self.need_inited_outer

W
WangZhen 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
    def _create_global_step(self, graph):
        if self.weight_quantize_type == 'range_abs_max' or \
                self.activation_quantize_type == 'range_abs_max':
            counter_name = '@STEP_COUNTER@'
            for node in graph.all_vars():
                if node.name() == counter_name:
                    self.global_step = node
            if self.global_step is None:
                global_step_in = graph.create_param_node(
                    name=counter_name,
                    var_type=core.VarDesc.VarType.LOD_TENSOR,
                    shape=[1],
                    var_dtype=core.VarDesc.VarType.INT64)
                self.need_inited_outer[global_step_in.var()] = \
                    Constant(value=0, force_cpu=True)
                global_step_out = graph.create_var_node_from_desc(
                    global_step_in.var())
                increment_op = graph.create_op_node(
                    op_type='increment',
                    attrs={'step': 1.0},
                    inputs={'X': global_step_in},
                    outputs={'Out': global_step_out})
                self._link_to(global_step_in, increment_op)
                self._link_to(increment_op, global_step_out)
                self.global_step = global_step_out

W
WangZhen 已提交
169 170 171 172 173 174 175
    def _insert_quant_op(self, graph, var_node, quant_bits, quant_type):
        """
        Insert fake_quantize_op in the graph.
        """
        if quant_type == 'abs_max':
            return self._insert_quant_abs_max_op(graph, var_node, quant_bits)
        elif quant_type == 'range_abs_max':
W
WangZhen 已提交
176 177
            return self._insert_quant_range_abs_max_op(graph, var_node,
                                                       quant_bits)
W
WangZhen 已提交
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284

    def _insert_quant_abs_max_op(self, graph, var_node, quant_bits):
        """
        Insert fake_quantize_abs_max op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
            var_type=var_node.var().type(),
            shape=var_node.var().shape(),
            var_dtype=var_node.var().dtype())
        scale_var_node = graph.create_var_node(
            name=self._quantized_scale_name(var_node.name()),
            var_type=var_node.var().type(),
            shape=var_node.var().shape(),
            var_dtype=var_node.var().dtype())
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_abs_max',
            attrs={'bit_length': quant_bits},
            inputs={'X': var_node},
            outputs={'Out': quant_var_node,
                     'OutScale': scale_var_node})
        self._link_to(var_node, quant_op_node)
        self._link_to(quant_op_node, quant_var_node)
        self._link_to(quant_op_node, scale_var_node)
        return quant_var_node, scale_var_node

    def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
        """
        Insert fake_quantize_range_abs_max on the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        quant_var_node = graph.create_var_node(
            name=self._quantized_var_name(var_node.name()),
            var_type=var_node.var().type(),
            shape=var_node.var().shape(),
            var_dtype=var_node.var().dtype())

        scale_in_node = graph.create_param_node(
            name=self._quantized_scale_name(var_node.name()),
            var_type=core.VarDesc.VarType.LOD_TENSOR,
            shape=[1],
            var_dtype=var_node.var().dtype())
        self.need_inited_outer[scale_in_node.var()] = Constant(value=0.001)

        scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
        inputs = {'X': var_node, 'InScale': scale_in_node}
        outputs = {'Out': quant_var_node, 'OutScale': scale_out_node}

        if not self.is_test:
            # The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
            scales_node = graph.create_param_node(
                name=unique_name.generate('scales'),
                var_type=core.VarDesc.VarType.LOD_TENSOR,
                shape=[self.window_size],
                var_dtype=var_node.var().dtype())
            self.need_inited_outer[scales_node.var()] = Constant(value=0)
            inputs['Iter'] = self.global_step
            outputs['OutScales'] = scales_node
        attrs = {
            'window_size': self.window_size,
            'bit_length': quant_bits,
            'is_test': self.is_test
        }
        quant_op_node = graph.create_op_node(
            op_type='fake_quantize_range_abs_max',
            attrs=attrs,
            inputs=inputs,
            outputs=outputs)

        self._link_to(var_node, quant_op_node)
        self._link_to(scale_in_node, quant_op_node)
        self._link_to(quant_op_node, quant_var_node)
        self._link_to(quant_op_node, scale_out_node)

        if not self.is_test:
            self._link_to(self.global_step, quant_op_node)
            self._link_to(quant_op_node, scales_node)

        return quant_var_node, scale_out_node

    def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
        """
        Insert fake_dequantize_op in the graph.
        """
        assert var_node.is_var(), '{} is not a var'.format(var_node.name())

        dequant_var_node = graph.create_var_node(
            name=self._dequantized_var_name(var_node.name()),
            var_type=var_node.var().type(),
            shape=var_node.var().shape(),
            var_dtype=var_node.var().dtype())
        max_range = (1 << (quant_bits - 1)) - 1
        dequant_op_node = graph.create_op_node(
            op_type='fake_dequantize_max_abs',
            attrs={'max_range': float(max_range)},
            inputs={'X': var_node,
                    'Scale': scale_var_node},
            outputs={'Out': dequant_var_node})
        self._link_to(var_node, dequant_op_node)
        self._link_to(scale_var_node, dequant_op_node)
        self._link_to(dequant_op_node, dequant_var_node)
        return dequant_var_node

    def _update_input(self, old_input_node, new_input_node, op_node):
W
WangZhen 已提交
285 286 287 288
        old_input_node.outputs_remove(op_node)
        op_node.inputs_remove(old_input_node)
        new_input_node.outputs_append(op_node)
        op_node.inputs_append(new_input_node)
W
WangZhen 已提交
289

W
WangZhen 已提交
290 291 292
    def _link_to(self, node_in, node_out):
        node_in.outputs_append(node_out)
        node_out.inputs_append(node_in)
W
WangZhen 已提交
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332

    def _quantized_var_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.quantized" % (var_name)

    def _dequantized_var_name(self, var_name):
        """
        Return dequantized variable name for the input `var_name`.
        """
        return "%s.dequantized" % (var_name)

    def _quantized_scale_name(self, var_name):
        """
        Return quantized variable name for the input `var_name`.
        """
        return "%s.scale" % (var_name)

    def _original_var_name(self, var_name):
        """
        Return the original variable name.
        """
        if var_name.endswith('.quantized.dequantized'):
            return var_name[:-len('.quantized.dequantized')]
        if var_name.endswith('.quantized'):
            return var_name[:-len('.quantized')]
        if var_name.endswith('.dequantized'):
            return var_name[:-len('.dequantized')]
        if var_name.endswith('.scale'):
            return var_name[:-len('.scale')]
        else:
            return var_name

    def _is_float(self, v):
        return isinstance(v, float) or isinstance(v, np.float32)

    def _quant(self, x, scale, num_bits):
        y = np.round(x / scale * ((1 << (num_bits - 1)) - 1))
        return y