未验证 提交 12b170c3 编写于 作者: Z zhouzj 提交者: GitHub

Add automatic region division and Fix pre-tensor quant (#1517) (#1547)

Co-authored-by: Nzhouzj <41366441+zzjjay@users.noreply.github.com>
Co-authored-by: Ngushiqiao <77222802+gushiqiao@users.noreply.github.com>
上级 b020589b
......@@ -29,7 +29,7 @@ from paddle.fluid.contrib.slim.quantization import utils
from ..dist import merge
from ..core.graph_wrapper import GraphWrapper
from ..common import get_logger
from ..common import get_logger, recover_program
__all__ = ['ReconstructionQuantization', ]
......@@ -75,7 +75,6 @@ class ReconstructionQuantization(PostTrainingQuantization):
Load the FP32 model, and use the calibrate data to calculate the forward-stage.
Based on the sample data, we can get the quantization information, and obtain
the final quantized model.
Args:
None
Returns:
......@@ -156,7 +155,10 @@ class ReconstructionQuantization(PostTrainingQuantization):
scope=self._scope,
place=self._place,
quantized_op_pairs=self._quantized_op_pairs,
weight_op_pairs=self._weight_op_pairs,
weight_quantize_type=self._weight_quantize_type,
activation_bits=self._activation_bits,
weight_bits=self._weight_bits,
scale_dict=copy.deepcopy(self._scale_dict),
regions=self._config['regions'],
region_weights_names=self._config['region_weights_names'],
......@@ -169,6 +171,11 @@ class ReconstructionQuantization(PostTrainingQuantization):
limit=self._config['limit'])
self._program, self._scale_dict = reconstruction_quanter._run()
if self._algo in ["KL", "hist"]:
self._quantized_var_threshold = self._scale_dict
else:
self._quantized_threshold = self._scale_dict
def _postprocessing(self):
if self._algo is 'min_max':
self._save_input_threhold()
......@@ -210,7 +217,10 @@ class ReconstructionQuanter(object):
scope,
place,
quantized_op_pairs,
weight_op_pairs,
weight_quantize_type,
activation_bits,
weight_bits,
scale_dict,
regions,
region_weights_names,
......@@ -225,7 +235,6 @@ class ReconstructionQuanter(object):
'''
Reconstruction Quanter, used to optimize the rounding policy
by reconstructing the intermediate output.
Args:
data_loader(Python Generator, Paddle.io.DataLoader, optional): The
Generator or Dataloader provides calibrate data, and it could
......@@ -280,9 +289,12 @@ class ReconstructionQuanter(object):
self._scope = scope
self._place = place
self._quantized_op_pairs = quantized_op_pairs
self._weight_op_pairs = weight_op_pairs
self._weight_var_names = list(self._quantized_op_pairs.keys())
self._weight_quantize_type = weight_quantize_type
self._scale_dict = scale_dict
self._activation_bits = activation_bits
self._weight_bits = weight_bits
self._num_iterations = num_iterations
self._epochs = epochs
self._lr = lr
......@@ -344,15 +356,7 @@ class ReconstructionQuanter(object):
teacher_scope=None,
name_prefix="teacher_",
merge_feed=True, )
for name in self._weight_var_names:
weight_np = utils.load_variable_data(self._scope, name)
scale = self._scale_dict[name]
weight_np_floor = np.floor(utils.quant_tensor(weight_np, scale))
utils.set_variable_data(
self._scope,
self._place,
name,
weight_np_floor, )
self._graph = GraphWrapper(self._student_program)
if self._simulate_activation_quant:
......@@ -394,7 +398,6 @@ class ReconstructionQuanter(object):
optimizer = paddle.optimizer.Adam(
learning_rate=self._lr, parameters=update_params)
optimizer.minimize(total_loss)
self._exe.run(startup_program)
start_time = time.time()
prev_start_time = start_time
......@@ -420,39 +423,48 @@ class ReconstructionQuanter(object):
sys.stdout.flush()
if i + 1 == self._num_iterations:
break
if self._weight_quantize_type == 'channel_wise_abs_max':
self._update_scale()
self._update_weights_to_int()
if self._bias_correction:
self._bias_correction_w()
return self._program
return self._program, self._scale_dict
def _init_alpha(self, name, scale):
_tensor = utils.load_variable_data(self._scope, "teacher_" + name)
tensor_scaled = utils.quant_tensor(_tensor, scale)
tensor_scaled = utils.quant_tensor(
x=_tensor,
scale=scale,
weight_bits=self._weight_bits,
quant_axis=0 if self._weight_op_pairs[name] not in
utils._channelwise_quant_axis1_ops else 1)
tensor_floor = np.floor(tensor_scaled)
tensor = tensor_scaled - tensor_floor
alpha = -np.log((ZETA - GAMMA) / (tensor - GAMMA) - 1)
return alpha
def _soft_rounding(self, weight, scale, weight_bits=8):
def _soft_rounding(self, weight, scale):
"""
Define network of soft rounding.
Args:
weight: The quanted weight with dtype=float32
"""
bnt = (1 << (weight_bits - 1)) - 1
bnt = (1 << (self._weight_bits - 1)) - 1
def _quant(x, scale):
s = scale / bnt
quant_x = x / s
return quant_x
def _dequant(x, scale):
s = (scale + 1e-8) / bnt
s = scale / bnt
dequant_x = s * x
return dequant_x
quantized_weight = paddle.static.data(
weight_copy = paddle.static.data(
shape=weight.shape,
dtype=weight.dtype,
name=weight.name + '_quant', )
name=weight.name + '_copy', )
v = paddle.static.create_parameter(
shape=weight.shape,
......@@ -472,10 +484,15 @@ class ReconstructionQuanter(object):
shape=weight.shape,
name=weight.name + '.scale',
default_initializer=fluid.initializer.NumpyArrayInitializer(
scale, ), )
scale, ))
else:
scale_var = scale
w = _dequant(quantized_weight + h_v, scale_var)
quantized_weight = _quant(weight_copy, scale_var)
floor_weight = (paddle.floor(quantized_weight) - quantized_weight
).detach() + quantized_weight
clip_weight = paddle.clip(floor_weight + h_v, -bnt, bnt)
w = _dequant(clip_weight, scale_var)
return w
def _insert_soft_rounding(self):
......@@ -491,15 +508,15 @@ class ReconstructionQuanter(object):
scale = scale.repeat(shape[0], axis=1).T
else:
scale = scale.repeat(shape[1] * shape[2] * shape[3], axis=1)
scale = scale.reshape(shape)
scale = scale.reshape(shape)
self._insert_func(var=weight, scale=scale, func="_soft_rounding")
def _drop_quant_dequant(self, inputs, scale, weight_bits=8):
def _drop_quant_dequant(self, inputs, scale):
x = paddle.static.data(
shape=inputs.shape,
dtype=inputs.dtype,
name=inputs.name + '.tmp', )
bnt = (1 << (weight_bits - 1)) - 1
bnt = (1 << (self._weight_bits - 1)) - 1
scale = scale / bnt
dequantized_tensor = paddle.round(x / scale) * scale
quant_noise = x - dequantized_tensor
......@@ -509,13 +526,14 @@ class ReconstructionQuanter(object):
def _insert_drop_quant_dequant(self):
for op in self._graph.ops():
if op.type() in ['conv2d', 'depthwise_conv2d', 'mul']:
if op.type(
) in ['conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2']:
if op.type() in ['conv2d', 'depthwise_conv2d']:
if op.inputs("Filter")[0].name().startswith("teacher"):
break
else:
input = op.inputs("Input")[0]
if op.type() in ['mul']:
if op.type() in ['mul', 'matmul', 'matmul_v2']:
if op.inputs("Y")[0].name().startswith("teacher"):
break
else:
......@@ -540,7 +558,7 @@ class ReconstructionQuanter(object):
self._exe.run(startup_program)
# create var in program
for new_var in new_program.list_vars():
if new_var.name == var._var.name + '_quant' or new_var.name == var._var.name + '.tmp':
if new_var.name == var._var.name + '_copy' or new_var.name == var._var.name + '.tmp':
continue
elif new_var.name == var._var.name + '.alpha':
program.global_block().create_parameter(
......@@ -548,7 +566,8 @@ class ReconstructionQuanter(object):
shape=new_var.shape,
dtype=new_var.dtype,
type=new_var.type,
stop_gradient=new_var.stop_gradient, )
stop_gradient=False,
trainable=True)
elif new_var.name == var._var.name + '.scale':
program.global_block().create_parameter(
name=new_var.name,
......@@ -556,7 +575,7 @@ class ReconstructionQuanter(object):
dtype=new_var.dtype,
type=new_var.type,
stop_gradient=True,
trainable=self._scale_trainable, )
trainable=False)
else:
if func == "_soft_rounding":
program.global_block().create_var(
......@@ -568,7 +587,7 @@ class ReconstructionQuanter(object):
stop_gradient=new_var.stop_gradient, )
else:
program.global_block().create_var(
name=new_var.name,
name=new_var.name + '.qdrop',
shape=new_var.shape,
dtype=new_var.dtype,
type=new_var.type,
......@@ -579,11 +598,12 @@ class ReconstructionQuanter(object):
block = var._var.block
# prepend new_program's op in program
for _op in ops:
if _op.type() not in ['conv2d', 'depthwise_conv2d', 'mul']:
if _op.type() not in [
'conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'
]:
continue
idx = block.ops.index(_op._op)
for op in op_list:
# _attrs = op.all_attrs()
_type = op.type
_attrs = {
'use_mkldnn': False,
......@@ -603,7 +623,7 @@ class ReconstructionQuanter(object):
'scale': op.attr('scale'),
'bias_after_scale': op.attr('bias_after_scale'),
}
elif _type == 'elementwise_mul':
elif _type in ['elementwise_mul', 'elementwise_div']:
_attrs = {
'use_mkldnn': False,
'with_quant_attr': False,
......@@ -615,15 +635,17 @@ class ReconstructionQuanter(object):
if func == "_soft_rounding":
_outputs = {'Out': op.output('Out')[0] + '.rounding'}
if _type == "elementwise_add":
if _type in [
"elementwise_add", "elementwise_sub",
"elementwise_mul"
]:
_inputs = {
'X': var.
_var, # replace tmp var conv.weight_quant with var conv.weight
'X': op.input('X')[0] + '.rounding',
'Y': op.input('Y')[0] + '.rounding',
}
elif _type == "elementwise_mul":
elif _type == "elementwise_div":
_inputs = {
'X': op.input('X')[0] + '.rounding',
'X': var._var,
'Y': op.input('Y')[0] + '.rounding',
}
elif (_type == 'scale' and
......@@ -638,23 +660,22 @@ class ReconstructionQuanter(object):
elif func == "_drop_quant_dequant":
if _type == 'dropout':
_outputs = {
'Out': op.output('Out')[0],
'Mask': op.output('Mask')[0],
'Out': op.output('Out')[0] + '.qdrop',
'Mask': op.output('Mask')[0] + '.qdrop',
}
else:
_outputs = {'Out': op.output('Out')[0]}
_outputs = {'Out': op.output('Out')[0] + '.qdrop'}
if _type == 'elementwise_add' or _type == 'elementwise_sub':
_inputs = {
'X': var.
_var, # replace tmp var conv.weight_quant with var conv.weight
'Y': op.input('Y'),
'X': var._var,
'Y': op.input('Y')[0] + '.qdrop',
}
elif _type == 'scale' and op.input('X')[
0] == inputs.name + '.tmp':
_inputs = {'X': var._var}
else:
_inputs = {'X': op.input('X')[0]}
_inputs = {'X': op.input('X')[0] + '.qdrop'}
block._insert_op(
idx,
......@@ -663,18 +684,20 @@ class ReconstructionQuanter(object):
inputs=_inputs,
outputs=_outputs, )
for op in ops:
if op.type() not in ['conv2d', 'depthwise_conv2d', 'mul']:
if op.type() not in [
'conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'
]:
continue
if op.type() in ['conv2d', 'depthwise_conv2d'] and op.inputs(
'Filter')[0].name().startswith('teacher'):
continue
if op.type() in ['mul'] and op.inputs('Y')[0].name().startswith(
'teacher'):
if op.type() in ['mul', 'matmul', 'matmul_v2'] and op.inputs('Y')[
0].name().startswith('teacher'):
continue
if func == '_soft_rounding':
op._op._rename_input(inputs.name, out.name + '.rounding')
else:
op._op._rename_input(inputs.name, out.name)
op._op._rename_input(inputs.name, out.name + '.qdrop')
def _isolate_regions(self):
starts = [region[0] for region in self._regions]
......@@ -713,20 +736,41 @@ class ReconstructionQuanter(object):
op_._rename_input(var_.name, duplicated_var.name)
return vars
def _update_scale(self):
for _name in self._weight_var_names:
scale_name = _name + '.scale'
scale_tensor = utils.load_variable_data(self._scope, scale_name)
scale_list = []
if self._weight_op_pairs[
_name] in utils._channelwise_quant_axis1_ops:
scale_list = list(scale_tensor[0])
else:
for i in range(scale_tensor.shape[0]):
scale_list.append(scale_tensor[i][0][0][0])
self._scale_dict[scale_name] = scale_list
def _update_weights_to_int(self):
for weight_var_name in self._weight_var_names:
alpha_tensor = utils.load_variable_data(
self._scope,
weight_var_name + '.alpha', )
h_alpha_tensor = self._compute_soft_rounding_np(alpha_tensor)
weight_quant_tensor = utils.load_variable_data(
weight_tensor = utils.load_variable_data(
self._scope,
weight_var_name, )
weight_quant_tensor = utils.quant_tensor(
x=weight_tensor,
scale=self._scale_dict[weight_var_name],
weight_bits=self._weight_bits,
quant_axis=0 if self._weight_op_pairs[weight_var_name] not in
utils._channelwise_quant_axis1_ops else 1)
utils.set_variable_data(
self._scope,
self._place,
weight_var_name,
np.round(weight_quant_tensor + h_alpha_tensor, ), )
np.floor(weight_quant_tensor) + h_alpha_tensor, )
def _bias_correction_w(self):
for weight_var_name in self._weight_var_names:
......@@ -741,8 +785,9 @@ class ReconstructionQuanter(object):
weight_var_tensor,
weight_quant_tensor,
scale,
quant_axis=0,
weight_bits=8, )
quant_axis=0 if self._weight_op_pairs[weight_var_name] not in
utils._channelwise_quant_axis1_ops else 1,
weight_bits=self._weight_bits, )
utils.set_variable_data(
self._scope,
self._place,
......@@ -773,7 +818,6 @@ class ReconstructionQuanterLoss(object):
weight=0.1):
"""
The loss function of Rounding Optimizer.
Args:
program(Program): The student program.
weight_region_names(list, optional): The weight names inside a region.
......@@ -1040,11 +1084,8 @@ def quant_recon_static(executor,
hist_percent=0.9999,
bias_correction=False,
quantizable_op_type=[
"conv2d",
"depthwise_conv2d",
"mul",
"matmul",
"matmul_v2",
"conv2d", "depthwise_conv2d", "mul", "matmul",
"matmul_v2"
],
is_full_quantize=False,
weight_bits=8,
......@@ -1059,7 +1100,6 @@ def quant_recon_static(executor,
regions=None,
region_weights_names=None,
epochs=20,
scale_trainable=False,
drop_prob=0.5,
lr=0.1,
limit=6):
......@@ -1068,7 +1108,6 @@ def quant_recon_static(executor,
quantize the fp32 model. It uses calibrate data to calculate the
scale factor of quantized variables, and inserts fake quantization
and dequantization operators to obtain the quantized model.
Args:
executor(paddle.static.Executor): The executor to load, run and save the
quantized model.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册