未验证 提交 fb3fac19 编写于 作者: G gushiqiao 提交者: GitHub

Add automatic region division and Fix pre-tensor quant (#1517)

* Fixed naming conflicts and fc layer quantification

* Update fine_tune.py

* Update readme

* Add automatic region division and Fix pre-tensor quant

* Update unit test file

* Fix bugs.
Co-authored-by: Nzhouzj <41366441+zzjjay@users.noreply.github.com>
上级 4fb41e4c
......@@ -74,7 +74,7 @@ $$
说明:
- 如果想使用bias_correction,可以在PaddleSlim的[离线量化接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/quant/quantization_api.rst#quant_post_static)修改`bias_correction`参数为True即可,默认为False。
- 如果想使用Adaround方法,可以在PaddleSlim的[离线量化接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/quant/quantization_api.rst#quant_post_static)修改`round_type`参数为`adaround`即可,默认为`round`
- 如果想使用BRECQ方法,可以在PaddleSlim的[量化重构接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/quant/quantization_api.rst#quant_post_static)修改`recon_level`参数为`regionn-wise`即可,默认为`layer-wise`
- 如果想使用BRECQ方法,可以在PaddleSlim的[量化重构接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/quant/quantization_api.rst#quant_post_static)修改`recon_level`参数为`region-wise`即可,默认为`layer-wise`
- 如果想使用QDrop方法,可以在PaddleSlim的[量化重构接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/quant/quantization_api.rst#quant_post_static)修改`simulate_activation_quant`参数为`True`即可,默认为`False`
### 效果对比
......
arch: YOLOv6
model_dir: ./yolov6s.onnx
model_dir: ./yolov6s.onnx
dataset_dir: /dataset/coco/
model_filename: model.pdmodel
params_filename: model.pdiparams
......@@ -8,25 +8,3 @@ val_image_dir: val2017
train_anno_path: annotations/instances_train2017.json
val_anno_path: annotations/instances_val2017.json
skip_tensor_list: None
regions: [['x2paddle_image_arrays','relu_8.tmp_0'],
['relu_8.tmp_0','relu_15.tmp_0'],
['relu_15.tmp_0','relu_21.tmp_0'],
['concat_1.tmp_0','relu_26.tmp_0'],
['concat_2.tmp_0', 'relu_30.tmp_0'],
['relu_30.tmp_0', 'concat_4.tmp_0'],
['relu_30.tmp_0', 'relu_31.tmp_0'],
['concat_3.tmp_0', 'relu_35.tmp_0'],
['relu_35.tmp_0', 'relu_36.tmp_0'],
['concat_5.tmp_0', 'concat_10.tmp_0'],
['relu_35.tmp_0', 'concat_8.tmp_0']]
region_weights_names: [['conv2d_0.w_0','conv2d_1.w_0','conv2d_2.w_0','conv2d_3.w_0','conv2d_4.w_0','conv2d_5.w_0','conv2d_6.w_0','conv2d_7.w_0','conv2d_8.w_0'],
['conv2d_9.w_0','conv2d_10.w_0','conv2d_11.w_0','conv2d_12.w_0','conv2d_13.w_0','conv2d_14.w_0','conv2d_15.w_0'],
['conv2d_16.w_0','conv2d_17.w_0','conv2d_18.w_0','conv2d_19.w_0','conv2d_20.w_0','conv2d_21.w_0'],
['conv2d_22.w_0','conv2d_23.w_0','conv2d_24.w_0','conv2d_25.w_0','conv2d_26.w_0'],
['conv2d_27.w_0','conv2d_28.w_0','conv2d_29.w_0','conv2d_30.w_0'],
['conv2d_32.w_0','conv2d_34.w_0','conv2d_35.w_0','conv2d_37.w_0','conv2d_38.w_0','conv2d_39.w_0'],
['conv2d_31.w_0'],
['conv2d_33.w_0','conv2d_36.w_0','conv2d_40.w_0','conv2d_41.w_0'],
['conv2d_42.w_0'],
['conv2d_44.w_0','conv2d_47.w_0','conv2d_51.w_0','conv2d_52.w_0','conv2d_53.w_0','conv2d_54.w_0','conv2d_55.w_0','conv2d_56.w_0','conv2d_57.w_0','conv2d_58.w_0'],
['conv2d_43.w_0','conv2d_45.w_0','conv2d_46.w_0','conv2d_49.w_0','conv2d_48.w_0','conv2d_50.w_0'],]
\ No newline at end of file
......@@ -43,8 +43,6 @@ def argsparser():
help="which device used to compress.")
parser.add_argument(
'--algo', type=str, default='avg', help="post quant algo.")
parser.add_argument(
'--round_type', type=str, default='adaround', help="round type.")
parser.add_argument('--gpu', type=int, default=0, help='gpu index')
parser.add_argument(
......@@ -59,6 +57,10 @@ def argsparser():
help='simulate activation quant')
parser.add_argument(
'--epochs', type=int, default=20, help='steps to reconstruct')
parser.add_argument(
'--lr', type=float, default=0.1, help='learning rate of reconstruct')
parser.add_argument(
'--limit', type=int, default=5, help='size of each region')
return parser
......@@ -104,10 +106,11 @@ def main():
weight_quantize_type='channel_wise_abs_max',
recon_level=FLAGS.recon_level,
simulate_activation_quant=FLAGS.simulate_activation_quant,
regions=config['regions'],
region_weights_names=config['region_weights_names'],
regions=None,
region_weights_names=None,
epochs=FLAGS.epochs,
lr=0.1)
lr=FLAGS.lr,
limit=FLAGS.limit)
if __name__ == '__main__':
......
......@@ -167,7 +167,8 @@ class ReconstructionQuantization(PostTrainingQuantization):
num_iterations=self._batch_nums,
lr=self._config['lr'],
bias_correction=self._bias_correction,
epochs=self._config['epochs'])
epochs=self._config['epochs'],
limit=self._config['limit'])
self._program, self._scale_dict = reconstruction_quanter._run()
if self._algo in ["KL", "hist"]:
......@@ -229,7 +230,8 @@ class ReconstructionQuanter(object):
lr=0.1,
bias_correction=False,
epochs=20,
drop_prob=0.5):
drop_prob=0.5,
limit=5):
'''
Reconstruction Quanter, used to optimize the rounding policy
by reconstructing the intermediate output.
......@@ -266,19 +268,17 @@ class ReconstructionQuanter(object):
lr(float, optional): The learning rate of Reconstruction Quanter. Default is 0.1.
bias_correction(bool, optional): If set as True, use the bias correction
method of https://arxiv.org/abs/1810.05723. Default is False.
drop_prob: The dropout probability of activation quantization, and it is valid only if
drop_prob(float, optional): The dropout probability of activation quantization, and it is valid only if
simulate_activation_quant is True. Default is 0.5.
limit(int, optional): The size of each region. Default is 5.
Returns:
None
'''
assert recon_level in [
'layer-wise', 'region-wise'
], "recon_level must be one of the ['layer-wise', 'region-wise'],but received: {}".format(
], "recon_level must be one of the ['layer-wise', 'region-wise'], but received: {}".format(
recon_level)
if recon_level == 'region-wise':
assert regions is not None, "The regions cannot be None."
assert region_weights_names is not None, "The region_weights_names cannot be None."
self._simulate_activation_quant = simulate_activation_quant
self._program = fp32_program
self._data_loader = data_loader
......@@ -301,7 +301,15 @@ class ReconstructionQuanter(object):
self._regions = regions
self._region_weights_names = region_weights_names
self._bias_correction = bias_correction
if self._recon_level == 'layer-wise':
self._limit = limit
if recon_level == 'region-wise' and regions is None:
builder = RegionBuilder(program=self._program)
_logger.info('Begin Region division')
self._regions, self._region_weights_names = builder._create_regions(
limit=self._limit)
_logger.info('End Region division')
elif self._recon_level == 'layer-wise':
regions, region_weights_names = self._get_layers()
self._regions = regions
self._region_weights_names = region_weights_names
......@@ -330,10 +338,11 @@ class ReconstructionQuanter(object):
def _preprocess(self):
for name in self._weight_var_names:
for i, s in enumerate(self._scale_dict[name]):
if s == 0.0:
self._scale_dict[name][i] = 1e-8
if self._weight_quantize_type == 'channel_wise_abs_max':
for name in self._weight_var_names:
for i, s in enumerate(self._scale_dict[name]):
if s == 0.0:
self._scale_dict[name][i] = 1e-8
data_name_map = {}
for name in self._feed_list:
......@@ -363,9 +372,10 @@ class ReconstructionQuanter(object):
region_ = self._regions[k]
tmp_program.global_block().var(region_[0]).stop_gradient = True
quant_op_out_name = region_[1]
_logger.info(f"Region's input: {region_[0]} output: {region_[1]}")
names = self._region_weights_names[k]
_logger.info(f"Current weights: {names}")
_logger.info(f"Current quanted weights: {names}")
loss_function = ReconstructionQuanterLoss(
program=tmp_program, weight_region_names=names)
update_params = [
......@@ -413,8 +423,8 @@ class ReconstructionQuanter(object):
sys.stdout.flush()
if i + 1 == self._num_iterations:
break
self._update_scale()
if self._weight_quantize_type == 'channel_wise_abs_max':
self._update_scale()
self._update_weights_to_int()
if self._bias_correction:
self._bias_correction_w()
......@@ -495,7 +505,6 @@ class ReconstructionQuanter(object):
scale = np.array(scale)
scale = scale.reshape(scale.shape[0], 1)
if len(shape) == 2:
print(name)
scale = scale.repeat(shape[0], axis=1).T
else:
scale = scale.repeat(shape[1] * shape[2] * shape[3], axis=1)
......@@ -643,6 +652,9 @@ class ReconstructionQuanter(object):
op.input('X')[0].endswith('scale')
) or _type == 'sigmoid':
_inputs = {'X': op.input('X')[0]}
elif (_type == 'scale' and
op.input('X')[0].endswith('copy')):
_inputs = {'X': var._var}
else:
_inputs = {'X': op.input('X')[0] + '.rounding'}
elif func == "_drop_quant_dequant":
......@@ -857,6 +869,202 @@ class ReconstructionQuanterLoss(object):
return total_loss, rec_loss, round_loss
class PriorityQueue:
def __init__(self):
self._data = []
self._ops = set()
self._idx = 0
self._lazy_tag = True
def pop(self):
if not self._lazy_tag:
self._data = sorted(self._data, key=lambda x: x[0])
self._lazy_tag = True
if self._idx >= len(self._data): raise IndexError('Index out of range!')
ele = self._data[self._idx]
self._idx += 1
return ele
def push(self, depth, op):
if op in self._ops: return
self._data.append((depth, op))
self._ops.add(op)
self._lazy_tag = False
def empty(self):
return self._idx >= len(self._data)
class RegionBuilder(object):
def __init__(self, program):
self._program = program
self._graph = GraphWrapper(self._program)
self._op_idx_map = {}
for op in self._graph.ops():
self._op_idx_map[op.idx()] = op
self._depth = {}
self._init_depth()
self._cache = {}
self._regions = []
self._region_weights_names = []
def _init_depth(self):
for op in self._graph.ops():
if len(self._graph.pre_ops(op)) == 0:
self._depth[op.idx()] = 0
continue
depths_cache = []
for up_op in self._graph.pre_ops(op):
assert up_op.idx() in self._depth
depths_cache.append(self._depth[up_op.idx()])
self._depth[op.idx()] = max(depths_cache) + 1
def _build(self, op, limit):
def _find_multi_input_ep(op):
least_first_queue = PriorityQueue()
for down_op in self._graph.next_ops(op):
least_first_queue.push(self._depth[down_op.idx()],
down_op.idx())
while not least_first_queue.empty():
iter_op_idx = least_first_queue.pop()[-1]
iter_op = self._op_idx_map[iter_op_idx]
if (least_first_queue.empty() and
len(self._graph.pre_ops(iter_op)) > 1):
return iter_op
for down_op in self._graph.next_ops(iter_op):
least_first_queue.push(self._depth[down_op.idx()],
down_op.idx())
return None
def _find_coherent_ep(op):
ops = self._graph.next_ops(op)
if len(ops) == 1:
following_op = ops[0]
if following_op.type() == 'fetch':
return None
inps = op.all_inputs()
non_parameter_input = 0
for var in inps:
if not var._var.persistable:
non_parameter_input += 1
upstream_ops = len(self._graph.pre_ops(following_op))
if non_parameter_input == 1 and upstream_ops == 1:
return ops[0]
return None
sp, ep, future_ep = op, op, op
while future_ep is not None:
if len(self._graph.next_ops(ep)) <= 1:
future_ep = _find_coherent_ep(ep)
else:
future_ep = _find_multi_input_ep(ep)
if future_ep is None or self._depth[future_ep.idx()] - self._depth[
sp.idx()] >= limit:
return self._create_region(sp, ep)
ep = future_ep
return self._create_region(sp=sp, ep=ep)
def _opset_matching(self, sp, ep):
if sp.idx() in self._cache: return self._cache[sp.idx()]
ret_collection = set()
following_ops = self._graph.next_ops(sp)
if (len(following_ops)) == 0:
return ret_collection.add(sp.idx())
for op in following_ops:
if op == ep:
ret_collection.update([sp.idx(), op.idx()])
else:
further_res = self._opset_matching(sp=op, ep=ep)
if further_res is None:
return None
if len(further_res) > 0:
ret_collection.update(further_res)
ret_collection.add(sp.idx())
self._cache[sp.idx()] = ret_collection
return ret_collection
def opset_matching(self, sp, ep):
ret_collection, candidates = set(), set()
for op in self._graph.ops():
if op == sp:
candidates.add(op.idx())
for idx in candidates:
op = self._op_idx_map[idx]
partial_matchings = self._opset_matching(sp=op, ep=ep)
if partial_matchings is None:
return None
if len(partial_matchings) > 0:
ret_collection.update(partial_matchings)
self._cache.clear()
return ret_collection
def _create_region(self, sp, ep):
rps = self.opset_matching(sp, ep)
return sp, ep, rps
def _create_regions(self, limit):
visited = []
for op in self._graph.ops():
region = []
region_weight_names = []
if op.type() == 'fill_constant': continue
if op.type() == 'feed': continue
if op.type() == 'fetch': continue
if op.idx() in visited: continue
sp, ep, rps = self._build(op=op, limit=limit)
if rps is None:
continue
ops = [self._op_idx_map[idx] for idx in rps]
# add region's input var
inps = sp.all_inputs()
for var in inps:
if not var._var.persistable:
region.append(var._var.name)
break
# add region's output var
if ep.type() == 'batch_norm':
out_var = ep.outputs('Y')
else:
out_var = ep.all_outputs()
if not out_var[0]._var.persistable:
region.append(out_var[0]._var.name)
for idx in rps:
visited.append(idx)
op = self._op_idx_map[idx]
if op.type() not in [
"conv2d", "depthwise_conv2d", "mul", "matmul",
"matmul_v2"
]:
continue
inps = op.all_inputs()
for var in inps:
if var._var.persistable:
region_weight_names.append(var._var.name)
if len(region) < 2 or len(region_weight_names) < 1: continue
self._regions.append(region)
self._region_weights_names.append(region_weight_names)
return self._regions, self._region_weights_names
def quant_recon_static(executor,
model_dir,
quantize_model_path,
......@@ -893,7 +1101,8 @@ def quant_recon_static(executor,
region_weights_names=None,
epochs=20,
drop_prob=0.5,
lr=0.1):
lr=0.1,
limit=6):
"""
The function utilizes static post training quantization method to
quantize the fp32 model. It uses calibrate data to calculate the
......@@ -966,8 +1175,8 @@ def quant_recon_static(executor,
skip_tensor_list(list): List of skip quant tensor name.
is_use_cache_file(bool): This param is deprecated.
cache_dir(str): This param is deprecated.
epochs: The number of steps in the reconstruction proces. Default is 20.
drop_prob: The dropout probability of activation quantization, and it is valid only if
epochs(int): The number of steps in the reconstruction proces. Default is 20.
drop_prob(float): The dropout probability of activation quantization, and it is valid only if
simulate_activation_quant is True. Default is 0.5.
regions(list[list], optional): The list of some regions, each region is a subgraph of
fp32 program and it will have exact 1 input operation and 1 output operation. When
......@@ -975,6 +1184,7 @@ def quant_recon_static(executor,
Default is None.
region_weights_names(list[list], optional): The weight names inside every region.
Default is None.
limit(int): The size of each region. Default is 6.
Returns:
None
"""
......@@ -1010,7 +1220,8 @@ def quant_recon_static(executor,
regions=regions,
region_weights_names=region_weights_names,
epochs=epochs,
lr=lr)
lr=lr,
limit=limit)
reconstruction_quantization = ReconstructionQuantization(
PTQCollections=PTQCollections, RSQCollections=RSQCollections)
......
......@@ -14,44 +14,51 @@
import sys
sys.path.append("../")
import unittest
import tempfile
import paddle
from paddleslim.quant import quant_post_static
from static_case import StaticCase
sys.path.append("../demo")
from models import MobileNet
from models import *
from layers import conv_bn_layer
import paddle.dataset.mnist as reader
import numpy as np
from paddleslim.quant import quant_recon_static
class TestRoundingOptimizer(StaticCase):
class ReconPTQ(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestRoundingOptimizer, self).__init__(*args, **kwargs)
super(ReconPTQ, self).__init__(*args, **kwargs)
paddle.enable_static()
self.tmpdir = tempfile.TemporaryDirectory(prefix="test_")
self._gen_model()
def _gen_model(self):
image = paddle.static.data(
name='image', shape=[None, 1, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
model = MobileNet()
out = model.net(input=image, class_dim=10)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
optimizer = paddle.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
weight_decay=paddle.regularizer.L2Decay(4e-5))
optimizer.minimize(avg_cost)
main_prog = paddle.static.default_main_program()
val_prog = main_prog.clone(for_test=True)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
image = paddle.static.data(
name='image', shape=[None, 1, 28, 28], dtype='float32')
label = paddle.static.data(
name='label', shape=[None, 1], dtype='int64')
model = MobileNetV2()
out = model.net(input=image, class_dim=10)
cost = paddle.nn.functional.loss.cross_entropy(
input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
val_program = main_program.clone(for_test=True)
optimizer = paddle.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
weight_decay=paddle.regularizer.L2Decay(4e-5))
optimizer.minimize(avg_cost)
exe.run(startup_program)
def transform(x):
return np.reshape(x, [1, 28, 28])
......@@ -95,64 +102,66 @@ class TestRoundingOptimizer(StaticCase):
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5))
train(main_prog)
train(main_program)
paddle.fluid.io.save_inference_model(
dirname='./test_rounding_optimizer',
feeded_var_names=[image.name, label.name],
target_vars=[avg_cost, acc_top1, acc_top5],
main_program=val_prog,
dirname=self.tmpdir.name,
feeded_var_names=[image.name],
target_vars=[out],
main_program=val_program,
executor=exe,
model_filename='model',
params_filename='params')
model_filename='model.pdmodel',
params_filename='params.pdiparams')
print(f"saved infer model to [{self.tmpdir.name}]")
self.data_loader = sample_generator_creator()
self._regions = [['image', 'batch_norm_26.tmp_4']]
self._region_weights_names = [[
'conv1_weights', 'conv2_1_dw_weights', 'conv2_1_sep_weights',
'conv2_2_dw_weights', 'conv2_2_sep_weights', 'conv3_1_dw_weights',
'conv3_1_sep_weights', 'conv3_2_dw_weights', 'conv3_2_sep_weights',
'conv4_1_dw_weights', 'conv4_1_sep_weights', 'conv4_2_dw_weights',
'conv4_2_sep_weights', 'conv5_1_dw_weights', 'conv5_1_sep_weights',
'conv5_2_dw_weights', 'conv5_2_sep_weights', 'conv5_3_dw_weights',
'conv5_3_sep_weights', 'conv5_4_dw_weights', 'conv5_4_sep_weights',
'conv5_5_dw_weights', 'conv5_5_sep_weights', 'conv5_6_dw_weights',
'conv5_6_sep_weights', 'conv6_dw_weights', 'conv6_sep_weights'
]]
def test_qdrop(self):
def __del__(self):
self.tmpdir.cleanup()
class TestReconRegion(ReconPTQ):
def __init__(self, *args, **kwargs):
super(TestReconRegion, self).__init__(*args, **kwargs)
def test_qdrop_region(self):
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
quant_recon_static(
exe,
'./test_rounding_optimizer',
quantize_model_path='rsq_out',
self.tmpdir.name,
quantize_model_path='output_region',
sample_generator=self.data_loader,
model_filename='model',
params_filename='params',
batch_nums=10,
model_filename='model.pdmodel',
params_filename='params.pdiparams',
batch_nums=1,
epochs=1,
algo='abs_max',
regions=self._regions,
region_weights_names=self._region_weights_names,
regions=None,
region_weights_names=None,
recon_level='region-wise',
simulate_activation_quant=True)
def test_qdrop(self):
class TestReconLayer(ReconPTQ):
def __init__(self, *args, **kwargs):
super(TestReconLayer, self).__init__(*args, **kwargs)
def test_qdrop_layer(self):
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
quant_recon_static(
exe,
'./test_rounding_optimizer',
quantize_model_path='rsq_out',
self.tmpdir.name,
quantize_model_path='output_layer',
sample_generator=self.data_loader,
model_filename='model',
params_filename='params',
batch_nums=10,
model_filename='model.pdmodel',
params_filename='params.pdiparams',
batch_nums=1,
epochs=1,
algo='KL',
regions=self._regions,
region_weights_names=self._region_weights_names,
regions=None,
region_weights_names=None,
recon_level='layer-wise',
simulate_activation_quant=True,
bias_correction=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册