未验证 提交 70cd4436 编写于 作者: Z zhouzj 提交者: GitHub

Solve the bug that ReconPTQ cannot skip the specified tensor. (#1605) (#1606)

上级 45c8f7ce
......@@ -23,10 +23,6 @@ import time
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import utils
from ..dist import merge
from ..core.graph_wrapper import GraphWrapper
from ..common import get_logger, recover_program
......@@ -52,7 +48,8 @@ class Collections(object):
return self._config
class ReconstructionQuantization(PostTrainingQuantization):
class ReconstructionQuantization(
paddle.fluid.contrib.slim.quantization.PostTrainingQuantization):
"""
Utilizing reconstruction quantization method to quantize the FP32 model,
and it uses calibrate data to get the quantization information for all
......@@ -95,7 +92,7 @@ class ReconstructionQuantization(PostTrainingQuantization):
def _preparation(self):
batch_id = 0
with utils.tqdm(
with paddle.fluid.contrib.slim.quantization.utils.tqdm(
total=self._batch_nums,
bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80, ) as t:
......@@ -115,7 +112,7 @@ class ReconstructionQuantization(PostTrainingQuantization):
def _sampling_threshold(self):
batch_id = 0
with utils.tqdm(
with paddle.fluid.contrib.slim.quantization.utils.tqdm(
total=self._batch_nums,
bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80, ) as t:
......@@ -164,6 +161,7 @@ class ReconstructionQuantization(PostTrainingQuantization):
region_weights_names=self._config['region_weights_names'],
recon_level=self._config['recon_level'],
simulate_activation_quant=self._config['simulate_activation_quant'],
skip_tensor_list=self._skip_tensor_list,
num_iterations=self._batch_nums,
lr=self._config['lr'],
bias_correction=self._bias_correction,
......@@ -226,6 +224,7 @@ class ReconstructionQuanter(object):
region_weights_names,
recon_level,
simulate_activation_quant,
skip_tensor_list=None,
num_iterations=1000,
lr=0.1,
bias_correction=False,
......@@ -239,7 +238,7 @@ class ReconstructionQuanter(object):
data_loader(Python Generator, Paddle.io.DataLoader, optional): The
Generator or Dataloader provides calibrate data, and it could
return a batch every time.
executor(fluid.Executor): The executor to load, run and save the
executor(paddle.static.Executor): The executor to load, run and save the
quantized model.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
......@@ -259,6 +258,7 @@ class ReconstructionQuanter(object):
Currently support ['layer-wise', 'region-wise'] types. Default is layer-wise.
simulate_activation_quant(bool, optional): Whether we need the noise caused by activation
quantization during the reconstruction process.
skip_tensor_list(list): List of skip quant tensor name.
regions(list[list], optional): The list of some regions, each region is a subgraph of
fp32 program and it will have exact 1 input operation and 1 output operation. When
the recon-level is region, the reconstruction loss of each region is minimized.
......@@ -302,6 +302,7 @@ class ReconstructionQuanter(object):
self._region_weights_names = region_weights_names
self._bias_correction = bias_correction
self._limit = limit
self._skip_tensor_list = skip_tensor_list
if recon_level == 'region-wise' and regions is None:
builder = RegionBuilder(program=self._program)
......@@ -322,13 +323,16 @@ class ReconstructionQuanter(object):
self._input_weight_pairs = {}
for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops:
in_var_names = utils._get_op_input_var_names(op)
in_var_names = paddle.fluid.contrib.slim.quantization.utils._get_op_input_var_names(
op)
for in_var_name in in_var_names:
if in_var_name in persistable_var_names:
in_var_names.remove(in_var_name)
self._input_weight_pairs[in_var_name] = in_var_names
break
for name in self._weight_var_names:
if self._skip_tensor_list is not None and name in self._skip_tensor_list:
continue
region_weights_names.append([name])
region_ = []
region_.append(self._input_weight_pairs[name][0])
......@@ -431,13 +435,14 @@ class ReconstructionQuanter(object):
return self._program, self._scale_dict
def _init_alpha(self, name, scale):
_tensor = utils.load_variable_data(self._scope, "teacher_" + name)
tensor_scaled = utils.quant_tensor(
_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
self._scope, "teacher_" + name)
tensor_scaled = paddle.fluid.contrib.slim.quantization.utils.quant_tensor(
x=_tensor,
scale=scale,
weight_bits=self._weight_bits,
quant_axis=0 if self._weight_op_pairs[name] not in
utils._channelwise_quant_axis1_ops else 1)
quant_axis=0 if self._weight_op_pairs[name] not in paddle.fluid.
contrib.slim.quantization.utils._channelwise_quant_axis1_ops else 1)
tensor_floor = np.floor(tensor_scaled)
tensor = tensor_scaled - tensor_floor
alpha = -np.log((ZETA - GAMMA) / (tensor - GAMMA) - 1)
......@@ -470,8 +475,7 @@ class ReconstructionQuanter(object):
shape=weight.shape,
dtype=weight.dtype,
name=weight.name + ".alpha",
default_initializer=fluid.initializer.NumpyArrayInitializer(
self._alpha, ), )
default_initializer=paddle.nn.initializer.Assign(self._alpha, ), )
h_v = paddle.clip(
paddle.nn.functional.sigmoid(v) * (ZETA - GAMMA) + GAMMA,
......@@ -483,8 +487,7 @@ class ReconstructionQuanter(object):
dtype=weight.dtype,
shape=weight.shape,
name=weight.name + '.scale',
default_initializer=fluid.initializer.NumpyArrayInitializer(
scale, ))
default_initializer=paddle.nn.initializer.Assign(scale, ))
else:
scale_var = scale
......@@ -497,6 +500,8 @@ class ReconstructionQuanter(object):
def _insert_soft_rounding(self):
for name in self._weight_var_names:
if self._skip_tensor_list is not None and name in self._skip_tensor_list:
continue
weight = self._graph.var(name)
scale = self._scale_dict[name]
shape = weight.shape()
......@@ -738,12 +743,14 @@ class ReconstructionQuanter(object):
def _update_scale(self):
for _name in self._weight_var_names:
if self._skip_tensor_list is not None and _name in self._skip_tensor_list:
continue
scale_name = _name + '.scale'
scale_tensor = utils.load_variable_data(self._scope, scale_name)
scale_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
self._scope, scale_name)
scale_list = []
if self._weight_op_pairs[
_name] in utils._channelwise_quant_axis1_ops:
_name] in paddle.fluid.contrib.slim.quantization.utils._channelwise_quant_axis1_ops:
scale_list = list(scale_tensor[0])
else:
for i in range(scale_tensor.shape[0]):
......@@ -752,21 +759,25 @@ class ReconstructionQuanter(object):
def _update_weights_to_int(self):
for weight_var_name in self._weight_var_names:
alpha_tensor = utils.load_variable_data(
if self._skip_tensor_list is not None and weight_var_name in self._skip_tensor_list:
continue
alpha_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
self._scope,
weight_var_name + '.alpha', )
h_alpha_tensor = self._compute_soft_rounding_np(alpha_tensor)
weight_tensor = utils.load_variable_data(
weight_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
self._scope,
weight_var_name, )
weight_quant_tensor = utils.quant_tensor(
weight_quant_tensor = paddle.fluid.contrib.slim.quantization.utils.quant_tensor(
x=weight_tensor,
scale=self._scale_dict[weight_var_name],
weight_bits=self._weight_bits,
quant_axis=0 if self._weight_op_pairs[weight_var_name] not in
utils._channelwise_quant_axis1_ops else 1)
quant_axis=0
if self._weight_op_pairs[weight_var_name] not in paddle.fluid.
contrib.slim.quantization.utils._channelwise_quant_axis1_ops
else 1)
utils.set_variable_data(
paddle.fluid.contrib.slim.quantization.utils.set_variable_data(
self._scope,
self._place,
weight_var_name,
......@@ -774,21 +785,23 @@ class ReconstructionQuanter(object):
def _bias_correction_w(self):
for weight_var_name in self._weight_var_names:
weight_var_tensor = utils.load_variable_data(
weight_var_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
self._scope,
"teacher_" + weight_var_name, )
weight_quant_tensor = utils.load_variable_data(
weight_quant_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
self._scope,
weight_var_name, )
scale = self._scale_dict[weight_var_name]
final_weight_tensor = utils.bias_correction_w(
final_weight_tensor = paddle.fluid.contrib.slim.quantization.utils.bias_correction_w(
weight_var_tensor,
weight_quant_tensor,
scale,
quant_axis=0 if self._weight_op_pairs[weight_var_name] not in
utils._channelwise_quant_axis1_ops else 1,
quant_axis=0
if self._weight_op_pairs[weight_var_name] not in paddle.fluid.
contrib.slim.quantization.utils._channelwise_quant_axis1_ops
else 1,
weight_bits=self._weight_bits, )
utils.set_variable_data(
paddle.fluid.contrib.slim.quantization.utils.set_variable_data(
self._scope,
self._place,
weight_var_name,
......@@ -796,7 +809,8 @@ class ReconstructionQuanter(object):
def _compute_soft_rounding_np(self, alpha_v):
return np.clip(
utils.stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA,
paddle.fluid.contrib.slim.quantization.utils.stable_sigmoid(alpha_v)
* (ZETA - GAMMA) + GAMMA,
a_min=0,
a_max=1, )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册