未验证 提交 82c30f71 编写于 作者: G Guanghua Yu 提交者: GitHub

add EMD method of post_quant (#40421)

上级 1593c7ca
......@@ -272,7 +272,7 @@ class PostTrainingQuantization(object):
]
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
self._support_algo_type = [
'KL', 'hist', 'avg', 'mse', 'abs_max', 'min_max'
'KL', 'hist', 'avg', 'mse', 'emd', 'abs_max', 'min_max'
]
self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = \
......@@ -349,7 +349,7 @@ class PostTrainingQuantization(object):
# The vars for algo = avg
self._quantized_var_avg = {}
# The best loss of algo = mse
self._best_mse_loss = {}
self._best_calibration_loss = {}
# The threshold for algo = abs_max, mse or avg
self._quantized_threshold = {}
......@@ -408,7 +408,7 @@ class PostTrainingQuantization(object):
np.array(self._quantized_var_avg[var_name]).mean()
if self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold()
if self._algo in ["KL", "abs_max", "hist", "avg", "mse"]:
if self._algo in ["KL", "abs_max", "hist", "avg", "mse", "emd"]:
self._update_program()
else:
self._save_input_threhold()
......@@ -582,6 +582,8 @@ class PostTrainingQuantization(object):
self._sample_min_max()
elif self._algo == "mse":
self._sample_mse()
elif self._algo == "emd":
self._sample_emd()
elif self._algo in ["KL", "hist"]:
self._sample_histogram()
......@@ -610,8 +612,8 @@ class PostTrainingQuantization(object):
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
if var_name not in self._best_mse_loss:
self._best_mse_loss[var_name] = float('inf')
if var_name not in self._best_calibration_loss:
self._best_calibration_loss[var_name] = float('inf')
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
......@@ -620,8 +622,49 @@ class PostTrainingQuantization(object):
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
mse_loss = ((var_tensor - quant_dequant_var)**2).mean()
if mse_loss <= self._best_mse_loss[var_name]:
self._best_mse_loss[var_name] = mse_loss
if mse_loss <= self._best_calibration_loss[var_name]:
self._best_calibration_loss[var_name] = mse_loss
self._quantized_threshold[var_name] = scale
def _sample_emd(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
abs_max_value = []
if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[:, i]))))
else:
for i in range(var_tensor.shape[0]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value
_logger.info("EMD searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
if var_name not in self._best_calibration_loss:
self._best_calibration_loss[var_name] = float('inf')
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
bins = 2**(self._activation_bits - 1) - 1
quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
emd_loss = np.abs(
np.mean(var_tensor) - np.mean(quant_dequant_var)) + np.abs(
np.std(var_tensor) - np.std(quant_dequant_var))
if emd_loss <= self._best_calibration_loss[var_name]:
self._best_calibration_loss[var_name] = emd_loss
self._quantized_threshold[var_name] = scale
def _sample_avg(self):
......
......@@ -244,6 +244,26 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
quant_iterations)
class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self):
model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "emd"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
def test_post_training_avg(self):
model_name = "mnist_model"
......
......@@ -394,5 +394,27 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
diff_threshold)
class TestPostTrainingEMDForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_avg_mobilenetv1(self):
model = "MobileNet-V1"
algo = "emd"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册