未验证 提交 2ed9048b 编写于 作者: Z zhouzj 提交者: GitHub

Solve the bug that ReconPTQ cannot skip the specified tensor. (#1605)

上级 71b1769e
......@@ -161,6 +161,7 @@ class ReconstructionQuantization(
region_weights_names=self._config['region_weights_names'],
recon_level=self._config['recon_level'],
simulate_activation_quant=self._config['simulate_activation_quant'],
skip_tensor_list=self._skip_tensor_list,
num_iterations=self._batch_nums,
lr=self._config['lr'],
bias_correction=self._bias_correction,
......@@ -223,6 +224,7 @@ class ReconstructionQuanter(object):
region_weights_names,
recon_level,
simulate_activation_quant,
skip_tensor_list=None,
num_iterations=1000,
lr=0.1,
bias_correction=False,
......@@ -256,6 +258,7 @@ class ReconstructionQuanter(object):
Currently support ['layer-wise', 'region-wise'] types. Default is layer-wise.
simulate_activation_quant(bool, optional): Whether we need the noise caused by activation
quantization during the reconstruction process.
skip_tensor_list(list): List of skip quant tensor name.
regions(list[list], optional): The list of some regions, each region is a subgraph of
fp32 program and it will have exact 1 input operation and 1 output operation. When
the recon-level is region, the reconstruction loss of each region is minimized.
......@@ -299,6 +302,7 @@ class ReconstructionQuanter(object):
self._region_weights_names = region_weights_names
self._bias_correction = bias_correction
self._limit = limit
self._skip_tensor_list = skip_tensor_list
if recon_level == 'region-wise' and regions is None:
builder = RegionBuilder(program=self._program)
......@@ -327,6 +331,8 @@ class ReconstructionQuanter(object):
self._input_weight_pairs[in_var_name] = in_var_names
break
for name in self._weight_var_names:
if self._skip_tensor_list is not None and name in self._skip_tensor_list:
continue
region_weights_names.append([name])
region_ = []
region_.append(self._input_weight_pairs[name][0])
......@@ -735,7 +741,8 @@ class ReconstructionQuanter(object):
def _update_scale(self):
for _name in self._weight_var_names:
if self._skip_tensor_list is not None and _name in self._skip_tensor_list:
continue
scale_name = _name + '.scale'
scale_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
self._scope, scale_name)
......@@ -750,6 +757,8 @@ class ReconstructionQuanter(object):
def _update_weights_to_int(self):
for weight_var_name in self._weight_var_names:
if self._skip_tensor_list is not None and weight_var_name in self._skip_tensor_list:
continue
alpha_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
self._scope,
weight_var_name + '.alpha', )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册