未验证 提交 72633164 编写于 作者: W whs 提交者: GitHub

Add hist and kl observer (#1679)

* Add histogram observer for PTQ
上级 b692d8ec
Global:
model_dir: './MobileNetV1_infer'
model_dir: './mobilenet_dbb_inference'
model_filename: 'inference.pdmodel'
params_filename: "inference.pdiparams"
batch_size: 128
data_dir: './ILSVRC2012_data_demo/ILSVRC2012/'
data_dir: './ILSVRC2012/'
img_size: 224
resize_size: 256
......@@ -31,12 +31,12 @@ def argsparser():
parser.add_argument(
'--config_path',
type=str,
default='./image_classification/configs/eval.yaml',
default='./configs/eval.yaml',
help="path of compression strategy config.")
parser.add_argument(
'--model_dir',
type=str,
default='./MobileNetV1_infer',
default='./mobilenet_dbb_inference',
help='model directory')
return parser
......@@ -65,6 +65,15 @@ def eval():
exe,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"])
features = None
for _var in val_program.list_vars():
print(f"meeting: {_var.name}")
if _var.name == "conv2d_98.tmp_1":
print(f"find {_var.name}")
features = _var
fetch_targets.append(features)
print('Loaded model from: {}'.format(global_config["model_dir"]))
val_loader = eval_reader(
......@@ -77,9 +86,13 @@ def eval():
for batch_id, (image, label) in enumerate(val_loader):
image = np.array(image)
label = np.array(label).astype('int64')
pred = exe.run(val_program,
pred = exe.run(
val_program,
feed={feed_target_names[0]: image},
fetch_list=fetch_targets)
features = np.array(pred[1])
print(f"feature shape: {features.shape}")
pred = np.array(pred[0])
label = np.array(label)
sort_array = pred.argsort(axis=1)
......@@ -92,6 +105,7 @@ def eval():
acc_num += 1
top_5 = float(acc_num) / len(label)
results.append([top_1, top_5])
break
result = np.mean(np.array(results), axis=0)
return result[0]
......@@ -107,10 +121,10 @@ def main(args):
global_config['model_dir'] = args.model_dir
global img_size, resize_size
img_size = int(global_config[
'img_size']) if 'img_size' in global_config else 224
resize_size = int(global_config[
'resize_size']) if 'resize_size' in global_config else 256
img_size = int(
global_config['img_size']) if 'img_size' in global_config else 224
resize_size = int(
global_config['resize_size']) if 'resize_size' in global_config else 256
result = eval()
print('Eval Top1:', result)
......
......@@ -359,8 +359,8 @@ class OFA(OFABase):
if isinstance(v, dict):
sample_cands[k] = self._sample_from_nestdict(
v, sample_type=sample_type, task=task, phase=phase)
elif isinstance(v, list) or isinstance(v, set) or isinstance(v,
tuple):
elif isinstance(v, list) or isinstance(v, set) or isinstance(
v, tuple):
if sample_type == 'largest':
sample_cands[k] = v[-1]
elif sample_type == 'smallest':
......@@ -413,8 +413,8 @@ class OFA(OFABase):
key = all_tokens.index(cand)
self.token_map[self.task][name][key] = cand
else:
raise NotImplementedError("Task {} not in ofa layers".format(
self.task))
raise NotImplementedError(
"Task {} not in ofa layers".format(self.task))
self.search_cands = []
for layer, t_map in self.token_map[self.task].items():
......@@ -610,8 +610,8 @@ class OFA(OFABase):
print(f"hit cpu in ofa-------------------------------")
place = paddle.CPUPlace()
else:
place = paddle.framework.core.CUDAPlace(p.gpu_device_id(
))
place = paddle.framework.core.CUDAPlace(
p.gpu_device_id())
t_value.set(pruned_state_dict[name], place)
if super_model_state_dict != None and len(super_model_state_dict) != 0:
......@@ -741,8 +741,7 @@ class OFA(OFABase):
### if skip_layers and same ss both have same layer,
### the layers related to this layer need to add to skip_layers
if self._skip_layers != None and self._param2key[
key] in self._skip_layers:
if self._skip_layers != None and self._param2key[key] in self._skip_layers:
self._skip_layers += [self._param2key[sk] for sk in ss]
per_ss = []
break
......@@ -794,8 +793,8 @@ class OFA(OFABase):
teacher_output = None
if self._add_teacher:
self._reset_hook_before_forward()
teacher_output = self.ofa_teacher_model.model.forward(*inputs,
**kwargs)
teacher_output = self.ofa_teacher_model.model.forward(
*inputs, **kwargs)
# ============================================================
# ==================== student process =====================
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .hist import HistObserver
from .kl import KLObserver
__all__ = ["HistObserver", "KLObserver"]
\ No newline at end of file
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Tuple
import paddle
import numpy as np
from .uniform import UniformObserver
class BaseHistObserver(UniformObserver):
"""
It is a base class of histogram observers defined some functions to
collects the values of multi batches to a histogram.
Args:
quant_bits (int): The number of bits for quantization.
sign (bool): Whether the quantized integer includes a sign.
symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric.
In symmetric quantization, the range of floating point values is relaxed to be symmetric
around zero and the zero-point is always 0.
bins_count(int): The number of equal-width bins.
"""
def __init__(self, quant_bits=8, bins_count=2048, sign=True,
symmetric=True):
super(BaseHistObserver, self).__init__(
quant_bits=quant_bits,
sign=sign,
symmetric=symmetric, )
self._bin_count = bins_count
self._upsample_bin_count = 64
self._hist_min = None
self._hist_max = None
self._hist = None
def _min_max(self, tensor):
"""" Get the min and max value of a tensor.
"""
return float(paddle.min(tensor).numpy()), float(
paddle.max(tensor).numpy())
def _init_hists(self, inputs):
"""" Initialize the histogram instance based on a tensor.
"""
_min, _max = self._min_max(inputs)
hist = None
if _max > _min:
hist, _ = np.histogram(
inputs.numpy(), range=(_min, _max), bins=self._bin_count)
hist.astype(np.float32)
return hist
def forward(self, inputs):
self._scale = None
self._zero_point = None
self._min = None
self._max = None
if self._hist_min is None or self._hist_max is None:
self._hist_min, self._hist_max = self._min_max(inputs)
self._hist = self._init_hists(inputs)
else:
new_min, new_max, new_hist = self._update_min_max_and_hist(
inputs,
self._hist_min,
self._hist_max,
self._hist,
self._bin_count,
self._upsample_bin_count, )
self._hist_min, self._hist_max = new_min, new_max
self._hist = new_hist
return inputs
def _update_min_max_and_hist(self, tensor, origin_min, origin_max,
origin_hist, bins_count, upsample_bins_count):
""" Update the histogram and its range based on the values of the target tensor.
Args:
tensor: The tensor used to update the histogram.
origin_min(float): The minimum of the original histogram's range.
origin_max(float): The max of the original histogram's range.
origin_hist: The original histogram.
bins_count(int): The number of histogram bins.
upsample_bins_count(int): The number of upsampled bins used to extend the histogram.
"""
_origin_min, _origin_max = origin_min, origin_max
_new_min, _new_max = self._min_max(tensor)
if (_new_max - _new_min) == 0.0:
return _origin_min, _origin_max, origin_hist
elif _origin_max - _origin_min == 0.0:
new_hist, _ = np.histogram(
tensor.numpy(), range=(_new_min, _new_max), bins=bins_count)
new_hist = new_hist.astype(np.float32)
return _new_min, _new_max, new_hist
elif _new_max <= _origin_max and _new_min >= _origin_min:
new_hist, _ = np.histogram(
tensor.numpy(),
range=(_origin_min, _origin_max),
bins=bins_count)
new_hist = new_hist.astype(np.float32)
new_hist += origin_hist
return _origin_min, _origin_max, new_hist
else:
_new_min = min(_new_min, _origin_min)
_new_max = max(_new_max, _origin_max)
_new_min, _new_max, downsample_bins_count, start_bin_idx = self._relax_min_max(
_new_min, _new_max, _origin_min, _origin_max, bins_count,
upsample_bins_count)
new_hist, _ = np.histogram(
tensor.numpy(), range=(_new_min, _new_max), bins=bins_count)
merged_histogram = self._merge_histograms(
new_hist, origin_hist, upsample_bins_count,
downsample_bins_count, start_bin_idx, bins_count)
return _new_min, _new_max, merged_histogram
def _merge_histograms(
self,
new_hist: np.ndarray,
origin_hist: np.ndarray,
upsample_bins_count: int,
downsample_bins_count: int,
start_bin_idx: int,
bins_count: int, ):
upsampled_histogram = np.repeat(origin_hist, upsample_bins_count)
expanded_hist = np.zeros(
(bins_count * downsample_bins_count), dtype=np.float32)
expanded_hist[start_bin_idx:bins_count * upsample_bins_count +
start_bin_idx] = upsampled_histogram
cumsumed_hist = np.cumsum(
expanded_hist,
dtype=np.float64)[downsample_bins_count - 1::downsample_bins_count]
shift_cumsumed_hist = np.zeros((bins_count), dtype=np.float64)
shift_cumsumed_hist[1:] = cumsumed_hist[0:-1]
sampled_hist = (
cumsumed_hist - shift_cumsumed_hist) / upsample_bins_count
new_hist = new_hist.astype(np.float32)
new_hist += sampled_hist.astype(np.float32)
return new_hist
def _relax_min_max(self, new_min, new_max, origin_min, origin_max,
bins_count,
upsample_bins_count) -> Tuple[float, float, int, int]:
_bin_width = (origin_max - origin_min) / (
bins_count * upsample_bins_count)
downsample_bins_count = int(
np.ceil((new_max - new_min) / (bins_count * _bin_width)))
error = downsample_bins_count * bins_count * _bin_width - (
new_max - new_min)
new_max += error
start_bin_idx = round((origin_min - new_min) / _bin_width)
return new_min, new_max, downsample_bins_count, start_bin_idx
@abc.abstractmethod
def cal_min_max(self) -> Tuple[float, float]:
""" Calculate the minimum and maximum based on the histogram. """
raise NotImplementedError("Please implement the abstract method.")
def cal_thresholds(self):
assert self._hist is not None
self._min, self._max = self.cal_min_max()
self._scale, self._zero_point = self.cal_scales_zero_points()
def min_value(self) -> float:
return self._min
def max_value(self) -> float:
return self._max
def bit_length(self):
return self._quant_bits
def quant_axis(self):
return -1
def scales(self):
if self._scale is None:
self.cal_thresholds()
return self._scale
def zero_points(self):
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
from .base_hist import BaseHistObserver
from paddle.quantization.factory import ObserverFactory
class HistObserver(ObserverFactory):
r"""
It collects tensor values into a histogram. And calculate quantization parameters
based on a percent ratio.
Args:
quant_bits (int): The number of bits for quantization.
bins_count(int): The number of equal-width bins.
percent(float): The percentage of bins that are retained when clipping the outliers.
sign (bool): Whether the quantized integer includes a sign.
symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric.
In symmetric quantization, the range of floating point values is relaxed to be symmetric
around zero and the zero-point is always 0.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import HistObserver
quanter = HistObserver()
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self,
quant_bits=8,
bins_count=2048,
percent=0.999,
sign=True,
symmetric=True):
super(HistObserver, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
percent=percent,
sign=sign,
symmetric=symmetric)
def _get_class(self):
return PercentHistObserverLayer
class PercentHistObserverLayer(BaseHistObserver):
r"""
It collects tensor values into a histogram. And calculate quantization parameters
based on a percent ratio.
"""
def __init__(self,
layer,
quant_bits=8,
bins_count=2048,
percent=0.999,
sign=True,
symmetric=True):
super(PercentHistObserverLayer, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
sign=sign,
symmetric=symmetric)
self._percent = percent
def _cal_min_max_by_percent(self):
hist = self._hist / np.sum(self._hist, dtype=np.float64)
cumsumed_hist = np.cumsum(hist)
max_idx = np.argwhere(cumsumed_hist >= self._percent)[0]
min_idx = np.argwhere(cumsumed_hist >= (1 - self._percent))[0]
bin_width = (self._hist_max - self._hist_min) / hist.shape[0]
_max = self._hist_min + float((max_idx - 0.5) * bin_width)
_min = self._hist_min + float((min_idx - 0.5) * bin_width)
return _min, _max
def cal_min_max(self):
return self._cal_min_max_by_percent()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import math
from .base_hist import BaseHistObserver
from paddle.quantization.factory import ObserverFactory
class KLObserver(ObserverFactory):
r"""
Calculate quantization parameters that minimize the Kullback–Leibler divergence
between the distribution of floating values and the distribution of quantized
floating values.
Args:
quant_bits (int): The number of bits for quantization.
bins_count(int): The number of equal-width bins.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import KLObserver
quanter = KLObserver()
q_config = QuantConfig(activation=quanter, weight=quanter)
"""
def __init__(self, quant_bits=8, bins_count=2048):
super(KLObserver, self).__init__(
quant_bits=quant_bits, bins_count=bins_count)
def _get_class(self):
return KLObserverLayer
class KLObserverLayer(BaseHistObserver):
"""
Per-tensor KL observer.
"""
def __init__(self, layer, quant_bits=8, bins_count=2048):
super(KLObserverLayer, self).__init__(
quant_bits=quant_bits,
bins_count=bins_count,
sign=True,
symmetric=True)
def _search_min_max_by_kl(self):
bin_width = (self._hist_max - self._hist_min) / self._bin_count
_max = cal_kl_threshold(self._hist, bin_width, self.bit_length())
return 0., _max
def cal_min_max(self):
return self._search_min_max_by_kl()
def expand_quantized_bins(quantized_bins, reference_bins):
'''
Expand hist bins.
'''
expanded_quantized_bins = [0] * len(reference_bins)
num_merged_bins = int(len(reference_bins) / len(quantized_bins))
j_start = 0
j_end = num_merged_bins
for idx in range(len(quantized_bins)):
zero_count = reference_bins[j_start:j_end].count(0)
num_merged_bins = j_end - j_start
if zero_count == num_merged_bins:
avg_bin_ele = 0
else:
avg_bin_ele = quantized_bins[idx] / (
num_merged_bins - zero_count + 0.0)
for idx1 in range(j_start, j_end):
expanded_quantized_bins[idx1] = (0 if reference_bins[idx1] == 0 else
avg_bin_ele)
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == len(quantized_bins) - 1:
j_end = len(reference_bins)
return expanded_quantized_bins
def safe_entropy(reference_distr_P, P_sum, candidate_distr_Q, Q_sum):
'''
Calculate the entropy.
'''
assert len(reference_distr_P) == len(candidate_distr_Q)
tmp_sum1 = 0
tmp_sum2 = 0
for idx in range(len(reference_distr_P)):
p_idx = reference_distr_P[idx]
q_idx = candidate_distr_Q[idx]
if p_idx == 0:
tmp_sum1 += 0
tmp_sum2 += 0
else:
if q_idx == 0:
_logger.error("Fatal error!, idx = " + str(idx) +
" qindex = 0! p_idx = " + str(p_idx))
tmp_sum1 += p_idx * (math.log(Q_sum * p_idx))
tmp_sum2 += p_idx * (math.log(P_sum * q_idx))
return (tmp_sum1 - tmp_sum2) / P_sum
def cal_kl_threshold(hist, bin_width, bits):
'''
Using the KL-divergenc method to get the more precise threshold.
Args:
hist(List): The hist of the tensor.
bin_width(float): The bin width for the hist.
bits(int): The quantization bits.
'''
assert hist.ndim == 1
hist_bins = hist.shape[0]
starting_iter = int((hist_bins - 1) * 0.5)
quant_range = 2**(bits - 1) - 1
P_sum = np.sum(np.array(hist).ravel())
min_kl_divergence = 0
min_kl_index = 0
kl_inited = False
for i in range(starting_iter, hist_bins):
reference_distr_P = hist[0:i].tolist()
outliers_count = sum(hist[i:])
if reference_distr_P[i - 1] == 0:
continue
reference_distr_P[i - 1] += outliers_count
reference_distr_bins = reference_distr_P[:]
candidate_distr_Q = hist[0:i].tolist()
num_merged_bins = int(i / quant_range)
candidate_distr_Q_quantized = [0] * quant_range
j_start = 0
j_end = num_merged_bins
for idx in range(quant_range):
candidate_distr_Q_quantized[idx] = sum(
candidate_distr_Q[j_start:j_end])
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == quant_range - 1:
j_end = i
candidate_distr_Q = expand_quantized_bins(candidate_distr_Q_quantized,
reference_distr_bins)
Q_sum = sum(candidate_distr_Q)
kl_divergence = safe_entropy(reference_distr_P, P_sum,
candidate_distr_Q, Q_sum)
if not kl_inited:
min_kl_divergence = kl_divergence
min_kl_index = i
kl_inited = True
elif kl_divergence < min_kl_divergence:
min_kl_divergence = kl_divergence
min_kl_index = i
else:
pass
if min_kl_index == 0:
while starting_iter > 0:
if hist[starting_iter] == 0:
starting_iter -= 1
continue
else:
break
min_kl_index = starting_iter
return (min_kl_index + 0.5) * bin_width
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Tuple
import numpy as np
from paddle.quantization.base_observer import BaseObserver
class UniformObserver(BaseObserver):
""" This is the base class for a uniform quantization observer, which provides
common functions for calculating the scale and zero-point used in uniform quantization.
Uniform quantization maps floating point values to integers, where the scale determines
the step size of the quantizer and the floating point zero is mapped to the zero-point,
an integer value ensuring that zero is quantized without error.
Args:
quant_bits (int): The number of bits for quantization.
sign (bool): Whether the quantized integer includes a sign.
symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric.
In symmetric quantization, the range of floating point values is relaxed to be symmetric
around zero and the zero-point is always 0.
"""
def __init__(
self,
quant_bits=8,
sign=True,
symmetric=True, ):
super(UniformObserver, self).__init__()
self._quant_bits = quant_bits
self._sign = sign
self._symmetric = symmetric
self._min = None
self._max = None
self._qmin = None
self._qmax = None
self._scale = None
self._zero_point = None
@property
def qmin_qmax(self):
""" Calculate the range of the quantized integer based on the specified
quant_bits, sign, and symmetric properties."""
if self._sign:
self._qmin = -2**(self.bit_length() - 1)
self._qmax = 2**(self.bit_length() - 1) - 1
else:
self._qmin = 0
self._qmax = 2**self.bit_length()
return self._qmin, self._qmax
@abc.abstractmethod
def min_value(self) -> float:
""" The minimum value of floating-point numbers."""
raise NotImplementedError(
"Please implement the abstract method to get the The minimum value of floating-point numbers."
)
@abc.abstractmethod
def max_value(self) -> float:
""" The maximum value of floating-point numbers."""
raise NotImplementedError(
"Please implement the abstract method to get the the maximum value value of floating-point numbers."
)
def cal_scales_zero_points(self) -> Tuple[float, float]:
""" Calculate the scales and zero points based on the min_value and max_value.
"""
assert self.min_value() is not None and self.max_value() is not None
_qmin, _qmax = self.qmin_qmax
# For one-sided distributions, the range (_min , _max ) is relaxed to include zero.
# It is important to ensure that common operations like zero padding do not cause quantization errors.
_min = min(self.min_value(), 0.)
_max = max(self.max_value(), 0.)
if self._symmetric:
self._scale = max(-_min, _max) / (float(_qmax - _qmin) / 2)
if self._sign:
self._zero_point = 0
else:
self._zero_point = (_qmax + _qmin) / 2
else:
self._scale = (_max - _min) / float(_qmax - _qmin)
self._zero_point = _qmin - round(_min / self._scale)
self._zero_point = np.clip(self._zero_point, _qmin, _qmax)
return self._scale, self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../../")
import os
import unittest
import paddle
import tempfile
from paddle.vision.models import resnet18
from paddle.quantization import QuantConfig
from paddle.quantization import PTQ
from paddleslim.quant.observers import HistObserver, KLObserver
from paddleslim.quant.observers.hist import PercentHistObserverLayer
from paddleslim.quant.observers.kl import KLObserverLayer
from paddle.nn.quant.format import LinearDequanter, LinearQuanter
class TestPTQWithHistObserver(unittest.TestCase):
def __init__(self, observer, observer_type, *args, **kvargs):
super(TestPTQWithHistObserver, self).__init__(*args, **kvargs)
self.observer = observer
self.observer_type = observer_type
def setUp(self):
paddle.set_device("cpu")
self.init_case()
self.dummy_input = paddle.rand([1, 3, 224, 224])
self.temp_dir = tempfile.TemporaryDirectory(dir="./")
self.path = os.path.join(self.temp_dir.name, 'qat')
def tearDown(self):
self.temp_dir.cleanup()
def runTest(self):
self.test_quantize()
self.test_convert()
def init_case(self):
# observer = HistObserver()
# self.observer_type = PercentHistObserverLayer
self.q_config = QuantConfig(activation=None, weight=None)
self.q_config.add_type_config(
paddle.nn.Conv2D, activation=self.observer, weight=self.observer)
def _count_layers(self, model, layer_type):
count = 0
for _layer in model.sublayers(True):
if isinstance(_layer, layer_type):
count += 1
return count
def test_quantize(self):
model = resnet18()
conv_count = self._count_layers(model, paddle.nn.Conv2D)
ptq = PTQ(self.q_config)
model.eval()
quant_model = ptq.quantize(model, inplace=False)
zero_input = paddle.zeros_like(self.dummy_input)
out = quant_model(zero_input)
out = quant_model(self.dummy_input)
out = quant_model(zero_input)
out = quant_model(self.dummy_input + 1.)
quantizer_cnt = self._count_layers(quant_model, self.observer_type)
self.assertEqual(quantizer_cnt, 2 * conv_count)
def test_convert(self):
model = resnet18()
conv_count = self._count_layers(model, paddle.nn.Conv2D)
ptq = PTQ(self.q_config)
model.eval()
quant_model = ptq.quantize(model, inplace=False)
out = quant_model(self.dummy_input)
converted_model = ptq.convert(quant_model, inplace=False)
# check count of LinearQuanter and LinearDequanter in dygraph
quantizer_count_in_dygraph = self._count_layers(converted_model,
LinearQuanter)
dequantizer_count_in_dygraph = self._count_layers(
converted_model, LinearDequanter)
self.assertEqual(quantizer_count_in_dygraph, conv_count)
self.assertEqual(dequantizer_count_in_dygraph, conv_count * 2)
observer_suite = unittest.TestSuite()
observer_suite.addTest(
TestPTQWithHistObserver(
observer=HistObserver(sign=True, symmetric=True),
observer_type=PercentHistObserverLayer))
observer_suite.addTest(
TestPTQWithHistObserver(
observer=HistObserver(sign=False, symmetric=True),
observer_type=PercentHistObserverLayer))
observer_suite.addTest(
TestPTQWithHistObserver(
observer=HistObserver(sign=True, symmetric=False),
observer_type=PercentHistObserverLayer))
observer_suite.addTest(
TestPTQWithHistObserver(
observer=HistObserver(sign=False, symmetric=False),
observer_type=PercentHistObserverLayer))
observer_suite.addTest(
TestPTQWithHistObserver(
observer=KLObserver(bins_count=256), observer_type=KLObserverLayer))
if __name__ == '__main__':
runner = unittest.TextTestRunner(verbosity=2)
runner.run(observer_suite)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册