未验证 提交 c7ffb3a3 编写于 作者: Z zhouzj 提交者: GitHub

[clean fluid api] replace fluid/contrib/slim api. (#1615)

上级 00f11040
......@@ -31,7 +31,7 @@ from paddle.metric.metrics import Accuracy
import paddle.vision.models as models
from paddleslim import QAT
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.quantization import ImperativeQuantAware
from imagenet_dataset import ImageNetDataset
......
......@@ -23,7 +23,7 @@ import numpy as np
import time
import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass
from paddle.quantization import Quant2Int8MkldnnPass
from paddle.framework import core
paddle.enable_static()
......
......@@ -54,7 +54,7 @@ import numpy as np
为了部署在CPU上,我们将保存的quant模型,通过一个转化脚本,移除fake_quantize/fake_dequantize op,进行算子融合和优化并且转化为INT8模型。
脚本在官网的位置为[save_quant_model.py](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/contrib/slim/tests/save_quant_model.py)
脚本在官网的位置为[save_quant_model.py](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/static/quantization/tests/save_quant_model.py)
复制脚本到本样例所在目录(`/PATH_TO_PaddleSlim/demo/mkldnn_quant/`),并执行如下命令:
```
......@@ -181,4 +181,4 @@ INT8模型精度和性能结果参考[CPU部署预测INT8模型的精度和性
## FAQ
- 自然语言处理模型在CPU上的部署和预测参考样例[ERNIE 模型 QUANT INT8 精度与性能复现](https://github.com/PaddlePaddle/benchmark/tree/master/Inference/c++/ernie/mkldnn)
- 具体DNNL量化原理可以查看[SLIM Quant for INT8 MKLDNN](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/contrib/slim/tests/README.md)
- 具体DNNL量化原理可以查看[SLIM Quant for INT8 MKLDNN](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/static/quantization/tests/README.md)
......@@ -45,7 +45,7 @@ To generate fake quantized model with quant-aware strategy, see [Quant-aware tra
To generate post-training fake quantized model, see [Offline post-training quantization tutorial](https://paddleslim.readthedocs.io/en/latest/quick_start/index_en.html)
## 3. Convert the fake quantized model to DNNL INT8 model
In order to deploy an INT8 model on the CPU, we need to collect scales, remove all fake_quantize/fake_dequantize operators, optimize the graph and quantize it, turning it into the final DNNL INT8 model. This is done by the script [save_quant_model.py](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/contrib/slim/tests/save_quant_model.py). Copy the script to the directory where the demo is located: `/PATH_TO_PaddleSlim/demo/mkldnn_quant/` and run it as follows:
In order to deploy an INT8 model on the CPU, we need to collect scales, remove all fake_quantize/fake_dequantize operators, optimize the graph and quantize it, turning it into the final DNNL INT8 model. This is done by the script [save_quant_model.py](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/static/quantization/tests/save_quant_model.py). Copy the script to the directory where the demo is located: `/PATH_TO_PaddleSlim/demo/mkldnn_quant/` and run it as follows:
```
python save_quant_model.py --quant_model_path=/PATH/TO/SAVE/FLOAT32/quant/MODEL --int8_model_save_path=/PATH/TO/SAVE/INT8/MODEL
```
......@@ -176,4 +176,4 @@ For INT8 models accuracy and performance results see [CPU deployment predicts th
## FAQ
- For deploying INT8 NLP models on CPU, see [ERNIE model quant INT8 accuracy and performance reproduction](https://github.com/PaddlePaddle/benchmark/tree/master/Inference/c++/ernie/mkldnn)
- The detailed DNNL quantification process can be viewed in [SLIM quant for INT8 DNNL](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/contrib/slim/tests/README.md)
- The detailed DNNL quantification process can be viewed in [SLIM quant for INT8 DNNL](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/static/quantization/tests/README.md)
......@@ -35,9 +35,9 @@ config->EnableTensorRtEngine(1 << 20 /* workspace_size*/,
false /* use_calib_mode*/);
```
- 如果量化模型在x86上线,需要使用[INT8 MKL-DNN](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/fluid/contrib/slim/tests)
- 如果量化模型在x86上线,需要使用[INT8 MKL-DNN](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/static/quantization/tests)
- 首先对模型进行转化,可以参考[脚本](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/contrib/slim/tests/save_quant_model.py)
- 首先对模型进行转化,可以参考[脚本](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/static/quantization/tests/save_quant_model.py)
- 转化之后可使用预测部署API进行加载。比如[c++ API](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/advanced_guide/inference_deployment/inference/native_infer.html)
......
......@@ -43,7 +43,7 @@ import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.quantization import PostTrainingQuantization
from paddleslim.quant.analysis_qat import AnalysisQAT
paddle.enable_static()
......
......@@ -276,7 +276,7 @@ print(f"Operators in inference model:\n{op_types.keys()}")
执行以下代码,查看当前PaddlePaddle版本的量化功能所支持的OP类型:
```
from paddle.fluid.contrib.slim.quantization.utils import _weight_supported_quantizable_op_type, _act_supported_quantizable_op_type
from paddle.static.quantization.utils import _weight_supported_quantizable_op_type, _act_supported_quantizable_op_type
print(f"_supported_quantizable_op_type:\n{_weight_supported_quantizable_op_type}")
print(f"_supported_quantizable_op_type:\n{_act_supported_quantizable_op_type}")
```
......
......@@ -19,9 +19,11 @@ import paddle
import paddleslim
import subprocess
import time
import ssl
import requests
import shutil
import logging
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
__all__ = [
"save_cls_model", "save_det_model", "nearest_interpolate", "opt_model",
"load_predictor"
......@@ -36,7 +38,7 @@ def _get_download(url, fullname):
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
_logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
......
......@@ -804,7 +804,7 @@ class AutoCompression:
else:
logging_iter = train_config.logging_iter
if batch_id % int(logging_iter) == 0:
print_info = "Total iter: {}, epoch: {}, batch: {}, loss: {}".format(
print_info = "Total iter: {}, epoch: {}, batch: {}, loss: {} ".format(
total_train_iter, epoch_id, batch_id, loss[0])
for idx, loss_value in enumerate(loss[1:]):
print_info += '{}: {} '.format(loss_names[idx],
......
......@@ -2,10 +2,10 @@ import os
import paddle
from paddle.fluid.framework import IrGraph
from paddle.framework import core
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass, QuantizationTransformPassV2, AddQuantDequantPass, AddQuantDequantPassV2, QuantizationFreezePass, QuantWeightPass
from paddle.static.quantization import QuantizationTransformPass, QuantizationTransformPassV2, AddQuantDequantPass, AddQuantDequantPassV2, QuantizationFreezePass, QuantWeightPass
try:
from paddle.fluid.contrib.slim.quantization import utils
from paddle.static.quantization import utils
TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type
except:
......
......@@ -17,11 +17,16 @@ import logging
import paddle
import paddle.nn as nn
import paddle.fluid.contrib.slim.quantization as Q
from paddle.fluid.contrib.slim.quantization import AbsmaxQuantizer
from paddle.fluid.contrib.slim.quantization import HistQuantizer
from paddle.fluid.contrib.slim.quantization import KLQuantizer
from paddle.fluid.contrib.slim.quantization import PerChannelAbsmaxQuantizer
from paddle.quantization import (
PTQConfig,
ImperativePTQ,
AbsmaxQuantizer,
HistQuantizer,
KLQuantizer,
PerChannelAbsmaxQuantizer,
SUPPORT_ACT_QUANTIZERS,
SUPPORT_WT_QUANTIZERS, )
from ...common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
......@@ -56,14 +61,14 @@ class PTQ(object):
print("activation_quantizer", activation_quantizer)
activation_quantizer = eval(activation_quantizer)(**kwargs)
weight_quantizer = eval(weight_quantizer)()
assert isinstance(activation_quantizer, tuple(Q.SUPPORT_ACT_QUANTIZERS))
assert isinstance(weight_quantizer, tuple(Q.SUPPORT_WT_QUANTIZERS))
assert isinstance(activation_quantizer, tuple(SUPPORT_ACT_QUANTIZERS))
assert isinstance(weight_quantizer, tuple(SUPPORT_WT_QUANTIZERS))
quant_config = Q.PTQConfig(
quant_config = PTQConfig(
activation_quantizer=activation_quantizer,
weight_quantizer=weight_quantizer)
self.ptq = Q.ImperativePTQ(quant_config=quant_config)
self.ptq = ImperativePTQ(quant_config=quant_config)
def quantize(self, model, inplace=False, fuse=False, fuse_list=None):
"""
......
......@@ -203,7 +203,7 @@ class QAT(object):
# TODO: remove try-except when the version is stable
try:
self.imperative_qat = paddle.fluid.contrib.slim.quantization.ImperativeQuantAware(
self.imperative_qat = paddle.quantization.ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
......@@ -220,7 +220,7 @@ class QAT(object):
onnx_format=self.config['onnx_format'], # support Paddle >= 2.4
)
except:
self.imperative_qat = paddle.fluid.contrib.slim.quantization.ImperativeQuantAware(
self.imperative_qat = paddle.quantization.ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
......@@ -291,7 +291,7 @@ class QAT(object):
def _remove_preprocess(self, model):
state_dict = model.state_dict()
try:
self.imperative_qat = paddle.fluid.contrib.slim.quantization.ImperativeQuantAware(
self.imperative_qat = paddle.quantization.ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
......@@ -302,7 +302,7 @@ class QAT(object):
onnx_format=self.config['onnx_format'], # support Paddle >= 2.4
)
except:
self.imperative_qat = paddle.fluid.contrib.slim.quantization.ImperativeQuantAware(
self.imperative_qat = paddle.quantization.ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
......
......@@ -165,7 +165,7 @@ class AnalysisPTQ(object):
_logger.info('Activation Statistic is saved in {}'.format(save_path))
def create_ptq(self, executor, skip_tensor_list):
return paddle.fluid.contrib.slim.quantization.PostTrainingQuantization(
return paddle.static.quantization.PostTrainingQuantization(
executor=executor,
data_loader=self.data_loader,
model_dir=self.model_dir,
......@@ -331,7 +331,7 @@ class AnalysisPTQ(object):
def collect_vars(self, scope, var_names):
all_vars = {}
for var_name in var_names:
var_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
var_tensor = paddle.static.quantization.utils.load_variable_data(
scope, var_name)
all_vars[var_name] = var_tensor
return all_vars
......@@ -446,7 +446,7 @@ class AnalysisPTQ(object):
for op_name in weight_names:
for block_id in range(len(program.blocks)):
for op in program.blocks[block_id].ops:
var_name_list = paddle.fluid.contrib.slim.quantization.utils._get_op_input_var_names(
var_name_list = paddle.static.quantization.utils._get_op_input_var_names(
op)
if op_name in var_name_list:
for var_name in var_name_list:
......
......@@ -68,8 +68,8 @@ class QuantConfig(object):
eval_function=None,
model_filename=None,
params_filename=None,
save_model_filename='__model__',
save_params_filename='__params__',
save_model_filename='model.pdmodel',
save_params_filename='model.pdiparams',
scope=None,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
......@@ -190,14 +190,14 @@ def eval_quant_model():
quant_scope = paddle.static.Scope()
with paddle.static.scope_guard(float_scope):
[infer_prog_float, feed_target_names_float, fetch_targets_float] = \
paddle.fluid.io.load_inference_model(dirname=g_quant_config.float_infer_model_path, \
paddle.static.load_inference_model(path_prefix=g_quant_config.float_infer_model_path, \
model_filename=g_quant_config.model_filename, \
params_filename=g_quant_config.params_filename, \
executor=g_quant_config.executor)
with paddle.static.scope_guard(quant_scope):
[infer_prog_quant, feed_target_names_quant, fetch_targets_quant] = \
paddle.fluid.io.load_inference_model(dirname=g_quant_model_cache_path, \
paddle.static.load_inference_model(path_prefix=g_quant_model_cache_path, \
model_filename=g_quant_config.save_model_filename, \
params_filename=g_quant_config.save_params_filename, \
executor=g_quant_config.executor)
......@@ -304,7 +304,7 @@ def quantize(cfg):
quant_scope = paddle.static.Scope()
with paddle.static.scope_guard(float_scope):
[float_inference_program, feed_target_names, fetch_targets]= paddle.static.load_inference_model( \
dirname=g_quant_config.float_infer_model_path, \
path_prefix=g_quant_config.float_infer_model_path, \
model_filename=g_quant_config.model_filename, params_filename=g_quant_config.params_filename,
executor=g_quant_config.executor)
float_metric = g_quant_config.eval_function(
......@@ -313,7 +313,7 @@ def quantize(cfg):
with paddle.static.scope_guard(quant_scope):
[quant_inference_program, feed_target_names, fetch_targets] = paddle.static.load_inference_model( \
dirname=g_quant_model_cache_path, \
path_prefix=g_quant_model_cache_path, \
model_filename=g_quant_config.model_filename, params_filename=g_quant_config.params_filename,
executor=g_quant_config.executor)
quant_metric = g_quant_config.eval_function(
......@@ -344,8 +344,8 @@ def quant_post_hpo(
eval_function=None,
model_filename=None,
params_filename=None,
save_model_filename='__model__',
save_params_filename='__params__',
save_model_filename='model,pdmodel',
save_params_filename='model.pdiparams',
scope=None,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
......@@ -388,9 +388,8 @@ def quant_post_hpo(
When all parameters are saved in a single file, set it
as filename. If parameters are saved in separate files,
set it as 'None'. Default : 'None'.
save_model_filename(str): The name of model file to save the quantized inference program. Default: '__model__'.
save_params_filename(str): The name of file to save all related parameters.
If it is set None, parameters will be saved in separate files. Default: '__params__'.
save_model_filename(str): The name of model file to save the quantized inference program. Default: 'model.pdmodel'.
save_params_filename(str): The name of file to save all related parameters. Default: 'model.pdiparams'.
scope(paddle.static.Scope, optional): The scope to run program, use it to load
and save variables. If scope is None, will use paddle.static.global_scope().
quantizable_op_type(list[str], optional): The list of op types
......
......@@ -21,15 +21,14 @@ import paddle
from paddle.framework import core
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import WeightQuantization
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from paddle.static.quantization import WeightQuantization
from paddle.static.quantization import QuantizationTransformPass
from paddle.static.quantization import QuantizationFreezePass
from paddle.static.quantization import ConvertToInt8Pass
from paddle.static.quantization import PostTrainingQuantization
from paddle.static.quantization import AddQuantDequantPass
from paddle.static.quantization import OutScaleForTrainingPass
from paddle.static.quantization import OutScaleForInferencePass
from ..common import get_logger
from ..common.patterns import get_patterns
from ..common.patterns_common import has_trainable_var, get_weight
......@@ -37,11 +36,11 @@ from ..core.graph_wrapper import GraphWrapper
_logger = get_logger(__name__, level=logging.INFO)
try:
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPassV2
from paddle.fluid.contrib.slim.quantization import QuantWeightPass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPassV2
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantizationProgram
from paddle.fluid.contrib.slim.quantization import AddQuantDequantForInferencePass
from paddle.static.quantization import QuantizationTransformPassV2
from paddle.static.quantization import QuantWeightPass
from paddle.static.quantization import AddQuantDequantPassV2
from paddle.static.quantization import PostTrainingQuantizationProgram
from paddle.static.quantization import AddQuantDequantForInferencePass
except:
_logger.warning(
"Some functions fail to import, please update PaddlePaddle version to 2.4+"
......@@ -62,7 +61,7 @@ ACTIVATION_QUANTIZATION_TYPES_TENSORRT = [
VALID_DTYPES = ['int8']
try:
from paddle.fluid.contrib.slim.quantization import utils
from paddle.static.quantization import utils
TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type
except:
......@@ -383,6 +382,7 @@ def quant_aware(program,
**calib_config)
main_graph = post_training_quantization.quantize()
scale_dict = post_training_quantization._scale_dict
sub_graphs = [sub_graph for sub_graph in main_graph.all_sub_graphs()]
else:
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
sub_graphs = [sub_graph for sub_graph in main_graph.all_sub_graphs()]
......@@ -582,7 +582,7 @@ def quant_post_static(executor,
sample_generator=sample_generator,
batch_generator=batch_generator,
data_loader=data_loader,
model_dir=model_dir,
model_dir=model_dir.rstrip('/'),
model_filename=model_filename,
params_filename=params_filename,
batch_size=batch_size,
......@@ -607,7 +607,7 @@ def quant_post_static(executor,
sample_generator=sample_generator,
batch_generator=batch_generator,
data_loader=data_loader,
model_dir=model_dir,
model_dir=model_dir.rstrip('/'),
model_filename=model_filename,
params_filename=params_filename,
batch_size=batch_size,
......@@ -744,10 +744,10 @@ def convert(program,
def quant_post_dynamic(model_dir,
save_model_dir,
model_filename=None,
params_filename=None,
save_model_filename=None,
save_params_filename=None,
model_filename,
params_filename,
save_model_filename='model.pdmodel',
save_params_filename='model.pdiparams',
quantizable_op_type=["conv2d", "mul"],
weight_bits=8,
generate_test_model=False):
......@@ -764,22 +764,15 @@ def quant_post_dynamic(model_dir,
model_dir(str): The path of the fp32 model that will be quantized,
and the model and params files are under the path.
save_model_dir(str): The path to save the quantized model.
model_filename(str, optional): The name of file used to load the
inference program. If it is None, the default filename
'__model__' will be used. Default is 'None'.
params_filename(str, optional): The name of file used to load all
parameters. When all parameters were saved in a single
binary file, set it as the real filename. If parameters
were saved in separate files, set it as 'None'. Default is
'None'.
model_filename(str): The name of file used to load the
inference program.
params_filename(str): The name of file used to load all
parameters.
save_model_dir(str): The path used to save the quantized model.
save_model_filename(str, optional): The name of file to
save the inference program. If it is None, the default
filename '__model__' will be used. Default is 'None'.
save the inference program. Default is 'model.pdmodel'.
save_params_filename(str, optional): The name of file to
save all parameters. If it is None, parameters were
saved in separate files. If it is not None, all
parameters were saved in a single binary file.
save all parameters. Default is 'model.pdiparams'.
quantizable_op_type(list[str], optional): The list of ops
that will be quantized, and the quantized ops should be
contained in ["conv2d", "depthwise_conv2d", "mul"].
......
......@@ -13,19 +13,17 @@
# limitations under the License.
import copy
import logging
import math
import os
import re
import shutil
import sys
import time
import numpy as np
import paddle
from paddle.static.quantization import utils
from paddle.static.quantization import PostTrainingQuantization
from ..dist import merge
from ..core.graph_wrapper import GraphWrapper
from ..common import get_logger, recover_program
from ..common import get_logger
__all__ = ['ReconstructionQuantization', ]
......@@ -48,8 +46,7 @@ class Collections(object):
return self._config
class ReconstructionQuantization(
paddle.fluid.contrib.slim.quantization.PostTrainingQuantization):
class ReconstructionQuantization(PostTrainingQuantization):
"""
Utilizing reconstruction quantization method to quantize the FP32 model,
and it uses calibrate data to get the quantization information for all
......@@ -92,7 +89,7 @@ class ReconstructionQuantization(
def _preparation(self):
batch_id = 0
with paddle.fluid.contrib.slim.quantization.utils.tqdm(
with utils.tqdm(
total=self._batch_nums,
bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80, ) as t:
......@@ -112,7 +109,7 @@ class ReconstructionQuantization(
def _sampling_threshold(self):
batch_id = 0
with paddle.fluid.contrib.slim.quantization.utils.tqdm(
with utils.tqdm(
total=self._batch_nums,
bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80, ) as t:
......@@ -175,7 +172,7 @@ class ReconstructionQuantization(
self._quantized_threshold = self._scale_dict
def _postprocessing(self):
if self._algo is 'min_max':
if self._algo == 'min_max':
self._save_input_threhold()
else:
self._update_program()
......@@ -323,8 +320,7 @@ class ReconstructionQuanter(object):
self._input_weight_pairs = {}
for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops:
in_var_names = paddle.fluid.contrib.slim.quantization.utils._get_op_input_var_names(
op)
in_var_names = utils._get_op_input_var_names(op)
for in_var_name in in_var_names:
if in_var_name in persistable_var_names:
in_var_names.remove(in_var_name)
......@@ -435,14 +431,14 @@ class ReconstructionQuanter(object):
return self._program, self._scale_dict
def _init_alpha(self, name, scale):
_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
_tensor = paddle.static.quantization.utils.load_variable_data(
self._scope, "teacher_" + name)
tensor_scaled = paddle.fluid.contrib.slim.quantization.utils.quant_tensor(
tensor_scaled = paddle.static.quantization.utils.quant_tensor(
x=_tensor,
scale=scale,
weight_bits=self._weight_bits,
quant_axis=0 if self._weight_op_pairs[name] not in paddle.fluid.
contrib.slim.quantization.utils._channelwise_quant_axis1_ops else 1)
quant_axis=0 if self._weight_op_pairs[name] not in
utils._channelwise_quant_axis1_ops else 1)
tensor_floor = np.floor(tensor_scaled)
tensor = tensor_scaled - tensor_floor
alpha = -np.log((ZETA - GAMMA) / (tensor - GAMMA) - 1)
......@@ -744,11 +740,10 @@ class ReconstructionQuanter(object):
if self._skip_tensor_list is not None and _name in self._skip_tensor_list:
continue
scale_name = _name + '.scale'
scale_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
self._scope, scale_name)
scale_tensor = utils.load_variable_data(self._scope, scale_name)
scale_list = []
if self._weight_op_pairs[
_name] in paddle.fluid.contrib.slim.quantization.utils._channelwise_quant_axis1_ops:
_name] in utils._channelwise_quant_axis1_ops:
scale_list = list(scale_tensor[0])
else:
for i in range(scale_tensor.shape[0]):
......@@ -759,23 +754,21 @@ class ReconstructionQuanter(object):
for weight_var_name in self._weight_var_names:
if self._skip_tensor_list is not None and weight_var_name in self._skip_tensor_list:
continue
alpha_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
alpha_tensor = utils.load_variable_data(
self._scope,
weight_var_name + '.alpha', )
h_alpha_tensor = self._compute_soft_rounding_np(alpha_tensor)
weight_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
weight_tensor = utils.load_variable_data(
self._scope,
weight_var_name, )
weight_quant_tensor = paddle.fluid.contrib.slim.quantization.utils.quant_tensor(
weight_quant_tensor = utils.quant_tensor(
x=weight_tensor,
scale=self._scale_dict[weight_var_name],
weight_bits=self._weight_bits,
quant_axis=0
if self._weight_op_pairs[weight_var_name] not in paddle.fluid.
contrib.slim.quantization.utils._channelwise_quant_axis1_ops
else 1)
quant_axis=0 if self._weight_op_pairs[weight_var_name] not in
utils._channelwise_quant_axis1_ops else 1)
paddle.fluid.contrib.slim.quantization.utils.set_variable_data(
utils.set_variable_data(
self._scope,
self._place,
weight_var_name,
......@@ -783,23 +776,21 @@ class ReconstructionQuanter(object):
def _bias_correction_w(self):
for weight_var_name in self._weight_var_names:
weight_var_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
weight_var_tensor = utils.load_variable_data(
self._scope,
"teacher_" + weight_var_name, )
weight_quant_tensor = paddle.fluid.contrib.slim.quantization.utils.load_variable_data(
weight_quant_tensor = utils.load_variable_data(
self._scope,
weight_var_name, )
scale = self._scale_dict[weight_var_name]
final_weight_tensor = paddle.fluid.contrib.slim.quantization.utils.bias_correction_w(
final_weight_tensor = utils.bias_correction_w(
weight_var_tensor,
weight_quant_tensor,
scale,
quant_axis=0
if self._weight_op_pairs[weight_var_name] not in paddle.fluid.
contrib.slim.quantization.utils._channelwise_quant_axis1_ops
else 1,
quant_axis=0 if self._weight_op_pairs[weight_var_name] not in
utils._channelwise_quant_axis1_ops else 1,
weight_bits=self._weight_bits, )
paddle.fluid.contrib.slim.quantization.utils.set_variable_data(
utils.set_variable_data(
self._scope,
self._place,
weight_var_name,
......@@ -807,8 +798,7 @@ class ReconstructionQuanter(object):
def _compute_soft_rounding_np(self, alpha_v):
return np.clip(
paddle.fluid.contrib.slim.quantization.utils.stable_sigmoid(alpha_v)
* (ZETA - GAMMA) + GAMMA,
utils.stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA,
a_min=0,
a_max=1, )
......@@ -1206,7 +1196,7 @@ def quant_recon_static(executor,
sample_generator=sample_generator,
batch_generator=batch_generator,
data_loader=data_loader,
model_dir=model_dir,
model_dir=model_dir.rstrip('/'),
model_filename=model_filename,
params_filename=params_filename,
batch_size=batch_size,
......
......@@ -7,7 +7,7 @@ import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.static.quantization import PostTrainingQuantization
from paddleslim.quant.analysis_qat import AnalysisQAT
paddle.enable_static()
......
......@@ -9,7 +9,7 @@ from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_qat import AnalysisQAT
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.static.quantization import PostTrainingQuantization
paddle.enable_static()
......
......@@ -116,7 +116,7 @@ class ModelCase5(paddle.nn.Layer):
anchors=anchors,
conf_thresh=0.01,
downsample_ratio=32)
out = paddle.fluid.layers.matrix_nms(
out = paddle.vision.ops.matrix_nms(
bboxes=boxes,
scores=scores,
background_label=0,
......@@ -125,7 +125,7 @@ class ModelCase5(paddle.nn.Layer):
nms_top_k=400,
keep_top_k=200,
normalized=False)
box, var = paddle.fluid.layers.prior_box(
box, var = paddle.vision.ops.prior_box(
input=image, image=image, min_sizes=[2.], clip=True, flip=True)
return boxes, scores, box, var, out
......@@ -185,7 +185,7 @@ class ModelCase7(paddle.nn.Layer):
anchors=anchors,
conf_thresh=0.01,
downsample_ratio=32)
box, var = paddle.fluid.layers.prior_box(
box, var = paddle.vision.ops.prior_box(
input=image, image=image, min_sizes=[2.], clip=True, flip=True)
return boxes, scores, box, var
......
......@@ -116,14 +116,12 @@ class TestQuantPostHpoCase1(StaticCase):
train(main_prog)
top1_1, top5_1 = test(val_prog)
paddle.fluid.io.save_inference_model(
dirname='./test_quant_post_hpo',
feeded_var_names=[image.name, label.name],
target_vars=[avg_cost, acc_top1, acc_top5],
main_program=val_prog,
paddle.static.save_inference_model(
path_prefix='./test_quant_post_hpo/model',
feed_vars=[image, label],
fetch_vars=[avg_cost, acc_top1, acc_top5],
executor=exe,
model_filename='model',
params_filename='params')
program=val_prog)
quant_post_hpo(
exe,
......@@ -132,16 +130,13 @@ class TestQuantPostHpoCase1(StaticCase):
"./test_quant_post_hpo_inference",
train_sample_generator=sample_generator_creator(),
eval_sample_generator=sample_generator_creator(),
model_filename="model",
params_filename="params",
save_model_filename='__model__',
save_params_filename='__params__',
model_filename="model.pdmodel",
params_filename="model.pdiparams",
save_model_filename='model.pdmodel',
save_params_filename='model.pdiparams',
runcount_limit=2)
quant_post_prog, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
dirname='./test_quant_post_hpo_inference',
executor=exe,
model_filename='__model__',
params_filename='__params__')
quant_post_prog, feed_target_names, fetch_targets = paddle.static.load_inference_model(
path_prefix='./test_quant_post_hpo_inference/model', executor=exe)
top1_2, top5_2 = test(quant_post_prog, fetch_targets)
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2))
......
......@@ -101,23 +101,22 @@ class TestQuantPostOnlyWeightCase1(StaticCase):
train(main_prog)
top1_1, top5_1 = test(val_prog)
paddle.fluid.io.save_inference_model(
dirname='./test_quant_post_dynamic',
feeded_var_names=[image.name, label.name],
target_vars=[avg_cost, acc_top1, acc_top5],
main_program=val_prog,
paddle.static.save_inference_model(
path_prefix='./test_quant_post_dynamic/model',
feed_vars=[image, label],
fetch_vars=[avg_cost, acc_top1, acc_top5],
executor=exe,
model_filename='model',
params_filename='params')
program=val_prog)
quant_post_dynamic(
model_dir='./test_quant_post_dynamic',
save_model_dir='./test_quant_post_inference',
model_filename='model',
params_filename='params',
model_filename='model.pdmodel',
params_filename='model.pdiparams',
generate_test_model=True)
quant_post_prog, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
dirname='./test_quant_post_inference/test_model', executor=exe)
quant_post_prog, feed_target_names, fetch_targets = paddle.static.load_inference_model(
path_prefix='./test_quant_post_inference/test_model/model',
executor=exe)
top1_2, top5_2 = test(quant_post_prog, fetch_targets)
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册