未验证 提交 fd85be80 编写于 作者: C cc 提交者: GitHub

[PTQ ] wrap simulated layers and save the quantized model (#33962)

* PTQ save quantized model

* Wrap simulated layer

* post process the inference model
上级 477d9f1e
......@@ -14,14 +14,18 @@
import logging
import copy
import os
import numpy as np
import paddle
import paddle.nn.quant.quant_layers as quant_layers
from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from . import utils
from . import ptq_hooks
from . import ptq_config
from . import ptq_quantizer
from .ptq_registry import PTQRegistry
__all__ = ['ImperativePTQ']
......@@ -53,7 +57,7 @@ class ImperativePTQ(object):
def quantize(self, model, inplace=False):
"""
Add hook to the leaf layer to calculate the threshold of inputs and outputs.
Add quant config and hook to the target layer.
Args:
model(paddle.nn.Layer): The model to be quantized.
......@@ -70,10 +74,16 @@ class ImperativePTQ(object):
for name, layer in new_model.named_sublayers():
if PTQRegistry.is_supported_layer(layer) \
and utils.is_leaf_layer(layer):
and utils.is_leaf_layer(layer) \
and not self._is_skip_layer(layer):
# Add quant config
quant_config = copy.deepcopy(self._quant_config)
if PTQRegistry.is_simulated_quant_layer(layer):
quant_config.enable_in_act_quantizer = True
layer._quant_config = quant_config
# register hook
hook = ptq_hooks.quant_forward_post_hook
quant_hook_handle = layer.register_forward_post_hook(hook)
quant_config.quant_hook_handle = quant_hook_handle
......@@ -82,35 +92,330 @@ class ImperativePTQ(object):
return new_model
def convert(self, model):
def save_quantized_model(self, model, path, input_spec=None, **config):
"""
Process the scales and remove the hooks.
1. Convert the quantized model
2. Call jit.save to save the inference model
3. Load and postprocess the inference model.
Args:
model(paddle.nn.Layer): The model to be quantized.
model (Layer): The model to be saved.
path (str): The path prefix to save model. The format is
``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor], optional): Describes the input
of the saved model's forward method, which can be described by
InputSpec or example Tensor. If None, all input variables of
the original Layer's forward method would be the inputs of
the saved model. Default None.
**configs (dict, optional): Other save configuration options for
compatibility. We do not recommend using these configurations,
they may be removed in the future. If not necessary, DO NOT use
them. Default None.
The following options are currently supported:
(1) output_spec (list[Tensor]): Selects the output targets of
the saved model. By default, all return variables of original
Layer's forward method are kept as the output of the saved model.
If the provided ``output_spec`` list is not all output variables,
the saved model will be pruned according to the given
``output_spec`` list.
Returns:
converted_model(paddle.nn.Layer): The converted model.
None
"""
assert isinstance(model, paddle.nn.Layer), \
"The input model must be the instance of paddle.nn.Layer."
"The model must be the instance of paddle.nn.Layer."
# Convert and save dygraph quantized model
self._convert(model)
paddle.jit.save(layer=model, path=path, input_spec=input_spec, **config)
# Load inference program
is_dynamic_mode = False
if paddle.in_dynamic_mode():
is_dynamic_mode = True
paddle.enable_static()
place = paddle.CPUPlace()
scope = paddle.static.global_scope()
exe = paddle.static.Executor(place)
dirname = os.path.dirname(path)
basename = os.path.basename(path)
model_filename = basename + INFER_MODEL_SUFFIX
params_filename = basename + INFER_PARAMS_SUFFIX
[infer_program, feed_target_names, fetch_targets] = (
paddle.fluid.io.load_inference_model(
dirname=dirname,
executor=exe,
model_filename=model_filename,
params_filename=params_filename))
# Process inference program
self._clean_up(infer_program)
self._gather_input_thresholds(infer_program, scope)
self._remove_scale_op(infer_program)
# Save final program
paddle.fluid.io.save_inference_model(
dirname=dirname,
feeded_var_names=feed_target_names,
target_vars=fetch_targets,
executor=exe,
main_program=infer_program.clone(),
model_filename=model_filename,
params_filename=params_filename)
if is_dynamic_mode:
paddle.disable_static()
def _convert(self, model):
"""
Convert the quantized model.
Args:
model(paddle.nn.Layer): The quantized model.
inplace(bool): Whether apply conversion to the input model.
Default: False.
Returns:
None
"""
for name, sub_layer in model.named_sublayers():
if PTQRegistry.is_supported_layer(sub_layer) \
and utils.is_leaf_layer(sub_layer):
if self._is_quant_layer(sub_layer):
sub_layer._quant_config.quant_hook_handle.remove()
assert hasattr(sub_layer, "_quant_config")
self._cal_thresholds(model)
for name, sub_layer in model.named_sublayers():
if self._is_quant_layer(sub_layer):
self._save_output_thresholds(sub_layer, sub_layer._quant_config)
self._wrap_simulated_layers(model)
def _cal_thresholds(self, model):
"""
Calculate the thresholds of inputs and outputs.
Args:
model(paddle.nn.Layer): The quantized model.
Returns:
None
"""
assert isinstance(model, paddle.nn.Layer), \
"The input model must be the instance of paddle.nn.Layer."
for name, sub_layer in model.named_sublayers():
if self._is_quant_layer(sub_layer):
quant_config = sub_layer._quant_config
quant_config.quant_hook_handle.remove()
quant_config.in_act_quantizer.cal_thresholds()
if quant_config.enable_in_act_quantizer:
quant_config.in_act_quantizer.cal_thresholds()
quant_config.out_act_quantizer.cal_thresholds()
# get weight thresholds
if isinstance(sub_layer, tuple(utils.fake_quant_input_layers)):
if PTQRegistry.is_simulated_quant_layer(sub_layer):
weights = (sub_layer.weight, )
quant_config.wt_quantizer.sample_data(sub_layer, weights)
quant_config.wt_quantizer.cal_thresholds()
def _save_output_thresholds(self, sub_layer, quant_config):
"""
Save the output thresholds to the layer.
Args:
sub_layer(paddle.nn.Layer): The quantized layer.
quant_config(PTQConfig): the quant config for the layer.
Returns:
None
"""
assert isinstance(sub_layer, paddle.nn.Layer), \
"The input model must be the instance of paddle.nn.Layer."
layer_info = PTQRegistry.layer_info(sub_layer)
output_names = layer_info.output_names
output_thresholds = quant_config.out_act_quantizer.thresholds
assert len(output_names) == 1
assert len(output_thresholds) == 1
save_name = output_names[0] + str(0) + "_threshold"
sub_layer._set_op_attrs({save_name: output_thresholds[0]})
sub_layer._set_op_attrs({"out_threshold": output_thresholds[0]})
def _wrap_simulated_layers(self, model):
"""
Replace conv2d and linear with the quantized layers, and save
thresholds into the fake layers.
Args:
model(paddle.nn.Layer): The model to be quantized.
Returns:
None
"""
assert isinstance(model, paddle.nn.Layer), \
"The input model must be the instance of paddle.nn.Layer."
for name, sub_layer in model.named_sublayers():
if self._is_quant_layer(sub_layer) \
and PTQRegistry.is_simulated_quant_layer(sub_layer):
quant_config = sub_layer._quant_config
assert quant_config.enable_in_act_quantizer == True
wt_quantizer = quant_config.wt_quantizer
in_act_quantizer = quant_config.in_act_quantizer
# create layer
quant_layer_name = None
for key, value in utils.layer_name_map.items():
if isinstance(sub_layer, value):
quant_layer_name = 'Quantized' + key
break
assert quant_layer_name is not None
if isinstance(wt_quantizer, ptq_quantizer.AbsmaxQuantizer):
weight_quantize_type = "abs_max"
else:
weight_quantize_type = "channel_wise_abs_max"
kwargs = {
"weight_quantize_type": weight_quantize_type,
"activation_quantize_type": "moving_average_abs_max",
"weight_bits": wt_quantizer.quant_bits,
"activation_bits": in_act_quantizer.quant_bits,
}
quant_layer = quant_layers.__dict__[quant_layer_name](sub_layer,
**kwargs)
# save the input thresholds
assert hasattr(quant_layer, "_fake_quant_input")
assert hasattr(quant_layer._fake_quant_input, "_scale")
assert len(in_act_quantizer.thresholds) == 1
input_threshold = np.array(
[in_act_quantizer.thresholds[0]], dtype=np.float32)
quant_layer._fake_quant_input._scale.set_value(input_threshold)
assert hasattr(quant_layer, "_fake_quant_weight")
assert hasattr(quant_layer._fake_quant_weight, "_scale")
assert len(wt_quantizer.thresholds) == 1
weight_threshold = wt_quantizer.thresholds[0]
if isinstance(weight_threshold, list):
weight_threshold = np.array(
weight_threshold, dtype=np.float32)
else:
weight_threshold = np.array(
[weight_threshold], dtype=np.float32)
quant_layer._fake_quant_weight._scale.set_value(
weight_threshold)
# save the output thresholds
self._save_output_thresholds(quant_layer, quant_config)
# replace the layer
parent_layer, sub_name = \
utils.find_parent_layer_and_sub_name(model, name)
setattr(parent_layer, sub_name, quant_layer)
def _gather_input_thresholds(self, program, scope):
"""
Get and save input thresholds from the front ops.
Args:
program(Program): the input infer program.
scope(Scope): the corresponding scope for the program.
Returns:
None
"""
for op in utils.program_all_ops(program):
for in_var_name in utils._get_op_input_var_names(op):
previous_op = utils.find_previous_op(op.block, in_var_name)
if previous_op is None:
continue
if "quantize_dequantize" in previous_op.type or \
previous_op.type == "moving_average_abs_max_scale":
attr_name = previous_op.output('OutScale')[0]
in_threshold = utils.load_variable_data(scope, attr_name)
in_threshold = utils.fp_numpy_to_naive(in_threshold)
argname, index = utils._get_input_name_index(op,
in_var_name)
op._set_attr(argname + str(index) + "_threshold",
in_threshold)
else:
for out_var_name in utils._get_op_output_var_names(
previous_op):
if out_var_name != in_var_name:
continue
argname, index = utils._get_output_name_index(
previous_op, out_var_name)
attr_name = argname + str(index) + "_threshold"
if not previous_op.has_attr(attr_name):
continue
threshold = previous_op.attr(attr_name)
argname, index = utils._get_input_name_index(
op, in_var_name)
attr_name = argname + str(index) + "_threshold"
op._set_attr(attr_name, threshold)
def _clean_up(self, program):
"""
Remove useless thresholds which are added in jit.save.
Args:
program(Program): the input infer program.
Returns:
None
"""
def _helper(op, next_op, old_attr_name, new_attr_name):
if op.has_attr(old_attr_name) and next_op.has_attr(old_attr_name) \
and op.attr(old_attr_name) == next_op.attr(old_attr_name):
threshold = op.attr(old_attr_name)
op._remove_attr(old_attr_name)
next_op._remove_attr(old_attr_name)
next_op._set_attr(new_attr_name, threshold)
for op in utils.program_all_ops(program):
if "quantize_dequantize" in op.type:
# remove the thresholds in fake ops
for attr_name in op.attr_names:
if "_threshold" in attr_name:
op._remove_attr(attr_name)
elif op.type in ["conv2d", "matmul"]:
# change the thresholds in conv2d/matmul + eleadd
arg_name = "Output" if op.type == "conv2d" else "Out"
out_var_name = op.output(arg_name)[0]
next_ops = utils.find_next_ops(op.block, out_var_name)
if len(next_ops) > 1 or next_ops[0].type != "elementwise_add":
continue
next_op = next_ops[0]
argname, index = utils._get_output_name_index(op, out_var_name)
old_attr_name = argname + str(index) + "_threshold"
argname, index = utils._get_output_name_index(
next_op, next_op.output("Out")[0])
new_attr_name = argname + str(index) + "_threshold"
_helper(op, next_op, old_attr_name, new_attr_name)
_helper(op, next_op, "out_threshold", "out_threshold")
def _remove_scale_op(self, program):
"""
Remove the moving_average_abs_max_scale op.
"""
for op in utils.program_all_ops(program):
if op.type == "moving_average_abs_max_scale":
in_var_name = op.input("X")[0]
out_var_name = op.output("Out")[0]
next_ops = utils.find_next_ops(op.block, out_var_name)
for next_op in next_ops:
next_op._rename_input(out_var_name, in_var_name)
# TODO (jc):
# save input activation threshold and quant bits
@staticmethod
def _is_skip_layer(layer):
return hasattr(layer, "skip_quant") and layer.skip_quant == True
return model
@staticmethod
def _is_quant_layer(layer):
return hasattr(layer, "_quant_config")
......@@ -39,9 +39,8 @@ class PTQConfig(object):
It should be the instance of BaseQuantizer.
"""
super(PTQConfig, self).__init__()
assert isinstance(activation_quantizer, BaseQuantizer)
assert isinstance(weight_quantizer, BaseQuantizer)
assert isinstance(activation_quantizer, tuple(SUPPORT_ACT_QUANTIZERS))
assert isinstance(weight_quantizer, tuple(SUPPORT_WT_QUANTIZERS))
self.in_act_quantizer = copy.deepcopy(activation_quantizer)
self.out_act_quantizer = copy.deepcopy(activation_quantizer)
......@@ -49,5 +48,9 @@ class PTQConfig(object):
self.quant_hook_handle = None
# In order to wrap simulated layers, use in_act_quantizer
# to calculate the input thresholds for conv2d, linear and etc.
self.enable_in_act_quantizer = False
default_ptq_config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer())
......@@ -16,6 +16,7 @@ import paddle
import math
import numpy as np
from . import ptq_config
from .ptq_registry import PTQRegistry
def quant_forward_post_hook(layer, inputs, outputs):
......@@ -24,5 +25,8 @@ def quant_forward_post_hook(layer, inputs, outputs):
"""
assert hasattr(layer, '_quant_config'), \
"The layer should have _quant_config attr"
layer._quant_config.in_act_quantizer.sample_data(layer, inputs)
layer._quant_config.out_act_quantizer.sample_data(layer, (outputs, ))
qc = layer._quant_config
if qc.enable_in_act_quantizer:
qc.in_act_quantizer.sample_data(layer, inputs)
qc.out_act_quantizer.sample_data(layer, (outputs, ))
......@@ -24,11 +24,9 @@ from . import utils
from ..cal_kl_threshold import cal_kl_threshold
__all__ = [
'BaseQuantizer',
'AbsmaxQuantizer',
'PerChannelAbsmaxQuantizer',
'KLQuantizer',
'HistQuantizer',
'BaseQuantizer', 'AbsmaxQuantizer', 'PerChannelAbsmaxQuantizer',
'KLQuantizer', 'HistQuantizer', 'SUPPORT_ACT_QUANTIZERS',
'SUPPORT_WT_QUANTIZERS'
]
......@@ -110,6 +108,7 @@ class BaseQuantizer(object):
self.quant_bits = quant_bits
self.abs_max_vals = []
self.thresholds = []
@abc.abstractmethod
......@@ -133,10 +132,10 @@ class AbsmaxQuantizer(BaseQuantizer):
assert isinstance(tensors, tuple)
abs_max_vals = [abs_max_value(t) for t in tensors]
self.thresholds = merge_max_value(self.thresholds, abs_max_vals)
self.abs_max_vals = merge_max_value(self.abs_max_vals, abs_max_vals)
def cal_thresholds(self):
pass
self.thresholds = self.abs_max_vals
class PerChannelAbsmaxQuantizer(BaseQuantizer):
......@@ -164,10 +163,11 @@ class PerChannelAbsmaxQuantizer(BaseQuantizer):
]
abs_max_vals_list.append(abs_max_vals)
self.thresholds = merge_max_value(self.thresholds, abs_max_vals_list)
self.abs_max_vals = merge_max_value(self.abs_max_vals,
abs_max_vals_list)
def cal_thresholds(self):
pass
self.thresholds = self.abs_max_vals
@six.add_metaclass(abc.ABCMeta)
......@@ -180,7 +180,6 @@ class BaseHistQuantizer(BaseQuantizer):
self.bins = bins
self.upsample_bins = upsample_bins
self.abs_max_vals = []
self.hists = []
def sample_data(self, layer, tensors):
......@@ -262,3 +261,7 @@ class KLQuantizer(BaseHistQuantizer):
bin_width = abs_max_val / hist.shape[0]
threshold = cal_kl_threshold(hist, bin_width, self.quant_bits)
self.thresholds.append(threshold)
SUPPORT_ACT_QUANTIZERS = [AbsmaxQuantizer, HistQuantizer, KLQuantizer]
SUPPORT_WT_QUANTIZERS = [AbsmaxQuantizer, PerChannelAbsmaxQuantizer]
......@@ -47,12 +47,22 @@ PTQ_LAYERS_INFO = [
LayerInfo(paddle.nn.quant.add, ['X', 'Y'], [], ['Out']),
]
QUANT_LAYERS_INFO = [
LayerInfo(paddle.nn.quant.quant_layers.QuantizedConv2D, ['Input'],
['Filter'], ['Output']),
LayerInfo(paddle.nn.quant.quant_layers.QuantizedLinear, ['X'], ['Y'],
['Out']),
]
SIMULATED_LAYERS = [paddle.nn.Conv2D, paddle.nn.Linear]
class PTQRegistry(object):
"""
Register the supported layers for PTQ and provide layers info.
"""
supported_layers_map = {}
registered_layers_map = {}
is_inited = False
def __init__(self):
......@@ -63,24 +73,62 @@ class PTQRegistry(object):
if not cls.is_inited:
for layer_info in PTQ_LAYERS_INFO:
cls.supported_layers_map[layer_info.layer] = layer_info
all_layers_info = PTQ_LAYERS_INFO + QUANT_LAYERS_INFO
for layer_info in all_layers_info:
cls.registered_layers_map[layer_info.layer] = layer_info
cls.is_inited = True
@classmethod
def is_supported_layer(cls, layer):
"""
Analyze whether the layer supports quantization.
Args:
layer(Layer): The input layer can be a python class or an instance.
Returns:
flag(bool): Whther the layer is supported.
"""
cls._init()
return layer in cls.supported_layers_map or \
isinstance(layer, tuple(cls.supported_layers_map.keys()))
@classmethod
def is_registered_layer(cls, layer):
"""
Analyze whether the layer is register layer_info.
Args:
layer(Layer): The input layer can be a python class or an instance.
Returns:
flag(bool): Wether the layer is register layer_info.
"""
cls._init()
return layer in cls.registered_layers_map or \
isinstance(layer, tuple(cls.registered_layers_map.keys()))
@classmethod
def is_simulated_quant_layer(cls, layer):
"""
Analyze whether the layer is simulated quant layer.
Args:
layer(Layer): The input layer can be a python class or an instance.
Returns:
flag(bool): Whther the layer is supported.
"""
return layer in SIMULATED_LAYERS or \
isinstance(layer, tuple(SIMULATED_LAYERS))
@classmethod
def layer_info(cls, layer):
"""
Get the infomation for the supported layer.
Get the infomation for the layer.
Args:
layer(Layer): The input layer can be a python class or an instance.
Returns:
layer_info(LayerInfo): The layer info of the input layer.
"""
assert cls.is_supported_layer(
layer), "The input layer is not supported."
assert cls.is_registered_layer(layer), \
"The input layer is not register."
for layer_key, layer_info in cls.supported_layers_map.items():
for layer_key, layer_info in cls.registered_layers_map.items():
if layer == layer_key or isinstance(layer, layer_key):
return layer_info
......@@ -379,12 +379,12 @@ class ImperativeQuantizeOutputs(object):
setattr(parent_layer, sub_name, cur_quant_layer)
def save_quantized_model(self, layer, path, input_spec=None, **config):
def save_quantized_model(self, model, path, input_spec=None, **config):
"""
Save the quantized model for the inference.
Args:
layer (Layer): The Layer to be saved.
model (Layer): The model to be saved.
path (str): The path prefix to save model. The format is
``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor], optional): Describes the input
......@@ -407,10 +407,10 @@ class ImperativeQuantizeOutputs(object):
Returns:
None
"""
assert isinstance(layer, dygraph.Layer), \
assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer."
paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config)
paddle.jit.save(layer=model, path=path, input_spec=input_spec, **config)
is_dynamic_mode = False
if paddle.in_dynamic_mode():
......
......@@ -69,7 +69,7 @@ fake_quant_wrap_layers = [
]
# The weight format of these layers is Cin * Cout * H * W
spec_channel_axis_layers = [paddle.nn.Conv2D, paddle.nn.Conv2DTranspose]
spec_channel_axis_layers = [paddle.nn.Conv2DTranspose, paddle.nn.Linear]
weight_op_types = [
"conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose",
......@@ -139,6 +139,17 @@ def find_parent_layer_and_sub_name(model, name):
return parent_layer, sub_name
def program_all_ops(program):
"""
Return all ops for the input program.
"""
all_ops = []
for block in program.blocks:
for op in block.ops:
all_ops.append(op)
return all_ops
def is_leaf_layer(layer):
"""
Whether the layer is leaf layer.
......
......@@ -128,9 +128,11 @@ class ImperativeLenet(fluid.dygraph.Layer):
bias_attr=fc_b3_attr),
Softmax())
self.add = paddle.nn.quant.add()
self.quant_stub = paddle.nn.quant.QuantStub()
def forward(self, inputs):
x = self.features(inputs)
x = self.quant_stub(inputs)
x = self.features(x)
x = fluid.layers.flatten(x, 1)
x = self.add(x, paddle.to_tensor(0.0)) # For CI
......
......@@ -20,6 +20,7 @@ import random
import shutil
import time
import unittest
import copy
import logging
import paddle
......@@ -59,7 +60,8 @@ class TestImperativePTQ(unittest.TestCase):
@classmethod
def tearDownClass(cls):
try:
shutil.rmtree(cls.root_path)
pass
# shutil.rmtree(cls.root_path)
except Exception as e:
print("Failed to delete {} due to {}".format(cls.root_path, str(e)))
......@@ -84,8 +86,9 @@ class TestImperativePTQ(unittest.TestCase):
self.batch_num = 10
self.batch_size = 10
self.eval_acc_top1 = 0.99
self.eval_acc_top1 = 0.95
# the input, output and weight thresholds of quantized op
self.gt_thresholds = {
'conv2d_0': [[1.0], [0.37673383951187134], [0.10933732241392136]],
'batch_norm2d_0': [[0.37673383951187134], [0.44249194860458374]],
......@@ -96,36 +99,6 @@ class TestImperativePTQ(unittest.TestCase):
'add_0': [[1.7058950662612915, 0.0], [1.7058950662612915]],
}
def model_train(self, model, train_reader, max_step=-1):
model.train()
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
for batch_id, data in enumerate(train_reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
out = model(img)
acc = fluid.layers.accuracy(out, label)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
model.clear_gradients()
if batch_id % 100 == 0:
_logger.info("Train | step {}: loss = {:}, acc= {:}".format(
batch_id, avg_loss.numpy(), acc.numpy()))
if max_step > 0 and batch_id > max_step: # For shortening CI time
break
def model_test(self, model, batch_num=-1, batch_size=8):
model.eval()
......@@ -145,9 +118,9 @@ class TestImperativePTQ(unittest.TestCase):
out = model(img)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
eval_acc_top1_list.append(float(acc_top1.numpy()))
if batch_id % 100 == 0:
eval_acc_top1_list.append(float(acc_top1.numpy()))
if batch_id % 50 == 0:
_logger.info("Test | At step {}: acc1 = {:}, acc5 = {:}".format(
batch_id, acc_top1.numpy(), acc_top5.numpy()))
......@@ -158,80 +131,88 @@ class TestImperativePTQ(unittest.TestCase):
return eval_acc_top1
def check_thresholds(self, model):
check_num = 0
for name, layer in model.named_sublayers():
layer_name = layer.full_name()
if layer_name in self.gt_thresholds:
ref_val = self.gt_thresholds[layer_name]
assert hasattr(layer, '_quant_config')
quant_config = layer._quant_config
in_val = quant_config.in_act_quantizer.thresholds
out_val = quant_config.out_act_quantizer.thresholds
wt_val = quant_config.wt_quantizer.thresholds
check_num += 1
self.assertTrue(
np.allclose(
ref_val[0], in_val, atol=1e-3),
"%s | The thresholds(%s) is different "
"from the ground truth(%s)." %
(layer_name, str(in_val), str(ref_val[0])))
self.assertTrue(
np.allclose(
ref_val[1], out_val, atol=1e-3),
"%s | The thresholds(%s) is different "
"from the ground truth(%s)." %
(layer_name, str(out_val), str(ref_val[1])))
if len(ref_val) > 2 and ref_val[2] != []:
self.assertTrue(
np.allclose(
ref_val[2], wt_val, atol=1e-3),
"%s | The thresholds(%s) is different "
"from the ground truth(%s)." %
(layer_name, str(wt_val), str(ref_val[2])))
self.assertTrue(check_num == len(self.gt_thresholds))
def program_test(self, program_path, batch_num=-1, batch_size=8):
exe = paddle.static.Executor(paddle.CPUPlace())
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(program_path, exe))
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
top1_correct_num = 0.
total_num = 0.
for batch_id, data in enumerate(test_reader()):
img = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
label = np.array([x[1] for x in data]).astype('int64')
feed = {feed_target_names[0]: img}
results = exe.run(inference_program,
feed=feed,
fetch_list=fetch_targets)
pred = np.argmax(results[0], axis=1)
top1_correct_num += np.sum(np.equal(pred, label))
total_num += len(img)
if total_num % 50 == 49:
_logger.info("Test | Test num {}: acc1 = {:}".format(
total_num, top1_correct_num / total_num))
if batch_num > 0 and batch_id + 1 >= batch_num:
break
return top1_correct_num / total_num
def test_ptq(self):
start_time = time.time()
self.set_vars()
# Load model
params_path = self.download_model(self.lenet_url, self.lenet_md5,
"lenet")
params_path += "/lenet_pretrained/lenet.pdparams"
with fluid.dygraph.guard():
model = ImperativeLenet()
model_state_dict = paddle.load(params_path)
model.set_state_dict(model_state_dict)
quant_model = self.ptq.quantize(model)
acc_top1 = self.model_test(quant_model, self.batch_num,
self.batch_size)
print('acc_top1: %s' % acc_top1)
self.assertTrue(
acc_top1 > self.eval_acc_top1,
msg="The test acc {%f} is less than {%f}." %
(acc_top1, self.eval_acc_top1))
final_model = self.ptq.convert(quant_model)
model = ImperativeLenet()
model_state_dict = paddle.load(params_path)
model.set_state_dict(model_state_dict)
self.check_thresholds(final_model)
# Quantize, calibrate and save
quant_model = self.ptq.quantize(model)
before_acc_top1 = self.model_test(quant_model, self.batch_num,
self.batch_size)
input_spec = [
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
]
paddle.jit.save(
layer=final_model, path=self.save_path, input_spec=input_spec)
self.ptq.save_quantized_model(
model=quant_model, path=self.save_path, input_spec=input_spec)
print('Quantized model saved in {%s}' % self.save_path)
after_acc_top1 = self.model_test(quant_model, self.batch_num,
self.batch_size)
paddle.enable_static()
infer_acc_top1 = self.program_test(self.save_path, self.batch_num,
self.batch_size)
paddle.disable_static()
# Check
print('Before converted acc_top1: %s' % before_acc_top1)
print('After converted acc_top1: %s' % after_acc_top1)
print('Infer acc_top1: %s' % infer_acc_top1)
self.assertTrue(
after_acc_top1 >= self.eval_acc_top1,
msg="The test acc {%f} is less than {%f}." %
(after_acc_top1, self.eval_acc_top1))
self.assertTrue(
infer_acc_top1 >= after_acc_top1,
msg='The acc is lower after converting model.')
end_time = time.time()
print("total time: %ss" % (end_time - start_time))
print("total time: %ss \n" % (end_time - start_time))
class TestImperativePTQHist(TestImperativePTQ):
......@@ -241,7 +222,7 @@ class TestImperativePTQHist(TestImperativePTQ):
self.batch_num = 10
self.batch_size = 10
self.eval_acc_top1 = 0.99
self.eval_acc_top1 = 0.98
self.gt_thresholds = {
'conv2d_0':
......@@ -262,7 +243,7 @@ class TestImperativePTQKL(TestImperativePTQ):
self.batch_num = 10
self.batch_size = 10
self.eval_acc_top1 = 0.99
self.eval_acc_top1 = 1.0
conv2d_1_wt_thresholds = [
0.18116560578346252, 0.17079241573810577, 0.1702047884464264,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册