未验证 提交 fb3fac19 编写于 作者: G gushiqiao 提交者: GitHub

Add automatic region division and Fix pre-tensor quant (#1517)

* Fixed naming conflicts and fc layer quantification

* Update fine_tune.py

* Update readme

* Add automatic region division and Fix pre-tensor quant

* Update unit test file

* Fix bugs.
Co-authored-by: Nzhouzj <41366441+zzjjay@users.noreply.github.com>
上级 4fb41e4c
...@@ -74,7 +74,7 @@ $$ ...@@ -74,7 +74,7 @@ $$
说明: 说明:
- 如果想使用bias_correction,可以在PaddleSlim的[离线量化接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/quant/quantization_api.rst#quant_post_static)修改`bias_correction`参数为True即可,默认为False。 - 如果想使用bias_correction,可以在PaddleSlim的[离线量化接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/quant/quantization_api.rst#quant_post_static)修改`bias_correction`参数为True即可,默认为False。
- 如果想使用Adaround方法,可以在PaddleSlim的[离线量化接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/quant/quantization_api.rst#quant_post_static)修改`round_type`参数为`adaround`即可,默认为`round` - 如果想使用Adaround方法,可以在PaddleSlim的[离线量化接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/quant/quantization_api.rst#quant_post_static)修改`round_type`参数为`adaround`即可,默认为`round`
- 如果想使用BRECQ方法,可以在PaddleSlim的[量化重构接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/quant/quantization_api.rst#quant_post_static)修改`recon_level`参数为`regionn-wise`即可,默认为`layer-wise` - 如果想使用BRECQ方法,可以在PaddleSlim的[量化重构接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/quant/quantization_api.rst#quant_post_static)修改`recon_level`参数为`region-wise`即可,默认为`layer-wise`
- 如果想使用QDrop方法,可以在PaddleSlim的[量化重构接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/quant/quantization_api.rst#quant_post_static)修改`simulate_activation_quant`参数为`True`即可,默认为`False` - 如果想使用QDrop方法,可以在PaddleSlim的[量化重构接口](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/quant/quantization_api.rst#quant_post_static)修改`simulate_activation_quant`参数为`True`即可,默认为`False`
### 效果对比 ### 效果对比
......
arch: YOLOv6 arch: YOLOv6
model_dir: ./yolov6s.onnx model_dir: ./yolov6s.onnx
dataset_dir: /dataset/coco/ dataset_dir: /dataset/coco/
model_filename: model.pdmodel model_filename: model.pdmodel
params_filename: model.pdiparams params_filename: model.pdiparams
...@@ -8,25 +8,3 @@ val_image_dir: val2017 ...@@ -8,25 +8,3 @@ val_image_dir: val2017
train_anno_path: annotations/instances_train2017.json train_anno_path: annotations/instances_train2017.json
val_anno_path: annotations/instances_val2017.json val_anno_path: annotations/instances_val2017.json
skip_tensor_list: None skip_tensor_list: None
regions: [['x2paddle_image_arrays','relu_8.tmp_0'],
['relu_8.tmp_0','relu_15.tmp_0'],
['relu_15.tmp_0','relu_21.tmp_0'],
['concat_1.tmp_0','relu_26.tmp_0'],
['concat_2.tmp_0', 'relu_30.tmp_0'],
['relu_30.tmp_0', 'concat_4.tmp_0'],
['relu_30.tmp_0', 'relu_31.tmp_0'],
['concat_3.tmp_0', 'relu_35.tmp_0'],
['relu_35.tmp_0', 'relu_36.tmp_0'],
['concat_5.tmp_0', 'concat_10.tmp_0'],
['relu_35.tmp_0', 'concat_8.tmp_0']]
region_weights_names: [['conv2d_0.w_0','conv2d_1.w_0','conv2d_2.w_0','conv2d_3.w_0','conv2d_4.w_0','conv2d_5.w_0','conv2d_6.w_0','conv2d_7.w_0','conv2d_8.w_0'],
['conv2d_9.w_0','conv2d_10.w_0','conv2d_11.w_0','conv2d_12.w_0','conv2d_13.w_0','conv2d_14.w_0','conv2d_15.w_0'],
['conv2d_16.w_0','conv2d_17.w_0','conv2d_18.w_0','conv2d_19.w_0','conv2d_20.w_0','conv2d_21.w_0'],
['conv2d_22.w_0','conv2d_23.w_0','conv2d_24.w_0','conv2d_25.w_0','conv2d_26.w_0'],
['conv2d_27.w_0','conv2d_28.w_0','conv2d_29.w_0','conv2d_30.w_0'],
['conv2d_32.w_0','conv2d_34.w_0','conv2d_35.w_0','conv2d_37.w_0','conv2d_38.w_0','conv2d_39.w_0'],
['conv2d_31.w_0'],
['conv2d_33.w_0','conv2d_36.w_0','conv2d_40.w_0','conv2d_41.w_0'],
['conv2d_42.w_0'],
['conv2d_44.w_0','conv2d_47.w_0','conv2d_51.w_0','conv2d_52.w_0','conv2d_53.w_0','conv2d_54.w_0','conv2d_55.w_0','conv2d_56.w_0','conv2d_57.w_0','conv2d_58.w_0'],
['conv2d_43.w_0','conv2d_45.w_0','conv2d_46.w_0','conv2d_49.w_0','conv2d_48.w_0','conv2d_50.w_0'],]
\ No newline at end of file
...@@ -43,8 +43,6 @@ def argsparser(): ...@@ -43,8 +43,6 @@ def argsparser():
help="which device used to compress.") help="which device used to compress.")
parser.add_argument( parser.add_argument(
'--algo', type=str, default='avg', help="post quant algo.") '--algo', type=str, default='avg', help="post quant algo.")
parser.add_argument(
'--round_type', type=str, default='adaround', help="round type.")
parser.add_argument('--gpu', type=int, default=0, help='gpu index') parser.add_argument('--gpu', type=int, default=0, help='gpu index')
parser.add_argument( parser.add_argument(
...@@ -59,6 +57,10 @@ def argsparser(): ...@@ -59,6 +57,10 @@ def argsparser():
help='simulate activation quant') help='simulate activation quant')
parser.add_argument( parser.add_argument(
'--epochs', type=int, default=20, help='steps to reconstruct') '--epochs', type=int, default=20, help='steps to reconstruct')
parser.add_argument(
'--lr', type=float, default=0.1, help='learning rate of reconstruct')
parser.add_argument(
'--limit', type=int, default=5, help='size of each region')
return parser return parser
...@@ -104,10 +106,11 @@ def main(): ...@@ -104,10 +106,11 @@ def main():
weight_quantize_type='channel_wise_abs_max', weight_quantize_type='channel_wise_abs_max',
recon_level=FLAGS.recon_level, recon_level=FLAGS.recon_level,
simulate_activation_quant=FLAGS.simulate_activation_quant, simulate_activation_quant=FLAGS.simulate_activation_quant,
regions=config['regions'], regions=None,
region_weights_names=config['region_weights_names'], region_weights_names=None,
epochs=FLAGS.epochs, epochs=FLAGS.epochs,
lr=0.1) lr=FLAGS.lr,
limit=FLAGS.limit)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -167,7 +167,8 @@ class ReconstructionQuantization(PostTrainingQuantization): ...@@ -167,7 +167,8 @@ class ReconstructionQuantization(PostTrainingQuantization):
num_iterations=self._batch_nums, num_iterations=self._batch_nums,
lr=self._config['lr'], lr=self._config['lr'],
bias_correction=self._bias_correction, bias_correction=self._bias_correction,
epochs=self._config['epochs']) epochs=self._config['epochs'],
limit=self._config['limit'])
self._program, self._scale_dict = reconstruction_quanter._run() self._program, self._scale_dict = reconstruction_quanter._run()
if self._algo in ["KL", "hist"]: if self._algo in ["KL", "hist"]:
...@@ -229,7 +230,8 @@ class ReconstructionQuanter(object): ...@@ -229,7 +230,8 @@ class ReconstructionQuanter(object):
lr=0.1, lr=0.1,
bias_correction=False, bias_correction=False,
epochs=20, epochs=20,
drop_prob=0.5): drop_prob=0.5,
limit=5):
''' '''
Reconstruction Quanter, used to optimize the rounding policy Reconstruction Quanter, used to optimize the rounding policy
by reconstructing the intermediate output. by reconstructing the intermediate output.
...@@ -266,19 +268,17 @@ class ReconstructionQuanter(object): ...@@ -266,19 +268,17 @@ class ReconstructionQuanter(object):
lr(float, optional): The learning rate of Reconstruction Quanter. Default is 0.1. lr(float, optional): The learning rate of Reconstruction Quanter. Default is 0.1.
bias_correction(bool, optional): If set as True, use the bias correction bias_correction(bool, optional): If set as True, use the bias correction
method of https://arxiv.org/abs/1810.05723. Default is False. method of https://arxiv.org/abs/1810.05723. Default is False.
drop_prob: The dropout probability of activation quantization, and it is valid only if drop_prob(float, optional): The dropout probability of activation quantization, and it is valid only if
simulate_activation_quant is True. Default is 0.5. simulate_activation_quant is True. Default is 0.5.
limit(int, optional): The size of each region. Default is 5.
Returns: Returns:
None None
''' '''
assert recon_level in [ assert recon_level in [
'layer-wise', 'region-wise' 'layer-wise', 'region-wise'
], "recon_level must be one of the ['layer-wise', 'region-wise'],but received: {}".format( ], "recon_level must be one of the ['layer-wise', 'region-wise'], but received: {}".format(
recon_level) recon_level)
if recon_level == 'region-wise':
assert regions is not None, "The regions cannot be None."
assert region_weights_names is not None, "The region_weights_names cannot be None."
self._simulate_activation_quant = simulate_activation_quant self._simulate_activation_quant = simulate_activation_quant
self._program = fp32_program self._program = fp32_program
self._data_loader = data_loader self._data_loader = data_loader
...@@ -301,7 +301,15 @@ class ReconstructionQuanter(object): ...@@ -301,7 +301,15 @@ class ReconstructionQuanter(object):
self._regions = regions self._regions = regions
self._region_weights_names = region_weights_names self._region_weights_names = region_weights_names
self._bias_correction = bias_correction self._bias_correction = bias_correction
if self._recon_level == 'layer-wise': self._limit = limit
if recon_level == 'region-wise' and regions is None:
builder = RegionBuilder(program=self._program)
_logger.info('Begin Region division')
self._regions, self._region_weights_names = builder._create_regions(
limit=self._limit)
_logger.info('End Region division')
elif self._recon_level == 'layer-wise':
regions, region_weights_names = self._get_layers() regions, region_weights_names = self._get_layers()
self._regions = regions self._regions = regions
self._region_weights_names = region_weights_names self._region_weights_names = region_weights_names
...@@ -330,10 +338,11 @@ class ReconstructionQuanter(object): ...@@ -330,10 +338,11 @@ class ReconstructionQuanter(object):
def _preprocess(self): def _preprocess(self):
for name in self._weight_var_names: if self._weight_quantize_type == 'channel_wise_abs_max':
for i, s in enumerate(self._scale_dict[name]): for name in self._weight_var_names:
if s == 0.0: for i, s in enumerate(self._scale_dict[name]):
self._scale_dict[name][i] = 1e-8 if s == 0.0:
self._scale_dict[name][i] = 1e-8
data_name_map = {} data_name_map = {}
for name in self._feed_list: for name in self._feed_list:
...@@ -363,9 +372,10 @@ class ReconstructionQuanter(object): ...@@ -363,9 +372,10 @@ class ReconstructionQuanter(object):
region_ = self._regions[k] region_ = self._regions[k]
tmp_program.global_block().var(region_[0]).stop_gradient = True tmp_program.global_block().var(region_[0]).stop_gradient = True
quant_op_out_name = region_[1] quant_op_out_name = region_[1]
_logger.info(f"Region's input: {region_[0]} output: {region_[1]}")
names = self._region_weights_names[k] names = self._region_weights_names[k]
_logger.info(f"Current weights: {names}") _logger.info(f"Current quanted weights: {names}")
loss_function = ReconstructionQuanterLoss( loss_function = ReconstructionQuanterLoss(
program=tmp_program, weight_region_names=names) program=tmp_program, weight_region_names=names)
update_params = [ update_params = [
...@@ -413,8 +423,8 @@ class ReconstructionQuanter(object): ...@@ -413,8 +423,8 @@ class ReconstructionQuanter(object):
sys.stdout.flush() sys.stdout.flush()
if i + 1 == self._num_iterations: if i + 1 == self._num_iterations:
break break
if self._weight_quantize_type == 'channel_wise_abs_max':
self._update_scale() self._update_scale()
self._update_weights_to_int() self._update_weights_to_int()
if self._bias_correction: if self._bias_correction:
self._bias_correction_w() self._bias_correction_w()
...@@ -495,7 +505,6 @@ class ReconstructionQuanter(object): ...@@ -495,7 +505,6 @@ class ReconstructionQuanter(object):
scale = np.array(scale) scale = np.array(scale)
scale = scale.reshape(scale.shape[0], 1) scale = scale.reshape(scale.shape[0], 1)
if len(shape) == 2: if len(shape) == 2:
print(name)
scale = scale.repeat(shape[0], axis=1).T scale = scale.repeat(shape[0], axis=1).T
else: else:
scale = scale.repeat(shape[1] * shape[2] * shape[3], axis=1) scale = scale.repeat(shape[1] * shape[2] * shape[3], axis=1)
...@@ -643,6 +652,9 @@ class ReconstructionQuanter(object): ...@@ -643,6 +652,9 @@ class ReconstructionQuanter(object):
op.input('X')[0].endswith('scale') op.input('X')[0].endswith('scale')
) or _type == 'sigmoid': ) or _type == 'sigmoid':
_inputs = {'X': op.input('X')[0]} _inputs = {'X': op.input('X')[0]}
elif (_type == 'scale' and
op.input('X')[0].endswith('copy')):
_inputs = {'X': var._var}
else: else:
_inputs = {'X': op.input('X')[0] + '.rounding'} _inputs = {'X': op.input('X')[0] + '.rounding'}
elif func == "_drop_quant_dequant": elif func == "_drop_quant_dequant":
...@@ -857,6 +869,202 @@ class ReconstructionQuanterLoss(object): ...@@ -857,6 +869,202 @@ class ReconstructionQuanterLoss(object):
return total_loss, rec_loss, round_loss return total_loss, rec_loss, round_loss
class PriorityQueue:
def __init__(self):
self._data = []
self._ops = set()
self._idx = 0
self._lazy_tag = True
def pop(self):
if not self._lazy_tag:
self._data = sorted(self._data, key=lambda x: x[0])
self._lazy_tag = True
if self._idx >= len(self._data): raise IndexError('Index out of range!')
ele = self._data[self._idx]
self._idx += 1
return ele
def push(self, depth, op):
if op in self._ops: return
self._data.append((depth, op))
self._ops.add(op)
self._lazy_tag = False
def empty(self):
return self._idx >= len(self._data)
class RegionBuilder(object):
def __init__(self, program):
self._program = program
self._graph = GraphWrapper(self._program)
self._op_idx_map = {}
for op in self._graph.ops():
self._op_idx_map[op.idx()] = op
self._depth = {}
self._init_depth()
self._cache = {}
self._regions = []
self._region_weights_names = []
def _init_depth(self):
for op in self._graph.ops():
if len(self._graph.pre_ops(op)) == 0:
self._depth[op.idx()] = 0
continue
depths_cache = []
for up_op in self._graph.pre_ops(op):
assert up_op.idx() in self._depth
depths_cache.append(self._depth[up_op.idx()])
self._depth[op.idx()] = max(depths_cache) + 1
def _build(self, op, limit):
def _find_multi_input_ep(op):
least_first_queue = PriorityQueue()
for down_op in self._graph.next_ops(op):
least_first_queue.push(self._depth[down_op.idx()],
down_op.idx())
while not least_first_queue.empty():
iter_op_idx = least_first_queue.pop()[-1]
iter_op = self._op_idx_map[iter_op_idx]
if (least_first_queue.empty() and
len(self._graph.pre_ops(iter_op)) > 1):
return iter_op
for down_op in self._graph.next_ops(iter_op):
least_first_queue.push(self._depth[down_op.idx()],
down_op.idx())
return None
def _find_coherent_ep(op):
ops = self._graph.next_ops(op)
if len(ops) == 1:
following_op = ops[0]
if following_op.type() == 'fetch':
return None
inps = op.all_inputs()
non_parameter_input = 0
for var in inps:
if not var._var.persistable:
non_parameter_input += 1
upstream_ops = len(self._graph.pre_ops(following_op))
if non_parameter_input == 1 and upstream_ops == 1:
return ops[0]
return None
sp, ep, future_ep = op, op, op
while future_ep is not None:
if len(self._graph.next_ops(ep)) <= 1:
future_ep = _find_coherent_ep(ep)
else:
future_ep = _find_multi_input_ep(ep)
if future_ep is None or self._depth[future_ep.idx()] - self._depth[
sp.idx()] >= limit:
return self._create_region(sp, ep)
ep = future_ep
return self._create_region(sp=sp, ep=ep)
def _opset_matching(self, sp, ep):
if sp.idx() in self._cache: return self._cache[sp.idx()]
ret_collection = set()
following_ops = self._graph.next_ops(sp)
if (len(following_ops)) == 0:
return ret_collection.add(sp.idx())
for op in following_ops:
if op == ep:
ret_collection.update([sp.idx(), op.idx()])
else:
further_res = self._opset_matching(sp=op, ep=ep)
if further_res is None:
return None
if len(further_res) > 0:
ret_collection.update(further_res)
ret_collection.add(sp.idx())
self._cache[sp.idx()] = ret_collection
return ret_collection
def opset_matching(self, sp, ep):
ret_collection, candidates = set(), set()
for op in self._graph.ops():
if op == sp:
candidates.add(op.idx())
for idx in candidates:
op = self._op_idx_map[idx]
partial_matchings = self._opset_matching(sp=op, ep=ep)
if partial_matchings is None:
return None
if len(partial_matchings) > 0:
ret_collection.update(partial_matchings)
self._cache.clear()
return ret_collection
def _create_region(self, sp, ep):
rps = self.opset_matching(sp, ep)
return sp, ep, rps
def _create_regions(self, limit):
visited = []
for op in self._graph.ops():
region = []
region_weight_names = []
if op.type() == 'fill_constant': continue
if op.type() == 'feed': continue
if op.type() == 'fetch': continue
if op.idx() in visited: continue
sp, ep, rps = self._build(op=op, limit=limit)
if rps is None:
continue
ops = [self._op_idx_map[idx] for idx in rps]
# add region's input var
inps = sp.all_inputs()
for var in inps:
if not var._var.persistable:
region.append(var._var.name)
break
# add region's output var
if ep.type() == 'batch_norm':
out_var = ep.outputs('Y')
else:
out_var = ep.all_outputs()
if not out_var[0]._var.persistable:
region.append(out_var[0]._var.name)
for idx in rps:
visited.append(idx)
op = self._op_idx_map[idx]
if op.type() not in [
"conv2d", "depthwise_conv2d", "mul", "matmul",
"matmul_v2"
]:
continue
inps = op.all_inputs()
for var in inps:
if var._var.persistable:
region_weight_names.append(var._var.name)
if len(region) < 2 or len(region_weight_names) < 1: continue
self._regions.append(region)
self._region_weights_names.append(region_weight_names)
return self._regions, self._region_weights_names
def quant_recon_static(executor, def quant_recon_static(executor,
model_dir, model_dir,
quantize_model_path, quantize_model_path,
...@@ -893,7 +1101,8 @@ def quant_recon_static(executor, ...@@ -893,7 +1101,8 @@ def quant_recon_static(executor,
region_weights_names=None, region_weights_names=None,
epochs=20, epochs=20,
drop_prob=0.5, drop_prob=0.5,
lr=0.1): lr=0.1,
limit=6):
""" """
The function utilizes static post training quantization method to The function utilizes static post training quantization method to
quantize the fp32 model. It uses calibrate data to calculate the quantize the fp32 model. It uses calibrate data to calculate the
...@@ -966,8 +1175,8 @@ def quant_recon_static(executor, ...@@ -966,8 +1175,8 @@ def quant_recon_static(executor,
skip_tensor_list(list): List of skip quant tensor name. skip_tensor_list(list): List of skip quant tensor name.
is_use_cache_file(bool): This param is deprecated. is_use_cache_file(bool): This param is deprecated.
cache_dir(str): This param is deprecated. cache_dir(str): This param is deprecated.
epochs: The number of steps in the reconstruction proces. Default is 20. epochs(int): The number of steps in the reconstruction proces. Default is 20.
drop_prob: The dropout probability of activation quantization, and it is valid only if drop_prob(float): The dropout probability of activation quantization, and it is valid only if
simulate_activation_quant is True. Default is 0.5. simulate_activation_quant is True. Default is 0.5.
regions(list[list], optional): The list of some regions, each region is a subgraph of regions(list[list], optional): The list of some regions, each region is a subgraph of
fp32 program and it will have exact 1 input operation and 1 output operation. When fp32 program and it will have exact 1 input operation and 1 output operation. When
...@@ -975,6 +1184,7 @@ def quant_recon_static(executor, ...@@ -975,6 +1184,7 @@ def quant_recon_static(executor,
Default is None. Default is None.
region_weights_names(list[list], optional): The weight names inside every region. region_weights_names(list[list], optional): The weight names inside every region.
Default is None. Default is None.
limit(int): The size of each region. Default is 6.
Returns: Returns:
None None
""" """
...@@ -1010,7 +1220,8 @@ def quant_recon_static(executor, ...@@ -1010,7 +1220,8 @@ def quant_recon_static(executor,
regions=regions, regions=regions,
region_weights_names=region_weights_names, region_weights_names=region_weights_names,
epochs=epochs, epochs=epochs,
lr=lr) lr=lr,
limit=limit)
reconstruction_quantization = ReconstructionQuantization( reconstruction_quantization = ReconstructionQuantization(
PTQCollections=PTQCollections, RSQCollections=RSQCollections) PTQCollections=PTQCollections, RSQCollections=RSQCollections)
......
...@@ -14,44 +14,51 @@ ...@@ -14,44 +14,51 @@
import sys import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import tempfile
import paddle import paddle
from paddleslim.quant import quant_post_static from paddleslim.quant import quant_post_static
from static_case import StaticCase from static_case import StaticCase
sys.path.append("../demo") sys.path.append("../demo")
from models import MobileNet from models import *
from layers import conv_bn_layer from layers import conv_bn_layer
import paddle.dataset.mnist as reader import paddle.dataset.mnist as reader
import numpy as np import numpy as np
from paddleslim.quant import quant_recon_static from paddleslim.quant import quant_recon_static
class TestRoundingOptimizer(StaticCase): class ReconPTQ(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TestRoundingOptimizer, self).__init__(*args, **kwargs) super(ReconPTQ, self).__init__(*args, **kwargs)
paddle.enable_static() paddle.enable_static()
self.tmpdir = tempfile.TemporaryDirectory(prefix="test_")
self._gen_model() self._gen_model()
def _gen_model(self): def _gen_model(self):
image = paddle.static.data(
name='image', shape=[None, 1, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
model = MobileNet()
out = model.net(input=image, class_dim=10)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
optimizer = paddle.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
weight_decay=paddle.regularizer.L2Decay(4e-5))
optimizer.minimize(avg_cost)
main_prog = paddle.static.default_main_program()
val_prog = main_prog.clone(for_test=True)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace() ) else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program()) main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
image = paddle.static.data(
name='image', shape=[None, 1, 28, 28], dtype='float32')
label = paddle.static.data(
name='label', shape=[None, 1], dtype='int64')
model = MobileNetV2()
out = model.net(input=image, class_dim=10)
cost = paddle.nn.functional.loss.cross_entropy(
input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
val_program = main_program.clone(for_test=True)
optimizer = paddle.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
weight_decay=paddle.regularizer.L2Decay(4e-5))
optimizer.minimize(avg_cost)
exe.run(startup_program)
def transform(x): def transform(x):
return np.reshape(x, [1, 28, 28]) return np.reshape(x, [1, 28, 28])
...@@ -95,64 +102,66 @@ class TestRoundingOptimizer(StaticCase): ...@@ -95,64 +102,66 @@ class TestRoundingOptimizer(StaticCase):
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'. 'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5)) format(iter, cost, top1, top5))
train(main_prog) train(main_program)
paddle.fluid.io.save_inference_model( paddle.fluid.io.save_inference_model(
dirname='./test_rounding_optimizer', dirname=self.tmpdir.name,
feeded_var_names=[image.name, label.name], feeded_var_names=[image.name],
target_vars=[avg_cost, acc_top1, acc_top5], target_vars=[out],
main_program=val_prog, main_program=val_program,
executor=exe, executor=exe,
model_filename='model', model_filename='model.pdmodel',
params_filename='params') params_filename='params.pdiparams')
print(f"saved infer model to [{self.tmpdir.name}]")
self.data_loader = sample_generator_creator() self.data_loader = sample_generator_creator()
self._regions = [['image', 'batch_norm_26.tmp_4']] def __del__(self):
self._region_weights_names = [[ self.tmpdir.cleanup()
'conv1_weights', 'conv2_1_dw_weights', 'conv2_1_sep_weights',
'conv2_2_dw_weights', 'conv2_2_sep_weights', 'conv3_1_dw_weights',
'conv3_1_sep_weights', 'conv3_2_dw_weights', 'conv3_2_sep_weights', class TestReconRegion(ReconPTQ):
'conv4_1_dw_weights', 'conv4_1_sep_weights', 'conv4_2_dw_weights', def __init__(self, *args, **kwargs):
'conv4_2_sep_weights', 'conv5_1_dw_weights', 'conv5_1_sep_weights', super(TestReconRegion, self).__init__(*args, **kwargs)
'conv5_2_dw_weights', 'conv5_2_sep_weights', 'conv5_3_dw_weights',
'conv5_3_sep_weights', 'conv5_4_dw_weights', 'conv5_4_sep_weights', def test_qdrop_region(self):
'conv5_5_dw_weights', 'conv5_5_sep_weights', 'conv5_6_dw_weights',
'conv5_6_sep_weights', 'conv6_dw_weights', 'conv6_sep_weights'
]]
def test_qdrop(self):
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace() ) else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
quant_recon_static( quant_recon_static(
exe, exe,
'./test_rounding_optimizer', self.tmpdir.name,
quantize_model_path='rsq_out', quantize_model_path='output_region',
sample_generator=self.data_loader, sample_generator=self.data_loader,
model_filename='model', model_filename='model.pdmodel',
params_filename='params', params_filename='params.pdiparams',
batch_nums=10, batch_nums=1,
epochs=1,
algo='abs_max', algo='abs_max',
regions=self._regions, regions=None,
region_weights_names=self._region_weights_names, region_weights_names=None,
recon_level='region-wise', recon_level='region-wise',
simulate_activation_quant=True) simulate_activation_quant=True)
def test_qdrop(self):
class TestReconLayer(ReconPTQ):
def __init__(self, *args, **kwargs):
super(TestReconLayer, self).__init__(*args, **kwargs)
def test_qdrop_layer(self):
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace() ) else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
quant_recon_static( quant_recon_static(
exe, exe,
'./test_rounding_optimizer', self.tmpdir.name,
quantize_model_path='rsq_out', quantize_model_path='output_layer',
sample_generator=self.data_loader, sample_generator=self.data_loader,
model_filename='model', model_filename='model.pdmodel',
params_filename='params', params_filename='params.pdiparams',
batch_nums=10, batch_nums=1,
epochs=1,
algo='KL', algo='KL',
regions=self._regions, regions=None,
region_weights_names=self._region_weights_names, region_weights_names=None,
recon_level='layer-wise', recon_level='layer-wise',
simulate_activation_quant=True, simulate_activation_quant=True,
bias_correction=True) bias_correction=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册