未验证 提交 8e1691b4 编写于 作者: C Chang Xu 提交者: GitHub

Add Tests for Analysis & Support EvalFunc is None (#1574)

上级 ef6a8f25
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
| model_dir | 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可 | | model_dir | 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可 |
| model_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入 | | model_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入 |
| params_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入 | | params_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入 |
| eval_function | 若需要验证精度,需要传入自定义的验证函数 | | eval_function | 若需要验证精度,需要传入自定义的验证函数;若不传入,精度误差分析将根据Cosine Similarity计算得出 |
| data_loader | 模型校准时使用的数据,DataLoader继承自`paddle.io.DataLoader`。可以直接使用模型套件中的DataLoader,或者根据[paddle.io.DataLoader](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader)自定义所需要的DataLoader | | data_loader | 模型校准时使用的数据,DataLoader继承自`paddle.io.DataLoader`。可以直接使用模型套件中的DataLoader,或者根据[paddle.io.DataLoader](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader)自定义所需要的DataLoader |
| save_dir | 分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results`| | save_dir | 分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results`|
| resume | 是否加载中间分析文件,默认为False| | resume | 是否加载中间分析文件,默认为False|
...@@ -31,19 +31,65 @@ ...@@ -31,19 +31,65 @@
## 3. 量化分析工具的使用 ## 3. 量化分析工具的使用
**创建量化分析工具** **创建量化分析工具**
```shell
# 下载Inference模型
wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar
tar -xf MobileNetV1_infer.tar
# 下载demo数据集
wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz
tar -xf ILSVRC2012_data_demo.tar.gz
``` ```
```shell
import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_ptq import AnalysisPTQ
paddle.enable_static()
class ImageNetDataset(DatasetFolder):
def __init__(self, path, image_size=224):
super(ImageNetDataset, self).__init__(path)
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
def __getitem__(self, idx):
img_path, _ = self.samples[idx]
return self.transform(Image.open(img_path).convert('RGB'))
def __len__(self):
return len(self.samples)
train_dataset = ImageNetDataset(
"./ILSVRC2012_data_demo/ILSVRC2012/train/")
image = paddle.static.data(
name='inputs', shape=[None] + [3, 224, 224], dtype='float32')
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
analyzer = AnalysisPTQ( analyzer = AnalysisPTQ(
model_dir=config["model_dir"], model_dir="./MobileNetV1_infer",
model_filename=config["model_filename"], model_filename="inference.pdmodel",
params_filename=config["params_filename"], params_filename="inference.pdiparams",
eval_function=eval_function, save_dir="MobileNetV1_analysis",
data_loader=data_loader, ptq_config={
save_dir=config['save_dir'], 'quantizable_op_type': ["conv2d", "depthwise_conv2d"],
ptq_config=config['PTQ']) 'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'is_full_quantize': False,
'batch_size': 8,
'batch_nums': 1,
},
data_loader=train_loader)
``` ```
**统计分析** **统计分析**
``` ```shell
analyzer.statistical_analyse() analyzer.statistical_analyse()
``` ```
...@@ -75,7 +121,7 @@ analyzer.statistical_analyse() ...@@ -75,7 +121,7 @@ analyzer.statistical_analyse()
**精度误差分析** **精度误差分析**
``` ```shell
analyzer.metric_error_analyse() analyzer.metric_error_analyse()
``` ```
调用该接口,会遍历量化模型中的一层,并计算量化该层后模型的损失。调用该接口时,需要输入Eval Function。会产出所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。 调用该接口,会遍历量化模型中的一层,并计算量化该层后模型的损失。调用该接口时,需要输入Eval Function。会产出所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。
...@@ -83,8 +129,8 @@ analyzer.metric_error_analyse() ...@@ -83,8 +129,8 @@ analyzer.metric_error_analyse()
**直接产出符合预期精度的目标量化模型** **直接产出符合预期精度的目标量化模型**
``` ```shell
analyzer.get_target_quant_model(target_metric) analyzer.get_target_quant_model(target_metric=70.0)
``` ```
## 4. 根据分析结果执行离线量化 ## 4. 根据分析结果执行离线量化
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
| params_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称 | | params_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称 |
| quantizable_op_type | 需分析的量化的op类型,默认为`conv2d`, `depthwise_conv2d`, `mul` | | quantizable_op_type | 需分析的量化的op类型,默认为`conv2d`, `depthwise_conv2d`, `mul` |
| qat_metric | 量化模型的精度,可不传入,默认为None,不传入时会自动计算 | | qat_metric | 量化模型的精度,可不传入,默认为None,不传入时会自动计算 |
| eval_function | 需要传入自定义的验证函数 | | eval_function | 若需要验证精度,需要传入自定义的验证函数;若不传入,精度误差分析将根据Cosine Similarity计算得出 |
| data_loader | 模型校准时使用的数据,DataLoader继承自`paddle.io.DataLoader`。可以直接使用模型套件中的DataLoader,或者根据[paddle.io.DataLoader](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader)自定义所需要的DataLoader | | data_loader | 模型校准时使用的数据,DataLoader继承自`paddle.io.DataLoader`。可以直接使用模型套件中的DataLoader,或者根据[paddle.io.DataLoader](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader)自定义所需要的DataLoader |
| save_dir | 分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results`| | save_dir | 分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results`|
| resume | 是否加载中间分析文件,默认为False| | resume | 是否加载中间分析文件,默认为False|
...@@ -25,24 +25,66 @@ ...@@ -25,24 +25,66 @@
## 3. 量化分析工具的使用 ## 3. 量化分析工具的使用
**创建量化分析工具** **创建量化分析工具**
```shell
# 下载Inference模型
wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar
tar -xf MobileNetV1_infer.tar
wget -q https://paddle-slim-models.bj.bcebos.com/act/MobileNetV1_QAT.tar
tar -xf MobileNetV1_QAT.tar
# 下载demo数据集
wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz
tar -xf ILSVRC2012_data_demo.tar.gz
``` ```
```shell
import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddleslim.quant.analysis_qat import AnalysisQAT
paddle.enable_static()
class ImageNetDataset(DatasetFolder):
def __init__(self, path, image_size=224):
super(ImageNetDataset, self).__init__(path)
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
def __getitem__(self, idx):
img_path, _ = self.samples[idx]
return self.transform(Image.open(img_path).convert('RGB'))
def __len__(self):
return len(self.samples)
train_dataset = ImageNetDataset(
"./ILSVRC2012_data_demo/ILSVRC2012/train/")
image = paddle.static.data(
name='inputs', shape=[None] + [3, 224, 224], dtype='float32')
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
analyzer = AnalysisQAT( analyzer = AnalysisQAT(
quant_model_dir=config["quant_model_dir"], float_model_dir="./MobileNetV1_infer",
float_model_dir=config["float_model_dir"], quant_model_dir="./MobileNetV1_QAT",
model_filename=config["model_filename"], model_filename="inference.pdmodel",
params_filename=config["params_filename"], params_filename="inference.pdiparams",
quantizable_op_type=config['quantizable_op_type'], save_dir="MobileNetV1_analysis",
qat_metric=config['qat_metric'], data_loader=train_loader)
eval_function=eval_function,
data_loader=eval_loader,
save_dir=config['save_dir'],
resume=config['resume'],
)
``` ```
**精度误差分析** **精度误差分析**
``` ```shell
analyzer.metric_error_analyse() analyzer.metric_error_analyse()
``` ```
调用该接口,会遍历量化模型中的每一层,去掉量化节点并计算当前层不量化的模型精度。调用该接口时,需要输入Eval Function。会产出所有去掉一层量化的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。具体使用可参考[GPT量化训练敏感度分析DEMO](../../../../example/quantization_analysis/GPT/README.md) 调用该接口,会遍历量化模型中的每一层,去掉量化节点并计算当前层不量化的模型精度。调用该接口时,需要输入Eval Function。会产出所有去掉一层量化的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。具体使用可参考[GPT量化训练敏感度分析DEMO](../../../../example/quantization_analysis/GPT/README.md)
......
...@@ -24,10 +24,22 @@ ...@@ -24,10 +24,22 @@
量化敏感度分析基于验证集获得每层的敏感度,可下载和使用 [LAMBADA](https://raw.githubusercontent.com/cybertronai/bflm/master/lambada_test.jsonl) 或者 [WikiText](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip) 数据集。本示例使用LAMBADA数据集来进行敏感度分析。 量化敏感度分析基于验证集获得每层的敏感度,可下载和使用 [LAMBADA](https://raw.githubusercontent.com/cybertronai/bflm/master/lambada_test.jsonl) 或者 [WikiText](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip) 数据集。本示例使用LAMBADA数据集来进行敏感度分析。
```shell
# 下载验证数据
wget https://raw.githubusercontent.com/cybertronai/bflm/master/lambada_test.jsonl
```
#### 3.3 准备预测模型 #### 3.3 准备预测模型
- [GPT-345M](https://bj.bcebos.com/v1/paddle-slim-models/GPT_345M_Baseline.tar) :Base模型 - 下载量化前Base模型
- [GPT-345M](https://bj.bcebos.com/v1/paddle-slim-models/GPT_345_QAT_wo_analysis.tar) :分析前量化训练后的模型 ```shell
wget https://bj.bcebos.com/v1/paddle-slim-models/GPT_345M_Baseline.tar
```
- 下载分析前量化训练后的模型
```shell
wget https://bj.bcebos.com/v1/paddle-slim-models/GPT_345_QAT_wo_analysis.tar
```
如想自行导出,可参考[GPT模型量化训练](https://github.com/PaddlePaddle/PaddleFleetX/blob/release/2.4/projects/gpt/docs/quantization_aware_training.md)
#### 3.4 量化敏感度分析 #### 3.4 量化敏感度分析
量化敏感度分析示例通过analysis.py脚本启动,会使用接口```paddleslim.quant.AnalysisQAT```对模型进行敏感度分析。配置config文件中模型路径、数据路径和量化相关的参数,配置完成后便可对模型进行敏感度分析。具体运行命令为: 量化敏感度分析示例通过analysis.py脚本启动,会使用接口```paddleslim.quant.AnalysisQAT```对模型进行敏感度分析。配置config文件中模型路径、数据路径和量化相关的参数,配置完成后便可对模型进行敏感度分析。具体运行命令为:
...@@ -44,3 +56,4 @@ python analysis.py --config_path=./configs/gpt_345M_analysis.yaml ...@@ -44,3 +56,4 @@ python analysis.py --config_path=./configs/gpt_345M_analysis.yaml
#### 3.5 重新量化训练 #### 3.5 重新量化训练
根据分析结果,重新量化训练时,去掉了```linear_31```,```linear_27```,```linear_22```,```linear_43```,```linear_83```,```linear_15```,```linear_87```七层Linear的量化,最后量化模型精度达到44.94。 根据分析结果,重新量化训练时,去掉了```linear_31```,```linear_27```,```linear_22```,```linear_43```,```linear_83```,```linear_15```,```linear_87```七层Linear的量化,最后量化模型精度达到44.94。
重新量化训练的过程在 PaddleFleetX 中实现,可参考[GPT模型量化训练](https://github.com/PaddlePaddle/PaddleFleetX/blob/release/2.4/projects/gpt/docs/quantization_aware_training.md)。
...@@ -104,8 +104,8 @@ def eval_function(exe, program, feed_names, fetch_list): ...@@ -104,8 +104,8 @@ def eval_function(exe, program, feed_names, fetch_list):
total_score += acc.numpy()[0] total_score += acc.numpy()[0]
if eval_step != 0 and (eval_step % 10 == 0): if eval_step != 0 and (eval_step % 10 == 0):
print("[eval] step: %d, batch: %d, %s: %.9f, speed: %.2f step/s" % print("[eval] step: %d, %s: %.9f, speed: %.2f step/s" %
(eval_step, eval_step, score_name, total_score, (eval_step, score_name, total_score,
1. / (time.time() - tic_eval))) 1. / (time.time() - tic_eval)))
tic_eval = time.time() tic_eval = time.time()
paddle.enable_static() paddle.enable_static()
......
...@@ -24,7 +24,7 @@ import numpy as np ...@@ -24,7 +24,7 @@ import numpy as np
import random import random
import tempfile import tempfile
import paddle import paddle
from .quanter import quant_post import paddle.nn.functional as F
from ..core import GraphWrapper from ..core import GraphWrapper
from ..common import get_logger from ..common import get_logger
from ..common import get_feed_vars, wrap_dataloader, load_inference_model, get_model_dir from ..common import get_feed_vars, wrap_dataloader, load_inference_model, get_model_dir
...@@ -80,6 +80,7 @@ class AnalysisPTQ(object): ...@@ -80,6 +80,7 @@ class AnalysisPTQ(object):
'is_full_quantize'] if 'is_full_quantize' in ptq_config else False 'is_full_quantize'] if 'is_full_quantize' in ptq_config else False
self.onnx_format = ptq_config[ self.onnx_format = ptq_config[
'onnx_format'] if 'onnx_format' in ptq_config else False 'onnx_format'] if 'onnx_format' in ptq_config else False
ptq_config['onnx_format'] = self.onnx_format
if 'algo' not in ptq_config: if 'algo' not in ptq_config:
ptq_config['algo'] = 'avg' ptq_config['algo'] = 'avg'
...@@ -134,9 +135,8 @@ class AnalysisPTQ(object): ...@@ -134,9 +135,8 @@ class AnalysisPTQ(object):
# load tobe_analyized_layer from checkpoint # load tobe_analyized_layer from checkpoint
if resume: if resume:
self.load_checkpoint() self.load_checkpoint()
self.tobe_analyized_layer = set(self.support_quant_val_name_list) - set( self.tobe_analyized_layer = sorted(
list(self.quant_layer_metrics.keys())) list(self.support_quant_val_name_list))
self.tobe_analyized_layer = sorted(list(self.tobe_analyized_layer))
def save_checkpoint(self): def save_checkpoint(self):
if not os.path.exists(self.save_dir): if not os.path.exists(self.save_dir):
...@@ -172,7 +172,6 @@ class AnalysisPTQ(object): ...@@ -172,7 +172,6 @@ class AnalysisPTQ(object):
model_filename=self.model_filename, model_filename=self.model_filename,
params_filename=self.params_filename, params_filename=self.params_filename,
skip_tensor_list=skip_tensor_list, skip_tensor_list=skip_tensor_list,
onnx_format=self.onnx_format,
**self.ptq_config) **self.ptq_config)
def sampling(self, executor, program, scope): def sampling(self, executor, program, scope):
...@@ -187,65 +186,125 @@ class AnalysisPTQ(object): ...@@ -187,65 +186,125 @@ class AnalysisPTQ(object):
if batch_id >= self.batch_nums: if batch_id >= self.batch_nums:
break break
def eval_quant_model(self, skip_list): def fp_int_cosine_similarity(self, executor, float_program, quant_program,
float_scope, quant_scope):
cosine_similarity = []
for step, data in enumerate(self.data_loader()):
with paddle.static.scope_guard(float_scope):
float_preds = executor.run(program=float_program,
feed=data,
fetch_list=self.fetch_list,
return_numpy=False)
float_preds = float_preds[0]
with paddle.static.scope_guard(quant_scope):
quant_preds = executor.run(program=quant_program,
feed=data,
fetch_list=self.fetch_list,
return_numpy=False)
quant_preds = quant_preds[0]
paddle.disable_static()
float_preds = paddle.to_tensor(float_preds)
quant_preds = paddle.to_tensor(quant_preds)
cos_sim = F.cosine_similarity(float_preds, quant_preds).mean()
cos_sim = cos_sim.numpy()
cosine_similarity.append(cos_sim)
if step != 0 and (step % 10 == 0):
_logger.info("[step]: %d, cosine similarity: %.9f" %
(step, np.array(cosine_similarity).mean()))
paddle.enable_static()
return np.array(cosine_similarity).mean()
def get_sensitive_metric(self, skip_list, layer_name):
executor = paddle.static.Executor(self.places) executor = paddle.static.Executor(self.places)
post_training_quantization = self.create_ptq(executor, skip_list) if self.eval_function is not None:
program = post_training_quantization.quantize() post_training_quantization = self.create_ptq(executor, skip_list)
_logger.info('Evaluating...') program = post_training_quantization.quantize()
if self.onnx_format: _logger.info('Evaluating...')
post_training_quantization.save_quantized_model( if self.onnx_format:
self.temp_save_path, post_training_quantization.save_quantized_model(
model_filename='model.pdmodel', self.temp_save_path,
params_filename='model.pdiparams') model_filename='model.pdmodel',
program, _, _ = load_inference_model( params_filename='model.pdiparams')
self.temp_save_path, program, _, _ = load_inference_model(
executor, self.temp_save_path,
model_filename='model.pdmodel', executor,
params_filename='model.pdiparams') model_filename='model.pdmodel',
quant_metric = self.eval_function(executor, program, self.feed_list, params_filename='model.pdiparams')
self.fetch_list) metric = self.eval_function(executor, program, self.feed_list,
self.fetch_list)
if skip_list is None:
executor.close()
return metric
sensitive_metric = self.base_metric - metric
_logger.info(
"Quantized layer name: %s, the accuracy: %.4f, the sensitive metric: %.4f"
% (layer_name, metric, sensitive_metric))
else:
float_scope = paddle.static.Scope()
quant_scope = paddle.static.Scope()
with paddle.static.scope_guard(float_scope):
[float_program, _, _] = load_inference_model(
self.model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
with paddle.static.scope_guard(quant_scope):
post_training_quantization = self.create_ptq(executor,
skip_list)
quant_program = post_training_quantization.quantize()
metric = self.fp_int_cosine_similarity(executor, float_program,
quant_program, float_scope,
quant_scope)
sensitive_metric = 1.0 - metric
_logger.info(
"Quantized layer name: %s, the cosine similarity: %.4f, the sensitive metric: %.4f"
% (layer_name, metric, sensitive_metric))
executor.close() executor.close()
return quant_metric return sensitive_metric
def metric_error_analyse(self): def metric_error_analyse(self):
''' '''
Evaluate the quantized models, which are generated by quantizing each weight operator one by one. The results will be saved into analysis.txt. Evaluate the quantized models, which are generated by quantizing each weight operator one by one. The results will be saved into analysis.txt.
''' '''
assert self.data_loader is not None, "When computing the sensitivity of quantized layers, the data loader is needed" assert self.data_loader is not None, "When computing the sensitivity of quantized layers, the data loader is needed"
assert self.eval_function is not None, "When computing the sensitivity of quantized layers, the eval function is needed" if self.eval_function is not None:
# evaluate before quant
# evaluate before quant _logger.info('Start to evaluate the base model.')
_logger.info('Start to evaluate the base model.') executor = paddle.static.Executor(self.places)
executor = paddle.static.Executor(self.places) [program, feed_list, fetch_list]= load_inference_model( \
[program, feed_list, fetch_list]= load_inference_model( \ self.model_dir, \
self.model_dir, \ executor=executor, \
executor=executor, \ model_filename=self.model_filename, \
model_filename=self.model_filename, \ params_filename=self.params_filename)
params_filename=self.params_filename) self.base_metric = self.eval_function(executor, program, feed_list,
self.base_metric = self.eval_function(executor, program, feed_list, fetch_list)
fetch_list) _logger.info('Before quantized, the accuracy of the model is: {}'.
_logger.info('Before quantized, the accuracy of the model is: {}'. format(self.base_metric))
format(self.base_metric)) executor.close()
# evaluate before quant # evaluate before quant
_logger.info('Start to evaluate the quantized model.') _logger.info('Start to evaluate the quantized model.')
self.quant_metric = self.eval_quant_model(None) self.quant_metric = self.get_sensitive_metric(
_logger.info('After quantized, the accuracy of the model is: {}'.format( None, 'all quantizable layers')
self.quant_metric)) _logger.info('After quantized, the accuracy of the model is: {}'.
format(self.quant_metric))
# For each layer, quantize the weight op and evaluate the quantized model. # For each layer, quantize the weight op and evaluate the quantized model.
for i, layer_name in enumerate(self.tobe_analyized_layer): for i, layer_name in enumerate(self.tobe_analyized_layer):
if layer_name in self.quant_layer_metrics:
continue
_logger.info('Checking {}/{} quant model: quant layer {}'.format( _logger.info('Checking {}/{} quant model: quant layer {}'.format(
i + 1, len(self.tobe_analyized_layer), layer_name)) i + 1, len(self.tobe_analyized_layer), layer_name))
skip_list = copy.copy(list(self.support_quant_val_name_list)) skip_list = copy.copy(list(self.support_quant_val_name_list))
skip_list.remove(layer_name) skip_list.remove(layer_name)
quant_metric = self.eval_quant_model(skip_list) sensitive_metric = self.get_sensitive_metric(skip_list, layer_name)
_logger.info( self.quant_layer_metrics[layer_name] = sensitive_metric
"Quantized layer name: {}, eval metric: {}, the loss caused by this layer: {}".
format(layer_name,
round(quant_metric, 4),
round(self.base_metric - quant_metric, 4)))
self.quant_layer_metrics[layer_name] = quant_metric
self.save_checkpoint() self.save_checkpoint()
if self.onnx_format: if self.onnx_format:
...@@ -254,18 +313,18 @@ class AnalysisPTQ(object): ...@@ -254,18 +313,18 @@ class AnalysisPTQ(object):
self.sensitivity_ranklist = sorted( self.sensitivity_ranklist = sorted(
self.quant_layer_metrics, self.quant_layer_metrics,
key=self.quant_layer_metrics.get, key=self.quant_layer_metrics.get,
reverse=False) reverse=True)
_logger.info('Finished computing the sensitivity of the model.') _logger.info('Finished computing the sensitivity of the model.')
for name in self.sensitivity_ranklist: for name in self.sensitivity_ranklist:
_logger.info("quant layer name: {}, eval metric: {}".format( _logger.info("Quantized layer name: {}, sensitivity metric: {}".
name, self.quant_layer_metrics[name])) format(name, self.quant_layer_metrics[name]))
analysis_file = os.path.join(self.save_dir, "analysis.txt") analysis_file = os.path.join(self.save_dir, "analysis.txt")
with open(analysis_file, "w") as analysis_ret_f: with open(analysis_file, "w") as analysis_ret_f:
for name in self.sensitivity_ranklist: for name in self.sensitivity_ranklist:
analysis_ret_f.write( analysis_ret_f.write(
"quant layer name: {}, eval metric: {}\n".format( "Quantized layer name: {}, sensitivity metric: {}\n".format(
name, self.quant_layer_metrics[name])) name, self.quant_layer_metrics[name]))
_logger.info('Analysis file is saved in {}'.format(analysis_file)) _logger.info('Analysis file is saved in {}'.format(analysis_file))
...@@ -285,7 +344,7 @@ class AnalysisPTQ(object): ...@@ -285,7 +344,7 @@ class AnalysisPTQ(object):
executor=executor, \ executor=executor, \
model_filename=self.model_filename, \ model_filename=self.model_filename, \
params_filename=self.params_filename) params_filename=self.params_filename)
scope = paddle.static.Executor.global_scope() scope = paddle.static.global_scope()
persistable_var_names = [] persistable_var_names = []
for var in program.list_vars(): for var in program.list_vars():
if var.persistable: if var.persistable:
...@@ -294,6 +353,9 @@ class AnalysisPTQ(object): ...@@ -294,6 +353,9 @@ class AnalysisPTQ(object):
self.acts_weight_map = self.get_weight_act_map( self.acts_weight_map = self.get_weight_act_map(
program, self.weight_names, persistable_var_names) program, self.weight_names, persistable_var_names)
activations_names = list(self.acts_weight_map.keys()) activations_names = list(self.acts_weight_map.keys())
for var in program.list_vars():
if var.name in activations_names:
var.persistable = True
# sample # sample
self.sampling(executor, program, scope) self.sampling(executor, program, scope)
...@@ -305,7 +367,7 @@ class AnalysisPTQ(object): ...@@ -305,7 +367,7 @@ class AnalysisPTQ(object):
def collect_quant_stat(self): def collect_quant_stat(self):
_logger.info('Collecting Statistic After PTQ...') _logger.info('Collecting Statistic After PTQ...')
executor = paddle.static.Executor(self.places) executor = paddle.static.Executor(self.places)
scope = paddle.static.Executor.global_scope() scope = paddle.static.global_scope()
post_training_quantization = self.create_ptq(executor, None) post_training_quantization = self.create_ptq(executor, None)
program = post_training_quantization.quantize() program = post_training_quantization.quantize()
...@@ -525,13 +587,13 @@ class AnalysisPTQ(object): ...@@ -525,13 +587,13 @@ class AnalysisPTQ(object):
rank_list = sorted( rank_list = sorted(
self.quant_layer_metrics, self.quant_layer_metrics,
key=self.quant_layer_metrics.get, key=self.quant_layer_metrics.get,
reverse=False) reverse=True)
else: else:
_logger.info( _logger.info(
'Analyse metric error before get target quantized model.') 'Analyse metric error before get target quantized model.')
self.metric_error_analyse() self.metric_error_analyse()
while True: while len(rank_list) > 0:
skip_list.append(rank_list.pop(0)) skip_list.append(rank_list.pop(0))
_logger.info('Skip Ops: {}'.format(skip_list)) _logger.info('Skip Ops: {}'.format(skip_list))
executor = paddle.static.Executor(self.places) executor = paddle.static.Executor(self.places)
...@@ -559,3 +621,7 @@ class AnalysisPTQ(object): ...@@ -559,3 +621,7 @@ class AnalysisPTQ(object):
'The quantized model does not satisfy the target metric. Skip next Op...' 'The quantized model does not satisfy the target metric. Skip next Op...'
) )
executor.close() executor.close()
else:
_logger.info(
'Sorry, the target quantized model cannot be found. Please set lower target metric.'
)
...@@ -20,6 +20,7 @@ import logging ...@@ -20,6 +20,7 @@ import logging
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn.functional as F
from paddle.framework import core from paddle.framework import core
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from ..common import get_logger, load_inference_model from ..common import get_logger, load_inference_model
...@@ -69,6 +70,7 @@ class AnalysisQAT(object): ...@@ -69,6 +70,7 @@ class AnalysisQAT(object):
self.quantizable_op_type = quantizable_op_type self.quantizable_op_type = quantizable_op_type
self.qat_metric = qat_metric self.qat_metric = qat_metric
self.eval_function = eval_function self.eval_function = eval_function
self.data_loader = data_loader
self.save_dir = save_dir self.save_dir = save_dir
self.checkpoint_name = os.path.join(save_dir, 'analysis_checkpoint.pkl') self.checkpoint_name = os.path.join(save_dir, 'analysis_checkpoint.pkl')
self.nonquant_layer_metrics = {} self.nonquant_layer_metrics = {}
...@@ -98,8 +100,13 @@ class AnalysisQAT(object): ...@@ -98,8 +100,13 @@ class AnalysisQAT(object):
if 'quantized' in input_name: if 'quantized' in input_name:
self.inputs_of_quantized_op.append(input_names) self.inputs_of_quantized_op.append(input_names)
break break
if self.eval_function is None:
assert self.data_loader is not None, "DataLoader cannot be None if Eval Fuction is None."
_logger.info(
'The sensitivity will measured by cosine similarity of the outputs from float model and quantized model.'
)
if self.qat_metric is None: if self.qat_metric is None and self.eval_function is not None:
_logger.info('Calculating the metric of QAT model...') _logger.info('Calculating the metric of QAT model...')
self.qat_metric = self.eval_function( self.qat_metric = self.eval_function(
executor, program, self.feed_list, self.fetch_list) * 100 executor, program, self.feed_list, self.fetch_list) * 100
...@@ -107,6 +114,9 @@ class AnalysisQAT(object): ...@@ -107,6 +114,9 @@ class AnalysisQAT(object):
round(self.qat_metric, 4))) round(self.qat_metric, 4)))
executor.close() executor.close()
if resume:
self.load_checkpoint()
def save_checkpoint(self): def save_checkpoint(self):
if not os.path.exists(self.save_dir): if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir) os.makedirs(self.save_dir)
...@@ -199,6 +209,35 @@ class AnalysisQAT(object): ...@@ -199,6 +209,35 @@ class AnalysisQAT(object):
return graph.to_program() return graph.to_program()
def fp_int_cosine_similarity(self, executor, float_program, quant_program,
float_scope, quant_scope):
cosine_similarity = []
for step, data in enumerate(self.data_loader()):
with paddle.static.scope_guard(float_scope):
float_preds = executor.run(program=float_program,
feed=data,
fetch_list=self.fetch_list,
return_numpy=False)
float_preds = float_preds[0]
with paddle.static.scope_guard(quant_scope):
quant_preds = executor.run(program=quant_program,
feed=data,
fetch_list=self.fetch_list,
return_numpy=False)
quant_preds = quant_preds[0]
paddle.disable_static()
float_preds = paddle.to_tensor(float_preds)
quant_preds = paddle.to_tensor(quant_preds)
cos_sim = F.cosine_similarity(float_preds, quant_preds).mean()
cos_sim = cos_sim.numpy()
cosine_similarity.append(cos_sim)
if step != 0 and (step % 10 == 0):
_logger.info("[step]: %d, cosine similarity: %.9f" %
(step, np.array(cosine_similarity).mean()))
paddle.enable_static()
return np.array(cosine_similarity).mean()
def metric_error_analyse(self): def metric_error_analyse(self):
executor = paddle.static.Executor(self.places) executor = paddle.static.Executor(self.places)
...@@ -207,12 +246,14 @@ class AnalysisQAT(object): ...@@ -207,12 +246,14 @@ class AnalysisQAT(object):
for idx, input_list in enumerate(self.inputs_of_quantized_op): for idx, input_list in enumerate(self.inputs_of_quantized_op):
weight_name = self.get_weight_name(input_list) weight_name = self.get_weight_name(input_list)
if weight_name in self.nonquant_layer_metrics:
continue
_logger.info( _logger.info(
'Checking {}/{} quant model: without quant layer {}'.format( 'Checking {}/{} quant model: without quant layer {}'.format(
idx + 1, len(self.inputs_of_quantized_op), weight_name)) idx + 1, len(self.inputs_of_quantized_op), weight_name))
with paddle.static.scope_guard(float_scope): with paddle.static.scope_guard(float_scope):
load_inference_model( [float_program, _, _] = load_inference_model(
self.float_model_dir, self.float_model_dir,
executor=executor, executor=executor,
model_filename=self.model_filename, model_filename=self.model_filename,
...@@ -232,18 +273,26 @@ class AnalysisQAT(object): ...@@ -232,18 +273,26 @@ class AnalysisQAT(object):
input_list, graph, float_scope, quant_scope) input_list, graph, float_scope, quant_scope)
saved_program = self.relink_graph(graph, input_rename_map, saved_program = self.relink_graph(graph, input_rename_map,
output_rename_map, removed_ops) output_rename_map, removed_ops)
with paddle.static.scope_guard(quant_scope): if self.eval_function is not None:
_logger.info('Skip quant {}, evaluating....'.format( with paddle.static.scope_guard(quant_scope):
weight_name)) _logger.info('Skip quant {}, evaluating....'.format(
metric = self.eval_function(executor, saved_program, weight_name))
self.feed_list, metric = self.eval_function(executor, saved_program,
self.fetch_list) * 100 self.feed_list,
self.nonquant_layer_metrics[weight_name] = metric self.fetch_list) * 100
self.nonquant_layer_metrics[
weight_name] = metric - self.qat_metric
_logger.info(
'When skip quant %s, the eval metric is %.4f, the sensitive metric is %.4f'
% (weight_name, metric, metric - self.qat_metric))
else:
metric = self.fp_int_cosine_similarity(executor, float_program,
saved_program,
float_scope, quant_scope)
self.nonquant_layer_metrics[weight_name] = 1 - metric
_logger.info( _logger.info(
'When skip quant {}, the metric is {}, the diff is {}'. 'When skip quant %s, the cosine similarity is %.4f, the sensitive metric is %.4f'
format(weight_name, % (weight_name, metric, 1 - metric))
round(metric, 4), round(metric - self.qat_metric,
4)))
self.save_checkpoint() self.save_checkpoint()
executor.close() executor.close()
...@@ -254,13 +303,13 @@ class AnalysisQAT(object): ...@@ -254,13 +303,13 @@ class AnalysisQAT(object):
reverse=True) reverse=True)
_logger.info('Finished computing the sensitivity of the model.') _logger.info('Finished computing the sensitivity of the model.')
for name in self.sensitivity_ranklist: for name in self.sensitivity_ranklist:
_logger.info("without quant layer name: {}, eval metric: {}".format( _logger.info("Without quant layer name: {}, sensitive metric: {}".
name, self.nonquant_layer_metrics[name])) format(name, self.nonquant_layer_metrics[name]))
analysis_file = os.path.join(self.save_dir, "analysis.txt") analysis_file = os.path.join(self.save_dir, "analysis.txt")
with open(analysis_file, "w") as analysis_ret_f: with open(analysis_file, "w") as analysis_ret_f:
for name in self.sensitivity_ranklist: for name in self.sensitivity_ranklist:
analysis_ret_f.write( analysis_ret_f.write(
"without layer name: {}, eval metric: {}\n".format( "Without quant layer name: {}, sensitive metric: {}\n".
name, self.nonquant_layer_metrics[name])) format(name, self.nonquant_layer_metrics[name]))
_logger.info('Analysis file is saved in {}'.format(analysis_file)) _logger.info('Analysis file is saved in {}'.format(analysis_file))
import os
import sys
import unittest
sys.path.append("../")
import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_ptq import AnalysisPTQ
paddle.enable_static()
class ImageNetDataset(DatasetFolder):
def __init__(self, path, image_size=224):
super(ImageNetDataset, self).__init__(path)
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
def __getitem__(self, idx):
img_path, _ = self.samples[idx]
return self.transform(Image.open(img_path).convert('RGB'))
def __len__(self):
return len(self.samples)
class AnalysisPTQDemo(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(AnalysisPTQDemo, self).__init__(*args, **kwargs)
if not os.path.exists('MobileNetV1_infer'):
os.system(
'wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
)
os.system('tar -xf MobileNetV1_infer.tar')
if not os.path.exists('ILSVRC2012_data_demo'):
os.system(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os.system('tar -xf ILSVRC2012_data_demo.tar.gz')
def test_demo(self):
train_dataset = ImageNetDataset(
"./ILSVRC2012_data_demo/ILSVRC2012/train/")
image = paddle.static.data(
name='inputs', shape=[None] + [3, 224, 224], dtype='float32')
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
analyzer = AnalysisPTQ(
model_dir="./MobileNetV1_infer",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
ptq_config={
'quantizable_op_type': ["conv2d", "depthwise_conv2d"],
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'is_full_quantize': False,
'batch_size': 8,
'batch_nums': 1,
},
data_loader=train_loader)
analyzer.statistical_analyse()
analyzer.metric_error_analyse()
os.system('rm -rf MobileNetV1_analysis')
if __name__ == '__main__':
unittest.main()
import os
import sys
import unittest
import numpy as np
sys.path.append("../")
import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_ptq import AnalysisPTQ
paddle.enable_static()
class ImageNetDataset(DatasetFolder):
def __init__(self, data_dir, image_size=224, mode='train'):
super(ImageNetDataset, self).__init__(data_dir)
self.data_dir = data_dir
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
self.mode = mode
train_file_list = os.path.join(data_dir, 'train_list.txt')
val_file_list = os.path.join(data_dir, 'val_list.txt')
self.mode = mode
if mode == 'train':
with open(train_file_list) as flist:
full_lines = [line.strip() for line in flist]
np.random.shuffle(full_lines)
lines = full_lines
self.samples = [line.split() for line in lines]
else:
with open(val_file_list) as flist:
lines = [line.strip() for line in flist]
self.samples = [line.split() for line in lines]
def __getitem__(self, idx):
img_path, label = self.samples[idx]
if self.mode == 'train':
return self.transform(
Image.open(os.path.join(self.data_dir, img_path)).convert(
'RGB'))
else:
return self.transform(
Image.open(os.path.join(self.data_dir, img_path)).convert(
'RGB')), np.array([label]).astype('int64')
def __len__(self):
return len(self.samples)
class AnalysisPTQEvalFunction(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(AnalysisPTQEvalFunction, self).__init__(*args, **kwargs)
if not os.path.exists('MobileNetV1_infer'):
os.system(
'wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
)
os.system('tar -xf MobileNetV1_infer.tar')
if not os.path.exists('ILSVRC2012_data_demo'):
os.system(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os.system('tar -xf ILSVRC2012_data_demo.tar.gz')
def test_demo(self):
train_dataset = ImageNetDataset("./ILSVRC2012_data_demo/ILSVRC2012/")
image = paddle.static.data(
name='inputs', shape=[None] + [3, 224, 224], dtype='float32')
label = paddle.static.data(
name='labels', shape=[None] + [1], dtype='float32')
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
def reader_wrapper(reader, input_name):
def gen():
for i, (imgs, label) in enumerate(reader()):
yield {input_name: imgs}
return gen
def eval_reader(data_dir,
batch_size,
crop_size,
resize_size,
place=None):
val_dataset = ImageNetDataset(
"./ILSVRC2012_data_demo/ILSVRC2012/", mode='val')
val_loader = paddle.io.DataLoader(
val_dataset,
feed_list=[image, label],
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=0,
return_list=False)
return val_loader
def eval_function(exe, compiled_test_program, test_feed_names,
test_fetch_list):
val_loader = eval_reader(
'./ILSVRC2012_data_demo/ILSVRC2012/',
batch_size=32,
crop_size=224,
resize_size=256)
results = []
print('Evaluating...')
for batch_id, data in enumerate(val_loader):
image = data[0]['inputs']
label = data[0]['labels']
# top1_acc, top5_acc
if len(test_feed_names) == 1:
image = np.array(image)
label = np.array(label).astype('int64')
pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = np.array(pred[0])
label = np.array(label)
sort_array = pred.argsort(axis=1)
top_1_pred = sort_array[:, -1:][:, ::-1]
top_1 = np.mean(label == top_1_pred)
top_5_pred = sort_array[:, -5:][:, ::-1]
acc_num = 0
for i in range(len(label)):
if label[i][0] in top_5_pred[i]:
acc_num += 1
top_5 = float(acc_num) / len(label)
results.append([top_1, top_5])
else:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = np.array(image)
label = np.array(label).astype('int64')
result = exe.run(compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = [np.mean(r) for r in result]
results.append(result)
if batch_id % 100 == 0:
print('Eval iter: ', batch_id)
result = np.mean(np.array(results), axis=0)
return result[0]
analyzer = AnalysisPTQ(
model_dir="./MobileNetV1_infer",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
ptq_config={
'quantizable_op_type': ["conv2d", "depthwise_conv2d"],
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'is_full_quantize': False,
'batch_size': 8,
'batch_nums': 10,
},
data_loader=train_loader,
eval_function=eval_function)
analyzer.metric_error_analyse()
analyzer.get_target_quant_model(69.5)
os.system('rm -rf MobileNetV1_analysis')
if __name__ == '__main__':
unittest.main()
import os
import sys
import unittest
sys.path.append("../")
import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddleslim.quant.analysis_qat import AnalysisQAT
paddle.enable_static()
class ImageNetDataset(DatasetFolder):
def __init__(self, path, image_size=224):
super(ImageNetDataset, self).__init__(path)
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
def __getitem__(self, idx):
img_path, _ = self.samples[idx]
return self.transform(Image.open(img_path).convert('RGB'))
def __len__(self):
return len(self.samples)
class AnalysisQATDemo(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(AnalysisQATDemo, self).__init__(*args, **kwargs)
if not os.path.exists('MobileNetV1_infer'):
os.system(
'wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
)
os.system('tar -xf MobileNetV1_infer.tar')
if not os.path.exists('ILSVRC2012_data_demo'):
os.system(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os.system('tar -xf ILSVRC2012_data_demo.tar.gz')
def test_demo(self):
train_dataset = ImageNetDataset(
"./ILSVRC2012_data_demo/ILSVRC2012/train/")
image = paddle.static.data(
name='inputs', shape=[None] + [3, 224, 224], dtype='float32')
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
executor = paddle.static.Executor(place)
ptq_config = {
'quantizable_op_type': ["conv2d", "depthwise_conv2d"],
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'is_full_quantize': False,
'batch_size': 8,
'batch_nums': 10,
}
post_training_quantization = PostTrainingQuantization(
executor=executor,
data_loader=train_loader,
model_dir="./MobileNetV1_infer",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
onnx_format=True,
algo='avg',
**ptq_config)
post_training_quantization.quantize()
post_training_quantization.save_quantized_model(
"./MobileNetV1_quant",
model_filename='inference.pdmodel',
params_filename='inference.pdiparams')
analyzer = AnalysisQAT(
float_model_dir="./MobileNetV1_infer",
quant_model_dir="./MobileNetV1_quant",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="analysis_result",
data_loader=train_loader)
analyzer.metric_error_analyse()
os.system('rm -rf analysis_result')
os.system('rm -rf MobileNetV1_quant')
if __name__ == '__main__':
unittest.main()
import os
import sys
import unittest
import numpy as np
sys.path.append("../")
import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_qat import AnalysisQAT
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
paddle.enable_static()
class ImageNetDataset(DatasetFolder):
def __init__(self, data_dir, image_size=224, mode='train'):
super(ImageNetDataset, self).__init__(data_dir)
self.data_dir = data_dir
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
self.mode = mode
train_file_list = os.path.join(data_dir, 'train_list.txt')
val_file_list = os.path.join(data_dir, 'val_list.txt')
self.mode = mode
if mode == 'train':
with open(train_file_list) as flist:
full_lines = [line.strip() for line in flist]
np.random.shuffle(full_lines)
lines = full_lines
self.samples = [line.split() for line in lines]
else:
with open(val_file_list) as flist:
lines = [line.strip() for line in flist]
self.samples = [line.split() for line in lines]
def __getitem__(self, idx):
img_path, label = self.samples[idx]
if self.mode == 'train':
return self.transform(
Image.open(os.path.join(self.data_dir, img_path)).convert(
'RGB'))
else:
return self.transform(
Image.open(os.path.join(self.data_dir, img_path)).convert(
'RGB')), np.array([label]).astype('int64')
def __len__(self):
return len(self.samples)
class AnalysisQATEvalFunction(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(AnalysisQATEvalFunction, self).__init__(*args, **kwargs)
if not os.path.exists('MobileNetV1_infer'):
os.system(
'wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
)
os.system('tar -xf MobileNetV1_infer.tar')
if not os.path.exists('ILSVRC2012_data_demo'):
os.system(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os.system('tar -xf ILSVRC2012_data_demo.tar.gz')
def test_demo(self):
train_dataset = ImageNetDataset("./ILSVRC2012_data_demo/ILSVRC2012/")
image = paddle.static.data(
name='inputs', shape=[None] + [3, 224, 224], dtype='float32')
label = paddle.static.data(
name='labels', shape=[None] + [1], dtype='float32')
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)
def reader_wrapper(reader, input_name):
def gen():
for i, (imgs, label) in enumerate(reader()):
yield {input_name: imgs}
return gen
def eval_reader(data_dir,
batch_size,
crop_size,
resize_size,
place=None):
val_dataset = ImageNetDataset(
"./ILSVRC2012_data_demo/ILSVRC2012/", mode='val')
val_loader = paddle.io.DataLoader(
val_dataset,
feed_list=[image, label],
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=0,
return_list=False)
return val_loader
def eval_function(exe, compiled_test_program, test_feed_names,
test_fetch_list):
val_loader = eval_reader(
'./ILSVRC2012_data_demo/ILSVRC2012/',
batch_size=32,
crop_size=224,
resize_size=256)
results = []
print('Evaluating...')
for batch_id, data in enumerate(val_loader):
image = data[0]['inputs']
label = data[0]['labels']
# top1_acc, top5_acc
if len(test_feed_names) == 1:
image = np.array(image)
label = np.array(label).astype('int64')
pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = np.array(pred[0])
label = np.array(label)
sort_array = pred.argsort(axis=1)
top_1_pred = sort_array[:, -1:][:, ::-1]
top_1 = np.mean(label == top_1_pred)
top_5_pred = sort_array[:, -5:][:, ::-1]
acc_num = 0
for i in range(len(label)):
if label[i][0] in top_5_pred[i]:
acc_num += 1
top_5 = float(acc_num) / len(label)
results.append([top_1, top_5])
else:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = np.array(image)
label = np.array(label).astype('int64')
result = exe.run(compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = [np.mean(r) for r in result]
results.append(result)
if batch_id % 100 == 0:
print('Eval iter: ', batch_id)
result = np.mean(np.array(results), axis=0)
return result[0]
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
executor = paddle.static.Executor(place)
ptq_config = {
'quantizable_op_type': ["conv2d", "depthwise_conv2d"],
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'is_full_quantize': False,
'batch_size': 8,
'batch_nums': 10,
}
post_training_quantization = PostTrainingQuantization(
executor=executor,
data_loader=train_loader,
model_dir="./MobileNetV1_infer",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
onnx_format=True,
algo='avg',
**ptq_config)
post_training_quantization.quantize()
post_training_quantization.save_quantized_model(
"./MobileNetV1_QAT",
model_filename='inference.pdmodel',
params_filename='inference.pdiparams')
analyzer = AnalysisQAT(
float_model_dir="./MobileNetV1_infer",
quant_model_dir="./MobileNetV1_QAT",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
data_loader=train_loader,
eval_function=eval_function)
analyzer.metric_error_analyse()
os.system('rm -rf MobileNetV1_analysis')
os.system('rm -rf MobileNetV1_QAT')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册