post_quantization.py 12.1 KB
Newer Older
W
wuyefeilin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
C
chenguowei01 已提交
17
from paddle.fluid.contrib.slim.quantization.quantization_pass import _out_scale_op_list
W
wuyefeilin 已提交
18
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
C
chenguowei01 已提交
19
import utils.logging as logging
W
wuyefeilin 已提交
20 21
import paddle.fluid as fluid
import os
C
chenguowei01 已提交
22 23 24
import re
import numpy as np
import time
W
wuyefeilin 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79


class HumanSegPostTrainingQuantization(PostTrainingQuantization):
    def __init__(self,
                 executor,
                 dataset,
                 program,
                 inputs,
                 outputs,
                 batch_size=10,
                 batch_nums=None,
                 scope=None,
                 algo="KL",
                 quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
                 is_full_quantize=False,
                 is_use_cache_file=False,
                 cache_dir="./temp_post_training"):
        '''
        The class utilizes post training quantization methon to quantize the
        fp32 model. It uses calibrate data to calculate the scale factor of
        quantized variables, and inserts fake quant/dequant op to obtain the
        quantized model.
        Args:
            executor(fluid.Executor): The executor to load, run and save the
                quantized model.
            dataset(Python Iterator): The data Reader.
            program(fluid.Program): The paddle program, save the parameters for model.
            inputs(dict): The input of prigram.
            outputs(dict): The output of program.
            batch_size(int, optional): The batch size of DataLoader. Default is 10.
            batch_nums(int, optional): If batch_nums is not None, the number of
                calibrate data is batch_size*batch_nums. If batch_nums is None, use
                all data provided by sample_generator as calibrate data.
            scope(fluid.Scope, optional): The scope of the program, use it to load
                and save variables. If scope=None, get scope by global_scope().
            algo(str, optional): If algo=KL, use KL-divergenc method to
                get the more precise scale factor. If algo='direct', use
                abs_max methon to get the scale factor. Default is KL.
            quantizable_op_type(list[str], optional): List the type of ops
                that will be quantized. Default is ["conv2d", "depthwise_conv2d",
                "mul"].
            is_full_quantized(bool, optional): If set is_full_quantized as True,
                apply quantization to all supported quantizable op type. If set
                is_full_quantized as False, only apply quantization to the op type
                according to the input quantizable_op_type.
            is_use_cache_file(bool, optional): If set is_use_cache_file as False,
                all temp data will be saved in memory. If set is_use_cache_file as True,
                it will save temp data to disk. When the fp32 model is complex or
                the number of calibrate data is large, we should set is_use_cache_file
                as True. Defalut is False.
            cache_dir(str, optional): When is_use_cache_file is True, set cache_dir as
                the directory for saving temp data. Default is ./temp_post_training.
        Returns:
            None
        '''
C
chenguowei01 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
        self._support_activation_quantize_type = [
            'range_abs_max', 'moving_average_abs_max', 'abs_max'
        ]
        self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
        self._support_algo_type = ['KL', 'abs_max', 'min_max']
        self._support_quantize_op_type = \
            list(set(QuantizationTransformPass._supported_quantizable_op_type +
                AddQuantDequantPass._supported_quantizable_op_type))

        # Check inputs
        assert executor is not None, "The executor cannot be None."
        assert batch_size > 0, "The batch_size should be greater than 0."
        assert algo in self._support_algo_type, \
            "The algo should be KL, abs_max or min_max."

W
wuyefeilin 已提交
95 96 97 98 99 100 101 102
        self._executor = executor
        self._dataset = dataset
        self._batch_size = batch_size
        self._batch_nums = batch_nums
        self._scope = fluid.global_scope() if scope == None else scope
        self._algo = algo
        self._is_use_cache_file = is_use_cache_file
        self._cache_dir = cache_dir
C
chenguowei01 已提交
103 104 105 106
        self._activation_bits = 8
        self._weight_bits = 8
        self._activation_quantize_type = 'range_abs_max'
        self._weight_quantize_type = 'channel_wise_abs_max'
W
wuyefeilin 已提交
107 108 109 110
        if self._is_use_cache_file and not os.path.exists(self._cache_dir):
            os.mkdir(self._cache_dir)

        if is_full_quantize:
C
chenguowei01 已提交
111
            self._quantizable_op_type = self._support_quantize_op_type
W
wuyefeilin 已提交
112 113 114
        else:
            self._quantizable_op_type = quantizable_op_type
            for op_type in self._quantizable_op_type:
C
chenguowei01 已提交
115
                assert op_type in self._support_quantize_op_type + \
W
wuyefeilin 已提交
116 117 118 119 120 121 122 123 124
                    AddQuantDequantPass._activation_type, \
                    op_type + " is not supported for quantization."

        self._place = self._executor.place
        self._program = program
        self._feed_list = list(inputs.values())
        self._fetch_list = list(outputs.values())
        self._data_loader = None

C
chenguowei01 已提交
125
        self._out_scale_op_list = _out_scale_op_list
W
wuyefeilin 已提交
126 127 128 129
        self._bit_length = 8
        self._quantized_weight_var_name = set()
        self._quantized_act_var_name = set()
        self._sampling_data = {}
C
chenguowei01 已提交
130 131 132 133
        self._quantized_var_kl_threshold = {}
        self._quantized_var_min = {}
        self._quantized_var_max = {}
        self._quantized_var_abs_max = {}
