未验证 提交 460b5c63 编写于 作者: G gushiqiao 提交者: GitHub

Fixed naming conflicts and fc layer quantization (#1494)

上级 10b87911
......@@ -29,7 +29,7 @@ from paddle.fluid.contrib.slim.quantization import utils
from ..dist import merge
from ..core.graph_wrapper import GraphWrapper
from ..common import get_logger
from ..common import get_logger, recover_program
__all__ = ['ReconstructionQuantization', ]
......@@ -75,7 +75,6 @@ class ReconstructionQuantization(PostTrainingQuantization):
Load the FP32 model, and use the calibrate data to calculate the forward-stage.
Based on the sample data, we can get the quantization information, and obtain
the final quantized model.
Args:
None
Returns:
......@@ -156,6 +155,7 @@ class ReconstructionQuantization(PostTrainingQuantization):
scope=self._scope,
place=self._place,
quantized_op_pairs=self._quantized_op_pairs,
weight_op_pairs=self._weight_op_pairs,
weight_quantize_type=self._weight_quantize_type,
activation_bits=self._activation_bits,
weight_bits=self._weight_bits,
......@@ -167,8 +167,13 @@ class ReconstructionQuantization(PostTrainingQuantization):
num_iterations=self._batch_nums,
lr=self._config['lr'],
bias_correction=self._bias_correction,
epochs=self._config['epochs'], )
self._program = reconstruction_quanter._run()
epochs=self._config['epochs'])
self._program, self._scale_dict = reconstruction_quanter._run()
if self._algo in ["KL", "hist"]:
self._quantized_var_threshold = self._scale_dict
else:
self._quantized_threshold = self._scale_dict
def _postprocessing(self):
if self._algo is 'min_max':
......@@ -211,6 +216,7 @@ class ReconstructionQuanter(object):
scope,
place,
quantized_op_pairs,
weight_op_pairs,
weight_quantize_type,
activation_bits,
weight_bits,
......@@ -227,7 +233,6 @@ class ReconstructionQuanter(object):
'''
Reconstruction Quanter, used to optimize the rounding policy
by reconstructing the intermediate output.
Args:
data_loader(Python Generator, Paddle.io.DataLoader, optional): The
Generator or Dataloader provides calibrate data, and it could
......@@ -284,6 +289,7 @@ class ReconstructionQuanter(object):
self._scope = scope
self._place = place
self._quantized_op_pairs = quantized_op_pairs
self._weight_op_pairs = weight_op_pairs
self._weight_var_names = list(self._quantized_op_pairs.keys())
self._weight_quantize_type = weight_quantize_type
self._scale_dict = scale_dict
......@@ -323,6 +329,12 @@ class ReconstructionQuanter(object):
return regions, region_weights_names
def _preprocess(self):
for name in self._weight_var_names:
for i, s in enumerate(self._scale_dict[name]):
if s == 0.0:
self._scale_dict[name][i] = 1e-8
data_name_map = {}
for name in self._feed_list:
data_name_map[name] = name
......@@ -335,17 +347,7 @@ class ReconstructionQuanter(object):
teacher_scope=None,
name_prefix="teacher_",
merge_feed=True, )
for name in self._weight_var_names:
weight_np = utils.load_variable_data(self._scope, name)
scale = self._scale_dict[name]
weight_np_floor = np.floor(
utils.quant_tensor(
x=weight_np, scale=scale, weight_bits=self._weight_bits))
utils.set_variable_data(
self._scope,
self._place,
name,
weight_np_floor, )
self._graph = GraphWrapper(self._student_program)
if self._simulate_activation_quant:
......@@ -362,7 +364,8 @@ class ReconstructionQuanter(object):
tmp_program = self._student_program.clone()
quant_op_out_name = region_[1]
with paddle.static.program_guard(tmp_program, startup_program):
loss_function = ReconstructionQuanterLoss(tmp_program, names)
loss_function = ReconstructionQuanterLoss(
program=tmp_program, weight_region_names=names)
student_var = tmp_program.global_block().var(quant_op_out_name)
teacher_var = tmp_program.global_block().var("teacher_" +
quant_op_out_name)
......@@ -382,11 +385,11 @@ class ReconstructionQuanter(object):
}
optimizer = paddle.optimizer.Adam(learning_rate=self._lr)
optimizer.minimize(total_loss)
self._exe.run(startup_program)
start_time = time.time()
prev_start_time = start_time
loader = self._data_loader()
for epoch in range(self._epochs):
for i, data in (
enumerate(loader) if
......@@ -412,14 +415,21 @@ class ReconstructionQuanter(object):
sys.stdout.flush()
if i == self._num_iterations:
break
self._update_scale()
self._update_weights_to_int()
if self._bias_correction:
self._bias_correction_w()
return self._program
return self._program, self._scale_dict
def _init_alpha(self, name, scale):
_tensor = utils.load_variable_data(self._scope, "teacher_" + name)
tensor_scaled = utils.quant_tensor(_tensor, scale)
tensor_scaled = utils.quant_tensor(
x=_tensor,
scale=scale,
weight_bits=self._weight_bits,
quant_axis=0 if self._weight_op_pairs[name] not in
utils._channelwise_quant_axis1_ops else 1)
tensor_floor = np.floor(tensor_scaled)
tensor = tensor_scaled - tensor_floor
alpha = -np.log((ZETA - GAMMA) / (tensor - GAMMA) - 1)
......@@ -433,15 +443,20 @@ class ReconstructionQuanter(object):
"""
bnt = (1 << (self._weight_bits - 1)) - 1
def _quant(x, scale):
s = scale / bnt
quant_x = x / s
return quant_x
def _dequant(x, scale):
s = (scale + 1e-8) / bnt
s = scale / bnt
dequant_x = s * x
return dequant_x
quantized_weight = paddle.static.data(
weight_copy = paddle.static.data(
shape=weight.shape,
dtype=weight.dtype,
name=weight.name + '_quant', )
name=weight.name + '_copy', )
v = paddle.static.create_parameter(
shape=weight.shape,
......@@ -461,10 +476,15 @@ class ReconstructionQuanter(object):
shape=weight.shape,
name=weight.name + '.scale',
default_initializer=fluid.initializer.NumpyArrayInitializer(
scale, ), )
scale, ))
else:
scale_var = scale
w = _dequant(quantized_weight + h_v, scale_var)
quantized_weight = _quant(weight_copy, scale_var)
floor_weight = (paddle.floor(quantized_weight) - quantized_weight
).detach() + quantized_weight
clip_weight = paddle.clip(floor_weight + h_v, -bnt, bnt)
w = _dequant(clip_weight, scale_var)
return w
def _insert_soft_rounding(self):
......@@ -477,6 +497,7 @@ class ReconstructionQuanter(object):
scale = np.array(scale)
scale = scale.reshape(scale.shape[0], 1)
if len(shape) == 2:
print(name)
scale = scale.repeat(shape[0], axis=1).T
else:
scale = scale.repeat(shape[1] * shape[2] * shape[3], axis=1)
......@@ -498,13 +519,14 @@ class ReconstructionQuanter(object):
def _insert_drop_quant_dequant(self):
for op in self._graph.ops():
if op.type() in ['conv2d', 'depthwise_conv2d', 'mul']:
if op.type(
) in ['conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2']:
if op.type() in ['conv2d', 'depthwise_conv2d']:
if op.inputs("Filter")[0].name().startswith("teacher"):
break
else:
input = op.inputs("Input")[0]
if op.type() in ['mul']:
if op.type() in ['mul', 'matmul', 'matmul_v2']:
if op.inputs("Y")[0].name().startswith("teacher"):
break
else:
......@@ -529,7 +551,7 @@ class ReconstructionQuanter(object):
self._exe.run(startup_program)
# create var in program
for new_var in new_program.list_vars():
if new_var.name == var._var.name + '_quant' or new_var.name == var._var.name + '.tmp':
if new_var.name == var._var.name + '_copy' or new_var.name == var._var.name + '.tmp':
continue
elif new_var.name == var._var.name + '.alpha':
program.global_block().create_parameter(
......@@ -537,14 +559,16 @@ class ReconstructionQuanter(object):
shape=new_var.shape,
dtype=new_var.dtype,
type=new_var.type,
stop_gradient=new_var.stop_gradient, )
stop_gradient=False,
trainable=True)
elif new_var.name == var._var.name + '.scale':
program.global_block().create_parameter(
name=new_var.name,
shape=new_var.shape,
dtype=new_var.dtype,
type=new_var.type,
stop_gradient=True, )
stop_gradient=True,
trainable=False)
else:
if func == "_soft_rounding":
program.global_block().create_var(
......@@ -556,7 +580,7 @@ class ReconstructionQuanter(object):
stop_gradient=new_var.stop_gradient, )
else:
program.global_block().create_var(
name=new_var.name,
name=new_var.name + '.qdrop',
shape=new_var.shape,
dtype=new_var.dtype,
type=new_var.type,
......@@ -567,11 +591,12 @@ class ReconstructionQuanter(object):
block = var._var.block
# prepend new_program's op in program
for _op in ops:
if _op.type() not in ['conv2d', 'depthwise_conv2d', 'mul']:
if _op.type() not in [
'conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'
]:
continue
idx = block.ops.index(_op._op)
for op in op_list:
# _attrs = op.all_attrs()
_type = op.type
_attrs = {
'use_mkldnn': False,
......@@ -591,7 +616,7 @@ class ReconstructionQuanter(object):
'scale': op.attr('scale'),
'bias_after_scale': op.attr('bias_after_scale'),
}
elif _type == 'elementwise_mul':
elif _type in ['elementwise_mul', 'elementwise_div']:
_attrs = {
'use_mkldnn': False,
'with_quant_attr': False,
......@@ -603,15 +628,17 @@ class ReconstructionQuanter(object):
if func == "_soft_rounding":
_outputs = {'Out': op.output('Out')[0] + '.rounding'}
if _type == "elementwise_add":
if _type in [
"elementwise_add", "elementwise_sub",
"elementwise_mul"
]:
_inputs = {
'X': var.
_var, # replace tmp var conv.weight_quant with var conv.weight
'X': op.input('X')[0] + '.rounding',
'Y': op.input('Y')[0] + '.rounding',
}
elif _type == "elementwise_mul":
elif _type == "elementwise_div":
_inputs = {
'X': op.input('X')[0] + '.rounding',
'X': var._var,
'Y': op.input('Y')[0] + '.rounding',
}
elif (_type == 'scale' and
......@@ -623,23 +650,22 @@ class ReconstructionQuanter(object):
elif func == "_drop_quant_dequant":
if _type == 'dropout':
_outputs = {
'Out': op.output('Out')[0],
'Mask': op.output('Mask')[0],
'Out': op.output('Out')[0] + '.qdrop',
'Mask': op.output('Mask')[0] + '.qdrop',
}
else:
_outputs = {'Out': op.output('Out')[0]}
_outputs = {'Out': op.output('Out')[0] + '.qdrop'}
if _type == 'elementwise_add' or _type == 'elementwise_sub':
_inputs = {
'X': var.
_var, # replace tmp var conv.weight_quant with var conv.weight
'Y': op.input('Y'),
'X': var._var,
'Y': op.input('Y')[0] + '.qdrop',
}
elif _type == 'scale' and op.input('X')[
0] == inputs.name + '.tmp':
_inputs = {'X': var._var}
else:
_inputs = {'X': op.input('X')[0]}
_inputs = {'X': op.input('X')[0] + '.qdrop'}
block._insert_op(
idx,
......@@ -648,18 +674,20 @@ class ReconstructionQuanter(object):
inputs=_inputs,
outputs=_outputs, )
for op in ops:
if op.type() not in ['conv2d', 'depthwise_conv2d', 'mul']:
if op.type() not in [
'conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'
]:
continue
if op.type() in ['conv2d', 'depthwise_conv2d'] and op.inputs(
'Filter')[0].name().startswith('teacher'):
continue
if op.type() in ['mul'] and op.inputs('Y')[0].name().startswith(
'teacher'):
if op.type() in ['mul', 'matmul', 'matmul_v2'] and op.inputs('Y')[
0].name().startswith('teacher'):
continue
if func == '_soft_rounding':
op._op._rename_input(inputs.name, out.name + '.rounding')
else:
op._op._rename_input(inputs.name, out.name)
op._op._rename_input(inputs.name, out.name + '.qdrop')
def _isolate_regions(self):
starts = [region[0] for region in self._regions]
......@@ -698,20 +726,41 @@ class ReconstructionQuanter(object):
op_._rename_input(var_.name, duplicated_var.name)
return vars
def _update_scale(self):
for _name in self._weight_var_names:
scale_name = _name + '.scale'
scale_tensor = utils.load_variable_data(self._scope, scale_name)
scale_list = []
if self._weight_op_pairs[
_name] in utils._channelwise_quant_axis1_ops:
scale_list = list(scale_tensor[0])
else:
for i in range(scale_tensor.shape[0]):
scale_list.append(scale_tensor[i][0][0][0])
self._scale_dict[scale_name] = scale_list
def _update_weights_to_int(self):
for weight_var_name in self._weight_var_names:
alpha_tensor = utils.load_variable_data(
self._scope,
weight_var_name + '.alpha', )
h_alpha_tensor = self._compute_soft_rounding_np(alpha_tensor)
weight_quant_tensor = utils.load_variable_data(
weight_tensor = utils.load_variable_data(
self._scope,
weight_var_name, )
weight_quant_tensor = utils.quant_tensor(
x=weight_tensor,
scale=self._scale_dict[weight_var_name],
weight_bits=self._weight_bits,
quant_axis=0 if self._weight_op_pairs[weight_var_name] not in
utils._channelwise_quant_axis1_ops else 1)
utils.set_variable_data(
self._scope,
self._place,
weight_var_name,
np.round(weight_quant_tensor + h_alpha_tensor, ), )
np.floor(weight_quant_tensor) + h_alpha_tensor, )
def _bias_correction_w(self):
for weight_var_name in self._weight_var_names:
......@@ -726,7 +775,8 @@ class ReconstructionQuanter(object):
weight_var_tensor,
weight_quant_tensor,
scale,
quant_axis=0,
quant_axis=0 if self._weight_op_pairs[weight_var_name] not in
utils._channelwise_quant_axis1_ops else 1,
weight_bits=self._weight_bits, )
utils.set_variable_data(
self._scope,
......@@ -758,7 +808,6 @@ class ReconstructionQuanterLoss(object):
weight=0.1):
"""
The loss function of Rounding Optimizer.
Args:
program(Program): The student program.
weight_region_names(list, optional): The weight names inside a region.
......@@ -829,9 +878,8 @@ def quant_recon_static(executor,
hist_percent=0.9999,
bias_correction=False,
quantizable_op_type=[
"conv2d",
"depthwise_conv2d",
"mul",
"conv2d", "depthwise_conv2d", "mul", "matmul",
"matmul_v2"
],
is_full_quantize=False,
weight_bits=8,
......@@ -853,7 +901,6 @@ def quant_recon_static(executor,
quantize the fp32 model. It uses calibrate data to calculate the
scale factor of quantized variables, and inserts fake quantization
and dequantization operators to obtain the quantized model.
Args:
executor(paddle.static.Executor): The executor to load, run and save the
quantized model.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册