未验证 提交 e406740a 编写于 作者: I itminner 提交者: GitHub

quant aware with infer model (#947)

quant aware with infer model
上级 bf123166
# 使用预测模型进行量化训练示例
预测模型获取
动态图使用paddle.jit.save保存;
静态图使用paddle.static.save_inference_model保存。
本示例将介绍如何使用预测模型进行蒸馏量化训练,
首先使用接口``paddleslim.quant.quant_aware_with_infermodel``训练量化模型,
训练完成后,使用接口``paddleslim.quant.export_quant_infermodel``将训好的量化模型导出为预测模型。
## 分类模型量化训练流程
###1. 准备数据
``demo``文件夹下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件:
- ``'train'``文件夹,训练图片
- ``'train_list.txt'``文件
- ``'val'``文件夹,验证图片
- ``'val_list.txt'``文件
### 2. 准备需要量化的模型
飞桨图像识别套件PaddleClas是飞桨为工业界和学术界所准备的一个图像识别任务的工具集,本示例使用该套件产出imagenet分类模型。
####① 下载MobileNetV2预训练模型
预训练模型库地址 ``https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/algorithm_introduction/ImageNet_models.md``
MobileNetV2预训练模型地址 ``https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_pretrained.pdparams``
在PaddleClas代码库根目录创建pretrained文件夹,MobileNetV2预训练参数保存在该文件夹中。
#### ② 导出预测模型
PaddleClas代码库根目录执行如下命令,导出预测模型
```
python tools/export_model.py \
-c ppcls/configs/ImageNet/MobileNetV2/MobileNetV2.yaml \
-o Global.pretrained_model=pretrained/MobileNetV2_pretrained \
-o Global.save_inference_dir=infermodel_mobilenetv2 \
```
#### ③ 测试模型精度
使用[eval.py](../quant_post/eval.py)脚本得到模型的分类精度:
```
python ../quant_post/eval.py --model_path infermodel_mobilenetv2 --model_name inference.pdmodel --params_name inference.pdiparams
```
精度输出为:
```
top1_acc/top5_acc= [0.71918 0.90568]
```
### 3. 进行量化蒸馏训练
蒸馏量化训练示例脚本为[quant_aware_with_infermodel.py](./quant_aware_with_infermodel.py),使用接口``paddleslim.quant.quant_aware_with_infermodel``对模型进行量化训练。运行命令为:
```
python quant_aware_with_infermodel.py \
--batch_size=2 \
--num_epoch=30 \
--save_iter_step=100 \
--learning_rate=0.0001 \
--weight_decay=0.00004 \
--use_pact=True \
--checkpoint_path="./inference_model/MobileNet_quantaware_ckpt/" \
--model_path="./infermodel_mobilenetv2/" \
--model_filename="inference.pdmodel" \
--params_filename="inference.pdiparams" \
--teacher_model_path="./infermodel_mobilenetv2/" \
--teacher_model_filename="inference.pdmodel" \
--teacher_params_filename="inference.pdiparams" \
--distill_node_name_list "teacher_conv2d_54.tmp_0" "conv2d_54.tmp_0" "teacher_conv2d_55.tmp_0" "conv2d_55.tmp_0" \
"teacher_conv2d_57.tmp_0" "conv2d_57.tmp_0" "teacher_elementwise_add_0" "elementwise_add_0" \
"teacher_conv2d_61.tmp_0" "conv2d_61.tmp_0" "teacher_elementwise_add_1" "elementwise_add_1" \
"teacher_elementwise_add_2" "elementwise_add_2" "teacher_conv2d_67.tmp_0" "conv2d_67.tmp_0" \
"teacher_elementwise_add_3" "elementwise_add_3" "teacher_elementwise_add_4" "elementwise_add_4" \
"teacher_elementwise_add_5" "elementwise_add_5" "teacher_conv2d_75.tmp_0" "conv2d_75.tmp_0" \
"teacher_elementwise_add_6" "elementwise_add_6" "teacher_elementwise_add_7" "elementwise_add_7" \
"teacher_conv2d_81.tmp_0" "conv2d_81.tmp_0" "teacher_elementwise_add_8" "elementwise_add_8" \
"teacher_elementwise_add_9" "elementwise_add_9" "teacher_conv2d_87.tmp_0" "conv2d_87.tmp_0" \
"teacher_linear_1.tmp_0" "linear_1.tmp_0"
```
- ``batch_size``: 量化训练batch size。
- ``num_epoch``: 量化训练epoch数。
- ``save_iter_step``: 每隔save_iter_step保存一次checkpoint。
- ``learning_rate``: 量化训练学习率,推荐使用float模型训练最小一级学习率。
- ``weight_decay``: 推荐使用float模型训练weight decay设置。
- ``use_pact``: 是否使用pact量化算法, 推荐使用。
- ``checkpoint_path``: 量化训练模型checkpoint保存路径。
- ``model_path``: 需要量化的预测模型所在路径。
- ``model_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的模型文件名称,如果参数文件保存在多个文件中,则不需要设置。
- ``params_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的参数文件名称,如果参数文件保存在多个文件中,则不需要设置。
- ``teacher_model_path``: teacher模型所在路径, 可以和量化模型是同一个,即自蒸馏。
- ``teacher_model_filename``: teacher模型model文件名字。
- ``teacher_params_filename``: teacher模型参数文件名字。
- ``distill_node_name_list``: 蒸馏节点名字列表,每两个节点组成一对,分别属于teacher模型和量化模型。
运行以上命令后,可在``${checkpoint_path}``下看到量化后模型的checkpoint。
### 4. 量化模型导出
量化模型checkpoint导出为预测模型。
```
python export_quantmodel.py \
--use_gpu=True \
--checkpoint_path="./MobileNetV2_checkpoints/epoch_0_iter_2000" \
--infermodel_save_path="./quant_infermodel_mobilenetv2/" \
```
###5. 测试精度
使用[eval.py](../quant_post/eval.py)脚本对量化后的模型进行精度测试:
```
python ../quant_post/eval.py --model_path ./quant_infermodel_mobilenetv2/ --model_name model --params_name params
```
精度输出为:
```
top1_acc/top5_acc= [0.71764 0.90418]
```
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import math
import time
import numpy as np
import paddle
import logging
import argparse
import functools
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
sys.path[1] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir, os.path.pardir)
from paddleslim.common import get_logger
from paddleslim.quant import export_quant_infermodel
from utility import add_arguments, print_arguments
import imagenet_reader as reader
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('batch_size', int, 4, "train batch size.")
add_arg('num_epoch', int, 1, "train epoch num.")
add_arg('save_iter_step', int, 1, "save train checkpoint every save_iter_step iter num.")
add_arg('learning_rate', float, 0.0001, "learning rate.")
add_arg('weight_decay', float, 0.00004, "weight decay.")
add_arg('use_pact', bool, True, "whether use pact quantization.")
add_arg('checkpoint_path', str, None, "model dir to save quanted model checkpoints")
add_arg('model_path_prefix', str, None, "storage directory of model + model name (excluding suffix)")
add_arg('teacher_model_path_prefix', str, None, "storage directory of teacher model + teacher model name (excluding suffix)")
add_arg('distill_node_name_list', str, None, "distill node name list", nargs="+")
add_arg('checkpoint_filename', str, None, "checkpoint filename to export inference model")
add_arg('export_inference_model_path_prefix', str, None, "inference model export path prefix")
def export(args):
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
exe = paddle.static.Executor(place)
quant_config = {
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'not_quant_pattern': ['skip_quant'],
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul']
}
train_config={
"num_epoch": args.num_epoch, # training epoch num
"max_iter": -1,
"save_iter_step": args.save_iter_step,
"learning_rate": args.learning_rate,
"weight_decay": args.weight_decay,
"use_pact": args.use_pact,
"quant_model_ckpt_path":args.checkpoint_path,
"teacher_model_path_prefix": args.teacher_model_path_prefix,
"model_path_prefix": args.model_path_prefix,
"distill_node_pair": args.distill_node_name_list
}
export_quant_infermodel(exe, place,
scope=None,
quant_config=quant_config,
train_config=train_config,
checkpoint_path=os.path.join(args.checkpoint_path, args.checkpoint_filename),
export_inference_model_path_prefix=args.export_inference_model_path_prefix)
def main():
args = parser.parse_args()
args.use_pact = bool(args.use_pact)
print_arguments(args)
export(args)
if __name__ == '__main__':
paddle.enable_static()
main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import math
import time
import numpy as np
import paddle
import logging
import argparse
import functools
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
sys.path[1] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir, os.path.pardir)
from paddleslim.common import get_logger
from paddleslim.quant import quant_aware_with_infermodel
from utility import add_arguments, print_arguments
import imagenet_reader as reader
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "whether to use GPU or not.")
add_arg('batch_size', int, 1, "train batch size.")
add_arg('num_epoch', int, 1, "train epoch num.")
add_arg('save_iter_step', int, 1, "save train checkpoint every save_iter_step iter num.")
add_arg('learning_rate', float, 0.0001, "learning rate.")
add_arg('weight_decay', float, 0.00004, "weight decay.")
add_arg('use_pact', bool, True, "whether use pact quantization.")
add_arg('checkpoint_path', str, None, "model dir to save quanted model checkpoints")
add_arg('model_path_prefix', str, None, "storage directory of model + model name (excluding suffix)")
add_arg('teacher_model_path_prefix', str, None, "storage directory of teacher model + teacher model name (excluding suffix)")
add_arg('distill_node_name_list', str, None, "distill node name list", nargs="+")
DATA_DIR = "../../data/ILSVRC2012/"
def eval(exe, place, compiled_test_program, test_feed_names, test_fetch_list):
val_reader = paddle.batch(reader.val(), batch_size=1)
image = paddle.static.data(
name='x', shape=[None, 3, 224, 224], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
results = []
for batch_id, data in enumerate(val_reader()):
# top1_acc, top5_acc
if len(test_feed_names) == 1:
# eval "infer model", which input is image, output is classification probability
image = data[0][0].reshape((1, 3, 224, 224))
label = [[d[1]] for d in data]
pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = np.array(pred[0])
label = np.array(label)
sort_array = pred.argsort(axis=1)
top_1_pred = sort_array[:, -1:][:, ::-1]
top_1 = np.mean(label == top_1_pred)
top_5_pred = sort_array[:, -5:][:, ::-1]
acc_num = 0
for i in range(len(label)):
if label[i][0] in top_5_pred[i]:
acc_num += 1
top_5 = float(acc_num) / len(label)
results.append([top_1, top_5])
else:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = data[0][0].reshape((1, 3, 224, 224))
label = [[d[1]] for d in data]
result = exe.run(compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = [np.mean(r) for r in result]
results.append(result)
result = np.mean(np.array(results), axis=0)
return result
def quantize(args):
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
#place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
quant_config = {
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'not_quant_pattern': ['skip_quant'],
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul']
}
train_config={
"num_epoch": args.num_epoch, # training epoch num
"max_iter": -1,
"save_iter_step": args.save_iter_step,
"learning_rate": args.learning_rate,
"weight_decay": args.weight_decay,
"use_pact": args.use_pact,
"quant_model_ckpt_path":args.checkpoint_path,
"teacher_model_path_prefix": args.teacher_model_path_prefix,
"model_path_prefix": args.model_path_prefix,
"distill_node_pair": args.distill_node_name_list
}
def test_callback(compiled_test_program, feed_names, fetch_list, checkpoint_name):
ret = eval(exe, place, compiled_test_program, feed_names, fetch_list)
print("{0} top1_acc/top5_acc= {1}".format(checkpoint_name, ret))
train_reader = paddle.batch(reader.train(), batch_size=args.batch_size)
def train_reader_wrapper():
def gen():
for i, data in enumerate(train_reader()):
imgs = np.float32([item[0] for item in data])
yield {"x":imgs}
return gen
quant_aware_with_infermodel(
exe,
place,
scope=None,
train_reader=train_reader_wrapper(),
quant_config=quant_config,
train_config=train_config,
test_callback=test_callback)
def main():
args = parser.parse_args()
args.use_pact=bool(args.use_pact)
print("args.use_pact", args.use_pact)
print_arguments(args)
quantize(args)
if __name__ == '__main__':
paddle.enable_static()
main()
# 静态离线量化超参搜索示例
本示例将介绍如何使用离线量化超参搜索接口``paddleslim.quant.quant_post_hpo``来对训练好的分类模型进行离线量化超参搜索。
## 分类模型的离线量化超参搜索流程
### 准备数据
``demo``文件夹下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件:
- ``'train'``文件夹,训练图片
- ``'train_list.txt'``文件
- ``'val'``文件夹,验证图片
- ``'val_list.txt'``文件
### 准备需要量化的模型
离线量化接口只支持加载通过``paddle.static.save_inference_model``接口保存的模型。因此如果您的模型是通过其他接口保存的,需要先将模型进行转化。本示例将以分类模型为例进行说明。
首先在[imagenet分类模型](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD)中下载训练好的``mobilenetv1``模型。
在当前文件夹下创建``'pretrain'``文件夹,将``mobilenetv1``模型在该文件夹下解压,解压后的目录为``pretrain/MobileNetV1_pretrained``
### 导出模型
通过运行以下命令可将模型转化为离线量化接口可用的模型:
```
python ../quant_post/export_model.py --model "MobileNet" --pretrained_model ./pretrain/MobileNetV1_pretrained --data imagenet
```
转化之后的模型存储在``inference_model/MobileNet/``文件夹下,可看到该文件夹下有``'model'``, ``'weights'``两个文件。
### 静态离线量化
接下来对导出的模型文件进行静态离线量化,静态离线量化的脚本为[quant_post_hpo.py](./quant_post_hpo.py),脚本中使用接口``paddleslim.quant.quant_post_hpo``对模型进行离线量化。运行命令为:
```
python quant_post_hpo.py \
--use_gpu=True \
--model_path="./inference_model/MobileNet/" \
--save_path="./inference_model/MobileNet_quant/" \
--model_filename="model" \
--params_filename="weights" \
--max_model_quant_count=26
```
- ``model_path``: 需要量化的模型所在路径
- ``save_path``: 量化后的模型保存的路径
- ``model_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的模型文件名称,如果参数文件保存在多个文件中,则不需要设置。
- ``params_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的参数文件名称,如果参数文件保存在多个文件中,则不需要设置。
- ``max_model_quant_count``: 最大离线量化搜索次数,次数越多产出高精度量化模型概率越大,耗时也会相应增加。建议值:大于20小于30。
运行以上命令后,可在``${save_path}``下看到量化后的模型文件和参数文件。
### 测试精度
使用[eval.py](../quant_post/eval.py)脚本对量化前后的模型进行测试,得到模型的分类精度进行对比。
首先测试量化前的模型的精度,运行以下命令:
```
python ../quant_post/eval.py --model_path ./inference_model/MobileNet --model_name model --params_name weights
```
精度输出为:
```
top1_acc/top5_acc= [0.70898 0.89534]
```
使用以下命令测试离线量化后的模型的精度:
```
python ../quant_post/eval.py --model_path ./inference_model/MobileNet_quant/ --model_name __model__ --params_name __params__
```
精度输出为
```
top1_acc/top5_acc= [0.70653 0.89369]
```
......@@ -14,6 +14,7 @@
import numpy as np
import paddle
from paddleslim.core import GraphWrapper
def merge(teacher_program,
......@@ -94,6 +95,16 @@ def merge(teacher_program,
student_program.global_block().append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=attrs)
student_graph = GraphWrapper(student_program)
for op in student_graph.ops():
belongsto_teacher = False
for inp in op.all_inputs():
if 'teacher' in inp.name():
belongsto_teacher = True
break
if belongsto_teacher:
op._op._set_attr("skip_quant", True)
def fsp_loss(teacher_var1_name,
teacher_var2_name,
......
......@@ -31,6 +31,7 @@ try:
], "training-aware and post-training quant is not supported in 2.0 alpha version paddle"
from .quanter import quant_aware, convert, quant_post_static, quant_post_dynamic
from .quanter import quant_post, quant_post_only_weight
from .quant_aware_with_infermodel import quant_aware_with_infermodel, export_quant_infermodel
from .quant_post_hpo import quant_post_hpo
except Exception as e:
_logger.warning(e)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""train aware quant with infermodel"""
import copy
import os
import argparse
import json
import six
from collections import namedtuple
import time
import shutil
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import unique_name
from paddle.fluid import core
from paddle.fluid.framework import Parameter
from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss
from paddleslim.core import GraphWrapper
from paddleslim.quant import quant_aware, convert
from .quanter import _quant_config_default, _parse_configs, pact, get_pact_optimizer
import logging
logging.getLogger().setLevel(logging.INFO)
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
############################################################################################################
# quantization training configs
############################################################################################################
_train_config_default = {
# configs of training aware quantization with infermodel
"num_epoch": 1000, # training epoch num
"max_iter": -1, # max training iteration num
"save_iter_step":
1000, # save quant model checkpoint every save_iter_step iteration
"learning_rate": 0.0001, # learning rate
"weight_decay": 0.0001, # weight decay
"use_pact": False, # use pact quantization or not
# quant model checkpoints save path
"quant_model_ckpt_path": "./quant_model_checkpoints/",
# storage directory of teacher model + teacher model name (excluding suffix)
"teacher_model_path_prefix": None,
# storage directory of model + model name (excluding suffix)
"model_path_prefix": None,
""" distillation node configuration:
the name of the distillation supervision nodes is configured as a list,
and the teacher node and student node are arranged in pairs.
for example, ["teacher_fc_0.tmp_0", "fc_0.tmp_0", "teacher_batch_norm_24.tmp_4", "batch_norm_24.tmp_4"]
"""
"distill_node_pair": None
}
def _parse_train_configs(train_config):
"""
check if user's train configs are valid.
Args:
train_config(dict): user's train config.
Return:
configs(dict): final configs will be used.
"""
configs = copy.deepcopy(_train_config_default)
configs.update(train_config)
assert isinstance(configs['num_epoch'], int), \
"'num_epoch' must be int value"
assert isinstance(configs['max_iter'], int), \
"'max_iter' must be int value"
assert isinstance(configs['save_iter_step'], int), \
"'save_iter_step' must be int value"
assert isinstance(configs['learning_rate'], float), \
"'learning_rate' must be float"
assert isinstance(configs['weight_decay'], float), \
"'weight_decay' must be float"
assert isinstance(configs['use_pact'], bool), \
"'use_pact' must be bool"
assert isinstance(configs['quant_model_ckpt_path'], str), \
"'quant_model_ckpt_path' must be str"
assert isinstance(configs['teacher_model_path_prefix'], str), \
"'teacher_model_path_prefix' must both be string"
assert isinstance(configs['model_path_prefix'], str), \
"'model_path_prefix' must both be str"
assert isinstance(configs['distill_node_pair'], list), \
"'distill_node_pair' must both be list"
assert len(configs['distill_node_pair']) > 0, \
"'distill_node_pair' not configured with distillation nodes"
assert len(configs['distill_node_pair']) % 2 == 0, \
"'distill_node_pair' distillation nodes need to be configured in pairs"
return train_config
def _create_optimizer(train_config):
"""create optimizer"""
optimizer = paddle.optimizer.SGD(
learning_rate=train_config["learning_rate"],
weight_decay=paddle.regularizer.L2Decay(train_config["weight_decay"]))
return optimizer
def _remove_fetch_node(program):
"""remove fetch node in program"""
for block in program.blocks:
removed = 0
ops = list(block.ops)
for op in ops:
if op.type == "fetch":
idx = ops.index(op)
block._remove_op(idx - removed)
removed += 1
def _recover_reserve_space_with_bn(program):
"""Add the outputs which is only used for training and not saved in
inference program."""
for block_idx in six.moves.range(program.num_blocks):
block = program.block(block_idx)
for op in block.ops:
if op.type == "batch_norm":
if "ReserveSpace" not in op.output_names or len(
op.output("ReserveSpace")) == 0:
reserve_space = block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["reserve_space", 'tmp'])),
dtype=block.var(op.input("X")[0]).dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
op.desc.set_output("ReserveSpace", [reserve_space.name])
return program
def _recover_param_attr(program):
"""recover parameters attribute.
Params in infermodel are stored in the form of variable, which can not be trained."""
all_weights = [param for param in program.list_vars() \
if param.persistable is True and param.name != 'feed' and param.name != 'fetch']
for w in all_weights:
new_w = Parameter(
block=program.block(0),
shape=w.shape,
dtype=w.dtype,
type=w.type,
name=w.name)
new_w.set_value(w.get_value())
program.block(0).vars[w.name] = new_w
return program
def _parse_distill_loss(train_config):
"""parse distill loss config"""
assert len(train_config["distill_node_pair"]) % 2 == 0, \
"distill_node_pair config wrong, the length needs to be an even number"
print("train config.distill_node_pair: ", train_config["distill_node_pair"])
distill_loss = 0
for i in range(len(train_config["distill_node_pair"]) // 2):
print(train_config["distill_node_pair"][i * 2],
train_config["distill_node_pair"][i * 2 + 1])
distill_loss += l2_loss(train_config["distill_node_pair"][i * 2],
train_config["distill_node_pair"][i * 2 + 1])
return distill_loss
DistillProgramInfo = namedtuple("DistillProgramInfo", \
"startup_program train_program train_feed_names train_fetch_list \
optimizer test_program test_feed_names test_fetch_list"
)
def build_distill_prog_with_infermodel(executor, place, train_config):
"""build distill program with infermodel"""
[train_program, feed_target_names, fetch_targets]= paddle.static.load_inference_model( \
path_prefix=train_config["model_path_prefix"], \
executor=executor)
_remove_fetch_node(train_program)
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.static.load_inference_model( \
path_prefix=train_config["teacher_model_path_prefix"], \
executor=executor)
_remove_fetch_node(teacher_program)
test_program = train_program.clone(for_test=True)
train_program = _recover_param_attr(train_program)
train_program = _recover_reserve_space_with_bn(train_program)
for var in train_program.list_vars():
var.stop_gradient = False
train_graph = GraphWrapper(train_program)
for op in train_graph.ops():
op._op._set_attr("is_test", False)
############################################################################
# distill
############################################################################
data_name_map = {}
assert len(feed_target_names) == len(teacher_feed_target_names), \
"the number of feed nodes in the teacher model is not equal to the student model"
for i, name in enumerate(feed_target_names):
data_name_map[teacher_feed_target_names[i]] = name
merge(teacher_program, train_program, data_name_map, place)
# all feed node should set stop_gradient is False, for using pact quant algo.
for var in train_program.list_vars():
if var.name in data_name_map.values() or var.name in data_name_map.keys(
):
var.stop_gradient = False
train_fetch_list = []
train_fetch_name_list = []
startup_program = paddle.static.Program()
with paddle.static.program_guard(train_program, startup_program):
with fluid.unique_name.guard('merge'):
optimizer = _create_optimizer(train_config)
distill_loss = _parse_distill_loss(train_config)
loss = paddle.mean(distill_loss)
loss.stop_gradient = False
p_g_list = paddle.static.append_backward(loss=loss)
opts = optimizer.apply_gradients(p_g_list)
train_fetch_list.append(loss)
train_fetch_name_list.append(loss.name)
return DistillProgramInfo(startup_program, train_program, \
feed_target_names, train_fetch_list, optimizer, \
test_program, feed_target_names, fetch_targets)
def _compile_program(program, fetch_var_name):
"""compiling program"""
compiled_prog = paddle.static.CompiledProgram(program)
build_strategy = paddle.static.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
build_strategy.fuse_all_reduce_ops = False
build_strategy.sync_batch_norm = False
exec_strategy = paddle.static.ExecutionStrategy()
compiled_prog = compiled_prog.with_data_parallel(
loss_name=fetch_var_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
return compiled_prog
def quant_aware_with_infermodel(executor,
place,
scope=None,
train_reader=None,
quant_config=None,
train_config=None,
test_callback=None):
"""train aware quantization with infermodel
Args:
executor(paddle.static.Executor): The executor to load, run and save the
quantized model.
place(paddle.CPUPlace or paddle.CUDAPlace): This parameter represents
the executor run on which device.
scope(paddle.static.Scope, optional): Scope records the mapping between
variable names and variables, similar to brackets in
programming languages. Usually users can use
`paddle.static.global_scope <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_.
When ``None`` will use
`paddle.static.global_scope() <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_
. Default: ``None``.
train_reader(data generator): data generator, yield feed_dictionary, {feed_name[0]:data[0], feed_name[1]:data[1]}.
quant_config(dict, optional): configs for convert. if set None, will use
default config. It must be same with config that used in
'quant_aware'. Default is None.
train_config(dict):train aware configs, include num_epoch, save_iter_step, learning_rate,
weight_decay, use_pact, quant_model_ckpt_path,
model_path_prefix, teacher_model_path_prefix,
distill_node_pair(teacher_node_name1, node_name1, teacher_node_name2, teacher_node_name2, ...)
test_callback(callback function): callback function include two params: compiled test quant program and checkpoint save filename.
user can implement test logic.
Returns:
None
"""
scope = paddle.static.global_scope() if not scope else scope
# parse quant config
if quant_config is None:
quant_config = _quant_config_default
else:
assert isinstance(quant_config, dict), "quant config must be dict"
quant_config = _parse_configs(quant_config)
_logger.info("quant_aware config {}".format(quant_config))
train_config = _parse_train_configs(train_config)
distill_program_info = build_distill_prog_with_infermodel(executor, place,
train_config)
startup_program = distill_program_info.startup_program
train_program = distill_program_info.train_program
train_feed_names = distill_program_info.train_feed_names
train_fetch_list = distill_program_info.train_fetch_list
optimizer = distill_program_info.optimizer
test_program = distill_program_info.test_program
test_feed_names = distill_program_info.test_feed_names
test_fetch_list = distill_program_info.test_fetch_list
############################################################################
# quant
############################################################################
use_pact = train_config["use_pact"]
if use_pact:
act_preprocess_func = pact
optimizer_func = get_pact_optimizer
pact_executor = executor
else:
act_preprocess_func = None
optimizer_func = None
pact_executor = None
test_program = quant_aware(
test_program,
place,
quant_config,
scope=scope,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=pact_executor,
for_test=True)
train_program = quant_aware(
train_program,
place,
quant_config,
scope=scope,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=pact_executor,
for_test=False,
return_program=True)
executor.run(startup_program)
compiled_train_prog = _compile_program(train_program,
train_fetch_list[0].name)
compiled_test_prog = _compile_program(test_program, test_fetch_list[0].name)
num_epoch = train_config["num_epoch"]
save_iter_step = train_config["save_iter_step"]
iter_sum = 0
for epoch in range(num_epoch):
for iter_num, feed_dict in enumerate(train_reader()):
np_probs_float = executor.run(compiled_train_prog, \
feed=feed_dict, \
fetch_list=train_fetch_list)
print("loss: ", np_probs_float)
if iter_num > 0 and iter_num % save_iter_step == 0:
checkpoint_name = "epoch_" + str(epoch) + "_iter_" + str(
iter_num)
paddle.static.save(
program=test_program,
model_path=os.path.join(
train_config["quant_model_ckpt_path"], checkpoint_name))
test_callback(compiled_test_prog, test_feed_names,
test_fetch_list, checkpoint_name)
iter_sum += 1
if train_config["max_iter"] >= 0 and iter_sum > train_config[
"max_iter"]:
return
def export_quant_infermodel(
executor,
place=None,
scope=None,
quant_config=None,
train_config=None,
checkpoint_path=None,
export_inference_model_path_prefix="./export_quant_infermodel"):
"""export quant model checkpoints to infermodel.
Args:
executor(paddle.static.Executor): The executor to load, run and save the
quantized model.
place(paddle.CPUPlace or paddle.CUDAPlace): This parameter represents
the executor run on which device.
scope(paddle.static.Scope, optional): Scope records the mapping between
variable names and variables, similar to brackets in
programming languages. Usually users can use
`paddle.static.global_scope <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_.
When ``None`` will use
`paddle.static.global_scope() <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_
. Default: ``None``.
quant_config(dict, optional): configs for convert. if set None, will use
default config. It must be same with config that used in
'quant_aware'. Default is None.
train_config(dict):train aware configs, include num_epoch, save_iter_step, learning_rate,
weight_decay, use_pact, quant_model_ckpt_path,
model_path_prefix, teacher_model_path_prefix,
distill_node_pair(teacher_node_name1, node_name1, teacher_node_name2, teacher_node_name2, ...)
checkpoint_path(str): checkpoint path need to export quant infer model.
export_inference_model_path_prefix(str): export infer model path prefix, storage directory of model + model name (excluding suffix).
Returns:
None
"""
scope = paddle.static.global_scope() if not scope else scope
# parse quant config
if quant_config is None:
quant_config = _quant_config_default
else:
assert isinstance(quant_config, dict), "quant config must be dict"
quant_config = _parse_configs(quant_config)
_logger.info("quant_aware config {}".format(quant_config))
train_config = _parse_train_configs(train_config)
distill_program_info = build_distill_prog_with_infermodel(executor, place,
train_config)
test_program = distill_program_info.test_program
test_feed_names = distill_program_info.test_feed_names
test_fetch_list = distill_program_info.test_fetch_list
############################################################################
# quant
############################################################################
use_pact = train_config["use_pact"]
if use_pact:
act_preprocess_func = pact
optimizer_func = get_pact_optimizer
pact_executor = executor
else:
act_preprocess_func = None
optimizer_func = None
pact_executor = None
test_program = quant_aware(
test_program,
place,
quant_config,
scope=scope,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=pact_executor,
for_test=True)
paddle.static.load(
executor=executor,
model_path=os.path.join(checkpoint_path),
program=test_program)
############################################################################################################
# 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
# The dtype of float_program's weights is float32, but in int8 range.
############################################################################################################
float_program, int8_program = convert(test_program, place, quant_config, \
scope=scope, \
save_int8=True)
############################################################################################################
# 4. Save inference model
############################################################################################################
export_model_dir = os.path.abspath(
os.path.join(export_inference_model_path_prefix, os.path.pardir))
if not os.path.exists(export_model_dir):
os.makedirs(export_model_dir)
feed_vars = []
for name in test_feed_names:
for var in float_program.list_vars():
if var.name == name:
feed_vars.append(var)
break
assert len(feed_vars) > 0, "can not find feed vars in quant program"
paddle.static.save_inference_model(
path_prefix=export_inference_model_path_prefix,
feed_vars=feed_vars,
fetch_vars=test_fetch_list,
executor=executor,
program=float_program)
......@@ -29,6 +29,7 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import WeightQuantization
from paddle.fluid.layer_helper import LayerHelper
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
......@@ -561,3 +562,30 @@ def quant_post_dynamic(model_dir,
# For compatibility, we keep quant_post_only_weight api for now,
# and it will be deprecated in the future.
quant_post_only_weight = quant_post_dynamic
def pact(x, name=None):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = 20
u_param_attr = paddle.fluid.ParamAttr(
name=x.name + '_pact',
initializer=paddle.fluid.initializer.ConstantInitializer(
value=init_thres),
regularizer=paddle.fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(attr=u_param_attr, shape=[1], dtype=dtype)
x = paddle.fluid.layers.elementwise_sub(
x,
paddle.fluid.layers.relu(
paddle.fluid.layers.elementwise_sub(x, u_param)))
x = paddle.fluid.layers.elementwise_add(
x,
paddle.fluid.layers.relu(
paddle.fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_pact_optimizer():
return paddle.fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
sys.path.append("../")
sys.path.append(".")
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import unittest
import paddle
from paddleslim.quant import quant_aware, convert
from paddleslim.quant import quant_aware_with_infermodel, export_quant_infermodel
from static_case import StaticCase
sys.path.append("../demo")
from models import MobileNet
from layers import conv_bn_layer
import paddle.dataset.mnist as reader
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
import numpy as np
class TestQuantAwareWithInferModelCase1(StaticCase):
def test_accuracy(self):
float_infer_model_path_prefix = "./mv1_float_inference"
image = paddle.static.data(
name='image', shape=[None, 1, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
model = MobileNet()
out = model.net(input=image, class_dim=10)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
optimizer = paddle.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
weight_decay=paddle.regularizer.L2Decay(4e-5))
optimizer.minimize(avg_cost)
main_prog = paddle.static.default_main_program()
val_prog = main_prog.clone(for_test=True)
#place = paddle.CPUPlace()
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
def transform(x):
return np.reshape(x, [1, 28, 28])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform)
train_loader = paddle.io.DataLoader(
train_dataset,
places=place,
feed_list=[image, label],
drop_last=True,
batch_size=64,
return_list=False)
valid_loader = paddle.io.DataLoader(
test_dataset,
places=place,
feed_list=[image, label],
batch_size=64,
return_list=False)
def sample_generator_creator():
def __reader__():
for data in test_dataset:
image, label = data
yield image, label
return __reader__
def train(program):
iter = 0
for data in train_loader():
cost, top1, top5 = exe.run(
program,
feed=data,
fetch_list=[avg_cost, acc_top1, acc_top5])
iter += 1
if iter % 100 == 0:
print(
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5))
def test(program, outputs=[avg_cost, acc_top1, acc_top5]):
iter = 0
result = [[], [], []]
for data in valid_loader():
cost, top1, top5 = exe.run(program,
feed=data,
fetch_list=outputs)
iter += 1
if iter % 100 == 0:
print('eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5))
result[0].append(cost)
result[1].append(top1)
result[2].append(top5)
print(' avg loss {}, acc_top1 {}, acc_top5 {}'.format(
np.mean(result[0]), np.mean(result[1]), np.mean(result[2])))
return np.mean(result[1]), np.mean(result[2])
train(main_prog)
top1_1, top5_1 = test(val_prog)
paddle.static.save_inference_model(
path_prefix=float_infer_model_path_prefix,
feed_vars=[image, label],
fetch_vars=[avg_cost, acc_top1, acc_top5],
executor=exe,
program=val_prog)
quant_config = {
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'not_quant_pattern': ['skip_quant'],
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul']
}
train_config = {
"num_epoch": 1, # training epoch num
"max_iter": 20,
"save_iter_step": 10,
"learning_rate": 0.0001,
"weight_decay": 0.0001,
"use_pact": False,
"quant_model_ckpt_path":
"./quantaware_with_infermodel_checkpoints/",
"teacher_model_path_prefix": float_infer_model_path_prefix,
"model_path_prefix": float_infer_model_path_prefix,
"distill_node_pair": [
"teacher_fc_0.tmp_0", "fc_0.tmp_0",
"teacher_batch_norm_24.tmp_4", "batch_norm_24.tmp_4",
"teacher_batch_norm_22.tmp_4", "batch_norm_22.tmp_4",
"teacher_batch_norm_18.tmp_4", "batch_norm_18.tmp_4",
"teacher_batch_norm_13.tmp_4", "batch_norm_13.tmp_4",
"teacher_batch_norm_5.tmp_4", "batch_norm_5.tmp_4"
]
}
def test_callback(compiled_test_program, feed_names, fetch_list,
checkpoint_name):
outputs = fetch_list
iter = 0
result = [[], [], []]
for data in valid_loader():
cost, top1, top5 = exe.run(compiled_test_program,
feed=data,
fetch_list=fetch_list)
iter += 1
if iter % 100 == 0:
print('eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5))
result[0].append(cost)
result[1].append(top1)
result[2].append(top5)
print("quant model checkpoint: " + checkpoint_name +
' avg loss {}, acc_top1 {}, acc_top5 {}'.format(
np.mean(result[0]),
np.mean(result[1]), np.mean(result[2])))
return np.mean(result[1]), np.mean(result[2])
def test_quant_aware_with_infermodel(exe, place):
quant_aware_with_infermodel(
exe,
place,
scope=None,
train_reader=train_loader,
quant_config=quant_config,
train_config=train_config,
test_callback=test_callback)
def test_export_quant_infermodel(exe, place, checkpoint_path,
quant_infermodel_save_path):
export_quant_infermodel(
exe,
place,
scope=None,
quant_config=quant_config,
train_config=train_config,
checkpoint_path=checkpoint_path,
export_inference_model_path_prefix=quant_infermodel_save_path)
#place = paddle.CPUPlace()
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
test_quant_aware_with_infermodel(exe, place)
checkpoint_path = "./quantaware_with_infermodel_checkpoints/epoch_0_iter_10"
quant_infermodel_save_path = "./quantaware_with_infermodel_export"
test_export_quant_infermodel(exe, place, checkpoint_path,
quant_infermodel_save_path)
train_config["use_pact"] = True
test_quant_aware_with_infermodel(exe, place)
train_config["use_pact"] = False
checkpoint_path = "./quantaware_with_infermodel_checkpoints/epoch_0_iter_10"
quant_infermodel_save_path = "./quantaware_with_infermodel_export"
test_export_quant_infermodel(exe, place, checkpoint_path,
quant_infermodel_save_path)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册