From 175ba39c03cf08dd4afb62d513404e00d05ff470 Mon Sep 17 00:00:00 2001 From: juncaipeng <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 5 Nov 2019 15:18:31 +0800 Subject: [PATCH] Add post_training_quantization (#20800) * add post training quantization, test=develop * specify the quantizable op type, test=develop --- .../contrib/slim/quantization/__init__.py | 3 + .../post_training_quantization.py | 448 ++++++++++++++++++ .../slim/quantization/quantization_pass.py | 72 ++- .../fluid/contrib/slim/tests/CMakeLists.txt | 1 + .../tests/test_post_training_quantization.py | 354 ++++++++++++++ 5 files changed, 857 insertions(+), 21 deletions(-) create mode 100644 python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py create mode 100644 python/paddle/fluid/contrib/slim/tests/test_post_training_quantization.py diff --git a/python/paddle/fluid/contrib/slim/quantization/__init__.py b/python/paddle/fluid/contrib/slim/quantization/__init__.py index 659265895a5..ad62006ddbc 100644 --- a/python/paddle/fluid/contrib/slim/quantization/__init__.py +++ b/python/paddle/fluid/contrib/slim/quantization/__init__.py @@ -22,7 +22,10 @@ from . import mkldnn_post_training_strategy from .mkldnn_post_training_strategy import * from . import quantization_mkldnn_pass from .quantization_mkldnn_pass import * +from . import post_training_quantization +from .post_training_quantization import * __all__ = quantization_pass.__all__ + quantization_strategy.__all__ __all__ += mkldnn_post_training_strategy.__all__ __all__ += quantization_mkldnn_pass.__all__ +__all__ += post_training_quantization.__all__ diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py new file mode 100644 index 00000000000..59b77ea6a9b --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -0,0 +1,448 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import logging +import numpy as np +from ....executor import global_scope +from .... import io +from .... import core +from .... import framework +from ....framework import IrGraph +from ....log_helper import get_logger +from .quantization_pass import QuantizationTransformPass +from .quantization_pass import QuantizationFreezePass +from .quantization_pass import AddQuantDequantPass + +__all__ = ['PostTrainingQuantization'] + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class PostTrainingQuantization(object): + def __init__(self, + executor, + model_path, + data_reader, + batch_size=10, + batch_nums=None, + scope=None, + algo="KL", + quantizable_op_type=[ + "conv2d", "depthwise_conv2d", "mul", "pool2d", + "elementwise_add" + ]): + ''' + The class utilizes post training quantization methon to quantize the + fp32 model. It uses calibrate data to calculate the scale factor of + quantized variables, and inserts fake quant/dequant op to obtain the + quantized model. + + Args: + executor(fluid.Executor): The executor to load, run and save the + quantized model. + model_path(str): The path of fp32 model that will be quantized. + data_reader(Reader): The data reader generates a sample every time, + and it provides calibrate data for DataLoader. + batch_size(int, optional): The batch size of DataLoader, default is 10. + batch_nums(int, optional): If set batch_nums, the number of calibrate + data is batch_size*batch_nums. If batch_nums=None, use all data + provided by data_reader as calibrate data. + scope(fluid.Scope, optional): The scope of the program, use it to load + and save variables. If scope=None, get scope by global_scope(). + algo(str, optional): If algo=KL, use KL-divergenc method to + get the more precise scale factor. If algo='direct', use + abs_max methon to get the scale factor. Default is KL. + quantizable_op_type(list[str], optional): List the type of ops + that will be quantized. Default is ["conv2d", "depthwise_conv2d", + "mul", "pool2d", "elementwise_add"]. + Examples: + .. code-block:: python + import paddle.fluid as fluid + from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization + + exe = fluid.Executor(fluid.CPUPlace()) + model_path = load_fp32_model_path + save_model_path = save_int8_path + data_reader = your_data_reader + batch_size = 10 + batch_nums = 10 + algo = "KL" + quantizable_op_type = ["conv2d", \ + "depthwise_conv2d", "mul", "pool2d", "elementwise_add"] + ptq = PostTrainingQuantization( + executor=exe, + model_path=model_path, + data_reader=data_reader, + batch_size=batch_size, + batch_nums=batch_nums, + algo=algo, + quantizable_op_type=quantizable_op_type) + ptq.quantize() + ptq.save_quantized_model(save_model_path) + ''' + self._executor = executor + self._model_path = model_path + self._data_reader = data_reader + self._batch_size = batch_size + self._batch_nums = batch_nums + self._scope = global_scope() if scope == None else scope + self._quantizable_op_type = quantizable_op_type + self._algo = algo + supported_quantizable_op_type = [ + "conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add" + ] + for op_type in self._quantizable_op_type: + assert op_type in supported_quantizable_op_type, \ + op_type + " is not supported for quantization." + + self._place = self._executor.place + self._program = None + self._feed_list = None + self._fetch_list = None + self._data_loader = None + + self._bit_length = 8 + self._quantized_weight_var_name = [] + self._quantized_act_var_name = [] + self._sampling_data = {} + self._quantized_var_scale_factor = {} + + def quantize(self): + ''' + Quantize the fp32 model. Use calibrate data to calculate the scale factor of + quantized variables, and inserts fake quant/dequant op to obtain the + quantized model. + + Return: + the program of quantized model. + ''' + self._prepare() + + batch_id = 0 + for data in self._data_loader(): + self._executor.run(program=self._program, + feed=data, + fetch_list=self._fetch_list) + self._sample_data() + + if batch_id % 5 == 0: + _logger.info("run batch: " + str(batch_id)) + batch_id += 1 + if self._batch_nums and batch_id >= self._batch_nums: + break + _logger.info("all run batch: " + str(batch_id)) + + self._calculate_scale_factor() + self._update_program() + + return self._program + + def save_quantized_model(self, save_model_path): + ''' + Save the quantized model to the disk. + + Args: + save_model_path(str): The path to save the quantized model + Return: + None + ''' + io.save_inference_model( + dirname=save_model_path, + feeded_var_names=self._feed_list, + target_vars=self._fetch_list, + executor=self._executor, + main_program=self._program) + + def _prepare(self): + ''' + Load model and set data loader, collect the variable names for sampling, + and set activation variables to be persistable. + ''' + # load model and set data loader + [self._program, self._feed_list, self._fetch_list] = \ + io.load_inference_model(self._model_path, self._executor) + feed_vars = [framework._get_var(str(var_name), self._program) \ + for var_name in self._feed_list] + self._data_loader = io.DataLoader.from_generator( + feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True) + self._data_loader.set_sample_generator( + self._data_reader, + batch_size=self._batch_size, + drop_last=True, + places=self._place) + + #collect the variable names for sampling + persistable_var_names = [] + for var in self._program.list_vars(): + if var.persistable: + persistable_var_names.append(var.name) + + block = self._program.global_block() + for op in block.ops: + op_type = op.type + if op_type in self._quantizable_op_type: + if op_type in ("conv2d", "depthwise_conv2d"): + self._quantized_act_var_name.append(op.input("Input")[0]) + self._quantized_weight_var_name.append( + op.input("Filter")[0]) + self._quantized_act_var_name.append(op.output("Output")[0]) + elif op_type == "mul": + x_var_name = op.input("X")[0] + y_var_name = op.input("Y")[0] + if x_var_name not in persistable_var_names and \ + y_var_name not in persistable_var_names: + op._set_attr("skip_quant", True) + _logger.warning("A mul op skip quant for two " + "input variables are not persistable") + else: + self._quantized_act_var_name.append(x_var_name) + self._quantized_weight_var_name.append(y_var_name) + self._quantized_act_var_name.append(op.output("Out")[0]) + elif op_type == "pool2d": + self._quantized_act_var_name.append(op.input("X")[0]) + elif op_type == "elementwise_add": + x_var_name = op.input("X")[0] + y_var_name = op.input("Y")[0] + if x_var_name not in persistable_var_names and \ + y_var_name not in persistable_var_names: + self._quantized_act_var_name.append(x_var_name) + self._quantized_act_var_name.append(y_var_name) + + # set activation variables to be persistable, + # so can obtain the tensor data in sample_data stage + for var in self._program.list_vars(): + if var.name in self._quantized_act_var_name: + var.persistable = True + + def _sample_data(self): + ''' + Sample the tensor data of quantized variables, + applied in every iteration. + ''' + for var_name in self._quantized_weight_var_name: + if var_name not in self._sampling_data: + var_tensor = self._load_var_value(var_name) + self._sampling_data[var_name] = var_tensor + + for var_name in self._quantized_act_var_name: + if var_name not in self._sampling_data: + self._sampling_data[var_name] = [] + var_tensor = self._load_var_value(var_name) + self._sampling_data[var_name].append(var_tensor) + + def _calculate_scale_factor(self): + ''' + Calculate the scale factor of quantized variables. + ''' + _logger.info("calculate scale factor ...") + + for var_name in self._quantized_weight_var_name: + data = self._sampling_data[var_name] + scale_factor_per_channel = [] + for i in range(data.shape[0]): + abs_max_value = np.max(np.abs(data[i])) + scale_factor_per_channel.append(abs_max_value) + self._quantized_var_scale_factor[ + var_name] = scale_factor_per_channel + + for var_name in self._quantized_act_var_name: + if self._algo == "KL": + self._quantized_var_scale_factor[var_name] = \ + self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name])) + else: + self._quantized_var_scale_factor[var_name] = \ + np.max(np.abs(self._sampling_data[var_name])) + + def _update_program(self): + ''' + Insert fake_quantize/fake_dequantize op to the program. + ''' + _logger.info("update the program ...") + + for var in self._program.list_vars(): + if var.name in self._quantized_act_var_name: + var.persistable = False + + # use QuantizationTransformPass to insert fake_quantize/fake_dequantize op + graph = IrGraph(core.Graph(self._program.desc), for_test=True) + + qtp_quantizable_op_type = [] + for op_type in ["conv2d", "depthwise_conv2d", "mul"]: + if op_type in self._quantizable_op_type: + qtp_quantizable_op_type.append(op_type) + transform_pass = QuantizationTransformPass( + scope=self._scope, + place=self._place, + weight_bits=self._bit_length, + activation_bits=self._bit_length, + activation_quantize_type='moving_average_abs_max', + weight_quantize_type='channel_wise_abs_max', + quantizable_op_type=qtp_quantizable_op_type) + transform_pass.apply(graph) + + # use AddQuantDequantPass to insert fake_quant_dequant op + aqdp_quantizable_op_type = [] + for op_type in ["pool2d", "elementwise_add"]: + if op_type in self._quantizable_op_type: + aqdp_quantizable_op_type.append(op_type) + add_quant_dequant_pass = AddQuantDequantPass( + scope=self._scope, + place=self._place, + quantizable_op_type=aqdp_quantizable_op_type) + add_quant_dequant_pass.apply(graph) + + # save scale factor to scale var node + for key, val in self._quantized_var_scale_factor.items(): + self._set_var_node_value( + key + ".scale", np.array( + [val], dtype=np.float32)) + self._set_var_node_value( + key + ".quant_dequant.scale", np.array( + [val], dtype=np.float32)) + + # apply QuantizationFreezePass, and obtain the final quant model + freeze_pass = QuantizationFreezePass( + scope=self._scope, + place=self._place, + weight_bits=self._bit_length, + activation_bits=self._bit_length, + weight_quantize_type='channel_wise_abs_max', + quantizable_op_type=qtp_quantizable_op_type) + freeze_pass.apply(graph) + self._program = graph.to_program() + + def _load_var_value(self, var_name): + ''' + Load variable value from scope + ''' + return np.array(self._scope.find_var(var_name).get_tensor()) + + def _set_var_node_value(self, var_node_name, np_value): + ''' + Set the value of var node by name, if the node is not exits, + ''' + assert isinstance(np_value, np.ndarray), \ + 'The type of value should be numpy array.' + var_node = self._scope.find_var(var_node_name) + if var_node != None: + tensor = var_node.get_tensor() + tensor.set(np_value, self._place) + + def _get_kl_scaling_factor(self, activation_blob, num_quantized_bins=255): + ''' + Using the KL-divergenc method to get the more precise scaling factor. + ''' + max_val = np.max(activation_blob) + min_val = np.min(activation_blob) + if min_val >= 0: + hist, hist_edeges = np.histogram( + activation_blob, bins=2048, range=(min_val, max_val)) + ending_iter = 2047 + starting_iter = int(ending_iter * 0.7) + else: + _logger.error("Please first apply abs to activation_blob.") + bin_width = hist_edeges[1] - hist_edeges[0] + + P_sum = len(np.array(activation_blob).ravel()) + min_kl_divergence = 0 + min_kl_index = 0 + kl_inited = False + for i in range(starting_iter, ending_iter + 1): + reference_distr_P = hist[0:i].tolist() + outliers_count = sum(hist[i:2048]) + if reference_distr_P[i - 1] == 0: + continue + reference_distr_P[i - 1] += outliers_count + reference_distr_bins = reference_distr_P[:] + candidate_distr_Q = hist[0:i].tolist() + num_merged_bins = int(i / num_quantized_bins) + candidate_distr_Q_quantized = [0] * num_quantized_bins + j_start = 0 + j_end = num_merged_bins + for idx in range(num_quantized_bins): + candidate_distr_Q_quantized[idx] = sum(candidate_distr_Q[ + j_start:j_end]) + j_start += num_merged_bins + j_end += num_merged_bins + if (idx + 1) == num_quantized_bins - 1: + j_end = i + candidate_distr_Q = self._expand_quantized_bins( + candidate_distr_Q_quantized, reference_distr_bins) + Q_sum = sum(candidate_distr_Q) + kl_divergence = self._safe_entropy(reference_distr_P, P_sum, + candidate_distr_Q, Q_sum) + if not kl_inited: + min_kl_divergence = kl_divergence + min_kl_index = i + kl_inited = True + elif kl_divergence < min_kl_divergence: + min_kl_divergence = kl_divergence + min_kl_index = i + else: + pass + if min_kl_index == 0: + while starting_iter > 0: + if hist[starting_iter] == 0: + starting_iter -= 1 + continue + else: + break + min_kl_index = starting_iter + return (min_kl_index + 0.5) * bin_width + + def _expand_quantized_bins(self, quantized_bins, reference_bins): + ''' + ''' + expanded_quantized_bins = [0] * len(reference_bins) + num_merged_bins = int(len(reference_bins) / len(quantized_bins)) + j_start = 0 + j_end = num_merged_bins + for idx in range(len(quantized_bins)): + zero_count = reference_bins[j_start:j_end].count(0) + num_merged_bins = j_end - j_start + if zero_count == num_merged_bins: + avg_bin_ele = 0 + else: + avg_bin_ele = quantized_bins[idx] / ( + num_merged_bins - zero_count + 0.0) + for idx1 in range(j_start, j_end): + expanded_quantized_bins[idx1] = (0 if reference_bins[idx1] == 0 + else avg_bin_ele) + j_start += num_merged_bins + j_end += num_merged_bins + if (idx + 1) == len(quantized_bins) - 1: + j_end = len(reference_bins) + return expanded_quantized_bins + + def _safe_entropy(self, reference_distr_P, P_sum, candidate_distr_Q, Q_sum): + ''' + Calculate the entropy. + ''' + assert len(reference_distr_P) == len(candidate_distr_Q) + tmp_sum1 = 0 + tmp_sum2 = 0 + for idx in range(len(reference_distr_P)): + p_idx = reference_distr_P[idx] + q_idx = candidate_distr_Q[idx] + if p_idx == 0: + tmp_sum1 += 0 + tmp_sum2 += 0 + else: + if q_idx == 0: + print("Fatal error!, idx = " + str(idx) + + " qindex = 0! p_idx = " + str(p_idx)) + tmp_sum1 += p_idx * (math.log(Q_sum * p_idx)) + tmp_sum2 += p_idx * (math.log(P_sum * q_idx)) + return (tmp_sum1 - tmp_sum2) / P_sum diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 39473773f45..702ec4d9f94 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -26,8 +26,6 @@ __all__ = [ 'AddQuantDequantPass' ] -_quantizable_op_list = ['conv2d', 'depthwise_conv2d', 'mul'] - _fake_quant_op_list = [ 'fake_quantize_abs_max', 'fake_quantize_range_abs_max', 'fake_quantize_moving_average_abs_max', 'fake_channel_wise_quantize_abs_max' @@ -65,17 +63,18 @@ class QuantizationTransformPass(object): weight_quantize_type='abs_max', window_size=10000, moving_rate=0.9, - skip_pattern='skip_quant'): + skip_pattern='skip_quant', + quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']): """ Convert and rewrite the IrGraph according to weight and activation quantization type. Args: scope(fluid.Scope): When activation use 'range_abs_max' as the quantize - type, this pass will create some new parameters. The scope is used to - initialize these new parameters. + type, this pass will create some new parameters. The scope is used to + initialize these new parameters. place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new - parameters described above. + parameters described above. weight_bits (int): quantization bit number for weights, the bias is not quantized. activation_bits (int): quantization bit number for activation. @@ -93,6 +92,8 @@ class QuantizationTransformPass(object): skip_pattern(str): The user-defined quantization skip pattern, which will be presented in the name scope of an op. When the skip pattern is detected in an op's name scope, the corresponding op will not be quantized. + quantizable_op_type(list[str]): List the type of ops that will be quantized. + Default is ["conv2d", "depthwise_conv2d", "mul"]. Examples: .. code-block:: python @@ -119,7 +120,8 @@ class QuantizationTransformPass(object): 'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max' ] - assert activation_quantize_type != 'channel_wise_abs_max', "The activation quantization type does not support 'channel_wise_abs_max'." + assert activation_quantize_type != 'channel_wise_abs_max', \ + "The activation quantization type does not support 'channel_wise_abs_max'." if activation_quantize_type not in quant_type: raise ValueError( "Unknown activation_quantize_type : '%s'. It can only be " @@ -136,7 +138,11 @@ class QuantizationTransformPass(object): self._window_size = window_size self._moving_rate = moving_rate - self._quantizable_ops = _quantizable_op_list + self._quantizable_ops = quantizable_op_type + supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] + for op in self._quantizable_ops: + assert op in supported_quantizable_ops, \ + op + " is not supported for quantization." self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._quantizable_grad_ops = [ '%s_grad' % (op) for op in self._quantizable_ops @@ -595,9 +601,11 @@ class QuantizationFreezePass(object): place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors. weight_bits (int): quantization bit number for weights. activation_bits (int): quantization bit number for activation. - weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'. - The 'range_abs_max' usually is not used for weight, since weights are fixed once the - model is well trained. + weight_quantize_type (str): quantization type for weights, support 'abs_max' and + 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, + since weights are fixed once the model is well trained. + quantizable_op_type(list[str]): List the type of ops that will be quantized. + Default is ["conv2d", "depthwise_conv2d", "mul"]. """ def __init__(self, @@ -605,7 +613,8 @@ class QuantizationFreezePass(object): place, weight_bits=8, activation_bits=8, - weight_quantize_type='abs_max'): + weight_quantize_type='abs_max', + quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']): assert scope is not None, \ 'The scope cannot be set None.' assert place is not None, \ @@ -615,7 +624,11 @@ class QuantizationFreezePass(object): self._weight_bits = weight_bits self._activation_bits = activation_bits self._weight_quantize_type = weight_quantize_type - self._quantizable_ops = _quantizable_op_list + self._quantizable_ops = quantizable_op_type + supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] + for op in self._quantizable_ops: + assert op in supported_quantizable_ops, \ + op + " is not supported for quantization." self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._fake_quant_op_names = _fake_quant_op_list self._fake_dequant_op_names = _fake_dequant_op_list @@ -888,17 +901,26 @@ class ConvertToInt8Pass(object): Args: scope(fluid.Scope): scope is used to get the weight tensor values. place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the - 8bits weight tensors. + 8bits weight tensors. + quantizable_op_type(list[str]): List the type of ops that will be quantized. + Default is ["conv2d", "depthwise_conv2d", "mul"]. """ - def __init__(self, scope, place): + def __init__(self, + scope, + place, + quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']): assert scope is not None, \ 'The scope cannot be set None.' assert place is not None, \ 'The place cannot be set None.' self._scope = scope self._place = place - self._quantizable_ops = _quantizable_op_list + self._quantizable_ops = quantizable_op_type + supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] + for op in self._quantizable_ops: + assert op in supported_quantizable_ops, \ + op + " is not supported for quantization." def apply(self, graph): """ @@ -1166,7 +1188,8 @@ class AddQuantDequantPass(object): place=None, moving_rate=0.9, quant_bits=8, - skip_pattern='skip_quant'): + skip_pattern='skip_quant', + quantizable_op_type=["elementwise_add", "pool2d"]): """ This pass is used to add quant_dequant op for some ops, such as the 'elementwise_add' and 'pool2d' op. @@ -1176,9 +1199,16 @@ class AddQuantDequantPass(object): self._moving_rate = moving_rate self._quant_bits = quant_bits self._is_test = None - self._target_ops = ["elementwise_add", "pool2d"] - self._target_grad_ops = ['%s_grad' % (op) for op in self._target_ops] self._skip_pattern = skip_pattern + self._quantizable_op_type = quantizable_op_type + self._quantizable_grad_op_type = [ + '%s_grad' % (op) for op in self._quantizable_op_type + ] + + supported_quantizable_op_type = ["elementwise_add", "pool2d"] + for op_type in quantizable_op_type: + assert op_type in supported_quantizable_op_type, \ + op_type + " is not supported for quantization." def apply(self, graph): """ @@ -1194,7 +1224,7 @@ class AddQuantDequantPass(object): ops = graph.all_op_nodes() for op_node in ops: - if op_node.name() in self._target_ops: + if op_node.name() in self._quantizable_op_type: if isinstance(self._skip_pattern, str) and \ op_node.op().has_attr("op_namescope") and \ op_node.op().attr("op_namescope").find(self._skip_pattern) != -1: @@ -1221,7 +1251,7 @@ class AddQuantDequantPass(object): graph.update_input_link(in_node, quant_var_node, op_node) for op_node in ops: - if op_node.name() in self._target_grad_ops: + if op_node.name() in self._quantizable_grad_op_type: for input_name in op_node.input_arg_names(): if input_name in dequantized_vars_map: in_node = graph._find_node_by_name(op_node.inputs, diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 71e6c99fb54..40746ce33ed 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -48,6 +48,7 @@ endfunction() if(WIN32) list(REMOVE_ITEM TEST_OPS test_light_nas) + list(REMOVE_ITEM TEST_OPS test_post_training_quantization) endif() # int8 image classification python api test diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization.py new file mode 100644 index 00000000000..3c86a74612f --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization.py @@ -0,0 +1,354 @@ +# copyright (c) 2018 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. +import unittest +import os +import time +import sys +import random +import math +import functools +import contextlib +import numpy as np +from PIL import Image, ImageEnhance +import paddle +import paddle.fluid as fluid +from paddle.dataset.common import download +from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization + +random.seed(0) +np.random.seed(0) + +DATA_DIM = 224 +THREAD = 1 +BUF_SIZE = 102400 +DATA_DIR = 'data/ILSVRC2012' + +img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) +img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + + +def resize_short(img, target_size): + percent = float(target_size) / min(img.size[0], img.size[1]) + resized_width = int(round(img.size[0] * percent)) + resized_height = int(round(img.size[1] * percent)) + img = img.resize((resized_width, resized_height), Image.LANCZOS) + return img + + +def crop_image(img, target_size, center): + width, height = img.size + size = target_size + if center == True: + w_start = (width - size) / 2 + h_start = (height - size) / 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img.crop((w_start, h_start, w_end, h_end)) + return img + + +def process_image(sample, mode, color_jitter, rotate): + img_path = sample[0] + img = Image.open(img_path) + img = resize_short(img, target_size=256) + img = crop_image(img, target_size=DATA_DIM, center=True) + if img.mode != 'RGB': + img = img.convert('RGB') + img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255 + img -= img_mean + img /= img_std + return img, sample[1] + + +def _reader_creator(file_list, + mode, + shuffle=False, + color_jitter=False, + rotate=False, + data_dir=DATA_DIR): + def reader(): + with open(file_list) as flist: + full_lines = [line.strip() for line in flist] + if shuffle: + np.random.shuffle(full_lines) + lines = full_lines + + for line in lines: + img_path, label = line.split() + img_path = os.path.join(data_dir, img_path) + if not os.path.exists(img_path): + continue + yield img_path, int(label) + + mapper = functools.partial( + process_image, mode=mode, color_jitter=color_jitter, rotate=rotate) + + return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) + + +def val(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'val_list.txt') + return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir) + + +class TestPostTrainingQuantization(unittest.TestCase): + def setUp(self): + self.int8_download = 'int8/download' + self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + + self.int8_download) + + data_urls = [] + data_md5s = [] + self.data_cache_folder = '' + if os.environ.get('DATASET') == 'full': + data_urls.append( + 'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa' + ) + data_md5s.append('60f6525b0e1d127f345641d75d41f0a8') + data_urls.append( + 'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partab' + ) + data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5') + self.data_cache_folder = self.download_data(data_urls, data_md5s, + "full_data", False) + else: + data_urls.append( + 'http://paddle-inference-dist.bj.bcebos.com/int8/calibration_test_data.tar.gz' + ) + data_md5s.append('1b6c1c434172cca1bf9ba1e4d7a3157d') + self.data_cache_folder = self.download_data(data_urls, data_md5s, + "small_data", False) + + # reader/decorator.py requires the relative path to the data folder + cmd = 'rm -rf {0} && ln -s {1} {0}'.format("data", + self.data_cache_folder) + os.system(cmd) + + self.batch_size = 1 if os.environ.get('DATASET') == 'full' else 50 + self.sample_iterations = 50 if os.environ.get( + 'DATASET') == 'full' else 1 + self.infer_iterations = 50000 if os.environ.get( + 'DATASET') == 'full' else 1 + + self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) + self.int8_model = '' + + def tearDown(self): + try: + os.system("rm -rf {}".format(self.int8_model)) + except Exception as e: + print("Failed to delete {} due to {}".format(self.int8_model, + str(e))) + + def cache_unzipping(self, target_folder, zip_path): + if not os.path.exists(target_folder): + cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, + zip_path) + os.system(cmd) + + def download_data(self, data_urls, data_md5s, folder_name, is_model=True): + data_cache_folder = os.path.join(self.cache_folder, folder_name) + zip_path = '' + if os.environ.get('DATASET') == 'full': + file_names = [] + for i in range(0, len(data_urls)): + download(data_urls[i], self.int8_download, data_md5s[i]) + file_names.append(data_urls[i].split('/')[-1]) + + zip_path = os.path.join(self.cache_folder, + 'full_imagenet_val.tar.gz') + if not os.path.exists(zip_path): + cat_command = 'cat' + for file_name in file_names: + cat_command += ' ' + os.path.join(self.cache_folder, + file_name) + cat_command += ' > ' + zip_path + os.system(cat_command) + + if os.environ.get('DATASET') != 'full' or is_model: + download(data_urls[0], self.int8_download, data_md5s[0]) + file_name = data_urls[0].split('/')[-1] + zip_path = os.path.join(self.cache_folder, file_name) + + print('Data is downloaded at {0}'.format(zip_path)) + self.cache_unzipping(data_cache_folder, zip_path) + return data_cache_folder + + def download_model(self): + pass + + def run_program(self, model_path): + image_shape = [3, 224, 224] + place = fluid.CPUPlace() + exe = fluid.Executor(place) + [infer_program, feed_dict, fetch_targets] = \ + fluid.io.load_inference_model(model_path, exe) + val_reader = paddle.batch(val(), self.batch_size) + iterations = self.infer_iterations + + test_info = [] + cnt = 0 + periods = [] + for batch_id, data in enumerate(val_reader()): + image = np.array( + [x[0].reshape(image_shape) for x in data]).astype("float32") + label = np.array([x[1] for x in data]).astype("int64") + label = label.reshape([-1, 1]) + + t1 = time.time() + _, acc1, _ = exe.run( + infer_program, + feed={feed_dict[0]: image, + feed_dict[1]: label}, + fetch_list=fetch_targets) + t2 = time.time() + period = t2 - t1 + periods.append(period) + + test_info.append(np.mean(acc1) * len(data)) + cnt += len(data) + + if (batch_id + 1) % 100 == 0: + print("{0} images,".format(batch_id + 1)) + sys.stdout.flush() + if (batch_id + 1) == iterations: + break + + throughput = cnt / np.sum(periods) + latency = np.average(periods) + acc1 = np.sum(test_info) / cnt + return (throughput, latency, acc1) + + def generate_quantized_model(self, model_path, algo="KL"): + self.int8_model = os.path.join(os.getcwd(), + "post_training_" + self.timestamp) + try: + os.system("mkdir " + self.int8_model) + except Exception as e: + print("Failed to create {} due to {}".format(self.int8_model, + str(e))) + sys.exit(-1) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.global_scope() + val_reader = val() + quantizable_op_type = [ + "conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add" + ] + + ptq = PostTrainingQuantization( + executor=exe, + scope=scope, + model_path=model_path, + data_reader=val_reader, + algo=algo, + quantizable_op_type=quantizable_op_type) + ptq.quantize() + ptq.save_quantized_model(self.int8_model) + + +class TestPostTrainingForResnet50(TestPostTrainingQuantization): + def download_model(self): + # resnet50 fp32 data + data_urls = [ + 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' + ] + data_md5s = ['4a5194524823d9b76da6e738e1367881'] + self.model_cache_folder = self.download_data(data_urls, data_md5s, + "resnet50_fp32") + self.model = "ResNet-50" + self.algo = "KL" + + def test_post_training_resnet50(self): + self.download_model() + + print("Start FP32 inference for {0} on {1} images ...".format( + self.model, self.infer_iterations * self.batch_size)) + (fp32_throughput, fp32_latency, + fp32_acc1) = self.run_program(self.model_cache_folder + "/model") + + print("Start INT8 post training quantization for {0} on {1} images ...". + format(self.model, self.sample_iterations * self.batch_size)) + self.generate_quantized_model( + self.model_cache_folder + "/model", algo=self.algo) + + print("Start INT8 inference for {0} on {1} images ...".format( + self.model, self.infer_iterations * self.batch_size)) + (int8_throughput, int8_latency, + int8_acc1) = self.run_program(self.int8_model) + + print( + "FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}". + format(self.model, self.batch_size, fp32_throughput, fp32_latency, + fp32_acc1)) + print( + "INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}". + format(self.model, self.batch_size, int8_throughput, int8_latency, + int8_acc1)) + sys.stdout.flush() + + delta_value = fp32_acc1 - int8_acc1 + self.assertLess(delta_value, 0.025) + + +class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization): + def download_model(self): + # mobilenetv1 fp32 data + data_urls = [ + 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' + ] + data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] + self.model_cache_folder = self.download_data(data_urls, data_md5s, + "mobilenetv1_fp32") + self.model = "MobileNet-V1" + self.algo = "KL" + + def test_post_training_mobilenetv1(self): + self.download_model() + + print("Start FP32 inference for {0} on {1} images ...".format( + self.model, self.infer_iterations * self.batch_size)) + (fp32_throughput, fp32_latency, + fp32_acc1) = self.run_program(self.model_cache_folder + "/model") + + print("Start INT8 post training quantization for {0} on {1} images ...". + format(self.model, self.sample_iterations * self.batch_size)) + self.generate_quantized_model( + self.model_cache_folder + "/model", algo=self.algo) + + print("Start INT8 inference for {0} on {1} images ...".format( + self.model, self.infer_iterations * self.batch_size)) + (int8_throughput, int8_latency, + int8_acc1) = self.run_program(self.int8_model) + + print( + "FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}". + format(self.model, self.batch_size, fp32_throughput, fp32_latency, + fp32_acc1)) + print( + "INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}". + format(self.model, self.batch_size, int8_throughput, int8_latency, + int8_acc1)) + sys.stdout.flush() + + delta_value = fp32_acc1 - int8_acc1 + self.assertLess(delta_value, 0.025) + + +if __name__ == '__main__': + unittest.main() -- GitLab