quantization_strategy.py 11.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import sys
import numpy as np
from .... import Executor
from .... import io
20
from .... import core, scope_guard
21 22
from ....compiler import CompiledProgram
from ....compiler import BuildStrategy
W
wanghaoshuang 已提交
23
from ....framework import IrGraph, Variable, Program
24
from ....log_helper import get_logger
25 26 27 28 29
from ..core.strategy import Strategy
from .quantization_pass import *

__all__ = ['QuantizationStrategy']

30 31
_logger = get_logger(
    __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47


class QuantizationStrategy(Strategy):
    """
    The strategy for Quantization.
    """

    def __init__(self,
                 start_epoch=0,
                 end_epoch=0,
                 float_model_save_path=None,
                 mobile_model_save_path=None,
                 int8_model_save_path=None,
                 activation_bits=8,
                 weight_bits=8,
                 activation_quantize_type='abs_max',
48
                 weight_quantize_type='abs_max',
49 50 51 52 53 54
                 save_in_nodes=None,
                 save_out_nodes=None):
        """
        Args:
            start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0
            end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0
55
            float_model_save_path(str): The path to save model with float weights.
翟飞跃 已提交
56
                            None means it doesn't save float model. default: None.
57
            mobile_model_save_path(str): The path to save model for paddle-mobile execution.
翟飞跃 已提交
58
                            None means it doesn't save mobile model. default: None.
59
            int8_model_save_path(str): The path to save model with int8_t weight.
翟飞跃 已提交
60
                            None means it doesn't save int8 model. default: None.
61 62 63 64 65 66 67 68 69
            activation_bits(int): quantization bit number for activation. default: 8.
            weight_bits(int): quantization bit number for weights. The bias is not quantized.
                              default: 8.
            activation_quantize_type(str): quantization type for activation,
                now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
                If use 'abs_max' mode, the quantization scale will be calculated
                dynamically each step in both training and testing period. If use
                'range_abs_max', a static quantization scale will be calculated
                during training and used in inference.
70 71 72
            weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'.
            The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained.
            save_in_nodes(list<str>): A list of variable names used to prune graph
73
                                      for saving inference model.
74
            save_out_nodes(list<str>): A list of variable names used to prune graph
75 76 77 78 79 80 81 82 83 84 85 86
                                      for saving inference model.

        """
        super(QuantizationStrategy, self).__init__(start_epoch, end_epoch)
        self.start_epoch = start_epoch
        self.end_epoch = end_epoch
        self.float_model_save_path = float_model_save_path
        self.mobile_model_save_path = mobile_model_save_path
        self.int8_model_save_path = int8_model_save_path
        self.activation_bits = activation_bits
        self.weight_bits = weight_bits
        self.activation_quantize_type = activation_quantize_type
87
        self.weight_quantize_type = weight_quantize_type
88 89 90
        self.save_out_nodes = save_out_nodes
        self.save_in_nodes = save_in_nodes

91
    def restore_from_checkpoint(self, context):
W
wanghaoshuang 已提交
92
        """
翟飞跃 已提交
93
        Restore graph when the compression task is inited from checkpoint.
W
wanghaoshuang 已提交
94 95 96 97 98 99 100 101 102
        """
        # It is inited from checkpoint and has missed start epoch.
        if context.epoch_id != 0 and context.epoch_id > self.start_epoch:
            _logger.info("Restore quantization task from checkpoint")
            self._modify_graph_for_quantization(context)
            _logger.info("Finish restoring quantization task from checkpoint")

    def _modify_graph_for_quantization(self, context):
        """
翟飞跃 已提交
103
        Insert fake_quantize_op and fake_dequantize_op before training and testing.
W
wanghaoshuang 已提交
104 105 106 107 108 109 110 111 112 113 114
        """
        train_ir_graph = IrGraph(
            core.Graph(context.optimize_graph.program.clone().desc),
            for_test=False)
        test_ir_graph = IrGraph(
            core.Graph(context.eval_graph.program.clone().desc), for_test=True)
        transform_pass = QuantizationTransformPass(
            scope=context.scope,
            place=context.place,
            weight_bits=self.weight_bits,
            activation_bits=self.activation_bits,
115 116
            activation_quantize_type=self.activation_quantize_type,
            weight_quantize_type=self.weight_quantize_type)
W
wanghaoshuang 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
        transform_pass.apply(train_ir_graph)
        transform_pass.apply(test_ir_graph)
        # Put persistables created by transform_pass into context.optimize_graph.persistables
        # for saving checkpoint.
        program_persistables = set()
        for var in context.optimize_graph.program.list_vars():
            if var.persistable:
                program_persistables.add(var.name)

        program = Program()
        for var_node in train_ir_graph.all_persistable_nodes():
            if var_node.name() not in program_persistables:
                var_desc = var_node.var()
                var = program.global_block().create_var(
                    name=var_node.name(),
                    shape=var_desc.shape(),
                    dtype=var_desc.dtype(),
                    type=var_desc.type(),
                    lod_level=var_desc.lod_level())
                context.optimize_graph.persistables[var.name] = var

        build_strategy = BuildStrategy()
        build_strategy.enable_inplace = False
        build_strategy.memory_optimize = False
141
        build_strategy.fuse_all_reduce_ops = False
W
wanghaoshuang 已提交
142 143 144 145 146
        # for quantization training
        context.optimize_graph.compiled_graph = CompiledProgram(
            train_ir_graph.graph).with_data_parallel(
                loss_name=context.optimize_graph.out_nodes['loss'],
                build_strategy=build_strategy)
147 148 149

        context.eval_graph.program = test_ir_graph.to_program()

W
wanghaoshuang 已提交
150 151 152
        # for saving inference model after training
        context.put('quantization_test_ir_graph_backup', test_ir_graph)

153 154
    def on_epoch_begin(self, context):
        """
翟飞跃 已提交
155
        Insert fake_quantize_op and fake_dequantize_op before training and testing.
156
        """
W
wanghaoshuang 已提交
157
        super(QuantizationStrategy, self).on_epoch_begin(context)
158 159
        if self.start_epoch == context.epoch_id:
            _logger.info('QuantizationStrategy::on_epoch_begin')
W
wanghaoshuang 已提交
160
            self._modify_graph_for_quantization(context)
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
            _logger.info('Finish QuantizationStrategy::on_epoch_begin')

    def on_epoch_end(self, context):
        """
        Free and save inference model.
        """
        super(QuantizationStrategy, self).on_compression_end(context)

        if context.epoch_id == self.end_epoch:
            _logger.info('QuantizationStrategy::on_epoch_end')
            test_ir_graph = context.get('quantization_test_ir_graph_backup')
            # freeze the graph after training
            freeze_pass = QuantizationFreezePass(
                scope=context.scope,
                place=context.place,
                weight_bits=self.weight_bits,
177 178
                activation_bits=self.activation_bits,
                weight_quantize_type=self.weight_quantize_type)
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
            freeze_pass.apply(test_ir_graph)

            # for other strategies
            context.eval_graph.program = test_ir_graph.to_program()

            if self.save_out_nodes == None:
                out_vars = [
                    context.eval_graph.var(var_name)._var
                    for var_name in context.eval_graph.out_nodes.values()
                ]
            else:
                out_vars = [
                    context.eval_graph.var(var_name)._var
                    for var_name in self.save_out_nodes
                ]

            if self.save_in_nodes == None:
Z
Zhen Wang 已提交
196
                in_vars = list(context.eval_graph.in_nodes.values())
197 198 199 200 201 202
            else:
                in_vars = self.save_in_nodes

            # save float model
            if self.float_model_save_path:
                executor = Executor(context.place)
203 204 205 206 207 208 209 210 211 212
                with scope_guard(context.scope):
                    io.save_inference_model(
                        self.float_model_save_path,
                        in_vars,
                        out_vars,
                        executor,
                        main_program=test_ir_graph.to_program(),
                        model_filename='model',
                        params_filename='weights',
                        export_for_deployment=True)
213 214 215 216 217 218 219 220

            # save int8 model
            if self.int8_model_save_path:
                convert_int8_pass = ConvertToInt8Pass(
                    scope=context.scope, place=context.place)
                convert_int8_pass.apply(test_ir_graph)

                executor = Executor(context.place)
221 222 223 224 225 226 227 228 229 230 231

                with scope_guard(context.scope):
                    io.save_inference_model(
                        self.int8_model_save_path,
                        in_vars,
                        out_vars,
                        executor,
                        main_program=test_ir_graph.to_program(),
                        model_filename='model',
                        params_filename='weights',
                        export_for_deployment=True)
232 233 234 235 236 237 238 239 240 241 242 243

            # save mobile model
            if self.mobile_model_save_path:
                if not self.int8_model_save_path:
                    # convert the weights as int8_t type
                    convert_int8_pass = ConvertToInt8Pass(
                        scope=context.scope, place=context.place)
                    convert_int8_pass.apply(test_ir_graph)
                # make some changes on the graph for the mobile inference
                mobile_pass = TransformForMobilePass()
                mobile_pass.apply(test_ir_graph)
                executor = Executor(context.place)
244 245 246 247 248 249 250 251 252 253
                with scope_guard(context.scope):
                    io.save_inference_model(
                        self.mobile_model_save_path,
                        in_vars,
                        out_vars,
                        executor,
                        main_program=test_ir_graph.to_program(),
                        model_filename='model',
                        params_filename='weights',
                        export_for_deployment=True)
254
            _logger.info('Finish QuantizationStrategy::on_epoch_end')