未验证 提交 175ba39c 编写于 作者: J juncaipeng 提交者: GitHub

Add post_training_quantization (#20800)

* add post training quantization, test=develop
* specify the quantizable op type, test=develop
上级 0059404e
......@@ -22,7 +22,10 @@ from . import mkldnn_post_training_strategy
from .mkldnn_post_training_strategy import *
from . import quantization_mkldnn_pass
from .quantization_mkldnn_pass import *
from . import post_training_quantization
from .post_training_quantization import *
__all__ = quantization_pass.__all__ + quantization_strategy.__all__
__all__ += mkldnn_post_training_strategy.__all__
__all__ += quantization_mkldnn_pass.__all__
__all__ += post_training_quantization.__all__
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import logging
import numpy as np
from ....executor import global_scope
from .... import io
from .... import core
from .... import framework
from ....framework import IrGraph
from ....log_helper import get_logger
from .quantization_pass import QuantizationTransformPass
from .quantization_pass import QuantizationFreezePass
from .quantization_pass import AddQuantDequantPass
__all__ = ['PostTrainingQuantization']
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class PostTrainingQuantization(object):
def __init__(self,
executor,
model_path,
data_reader,
batch_size=10,
batch_nums=None,
scope=None,
algo="KL",
quantizable_op_type=[
"conv2d", "depthwise_conv2d", "mul", "pool2d",
"elementwise_add"
]):
'''
The class utilizes post training quantization methon to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args:
executor(fluid.Executor): The executor to load, run and save the
quantized model.
model_path(str): The path of fp32 model that will be quantized.
data_reader(Reader): The data reader generates a sample every time,
and it provides calibrate data for DataLoader.
batch_size(int, optional): The batch size of DataLoader, default is 10.
batch_nums(int, optional): If set batch_nums, the number of calibrate
data is batch_size*batch_nums. If batch_nums=None, use all data
provided by data_reader as calibrate data.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use
abs_max methon to get the scale factor. Default is KL.
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul", "pool2d", "elementwise_add"].
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
exe = fluid.Executor(fluid.CPUPlace())
model_path = load_fp32_model_path
save_model_path = save_int8_path
data_reader = your_data_reader
batch_size = 10
batch_nums = 10
algo = "KL"
quantizable_op_type = ["conv2d", \
"depthwise_conv2d", "mul", "pool2d", "elementwise_add"]
ptq = PostTrainingQuantization(
executor=exe,
model_path=model_path,
data_reader=data_reader,
batch_size=batch_size,
batch_nums=batch_nums,
algo=algo,
quantizable_op_type=quantizable_op_type)
ptq.quantize()
ptq.save_quantized_model(save_model_path)
'''
self._executor = executor
self._model_path = model_path
self._data_reader = data_reader
self._batch_size = batch_size
self._batch_nums = batch_nums
self._scope = global_scope() if scope == None else scope
self._quantizable_op_type = quantizable_op_type
self._algo = algo
supported_quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
for op_type in self._quantizable_op_type:
assert op_type in supported_quantizable_op_type, \
op_type + " is not supported for quantization."
self._place = self._executor.place
self._program = None
self._feed_list = None
self._fetch_list = None
self._data_loader = None
self._bit_length = 8
self._quantized_weight_var_name = []
self._quantized_act_var_name = []
self._sampling_data = {}
self._quantized_var_scale_factor = {}
def quantize(self):
'''
Quantize the fp32 model. Use calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Return:
the program of quantized model.
'''
self._prepare()
batch_id = 0
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list)
self._sample_data()
if batch_id % 5 == 0:
_logger.info("run batch: " + str(batch_id))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("all run batch: " + str(batch_id))
self._calculate_scale_factor()
self._update_program()
return self._program
def save_quantized_model(self, save_model_path):
'''
Save the quantized model to the disk.
Args:
save_model_path(str): The path to save the quantized model
Return:
None
'''
io.save_inference_model(
dirname=save_model_path,
feeded_var_names=self._feed_list,
target_vars=self._fetch_list,
executor=self._executor,
main_program=self._program)
def _prepare(self):
'''
Load model and set data loader, collect the variable names for sampling,
and set activation variables to be persistable.
'''
# load model and set data loader
[self._program, self._feed_list, self._fetch_list] = \
io.load_inference_model(self._model_path, self._executor)
feed_vars = [framework._get_var(str(var_name), self._program) \
for var_name in self._feed_list]
self._data_loader = io.DataLoader.from_generator(
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
self._data_loader.set_sample_generator(
self._data_reader,
batch_size=self._batch_size,
drop_last=True,
places=self._place)
#collect the variable names for sampling
persistable_var_names = []
for var in self._program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
block = self._program.global_block()
for op in block.ops:
op_type = op.type
if op_type in self._quantizable_op_type:
if op_type in ("conv2d", "depthwise_conv2d"):
self._quantized_act_var_name.append(op.input("Input")[0])
self._quantized_weight_var_name.append(
op.input("Filter")[0])
self._quantized_act_var_name.append(op.output("Output")[0])
elif op_type == "mul":
x_var_name = op.input("X")[0]
y_var_name = op.input("Y")[0]
if x_var_name not in persistable_var_names and \
y_var_name not in persistable_var_names:
op._set_attr("skip_quant", True)
_logger.warning("A mul op skip quant for two "
"input variables are not persistable")
else:
self._quantized_act_var_name.append(x_var_name)
self._quantized_weight_var_name.append(y_var_name)
self._quantized_act_var_name.append(op.output("Out")[0])
elif op_type == "pool2d":
self._quantized_act_var_name.append(op.input("X")[0])
elif op_type == "elementwise_add":
x_var_name = op.input("X")[0]
y_var_name = op.input("Y")[0]
if x_var_name not in persistable_var_names and \
y_var_name not in persistable_var_names:
self._quantized_act_var_name.append(x_var_name)
self._quantized_act_var_name.append(y_var_name)
# set activation variables to be persistable,
# so can obtain the tensor data in sample_data stage
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = True
def _sample_data(self):
'''
Sample the tensor data of quantized variables,
applied in every iteration.
'''
for var_name in self._quantized_weight_var_name:
if var_name not in self._sampling_data:
var_tensor = self._load_var_value(var_name)
self._sampling_data[var_name] = var_tensor
for var_name in self._quantized_act_var_name:
if var_name not in self._sampling_data:
self._sampling_data[var_name] = []
var_tensor = self._load_var_value(var_name)
self._sampling_data[var_name].append(var_tensor)
def _calculate_scale_factor(self):
'''
Calculate the scale factor of quantized variables.
'''
_logger.info("calculate scale factor ...")
for var_name in self._quantized_weight_var_name:
data = self._sampling_data[var_name]
scale_factor_per_channel = []
for i in range(data.shape[0]):
abs_max_value = np.max(np.abs(data[i]))
scale_factor_per_channel.append(abs_max_value)
self._quantized_var_scale_factor[
var_name] = scale_factor_per_channel
for var_name in self._quantized_act_var_name:
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(self._sampling_data[var_name]))
def _update_program(self):
'''
Insert fake_quantize/fake_dequantize op to the program.
'''
_logger.info("update the program ...")
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = False
# use QuantizationTransformPass to insert fake_quantize/fake_dequantize op
graph = IrGraph(core.Graph(self._program.desc), for_test=True)
qtp_quantizable_op_type = []
for op_type in ["conv2d", "depthwise_conv2d", "mul"]:
if op_type in self._quantizable_op_type:
qtp_quantizable_op_type.append(op_type)
transform_pass = QuantizationTransformPass(
scope=self._scope,
place=self._place,
weight_bits=self._bit_length,
activation_bits=self._bit_length,
activation_quantize_type='moving_average_abs_max',
weight_quantize_type='channel_wise_abs_max',
quantizable_op_type=qtp_quantizable_op_type)
transform_pass.apply(graph)
# use AddQuantDequantPass to insert fake_quant_dequant op
aqdp_quantizable_op_type = []
for op_type in ["pool2d", "elementwise_add"]:
if op_type in self._quantizable_op_type:
aqdp_quantizable_op_type.append(op_type)
add_quant_dequant_pass = AddQuantDequantPass(
scope=self._scope,
place=self._place,
quantizable_op_type=aqdp_quantizable_op_type)
add_quant_dequant_pass.apply(graph)
# save scale factor to scale var node
for key, val in self._quantized_var_scale_factor.items():
self._set_var_node_value(
key + ".scale", np.array(
[val], dtype=np.float32))
self._set_var_node_value(
key + ".quant_dequant.scale", np.array(
[val], dtype=np.float32))
# apply QuantizationFreezePass, and obtain the final quant model
freeze_pass = QuantizationFreezePass(
scope=self._scope,
place=self._place,
weight_bits=self._bit_length,
activation_bits=self._bit_length,
weight_quantize_type='channel_wise_abs_max',
quantizable_op_type=qtp_quantizable_op_type)
freeze_pass.apply(graph)
self._program = graph.to_program()
def _load_var_value(self, var_name):
'''
Load variable value from scope
'''
return np.array(self._scope.find_var(var_name).get_tensor())
def _set_var_node_value(self, var_node_name, np_value):
'''
Set the value of var node by name, if the node is not exits,
'''
assert isinstance(np_value, np.ndarray), \
'The type of value should be numpy array.'
var_node = self._scope.find_var(var_node_name)
if var_node != None:
tensor = var_node.get_tensor()
tensor.set(np_value, self._place)
def _get_kl_scaling_factor(self, activation_blob, num_quantized_bins=255):
'''
Using the KL-divergenc method to get the more precise scaling factor.
'''
max_val = np.max(activation_blob)
min_val = np.min(activation_blob)
if min_val >= 0:
hist, hist_edeges = np.histogram(
activation_blob, bins=2048, range=(min_val, max_val))
ending_iter = 2047
starting_iter = int(ending_iter * 0.7)
else:
_logger.error("Please first apply abs to activation_blob.")
bin_width = hist_edeges[1] - hist_edeges[0]
P_sum = len(np.array(activation_blob).ravel())
min_kl_divergence = 0
min_kl_index = 0
kl_inited = False
for i in range(starting_iter, ending_iter + 1):
reference_distr_P = hist[0:i].tolist()
outliers_count = sum(hist[i:2048])
if reference_distr_P[i - 1] == 0:
continue
reference_distr_P[i - 1] += outliers_count
reference_distr_bins = reference_distr_P[:]
candidate_distr_Q = hist[0:i].tolist()
num_merged_bins = int(i / num_quantized_bins)
candidate_distr_Q_quantized = [0] * num_quantized_bins
j_start = 0
j_end = num_merged_bins
for idx in range(num_quantized_bins):
candidate_distr_Q_quantized[idx] = sum(candidate_distr_Q[
j_start:j_end])
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == num_quantized_bins - 1:
j_end = i
candidate_distr_Q = self._expand_quantized_bins(
candidate_distr_Q_quantized, reference_distr_bins)
Q_sum = sum(candidate_distr_Q)
kl_divergence = self._safe_entropy(reference_distr_P, P_sum,
candidate_distr_Q, Q_sum)
if not kl_inited:
min_kl_divergence = kl_divergence
min_kl_index = i
kl_inited = True
elif kl_divergence < min_kl_divergence:
min_kl_divergence = kl_divergence
min_kl_index = i
else:
pass
if min_kl_index == 0:
while starting_iter > 0:
if hist[starting_iter] == 0:
starting_iter -= 1
continue
else:
break
min_kl_index = starting_iter
return (min_kl_index + 0.5) * bin_width
def _expand_quantized_bins(self, quantized_bins, reference_bins):
'''
'''
expanded_quantized_bins = [0] * len(reference_bins)
num_merged_bins = int(len(reference_bins) / len(quantized_bins))
j_start = 0
j_end = num_merged_bins
for idx in range(len(quantized_bins)):
zero_count = reference_bins[j_start:j_end].count(0)
num_merged_bins = j_end - j_start
if zero_count == num_merged_bins:
avg_bin_ele = 0
else:
avg_bin_ele = quantized_bins[idx] / (
num_merged_bins - zero_count + 0.0)
for idx1 in range(j_start, j_end):
expanded_quantized_bins[idx1] = (0 if reference_bins[idx1] == 0
else avg_bin_ele)
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == len(quantized_bins) - 1:
j_end = len(reference_bins)
return expanded_quantized_bins
def _safe_entropy(self, reference_distr_P, P_sum, candidate_distr_Q, Q_sum):
'''
Calculate the entropy.
'''
assert len(reference_distr_P) == len(candidate_distr_Q)
tmp_sum1 = 0
tmp_sum2 = 0
for idx in range(len(reference_distr_P)):
p_idx = reference_distr_P[idx]
q_idx = candidate_distr_Q[idx]
if p_idx == 0:
tmp_sum1 += 0
tmp_sum2 += 0
else:
if q_idx == 0:
print("Fatal error!, idx = " + str(idx) +
" qindex = 0! p_idx = " + str(p_idx))
tmp_sum1 += p_idx * (math.log(Q_sum * p_idx))
tmp_sum2 += p_idx * (math.log(P_sum * q_idx))
return (tmp_sum1 - tmp_sum2) / P_sum
......@@ -26,8 +26,6 @@ __all__ = [
'AddQuantDequantPass'
]
_quantizable_op_list = ['conv2d', 'depthwise_conv2d', 'mul']
_fake_quant_op_list = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
'fake_quantize_moving_average_abs_max', 'fake_channel_wise_quantize_abs_max'
......@@ -65,17 +63,18 @@ class QuantizationTransformPass(object):
weight_quantize_type='abs_max',
window_size=10000,
moving_rate=0.9,
skip_pattern='skip_quant'):
skip_pattern='skip_quant',
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
"""
Convert and rewrite the IrGraph according to weight and
activation quantization type.
Args:
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
type, this pass will create some new parameters. The scope is used to
initialize these new parameters.
type, this pass will create some new parameters. The scope is used to
initialize these new parameters.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new
parameters described above.
parameters described above.
weight_bits (int): quantization bit number for weights,
the bias is not quantized.
activation_bits (int): quantization bit number for activation.
......@@ -93,6 +92,8 @@ class QuantizationTransformPass(object):
skip_pattern(str): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is
detected in an op's name scope, the corresponding op will not be quantized.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"].
Examples:
.. code-block:: python
......@@ -119,7 +120,8 @@ class QuantizationTransformPass(object):
'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max'
]
assert activation_quantize_type != 'channel_wise_abs_max', "The activation quantization type does not support 'channel_wise_abs_max'."
assert activation_quantize_type != 'channel_wise_abs_max', \
"The activation quantization type does not support 'channel_wise_abs_max'."
if activation_quantize_type not in quant_type:
raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be "
......@@ -136,7 +138,11 @@ class QuantizationTransformPass(object):
self._window_size = window_size
self._moving_rate = moving_rate
self._quantizable_ops = _quantizable_op_list
self._quantizable_ops = quantizable_op_type
supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
for op in self._quantizable_ops:
assert op in supported_quantizable_ops, \
op + " is not supported for quantization."
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops
......@@ -595,9 +601,11 @@ class QuantizationFreezePass(object):
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors.
weight_bits (int): quantization bit number for weights.
activation_bits (int): quantization bit number for activation.
weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'.
The 'range_abs_max' usually is not used for weight, since weights are fixed once the
model is well trained.
weight_quantize_type (str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"].
"""
def __init__(self,
......@@ -605,7 +613,8 @@ class QuantizationFreezePass(object):
place,
weight_bits=8,
activation_bits=8,
weight_quantize_type='abs_max'):
weight_quantize_type='abs_max',
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
assert scope is not None, \
'The scope cannot be set None.'
assert place is not None, \
......@@ -615,7 +624,11 @@ class QuantizationFreezePass(object):
self._weight_bits = weight_bits
self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = _quantizable_op_list
self._quantizable_ops = quantizable_op_type
supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
for op in self._quantizable_ops:
assert op in supported_quantizable_ops, \
op + " is not supported for quantization."
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._fake_quant_op_names = _fake_quant_op_list
self._fake_dequant_op_names = _fake_dequant_op_list
......@@ -888,17 +901,26 @@ class ConvertToInt8Pass(object):
Args:
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
8bits weight tensors.
8bits weight tensors.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"].
"""
def __init__(self, scope, place):
def __init__(self,
scope,
place,
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
assert scope is not None, \
'The scope cannot be set None.'
assert place is not None, \
'The place cannot be set None.'
self._scope = scope
self._place = place
self._quantizable_ops = _quantizable_op_list
self._quantizable_ops = quantizable_op_type
supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
for op in self._quantizable_ops:
assert op in supported_quantizable_ops, \
op + " is not supported for quantization."
def apply(self, graph):
"""
......@@ -1166,7 +1188,8 @@ class AddQuantDequantPass(object):
place=None,
moving_rate=0.9,
quant_bits=8,
skip_pattern='skip_quant'):
skip_pattern='skip_quant',
quantizable_op_type=["elementwise_add", "pool2d"]):
"""
This pass is used to add quant_dequant op for some ops, such as the
'elementwise_add' and 'pool2d' op.
......@@ -1176,9 +1199,16 @@ class AddQuantDequantPass(object):
self._moving_rate = moving_rate
self._quant_bits = quant_bits
self._is_test = None
self._target_ops = ["elementwise_add", "pool2d"]
self._target_grad_ops = ['%s_grad' % (op) for op in self._target_ops]
self._skip_pattern = skip_pattern
self._quantizable_op_type = quantizable_op_type
self._quantizable_grad_op_type = [
'%s_grad' % (op) for op in self._quantizable_op_type
]
supported_quantizable_op_type = ["elementwise_add", "pool2d"]
for op_type in quantizable_op_type:
assert op_type in supported_quantizable_op_type, \
op_type + " is not supported for quantization."
def apply(self, graph):
"""
......@@ -1194,7 +1224,7 @@ class AddQuantDequantPass(object):
ops = graph.all_op_nodes()
for op_node in ops:
if op_node.name() in self._target_ops:
if op_node.name() in self._quantizable_op_type:
if isinstance(self._skip_pattern, str) and \
op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1:
......@@ -1221,7 +1251,7 @@ class AddQuantDequantPass(object):
graph.update_input_link(in_node, quant_var_node, op_node)
for op_node in ops:
if op_node.name() in self._target_grad_ops:
if op_node.name() in self._quantizable_grad_op_type:
for input_name in op_node.input_arg_names():
if input_name in dequantized_vars_map:
in_node = graph._find_node_by_name(op_node.inputs,
......
......@@ -48,6 +48,7 @@ endfunction()
if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization)
endif()
# int8 image classification python api test
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import time
import sys
import random
import math
import functools
import contextlib
import numpy as np
from PIL import Image, ImageEnhance
import paddle
import paddle.fluid as fluid
from paddle.dataset.common import download
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
random.seed(0)
np.random.seed(0)
DATA_DIM = 224
THREAD = 1
BUF_SIZE = 102400
DATA_DIR = 'data/ILSVRC2012'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.LANCZOS)
return img
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def process_image(sample, mode, color_jitter, rotate):
img_path = sample[0]
img = Image.open(img_path)
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= img_mean
img /= img_std
return img, sample[1]
def _reader_creator(file_list,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=DATA_DIR):
def reader():
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
np.random.shuffle(full_lines)
lines = full_lines
for line in lines:
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
if not os.path.exists(img_path):
continue
yield img_path, int(label)
mapper = functools.partial(
process_image, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def val(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
class TestPostTrainingQuantization(unittest.TestCase):
def setUp(self):
self.int8_download = 'int8/download'
self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
self.int8_download)
data_urls = []
data_md5s = []
self.data_cache_folder = ''
if os.environ.get('DATASET') == 'full':
data_urls.append(
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa'
)
data_md5s.append('60f6525b0e1d127f345641d75d41f0a8')
data_urls.append(
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partab'
)
data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5')
self.data_cache_folder = self.download_data(data_urls, data_md5s,
"full_data", False)
else:
data_urls.append(
'http://paddle-inference-dist.bj.bcebos.com/int8/calibration_test_data.tar.gz'
)
data_md5s.append('1b6c1c434172cca1bf9ba1e4d7a3157d')
self.data_cache_folder = self.download_data(data_urls, data_md5s,
"small_data", False)
# reader/decorator.py requires the relative path to the data folder
cmd = 'rm -rf {0} && ln -s {1} {0}'.format("data",
self.data_cache_folder)
os.system(cmd)
self.batch_size = 1 if os.environ.get('DATASET') == 'full' else 50
self.sample_iterations = 50 if os.environ.get(
'DATASET') == 'full' else 1
self.infer_iterations = 50000 if os.environ.get(
'DATASET') == 'full' else 1
self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
self.int8_model = ''
def tearDown(self):
try:
os.system("rm -rf {}".format(self.int8_model))
except Exception as e:
print("Failed to delete {} due to {}".format(self.int8_model,
str(e)))
def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder,
zip_path)
os.system(cmd)
def download_data(self, data_urls, data_md5s, folder_name, is_model=True):
data_cache_folder = os.path.join(self.cache_folder, folder_name)
zip_path = ''
if os.environ.get('DATASET') == 'full':
file_names = []
for i in range(0, len(data_urls)):
download(data_urls[i], self.int8_download, data_md5s[i])
file_names.append(data_urls[i].split('/')[-1])
zip_path = os.path.join(self.cache_folder,
'full_imagenet_val.tar.gz')
if not os.path.exists(zip_path):
cat_command = 'cat'
for file_name in file_names:
cat_command += ' ' + os.path.join(self.cache_folder,
file_name)
cat_command += ' > ' + zip_path
os.system(cat_command)
if os.environ.get('DATASET') != 'full' or is_model:
download(data_urls[0], self.int8_download, data_md5s[0])
file_name = data_urls[0].split('/')[-1]
zip_path = os.path.join(self.cache_folder, file_name)
print('Data is downloaded at {0}'.format(zip_path))
self.cache_unzipping(data_cache_folder, zip_path)
return data_cache_folder
def download_model(self):
pass
def run_program(self, model_path):
image_shape = [3, 224, 224]
place = fluid.CPUPlace()
exe = fluid.Executor(place)
[infer_program, feed_dict, fetch_targets] = \
fluid.io.load_inference_model(model_path, exe)
val_reader = paddle.batch(val(), self.batch_size)
iterations = self.infer_iterations
test_info = []
cnt = 0
periods = []
for batch_id, data in enumerate(val_reader()):
image = np.array(
[x[0].reshape(image_shape) for x in data]).astype("float32")
label = np.array([x[1] for x in data]).astype("int64")
label = label.reshape([-1, 1])
t1 = time.time()
_, acc1, _ = exe.run(
infer_program,
feed={feed_dict[0]: image,
feed_dict[1]: label},
fetch_list=fetch_targets)
t2 = time.time()
period = t2 - t1
periods.append(period)
test_info.append(np.mean(acc1) * len(data))
cnt += len(data)
if (batch_id + 1) % 100 == 0:
print("{0} images,".format(batch_id + 1))
sys.stdout.flush()
if (batch_id + 1) == iterations:
break
throughput = cnt / np.sum(periods)
latency = np.average(periods)
acc1 = np.sum(test_info) / cnt
return (throughput, latency, acc1)
def generate_quantized_model(self, model_path, algo="KL"):
self.int8_model = os.path.join(os.getcwd(),
"post_training_" + self.timestamp)
try:
os.system("mkdir " + self.int8_model)
except Exception as e:
print("Failed to create {} due to {}".format(self.int8_model,
str(e)))
sys.exit(-1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.global_scope()
val_reader = val()
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
ptq = PostTrainingQuantization(
executor=exe,
scope=scope,
model_path=model_path,
data_reader=val_reader,
algo=algo,
quantizable_op_type=quantizable_op_type)
ptq.quantize()
ptq.save_quantized_model(self.int8_model)
class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def download_model(self):
# resnet50 fp32 data
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
data_md5s = ['4a5194524823d9b76da6e738e1367881']
self.model_cache_folder = self.download_data(data_urls, data_md5s,
"resnet50_fp32")
self.model = "ResNet-50"
self.algo = "KL"
def test_post_training_resnet50(self):
self.download_model()
print("Start FP32 inference for {0} on {1} images ...".format(
self.model, self.infer_iterations * self.batch_size))
(fp32_throughput, fp32_latency,
fp32_acc1) = self.run_program(self.model_cache_folder + "/model")
print("Start INT8 post training quantization for {0} on {1} images ...".
format(self.model, self.sample_iterations * self.batch_size))
self.generate_quantized_model(
self.model_cache_folder + "/model", algo=self.algo)
print("Start INT8 inference for {0} on {1} images ...".format(
self.model, self.infer_iterations * self.batch_size))
(int8_throughput, int8_latency,
int8_acc1) = self.run_program(self.int8_model)
print(
"FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}".
format(self.model, self.batch_size, fp32_throughput, fp32_latency,
fp32_acc1))
print(
"INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}".
format(self.model, self.batch_size, int8_throughput, int8_latency,
int8_acc1))
sys.stdout.flush()
delta_value = fp32_acc1 - int8_acc1
self.assertLess(delta_value, 0.025)
class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
def download_model(self):
# mobilenetv1 fp32 data
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
self.model_cache_folder = self.download_data(data_urls, data_md5s,
"mobilenetv1_fp32")
self.model = "MobileNet-V1"
self.algo = "KL"
def test_post_training_mobilenetv1(self):
self.download_model()
print("Start FP32 inference for {0} on {1} images ...".format(
self.model, self.infer_iterations * self.batch_size))
(fp32_throughput, fp32_latency,
fp32_acc1) = self.run_program(self.model_cache_folder + "/model")
print("Start INT8 post training quantization for {0} on {1} images ...".
format(self.model, self.sample_iterations * self.batch_size))
self.generate_quantized_model(
self.model_cache_folder + "/model", algo=self.algo)
print("Start INT8 inference for {0} on {1} images ...".format(
self.model, self.infer_iterations * self.batch_size))
(int8_throughput, int8_latency,
int8_acc1) = self.run_program(self.int8_model)
print(
"FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}".
format(self.model, self.batch_size, fp32_throughput, fp32_latency,
fp32_acc1))
print(
"INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}".
format(self.model, self.batch_size, int8_throughput, int8_latency,
int8_acc1))
sys.stdout.flush()
delta_value = fp32_acc1 - int8_acc1
self.assertLess(delta_value, 0.025)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册