未验证 提交 9254183d 编写于 作者: C cc 提交者: GitHub

Refine the dygraph ptq and the module of calculating KL threshold (#33898)

* refine ptq according comments
* reuse the module to calculate kl threshold
上级 0b911330
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
__all__ = ['cal_kl_threshold']
def expand_quantized_bins(quantized_bins, reference_bins):
'''
Expand hist bins.
'''
expanded_quantized_bins = [0] * len(reference_bins)
num_merged_bins = int(len(reference_bins) / len(quantized_bins))
j_start = 0
j_end = num_merged_bins
for idx in range(len(quantized_bins)):
zero_count = reference_bins[j_start:j_end].count(0)
num_merged_bins = j_end - j_start
if zero_count == num_merged_bins:
avg_bin_ele = 0
else:
avg_bin_ele = quantized_bins[idx] / (
num_merged_bins - zero_count + 0.0)
for idx1 in range(j_start, j_end):
expanded_quantized_bins[idx1] = (0 if reference_bins[idx1] == 0 else
avg_bin_ele)
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == len(quantized_bins) - 1:
j_end = len(reference_bins)
return expanded_quantized_bins
def safe_entropy(reference_distr_P, P_sum, candidate_distr_Q, Q_sum):
'''
Calculate the entropy.
'''
assert len(reference_distr_P) == len(candidate_distr_Q)
tmp_sum1 = 0
tmp_sum2 = 0
for idx in range(len(reference_distr_P)):
p_idx = reference_distr_P[idx]
q_idx = candidate_distr_Q[idx]
if p_idx == 0:
tmp_sum1 += 0
tmp_sum2 += 0
else:
if q_idx == 0:
_logger.error("Fatal error!, idx = " + str(idx) +
" qindex = 0! p_idx = " + str(p_idx))
tmp_sum1 += p_idx * (math.log(Q_sum * p_idx))
tmp_sum2 += p_idx * (math.log(P_sum * q_idx))
return (tmp_sum1 - tmp_sum2) / P_sum
def cal_kl_threshold(hist, bin_width, bits):
'''
Using the KL-divergenc method to get the more precise threshold.
Args:
hist(List): The hist of the tensor.
bin_width(float): The bin width for the hist.
bits(int): The quantization bits.
'''
assert hist.ndim == 1
hist_bins = hist.shape[0]
starting_iter = int((hist_bins - 1) * 0.5)
quant_range = 2**(bits - 1) - 1
P_sum = np.sum(np.array(hist).ravel())
min_kl_divergence = 0
min_kl_index = 0
kl_inited = False
for i in range(starting_iter, hist_bins):
reference_distr_P = hist[0:i].tolist()
outliers_count = sum(hist[i:])
if reference_distr_P[i - 1] == 0:
continue
reference_distr_P[i - 1] += outliers_count
reference_distr_bins = reference_distr_P[:]
candidate_distr_Q = hist[0:i].tolist()
num_merged_bins = int(i / quant_range)
candidate_distr_Q_quantized = [0] * quant_range
j_start = 0
j_end = num_merged_bins
for idx in range(quant_range):
candidate_distr_Q_quantized[idx] = sum(candidate_distr_Q[j_start:
j_end])
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == quant_range - 1:
j_end = i
candidate_distr_Q = expand_quantized_bins(candidate_distr_Q_quantized,
reference_distr_bins)
Q_sum = sum(candidate_distr_Q)
kl_divergence = safe_entropy(reference_distr_P, P_sum,
candidate_distr_Q, Q_sum)
if not kl_inited:
min_kl_divergence = kl_divergence
min_kl_index = i
kl_inited = True
elif kl_divergence < min_kl_divergence:
min_kl_divergence = kl_divergence
min_kl_index = i
else:
pass
if min_kl_index == 0:
while starting_iter > 0:
if hist[starting_iter] == 0:
starting_iter -= 1
continue
else:
break
min_kl_index = starting_iter
return (min_kl_index + 0.5) * bin_width
...@@ -32,16 +32,18 @@ _logger = get_logger( ...@@ -32,16 +32,18 @@ _logger = get_logger(
class ImperativePTQ(object): class ImperativePTQ(object):
""" """
Applying static post_training quantization to the dgraph model. Static post training quantization.
""" """
def __init__(self, quant_config=ptq_config.default_ptq_config): def __init__(self, quant_config=ptq_config.default_ptq_config):
""" """
Constructor. Constructor.
Args: Args:
algo(str): The algorithm in post_training quantizaion to be used. quant_config(PTQConfig): the config of post training quantization.
activation_bits(int): quantization bit number for activations. The config has weight_quantizer and activation_quantizer.
weight_bits(int): quantization bit number for weights. In default, the weight_quantizer and activation_quantizer are
AbsmaxQuantizer.
""" """
super(ImperativePTQ, self).__init__() super(ImperativePTQ, self).__init__()
...@@ -55,28 +57,30 @@ class ImperativePTQ(object): ...@@ -55,28 +57,30 @@ class ImperativePTQ(object):
Args: Args:
model(paddle.nn.Layer): The model to be quantized. model(paddle.nn.Layer): The model to be quantized.
inplace(bool): Whether apply quantization to the input model.
Default: False.
Returns: Returns:
None quantized_model(paddle.nn.Layer): The quantized model.
""" """
assert isinstance(model, paddle.nn.Layer), \ assert isinstance(model, paddle.nn.Layer), \
"The model must be the instance of paddle.nn.Layer." "The model must be the instance of paddle.nn.Layer."
if not inplace: if not inplace:
model = copy.deepcopy(model) new_model = copy.deepcopy(model)
for name, layer in model.named_sublayers(): for name, layer in new_model.named_sublayers():
if PTQRegistry.is_supported_layer(layer) \ if PTQRegistry.is_supported_layer(layer) \
and utils.is_leaf_layer(layer): and utils.is_leaf_layer(layer):
quant_config = copy.deepcopy(self._quant_config) quant_config = copy.deepcopy(self._quant_config)
layer._quant_config = quant_config layer._quant_config = quant_config
hook = ptq_hooks.quant_forward_post_hook hook = ptq_hooks.quant_forward_post_hook
hook_handle = layer.register_forward_post_hook(hook) quant_hook_handle = layer.register_forward_post_hook(hook)
quant_config.hook_handle = hook_handle quant_config.quant_hook_handle = quant_hook_handle
layer._forward_post_hooks.move_to_end( layer._forward_post_hooks.move_to_end(
hook_handle._hook_id, last=False) quant_hook_handle._hook_id, last=False)
return model return new_model
def convert(self, model): def convert(self, model):
""" """
...@@ -85,7 +89,7 @@ class ImperativePTQ(object): ...@@ -85,7 +89,7 @@ class ImperativePTQ(object):
Args: Args:
model(paddle.nn.Layer): The model to be quantized. model(paddle.nn.Layer): The model to be quantized.
Returns: Returns:
None converted_model(paddle.nn.Layer): The converted model.
""" """
assert isinstance(model, paddle.nn.Layer), \ assert isinstance(model, paddle.nn.Layer), \
"The input model must be the instance of paddle.nn.Layer." "The input model must be the instance of paddle.nn.Layer."
...@@ -96,7 +100,7 @@ class ImperativePTQ(object): ...@@ -96,7 +100,7 @@ class ImperativePTQ(object):
assert hasattr(sub_layer, "_quant_config") assert hasattr(sub_layer, "_quant_config")
quant_config = sub_layer._quant_config quant_config = sub_layer._quant_config
quant_config.hook_handle.remove() quant_config.quant_hook_handle.remove()
quant_config.in_act_quantizer.cal_thresholds() quant_config.in_act_quantizer.cal_thresholds()
quant_config.out_act_quantizer.cal_thresholds() quant_config.out_act_quantizer.cal_thresholds()
......
...@@ -29,6 +29,15 @@ class PTQConfig(object): ...@@ -29,6 +29,15 @@ class PTQConfig(object):
""" """
def __init__(self, activation_quantizer, weight_quantizer): def __init__(self, activation_quantizer, weight_quantizer):
"""
Constructor.
Args:
activation_quantizer(BaseQuantizer): The activation quantizer.
It should be the instance of BaseQuantizer.
weight_quantizer(BaseQuantizer): The weight quantizer.
It should be the instance of BaseQuantizer.
"""
super(PTQConfig, self).__init__() super(PTQConfig, self).__init__()
assert isinstance(activation_quantizer, BaseQuantizer) assert isinstance(activation_quantizer, BaseQuantizer)
...@@ -38,7 +47,7 @@ class PTQConfig(object): ...@@ -38,7 +47,7 @@ class PTQConfig(object):
self.out_act_quantizer = copy.deepcopy(activation_quantizer) self.out_act_quantizer = copy.deepcopy(activation_quantizer)
self.wt_quantizer = copy.deepcopy(weight_quantizer) self.wt_quantizer = copy.deepcopy(weight_quantizer)
self.hook_handle = None self.quant_hook_handle = None
default_ptq_config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer()) default_ptq_config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer())
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import paddle import paddle
from . import utils from . import utils
from ..cal_kl_threshold import cal_kl_threshold
__all__ = [ __all__ = [
'BaseQuantizer', 'BaseQuantizer',
...@@ -256,6 +257,8 @@ class KLQuantizer(BaseHistQuantizer): ...@@ -256,6 +257,8 @@ class KLQuantizer(BaseHistQuantizer):
if self.hists[idx] is None: if self.hists[idx] is None:
self.thresholds.append(self.abs_max_vals[idx]) self.thresholds.append(self.abs_max_vals[idx])
else: else:
threshold = utils.cal_kl_scaling_factor( hist = self.hists[idx]
self.hists[idx], self.abs_max_vals[idx], self.quant_bits) abs_max_val = self.abs_max_vals[idx]
bin_width = abs_max_val / hist.shape[0]
threshold = cal_kl_threshold(hist, bin_width, self.quant_bits)
self.thresholds.append(threshold) self.thresholds.append(threshold)
...@@ -147,113 +147,10 @@ def is_leaf_layer(layer): ...@@ -147,113 +147,10 @@ def is_leaf_layer(layer):
and len(layer.sublayers()) == 0 and len(layer.sublayers()) == 0
def expand_quantized_bins(quantized_bins, reference_bins): def fp_numpy_to_naive(x_np):
""" """
Convert numpy to float or list.
""" """
expanded_quantized_bins = [0] * len(reference_bins)
num_merged_bins = int(len(reference_bins) / len(quantized_bins))
j_start = 0
j_end = num_merged_bins
for idx in range(len(quantized_bins)):
zero_count = reference_bins[j_start:j_end].count(0)
num_merged_bins = j_end - j_start
if zero_count == num_merged_bins:
avg_bin_ele = 0
else:
avg_bin_ele = quantized_bins[idx] / (
num_merged_bins - zero_count + 0.0)
for idx1 in range(j_start, j_end):
expanded_quantized_bins[idx1] = (0 if reference_bins[idx1] == 0 else
avg_bin_ele)
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == len(quantized_bins) - 1:
j_end = len(reference_bins)
return expanded_quantized_bins
def safe_entropy(reference_distr_P, P_sum, candidate_distr_Q, Q_sum):
'''
Calculate the entropy.
'''
assert len(reference_distr_P) == len(candidate_distr_Q)
tmp_sum1 = 0
tmp_sum2 = 0
for idx in range(len(reference_distr_P)):
p_idx = reference_distr_P[idx]
q_idx = candidate_distr_Q[idx]
if p_idx == 0:
tmp_sum1 += 0
tmp_sum2 += 0
else:
if q_idx == 0:
_logger.error("Fatal error!, idx = " + str(idx) +
" qindex = 0! p_idx = " + str(p_idx))
tmp_sum1 += p_idx * (math.log(Q_sum * p_idx))
tmp_sum2 += p_idx * (math.log(P_sum * q_idx))
return (tmp_sum1 - tmp_sum2) / P_sum
def cal_kl_scaling_factor(hist, abs_max, bits):
'''
Using the KL-divergenc method to get the more precise scaling factor.
'''
assert hist.ndim == 1
hist_bins = hist.shape[0]
starting_iter = int((hist_bins - 1) * 0.5)
bin_width = abs_max / hist_bins
quant_range = 2**(bits - 1) - 1
P_sum = np.sum(np.array(hist).ravel())
min_kl_divergence = 0
min_kl_index = 0
kl_inited = False
for i in range(starting_iter, hist_bins):
reference_distr_P = hist[0:i].tolist()
outliers_count = sum(hist[i:])
if reference_distr_P[i - 1] == 0:
continue
reference_distr_P[i - 1] += outliers_count
reference_distr_bins = reference_distr_P[:]
candidate_distr_Q = hist[0:i].tolist()
num_merged_bins = int(i / quant_range)
candidate_distr_Q_quantized = [0] * quant_range
j_start = 0
j_end = num_merged_bins
for idx in range(quant_range):
candidate_distr_Q_quantized[idx] = sum(candidate_distr_Q[j_start:
j_end])
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == quant_range - 1:
j_end = i
candidate_distr_Q = expand_quantized_bins(candidate_distr_Q_quantized,
reference_distr_bins)
Q_sum = sum(candidate_distr_Q)
kl_divergence = safe_entropy(reference_distr_P, P_sum,
candidate_distr_Q, Q_sum)
if not kl_inited:
min_kl_divergence = kl_divergence
min_kl_index = i
kl_inited = True
elif kl_divergence < min_kl_divergence:
min_kl_divergence = kl_divergence
min_kl_index = i
else:
pass
if min_kl_index == 0:
while starting_iter > 0:
if hist[starting_iter] == 0:
starting_iter -= 1
continue
else:
break
min_kl_index = starting_iter
return (min_kl_index + 0.5) * bin_width
def fp_numpy_to_naive(x_np):
if x_np.size == 1: if x_np.size == 1:
return float(x_np) return float(x_np)
else: else:
......
...@@ -33,6 +33,7 @@ from .quantization_pass import _get_op_output_var_names ...@@ -33,6 +33,7 @@ from .quantization_pass import _get_op_output_var_names
from .quantization_pass import _get_output_name_index from .quantization_pass import _get_output_name_index
from .quantization_pass import _get_input_name_index from .quantization_pass import _get_input_name_index
from .quantization_pass import _channelwise_quant_axis1_ops from .quantization_pass import _channelwise_quant_axis1_ops
from .cal_kl_threshold import cal_kl_threshold
__all__ = ['PostTrainingQuantization', 'WeightQuantization'] __all__ = ['PostTrainingQuantization', 'WeightQuantization']
...@@ -763,8 +764,9 @@ class PostTrainingQuantization(object): ...@@ -763,8 +764,9 @@ class PostTrainingQuantization(object):
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
hist, hist_edeges = self._sampling_act_histogram[var_name] hist, hist_edeges = self._sampling_act_histogram[var_name]
if self._algo == "KL": if self._algo == "KL":
bin_width = hist_edeges[1] - hist_edeges[0]
self._quantized_var_threshold[var_name] = \ self._quantized_var_threshold[var_name] = \
self._get_kl_scaling_factor(hist, hist_edeges) cal_kl_threshold(hist, bin_width, self._activation_bits)
elif self._algo == "hist": elif self._algo == "hist":
self._quantized_var_threshold[var_name] = \ self._quantized_var_threshold[var_name] = \
self._get_hist_scaling_factor(hist, hist_edeges) self._get_hist_scaling_factor(hist, hist_edeges)
...@@ -935,107 +937,6 @@ class PostTrainingQuantization(object): ...@@ -935,107 +937,6 @@ class PostTrainingQuantization(object):
bin_width = hist_edges[1] - hist_edges[0] bin_width = hist_edges[1] - hist_edges[0]
return (hist_index - 0.5) * bin_width return (hist_index - 0.5) * bin_width
def _get_kl_scaling_factor(self, hist, hist_edeges):
'''
Using the KL-divergenc method to get the more precise scaling factor.
'''
num_quantized_bins = 2**(self._activation_bits - 1) - 1
ending_iter = self._histogram_bins - 1
starting_iter = int(ending_iter * 0.7)
bin_width = hist_edeges[1] - hist_edeges[0]
P_sum = np.sum(np.array(hist).ravel())
min_kl_divergence = 0
min_kl_index = 0
kl_inited = False
for i in range(starting_iter, ending_iter + 1):
reference_distr_P = hist[0:i].tolist()
outliers_count = sum(hist[i:2048])
if reference_distr_P[i - 1] == 0:
continue
reference_distr_P[i - 1] += outliers_count
reference_distr_bins = reference_distr_P[:]
candidate_distr_Q = hist[0:i].tolist()
num_merged_bins = int(i / num_quantized_bins)
candidate_distr_Q_quantized = [0] * num_quantized_bins
j_start = 0
j_end = num_merged_bins
for idx in range(num_quantized_bins):
candidate_distr_Q_quantized[idx] = sum(candidate_distr_Q[
j_start:j_end])
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == num_quantized_bins - 1:
j_end = i
candidate_distr_Q = self._expand_quantized_bins(
candidate_distr_Q_quantized, reference_distr_bins)
Q_sum = sum(candidate_distr_Q)
kl_divergence = self._safe_entropy(reference_distr_P, P_sum,
candidate_distr_Q, Q_sum)
if not kl_inited:
min_kl_divergence = kl_divergence
min_kl_index = i
kl_inited = True
elif kl_divergence < min_kl_divergence:
min_kl_divergence = kl_divergence
min_kl_index = i
else:
pass
if min_kl_index == 0:
while starting_iter > 0:
if hist[starting_iter] == 0:
starting_iter -= 1
continue
else:
break
min_kl_index = starting_iter
return (min_kl_index + 0.5) * bin_width
def _expand_quantized_bins(self, quantized_bins, reference_bins):
'''
'''
expanded_quantized_bins = [0] * len(reference_bins)
num_merged_bins = int(len(reference_bins) / len(quantized_bins))
j_start = 0
j_end = num_merged_bins
for idx in range(len(quantized_bins)):
zero_count = reference_bins[j_start:j_end].count(0)
num_merged_bins = j_end - j_start
if zero_count == num_merged_bins:
avg_bin_ele = 0
else:
avg_bin_ele = quantized_bins[idx] / (
num_merged_bins - zero_count + 0.0)
for idx1 in range(j_start, j_end):
expanded_quantized_bins[idx1] = (0 if reference_bins[idx1] == 0
else avg_bin_ele)
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == len(quantized_bins) - 1:
j_end = len(reference_bins)
return expanded_quantized_bins
def _safe_entropy(self, reference_distr_P, P_sum, candidate_distr_Q, Q_sum):
'''
Calculate the entropy.
'''
assert len(reference_distr_P) == len(candidate_distr_Q)
tmp_sum1 = 0
tmp_sum2 = 0
for idx in range(len(reference_distr_P)):
p_idx = reference_distr_P[idx]
q_idx = candidate_distr_Q[idx]
if p_idx == 0:
tmp_sum1 += 0
tmp_sum2 += 0
else:
if q_idx == 0:
_logger.error("Fatal error!, idx = " + str(idx) +
" qindex = 0! p_idx = " + str(p_idx))
tmp_sum1 += p_idx * (math.log(Q_sum * p_idx))
tmp_sum2 += p_idx * (math.log(P_sum * q_idx))
return (tmp_sum1 - tmp_sum2) / P_sum
class WeightQuantization(object): class WeightQuantization(object):
_supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] _supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
......
...@@ -208,24 +208,26 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -208,24 +208,26 @@ class TestImperativePTQ(unittest.TestCase):
model_state_dict = paddle.load(params_path) model_state_dict = paddle.load(params_path)
model.set_state_dict(model_state_dict) model.set_state_dict(model_state_dict)
self.ptq.quantize(model, inplace=True) quant_model = self.ptq.quantize(model)
acc_top1 = self.model_test(model, self.batch_num, self.batch_size) acc_top1 = self.model_test(quant_model, self.batch_num,
self.batch_size)
print('acc_top1: %s' % acc_top1) print('acc_top1: %s' % acc_top1)
self.assertTrue( self.assertTrue(
acc_top1 > self.eval_acc_top1, acc_top1 > self.eval_acc_top1,
msg="The test acc {%f} is less than {%f}." % msg="The test acc {%f} is less than {%f}." %
(acc_top1, self.eval_acc_top1)) (acc_top1, self.eval_acc_top1))
self.ptq.convert(model) final_model = self.ptq.convert(quant_model)
self.check_thresholds(model) self.check_thresholds(final_model)
input_spec = [ input_spec = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32') shape=[None, 1, 28, 28], dtype='float32')
] ]
paddle.jit.save(layer=model, path=self.save_path, input_spec=input_spec) paddle.jit.save(
layer=final_model, path=self.save_path, input_spec=input_spec)
print('Quantized model saved in {%s}' % self.save_path) print('Quantized model saved in {%s}' % self.save_path)
end_time = time.time() end_time = time.time()
...@@ -233,9 +235,6 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -233,9 +235,6 @@ class TestImperativePTQ(unittest.TestCase):
class TestImperativePTQHist(TestImperativePTQ): class TestImperativePTQHist(TestImperativePTQ):
"""
"""
def set_vars(self): def set_vars(self):
config = PTQConfig(HistQuantizer(), AbsmaxQuantizer()) config = PTQConfig(HistQuantizer(), AbsmaxQuantizer())
self.ptq = ImperativePTQ(config) self.ptq = ImperativePTQ(config)
...@@ -257,9 +256,6 @@ class TestImperativePTQHist(TestImperativePTQ): ...@@ -257,9 +256,6 @@ class TestImperativePTQHist(TestImperativePTQ):
class TestImperativePTQKL(TestImperativePTQ): class TestImperativePTQKL(TestImperativePTQ):
"""
"""
def set_vars(self): def set_vars(self):
config = PTQConfig(KLQuantizer(), PerChannelAbsmaxQuantizer()) config = PTQConfig(KLQuantizer(), PerChannelAbsmaxQuantizer())
self.ptq = ImperativePTQ(config) self.ptq = ImperativePTQ(config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册