未验证 提交 9ac27ac3 编写于 作者: H handiz 提交者: GitHub

add new function ptq first then initialize qat scale with ptq scale (#44747)

上级 bdd0b0f1
...@@ -11,12 +11,14 @@ ...@@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
import os import os
import re import re
import math
import shutil
import logging import logging
import numpy as np import numpy as np
import shutil
try: try:
from tqdm import tqdm from tqdm import tqdm
except: except:
...@@ -34,7 +36,10 @@ from .cal_kl_threshold import cal_kl_threshold ...@@ -34,7 +36,10 @@ from .cal_kl_threshold import cal_kl_threshold
from .adaround import run_adaround from .adaround import run_adaround
from . import utils from . import utils
__all__ = ['PostTrainingQuantization', 'WeightQuantization'] __all__ = [
'PostTrainingQuantization', 'WeightQuantization',
'PostTrainingQuantizationProgram'
]
_logger = get_logger(__name__, _logger = get_logger(__name__,
logging.INFO, logging.INFO,
...@@ -108,9 +113,9 @@ class PostTrainingQuantization(object): ...@@ -108,9 +113,9 @@ class PostTrainingQuantization(object):
""" """
def __init__(self, def __init__(self,
executor=None, executor,
model_dir,
scope=None, scope=None,
model_dir=None,
model_filename=None, model_filename=None,
params_filename=None, params_filename=None,
batch_generator=None, batch_generator=None,
...@@ -130,10 +135,15 @@ class PostTrainingQuantization(object): ...@@ -130,10 +135,15 @@ class PostTrainingQuantization(object):
activation_quantize_type='range_abs_max', activation_quantize_type='range_abs_max',
weight_quantize_type='channel_wise_abs_max', weight_quantize_type='channel_wise_abs_max',
onnx_format=False, onnx_format=False,
freeze_model=True,
optimize_model=False, optimize_model=False,
is_use_cache_file=False, is_use_cache_file=False,
skip_tensor_list=None, skip_tensor_list=None,
cache_dir=None): same_scale_tensor_list=None,
scale_trainable=False,
cache_dir=None,
scale_dict=None,
return_graph=False):
''' '''
Constructor. Constructor.
...@@ -206,7 +216,12 @@ class PostTrainingQuantization(object): ...@@ -206,7 +216,12 @@ class PostTrainingQuantization(object):
the model accuracy is usually higher when it is 'channel_wise_abs_max'. the model accuracy is usually higher when it is 'channel_wise_abs_max'.
onnx_format(bool): Whether to export the quantized model with format of ONNX. onnx_format(bool): Whether to export the quantized model with format of ONNX.
Default is False. Default is False.
skip_tensor_list(list): List of skip quant tensor name. freeze_model(bool): Whether to convert quantized and trained ``program`` to final
quantized ``program``. Default: True.
skip_tensor_list(list): List of skip quant tensor name. Default: None.
same_scale_tensor_list(list(list)): The list of tensor keep same scale in the outermost
list, the final scale about every list is the max of the scale in the list
of tensor. Default: None.
optimize_model(bool, optional): If set optimize_model as True, it applies optimize_model(bool, optional): If set optimize_model as True, it applies
some passes to the model before quantization, and it supports some passes to the model before quantization, and it supports
`conv2d/depthwise_conv2d + bn` pass so far. Some targets require the `conv2d/depthwise_conv2d + bn` pass so far. Some targets require the
...@@ -215,6 +230,7 @@ class PostTrainingQuantization(object): ...@@ -215,6 +230,7 @@ class PostTrainingQuantization(object):
`conv2d/depthwise_conv2d + bn`, the weights scale for all channel will `conv2d/depthwise_conv2d + bn`, the weights scale for all channel will
be different. In address this problem, fuse the pattern before be different. In address this problem, fuse the pattern before
quantization. Default False. quantization. Default False.
scale_trainable(bool, optional): whether scale can be train.
is_use_cache_file(bool, optional): This param is deprecated. is_use_cache_file(bool, optional): This param is deprecated.
cache_dir(str, optional): This param is deprecated. cache_dir(str, optional): This param is deprecated.
Returns: Returns:
...@@ -275,7 +291,6 @@ class PostTrainingQuantization(object): ...@@ -275,7 +291,6 @@ class PostTrainingQuantization(object):
# Check inputs # Check inputs
assert executor is not None, "The executor cannot be None." assert executor is not None, "The executor cannot be None."
assert model_dir is not None, "The model_dir cannot be None."
assert any([gen is not None] for gen in [sample_generator, assert any([gen is not None] for gen in [sample_generator,
batch_generator, data_loader]), "The sample_generator, batch_generator " \ batch_generator, data_loader]), "The sample_generator, batch_generator " \
"and data_loader cannot be None in the same time." "and data_loader cannot be None in the same time."
...@@ -347,6 +362,11 @@ class PostTrainingQuantization(object): ...@@ -347,6 +362,11 @@ class PostTrainingQuantization(object):
self._best_calibration_loss = {} self._best_calibration_loss = {}
# The threshold for algo = abs_max, mse or avg # The threshold for algo = abs_max, mse or avg
self._quantized_threshold = {} self._quantized_threshold = {}
self._same_scale_tensor_list = same_scale_tensor_list
self._freeze_model = freeze_model
self._scale_trainable = scale_trainable
self._scale_dict = scale_dict
self._return_graph = return_graph
def quantize(self): def quantize(self):
''' '''
...@@ -441,7 +461,11 @@ class PostTrainingQuantization(object): ...@@ -441,7 +461,11 @@ class PostTrainingQuantization(object):
persistables.extend(_op.input('X')) persistables.extend(_op.input('X'))
_op.desc.set_input("X", persistables) _op.desc.set_input("X", persistables)
return self._program if not self._return_graph:
return self._program
else:
main_graph = IrGraph(core.Graph(self._program.desc), for_test=True)
return main_graph
def _adaround_apply(self): def _adaround_apply(self):
assert self._algo != "min_max", "The algo should not be min_max." assert self._algo != "min_max", "The algo should not be min_max."
...@@ -495,12 +519,13 @@ class PostTrainingQuantization(object): ...@@ -495,12 +519,13 @@ class PostTrainingQuantization(object):
''' '''
Load model and set data loader. Load model and set data loader.
''' '''
_logger.info("Load model and set data loader ...") if self._program is None:
[self._program, self._feed_list, self._fetch_list] = \ _logger.info("Load model and set data loader ...")
io.load_inference_model(dirname=self._model_dir, [self._program, self._feed_list, self._fetch_list] = \
executor=self._executor, io.load_inference_model(dirname=self._model_dir,
model_filename=self._model_filename, executor=self._executor,
params_filename=self._params_filename) model_filename=self._model_filename,
params_filename=self._params_filename)
if self._optimize_model: if self._optimize_model:
self._optimize_fp32_model() self._optimize_fp32_model()
...@@ -972,7 +997,8 @@ class PostTrainingQuantization(object): ...@@ -972,7 +997,8 @@ class PostTrainingQuantization(object):
activation_bits=self._activation_bits, activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type, activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type, weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types) quantizable_op_type=major_quantizable_op_types,
is_test=not self._scale_trainable)
else: else:
transform_pass = QuantizationTransformPassV2( transform_pass = QuantizationTransformPassV2(
scope=self._scope, scope=self._scope,
...@@ -981,7 +1007,8 @@ class PostTrainingQuantization(object): ...@@ -981,7 +1007,8 @@ class PostTrainingQuantization(object):
activation_bits=self._activation_bits, activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type, activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type, weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types) quantizable_op_type=major_quantizable_op_types,
is_test=not self._scale_trainable)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
# Insert fake_quant/fake_dequantize op must in test graph, so # Insert fake_quant/fake_dequantize op must in test graph, so
...@@ -998,24 +1025,68 @@ class PostTrainingQuantization(object): ...@@ -998,24 +1025,68 @@ class PostTrainingQuantization(object):
add_quant_dequant_pass = AddQuantDequantPass( add_quant_dequant_pass = AddQuantDequantPass(
scope=self._scope, scope=self._scope,
place=self._place, place=self._place,
quantizable_op_type=minor_quantizable_op_types) quantizable_op_type=minor_quantizable_op_types,
is_test=not self._scale_trainable)
else: else:
add_quant_dequant_pass = AddQuantDequantPassV2( add_quant_dequant_pass = AddQuantDequantPassV2(
scope=self._scope, scope=self._scope,
place=self._place, place=self._place,
quantizable_op_type=minor_quantizable_op_types, quantizable_op_type=minor_quantizable_op_types,
is_full_quantized=self._is_full_quantize) is_full_quantized=self._is_full_quantize,
is_test=not self._scale_trainable)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True sub_graph._for_test = True
add_quant_dequant_pass.apply(sub_graph) add_quant_dequant_pass.apply(sub_graph)
# save threshold to scale var node # save threshold to scale var node
if self._algo in ["KL", "hist"]: if self._scale_dict is None:
scale_dict = self._quantized_var_threshold if self._algo in ["KL", "hist"]:
else: scale_dict = self._quantized_var_threshold
scale_dict = self._quantized_threshold else:
for key, val in scale_dict.items(): scale_dict = self._quantized_threshold
if self._same_scale_tensor_list is not None:
for tensor_list in self._same_scale_tensor_list:
max_scale = None
tmp_tensor_list = []
for tensor_name in tensor_list:
if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split(
'#')
if opera == '*':
scale_dict[real_tensor_name] = float(
scale_dict[real_tensor_name]) * float(
scalar)
elif opera == '/':
scale_dict[real_tensor_name] = float(
scale_dict[real_tensor_name]) / float(
scalar)
max_scale = scale_dict[
real_tensor_name] if max_scale is None else max(
max_scale, scale_dict[real_tensor_name])
else:
max_scale = scale_dict[
tensor_name] if max_scale is None else max(
max_scale, scale_dict[tensor_name])
for tensor_name in tensor_list:
if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split(
'#')
if opera == '*':
scale_dict[
real_tensor_name] = max_scale / float(
scalar)
elif opera == '/':
scale_dict[
real_tensor_name] = max_scale * float(
scalar)
else:
scale_dict[tensor_name] = max_scale
self._scale_dict = scale_dict
for key, val in self._scale_dict.items():
utils.set_variable_data(self._scope, self._place, key + "@scale", utils.set_variable_data(self._scope, self._place, key + "@scale",
np.array([val], dtype=np.float32)) np.array([val], dtype=np.float32))
utils.set_variable_data(self._scope, self._place, utils.set_variable_data(self._scope, self._place,
...@@ -1024,19 +1095,20 @@ class PostTrainingQuantization(object): ...@@ -1024,19 +1095,20 @@ class PostTrainingQuantization(object):
if not self._onnx_format: if not self._onnx_format:
# apply QuantizationFreezePass, and obtain the final quant model # apply QuantizationFreezePass, and obtain the final quant model
freeze_pass = QuantizationFreezePass( if self._freeze_model:
scope=self._scope, freeze_pass = QuantizationFreezePass(
place=self._place, scope=self._scope,
bias_correction=self._bias_correction, place=self._place,
weight_bits=self._weight_bits, bias_correction=self._bias_correction,
round_type=self._round_type, weight_bits=self._weight_bits,
activation_bits=self._activation_bits, round_type=self._round_type,
weight_quantize_type=self._weight_quantize_type, activation_bits=self._activation_bits,
quantizable_op_type=major_quantizable_op_types) weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True for sub_graph in graph.all_sub_graphs():
freeze_pass.apply(sub_graph) sub_graph._for_test = True
freeze_pass.apply(sub_graph)
else: else:
quant_weight_pass = QuantWeightPass(self._scope, self._place) quant_weight_pass = QuantWeightPass(self._scope, self._place)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
...@@ -1155,6 +1227,58 @@ class PostTrainingQuantization(object): ...@@ -1155,6 +1227,58 @@ class PostTrainingQuantization(object):
return (hist_index - 0.5) * bin_width return (hist_index - 0.5) * bin_width
class PostTrainingQuantizationProgram(PostTrainingQuantization):
def __init__(self,
executor,
program,
feed_list=None,
fetch_list=None,
scope=None,
batch_generator=None,
sample_generator=None,
data_loader=None,
batch_size=10,
batch_nums=None,
algo="KL",
hist_percent=0.99999,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
round_type='round',
learning_rate=0.001,
is_full_quantize=False,
bias_correction=False,
activation_bits=8,
weight_bits=8,
activation_quantize_type='range_abs_max',
weight_quantize_type='channel_wise_abs_max',
onnx_format=False,
freeze_model=True,
optimize_model=False,
is_use_cache_file=False,
skip_tensor_list=None,
same_scale_tensor_list=None,
scale_trainable=False,
cache_dir=None,
scale_dict=None,
return_graph=True):
super().__init__(executor, scope, None, None, None, batch_generator,
sample_generator, data_loader, batch_size, batch_nums,
algo, hist_percent, quantizable_op_type, round_type,
learning_rate, is_full_quantize, bias_correction,
activation_bits, weight_bits, activation_quantize_type,
weight_quantize_type, onnx_format, freeze_model,
optimize_model, is_use_cache_file, skip_tensor_list,
same_scale_tensor_list, scale_trainable, cache_dir,
scale_dict, return_graph)
self._program = program
assert feed_list is not None, \
"Feed list should not be None."
assert fetch_list is not None, \
"Fetch list should not be None."
self._feed_list = feed_list
self._fetch_list = fetch_list
class WeightQuantization(object): class WeightQuantization(object):
_supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] _supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
_supported_weight_quantize_type = ['channel_wise_abs_max', 'abs_max'] _supported_weight_quantize_type = ['channel_wise_abs_max', 'abs_max']
......
...@@ -124,7 +124,8 @@ class QuantizationTransformPass(object): ...@@ -124,7 +124,8 @@ class QuantizationTransformPass(object):
weight_preprocess_func=None, weight_preprocess_func=None,
act_preprocess_func=None, act_preprocess_func=None,
optimizer_func=None, optimizer_func=None,
executor=None): executor=None,
is_test=None):
r""" r"""
Constructor. Constructor.
...@@ -241,7 +242,7 @@ class QuantizationTransformPass(object): ...@@ -241,7 +242,7 @@ class QuantizationTransformPass(object):
self._quantizable_grad_ops = [ self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops '%s_grad' % (op) for op in self._quantizable_ops
] ]
self._is_test = None self._is_test = is_test
self._global_step = None self._global_step = None
self.create_var_map = {} self.create_var_map = {}
...@@ -260,7 +261,8 @@ class QuantizationTransformPass(object): ...@@ -260,7 +261,8 @@ class QuantizationTransformPass(object):
""" """
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test() if self._is_test is None:
self._is_test = graph.is_test()
# marked the variable which has been dequantized. # marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict() dequantized_vars = collections.OrderedDict()
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
...@@ -449,16 +451,21 @@ class QuantizationTransformPass(object): ...@@ -449,16 +451,21 @@ class QuantizationTransformPass(object):
var_type=var_node.type(), var_type=var_node.type(),
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
scale_name = self._quantized_scale_name(name)
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
try:
scale_value = np.array(
self._scope.find_var(scale_name).get_tensor())
except:
scale_value = np.zeros([1], dtype=data_type)
scale_var_node = graph.create_persistable_node( scale_var_node = graph.create_persistable_node(
name=self._quantized_scale_name(name), name=scale_name,
var_type=var_node.type(), var_type=var_node.type(),
shape=[1], shape=[1],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( _init_var_node(scale_var_node, scale_value, self._scope, self._place)
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_var_node,
np.zeros(scale_var_node.shape(), dtype=data_type),
self._scope, self._place)
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max', op_type='fake_quantize_abs_max',
attrs={ attrs={
...@@ -487,16 +494,20 @@ class QuantizationTransformPass(object): ...@@ -487,16 +494,20 @@ class QuantizationTransformPass(object):
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
scale_name = self._quantized_scale_name(name)
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
try:
scale_value = np.array(
self._scope.find_var(scale_name).get_tensor())
except:
scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)
scale_in_node = graph.create_persistable_node( scale_in_node = graph.create_persistable_node(
name=self._quantized_scale_name(name), name=scale_name,
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( _init_var_node(scale_in_node, scale_value, self._scope, self._place)
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_in_node,
np.array([_SCALE_DEFAULT_VALUE], dtype=data_type),
self._scope, self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
inputs = {'X': var_node, 'InScale': scale_in_node} inputs = {'X': var_node, 'InScale': scale_in_node}
...@@ -549,16 +560,20 @@ class QuantizationTransformPass(object): ...@@ -549,16 +560,20 @@ class QuantizationTransformPass(object):
var_type=var_node.type(), var_type=var_node.type(),
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
scale_name = self._quantized_scale_name(name)
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
try:
scale_value = np.array(
self._scope.find_var(scale_name).get_tensor())
except:
scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)
scale_in_node = graph.create_persistable_node( scale_in_node = graph.create_persistable_node(
name=self._quantized_scale_name(name), name=scale_name,
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( _init_var_node(scale_in_node, scale_value, self._scope, self._place)
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_in_node,
np.array([_SCALE_DEFAULT_VALUE], dtype=data_type),
self._scope, self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
ins = {'X': var_node, 'InScale': scale_in_node} ins = {'X': var_node, 'InScale': scale_in_node}
...@@ -628,16 +643,21 @@ class QuantizationTransformPass(object): ...@@ -628,16 +643,21 @@ class QuantizationTransformPass(object):
var_type=var_node.type(), var_type=var_node.type(),
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
scale_name = self._quantized_scale_name(name)
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
try:
scale_value = np.array(
self._scope.find_var(scale_name).get_tensor())
except:
scale_value = np.zeros([var_node.shape()[quant_axis]],
dtype=data_type)
scale_var_node = graph.create_persistable_node( scale_var_node = graph.create_persistable_node(
name=self._quantized_scale_name(name), name=self._quantized_scale_name(name),
var_type=var_node.type(), var_type=var_node.type(),
shape=[var_node.shape()[quant_axis]], shape=[var_node.shape()[quant_axis]],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( _init_var_node(scale_var_node, scale_value, self._scope, self._place)
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_var_node,
np.zeros(scale_var_node.shape(), dtype=data_type),
self._scope, self._place)
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_channel_wise_quantize_abs_max', op_type='fake_channel_wise_quantize_abs_max',
attrs={ attrs={
...@@ -1396,7 +1416,12 @@ class TransformForMobilePass(object): ...@@ -1396,7 +1416,12 @@ class TransformForMobilePass(object):
class OutScaleForTrainingPass(object): class OutScaleForTrainingPass(object):
def __init__(self, scope=None, place=None, moving_rate=0.9): def __init__(self,
scope=None,
place=None,
moving_rate=0.9,
is_test=None,
scale_dict=None):
""" """
This pass is used for calculating output scales of some operators. This pass is used for calculating output scales of some operators.
These output scales may be used by tensorRT or some other inference engines. These output scales may be used by tensorRT or some other inference engines.
...@@ -1411,8 +1436,9 @@ class OutScaleForTrainingPass(object): ...@@ -1411,8 +1436,9 @@ class OutScaleForTrainingPass(object):
self._scope = scope self._scope = scope
self._place = _get_paddle_place(place) self._place = _get_paddle_place(place)
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._is_test = None self._is_test = is_test
self._teller_set = utils._out_scale_op_list self._teller_set = utils._out_scale_op_list
self._scale_dict = scale_dict
def apply(self, graph): def apply(self, graph):
""" """
...@@ -1424,7 +1450,8 @@ class OutScaleForTrainingPass(object): ...@@ -1424,7 +1450,8 @@ class OutScaleForTrainingPass(object):
""" """
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test() if self._is_test is None:
self._is_test = graph.is_test()
target_ops = [] target_ops = []
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._teller_set: if op.name() in self._teller_set:
...@@ -1440,22 +1467,29 @@ class OutScaleForTrainingPass(object): ...@@ -1440,22 +1467,29 @@ class OutScaleForTrainingPass(object):
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue continue
data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
try: try:
graph._find_node_by_name( scale_node = graph._find_node_by_name(
graph.all_var_nodes(), graph.all_var_nodes(),
self._scale_name(in_node.name())) self._scale_name(in_node.name()))
continue
except: except:
scale_node = graph.create_persistable_node( scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()), name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=in_node.dtype()) var_dtype=in_node.dtype())
if self._scale_dict is not None:
try:
scale_value = np.array(
[self._scale_dict[in_node.name()]])
except:
scale_value = np.ones([1], dtype=data_type)
else:
scale_value = np.ones([1], dtype=data_type)
_init_var_node(scale_node, scale_value, self._scope,
self._place)
data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_node, np.ones([1], dtype=data_type),
self._scope, self._place)
ins = {'X': in_node} ins = {'X': in_node}
outs = {'OutScale': scale_node} outs = {'OutScale': scale_node}
if not self._is_test: if not self._is_test:
...@@ -1589,7 +1623,9 @@ class AddQuantDequantPass(object): ...@@ -1589,7 +1623,9 @@ class AddQuantDequantPass(object):
quant_bits=8, quant_bits=8,
skip_pattern=["skip_quant"], skip_pattern=["skip_quant"],
quantizable_op_type=["elementwise_add", "pool2d"], quantizable_op_type=["elementwise_add", "pool2d"],
is_full_quantized=False): is_full_quantized=False,
is_test=None,
scale_dict=None):
""" """
Constructor. Constructor.
...@@ -1616,8 +1652,9 @@ class AddQuantDequantPass(object): ...@@ -1616,8 +1652,9 @@ class AddQuantDequantPass(object):
self._place = _get_paddle_place(place) self._place = _get_paddle_place(place)
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._quant_bits = quant_bits self._quant_bits = quant_bits
self._is_test = None self._is_test = is_test
self._skip_pattern = skip_pattern self._skip_pattern = skip_pattern
self._scale_dict = scale_dict
if is_full_quantized: if is_full_quantized:
self._quantizable_op_type = utils._act_supported_quantizable_op_type self._quantizable_op_type = utils._act_supported_quantizable_op_type
...@@ -1645,7 +1682,8 @@ class AddQuantDequantPass(object): ...@@ -1645,7 +1682,8 @@ class AddQuantDequantPass(object):
""" """
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test() if self._is_test is None:
self._is_test = graph.is_test()
dequantized_vars_map = collections.OrderedDict() dequantized_vars_map = collections.OrderedDict()
# Forward stage, insert quant_dequant op # Forward stage, insert quant_dequant op
...@@ -1711,17 +1749,28 @@ class AddQuantDequantPass(object): ...@@ -1711,17 +1749,28 @@ class AddQuantDequantPass(object):
var_type=var_node.type(), var_type=var_node.type(),
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
scale_name = "{}.quant_dequant@scale".format(var_node.name())
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
try:
if self._scale_dict is not None and var_node.name(
) in self._scale_dict.keys():
scale_value = np.array([self._scale_dict[var_node.name()]],
dtype=data_type)
else:
scale_value = np.array(
self._scope.find_var(scale_name).get_tensor(),
dtype=data_type)
except:
scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)
scale_in_node = graph.create_persistable_node( scale_in_node = graph.create_persistable_node(
name="{}.quant_dequant@scale".format(var_node.name()), name="{}.quant_dequant@scale".format(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_in_node,
np.array([_SCALE_DEFAULT_VALUE], dtype=data_type),
self._scope, self._place)
_init_var_node(scale_in_node, scale_value, self._scope, self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
ins = {'X': var_node, 'InScale': scale_in_node} ins = {'X': var_node, 'InScale': scale_in_node}
outs = {'Out': quant_var_node, 'OutScale': scale_out_node} outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
...@@ -1992,7 +2041,8 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -1992,7 +2041,8 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
weight_preprocess_func=None, weight_preprocess_func=None,
act_preprocess_func=None, act_preprocess_func=None,
optimizer_func=None, optimizer_func=None,
executor=None): executor=None,
is_test=None):
r""" r"""
Args: Args:
scope(paddle.Scope): When activation use 'range_abs_max' as the quantize scope(paddle.Scope): When activation use 'range_abs_max' as the quantize
...@@ -2106,7 +2156,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -2106,7 +2156,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
self._quantizable_grad_ops = [ self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops '%s_grad' % (op) for op in self._quantizable_ops
] ]
self._is_test = None self._is_test = is_test
self._global_step = None self._global_step = None
self.create_var_map = {} self.create_var_map = {}
...@@ -2235,7 +2285,8 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -2235,7 +2285,8 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
""" """
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test() if self._is_test is None:
self._is_test = graph.is_test()
self.persistable_vars = [ self.persistable_vars = [
p.name() for p in graph.all_persistable_nodes() p.name() for p in graph.all_persistable_nodes()
...@@ -2285,7 +2336,8 @@ class AddQuantDequantPassV2(object): ...@@ -2285,7 +2336,8 @@ class AddQuantDequantPassV2(object):
quant_bits=8, quant_bits=8,
skip_pattern=["skip_quant"], skip_pattern=["skip_quant"],
quantizable_op_type=["elementwise_add", "pool2d"], quantizable_op_type=["elementwise_add", "pool2d"],
is_full_quantized=False): is_full_quantized=False,
is_test=None):
""" """
Args: Args:
scope(paddle.Scope): The scope is used to initialize these new parameters. scope(paddle.Scope): The scope is used to initialize these new parameters.
...@@ -2325,7 +2377,7 @@ class AddQuantDequantPassV2(object): ...@@ -2325,7 +2377,7 @@ class AddQuantDequantPassV2(object):
self._place = _get_paddle_place(place) self._place = _get_paddle_place(place)
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._quant_bits = quant_bits self._quant_bits = quant_bits
self._is_test = None self._is_test = is_test
self._skip_pattern = skip_pattern self._skip_pattern = skip_pattern
if is_full_quantized: if is_full_quantized:
...@@ -2355,7 +2407,8 @@ class AddQuantDequantPassV2(object): ...@@ -2355,7 +2407,8 @@ class AddQuantDequantPassV2(object):
""" """
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test() if self._is_test is None:
self._is_test = graph.is_test()
dequantized_vars_map = collections.OrderedDict() dequantized_vars_map = collections.OrderedDict()
self.persistable_vars = [ self.persistable_vars = [
......
...@@ -38,7 +38,6 @@ _act_supported_quantizable_op_type = [ ...@@ -38,7 +38,6 @@ _act_supported_quantizable_op_type = [
"mean", "mean",
"not_equal", "not_equal",
"reshape", "reshape",
"reshape2",
"dropout", "dropout",
"bilinear_interp", "bilinear_interp",
"nearest_interp", "nearest_interp",
......
...@@ -246,6 +246,7 @@ if(WIN32) ...@@ -246,6 +246,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_while) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_while)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_program_resnet50)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model)
list(REMOVE_ITEM TEST_OPS test_imperative_ptq) list(REMOVE_ITEM TEST_OPS test_imperative_ptq)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
...@@ -520,6 +521,8 @@ endforeach() ...@@ -520,6 +521,8 @@ endforeach()
if(NOT WIN32) if(NOT WIN32)
set_tests_properties(test_post_training_quantization_lstm_model set_tests_properties(test_post_training_quantization_lstm_model
PROPERTIES TIMEOUT 120) PROPERTIES TIMEOUT 120)
set_tests_properties(test_post_training_quantization_program_resnet50
PROPERTIES TIMEOUT 240)
set_tests_properties(test_post_training_quantization_mobilenetv1 set_tests_properties(test_post_training_quantization_mobilenetv1
PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY")
set_tests_properties(test_post_training_quantization_resnet50 set_tests_properties(test_post_training_quantization_resnet50
......
...@@ -292,13 +292,13 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -292,13 +292,13 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start FP32 inference for {0} on {1} images ...".format( print("Start FP32 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size)) model, infer_iterations * batch_size))
(fp32_throughput, fp32_latency, (fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(
fp32_acc1) = self.run_program(model_cache_folder + "/model", os.path.join(model_cache_folder, "model"), batch_size,
batch_size, infer_iterations) infer_iterations)
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size)) format(model, sample_iterations * batch_size))
self.generate_quantized_model(model_cache_folder + "/model", self.generate_quantized_model(os.path.join(model_cache_folder, "model"),
quantizable_op_type, algo, round_type, quantizable_op_type, algo, round_type,
is_full_quantize, is_use_cache_file, is_full_quantize, is_use_cache_file,
is_optimize_model, onnx_format) is_optimize_model, onnx_format)
...@@ -454,29 +454,5 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): ...@@ -454,29 +454,5 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
onnx_format=onnx_format) onnx_format=onnx_format)
class TestPostTrainingPtfForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_ptf_mobilenetv1(self):
model = "MobileNet-V1"
algo = "ptf"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
# The accuracy diff of post-training quantization (abs_max) maybe bigger
diff_threshold = 0.05
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import os
import sys
import time
import paddle
import random
import unittest
import functools
import contextlib
import numpy as np
import paddle.fluid as fluid
from PIL import Image, ImageEnhance
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantizationProgram
from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantization
paddle.enable_static()
random.seed(0)
np.random.seed(0)
THREAD = 1
DATA_DIM = 224
BUF_SIZE = 102400
DATA_DIR = 'data/ILSVRC2012'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.LANCZOS)
return img
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def process_image(sample, mode, color_jitter, rotate):
img_path = sample[0]
img = Image.open(img_path)
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= img_mean
img /= img_std
return img, sample[1]
def _reader_creator(file_list,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=DATA_DIR):
def reader():
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
np.random.shuffle(full_lines)
lines = full_lines
for line in lines:
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
if not os.path.exists(img_path):
continue
yield img_path, int(label)
mapper = functools.partial(process_image,
mode=mode,
color_jitter=color_jitter,
rotate=rotate)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def val(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization):
def run_program(self, model_path, batch_size, infer_iterations):
image_shape = [3, 224, 224]
place = fluid.CPUPlace()
exe = fluid.Executor(place)
[infer_program, feed_dict, fetch_targets] = \
fluid.io.load_inference_model(model_path, exe)
val_reader = paddle.batch(val(), batch_size)
iterations = infer_iterations
test_info = []
cnt = 0
periods = []
for batch_id, data in enumerate(val_reader()):
image = np.array([x[0].reshape(image_shape)
for x in data]).astype("float32")
label = np.array([x[1] for x in data]).astype("int64")
label = label.reshape([-1, 1])
t1 = time.time()
_, acc1, _ = exe.run(infer_program,
feed={
feed_dict[0]: image,
feed_dict[1]: label
},
fetch_list=fetch_targets)
t2 = time.time()
period = t2 - t1
periods.append(period)
test_info.append(np.mean(acc1) * len(data))
cnt += len(data)
if (batch_id + 1) % 100 == 0:
print("{0} images,".format(batch_id + 1))
sys.stdout.flush()
if (batch_id + 1) == iterations:
break
throughput = cnt / np.sum(periods)
latency = np.average(periods)
acc1 = np.sum(test_info) / cnt
[infer_program, feed_dict, fetch_targets] = \
fluid.io.load_inference_model(model_path, exe)
return (throughput, latency, acc1, infer_program, feed_dict,
fetch_targets)
def generate_quantized_model(
self,
program,
quantizable_op_type,
feed_list,
fetch_list,
algo="KL",
round_type="round",
is_full_quantize=False,
is_use_cache_file=False,
is_optimize_model=False,
onnx_format=False,
):
try:
os.system("mkdir " + self.int8_model)
except Exception as e:
print("Failed to create {} due to {}".format(
self.int8_model, str(e)))
sys.exit(-1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.global_scope()
val_reader = val()
same_scale_tensor_list = [[
'batch_norm_3.tmp_2#/#1', 'batch_norm_4.tmp_2#*#1'
], ['batch_norm_27.tmp_2', 'batch_norm_26.tmp_2']]
ptq = PostTrainingQuantizationProgram(
executor=exe,
program=program,
sample_generator=val_reader,
batch_nums=10,
algo=algo,
quantizable_op_type=quantizable_op_type,
round_type=round_type,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
onnx_format=onnx_format,
is_use_cache_file=is_use_cache_file,
feed_list=feed_list,
fetch_list=fetch_list,
same_scale_tensor_list=same_scale_tensor_list)
ptq.quantize()
ptq.save_quantized_model(self.int8_model)
def run_test(self,
model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=False):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
model_cache_folder = self.download_data(data_urls, data_md5s, model)
print("Start FP32 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
(fp32_throughput, fp32_latency, fp32_acc1, infer_program, feed_dict,
fetch_targets) = self.run_program(
os.path.join(model_cache_folder, "model"), batch_size,
infer_iterations)
print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size))
self.generate_quantized_model(infer_program, quantizable_op_type,
feed_dict, fetch_targets, algo,
round_type, is_full_quantize,
is_use_cache_file, is_optimize_model,
onnx_format)
print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
(int8_throughput, int8_latency, int8_acc1, _, _,
_) = self.run_program(self.int8_model, batch_size, infer_iterations)
print("---Post training quantization of {} method---".format(algo))
print(
"FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}."
.format(model, batch_size, fp32_throughput, fp32_latency,
fp32_acc1))
print(
"INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}.\n"
.format(model, batch_size, int8_throughput, int8_latency,
int8_acc1))
sys.stdout.flush()
delta_value = fp32_acc1 - int8_acc1
self.assertLess(delta_value, diff_threshold)
class TestPostTrainingProgramAbsMaxForResnet50(
TestPostTrainingQuantizationProgram):
def test_post_training_abs_max_resnet50(self):
model = "ResNet-50"
algo = "abs_max"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
data_md5s = ['4a5194524823d9b76da6e738e1367881']
quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
diff_threshold = 0.025
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
if __name__ == '__main__':
unittest.main()
...@@ -118,6 +118,11 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): ...@@ -118,6 +118,11 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
activation_quantize_type=activation_quant_type, activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type) weight_quantize_type=weight_quant_type)
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass = QuantizationTransformPass(
scope=scope,
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type)
transform_pass.apply(test_graph) transform_pass.apply(test_graph)
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
......
...@@ -313,6 +313,12 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -313,6 +313,12 @@ class TestQuantizationFreezePass(unittest.TestCase):
weight_quantize_type=weight_quant_type, weight_quantize_type=weight_quant_type,
skip_pattern=quant_skip_pattern) skip_pattern=quant_skip_pattern)
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass = QuantizationTransformPass(
scope=scope,
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type,
skip_pattern=quant_skip_pattern)
transform_pass.apply(test_graph) transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_' dev_name = '_gpu_' if use_cuda else '_cpu_'
if not for_ci: if not for_ci:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册