未验证 提交 3d5a27f0 编写于 作者: G Guanghua Yu 提交者: GitHub

add adaround post-quant method (#38460)

* add adaround post-quant method
上级 56dc8c79
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import time
import sys
import logging
import paddle.fluid as fluid
from ....log_helper import get_logger
from .utils import load_variable_data, set_variable_data, stable_sigmoid, quant_tensor, dequant_tensor, _channelwise_quant_axis1_ops, calculate_quant_cos_error
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
GAMMA = -0.1
ZETA = 1.1
def compute_soft_rounding(alpha_v):
return fluid.layers.clip(
fluid.layers.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, min=0, max=1)
def compute_soft_rounding_np(alpha_v):
return np.clip(
stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, a_min=0, a_max=1)
class AdaRoundLoss(object):
def __init__(self, reg_param=0.01, default_beta_range=(20, 2)):
self.default_reg_param = reg_param
self.default_beta_range = default_beta_range
def compute_recon_loss(self, ada_quantized_output, orig_output):
square_cost = fluid.layers.square_error_cost(ada_quantized_output,
orig_output)
recon_loss = fluid.layers.reduce_mean(
fluid.layers.reduce_sum(
square_cost, dim=-1))
return recon_loss
def compute_round_loss(self, alpha_v, warm_start, beta):
def round_loss_fn():
# compute rectified sigmoid of parameter 'alpha' which maps it between zero and one
h_v = compute_soft_rounding(alpha_v)
# calculate regularization term - which ensures parameter to converge to exactly zeros and ones
# at the end of optimization
reg_term = fluid.layers.reduce_sum(-fluid.layers.pow(
fluid.layers.abs(2 * h_v - 1), factor=beta) + 1)
# calculate the rounding loss
round_loss = self.default_reg_param * reg_term
return round_loss
round_loss = fluid.layers.cond(warm_start, lambda: fluid.layers.fill_constant(shape=[1], dtype='float32', value=0.0), round_loss_fn)
return round_loss
def compute_beta(self, max_iter, cur_iter, warm_start):
# Start and stop beta for annealing of rounding loss (start_beta, end_beta)
start_beta, end_beta = self.default_beta_range
# iteration at end of warm start period, which is 20% of max iterations
warm_start_end_iter = warm_start * max_iter
# compute relative iteration of current iteration
rel_iter = (cur_iter - warm_start_end_iter) / (
max_iter - warm_start_end_iter)
beta = end_beta + 0.5 * (start_beta - end_beta) * (1 + np.cos(rel_iter *
np.pi))
return beta
class AdaRound(object):
def __init__(self,
scale,
weight_tensor,
scope=None,
weight_var_name=None,
weight_op_type=None,
is_train=True,
num_iterations=1000):
self.is_train = is_train
self.num_iterations = num_iterations
self.warm_start = 0.1
self.weight_bits = 8
self.offset = 0. # zero-point offset
self.adaround_loss = AdaRoundLoss()
self.ori_weight_tensor = weight_tensor
self.scale = scale
self.scope = scope
self.quant_axis = 0
if weight_op_type in _channelwise_quant_axis1_ops:
self.quant_axis = 1
self.weight_var_name = weight_var_name
self.alpha_name = weight_var_name + ".alpha"
self.initialize_alpha(weight_tensor.copy(), scale, weight_var_name)
def initialize_alpha(self, tensor, scale, var_name):
"""
Initializes alpha parameter, same shape as the weight tensor
"""
tensor_scale = quant_tensor(tensor, scale, quant_axis=self.quant_axis)
tensor_floor = np.floor(tensor_scale)
tensor = tensor_scale - tensor_floor
alpha = -np.log((ZETA - GAMMA) / (tensor - GAMMA) - 1)
self.alpha_v = fluid.layers.create_parameter(
shape=alpha.shape,
dtype="float32",
name=var_name + ".alpha",
default_initializer=fluid.initializer.NumpyArrayInitializer(alpha))
def _calculate_output_with_adarounded_weights(self, program, place, exe,
data, fp32_fetch_list,
weight_tensor_dequant):
set_variable_data(self.scope, place, self.weight_var_name,
weight_tensor_dequant)
adaround_out_tensor = exe.run(program=program,
feed=data,
fetch_list=[fp32_fetch_list],
return_numpy=True,
scope=self.scope)
return adaround_out_tensor
def _calculate_quant_weight(self):
np_alpha = load_variable_data(self.scope, self.alpha_name)
h_alpha = compute_soft_rounding_np(np_alpha)
# Scale the tensor
tensor_scale = quant_tensor(
self.ori_weight_tensor.copy(),
self.scale,
quant_axis=self.quant_axis)
weight_tensor = np.floor(tensor_scale)
# Adaround the tensor
weight_tensor_quant = np.add(weight_tensor, h_alpha)
return weight_tensor_quant
def _calculate_adarounded_weights(self):
weight_tensor_quant = self._calculate_quant_weight()
# Dequantize the tensor
weight_tensor_dequant = dequant_tensor(
weight_tensor_quant + self.offset,
self.scale,
quant_axis=self.quant_axis)
return weight_tensor_dequant
def update_final_weights(self):
weight_tensor_quant = self._calculate_quant_weight()
return weight_tensor_quant
def get_loss(self, beta, warm_start, adaround_out_tensor, orig_out_tensor):
round_loss = self.adaround_loss.compute_round_loss(self.alpha_v,
warm_start, beta)
recon_loss = self.adaround_loss.compute_recon_loss(adaround_out_tensor,
orig_out_tensor)
loss = round_loss + recon_loss
losses = {
'loss': loss,
'round_loss': round_loss,
'recon_loss': recon_loss
}
return losses
def update_beta_warm(self, cur_iteration):
warm_start = cur_iteration < self.num_iterations * self.warm_start
beta = self.adaround_loss.compute_beta(self.num_iterations,
cur_iteration, self.warm_start)
return beta, warm_start
def run_adaround(data_loader,
fp32_program,
fetch_list,
exe,
scope,
place,
quantized_op_pairs,
weight_op_pairs,
scale_dict,
num_iterations=1000,
lr=0.001,
fast_mode=True):
fetch_op_name = fetch_list[0].name
final_weight_tensor_quant_dict = {}
for weight_var_name, quant_op_out_name in quantized_op_pairs.items():
_logger.info('Start adaround op: {}'.format(weight_var_name))
weight_op_type = weight_op_pairs[weight_var_name]
# get scale and weight tensor
weight_var_tensor = load_variable_data(scope, weight_var_name)
scale = scale_dict[weight_var_name]
fp32_fetch_list = None
for _op in fp32_program.global_block().ops:
if _op.type == "fetch":
_op._rename_input(fetch_op_name, quant_op_out_name)
fp32_fetch_list = fp32_program.global_block().var(
quant_op_out_name)
fetch_op_name = quant_op_out_name
# build adaround program
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 1
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard():
# initialize adaround
adaround = AdaRound(
scale,
weight_var_tensor,
scope=scope,
weight_var_name=weight_var_name,
weight_op_type=weight_op_type,
num_iterations=num_iterations)
orig_out_tensor = fluid.data(
name='orig_out_tensor',
shape=fp32_fetch_list.shape,
dtype='float32')
adaround_out_tensor = fluid.data(
name='adaround_out_tensor',
shape=fp32_fetch_list.shape,
dtype='float32')
beta_tensor = fluid.data(
name='beta', shape=[1], dtype='float32')
warm_start_tensor = fluid.data(
name='warm_start', shape=[1], dtype='bool')
train_fetches_loss = adaround.get_loss(
beta_tensor, warm_start_tensor, adaround_out_tensor,
orig_out_tensor)
optimizer = fluid.optimizer.Adam(learning_rate=lr)
loss = train_fetches_loss['loss']
optimizer.minimize(loss)
exe.run(startup_program)
start_time = time.time()
prev_start_time = start_time
for i, data in enumerate(data_loader()):
prev_start_time = start_time
start_time = time.time()
# run fp32 model
np_orig_out_tensor = exe.run(program=fp32_program,
feed=data,
fetch_list=[fp32_fetch_list],
return_numpy=True,
scope=scope)
adaround_weight_tensor_dequant = adaround._calculate_adarounded_weights(
)
np_adaround_out_tensor = adaround._calculate_output_with_adarounded_weights(
fp32_program, place, exe, data, fp32_fetch_list,
adaround_weight_tensor_dequant)
# If the cosine distance of the two tensor is small, skip training
cos_error = calculate_quant_cos_error(np_orig_out_tensor[0],
np_adaround_out_tensor[0])
if fast_mode and cos_error > 0.99:
_logger.info("The cosine error is small, skip training.")
break
beta, warm_start = adaround.update_beta_warm(i)
feed_dict = {
'orig_out_tensor': np_orig_out_tensor[0],
'adaround_out_tensor': np_adaround_out_tensor[0],
'beta': beta,
'warm_start': warm_start
}
out = exe.run(
train_program,
feed=feed_dict,
fetch_list=[v.name for v in train_fetches_loss.values()],
return_numpy=True)
_logger.info(
"Iter {:d}, lr {:.5f}, loss {:.5f}, loss_round {:.5f}, loss_recon {:.5f}, time {:.5f}s".
format(i, lr,
np.mean(out[0]),
np.mean(out[1]),
np.mean(out[2]), start_time - prev_start_time))
sys.stdout.flush()
if i == num_iterations:
break
final_weight_tensor_quant_dict[
weight_var_name] = adaround.update_final_weights()
del adaround
# update adarounded calibrated weights
for weight_var_name in quantized_op_pairs.keys():
set_variable_data(scope, place, weight_var_name,
final_weight_tensor_quant_dict[weight_var_name])
......@@ -35,6 +35,8 @@ from .quantization_pass import _get_output_name_index
from .quantization_pass import _get_input_name_index
from .quantization_pass import _channelwise_quant_axis1_ops
from .cal_kl_threshold import cal_kl_threshold
from .adaround import run_adaround
from .utils import load_variable_data, set_variable_data
__all__ = ['PostTrainingQuantization', 'WeightQuantization']
......@@ -42,28 +44,6 @@ _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def _load_variable_data(scope, var_name):
'''
Load variable value from scope
'''
var_node = scope.find_var(var_name)
assert var_node is not None, \
"Cannot find " + var_name + " in scope."
return np.array(var_node.get_tensor())
def _set_variable_data(scope, place, var_name, np_value):
'''
Set the value of var node by name, if the node exits,
'''
assert isinstance(np_value, np.ndarray), \
'The type of value should be numpy array.'
var_node = scope.find_var(var_name)
if var_node != None:
tensor = var_node.get_tensor()
tensor.set(np_value, place)
def _all_persistable_var_names(program):
persistable_var_names = []
for var in program.list_vars():
......@@ -143,6 +123,8 @@ class PostTrainingQuantization(object):
algo="KL",
hist_percent=0.99999,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
round_type='round',
learning_rate=0.001,
is_full_quantize=False,
bias_correction=False,
activation_bits=8,
......@@ -198,6 +180,10 @@ class PostTrainingQuantization(object):
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
round_type(str, optional): The method of converting the quantized weights
value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the nearest whole number.
learning_rate(float, optional): The learning rate of adaround method.
is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type
......@@ -274,6 +260,9 @@ class PostTrainingQuantization(object):
self._support_algo_type = [
'KL', 'hist', 'avg', 'mse', 'emd', 'abs_max', 'min_max'
]
assert round_type in ['adaround', 'round']
self._round_type = round_type
self._learning_rate = learning_rate
self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = \
list(set(QuantizationTransformPass._supported_quantizable_op_type +
......@@ -401,6 +390,10 @@ class PostTrainingQuantization(object):
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish sampling stage, all batch: " + str(batch_id))
if self._round_type == 'adaround':
self._adaround_apply()
self._reset_activation_persistable()
if self._algo == 'avg':
for var_name in self._quantized_act_var_name:
......@@ -437,6 +430,24 @@ class PostTrainingQuantization(object):
return self._program
def _adaround_apply(self):
if self._algo in ["KL", "hist"]:
scale_dict = self._quantized_var_threshold
else:
scale_dict = self._quantized_threshold
run_adaround(
self._data_loader,
self._program,
self._fetch_list,
self._executor,
self._scope,
self._place,
self._quantized_op_pairs,
self._weight_op_pairs,
scale_dict,
num_iterations=self._batch_nums,
lr=self._learning_rate)
def save_quantized_model(self,
save_model_path,
model_filename=None,
......@@ -519,6 +530,7 @@ class PostTrainingQuantization(object):
'''
# TODO(juncaipeng), consider the name_scope of skip_quant
_logger.info("Collect quantized variable names ...")
self._quantized_op_pairs = {}
def collect_var_name(var_name_list, persistable_var_names, op_type):
for var_name in var_name_list:
......@@ -544,6 +556,12 @@ class PostTrainingQuantization(object):
collect_var_name(
_get_op_output_var_names(op), persistable_var_names,
op_type)
# collect quanted op output var name
for out_var_name in _get_op_output_var_names(op):
for in_var_name in _get_op_input_var_names(op):
if in_var_name in persistable_var_names:
self._quantized_op_pairs[
in_var_name] = out_var_name
# For other op, only sample output scale
elif op_type in self._out_scale_op_list:
collect_var_name(
......@@ -590,7 +608,7 @@ class PostTrainingQuantization(object):
def _sample_mse(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
......@@ -607,7 +625,7 @@ class PostTrainingQuantization(object):
self._quantized_threshold[var_name] = abs_max_value
_logger.info("MSE searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
......@@ -629,7 +647,7 @@ class PostTrainingQuantization(object):
def _sample_emd(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
......@@ -646,7 +664,7 @@ class PostTrainingQuantization(object):
self._quantized_threshold[var_name] = abs_max_value
_logger.info("EMD searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
......@@ -670,7 +688,7 @@ class PostTrainingQuantization(object):
def _sample_avg(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
......@@ -687,7 +705,7 @@ class PostTrainingQuantization(object):
self._quantized_threshold[var_name] = abs_max_value
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_var_avg):
self._quantized_var_avg[var_name] = []
......@@ -699,7 +717,7 @@ class PostTrainingQuantization(object):
def _sample_abs_max(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
......@@ -716,7 +734,7 @@ class PostTrainingQuantization(object):
self._quantized_threshold[var_name] = abs_max_value
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_threshold) or \
(abs_max_value > self._quantized_threshold[var_name]):
......@@ -725,7 +743,7 @@ class PostTrainingQuantization(object):
def _sample_min_max(self):
if self._quantized_var_min == {} and self._quantized_var_max == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
......@@ -745,7 +763,7 @@ class PostTrainingQuantization(object):
self._quantized_var_max[var_name] = max_value
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
if (var_name not in self._quantized_var_min) or \
......@@ -757,7 +775,7 @@ class PostTrainingQuantization(object):
def _sample_histogram(self):
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
var_tensor_abs = np.abs(var_tensor)
bins = self._sampling_act_histogram[var_name][1]
hist, _ = np.histogram(var_tensor_abs, bins=bins)
......@@ -787,7 +805,7 @@ class PostTrainingQuantization(object):
get the min and max value, and then calculate the threshold.
'''
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
var_tensor = np.abs(var_tensor)
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
......@@ -821,7 +839,7 @@ class PostTrainingQuantization(object):
# Abs_max threshold for weights
for var_name in self._quantized_weight_var_name:
weight_data = _load_variable_data(self._scope, var_name)
weight_data = load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
weight_threshold = float(np.max(np.abs(weight_data)))
elif self._weight_quantize_type == "channel_wise_abs_max":
......@@ -896,13 +914,13 @@ class PostTrainingQuantization(object):
else:
scale_dict = self._quantized_threshold
for key, val in scale_dict.items():
_set_variable_data(
set_variable_data(
self._scope,
self._place,
key + ".scale",
np.array(
[val], dtype=np.float32))
_set_variable_data(
set_variable_data(
self._scope,
self._place,
key + ".quant_dequant.scale",
......@@ -915,6 +933,7 @@ class PostTrainingQuantization(object):
place=self._place,
bias_correction=self._bias_correction,
weight_bits=self._weight_bits,
round_type=self._round_type,
activation_bits=self._activation_bits,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types)
......@@ -961,7 +980,7 @@ class PostTrainingQuantization(object):
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_hist")
elif self._algo in ["avg", "abs_max", "mse"]:
elif self._algo in ["avg", "abs_max", "mse", "emd"]:
save_info(op_node, out_var_name, self._quantized_threshold,
"out_threshold", "post_" + str(self._algo))
save_info(
......@@ -1003,7 +1022,7 @@ class PostTrainingQuantization(object):
for op in target_ops:
for var_name in _get_op_input_var_names(op):
if var_name in persistable_var_names:
var_data = _load_variable_data(self._scope, var_name)
var_data = load_variable_data(self._scope, var_name)
threshold = float(np.max(np.abs(var_data)))
argname, index = _get_input_name_index(op, var_name)
op._set_attr(argname + str(index) + "_threshold", threshold)
......@@ -1249,7 +1268,7 @@ class WeightQuantization(object):
save_weight_dtype = np.int8 if weight_bits == 8 else np.int16
# Get quantized scale and weight data
weight_data = _load_variable_data(scope, var_name)
weight_data = load_variable_data(scope, var_name)
if abs(threshold_rate) < 1e-10:
threshold_value = np.max(np.abs(weight_data))
else:
......@@ -1263,11 +1282,11 @@ class WeightQuantization(object):
# Set weight data
if not for_test:
_set_variable_data(scope, place, var_name, quantized_weight_data)
set_variable_data(scope, place, var_name, quantized_weight_data)
else:
dequantized_weight_data = \
(quantized_weight_data * scale).astype(np.float32)
_set_variable_data(scope, place, var_name, dequantized_weight_data)
set_variable_data(scope, place, var_name, dequantized_weight_data)
# Save info
op._set_attr('quantization_type', 'post_weight_abs_max')
......@@ -1284,7 +1303,7 @@ class WeightQuantization(object):
save_weight_dtype = np.int8 if weight_bits == 8 else np.int16
# Get quantized scale and weight data
weight_data = _load_variable_data(scope, var_name)
weight_data = load_variable_data(scope, var_name)
if op.type == "mul":
scales, quantized_weight_data = \
self._mul_channel_wise_quantization(weight_data,
......@@ -1298,7 +1317,7 @@ class WeightQuantization(object):
# Set weight data
if not for_test:
_set_variable_data(scope, place, var_name, quantized_weight_data)
set_variable_data(scope, place, var_name, quantized_weight_data)
else:
if op.type == "mul":
dequantized_weight_data = \
......@@ -1309,7 +1328,7 @@ class WeightQuantization(object):
else:
_logger.error(op.type +
" is not supported by weight quantization")
_set_variable_data(scope, place, var_name, dequantized_weight_data)
set_variable_data(scope, place, var_name, dequantized_weight_data)
# Save info
op._set_attr('quantization_type', 'post_weight_channel_wise_abs_max')
......
......@@ -26,6 +26,7 @@ from ....data import data
from ....layers import mean
from ....executor import scope_guard
from ....framework import _get_paddle_place
from .utils import _channelwise_quant_axis1_ops, quant_tensor
__all__ = [
'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass',
......@@ -233,10 +234,6 @@ _op_real_in_out_name = {
_conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose']
_channelwise_quant_axis1_ops = [
'conv2d_transpose', 'mul', 'matmul', 'matmul_v2'
]
def _get_op_input_var_names(op):
"""
......@@ -1206,6 +1203,7 @@ class QuantizationFreezePass(object):
bias_correction=False,
weight_bits=8,
activation_bits=8,
round_type='round',
weight_quantize_type='abs_max',
quantizable_op_type=None):
"""
......@@ -1223,6 +1221,9 @@ class QuantizationFreezePass(object):
https://arxiv.org/abs/1810.05723.
weight_bits(int): quantization bit number for weights.
activation_bits(int): quantization bit number for activation.
round_type(str, optional): The method of converting the quantized weights
value from float to int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the nearest whole number.
weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained.
......@@ -1238,6 +1239,7 @@ class QuantizationFreezePass(object):
self._place = _get_paddle_place(place)
self._weight_bits = weight_bits
self._activation_bits = activation_bits
self._round_type = round_type
self._weight_quantize_type = weight_quantize_type
self._fake_quant_op_names = _fake_quant_op_list
self._fake_dequant_op_names = _fake_dequant_op_list
......@@ -1284,18 +1286,22 @@ class QuantizationFreezePass(object):
self._quant_var_scale_map[input_arg_name] = scale_v
# Quantize weight and restore
param_v = self._load_var(input_arg_name)
if isinstance(scale_v, list) and \
any(_check_grandchild_op_node(op_node, op)
for op in _channelwise_quant_axis1_ops):
quant_axis = 1
else:
quant_axis = 0
quantized_param_v = self._quant(
param_v.copy(), scale_v, self._weight_bits, quant_axis)
if self._bias_correction == True:
quantized_param_v = self._bias_correction_w(
param_v, quantized_param_v, scale_v, quant_axis)
self._restore_var(input_arg_name, quantized_param_v)
if self._round_type == 'round':
if any(
_check_grandchild_op_node(op_node, op)
for op in _channelwise_quant_axis1_ops):
quant_axis = 1
else:
quant_axis = 0
quantized_param_v = quant_tensor(param_v.copy(),
scale_v, quant_axis,
self._weight_bits)
quantized_param_v = np.round(quantized_param_v)
if self._bias_correction == True:
quantized_param_v = self._bias_correction_w(
param_v, quantized_param_v, scale_v, quant_axis)
quantized_param_v = np.round(quantized_param_v)
self._restore_var(input_arg_name, quantized_param_v)
self._remove_fake_quant_and_dequant_op(graph, op_node)
# Remove all fake dequant op
......@@ -1513,31 +1519,6 @@ class QuantizationFreezePass(object):
return isinstance(v, float) or isinstance(v, np.float32) \
or isinstance(v, np.float64)
def _quant(self, x, scale, num_bits, quant_axis):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
bnt = (1 << (num_bits - 1)) - 1
def _clip(x, scale):
x[x > scale] = scale
x[x < -scale] = -scale
return x
if isinstance(scale, list):
for i, s in enumerate(scale):
if s == 0.0:
s = 1e-8
if quant_axis == 0:
x[i] = _clip(x[i], s)
x[i] = np.round(x[i] / s * bnt)
else:
x[:, i] = _clip(x[:, i], s)
x[:, i] = np.round(x[:, i] / s * bnt)
else:
scale = 1e-8 if scale == 0.0 else scale
x = _clip(x, scale)
x = np.round(x / scale * bnt)
return x
def _bias_correction_w(self, x, x_quant, scale_v, quant_axis):
'''
Bias correction for weight
......@@ -1574,8 +1555,8 @@ class QuantizationFreezePass(object):
mean_bias = np.resize(mean_bias, x.shape)
x_dequant = (mean_bias + x_dequant) * std_bias
quantized_param_v = self._quant(x_dequant, scale_v, self._weight_bits,
quant_axis)
quantized_param_v = quant_tensor(x_dequant, scale_v, quant_axis,
self._weight_bits)
return quantized_param_v
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
_channelwise_quant_axis1_ops = [
'conv2d_transpose', 'mul', 'matmul', 'matmul_v2'
]
def load_variable_data(scope, var_name):
'''
Load variable value from scope
'''
var_node = scope.find_var(var_name)
assert var_node is not None, \
"Cannot find " + var_name + " in scope."
return np.array(var_node.get_tensor())
def set_variable_data(scope, place, var_name, np_value):
'''
Set the value of var node by name, if the node exits,
'''
assert isinstance(np_value, np.ndarray), \
'The type of value should be numpy array.'
var_node = scope.find_var(var_name)
if var_node != None:
tensor = var_node.get_tensor()
tensor.set(np_value, place)
def quant_tensor(x, scale, quant_axis=0, weight_bits=8):
# symmetry quant
def _clip(x, scale):
x[x > scale] = scale
x[x < -scale] = -scale
return x
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
bnt = (1 << (weight_bits - 1)) - 1
if isinstance(scale, list):
for i, s in enumerate(scale):
if s == 0.0:
s = 1e-8
if quant_axis == 0:
x[i] = _clip(x[i], s)
x[i] = x[i] / s * bnt
else:
x[:, i] = _clip(x[:, i], s)
x[:, i] = x[:, i] / s * bnt
else:
scale = 1e-8 if scale == 0.0 else scale
x = _clip(x, scale)
x = x / scale * bnt
return x
def dequant_tensor(x, scale, quant_axis=0, weight_bits=8):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
bnt = (1 << (weight_bits - 1)) - 1
if isinstance(scale, list):
for i, s in enumerate(scale):
if s == 0.0:
s = 1e-8
if quant_axis == 0:
x[i] = x[i] * s / bnt
else:
x[:, i] = x[:, i] * s / bnt
else:
scale = 1e-8 if scale == 0.0 else scale
x = x * scale / bnt
return x
def stable_sigmoid(x):
sig = np.where(x < 0, np.exp(x) / (1 + np.exp(x)), 1 / (1 + np.exp(-x)))
return sig
def calculate_quant_cos_error(orig_tensor, qdq_tensor):
cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \
/ (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten()))
return cos_sim
......@@ -167,6 +167,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path,
data_path,
algo="KL",
round_type="round",
quantizable_op_type=["conv2d"],
is_full_quantize=False,
is_use_cache_file=False,
......@@ -186,6 +187,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_nums=batch_nums,
algo=algo,
quantizable_op_type=quantizable_op_type,
round_type=round_type,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
is_use_cache_file=is_use_cache_file)
......@@ -193,9 +195,9 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq.save_quantized_model(self.int8_model_path)
def run_test(self, model_name, model_url, model_md5, data_name, data_url,
data_md5, algo, quantizable_op_type, is_full_quantize,
is_use_cache_file, is_optimize_model, diff_threshold,
infer_iterations, quant_iterations):
data_md5, algo, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, infer_iterations, quant_iterations):
fp32_model_path = self.download_model(model_url, model_md5, model_name)
fp32_model_path = os.path.join(fp32_model_path, model_name)
......@@ -210,9 +212,9 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start post training quantization for {0} on {1} samples ...".
format(model_name, quant_iterations))
self.generate_quantized_model(fp32_model_path, data_path, algo,
quantizable_op_type, is_full_quantize,
is_use_cache_file, is_optimize_model,
quant_iterations)
round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file,
is_optimize_model, quant_iterations)
print("Start INT8 inference for {0} on {1} samples ...".format(
model_name, infer_iterations))
......@@ -239,6 +241,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5 = "add84c754e9b792fea1fbd728d134ab7"
algo = "KL"
round_type = "round"
quantizable_op_type = ["mul", "lstm"]
is_full_quantize = False
is_use_cache_file = False
......@@ -247,9 +250,32 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
infer_iterations = 100
quant_iterations = 10
self.run_test(model_name, model_url, model_md5, data_name, data_url,
data_md5, algo, quantizable_op_type, is_full_quantize,
is_use_cache_file, is_optimize_model, diff_threshold,
infer_iterations, quant_iterations)
data_md5, algo, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, infer_iterations, quant_iterations)
class TestPostTrainingKLForMnistAdaround(TestPostTrainingQuantization):
def test_post_training_kl(self):
model_name = "nlp_lstm_fp32_model"
model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz"
model_md5 = "519b8eeac756e7b4b7bcb2868e880452"
data_name = "quant_lstm_input_data"
data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5 = "add84c754e9b792fea1fbd728d134ab7"
algo = "KL"
round_type = "adaround"
quantizable_op_type = ["mul", "lstm"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
diff_threshold = 0.01
infer_iterations = 100
quant_iterations = 10
self.run_test(model_name, model_url, model_md5, data_name, data_url,
data_md5, algo, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, infer_iterations, quant_iterations)
if __name__ == '__main__':
......
......@@ -110,6 +110,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
def generate_quantized_model(self,
model_path,
algo="KL",
round_type="round",
quantizable_op_type=["conv2d"],
is_full_quantize=False,
is_use_cache_file=False,
......@@ -130,6 +131,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_nums=batch_nums,
algo=algo,
quantizable_op_type=quantizable_op_type,
round_type=round_type,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
is_use_cache_file=is_use_cache_file)
......@@ -141,6 +143,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
data_url,
data_md5,
algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......@@ -160,9 +163,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...".
format(model_name, quant_iterations * batch_size))
self.generate_quantized_model(
origin_model_path, algo, quantizable_op_type, is_full_quantize,
is_use_cache_file, is_optimize_model, batch_size, quant_iterations)
self.generate_quantized_model(origin_model_path, algo, round_type,
quantizable_op_type, is_full_quantize,
is_use_cache_file, is_optimize_model,
batch_size, quant_iterations)
print("Start INT8 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size))
......@@ -190,6 +194,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "KL"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -198,10 +203,10 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
class TestPostTraininghistForMnist(TestPostTrainingQuantization):
......@@ -210,6 +215,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "hist"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -218,10 +224,10 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
......@@ -230,6 +236,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "mse"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -238,10 +245,10 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
......@@ -250,6 +257,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "emd"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -258,10 +266,10 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
......@@ -270,6 +278,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "avg"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
......@@ -278,10 +287,10 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
......@@ -290,6 +299,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "abs_max"
round_type = "round"
quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = True
is_use_cache_file = False
......@@ -298,10 +308,31 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
batch_size = 10
infer_iterations = 50
quant_iterations = 10
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self):
model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "mse"
round_type = "adaround"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)
if __name__ == '__main__':
......
......@@ -240,6 +240,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path,
quantizable_op_type,
algo="KL",
round_type="round",
is_full_quantize=False,
is_use_cache_file=False,
is_optimize_model=False):
......@@ -261,15 +262,16 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_dir=model_path,
algo=algo,
quantizable_op_type=quantizable_op_type,
round_type=round_type,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
is_use_cache_file=is_use_cache_file)
ptq.quantize()
ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold):
def run_test(self, model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
......@@ -285,7 +287,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
format(model, sample_iterations * batch_size))
self.generate_quantized_model(
model_cache_folder + "/model", quantizable_op_type, algo,
is_full_quantize, is_use_cache_file, is_optimize_model)
round_type, is_full_quantize, is_use_cache_file, is_optimize_model)
print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
......@@ -309,6 +311,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_kl_mobilenetv1(self):
model = "MobileNet-V1"
algo = "KL"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
......@@ -323,15 +326,16 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_avg_mobilenetv1(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
......@@ -345,15 +349,16 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_hist_mobilenetv1(self):
model = "MobileNet-V1"
algo = "hist"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
......@@ -367,15 +372,16 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_abs_max_mobilenetv1(self):
model = "MobileNet-V1"
algo = "abs_max"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
......@@ -389,15 +395,110 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
is_optimize_model = False
# The accuracy diff of post-traing quantization (abs_max) maybe bigger
diff_threshold = 0.05
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
class TestPostTrainingAvgAdaRoundForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_adaround_mobilenetv1(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "adaround"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
class TestPostTrainingAbsMaxAdaRoundForMobilenetv1(
TestPostTrainingQuantization):
def test_post_training_adaround_mobilenetv1(self):
model = "MobileNet-V1"
algo = "abs_max"
round_type = "adaround"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
class TestPostTraininghistAdaroundForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_hist_mobilenetv1(self):
model = "MobileNet-V1"
algo = "hist"
round_type = "adaround"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
class TestPostTrainingKLAdaroundForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_kl_mobilenetv1(self):
model = "MobileNet-V1"
algo = "KL"
round_type = "adaround"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
"pool2d",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
class TestPostTrainingEMDForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_avg_mobilenetv1(self):
model = "MobileNet-V1"
algo = "emd"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
......@@ -411,9 +512,9 @@ class TestPostTrainingEMDForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
if __name__ == '__main__':
......
......@@ -24,6 +24,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def test_post_training_resnet50(self):
model = "ResNet-50"
algo = "min_max"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
......@@ -33,9 +34,9 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
is_use_cache_file = False
is_optimize_model = False
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册