W
wuyefeilin 已提交
134 135 136 137 138 139 140 141 142 143 144

    def quantize(self):
        '''
        Quantize the fp32 model. Use calibrate data to calculate the scale factor of
        quantized variables, and inserts fake quant/dequant op to obtain the
        quantized model.
        Args:
            None
        Returns:
            the program of quantized model.
        '''
C
chenguowei01 已提交
145 146 147 148 149 150 151 152
        self._load_model_data()
        self._collect_target_varnames()
        self._set_activation_persistable()
        batch_ct = 0
        for data in self._data_loader():
            batch_ct += 1
            if self._batch_nums and batch_ct >= self._batch_nums:
                break
W
wuyefeilin 已提交
153
        batch_id = 0
C
chenguowei01 已提交
154
        logging.info("Start to run batch!")
W
wuyefeilin 已提交
155
        for data in self._data_loader():
C
chenguowei01 已提交
156
            start = time.time()
W
wuyefeilin 已提交
157 158 159 160 161
            self._executor.run(
                program=self._program,
                feed=data,
                fetch_list=self._fetch_list,
                return_numpy=False)
C
chenguowei01 已提交
162 163 164 165 166 167 168 169
            if self._algo == "KL":
                self._sample_data(batch_id)
            else:
                self._sample_threshold()
            end = time.time()
            logging.debug(
                '[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
                    str(batch_id + 1), str(batch_ct), str(end - start)))
W
wuyefeilin 已提交
170 171 172
            batch_id += 1
            if self._batch_nums and batch_id >= self._batch_nums:
                break
C
chenguowei01 已提交
173 174 175 176 177 178 179 180 181 182 183 184 185
        logging.info("All run batch: ".format(batch_id))
        self._reset_activation_persistable()
        logging.info("Calculate scale factor ...")
        if self._algo == "KL":
            self._calculate_kl_threshold()
        logging.info("Update the program ...")
        if self._algo in ["KL", "abs_max"]:
            self._update_program()
        else:
            self._save_input_threhold()
        logging.info("Save ...")
        self._save_output_threshold()
        logging.info("Finish quant!")
W
wuyefeilin 已提交
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
        return self._program

    def save_quantized_model(self, save_model_path):
        '''
        Save the quantized model to the disk.
        Args:
            save_model_path(str): The path to save the quantized model
        Returns:
            None
        '''
        feed_vars_names = [var.name for var in self._feed_list]
        fluid.io.save_inference_model(
            dirname=save_model_path,
            feeded_var_names=feed_vars_names,
            target_vars=self._fetch_list,
            executor=self._executor,
            params_filename='__params__',
            main_program=self._program)

C
chenguowei01 已提交
205
    def _load_model_data(self):
W
wuyefeilin 已提交
206
        '''
C
chenguowei01 已提交
207
        Set data loader.
W
wuyefeilin 已提交
208 209 210 211 212 213 214 215 216
        '''
        feed_vars = [fluid.framework._get_var(var.name, self._program) \
            for var in self._feed_list]
        self._data_loader = fluid.io.DataLoader.from_generator(
            feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
        self._data_loader.set_sample_list_generator(
            self._dataset.generator(self._batch_size, drop_last=True),
            places=self._place)

C
chenguowei01 已提交
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
    def _calculate_kl_threshold(self):
        '''
        Calculate the KL threshold of quantized variables.
        '''
        assert self._algo == "KL", "The algo should be KL to calculate kl threshold."
        ct = 1
        # Abs_max threshold for weights
        for var_name in self._quantized_weight_var_name:
            start = time.time()
            weight_data = self._sampling_data[var_name]
            weight_threshold = None
            if self._weight_quantize_type == "abs_max":
                weight_threshold = np.max(np.abs(weight_data))
            elif self._weight_quantize_type == "channel_wise_abs_max":
                weight_threshold = []
                for i in range(weight_data.shape[0]):
                    abs_max_value = np.max(np.abs(weight_data[i]))
                    weight_threshold.append(abs_max_value)
            self._quantized_var_kl_threshold[var_name] = weight_threshold
            end = time.time()
            logging.debug(
                '[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.
                format(
                    str(ct), str(len(self._quantized_weight_var_name)),
                    str(end - start)))
            ct += 1
W
wuyefeilin 已提交
243

C
chenguowei01 已提交
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
        ct = 1
        # KL threshold for activations
        if self._is_use_cache_file:
            for var_name in self._quantized_act_var_name:
                start = time.time()
                sampling_data = []
                filenames = [f for f in os.listdir(self._cache_dir) \
                    if re.match(var_name + '_[0-9]+.npy', f)]
                for filename in filenames:
                    file_path = os.path.join(self._cache_dir, filename)
                    sampling_data.append(np.load(file_path))
                    os.remove(file_path)
                sampling_data = np.concatenate(sampling_data)
                self._quantized_var_kl_threshold[var_name] = \
                    self._get_kl_scaling_factor(np.abs(sampling_data))
                end = time.time()
                logging.debug(
                    '[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
                    .format(
                        str(ct), str(len(self._quantized_act_var_name)),
                        str(end - start)))
                ct += 1
        else:
            for var_name in self._quantized_act_var_name:
                start = time.time()
                self._sampling_data[var_name] = np.concatenate(
                    self._sampling_data[var_name])
                self._quantized_var_kl_threshold[var_name] = \
                    self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
                end = time.time()
                logging.debug(
                    '[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
                    .format(
                        str(ct), str(len(self._quantized_act_var_name)),
                        str(end - start)))
                ct += 1