未验证 提交 bc442429 编写于 作者: G gushiqiao 提交者: GitHub

Add reconstuction quant algorithm (#1457)

上级 880ad20b
arch: YOLOv6
model_dir: ./yolov6s.onnx
dataset_dir: /dataset/coco/
model_filename: model.pdmodel
params_filename: model.pdiparams
train_image_dir: train2017
val_image_dir: val2017
train_anno_path: annotations/instances_train2017.json
val_anno_path: annotations/instances_val2017.json
skip_tensor_list: None
regions: [['x2paddle_image_arrays','relu_8.tmp_0'],
['relu_8.tmp_0','relu_15.tmp_0'],
['relu_15.tmp_0','relu_21.tmp_0'],
['concat_1.tmp_0','relu_26.tmp_0'],
['concat_2.tmp_0', 'relu_30.tmp_0'],
['relu_30.tmp_0', 'concat_4.tmp_0'],
['relu_30.tmp_0', 'relu_31.tmp_0'],
['concat_3.tmp_0', 'relu_35.tmp_0'],
['relu_35.tmp_0', 'relu_36.tmp_0'],
['concat_5.tmp_0', 'concat_10.tmp_0'],
['relu_35.tmp_0', 'concat_8.tmp_0']]
region_weights_names: [['conv2d_0.w_0','conv2d_1.w_0','conv2d_2.w_0','conv2d_3.w_0','conv2d_4.w_0','conv2d_5.w_0','conv2d_6.w_0','conv2d_7.w_0','conv2d_8.w_0'],
['conv2d_9.w_0','conv2d_10.w_0','conv2d_11.w_0','conv2d_12.w_0','conv2d_13.w_0','conv2d_14.w_0','conv2d_15.w_0'],
['conv2d_16.w_0','conv2d_17.w_0','conv2d_18.w_0','conv2d_19.w_0','conv2d_20.w_0','conv2d_21.w_0'],
['conv2d_22.w_0','conv2d_23.w_0','conv2d_24.w_0','conv2d_25.w_0','conv2d_26.w_0'],
['conv2d_27.w_0','conv2d_28.w_0','conv2d_29.w_0','conv2d_30.w_0'],
['conv2d_32.w_0','conv2d_34.w_0','conv2d_35.w_0','conv2d_37.w_0','conv2d_38.w_0','conv2d_39.w_0'],
['conv2d_31.w_0'],
['conv2d_33.w_0','conv2d_36.w_0','conv2d_40.w_0','conv2d_41.w_0'],
['conv2d_42.w_0'],
['conv2d_44.w_0','conv2d_47.w_0','conv2d_51.w_0','conv2d_52.w_0','conv2d_53.w_0','conv2d_54.w_0','conv2d_55.w_0','conv2d_56.w_0','conv2d_57.w_0','conv2d_58.w_0'],
['conv2d_43.w_0','conv2d_45.w_0','conv2d_46.w_0','conv2d_49.w_0','conv2d_48.w_0','conv2d_50.w_0'],]
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
from paddleslim.common import load_config, load_onnx_model
from paddleslim.quant import quant_post_static
from paddleslim.quant import quant_recon_static
from dataset import COCOTrainDataset
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of post training quantization config.",
required=True)
parser.add_argument(
'--save_dir',
type=str,
default='ptq_out',
help="directory to save compressed model.")
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
parser.add_argument(
'--algo', type=str, default='avg', help="post quant algo.")
parser.add_argument(
'--round_type', type=str, default='adaround', help="round type.")
parser.add_argument('--gpu', type=int, default=0, help='gpu index')
parser.add_argument(
'--recon_level',
type=str,
default='layer-wise',
help='reconstruction level')
parser.add_argument(
'--simulate_activation_quant',
type=bool,
default=False,
help='simulate activation quant')
return parser
def main():
global config
config = load_config(FLAGS.config_path)
input_name = 'x2paddle_image_arrays' if config[
'arch'] == 'YOLOv6' else 'x2paddle_images'
dataset = COCOTrainDataset(
dataset_dir=config['dataset_dir'],
image_dir=config['val_image_dir'],
anno_path=config['val_anno_path'],
input_name=input_name)
train_loader = paddle.io.DataLoader(
dataset, batch_size=1, shuffle=True, drop_last=True, num_workers=0)
place = paddle.CUDAPlace(
FLAGS.gpu) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place)
# since the type pf model converted from pytorch is onnx,
# use load_onnx_model firstly and rename the model_dir
load_onnx_model(config["model_dir"])
inference_model_path = config["model_dir"].rstrip().rstrip(
'.onnx') + '_infer'
quant_recon_static(
executor=exe,
model_dir=inference_model_path,
quantize_model_path=FLAGS.save_dir,
data_loader=train_loader,
model_filename='model.pdmodel',
params_filename='model.pdiparams',
batch_size=32,
batch_nums=10,
algo=FLAGS.algo,
hist_percent=0.999,
is_full_quantize=False,
bias_correction=False,
onnx_format=False,
weight_quantize_type='channel_wise_abs_max',
recon_level=FLAGS.recon_level,
simulate_activation_quant=FLAGS.simulate_activation_quant,
regions=config['regions'],
region_weights_names=config['region_weights_names'],
skip_tensor_list=config['skip_tensor_list']
if 'skip_tensor_list' in config else None,
epochs=20,
lr=0.1)
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
...@@ -26,6 +26,7 @@ try: ...@@ -26,6 +26,7 @@ try:
from .quanter import quant_aware, convert, quant_post_static, quant_post_dynamic from .quanter import quant_aware, convert, quant_post_static, quant_post_dynamic
from .quanter import quant_post, quant_post_only_weight from .quanter import quant_post, quant_post_only_weight
from .quant_aware_with_infermodel import quant_aware_with_infermodel, export_quant_infermodel from .quant_aware_with_infermodel import quant_aware_with_infermodel, export_quant_infermodel
from .reconstruction_quantization import quant_recon_static
if platform.system().lower() == 'linux': if platform.system().lower() == 'linux':
from .post_quant_hpo import quant_post_hpo from .post_quant_hpo import quant_post_hpo
else: else:
......
...@@ -813,4 +813,4 @@ def pact(x, name=None): ...@@ -813,4 +813,4 @@ def pact(x, name=None):
def get_pact_optimizer(): def get_pact_optimizer():
return paddle.fluid.optimizer.MomentumOptimizer(0.0001, 0.9) return paddle.fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
\ No newline at end of file
import numpy as np # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
import time #
import sys # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging import logging
import math
import os
import re
import shutil
import sys
import time
import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import six from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
import math from paddle.fluid.contrib.slim.quantization import utils
import copy
from ..dist import merge from ..dist import merge
from ..core.graph_wrapper import GraphWrapper from ..core.graph_wrapper import GraphWrapper
from ..common import get_logger from ..common import get_logger
from paddle.fluid.contrib.slim.quantization import utils
__all__ = ['ReconstructionQuantization', ]
_logger = get_logger(
__name__,
logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s', )
_logger = get_logger(__name__,
logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
GAMMA = -0.1 GAMMA = -0.1
ZETA = 1.1 ZETA = 1.1
__all__ = [
'RoundingOptimizer',
]
class RoundingOptimizerLoss(object):
def __init__(self,
program,
weight_block_names=None,
round_loss_mode='relaxation',
rec_loss_mode='mse',
beta_mode='const',
weight=0.1,):
"""
The loss function of Rounding Optimizer.
Args:
program(Program): The student program.
weight_block_names(list, optional): The weight names inside a block.
round_loss_mode(str): The rounding loss function mode.
rec_loss_mode(str): The reconstruction loss function mode.
beta_mode(str): The parameter beta mode.
Returns:
total_loss(Variable): The sum of rounding loss and reconstruction loss.
rec_loss(Variable): The reconstruction loss.
round_loss(Variable): The rounding loss.
"""
self.program = program
self.round_loss_mode = round_loss_mode
self.weight = weight
self.rec_loss_mode = rec_loss_mode
self.weight_block_names = weight_block_names
self.beta_mode = beta_mode
def compute_soft_rounding(self, alpha_v): class Collections(object):
return paddle.clip(paddle.nn.functional.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, 0, 1) def __init__(self, **kwargs):
self._config = dict()
for k, v in kwargs.items():
self._config[k] = v
def get_loss(self, student_tensor, teacher_tensor, scheduler): def _get_config(self):
if self.rec_loss_mode == 'mse': return self._config
rec_loss = paddle.nn.functional.mse_loss(student_tensor, teacher_tensor)
else:
raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss))
if self.beta_mode == 'const':
self.beta = 3
else:
self.beta = scheduler.get_lr()
if self.round_loss_mode == 'relaxation': class ReconstructionQuantization(PostTrainingQuantization):
round_loss = 0.0 """
for name in self.weight_block_names: Utilizing reconstruction quantization method to quantize the FP32 model,
alpha_v = self.program.global_block().var(name+'.alpha') and it uses calibrate data to get the quantization information for all
h_v = self.compute_soft_rounding(alpha_v) quantized variables.
round_loss += self.weight * paddle.sum(-paddle.pow(paddle.abs(2 * h_v-1), self.beta) + 1) """
else:
raise NotImplementedError
total_loss = rec_loss+round_loss
return total_loss, rec_loss, round_loss
def __init__(self, PTQCollections, RSQCollections):
'''
Args:
PTQCollections(Collections): The parameters set required for post training quantization.
RSQCollections(Collections): The parameters set required for reconstruction quantization.
Returns:
None
'''
super().__init__(**PTQCollections._get_config())
self._config = RSQCollections._get_config()
class RoundingOptimizer(object): def quantize(self):
'''
Load the FP32 model, and use the calibrate data to calculate the forward-stage.
Based on the sample data, we can get the quantization information, and obtain
the final quantized model.
Args:
None
Returns:
the program of quantized model.
'''
self._load_model_data()
self._collect_target_varnames()
self._set_activation_persistable()
if self._algo in ["KL", "hist"]:
self._preparation()
self._sampling_threshold()
self._calculate_threshold()
self._reset_activation_persistable()
self._reconstruction()
self._postprocessing()
return self._program
def _preparation(self):
batch_id = 0
with utils.tqdm(
total=self._batch_nums,
bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80, ) as t:
for data in self._data_loader():
self._executor.run(
program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope, )
self._collect_activation_abs_min_max()
batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums:
break
self._init_sampling_act_histogram()
def _sampling_threshold(self):
batch_id = 0
with utils.tqdm(
total=self._batch_nums,
bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80, ) as t:
for data in self._data_loader():
self._executor.run(
program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope, )
self._sampling()
batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums:
break
def _calculate_threshold(self):
if self._algo == 'avg':
for var_name in self._quantized_act_var_name:
self._quantized_threshold[var_name] = \
np.array(self._quantized_var_avg[var_name]).mean()
self._scale_dict = self._quantized_threshold
elif self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold()
self._scale_dict = self._quantized_var_threshold
else:
self._scale_dict = self._quantized_threshold
def _reconstruction(self):
reconstruction_quanter = ReconstructionQuanter(
data_loader=self._data_loader,
fp32_program=self._program,
feed_list=self._feed_list,
fetch_list=self._fetch_list,
exe=self._executor,
scope=self._scope,
place=self._place,
quantized_op_pairs=self._quantized_op_pairs,
weight_quantize_type=self._weight_quantize_type,
scale_dict=copy.deepcopy(self._scale_dict),
regions=self._config['regions'],
region_weights_names=self._config['region_weights_names'],
recon_level=self._config['recon_level'],
simulate_activation_quant=self._config['simulate_activation_quant'],
num_iterations=self._batch_nums,
lr=self._config['lr'],
bias_correction=self._bias_correction,
epochs=self._config['epochs'],
scale_trainable=self._config['scale_trainable'])
self._program = reconstruction_quanter._run()
def _postprocessing(self):
if self._algo is 'min_max':
self._save_input_threhold()
else:
self._update_program()
# save out_threshold for quantized ops.
self._save_output_threshold()
if any(op_type in self._quantizable_op_type
for op_type in self._dynamic_quantize_op_type):
self._collect_dynamic_quantize_op_threshold(
self._dynamic_quantize_op_type, )
# Move sub blocks persistable var to global block
global_block = self._program.global_block()
for _op in global_block.ops:
if _op.type == "while":
_block_id = _op.attr("sub_block").id
_block = self._program.block(_block_id)
persistables = []
for _name, _var in _block.vars.items():
if _var.persistable:
global_block._clone_variable(_var)
persistables.append(_name)
for _name in persistables:
_block._remove_var(_name)
persistables.extend(_op.input('X'))
_op.desc.set_input("X", persistables)
class ReconstructionQuanter(object):
def __init__(self, def __init__(self,
data_loader, data_loader,
fp32_program, fp32_program,
...@@ -90,16 +212,18 @@ class RoundingOptimizer(object): ...@@ -90,16 +212,18 @@ class RoundingOptimizer(object):
quantized_op_pairs, quantized_op_pairs,
weight_quantize_type, weight_quantize_type,
scale_dict, scale_dict,
blocks, regions,
block_weights_names, region_weights_names,
round_type, recon_level,
simulate_activation_quant,
num_iterations=1000, num_iterations=1000,
lr=0.1, lr=0.1,
bias_correction=False, bias_correction=False,
epochs=20, epochs=20,
): scale_trainable=False,
drop_prob=0.5):
''' '''
Rounding Optimizer, used to optimize the rounding policy Reconstruction Quanter, used to optimize the rounding policy
by reconstructing the intermediate output. by reconstructing the intermediate output.
Args: Args:
...@@ -108,44 +232,51 @@ class RoundingOptimizer(object): ...@@ -108,44 +232,51 @@ class RoundingOptimizer(object):
return a batch every time. return a batch every time.
executor(fluid.Executor): The executor to load, run and save the executor(fluid.Executor): The executor to load, run and save the
quantized model. quantized model.
scope(fluid.Scope, optional): The scope of the program, use it to load scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope(). and save variables. If scope=None, get scope by global_scope().
place(CPUPlace()|CUDAPlace(N)): This parameter represents place(CPUPlace()|CUDAPlace(N)): This parameter represents
paddle run on which device. paddle run on which device.
quantized_op_pairs(dict, optional): Mapping of op's weight name quantized_op_pairs(dict, optional): Mapping of op's weight name
and output var name, where key of dict is the weight name of and output var name, where key of dict is the weight name of
op, and value is the output var name of op. op, and value is the output var name of op.
weight_quantize_type(str): quantization type for weights, weight_quantize_type(str): quantization type for weights,
support 'abs_max' and 'channel_wise_abs_max'. This param only specifies support 'abs_max' and 'channel_wise_abs_max'. This param only specifies
the fake ops in saving quantized model, and we save the scale obtained the fake ops in saving quantized model, and we save the scale obtained
by post training quantization in fake ops. Compared to 'abs_max', by post training quantization in fake ops. Compared to 'abs_max',
the model accuracy is usually higher when it is 'channel_wise_abs_max'. the model accuracy is usually higher when it is 'channel_wise_abs_max'.
scale_dict(dict, optional): Mapping of var's name and var's scales, where key scale_dict(dict, optional): Mapping of var's name and var's scales, where key
of dict is the var name, and value is the quant scales of var. of dict is the var name, and value is the quant scales of var.
round_type(str, optional): The rounding policy of converting the quantized recon_level(str, optional): The type of reconstruction granularity.
weights value float->int. Currently supports ['round', 'brecq', 'qdrop'] Currently support ['layer-wise', 'region-wise'] types. Default is layer-wise.
methods. simulate_activation_quant(bool, optional): Whether we need the noise caused by activation
'adaround' is refer to https://arxiv.org/abs/2004.10568, quantization during the reconstruction process.
'brecq' is refer to https://arxiv.org/pdf/2102.05426, regions(list[list], optional): The list of some regions, each region is a subgraph of
'qdrop' is refer to https://arxiv.org/pdf/2203.05740. fp32 program and it will have exact 1 input operation and 1 output operation. When
blocks(list[list], optional): The list of some blocks, each block is subgraph of the recon-level is region, the reconstruction loss of each region is minimized.
fp32 program and it will have exact 1 input operation and 1 output operation. Default is None.
block_weights_names(list[list], optional): The weight names inside every block. region_weights_names(list[list], optional): The weight names inside every region.
lr(float, optional): The learning rate of Rounding Optimizer. Default is None.
lr(float, optional): The learning rate of Reconstruction Quanter. Default is 0.1.
bias_correction(bool, optional): If set as True, use the bias correction bias_correction(bool, optional): If set as True, use the bias correction
method of https://arxiv.org/abs/1810.05723. Default is False. method of https://arxiv.org/abs/1810.05723. Default is False.
scale_trainable: Wether weight‘s scale is trainable. Default is False.
drop_prob: The dropout probability of activation quantization, and it is valid only if
simulate_activation_quant is True. Default is 0.5.
Returns: Returns:
None None
''' '''
assert round_type in ['adaround', 'brecq', 'qdrop'] assert recon_level in [
if round_type in ['brecq', 'qdrop']: 'layer-wise', 'region-wise'
assert blocks is not None, "The blocks cannot be None." ], "recon_level must be one of the ['layer-wise', 'region-wise'],but received: {}".format(
assert block_weights_names is not None, "The block_weights_names cannot be None." recon_level)
if recon_level == 'region-wise':
assert regions is not None, "The regions cannot be None."
assert region_weights_names is not None, "The region_weights_names cannot be None."
self._simulate_activation_quant = simulate_activation_quant
self._program = fp32_program self._program = fp32_program
self._data_loader = data_loader self._data_loader = data_loader
self._round_type = round_type self._recon_level = recon_level
self._feed_list = feed_list self._feed_list = feed_list
self._fetch_list = fetch_list self._fetch_list = fetch_list
self._exe = exe self._exe = exe
...@@ -158,17 +289,19 @@ class RoundingOptimizer(object): ...@@ -158,17 +289,19 @@ class RoundingOptimizer(object):
self._num_iterations = num_iterations self._num_iterations = num_iterations
self._epochs = epochs self._epochs = epochs
self._lr = lr self._lr = lr
self._blocks = blocks self._regions = regions
self._block_weights_names = block_weights_names self._region_weights_names = region_weights_names
self._bias_correction = bias_correction self._bias_correction = bias_correction
if round_type in ['adaround']: if self._recon_level == 'layer-wise':
blocks, block_weights_names = self._get_layers() regions, region_weights_names = self._get_layers()
self._blocks = blocks self._regions = regions
self._block_weights_names = block_weights_names self._region_weights_names = region_weights_names
self._scale_trainable = scale_trainable
self._drop_prob = drop_prob
def _get_layers(self): def _get_layers(self):
blocks = [] regions = []
block_weights_names = [] region_weights_names = []
persistable_var_names = self._all_persistable_var_names() persistable_var_names = self._all_persistable_var_names()
self._input_weight_pairs = {} self._input_weight_pairs = {}
for block_id in range(len(self._program.blocks)): for block_id in range(len(self._program.blocks)):
...@@ -180,14 +313,14 @@ class RoundingOptimizer(object): ...@@ -180,14 +313,14 @@ class RoundingOptimizer(object):
self._input_weight_pairs[in_var_name] = in_var_names self._input_weight_pairs[in_var_name] = in_var_names
break break
for name in self._weight_var_names: for name in self._weight_var_names:
block_weights_names.append([name]) region_weights_names.append([name])
block_ = [] region_ = []
block_.append(self._input_weight_pairs[name][0]) region_.append(self._input_weight_pairs[name][0])
block_.append(self._quantized_op_pairs[name]) region_.append(self._quantized_op_pairs[name])
blocks.append(block_) regions.append(region_)
return blocks, block_weights_names return regions, region_weights_names
def _preprocess(self): def _preprocess(self):
data_name_map = {} data_name_map = {}
for name in self._feed_list: for name in self._feed_list:
data_name_map[name] = name data_name_map[name] = name
...@@ -199,35 +332,51 @@ class RoundingOptimizer(object): ...@@ -199,35 +332,51 @@ class RoundingOptimizer(object):
self._place, self._place,
teacher_scope=None, teacher_scope=None,
name_prefix="teacher_", name_prefix="teacher_",
merge_feed=True) merge_feed=True, )
for name in self._weight_var_names: for name in self._weight_var_names:
weight_np = utils.load_variable_data(self._scope, name) weight_np = utils.load_variable_data(self._scope, name)
scale = self._scale_dict[name] scale = self._scale_dict[name]
weight_np_floor = np.floor(utils.quant_tensor(weight_np, scale)) weight_np_floor = np.floor(utils.quant_tensor(weight_np, scale))
utils.set_variable_data(self._scope, self._place, name, weight_np_floor) utils.set_variable_data(
self._scope,
self._place,
name,
weight_np_floor, )
self._graph = GraphWrapper(self._student_program) self._graph = GraphWrapper(self._student_program)
if self._round_type == 'qdrop': if self._simulate_activation_quant:
self._insert_drop_quant_dequant() self._insert_drop_quant_dequant()
self._insert_soft_rounding() self._insert_soft_rounding()
self._isolate_blocks() self._isolate_regions()
def _run(self): def _run(self):
self._preprocess() self._preprocess()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
for k in range(len(self._blocks)): for k in range(len(self._regions)):
block_ = self._blocks[k] region_ = self._regions[k]
names = self._block_weights_names[k] names = self._region_weights_names[k]
tmp_program = self._student_program.clone() tmp_program = self._student_program.clone()
quant_op_out_name = block_[1] quant_op_out_name = region_[1]
with paddle.static.program_guard(tmp_program, startup_program): with paddle.static.program_guard(tmp_program, startup_program):
loss_function = RoundingOptimizerLoss(tmp_program, names) loss_function = ReconstructionQuanterLoss(tmp_program, names)
quant_op_out_name = block_[1] quant_op_out_name = region_[1]
student_var = tmp_program.global_block().var(quant_op_out_name) student_var = tmp_program.global_block().var(quant_op_out_name)
teacher_var = tmp_program.global_block().var("teacher_"+quant_op_out_name) teacher_var = tmp_program.global_block().var("teacher_" +
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=20, eta_min=2, T_max=2000, verbose=True) quant_op_out_name)
total_loss, recon_loss, round_loss = loss_function.get_loss(student_var, teacher_var, scheduler) scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
train_fetches_loss = {"total_loss":total_loss, "recon_loss":recon_loss, "round_loss":round_loss} learning_rate=20,
eta_min=2,
T_max=2000,
verbose=True, )
total_loss, recon_loss, round_loss = loss_function.get_loss(
student_var,
teacher_var,
scheduler, )
train_fetches_loss = {
"total_loss": total_loss,
"recon_loss": recon_loss,
"round_loss": round_loss,
}
optimizer = paddle.optimizer.Adam(learning_rate=self._lr) optimizer = paddle.optimizer.Adam(learning_rate=self._lr)
optimizer.minimize(total_loss) optimizer.minimize(total_loss)
...@@ -241,11 +390,17 @@ class RoundingOptimizer(object): ...@@ -241,11 +390,17 @@ class RoundingOptimizer(object):
out = self._exe.run( out = self._exe.run(
tmp_program, tmp_program,
feed=data, feed=data,
fetch_list=[v.name for v in train_fetches_loss.values()], fetch_list=[
return_numpy=True) v.name for v in train_fetches_loss.values()
],
return_numpy=True, )
_logger.info( _logger.info(
"Iter {:d}, lr {}, total_loss {:.5f}, recon_loss {:.5f}, round_loss {:.5f}, time {:.5f}s" "Iter {:d}, lr {}, total_loss {:.5f}, recon_loss {:.5f}, round_loss {:.5f}, time {:.5f}s"
.format(epoch, self._lr, np.mean(out[0]), np.mean(out[1]), np.mean(out[2]), start_time - prev_start_time)) .format(epoch, self._lr,
np.mean(out[0]),
np.mean(out[1]),
np.mean(out[2]),
start_time - prev_start_time), )
sys.stdout.flush() sys.stdout.flush()
if i == self._num_iterations: if i == self._num_iterations:
break break
...@@ -255,7 +410,7 @@ class RoundingOptimizer(object): ...@@ -255,7 +410,7 @@ class RoundingOptimizer(object):
return self._program return self._program
def _init_alpha(self, name, scale): def _init_alpha(self, name, scale):
_tensor = utils.load_variable_data(self._scope, "teacher_"+name) _tensor = utils.load_variable_data(self._scope, "teacher_" + name)
tensor_scaled = utils.quant_tensor(_tensor, scale) tensor_scaled = utils.quant_tensor(_tensor, scale)
tensor_floor = np.floor(tensor_scaled) tensor_floor = np.floor(tensor_scaled)
tensor = tensor_scaled - tensor_floor tensor = tensor_scaled - tensor_floor
...@@ -269,31 +424,39 @@ class RoundingOptimizer(object): ...@@ -269,31 +424,39 @@ class RoundingOptimizer(object):
weight: The quanted weight with dtype=float32 weight: The quanted weight with dtype=float32
""" """
bnt = (1 << (weight_bits - 1)) - 1 bnt = (1 << (weight_bits - 1)) - 1
def _dequant(x, scale): def _dequant(x, scale):
s = (scale+1e-8)/bnt s = (scale + 1e-8) / bnt
dequant_x = s * x dequant_x = s * x
return dequant_x return dequant_x
quantized_weight = paddle.static.data(shape=weight.shape,
dtype=weight.dtype,
name=weight.name+'_quant')
v = paddle.static.create_parameter(shape=weight.shape, quantized_weight = paddle.static.data(
dtype=weight.dtype, shape=weight.shape,
name=weight.name+".alpha", dtype=weight.dtype,
default_initializer=fluid.initializer.NumpyArrayInitializer(self._alpha)) name=weight.name + '_quant', )
v = paddle.static.create_parameter(
shape=weight.shape,
dtype=weight.dtype,
name=weight.name + ".alpha",
default_initializer=fluid.initializer.NumpyArrayInitializer(
self._alpha, ), )
h_v = paddle.clip(paddle.nn.functional.sigmoid(v) * (ZETA - GAMMA) + GAMMA, 0, 1) h_v = paddle.clip(
paddle.nn.functional.sigmoid(v) * (ZETA - GAMMA) + GAMMA,
0,
1, )
if self._weight_quantize_type=='channel_wise_abs_max': if self._weight_quantize_type == 'channel_wise_abs_max':
scale_var = paddle.static.create_parameter( scale_var = paddle.static.create_parameter(
dtype=weight.dtype, dtype=weight.dtype,
shape=weight.shape, shape=weight.shape,
name=weight.name+'.scale', name=weight.name + '.scale',
default_initializer=fluid.initializer.NumpyArrayInitializer(scale), default_initializer=fluid.initializer.NumpyArrayInitializer(
) scale, ), )
else: else:
scale_var = scale scale_var = scale
w = _dequant(quantized_weight+h_v, scale_var) w = _dequant(quantized_weight + h_v, scale_var)
return w return w
def _insert_soft_rounding(self): def _insert_soft_rounding(self):
...@@ -302,26 +465,28 @@ class RoundingOptimizer(object): ...@@ -302,26 +465,28 @@ class RoundingOptimizer(object):
scale = self._scale_dict[name] scale = self._scale_dict[name]
shape = weight.shape() shape = weight.shape()
self._alpha = self._init_alpha(name, scale) self._alpha = self._init_alpha(name, scale)
if self._weight_quantize_type=='channel_wise_abs_max': if self._weight_quantize_type == 'channel_wise_abs_max':
scale = np.array(scale) scale = np.array(scale)
scale = scale.reshape(scale.shape[0], 1) scale = scale.reshape(scale.shape[0], 1)
if len(shape)==2: if len(shape) == 2:
scale = scale.repeat(shape[0], axis=0) scale = scale.repeat(shape[0], axis=0)
else: else:
scale = scale.repeat(shape[1]*shape[2]*shape[3], axis=1) scale = scale.repeat(shape[1] * shape[2] * shape[3], axis=1)
scale = scale.reshape(shape) scale = scale.reshape(shape)
self._insert_func(var=weight, scale=scale, func="_soft_rounding") self._insert_func(var=weight, scale=scale, func="_soft_rounding")
def _drop_quant_dequant(self, inputs, scale, weight_bits=8): def _drop_quant_dequant(self, inputs, scale, weight_bits=8):
x = paddle.static.data(shape=inputs.shape, x = paddle.static.data(
dtype=inputs.dtype, shape=inputs.shape,
name=inputs.name+'.tmp') dtype=inputs.dtype,
name=inputs.name + '.tmp', )
bnt = (1 << (weight_bits - 1)) - 1 bnt = (1 << (weight_bits - 1)) - 1
scale = scale / bnt scale = scale / bnt
dequantized_tensor = paddle.round(x / scale) * scale dequantized_tensor = paddle.round(x / scale) * scale
quant_noise = x - dequantized_tensor quant_noise = x - dequantized_tensor
random_noise = paddle.nn.functional.dropout(quant_noise, p=0.5) random_noise = paddle.nn.functional.dropout(
return x + random_noise quant_noise, p=self._drop_prob)
return x - random_noise
def _insert_drop_quant_dequant(self): def _insert_drop_quant_dequant(self):
for op in self._graph.ops(): for op in self._graph.ops():
...@@ -337,7 +502,10 @@ class RoundingOptimizer(object): ...@@ -337,7 +502,10 @@ class RoundingOptimizer(object):
else: else:
input = op.inputs("X")[0] input = op.inputs("X")[0]
if input.name() in self._scale_dict.keys(): if input.name() in self._scale_dict.keys():
self._insert_func(var=input, scale=self._scale_dict[input.name()], func="_drop_quant_dequant") self._insert_func(
var=input,
scale=self._scale_dict[input.name()],
func="_drop_quant_dequant", )
def _insert_func(self, var, scale, func): def _insert_func(self, var, scale, func):
program = var._graph.program program = var._graph.program
...@@ -346,51 +514,51 @@ class RoundingOptimizer(object): ...@@ -346,51 +514,51 @@ class RoundingOptimizer(object):
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
new_program = paddle.static.Program() new_program = paddle.static.Program()
with paddle.static.program_guard(new_program, startup_program): with paddle.static.program_guard(new_program, startup_program):
if func=="_soft_rounding": if func == "_soft_rounding":
out = self._soft_rounding(inputs, scale) out = self._soft_rounding(inputs, scale)
elif func=="_drop_quant_dequant": elif func == "_drop_quant_dequant":
out = self._drop_quant_dequant(inputs, scale) out = self._drop_quant_dequant(inputs, scale)
self._exe.run(startup_program) self._exe.run(startup_program)
#create var in program # create var in program
for new_var in new_program.list_vars(): for new_var in new_program.list_vars():
if new_var.name == var._var.name+'_quant' or new_var.name == var._var.name+'.tmp': if new_var.name == var._var.name + '_quant' or new_var.name == var._var.name + '.tmp':
continue continue
elif new_var.name == var._var.name+'.alpha': elif new_var.name == var._var.name + '.alpha':
program.global_block().create_parameter(
name=new_var.name,
shape=new_var.shape,
dtype=new_var.dtype,
type=new_var.type,
stop_gradient=new_var.stop_gradient)
elif new_var.name == var._var.name+'.scale':
program.global_block().create_parameter( program.global_block().create_parameter(
name=new_var.name, name=new_var.name,
shape=new_var.shape,
dtype=new_var.dtype,
type=new_var.type,
stop_gradient=True,
trainable=False)
else:
if func=="_soft_rounding":
program.global_block().create_var(
name=new_var.name+'.rounding',
shape=new_var.shape, shape=new_var.shape,
dtype=new_var.dtype, dtype=new_var.dtype,
type=new_var.type, type=new_var.type,
persistable=new_var.persistable, stop_gradient=new_var.stop_gradient, )
stop_gradient=new_var.stop_gradient) elif new_var.name == var._var.name + '.scale':
else: program.global_block().create_parameter(
program.global_block().create_var(
name=new_var.name, name=new_var.name,
shape=new_var.shape, shape=new_var.shape,
dtype=new_var.dtype, dtype=new_var.dtype,
type=new_var.type, type=new_var.type,
persistable=new_var.persistable, stop_gradient=True,
stop_gradient=new_var.stop_gradient) trainable=self._scale_trainable, )
else:
if func == "_soft_rounding":
program.global_block().create_var(
name=new_var.name + '.rounding',
shape=new_var.shape,
dtype=new_var.dtype,
type=new_var.type,
persistable=new_var.persistable,
stop_gradient=new_var.stop_gradient, )
else:
program.global_block().create_var(
name=new_var.name,
shape=new_var.shape,
dtype=new_var.dtype,
type=new_var.type,
persistable=new_var.persistable,
stop_gradient=new_var.stop_gradient, )
op_list = new_program.global_block().ops op_list = new_program.global_block().ops
op_list = list(reversed(op_list)) op_list = list(reversed(op_list))
block = var._var.block block = var._var.block
#prepend new_program's op in program # prepend new_program's op in program
for _op in ops: for _op in ops:
if _op.type() not in ['conv2d', 'depthwise_conv2d', 'mul']: if _op.type() not in ['conv2d', 'depthwise_conv2d', 'mul']:
continue continue
...@@ -398,84 +566,96 @@ class RoundingOptimizer(object): ...@@ -398,84 +566,96 @@ class RoundingOptimizer(object):
for op in op_list: for op in op_list:
# _attrs = op.all_attrs() # _attrs = op.all_attrs()
_type = op.type _type = op.type
_attrs={ _attrs = {
'use_mkldnn': False, 'use_mkldnn': False,
'with_quant_attr' :False} 'with_quant_attr': False,
if _type=='clip': }
_attrs={ if _type == 'clip':
_attrs = {
'use_mkldnn': False, 'use_mkldnn': False,
'with_quant_attr' :False, 'with_quant_attr': False,
'max':op.attr('max'), 'max': op.attr('max'),
'min':op.attr('min')} 'min': op.attr('min'),
elif _type=='scale': }
_attrs={ elif _type == 'scale':
_attrs = {
'use_mkldnn': False, 'use_mkldnn': False,
'with_quant_attr' :False, 'with_quant_attr': False,
'scale': op.attr('scale'), 'scale': op.attr('scale'),
'bias_after_scale':op.attr('bias_after_scale')} 'bias_after_scale': op.attr('bias_after_scale'),
elif _type=='elementwise_mul': }
_attrs={ elif _type == 'elementwise_mul':
_attrs = {
'use_mkldnn': False, 'use_mkldnn': False,
'with_quant_attr' :False, 'with_quant_attr': False,
'Scale_out':op.attr('Scale_out'), 'Scale_out': op.attr('Scale_out'),
'Scale_x':op.attr('Scale_x'), 'Scale_x': op.attr('Scale_x'),
'Scale_y':op.attr('Scale_y'), 'Scale_y': op.attr('Scale_y'),
'axis':op.attr('axis')} 'axis': op.attr('axis'),
}
if func=="_soft_rounding":
_outputs = {'Out':op.output('Out')[0]+'.rounding'} if func == "_soft_rounding":
if _type=="elementwise_add": _outputs = {'Out': op.output('Out')[0] + '.rounding'}
if _type == "elementwise_add":
_inputs = { _inputs = {
'X': var._var, #replace tmp var conv.weight_quant with var conv.weight 'X': var.
'Y': op.input('Y')[0]+'.rounding', _var, # replace tmp var conv.weight_quant with var conv.weight
} 'Y': op.input('Y')[0] + '.rounding',
elif _type=="elementwise_mul": }
elif _type == "elementwise_mul":
_inputs = { _inputs = {
'X':op.input('X')[0]+'.rounding', 'X': op.input('X')[0] + '.rounding',
'Y':op.input('Y')[0]+'.rounding', 'Y': op.input('Y')[0] + '.rounding',
} }
elif (_type=='scale' and op.input('X')[0].endswith('scale')) or _type=='sigmoid': elif (_type == 'scale' and
_inputs = {'X':op.input('X')[0]} op.input('X')[0].endswith('scale')
) or _type == 'sigmoid':
_inputs = {'X': op.input('X')[0]}
else: else:
_inputs = {'X':op.input('X')[0]+'.rounding'} _inputs = {'X': op.input('X')[0] + '.rounding'}
elif func=="_drop_quant_dequant": elif func == "_drop_quant_dequant":
if _type=='dropout': if _type == 'dropout':
_outputs = {'Out':op.output('Out')[0], _outputs = {
'Mask':op.output('Mask')[0]} 'Out': op.output('Out')[0],
'Mask': op.output('Mask')[0],
}
else: else:
_outputs = {'Out':op.output('Out')[0]} _outputs = {'Out': op.output('Out')[0]}
if _type=='elementwise_add' or _type=='elementwise_sub': if _type == 'elementwise_add' or _type == 'elementwise_sub':
_inputs = { _inputs = {
'X': var._var, #replace tmp var conv.weight_quant with var conv.weight 'X': var.
_var, # replace tmp var conv.weight_quant with var conv.weight
'Y': op.input('Y'), 'Y': op.input('Y'),
} }
elif _type=='scale' and op.input('X')[0]==inputs.name+'.tmp': elif _type == 'scale' and op.input('X')[
0] == inputs.name + '.tmp':
_inputs = {'X': var._var} _inputs = {'X': var._var}
else: else:
_inputs = {'X':op.input('X')[0]} _inputs = {'X': op.input('X')[0]}
block._insert_op( block._insert_op(
idx, idx,
type=_type, type=_type,
attrs=_attrs, attrs=_attrs,
inputs=_inputs, inputs=_inputs,
outputs=_outputs, outputs=_outputs, )
)
for op in ops: for op in ops:
if op.type() not in ['conv2d', 'depthwise_conv2d', 'mul']: if op.type() not in ['conv2d', 'depthwise_conv2d', 'mul']:
continue continue
if op.type() in ['conv2d', 'depthwise_conv2d'] and op.inputs('Filter')[0].name().startswith('teacher'): if op.type() in ['conv2d', 'depthwise_conv2d'] and op.inputs(
'Filter')[0].name().startswith('teacher'):
continue continue
if op.type() in ['mul'] and op.inputs('Y')[0].name().startswith('teacher'): if op.type() in ['mul'] and op.inputs('Y')[0].name().startswith(
continue 'teacher'):
if func=='_soft_rounding': continue
op._op._rename_input(inputs.name, out.name+'.rounding') if func == '_soft_rounding':
op._op._rename_input(inputs.name, out.name + '.rounding')
else: else:
op._op._rename_input(inputs.name, out.name) op._op._rename_input(inputs.name, out.name)
def _isolate_blocks(self): def _isolate_regions(self):
starts = [block[0] for block in self._blocks] starts = [region[0] for region in self._regions]
var2duplications = self._duplicate_vars(starts) var2duplications = self._duplicate_vars(starts)
for vars_ in var2duplications.values(): for vars_ in var2duplications.values():
for var_ in vars_: for var_ in vars_:
...@@ -495,49 +675,301 @@ class RoundingOptimizer(object): ...@@ -495,49 +675,301 @@ class RoundingOptimizer(object):
for op in var.outputs(): for op in var.outputs():
var_ = var._var var_ = var._var
op_ = op._op op_ = op._op
duplicated_var = block.create_var(name=var_.name+".assign"+str(index), duplicated_var = block.create_var(
type=var_.type, name=var_.name + ".assign" + str(index),
shape=var_.shape, type=var_.type,
dtype=var_.dtype) shape=var_.shape,
dtype=var_.dtype, )
vars.append(duplicated_var) vars.append(duplicated_var)
index += 1 index += 1
idx = block.ops.index(op_) idx = block.ops.index(op_)
block._insert_op(idx, block._insert_op(
type="assign", idx,
inputs={"X": var_}, type="assign",
outputs={"Out": duplicated_var}) inputs={"X": var_},
outputs={"Out": duplicated_var}, )
op_._rename_input(var_.name, duplicated_var.name) op_._rename_input(var_.name, duplicated_var.name)
return vars return vars
def _update_weights_to_int(self): def _update_weights_to_int(self):
for weight_var_name in self._weight_var_names: for weight_var_name in self._weight_var_names:
alpha_tensor = utils.load_variable_data(self._scope, weight_var_name+'.alpha') alpha_tensor = utils.load_variable_data(
self._scope,
weight_var_name + '.alpha', )
h_alpha_tensor = self._compute_soft_rounding_np(alpha_tensor) h_alpha_tensor = self._compute_soft_rounding_np(alpha_tensor)
weight_quant_tensor = utils.load_variable_data(self._scope, weight_var_name) weight_quant_tensor = utils.load_variable_data(
utils.set_variable_data(self._scope, self._place, weight_var_name, np.round(weight_quant_tensor+h_alpha_tensor)) self._scope,
weight_var_name, )
utils.set_variable_data(
self._scope,
self._place,
weight_var_name,
np.round(weight_quant_tensor + h_alpha_tensor, ), )
def _bias_correction_w(self): def _bias_correction_w(self):
for weight_var_name in self._weight_var_names: for weight_var_name in self._weight_var_names:
weight_var_tensor = utils.load_variable_data(self._scope, "teacher_"+weight_var_name) weight_var_tensor = utils.load_variable_data(
weight_quant_tensor = utils.load_variable_data(self._scope, weight_var_name) self._scope,
"teacher_" + weight_var_name, )
weight_quant_tensor = utils.load_variable_data(
self._scope,
weight_var_name, )
scale = self._scale_dict[weight_var_name] scale = self._scale_dict[weight_var_name]
final_weight_tensor = utils.bias_correction_w( final_weight_tensor = utils.bias_correction_w(
weight_var_tensor, weight_var_tensor,
weight_quant_tensor, weight_quant_tensor,
scale, scale,
quant_axis=0, quant_axis=0,
weight_bits=8) weight_bits=8, )
utils.set_variable_data(self._scope, self._place, weight_var_name, final_weight_tensor) utils.set_variable_data(
self._scope,
self._place,
weight_var_name,
final_weight_tensor, )
def _compute_soft_rounding_np(self, alpha_v): def _compute_soft_rounding_np(self, alpha_v):
return np.clip(utils.stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, return np.clip(
a_min=0, utils.stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA,
a_max=1) a_min=0,
a_max=1, )
def _all_persistable_var_names(self): def _all_persistable_var_names(self):
persistable_var_names = [] persistable_var_names = []
for var in self._program.list_vars(): for var in self._program.list_vars():
if var.persistable: if var.persistable:
persistable_var_names.append(var.name) persistable_var_names.append(var.name)
return persistable_var_names return persistable_var_names
class ReconstructionQuanterLoss(object):
def __init__(self,
program,
weight_region_names=None,
round_loss_type='relaxation',
rec_loss_type='mse',
beta_type='const',
weight=0.1):
"""
The loss function of Rounding Optimizer.
Args:
program(Program): The student program.
weight_region_names(list, optional): The weight names inside a region.
round_loss_type(str): The type of rounding loss function.
rec_loss_type(str): The type of reconstruction loss function.
beta_type(str): The type of hyper-parameter beta.
Returns:
total_loss(Variable): The sum of rounding loss and reconstruction loss.
rec_loss(Variable): The reconstruction loss.
round_loss(Variable): The rounding loss.
"""
self.program = program
self.round_loss_type = round_loss_type
self.weight = weight
self.rec_loss_type = rec_loss_type
self.weight_region_names = weight_region_names
self.beta_type = beta_type
def compute_soft_rounding(self, alpha_v):
return paddle.clip(
paddle.nn.functional.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, 0,
1)
def get_loss(self, student_tensor, teacher_tensor, scheduler):
if self.rec_loss_type == 'mse':
rec_loss = paddle.nn.functional.mse_loss(
student_tensor,
teacher_tensor, )
else:
raise ValueError(
'Not supported reconstruction loss function: {}'.format(
self.rec_loss, ), )
if self.beta_type == 'const':
self.beta = 3
else:
self.beta = scheduler.get_lr()
if self.round_loss_type == 'relaxation':
round_loss = 0.0
for name in self.weight_region_names:
alpha_v = self.program.global_block().var(name + '.alpha')
h_v = self.compute_soft_rounding(alpha_v)
round_loss += self.weight * \
paddle.sum(-paddle.pow(paddle.abs(2 * h_v-1), self.beta) + 1)
else:
raise NotImplementedError
total_loss = rec_loss + round_loss
return total_loss, rec_loss, round_loss
def quant_recon_static(executor,
model_dir,
quantize_model_path,
batch_generator=None,
sample_generator=None,
data_loader=None,
model_filename=None,
params_filename=None,
save_model_filename='model.pdmodel',
save_params_filename='model.pdiparams',
batch_size=1,
batch_nums=None,
scope=None,
algo='hist',
recon_level='layer-wise',
simulate_activation_quant=False,
hist_percent=0.9999,
bias_correction=False,
quantizable_op_type=[
"conv2d",
"depthwise_conv2d",
"mul",
"matmul",
"matmul_v2",
],
is_full_quantize=False,
weight_bits=8,
activation_bits=8,
activation_quantize_type='range_abs_max',
weight_quantize_type='channel_wise_abs_max',
optimize_model=False,
onnx_format=False,
skip_tensor_list=None,
is_use_cache_file=False,
cache_dir="./temp_recon_quantization",
regions=None,
region_weights_names=None,
epochs=20,
scale_trainable=False,
drop_prob=0.5,
lr=0.1):
"""
The function utilizes static post training quantization method to
quantize the fp32 model. It uses calibrate data to calculate the
scale factor of quantized variables, and inserts fake quantization
and dequantization operators to obtain the quantized model.
Args:
executor(paddle.static.Executor): The executor to load, run and save the
quantized model.
model_dir(str): The path of fp32 model that will be quantized, and
the model and params that saved by ``paddle.static.io.save_inference_model``
are under the path.
quantize_model_path(str): The path to save quantized model using api
``paddle.static.io.save_inference_model``.
batch_generator(Python Generator): The batch generator provides
calibrate data for DataLoader, and it returns a batch every
time. For sample_generator and batch_generator, only one
can be set. Beisdes, batch_generator supports lod tensor.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every time.
data_loader(Python Generator, Paddle.io.DataLoader, optional): The
Generator or Dataloader provides calibrate data, and it could
return a batch every time.
model_filename(str, optional): The name of model file. If parameters
are saved in separate files, set it as 'None'. Default: 'None'.
params_filename(str, optional): The name of params file.
When all parameters are saved in a single file, set it
as filename. If parameters are saved in separate files,
set it as 'None'. Default : 'None'.
save_model_filename(str): The name of model file to save the quantized inference program. Default: 'model.pdmodel'.
save_params_filename(str): The name of file to save all related parameters.
If it is set None, parameters will be saved in separate files. Default: 'model.pdiparams'.
batch_size(int, optional): The batch size of DataLoader, default is 1.
batch_nums(int, optional): If batch_nums is not None, the number of calibrate
data is 'batch_size*batch_nums'. If batch_nums is None, use all data
generated by sample_generator as calibrate data.
scope(paddle.static.Scope, optional): The scope to run program, use it to load
and save variables. If scope is None, will use paddle.static.global_scope().
algo(str, optional): If algo='KL', use KL-divergenc method to
get the scale factor. If algo='hist', use the hist_percent of histogram
to get the scale factor. If algo='mse', search for the best scale factor which
makes the mse loss minimal. Use one batch of data for mse is enough. If
algo='avg', use the average of abs_max values to get the scale factor. If
algo='abs_max', use abs_max method to get the scale factor. Default: 'hist'.
recon_level(str, optional): The type of reconstruction granularity.
Currently support ['layer-wise', 'region-wise'] types. Default is layer-wise.
simulate_activation_quant(bool, optional): Whether we need the noise caused by activation
quantization during the reconstruction process. Default is False.
hist_percent(float, optional): The percentile of histogram for algo hist.Default:0.9999.
bias_correction(bool, optional): Bias correction method of https://arxiv.org/abs/1810.05723.
Default: False.
quantizable_op_type(list[str], optional): The list of op types
that will be quantized. Default: ["conv2d", "depthwise_conv2d", "mul"].
weight_bits(int, optional): quantization bit number for weights.
activation_bits(int): quantization bit number for activation.
activation_quantize_type(str): quantization type for activation,
now support 'range_abs_max', 'moving_average_abs_max' and 'abs_max'.
This parameter only specifies the fake ops in quantized model.
If it is 'range_abs_max' or 'moving_average_abs_max', we save the scale
obtained by post training quantization in fake ops. If it
is 'abs_max', the scale will not be saved in fake ops.
weight_quantize_type(str): quantization type for weights,
support 'abs_max' and 'channel_wise_abs_max'. Compared to 'abs_max',
the model accuracy is usually higher when using 'channel_wise_abs_max'.
is_full_quantize(bool): if True, apply quantization to all supported quantizable op type.
If False, only apply quantization to the input quantizable_op_type. Default is False.
optimize_model(bool, optional): If set optimize_model as True, it applies some
passes to optimize the model before quantization. So far, the place of
executor must be cpu it supports fusing batch_norm into convs.
onnx_format(bool): Whether to export the quantized model with format of ONNX. Default is False.
skip_tensor_list(list): List of skip quant tensor name.
is_use_cache_file(bool): This param is deprecated.
cache_dir(str): This param is deprecated.
epochs: The number of steps in the reconstruction proces. Default is 20.
scale_trainable: Wether weight‘s scale is trainable. Default is False.
drop_prob: The dropout probability of activation quantization, and it is valid only if
simulate_activation_quant is True. Default is 0.5.
regions(list[list], optional): The list of some regions, each region is a subgraph of
fp32 program and it will have exact 1 input operation and 1 output operation. When
the recon-level is region, the reconstruction loss of each region is minimized.
Default is None.
region_weights_names(list[list], optional): The weight names inside every region.
Default is None.
Returns:
None
"""
PTQCollections = Collections(
executor=executor,
sample_generator=sample_generator,
batch_generator=batch_generator,
data_loader=data_loader,
model_dir=model_dir,
model_filename=model_filename,
params_filename=params_filename,
batch_size=batch_size,
batch_nums=batch_nums,
scope=scope,
algo=algo,
hist_percent=hist_percent,
bias_correction=bias_correction,
quantizable_op_type=quantizable_op_type,
is_full_quantize=is_full_quantize,
weight_bits=weight_bits,
activation_bits=activation_bits,
activation_quantize_type=activation_quantize_type,
weight_quantize_type=weight_quantize_type,
onnx_format=onnx_format,
skip_tensor_list=skip_tensor_list,
optimize_model=optimize_model,
round_type='adaround')
RSQCollections = Collections(
recon_level=recon_level,
simulate_activation_quant=simulate_activation_quant,
regions=regions,
region_weights_names=region_weights_names,
epochs=epochs,
scale_trainable=scale_trainable,
lr=lr)
reconstruction_quantization = ReconstructionQuantization(
PTQCollections=PTQCollections, RSQCollections=RSQCollections)
reconstruction_quantization.quantize()
reconstruction_quantization.save_quantized_model(
quantize_model_path,
model_filename=save_model_filename,
params_filename=save_params_filename)
...@@ -22,15 +22,15 @@ from models import MobileNet ...@@ -22,15 +22,15 @@ from models import MobileNet
from layers import conv_bn_layer from layers import conv_bn_layer
import paddle.dataset.mnist as reader import paddle.dataset.mnist as reader
import numpy as np import numpy as np
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization from paddleslim.quant import quant_recon_static
from paddleslim.quant.rounding_optimizer import RoundingOptimizer
class TestRoundingOptimizer(StaticCase): class TestRoundingOptimizer(StaticCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TestRoundingOptimizer, self).__init__(*args, **kwargs) super(TestRoundingOptimizer, self).__init__(*args, **kwargs)
paddle.enable_static() paddle.enable_static()
self._gen_model() self._gen_model()
def _gen_model(self): def _gen_model(self):
image = paddle.static.data( image = paddle.static.data(
name='image', shape=[None, 1, 28, 28], dtype='float32') name='image', shape=[None, 1, 28, 28], dtype='float32')
...@@ -52,13 +52,15 @@ class TestRoundingOptimizer(StaticCase): ...@@ -52,13 +52,15 @@ class TestRoundingOptimizer(StaticCase):
) else paddle.CPUPlace() ) else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
def transform(x): def transform(x):
return np.reshape(x, [1, 28, 28]) return np.reshape(x, [1, 28, 28])
train_dataset = paddle.vision.datasets.MNIST( train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform) mode='train', backend='cv2', transform=transform)
test_dataset = paddle.vision.datasets.MNIST( test_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform) mode='test', backend='cv2', transform=transform)
train_loader = paddle.io.DataLoader( self.train_loader = paddle.io.DataLoader(
train_dataset, train_dataset,
places=place, places=place,
feed_list=[image, label], feed_list=[image, label],
...@@ -71,15 +73,18 @@ class TestRoundingOptimizer(StaticCase): ...@@ -71,15 +73,18 @@ class TestRoundingOptimizer(StaticCase):
feed_list=[image, label], feed_list=[image, label],
batch_size=64, batch_size=64,
return_list=False) return_list=False)
def sample_generator_creator(): def sample_generator_creator():
def __reader__(): def __reader__():
for data in test_dataset: for data in test_dataset:
image, label = data image, label = data
yield image, label yield image, label
return __reader__ return __reader__
def train(program): def train(program):
iter = 0 iter = 0
for data in train_loader(): for data in self.train_loader():
cost, top1, top5 = exe.run( cost, top1, top5 = exe.run(
program, program,
feed=data, feed=data,
...@@ -89,6 +94,7 @@ class TestRoundingOptimizer(StaticCase): ...@@ -89,6 +94,7 @@ class TestRoundingOptimizer(StaticCase):
print( print(
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'. 'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5)) format(iter, cost, top1, top5))
train(main_prog) train(main_prog)
paddle.fluid.io.save_inference_model( paddle.fluid.io.save_inference_model(
dirname='./test_rounding_optimizer', dirname='./test_rounding_optimizer',
...@@ -98,55 +104,59 @@ class TestRoundingOptimizer(StaticCase): ...@@ -98,55 +104,59 @@ class TestRoundingOptimizer(StaticCase):
executor=exe, executor=exe,
model_filename='model', model_filename='model',
params_filename='params') params_filename='params')
self.post_training_quantization = PostTrainingQuantization(
exe, self.data_loader = sample_generator_creator()
'./test_rounding_optimizer',
sample_generator=sample_generator_creator(), self._regions = [['image', 'batch_norm_26.tmp_4']]
model_filename='model', self._region_weights_names = [[
params_filename='params', 'conv1_weights', 'conv2_1_dw_weights', 'conv2_1_sep_weights',
batch_nums=10, 'conv2_2_dw_weights', 'conv2_2_sep_weights', 'conv3_1_dw_weights',
algo='abs_max', 'conv3_1_sep_weights', 'conv3_2_dw_weights', 'conv3_2_sep_weights',
bias_correction=True) 'conv4_1_dw_weights', 'conv4_1_sep_weights', 'conv4_2_dw_weights',
'conv4_2_sep_weights', 'conv5_1_dw_weights', 'conv5_1_sep_weights',
self.post_training_quantization._load_model_data() 'conv5_2_dw_weights', 'conv5_2_sep_weights', 'conv5_3_dw_weights',
self.post_training_quantization._collect_target_varnames() 'conv5_3_sep_weights', 'conv5_4_dw_weights', 'conv5_4_sep_weights',
self.post_training_quantization._set_activation_persistable() 'conv5_5_dw_weights', 'conv5_5_sep_weights', 'conv5_6_dw_weights',
for data in self.post_training_quantization._data_loader(): 'conv5_6_sep_weights', 'conv6_dw_weights', 'conv6_sep_weights'
self.post_training_quantization._executor.run(program=self.post_training_quantization._program, ]]
feed=data,
fetch_list=self.post_training_quantization._fetch_list,
return_numpy=False,
scope=self.post_training_quantization._scope)
self.post_training_quantization._sampling()
self.post_training_quantization._reset_activation_persistable()
self._blocks= [['image','batch_norm_26.tmp_4']]
self._block_weights_names= [['conv1_weights', 'conv2_1_dw_weights', 'conv2_1_sep_weights', 'conv2_2_dw_weights', 'conv2_2_sep_weights', 'conv3_1_dw_weights', 'conv3_1_sep_weights','conv3_2_dw_weights','conv3_2_sep_weights'
,'conv4_1_dw_weights','conv4_1_sep_weights','conv4_2_dw_weights','conv4_2_sep_weights','conv5_1_dw_weights','conv5_1_sep_weights','conv5_2_dw_weights','conv5_2_sep_weights','conv5_3_dw_weights','conv5_3_sep_weights','conv5_4_dw_weights','conv5_4_sep_weights','conv5_5_dw_weights','conv5_5_sep_weights','conv5_6_dw_weights','conv5_6_sep_weights','conv6_dw_weights','conv6_sep_weights']]
def test_qdrop(self): def test_qdrop(self):
rounding_optimizer = RoundingOptimizer( place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
data_loader=self.post_training_quantization._data_loader, ) else paddle.CPUPlace()
fp32_program=self.post_training_quantization._program, exe = paddle.static.Executor(place)
feed_list=self.post_training_quantization._feed_list, quant_recon_static(
fetch_list=self.post_training_quantization._fetch_list, exe,
exe=self.post_training_quantization._executor, './test_rounding_optimizer',
scope=self.post_training_quantization._scope, quantize_model_path='rsq_out',
place=self.post_training_quantization._place, sample_generator=self.data_loader,
quantized_op_pairs=self.post_training_quantization._quantized_op_pairs, model_filename='model',
weight_quantize_type=self.post_training_quantization._weight_quantize_type, params_filename='params',
scale_dict=self.post_training_quantization._quantized_threshold, batch_nums=10,
blocks=self._blocks, algo='abs_max',
block_weights_names=self._block_weights_names, regions=self._regions,
round_type='qdrop', region_weights_names=self._region_weights_names,
num_iterations=self.post_training_quantization._batch_nums, recon_level='region-wise',
lr=self.post_training_quantization._learning_rate, simulate_activation_quant=True)
bias_correction=self.post_training_quantization._bias_correction,
epochs=10, def test_qdrop(self):
) place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
rounding_optimizer._run() ) else paddle.CPUPlace()
rounding_optimizer._get_layers() exe = paddle.static.Executor(place)
quant_recon_static(
exe,
'./test_rounding_optimizer',
quantize_model_path='rsq_out',
sample_generator=self.data_loader,
model_filename='model',
params_filename='params',
batch_nums=10,
algo='KL',
regions=self._regions,
region_weights_names=self._region_weights_names,
recon_level='layer-wise',
simulate_activation_quant=True,
bias_correction=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册