未验证 提交 2b6fc108 编写于 作者: C cc 提交者: GitHub

Dygraph post trainging quantization (#33445)

* dygraph post training quantization

* refine the ptq config

* refine ptq quantizer
上级 1b0c5ef2
......@@ -20,6 +20,22 @@ from .quant_nn import *
from . import qat
from .qat import *
from . import ptq
from .ptq import *
from . import ptq_config
from .ptq_config import *
from . import ptq_quantizer
from .ptq_quantizer import *
from . import ptq_registry
from .ptq_registry import *
__all__ = []
__all__ += quant_nn.__all__
__all__ += qat.__all__
__all__ += ptq.__all__
__all__ += ptq_config.__all__
__all__ += ptq_quantizer.__all__
__all__ += ptq_registry.__all__
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import copy
import numpy as np
import paddle
from paddle.fluid.log_helper import get_logger
from . import utils
from . import ptq_hooks
from . import ptq_config
from .ptq_registry import PTQRegistry
__all__ = ['ImperativePTQ']
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class ImperativePTQ(object):
"""
Applying static post_training quantization to the dgraph model.
"""
def __init__(self, quant_config=ptq_config.default_ptq_config):
"""
Constructor.
Args:
algo(str): The algorithm in post_training quantizaion to be used.
activation_bits(int): quantization bit number for activations.
weight_bits(int): quantization bit number for weights.
"""
super(ImperativePTQ, self).__init__()
assert isinstance(quant_config, ptq_config.PTQConfig)
self._quant_config = quant_config
def quantize(self, model, inplace=False):
"""
Add hook to the leaf layer to calculate the threshold of inputs and outputs.
Args:
model(paddle.nn.Layer): The model to be quantized.
Returns:
None
"""
assert isinstance(model, paddle.nn.Layer), \
"The model must be the instance of paddle.nn.Layer."
if not inplace:
model = copy.deepcopy(model)
for name, layer in model.named_sublayers():
if PTQRegistry.is_supported_layer(layer) \
and utils.is_leaf_layer(layer):
quant_config = copy.deepcopy(self._quant_config)
layer._quant_config = quant_config
hook = ptq_hooks.quant_forward_post_hook
hook_handle = layer.register_forward_post_hook(hook)
quant_config.hook_handle = hook_handle
layer._forward_post_hooks.move_to_end(
hook_handle._hook_id, last=False)
return model
def convert(self, model):
"""
Process the scales and remove the hooks.
Args:
model(paddle.nn.Layer): The model to be quantized.
Returns:
None
"""
assert isinstance(model, paddle.nn.Layer), \
"The input model must be the instance of paddle.nn.Layer."
for name, sub_layer in model.named_sublayers():
if PTQRegistry.is_supported_layer(sub_layer) \
and utils.is_leaf_layer(sub_layer):
assert hasattr(sub_layer, "_quant_config")
quant_config = sub_layer._quant_config
quant_config.hook_handle.remove()
quant_config.in_act_quantizer.cal_thresholds()
quant_config.out_act_quantizer.cal_thresholds()
# get weight thresholds
if isinstance(sub_layer, tuple(utils.fake_quant_input_layers)):
weights = (sub_layer.weight, )
quant_config.wt_quantizer.sample_data(sub_layer, weights)
# TODO (jc):
# save input activation threshold and quant bits
return model
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
import abc
import copy
import paddle
from .ptq_quantizer import *
__all__ = ['PTQConfig', 'default_ptq_config']
class PTQConfig(object):
"""
The PTQ config shows how to quantize the inputs and outputs.
"""
def __init__(self, activation_quantizer, weight_quantizer):
super(PTQConfig, self).__init__()
assert isinstance(activation_quantizer, BaseQuantizer)
assert isinstance(weight_quantizer, BaseQuantizer)
self.in_act_quantizer = copy.deepcopy(activation_quantizer)
self.out_act_quantizer = copy.deepcopy(activation_quantizer)
self.wt_quantizer = copy.deepcopy(weight_quantizer)
self.hook_handle = None
default_ptq_config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer())
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import math
import numpy as np
from . import ptq_config
def quant_forward_post_hook(layer, inputs, outputs):
"""
The forward_post_hook for PTQ.
"""
assert hasattr(layer, '_quant_config'), \
"The layer should have _quant_config attr"
layer._quant_config.in_act_quantizer.sample_data(layer, inputs)
layer._quant_config.out_act_quantizer.sample_data(layer, (outputs, ))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
import abc
import copy
import math
import numpy as np
import paddle
from . import utils
__all__ = [
'BaseQuantizer',
'AbsmaxQuantizer',
'PerChannelAbsmaxQuantizer',
'KLQuantizer',
'HistQuantizer',
]
def abs_max_value(tensor):
return float(paddle.max(paddle.abs(tensor)).numpy())
def merge_max_value(old, new):
"""
Merge the max element one by one in two lists.
"""
assert isinstance(old, list) and isinstance(new, list)
if old != []:
assert len(old) == len(new)
for i in range(len(old)):
assert type(old[i]) == type(new[i])
if isinstance(old[i], list):
new[i] = merge_max_value(old[i], new[i])
else:
new[i] = old[i] if new[i] < old[i] else new[i]
return new
def combine_abs_max_and_hist(tensor, origin_max, origin_hist, bins,
upsample_bins):
"""
"""
new_max = abs_max_value(tensor)
if new_max == 0.0:
return origin_max, origin_hist
elif origin_max == 0.0:
new_hist, _ = np.histogram(
paddle.abs(tensor).numpy(), range=(0, new_max), bins=bins)
new_hist = new_hist.astype(np.float32)
return new_max, new_hist
elif new_max <= origin_max:
new_hist, _ = np.histogram(
paddle.abs(tensor).numpy(), range=(0, origin_max), bins=bins)
new_hist = new_hist.astype(np.float32)
new_hist += origin_hist
return origin_max, new_hist
else:
# bin_width = origin_max / (bins * upsample_bins)
# = new_max / (bins * downsample_bins)
bin_width = origin_max / (bins * upsample_bins)
downsampe_bins = int(math.ceil(new_max / (bins * bin_width)))
new_max = bins * bin_width * downsampe_bins
upsampled_hist = np.repeat(origin_hist, upsample_bins)
expanded_hist = np.zeros((bins * downsampe_bins), dtype=np.float32)
expanded_hist[0:bins * upsample_bins] = upsampled_hist
cumsumed_hist = np.cumsum(
expanded_hist, dtype=np.float64)[downsampe_bins - 1::downsampe_bins]
shift_cumsumed_hist = np.zeros((bins), dtype=np.float64)
shift_cumsumed_hist[1:] = cumsumed_hist[0:-1]
sampled_hist = (cumsumed_hist - shift_cumsumed_hist) / upsample_bins
sampled_hist = sampled_hist.astype(np.float32)
new_hist, _ = np.histogram(
paddle.abs(tensor).numpy(), range=(0, new_max), bins=bins)
new_hist = new_hist.astype(np.float32)
new_hist += sampled_hist
return new_max, new_hist
@six.add_metaclass(abc.ABCMeta)
class BaseQuantizer(object):
"""
Base quantizer for activation and weight.
"""
def __init__(self, quant_bits=8):
super(BaseQuantizer, self).__init__()
assert isinstance(quant_bits, int)
assert quant_bits > 0 and quant_bits <= 16
self.quant_bits = quant_bits
self.thresholds = []
@abc.abstractmethod
def sample_data(self, layer, tensors):
pass
@abc.abstractmethod
def cal_thresholds(self):
pass
class AbsmaxQuantizer(BaseQuantizer):
"""
Per-tensor abs max quantizer.
"""
def __init__(self, quant_bits=8):
super(AbsmaxQuantizer, self).__init__(quant_bits)
def sample_data(self, layer, tensors):
assert isinstance(tensors, tuple)
abs_max_vals = [abs_max_value(t) for t in tensors]
self.thresholds = merge_max_value(self.thresholds, abs_max_vals)
def cal_thresholds(self):
pass
class PerChannelAbsmaxQuantizer(BaseQuantizer):
"""
Per channel abs max quantizer.
"""
def __init__(self, quant_bits=8):
super(PerChannelAbsmaxQuantizer, self).__init__(quant_bits)
def sample_data(self, layer, tensors):
assert isinstance(layer, paddle.nn.Layer)
assert isinstance(tensors, tuple)
abs_max_vals_list = []
for idx, tensor in enumerate(tensors):
if isinstance(layer, tuple(utils.spec_channel_axis_layers)):
abs_max_vals = [
abs_max_value(tensor[:, i]) for i in range(tensor.shape[1])
]
abs_max_vals_list.append(abs_max_vals)
else:
abs_max_vals = [
abs_max_value(tensor[i]) for i in range(tensor.shape[0])
]
abs_max_vals_list.append(abs_max_vals)
self.thresholds = merge_max_value(self.thresholds, abs_max_vals_list)
def cal_thresholds(self):
pass
@six.add_metaclass(abc.ABCMeta)
class BaseHistQuantizer(BaseQuantizer):
"""
"""
def __init__(self, quant_bits=8, bins=1024, upsample_bins=64):
super(BaseHistQuantizer, self).__init__(quant_bits)
self.bins = bins
self.upsample_bins = upsample_bins
self.abs_max_vals = []
self.hists = []
def sample_data(self, layer, tensors):
assert isinstance(tensors, tuple)
if self.abs_max_vals == []:
abs_max_vals = [abs_max_value(t) for t in tensors]
self.abs_max_vals = abs_max_vals
for idx, tensor in enumerate(tensors):
if abs_max_vals[idx] == 0.0:
self.hists.append(None)
else:
hist, _ = np.histogram(
paddle.abs(tensor).numpy(),
range=(0., abs_max_vals[idx]),
bins=self.bins)
hist = hist.astype(np.float32)
self.hists.append(hist)
else:
assert len(self.abs_max_vals) == len(tensors)
assert len(self.hists) == len(tensors)
for idx, tensor in enumerate(tensors):
new_abs_max, new_hist = combine_abs_max_and_hist(
tensor, self.abs_max_vals[idx], self.hists[idx], self.bins,
self.upsample_bins)
self.abs_max_vals[idx] = new_abs_max
self.hists[idx] = new_hist
@abc.abstractmethod
def cal_thresholds(self):
pass
class HistQuantizer(BaseHistQuantizer):
"""
"""
def __init__(self,
quant_bits=8,
bins=1024,
upsample_bins=64,
hist_percent=0.99999):
super(HistQuantizer, self).__init__(quant_bits, bins, upsample_bins)
self.hist_percent = hist_percent
def cal_thresholds(self):
def _helper(abs_max, hist, percent):
assert hist.ndim == 1 and percent < 1.0
hist = hist / np.sum(hist, dtype=np.float64)
cumsumed_hist = np.cumsum(hist)
index = np.argwhere(cumsumed_hist >= percent)[0]
return float((index - 0.5) * (abs_max / hist.shape[0]))
for idx in range(len(self.hists)):
if self.hists[idx] is None:
self.thresholds.append(self.abs_max_vals[idx])
else:
threshold = _helper(self.abs_max_vals[idx], self.hists[idx],
self.hist_percent)
self.thresholds.append(threshold)
class KLQuantizer(BaseHistQuantizer):
"""
"""
def __init__(self, quant_bits=8, bins=1024, upsample_bins=64):
super(KLQuantizer, self).__init__(quant_bits, bins, upsample_bins)
def cal_thresholds(self):
for idx in range(len(self.hists)):
if self.hists[idx] is None:
self.thresholds.append(self.abs_max_vals[idx])
else:
threshold = utils.cal_kl_scaling_factor(
self.hists[idx], self.abs_max_vals[idx], self.quant_bits)
self.thresholds.append(threshold)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
__all__ = ['PTQRegistry']
class LayerInfo(object):
"""
Store the argnames of the inputs and outputs.
"""
def __init__(self, layer, input_names, weight_names, output_names):
super(LayerInfo, self).__init__()
self.layer = layer
self.input_names = input_names
self.weight_names = weight_names
self.output_names = output_names
PTQ_LAYERS_INFO = [
LayerInfo(paddle.nn.Conv2D, ['Input'], ['Filter'], ['Output']),
LayerInfo(paddle.nn.Linear, ['X'], ['Y'], ['Out']),
LayerInfo(paddle.nn.BatchNorm2D, ['X'], [], ['Y']),
LayerInfo(paddle.nn.AdaptiveMaxPool2D, ['X'], [], ['Out']),
LayerInfo(paddle.nn.AdaptiveAvgPool2D, ['X'], [], ['Out']),
LayerInfo(paddle.nn.AvgPool2D, ['X'], [], ['Out']),
LayerInfo(paddle.nn.MaxPool2D, ['X'], [], ['Out']),
LayerInfo(paddle.nn.ReLU, ['X'], [], ['Out']),
LayerInfo(paddle.nn.ReLU6, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Hardswish, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Sigmoid, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Softmax, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Tanh, ['X'], [], ['Out']),
LayerInfo(paddle.nn.quant.add, ['X', 'Y'], [], ['Out']),
]
class PTQRegistry(object):
"""
Register the supported layers for PTQ and provide layers info.
"""
supported_layers_map = {}
is_inited = False
def __init__(self):
super(PTQRegistry, self).__init__()
@classmethod
def _init(cls):
if not cls.is_inited:
for layer_info in PTQ_LAYERS_INFO:
cls.supported_layers_map[layer_info.layer] = layer_info
cls.is_inited = True
@classmethod
def is_supported_layer(cls, layer):
"""
Analyze whether the layer supports quantization.
"""
cls._init()
return layer in cls.supported_layers_map or \
isinstance(layer, tuple(cls.supported_layers_map.keys()))
def layer_info(cls, layer):
"""
Get the infomation for the supported layer.
"""
assert cls.is_supported_layer(
layer), "The input layer is not supported."
for layer_key, layer_info in cls.supported_layers_map.items():
if layer == layer_key or isinstance(layer, layer_key):
return layer_info
......@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid import dygraph
import math
import numpy as np
import paddle
from . import quant_nn
layer_name_map = {
......@@ -60,6 +62,9 @@ fake_quant_leaf_layers = [
fake_quant_wrap_layers = [quant_nn.QuantizedConv2D, quant_nn.QuantizedLinear]
# The weight format of these layers is Cin * Cout * H * W
spec_channel_axis_layers = [paddle.nn.Conv2D, paddle.nn.Conv2DTranspose]
weight_op_types = [
"conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose",
"depthwise_conv2d_transpose"
......@@ -109,7 +114,7 @@ def find_parent_layer_and_sub_name(model, name):
For example, if name is 'block_1/convbn_1/conv_1', the parent layer is
'block_1/convbn_1' and the sub_name is `conv_1`.
"""
assert isinstance(model, dygraph.Layer), \
assert isinstance(model, paddle.nn.Layer), \
"The model must be the instance of paddle.nn.Layer."
assert len(name) > 0, "The input (name) should not be empty."
......@@ -131,5 +136,111 @@ def is_leaf_layer(layer):
"""
Whether the layer is leaf layer.
"""
return isinstance(layer, dygraph.Layer) \
return isinstance(layer, paddle.nn.Layer) \
and len(layer.sublayers()) == 0
def expand_quantized_bins(quantized_bins, reference_bins):
"""
"""
expanded_quantized_bins = [0] * len(reference_bins)
num_merged_bins = int(len(reference_bins) / len(quantized_bins))
j_start = 0
j_end = num_merged_bins
for idx in range(len(quantized_bins)):
zero_count = reference_bins[j_start:j_end].count(0)
num_merged_bins = j_end - j_start
if zero_count == num_merged_bins:
avg_bin_ele = 0
else:
avg_bin_ele = quantized_bins[idx] / (
num_merged_bins - zero_count + 0.0)
for idx1 in range(j_start, j_end):
expanded_quantized_bins[idx1] = (0 if reference_bins[idx1] == 0 else
avg_bin_ele)
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == len(quantized_bins) - 1:
j_end = len(reference_bins)
return expanded_quantized_bins
def safe_entropy(reference_distr_P, P_sum, candidate_distr_Q, Q_sum):
'''
Calculate the entropy.
'''
assert len(reference_distr_P) == len(candidate_distr_Q)
tmp_sum1 = 0
tmp_sum2 = 0
for idx in range(len(reference_distr_P)):
p_idx = reference_distr_P[idx]
q_idx = candidate_distr_Q[idx]
if p_idx == 0:
tmp_sum1 += 0
tmp_sum2 += 0
else:
if q_idx == 0:
_logger.error("Fatal error!, idx = " + str(idx) +
" qindex = 0! p_idx = " + str(p_idx))
tmp_sum1 += p_idx * (math.log(Q_sum * p_idx))
tmp_sum2 += p_idx * (math.log(P_sum * q_idx))
return (tmp_sum1 - tmp_sum2) / P_sum
def cal_kl_scaling_factor(hist, abs_max, bits):
'''
Using the KL-divergenc method to get the more precise scaling factor.
'''
assert hist.ndim == 1
hist_bins = hist.shape[0]
starting_iter = int((hist_bins - 1) * 0.5)
bin_width = abs_max / hist_bins
quant_range = 2**(bits - 1) - 1
P_sum = np.sum(np.array(hist).ravel())
min_kl_divergence = 0
min_kl_index = 0
kl_inited = False
for i in range(starting_iter, hist_bins):
reference_distr_P = hist[0:i].tolist()
outliers_count = sum(hist[i:])
if reference_distr_P[i - 1] == 0:
continue
reference_distr_P[i - 1] += outliers_count
reference_distr_bins = reference_distr_P[:]
candidate_distr_Q = hist[0:i].tolist()
num_merged_bins = int(i / quant_range)
candidate_distr_Q_quantized = [0] * quant_range
j_start = 0
j_end = num_merged_bins
for idx in range(quant_range):
candidate_distr_Q_quantized[idx] = sum(candidate_distr_Q[j_start:
j_end])
j_start += num_merged_bins
j_end += num_merged_bins
if (idx + 1) == quant_range - 1:
j_end = i
candidate_distr_Q = expand_quantized_bins(candidate_distr_Q_quantized,
reference_distr_bins)
Q_sum = sum(candidate_distr_Q)
kl_divergence = safe_entropy(reference_distr_P, P_sum,
candidate_distr_Q, Q_sum)
if not kl_inited:
min_kl_divergence = kl_divergence
min_kl_index = i
kl_inited = True
elif kl_divergence < min_kl_divergence:
min_kl_divergence = kl_divergence
min_kl_index = i
else:
pass
if min_kl_index == 0:
while starting_iter > 0:
if hist[starting_iter] == 0:
starting_iter -= 1
continue
else:
break
min_kl_index = starting_iter
return (min_kl_index + 0.5) * bin_width
......@@ -125,6 +125,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model)
list(REMOVE_ITEM TEST_OPS test_imperative_ptq)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2)
list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp)
......@@ -300,6 +301,7 @@ if(NOT WIN32)
set_tests_properties(test_post_training_quantization_mobilenetv1 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY")
set_tests_properties(test_post_training_quantization_resnet50 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY")
set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120)
set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT 120)
endif()
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from __future__ import print_function
import os
import numpy as np
import random
import shutil
import time
import unittest
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import *
from paddle.fluid.log_helper import get_logger
from paddle.dataset.common import download
from imperative_test_utils import fix_model_dict, ImperativeLenet
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class TestImperativePTQ(unittest.TestCase):
"""
"""
@classmethod
def setUpClass(cls):
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
cls.root_path = os.path.join(os.getcwd(), "imperative_ptq_" + timestamp)
cls.save_path = os.path.join(cls.root_path, "model")
cls.download_path = 'dygraph_int8/download'
cls.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
cls.download_path)
cls.lenet_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/lenet_pretrained.tar.gz"
cls.lenet_md5 = "953b802fb73b52fae42896e3c24f0afb"
seed = 1
np.random.seed(seed)
paddle.static.default_main_program().random_seed = seed
paddle.static.default_startup_program().random_seed = seed
@classmethod
def tearDownClass(cls):
try:
shutil.rmtree(cls.root_path)
except Exception as e:
print("Failed to delete {} due to {}".format(cls.root_path, str(e)))
def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder,
zip_path)
os.system(cmd)
def download_model(self, data_url, data_md5, folder_name):
download(data_url, self.download_path, data_md5)
file_name = data_url.split('/')[-1]
zip_path = os.path.join(self.cache_folder, file_name)
print('Data is downloaded at {0}'.format(zip_path))
data_cache_folder = os.path.join(self.cache_folder, folder_name)
self.cache_unzipping(data_cache_folder, zip_path)
return data_cache_folder
def set_vars(self):
self.ptq = ImperativePTQ(default_ptq_config)
self.batch_num = 10
self.batch_size = 10
self.eval_acc_top1 = 0.99
self.gt_thresholds = {
'conv2d_0': [[1.0], [0.37673383951187134], [0.10933732241392136]],
'batch_norm2d_0': [[0.37673383951187134], [0.44249194860458374]],
're_lu_0': [[0.44249194860458374], [0.25804123282432556]],
'max_pool2d_0': [[0.25804123282432556], [0.25804123282432556]],
'linear_0':
[[1.7058950662612915], [14.405526161193848], [0.4373355209827423]],
'add_0': [[1.7058950662612915, 0.0], [1.7058950662612915]],
}
def model_train(self, model, train_reader, max_step=-1):
model.train()
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
for batch_id, data in enumerate(train_reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
out = model(img)
acc = fluid.layers.accuracy(out, label)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
model.clear_gradients()
if batch_id % 100 == 0:
_logger.info("Train | step {}: loss = {:}, acc= {:}".format(
batch_id, avg_loss.numpy(), acc.numpy()))
if max_step > 0 and batch_id > max_step: # For shortening CI time
break
def model_test(self, model, batch_num=-1, batch_size=8):
model.eval()
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
eval_acc_top1_list = []
for batch_id, data in enumerate(test_reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
out = model(img)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
if batch_id % 100 == 0:
eval_acc_top1_list.append(float(acc_top1.numpy()))
_logger.info("Test | At step {}: acc1 = {:}, acc5 = {:}".format(
batch_id, acc_top1.numpy(), acc_top5.numpy()))
if batch_num > 0 and batch_id + 1 >= batch_num:
break
eval_acc_top1 = sum(eval_acc_top1_list) / len(eval_acc_top1_list)
return eval_acc_top1
def check_thresholds(self, model):
check_num = 0
for name, layer in model.named_sublayers():
layer_name = layer.full_name()
if layer_name in self.gt_thresholds:
ref_val = self.gt_thresholds[layer_name]
assert hasattr(layer, '_quant_config')
quant_config = layer._quant_config
in_val = quant_config.in_act_quantizer.thresholds
out_val = quant_config.out_act_quantizer.thresholds
wt_val = quant_config.wt_quantizer.thresholds
check_num += 1
self.assertTrue(
np.allclose(
ref_val[0], in_val, atol=1e-3),
"%s | The thresholds(%s) is different "
"from the ground truth(%s)." %
(layer_name, str(in_val), str(ref_val[0])))
self.assertTrue(
np.allclose(
ref_val[1], out_val, atol=1e-3),
"%s | The thresholds(%s) is different "
"from the ground truth(%s)." %
(layer_name, str(out_val), str(ref_val[1])))
if len(ref_val) > 2 and ref_val[2] != []:
self.assertTrue(
np.allclose(
ref_val[2], wt_val, atol=1e-3),
"%s | The thresholds(%s) is different "
"from the ground truth(%s)." %
(layer_name, str(wt_val), str(ref_val[2])))
self.assertTrue(check_num == len(self.gt_thresholds))
def test_ptq(self):
start_time = time.time()
self.set_vars()
params_path = self.download_model(self.lenet_url, self.lenet_md5,
"lenet")
params_path += "/lenet_pretrained/lenet.pdparams"
with fluid.dygraph.guard():
model = ImperativeLenet()
model_state_dict = paddle.load(params_path)
model.set_state_dict(model_state_dict)
self.ptq.quantize(model, inplace=True)
acc_top1 = self.model_test(model, self.batch_num, self.batch_size)
print('acc_top1: %s' % acc_top1)
self.assertTrue(
acc_top1 > self.eval_acc_top1,
msg="The test acc {%f} is less than {%f}." %
(acc_top1, self.eval_acc_top1))
self.ptq.convert(model)
self.check_thresholds(model)
input_spec = [
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
]
paddle.jit.save(layer=model, path=self.save_path, input_spec=input_spec)
print('Quantized model saved in {%s}' % self.save_path)
end_time = time.time()
print("total time: %ss" % (end_time - start_time))
class TestImperativePTQHist(TestImperativePTQ):
"""
"""
def set_vars(self):
config = PTQConfig(HistQuantizer(), AbsmaxQuantizer())
self.ptq = ImperativePTQ(config)
self.batch_num = 10
self.batch_size = 10
self.eval_acc_top1 = 0.99
self.gt_thresholds = {
'conv2d_0':
[[0.99853515625], [0.35732391771364225], [0.10933732241392136]],
'batch_norm2d_0': [[0.35732391771364225], [0.4291427868761275]],
're_lu_0': [[0.4291427868761275], [0.2359918110742001]],
'max_pool2d_0': [[0.2359918110742001], [0.25665526917146053]],
'linear_0':
[[1.7037603475152991], [14.395224522473026], [0.4373355209827423]],
'add_0': [[1.7037603475152991, 0.0], [1.7037603475152991]],
}
class TestImperativePTQKL(TestImperativePTQ):
"""
"""
def set_vars(self):
config = PTQConfig(KLQuantizer(), PerChannelAbsmaxQuantizer())
self.ptq = ImperativePTQ(config)
self.batch_num = 10
self.batch_size = 10
self.eval_acc_top1 = 0.99
conv2d_1_wt_thresholds = [
0.18116560578346252, 0.17079241573810577, 0.1702047884464264,
0.179476797580719, 0.1454375684261322, 0.22981858253479004
]
self.gt_thresholds = {
'conv2d_0': [[0.99267578125], [0.37695913558696836]],
'conv2d_1': [[0.19189296757394914], [0.24514256547263358],
[conv2d_1_wt_thresholds]],
'batch_norm2d_0': [[0.37695913558696836], [0.27462541429440535]],
're_lu_0': [[0.27462541429440535], [0.19189296757394914]],
'max_pool2d_0': [[0.19189296757394914], [0.19189296757394914]],
'linear_0': [[1.2839322163611087], [8.957185942414352]],
'add_0': [[1.2839322163611087, 0.0], [1.2839322163611087]],
}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册