未验证 提交 e406740a 编写于 作者: I itminner 提交者: GitHub

quant aware with infer model (#947)

quant aware with infer model
上级 bf123166
# 使用预测模型进行量化训练示例
预测模型获取
动态图使用paddle.jit.save保存;
静态图使用paddle.static.save_inference_model保存。
本示例将介绍如何使用预测模型进行蒸馏量化训练,
首先使用接口``paddleslim.quant.quant_aware_with_infermodel``训练量化模型,
训练完成后,使用接口``paddleslim.quant.export_quant_infermodel``将训好的量化模型导出为预测模型。
## 分类模型量化训练流程
###1. 准备数据
``demo``文件夹下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件:
- ``'train'``文件夹,训练图片
- ``'train_list.txt'``文件
- ``'val'``文件夹,验证图片
- ``'val_list.txt'``文件
### 2. 准备需要量化的模型
飞桨图像识别套件PaddleClas是飞桨为工业界和学术界所准备的一个图像识别任务的工具集,本示例使用该套件产出imagenet分类模型。
####① 下载MobileNetV2预训练模型
预训练模型库地址 ``https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/algorithm_introduction/ImageNet_models.md``
MobileNetV2预训练模型地址 ``https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_pretrained.pdparams``
在PaddleClas代码库根目录创建pretrained文件夹,MobileNetV2预训练参数保存在该文件夹中。
#### ② 导出预测模型
PaddleClas代码库根目录执行如下命令,导出预测模型
```
python tools/export_model.py \
-c ppcls/configs/ImageNet/MobileNetV2/MobileNetV2.yaml \
-o Global.pretrained_model=pretrained/MobileNetV2_pretrained \
-o Global.save_inference_dir=infermodel_mobilenetv2 \
```
#### ③ 测试模型精度
使用[eval.py](../quant_post/eval.py)脚本得到模型的分类精度:
```
python ../quant_post/eval.py --model_path infermodel_mobilenetv2 --model_name inference.pdmodel --params_name inference.pdiparams
```
精度输出为:
```
top1_acc/top5_acc= [0.71918 0.90568]
```
### 3. 进行量化蒸馏训练
蒸馏量化训练示例脚本为[quant_aware_with_infermodel.py](./quant_aware_with_infermodel.py),使用接口``paddleslim.quant.quant_aware_with_infermodel``对模型进行量化训练。运行命令为:
```
python quant_aware_with_infermodel.py \
--batch_size=2 \
--num_epoch=30 \
--save_iter_step=100 \
--learning_rate=0.0001 \
--weight_decay=0.00004 \
--use_pact=True \
--checkpoint_path="./inference_model/MobileNet_quantaware_ckpt/" \
--model_path="./infermodel_mobilenetv2/" \
--model_filename="inference.pdmodel" \
--params_filename="inference.pdiparams" \
--teacher_model_path="./infermodel_mobilenetv2/" \
--teacher_model_filename="inference.pdmodel" \
--teacher_params_filename="inference.pdiparams" \
--distill_node_name_list "teacher_conv2d_54.tmp_0" "conv2d_54.tmp_0" "teacher_conv2d_55.tmp_0" "conv2d_55.tmp_0" \
"teacher_conv2d_57.tmp_0" "conv2d_57.tmp_0" "teacher_elementwise_add_0" "elementwise_add_0" \
"teacher_conv2d_61.tmp_0" "conv2d_61.tmp_0" "teacher_elementwise_add_1" "elementwise_add_1" \
"teacher_elementwise_add_2" "elementwise_add_2" "teacher_conv2d_67.tmp_0" "conv2d_67.tmp_0" \
"teacher_elementwise_add_3" "elementwise_add_3" "teacher_elementwise_add_4" "elementwise_add_4" \
"teacher_elementwise_add_5" "elementwise_add_5" "teacher_conv2d_75.tmp_0" "conv2d_75.tmp_0" \
"teacher_elementwise_add_6" "elementwise_add_6" "teacher_elementwise_add_7" "elementwise_add_7" \
"teacher_conv2d_81.tmp_0" "conv2d_81.tmp_0" "teacher_elementwise_add_8" "elementwise_add_8" \
"teacher_elementwise_add_9" "elementwise_add_9" "teacher_conv2d_87.tmp_0" "conv2d_87.tmp_0" \
"teacher_linear_1.tmp_0" "linear_1.tmp_0"
```
- ``batch_size``: 量化训练batch size。
- ``num_epoch``: 量化训练epoch数。
- ``save_iter_step``: 每隔save_iter_step保存一次checkpoint。
- ``learning_rate``: 量化训练学习率,推荐使用float模型训练最小一级学习率。
- ``weight_decay``: 推荐使用float模型训练weight decay设置。
- ``use_pact``: 是否使用pact量化算法, 推荐使用。
- ``checkpoint_path``: 量化训练模型checkpoint保存路径。
- ``model_path``: 需要量化的预测模型所在路径。
- ``model_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的模型文件名称,如果参数文件保存在多个文件中,则不需要设置。
- ``params_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的参数文件名称,如果参数文件保存在多个文件中,则不需要设置。
- ``teacher_model_path``: teacher模型所在路径, 可以和量化模型是同一个,即自蒸馏。
- ``teacher_model_filename``: teacher模型model文件名字。
- ``teacher_params_filename``: teacher模型参数文件名字。
- ``distill_node_name_list``: 蒸馏节点名字列表,每两个节点组成一对,分别属于teacher模型和量化模型。
运行以上命令后,可在``${checkpoint_path}``下看到量化后模型的checkpoint。
### 4. 量化模型导出
量化模型checkpoint导出为预测模型。
```
python export_quantmodel.py \
--use_gpu=True \
--checkpoint_path="./MobileNetV2_checkpoints/epoch_0_iter_2000" \
--infermodel_save_path="./quant_infermodel_mobilenetv2/" \
```
###5. 测试精度
使用[eval.py](../quant_post/eval.py)脚本对量化后的模型进行精度测试:
```
python ../quant_post/eval.py --model_path ./quant_infermodel_mobilenetv2/ --model_name model --params_name params
```
精度输出为:
```
top1_acc/top5_acc= [0.71764 0.90418]
```
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import math
import time
import numpy as np
import paddle
import logging
import argparse
import functools
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
sys.path[1] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir, os.path.pardir)
from paddleslim.common import get_logger
from paddleslim.quant import export_quant_infermodel
from utility import add_arguments, print_arguments
import imagenet_reader as reader
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('batch_size', int, 4, "train batch size.")
add_arg('num_epoch', int, 1, "train epoch num.")
add_arg('save_iter_step', int, 1, "save train checkpoint every save_iter_step iter num.")
add_arg('learning_rate', float, 0.0001, "learning rate.")
add_arg('weight_decay', float, 0.00004, "weight decay.")
add_arg('use_pact', bool, True, "whether use pact quantization.")
add_arg('checkpoint_path', str, None, "model dir to save quanted model checkpoints")
add_arg('model_path_prefix', str, None, "storage directory of model + model name (excluding suffix)")
add_arg('teacher_model_path_prefix', str, None, "storage directory of teacher model + teacher model name (excluding suffix)")
add_arg('distill_node_name_list', str, None, "distill node name list", nargs="+")
add_arg('checkpoint_filename', str, None, "checkpoint filename to export inference model")
add_arg('export_inference_model_path_prefix', str, None, "inference model export path prefix")
def export(args):
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
exe = paddle.static.Executor(place)
quant_config = {
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'not_quant_pattern': ['skip_quant'],
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul']
}
train_config={
"num_epoch": args.num_epoch, # training epoch num
"max_iter": -1,
"save_iter_step": args.save_iter_step,
"learning_rate": args.learning_rate,
"weight_decay": args.weight_decay,
"use_pact": args.use_pact,
"quant_model_ckpt_path":args.checkpoint_path,
"teacher_model_path_prefix": args.teacher_model_path_prefix,
"model_path_prefix": args.model_path_prefix,
"distill_node_pair": args.distill_node_name_list
}
export_quant_infermodel(exe, place,
scope=None,
quant_config=quant_config,
train_config=train_config,
checkpoint_path=os.path.join(args.checkpoint_path, args.checkpoint_filename),
export_inference_model_path_prefix=args.export_inference_model_path_prefix)
def main():
args = parser.parse_args()
args.use_pact = bool(args.use_pact)
print_arguments(args)
export(args)
if __name__ == '__main__':
paddle.enable_static()
main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import math
import time
import numpy as np
import paddle
import logging
import argparse
import functools
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
sys.path[1] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir, os.path.pardir)
from paddleslim.common import get_logger
from paddleslim.quant import quant_aware_with_infermodel
from utility import add_arguments, print_arguments
import imagenet_reader as reader
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "whether to use GPU or not.")
add_arg('batch_size', int, 1, "train batch size.")
add_arg('num_epoch', int, 1, "train epoch num.")
add_arg('save_iter_step', int, 1, "save train checkpoint every save_iter_step iter num.")
add_arg('learning_rate', float, 0.0001, "learning rate.")
add_arg('weight_decay', float, 0.00004, "weight decay.")
add_arg('use_pact', bool, True, "whether use pact quantization.")
add_arg('checkpoint_path', str, None, "model dir to save quanted model checkpoints")
add_arg('model_path_prefix', str, None, "storage directory of model + model name (excluding suffix)")
add_arg('teacher_model_path_prefix', str, None, "storage directory of teacher model + teacher model name (excluding suffix)")
add_arg('distill_node_name_list', str, None, "distill node name list", nargs="+")
DATA_DIR = "../../data/ILSVRC2012/"
def eval(exe, place, compiled_test_program, test_feed_names, test_fetch_list):
val_reader = paddle.batch(reader.val(), batch_size=1)
image = paddle.static.data(
name='x', shape=[None, 3, 224, 224], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
results = []
for batch_id, data in enumerate(val_reader()):
# top1_acc, top5_acc
if len(test_feed_names) == 1:
# eval "infer model", which input is image, output is classification probability
image = data[0][0].reshape((1, 3, 224, 224))
label = [[d[1]] for d in data]
pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = np.array(pred[0])
label = np.array(label)
sort_array = pred.argsort(axis=1)
top_1_pred = sort_array[:, -1:][:, ::-1]
top_1 = np.mean(label == top_1_pred)
top_5_pred = sort_array[:, -5:][:, ::-1]
acc_num = 0
for i in range(len(label)):
if label[i][0] in top_5_pred[i]:
acc_num += 1
top_5 = float(acc_num) / len(label)
results.append([top_1, top_5])
else:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = data[0][0].reshape((1, 3, 224, 224))
label = [[d[1]] for d in data]
result = exe.run(compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = [np.mean(r) for r in result]
results.append(result)
result = np.mean(np.array(results), axis=0)
return result
def quantize(args):
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
#place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
quant_config = {
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'not_quant_pattern': ['skip_quant'],
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul']
}
train_config={
"num_epoch": args.num_epoch, # training epoch num
"max_iter": -1,
"save_iter_step": args.save_iter_step,
"learning_rate": args.learning_rate,
"weight_decay": args.weight_decay,
"use_pact": args.use_pact,
"quant_model_ckpt_path":args.checkpoint_path,
"teacher_model_path_prefix": args.teacher_model_path_prefix,
"model_path_prefix": args.model_path_prefix,
"distill_node_pair": args.distill_node_name_list
}
def test_callback(compiled_test_program, feed_names, fetch_list, checkpoint_name):
ret = eval(exe, place, compiled_test_program, feed_names, fetch_list)
print("{0} top1_acc/top5_acc= {1}".format(checkpoint_name, ret))
train_reader = paddle.batch(reader.train(), batch_size=args.batch_size)
def train_reader_wrapper():
def gen():
for i, data in enumerate(train_reader()):
imgs = np.float32([item[0] for item in data])
yield {"x":imgs}
return gen
quant_aware_with_infermodel(
exe,
place,
scope=None,
train_reader=train_reader_wrapper(),
quant_config=quant_config,
train_config=train_config,
test_callback=test_callback)
def main():
args = parser.parse_args()
args.use_pact=bool(args.use_pact)
print("args.use_pact", args.use_pact)
print_arguments(args)
quantize(args)
if __name__ == '__main__':
paddle.enable_static()
main()
# 静态离线量化超参搜索示例
本示例将介绍如何使用离线量化超参搜索接口``paddleslim.quant.quant_post_hpo``来对训练好的分类模型进行离线量化超参搜索。
## 分类模型的离线量化超参搜索流程
### 准备数据
``demo``文件夹下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件:
- ``'train'``文件夹,训练图片
- ``'train_list.txt'``文件
- ``'val'``文件夹,验证图片
- ``'val_list.txt'``文件
### 准备需要量化的模型
离线量化接口只支持加载通过``paddle.static.save_inference_model``接口保存的模型。因此如果您的模型是通过其他接口保存的,需要先将模型进行转化。本示例将以分类模型为例进行说明。
首先在[imagenet分类模型](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD)中下载训练好的``mobilenetv1``模型。
在当前文件夹下创建``'pretrain'``文件夹,将``mobilenetv1``模型在该文件夹下解压,解压后的目录为``pretrain/MobileNetV1_pretrained``
### 导出模型
通过运行以下命令可将模型转化为离线量化接口可用的模型:
```
python ../quant_post/export_model.py --model "MobileNet" --pretrained_model ./pretrain/MobileNetV1_pretrained --data imagenet
```
转化之后的模型存储在``inference_model/MobileNet/``文件夹下,可看到该文件夹下有``'model'``, ``'weights'``两个文件。
### 静态离线量化
接下来对导出的模型文件进行静态离线量化,静态离线量化的脚本为[quant_post_hpo.py](./quant_post_hpo.py),脚本中使用接口``paddleslim.quant.quant_post_hpo``对模型进行离线量化。运行命令为:
```
python quant_post_hpo.py \
--use_gpu=True \
--model_path="./inference_model/MobileNet/" \
--save_path="./inference_model/MobileNet_quant/" \
--model_filename="model" \
--params_filename="weights" \
--max_model_quant_count=26
```
- ``model_path``: 需要量化的模型所在路径
- ``save_path``: 量化后的模型保存的路径
- ``model_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的模型文件名称,如果参数文件保存在多个文件中,则不需要设置。
- ``params_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的参数文件名称,如果参数文件保存在多个文件中,则不需要设置。
- ``max_model_quant_count``: 最大离线量化搜索次数,次数越多产出高精度量化模型概率越大,耗时也会相应增加。建议值:大于20小于30。
运行以上命令后,可在``${save_path}``下看到量化后的模型文件和参数文件。
### 测试精度
使用[eval.py](../quant_post/eval.py)脚本对量化前后的模型进行测试,得到模型的分类精度进行对比。
首先测试量化前的模型的精度,运行以下命令:
```
python ../quant_post/eval.py --model_path ./inference_model/MobileNet --model_name model --params_name weights
```
精度输出为:
```
top1_acc/top5_acc= [0.70898 0.89534]
```
使用以下命令测试离线量化后的模型的精度:
```
python ../quant_post/eval.py --model_path ./inference_model/MobileNet_quant/ --model_name __model__ --params_name __params__
```
精度输出为
```
top1_acc/top5_acc= [0.70653 0.89369]
```
......@@ -14,6 +14,7 @@
import numpy as np
import paddle
from paddleslim.core import GraphWrapper
def merge(teacher_program,
......@@ -94,6 +95,16 @@ def merge(teacher_program,
student_program.global_block().append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=attrs)
student_graph = GraphWrapper(student_program)
for op in student_graph.ops():
belongsto_teacher = False
for inp in op.all_inputs():
if 'teacher' in inp.name():
belongsto_teacher = True
break
if belongsto_teacher:
op._op._set_attr("skip_quant", True)
def fsp_loss(teacher_var1_name,
teacher_var2_name,
......
......@@ -31,6 +31,7 @@ try:
], "training-aware and post-training quant is not supported in 2.0 alpha version paddle"
from .quanter import quant_aware, convert, quant_post_static, quant_post_dynamic
from .quanter import quant_post, quant_post_only_weight
from .quant_aware_with_infermodel import quant_aware_with_infermodel, export_quant_infermodel
from .quant_post_hpo import quant_post_hpo
except Exception as e:
_logger.warning(e)
......
此差异已折叠。
......@@ -29,6 +29,7 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import WeightQuantization
from paddle.fluid.layer_helper import LayerHelper
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
......@@ -561,3 +562,30 @@ def quant_post_dynamic(model_dir,
# For compatibility, we keep quant_post_only_weight api for now,
# and it will be deprecated in the future.
quant_post_only_weight = quant_post_dynamic
def pact(x, name=None):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = 20
u_param_attr = paddle.fluid.ParamAttr(
name=x.name + '_pact',
initializer=paddle.fluid.initializer.ConstantInitializer(
value=init_thres),
regularizer=paddle.fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(attr=u_param_attr, shape=[1], dtype=dtype)
x = paddle.fluid.layers.elementwise_sub(
x,
paddle.fluid.layers.relu(
paddle.fluid.layers.elementwise_sub(x, u_param)))
x = paddle.fluid.layers.elementwise_add(
x,
paddle.fluid.layers.relu(
paddle.fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_pact_optimizer():
return paddle.fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
sys.path.append("../")
sys.path.append(".")
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import unittest
import paddle
from paddleslim.quant import quant_aware, convert
from paddleslim.quant import quant_aware_with_infermodel, export_quant_infermodel
from static_case import StaticCase
sys.path.append("../demo")
from models import MobileNet
from layers import conv_bn_layer
import paddle.dataset.mnist as reader
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
import numpy as np
class TestQuantAwareWithInferModelCase1(StaticCase):
def test_accuracy(self):
float_infer_model_path_prefix = "./mv1_float_inference"
image = paddle.static.data(
name='image', shape=[None, 1, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
model = MobileNet()
out = model.net(input=image, class_dim=10)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
optimizer = paddle.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
weight_decay=paddle.regularizer.L2Decay(4e-5))
optimizer.minimize(avg_cost)
main_prog = paddle.static.default_main_program()
val_prog = main_prog.clone(for_test=True)
#place = paddle.CPUPlace()
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
def transform(x):
return np.reshape(x, [1, 28, 28])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform)
train_loader = paddle.io.DataLoader(
train_dataset,
places=place,
feed_list=[image, label],
drop_last=True,
batch_size=64,
return_list=False)
valid_loader = paddle.io.DataLoader(
test_dataset,
places=place,
feed_list=[image, label],
batch_size=64,
return_list=False)
def sample_generator_creator():
def __reader__():
for data in test_dataset:
image, label = data
yield image, label
return __reader__
def train(program):
iter = 0
for data in train_loader():
cost, top1, top5 = exe.run(
program,
feed=data,
fetch_list=[avg_cost, acc_top1, acc_top5])
iter += 1
if iter % 100 == 0:
print(
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5))
def test(program, outputs=[avg_cost, acc_top1, acc_top5]):
iter = 0
result = [[], [], []]
for data in valid_loader():
cost, top1, top5 = exe.run(program,
feed=data,
fetch_list=outputs)
iter += 1
if iter % 100 == 0:
print('eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5))
result[0].append(cost)
result[1].append(top1)
result[2].append(top5)
print(' avg loss {}, acc_top1 {}, acc_top5 {}'.format(
np.mean(result[0]), np.mean(result[1]), np.mean(result[2])))
return np.mean(result[1]), np.mean(result[2])
train(main_prog)
top1_1, top5_1 = test(val_prog)
paddle.static.save_inference_model(
path_prefix=float_infer_model_path_prefix,
feed_vars=[image, label],
fetch_vars=[avg_cost, acc_top1, acc_top5],
executor=exe,
program=val_prog)
quant_config = {
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'not_quant_pattern': ['skip_quant'],
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul']
}
train_config = {
"num_epoch": 1, # training epoch num
"max_iter": 20,
"save_iter_step": 10,
"learning_rate": 0.0001,
"weight_decay": 0.0001,
"use_pact": False,
"quant_model_ckpt_path":
"./quantaware_with_infermodel_checkpoints/",
"teacher_model_path_prefix": float_infer_model_path_prefix,
"model_path_prefix": float_infer_model_path_prefix,
"distill_node_pair": [
"teacher_fc_0.tmp_0", "fc_0.tmp_0",
"teacher_batch_norm_24.tmp_4", "batch_norm_24.tmp_4",
"teacher_batch_norm_22.tmp_4", "batch_norm_22.tmp_4",
"teacher_batch_norm_18.tmp_4", "batch_norm_18.tmp_4",
"teacher_batch_norm_13.tmp_4", "batch_norm_13.tmp_4",
"teacher_batch_norm_5.tmp_4", "batch_norm_5.tmp_4"
]
}
def test_callback(compiled_test_program, feed_names, fetch_list,
checkpoint_name):
outputs = fetch_list
iter = 0
result = [[], [], []]
for data in valid_loader():
cost, top1, top5 = exe.run(compiled_test_program,
feed=data,
fetch_list=fetch_list)
iter += 1
if iter % 100 == 0:
print('eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.
format(iter, cost, top1, top5))
result[0].append(cost)
result[1].append(top1)
result[2].append(top5)
print("quant model checkpoint: " + checkpoint_name +
' avg loss {}, acc_top1 {}, acc_top5 {}'.format(
np.mean(result[0]),
np.mean(result[1]), np.mean(result[2])))
return np.mean(result[1]), np.mean(result[2])
def test_quant_aware_with_infermodel(exe, place):
quant_aware_with_infermodel(
exe,
place,
scope=None,
train_reader=train_loader,
quant_config=quant_config,
train_config=train_config,
test_callback=test_callback)
def test_export_quant_infermodel(exe, place, checkpoint_path,
quant_infermodel_save_path):
export_quant_infermodel(
exe,
place,
scope=None,
quant_config=quant_config,
train_config=train_config,
checkpoint_path=checkpoint_path,
export_inference_model_path_prefix=quant_infermodel_save_path)
#place = paddle.CPUPlace()
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
test_quant_aware_with_infermodel(exe, place)
checkpoint_path = "./quantaware_with_infermodel_checkpoints/epoch_0_iter_10"
quant_infermodel_save_path = "./quantaware_with_infermodel_export"
test_export_quant_infermodel(exe, place, checkpoint_path,
quant_infermodel_save_path)
train_config["use_pact"] = True
test_quant_aware_with_infermodel(exe, place)
train_config["use_pact"] = False
checkpoint_path = "./quantaware_with_infermodel_checkpoints/epoch_0_iter_10"
quant_infermodel_save_path = "./quantaware_with_infermodel_export"
test_export_quant_infermodel(exe, place, checkpoint_path,
quant_infermodel_save_path)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册