From e406740ab13ee0fc2c7f983d30526aa668b328c7 Mon Sep 17 00:00:00 2001 From: itminner <397809320@qq.com> Date: Mon, 27 Dec 2021 18:55:43 +0800 Subject: [PATCH] quant aware with infer model (#947) quant aware with infer model --- .../quant_aware_with_infermodel/README.md | 113 +++++ .../export_quantmodel.py | 91 ++++ .../quant_aware_with_infermodel.py | 148 ++++++ demo/quant/quant_post_hpo/README.md | 72 +++ paddleslim/dist/single_distiller.py | 11 + paddleslim/quant/__init__.py | 1 + .../quant/quant_aware_with_infermodel.py | 474 ++++++++++++++++++ paddleslim/quant/quanter.py | 28 ++ tests/test_quant_aware_with_infermodel.py | 221 ++++++++ 9 files changed, 1159 insertions(+) create mode 100644 demo/quant/quant_aware_with_infermodel/README.md create mode 100755 demo/quant/quant_aware_with_infermodel/export_quantmodel.py create mode 100755 demo/quant/quant_aware_with_infermodel/quant_aware_with_infermodel.py create mode 100755 demo/quant/quant_post_hpo/README.md create mode 100644 paddleslim/quant/quant_aware_with_infermodel.py create mode 100644 tests/test_quant_aware_with_infermodel.py diff --git a/demo/quant/quant_aware_with_infermodel/README.md b/demo/quant/quant_aware_with_infermodel/README.md new file mode 100644 index 00000000..71e66e47 --- /dev/null +++ b/demo/quant/quant_aware_with_infermodel/README.md @@ -0,0 +1,113 @@ +# 使用预测模型进行量化训练示例 + +预测模型获取 +动态图使用paddle.jit.save保存; +静态图使用paddle.static.save_inference_model保存。 + +本示例将介绍如何使用预测模型进行蒸馏量化训练, +首先使用接口``paddleslim.quant.quant_aware_with_infermodel``训练量化模型, +训练完成后,使用接口``paddleslim.quant.export_quant_infermodel``将训好的量化模型导出为预测模型。 + +## 分类模型量化训练流程 + +###1. 准备数据 + +在``demo``文件夹下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件: +- ``'train'``文件夹,训练图片 +- ``'train_list.txt'``文件 +- ``'val'``文件夹,验证图片 +- ``'val_list.txt'``文件 + +### 2. 准备需要量化的模型 + +飞桨图像识别套件PaddleClas是飞桨为工业界和学术界所准备的一个图像识别任务的工具集,本示例使用该套件产出imagenet分类模型。 +####① 下载MobileNetV2预训练模型 +预训练模型库地址 ``https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/algorithm_introduction/ImageNet_models.md`` +MobileNetV2预训练模型地址 ``https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_pretrained.pdparams`` +在PaddleClas代码库根目录创建pretrained文件夹,MobileNetV2预训练参数保存在该文件夹中。 + +#### ② 导出预测模型 +PaddleClas代码库根目录执行如下命令,导出预测模型 +``` +python tools/export_model.py \ + -c ppcls/configs/ImageNet/MobileNetV2/MobileNetV2.yaml \ + -o Global.pretrained_model=pretrained/MobileNetV2_pretrained \ + -o Global.save_inference_dir=infermodel_mobilenetv2 \ +``` +#### ③ 测试模型精度 +使用[eval.py](../quant_post/eval.py)脚本得到模型的分类精度: +``` +python ../quant_post/eval.py --model_path infermodel_mobilenetv2 --model_name inference.pdmodel --params_name inference.pdiparams +``` +精度输出为: +``` +top1_acc/top5_acc= [0.71918 0.90568] +``` + +### 3. 进行量化蒸馏训练 + +蒸馏量化训练示例脚本为[quant_aware_with_infermodel.py](./quant_aware_with_infermodel.py),使用接口``paddleslim.quant.quant_aware_with_infermodel``对模型进行量化训练。运行命令为: +``` +python quant_aware_with_infermodel.py \ + --batch_size=2 \ + --num_epoch=30 \ + --save_iter_step=100 \ + --learning_rate=0.0001 \ + --weight_decay=0.00004 \ + --use_pact=True \ + --checkpoint_path="./inference_model/MobileNet_quantaware_ckpt/" \ + --model_path="./infermodel_mobilenetv2/" \ + --model_filename="inference.pdmodel" \ + --params_filename="inference.pdiparams" \ + --teacher_model_path="./infermodel_mobilenetv2/" \ + --teacher_model_filename="inference.pdmodel" \ + --teacher_params_filename="inference.pdiparams" \ + --distill_node_name_list "teacher_conv2d_54.tmp_0" "conv2d_54.tmp_0" "teacher_conv2d_55.tmp_0" "conv2d_55.tmp_0" \ + "teacher_conv2d_57.tmp_0" "conv2d_57.tmp_0" "teacher_elementwise_add_0" "elementwise_add_0" \ + "teacher_conv2d_61.tmp_0" "conv2d_61.tmp_0" "teacher_elementwise_add_1" "elementwise_add_1" \ + "teacher_elementwise_add_2" "elementwise_add_2" "teacher_conv2d_67.tmp_0" "conv2d_67.tmp_0" \ + "teacher_elementwise_add_3" "elementwise_add_3" "teacher_elementwise_add_4" "elementwise_add_4" \ + "teacher_elementwise_add_5" "elementwise_add_5" "teacher_conv2d_75.tmp_0" "conv2d_75.tmp_0" \ + "teacher_elementwise_add_6" "elementwise_add_6" "teacher_elementwise_add_7" "elementwise_add_7" \ + "teacher_conv2d_81.tmp_0" "conv2d_81.tmp_0" "teacher_elementwise_add_8" "elementwise_add_8" \ + "teacher_elementwise_add_9" "elementwise_add_9" "teacher_conv2d_87.tmp_0" "conv2d_87.tmp_0" \ + "teacher_linear_1.tmp_0" "linear_1.tmp_0" +``` +- ``batch_size``: 量化训练batch size。 +- ``num_epoch``: 量化训练epoch数。 +- ``save_iter_step``: 每隔save_iter_step保存一次checkpoint。 +- ``learning_rate``: 量化训练学习率,推荐使用float模型训练最小一级学习率。 +- ``weight_decay``: 推荐使用float模型训练weight decay设置。 +- ``use_pact``: 是否使用pact量化算法, 推荐使用。 +- ``checkpoint_path``: 量化训练模型checkpoint保存路径。 +- ``model_path``: 需要量化的预测模型所在路径。 +- ``model_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的模型文件名称,如果参数文件保存在多个文件中,则不需要设置。 +- ``params_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的参数文件名称,如果参数文件保存在多个文件中,则不需要设置。 +- ``teacher_model_path``: teacher模型所在路径, 可以和量化模型是同一个,即自蒸馏。 +- ``teacher_model_filename``: teacher模型model文件名字。 +- ``teacher_params_filename``: teacher模型参数文件名字。 +- ``distill_node_name_list``: 蒸馏节点名字列表,每两个节点组成一对,分别属于teacher模型和量化模型。 + +运行以上命令后,可在``${checkpoint_path}``下看到量化后模型的checkpoint。 + +### 4. 量化模型导出 + +量化模型checkpoint导出为预测模型。 + +``` +python export_quantmodel.py \ + --use_gpu=True \ + --checkpoint_path="./MobileNetV2_checkpoints/epoch_0_iter_2000" \ + --infermodel_save_path="./quant_infermodel_mobilenetv2/" \ +``` + +###5. 测试精度 + +使用[eval.py](../quant_post/eval.py)脚本对量化后的模型进行精度测试: +``` +python ../quant_post/eval.py --model_path ./quant_infermodel_mobilenetv2/ --model_name model --params_name params +``` +精度输出为: +``` +top1_acc/top5_acc= [0.71764 0.90418] +``` diff --git a/demo/quant/quant_aware_with_infermodel/export_quantmodel.py b/demo/quant/quant_aware_with_infermodel/export_quantmodel.py new file mode 100755 index 00000000..e7e2306e --- /dev/null +++ b/demo/quant/quant_aware_with_infermodel/export_quantmodel.py @@ -0,0 +1,91 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import math +import time +import numpy as np +import paddle +import logging +import argparse +import functools + +sys.path[0] = os.path.join( + os.path.dirname("__file__"), os.path.pardir, os.path.pardir) +sys.path[1] = os.path.join( + os.path.dirname("__file__"), os.path.pardir, os.path.pardir, os.path.pardir) +from paddleslim.common import get_logger +from paddleslim.quant import export_quant_infermodel +from utility import add_arguments, print_arguments +import imagenet_reader as reader +_logger = get_logger(__name__, level=logging.INFO) + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('batch_size', int, 4, "train batch size.") +add_arg('num_epoch', int, 1, "train epoch num.") +add_arg('save_iter_step', int, 1, "save train checkpoint every save_iter_step iter num.") +add_arg('learning_rate', float, 0.0001, "learning rate.") +add_arg('weight_decay', float, 0.00004, "weight decay.") +add_arg('use_pact', bool, True, "whether use pact quantization.") +add_arg('checkpoint_path', str, None, "model dir to save quanted model checkpoints") +add_arg('model_path_prefix', str, None, "storage directory of model + model name (excluding suffix)") +add_arg('teacher_model_path_prefix', str, None, "storage directory of teacher model + teacher model name (excluding suffix)") +add_arg('distill_node_name_list', str, None, "distill node name list", nargs="+") +add_arg('checkpoint_filename', str, None, "checkpoint filename to export inference model") +add_arg('export_inference_model_path_prefix', str, None, "inference model export path prefix") + +def export(args): + place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace() + exe = paddle.static.Executor(place) + + quant_config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'not_quant_pattern': ['skip_quant'], + 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'] + } + train_config={ + "num_epoch": args.num_epoch, # training epoch num + "max_iter": -1, + "save_iter_step": args.save_iter_step, + "learning_rate": args.learning_rate, + "weight_decay": args.weight_decay, + "use_pact": args.use_pact, + "quant_model_ckpt_path":args.checkpoint_path, + "teacher_model_path_prefix": args.teacher_model_path_prefix, + "model_path_prefix": args.model_path_prefix, + "distill_node_pair": args.distill_node_name_list + } + + export_quant_infermodel(exe, place, + scope=None, + quant_config=quant_config, + train_config=train_config, + checkpoint_path=os.path.join(args.checkpoint_path, args.checkpoint_filename), + export_inference_model_path_prefix=args.export_inference_model_path_prefix) + +def main(): + args = parser.parse_args() + args.use_pact = bool(args.use_pact) + print_arguments(args) + export(args) + + +if __name__ == '__main__': + paddle.enable_static() + main() diff --git a/demo/quant/quant_aware_with_infermodel/quant_aware_with_infermodel.py b/demo/quant/quant_aware_with_infermodel/quant_aware_with_infermodel.py new file mode 100755 index 00000000..430defbf --- /dev/null +++ b/demo/quant/quant_aware_with_infermodel/quant_aware_with_infermodel.py @@ -0,0 +1,148 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import math +import time +import numpy as np +import paddle +import logging +import argparse +import functools + +sys.path[0] = os.path.join( + os.path.dirname("__file__"), os.path.pardir, os.path.pardir) +sys.path[1] = os.path.join( + os.path.dirname("__file__"), os.path.pardir, os.path.pardir, os.path.pardir) +from paddleslim.common import get_logger +from paddleslim.quant import quant_aware_with_infermodel +from utility import add_arguments, print_arguments +import imagenet_reader as reader +_logger = get_logger(__name__, level=logging.INFO) + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('use_gpu', bool, True, "whether to use GPU or not.") +add_arg('batch_size', int, 1, "train batch size.") +add_arg('num_epoch', int, 1, "train epoch num.") +add_arg('save_iter_step', int, 1, "save train checkpoint every save_iter_step iter num.") +add_arg('learning_rate', float, 0.0001, "learning rate.") +add_arg('weight_decay', float, 0.00004, "weight decay.") +add_arg('use_pact', bool, True, "whether use pact quantization.") +add_arg('checkpoint_path', str, None, "model dir to save quanted model checkpoints") +add_arg('model_path_prefix', str, None, "storage directory of model + model name (excluding suffix)") +add_arg('teacher_model_path_prefix', str, None, "storage directory of teacher model + teacher model name (excluding suffix)") +add_arg('distill_node_name_list', str, None, "distill node name list", nargs="+") + + +DATA_DIR = "../../data/ILSVRC2012/" +def eval(exe, place, compiled_test_program, test_feed_names, test_fetch_list): + val_reader = paddle.batch(reader.val(), batch_size=1) + image = paddle.static.data( + name='x', shape=[None, 3, 224, 224], dtype='float32') + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + + results = [] + for batch_id, data in enumerate(val_reader()): + # top1_acc, top5_acc + if len(test_feed_names) == 1: + # eval "infer model", which input is image, output is classification probability + image = data[0][0].reshape((1, 3, 224, 224)) + label = [[d[1]] for d in data] + pred = exe.run(compiled_test_program, + feed={test_feed_names[0]: image}, + fetch_list=test_fetch_list) + pred = np.array(pred[0]) + label = np.array(label) + sort_array = pred.argsort(axis=1) + top_1_pred = sort_array[:, -1:][:, ::-1] + top_1 = np.mean(label == top_1_pred) + top_5_pred = sort_array[:, -5:][:, ::-1] + acc_num = 0 + for i in range(len(label)): + if label[i][0] in top_5_pred[i]: + acc_num += 1 + top_5 = float(acc_num) / len(label) + results.append([top_1, top_5]) + else: + # eval "eval model", which inputs are image and label, output is top1 and top5 accuracy + image = data[0][0].reshape((1, 3, 224, 224)) + label = [[d[1]] for d in data] + result = exe.run(compiled_test_program, + feed={ + test_feed_names[0]: image, + test_feed_names[1]: label + }, + fetch_list=test_fetch_list) + result = [np.mean(r) for r in result] + results.append(result) + result = np.mean(np.array(results), axis=0) + return result + +def quantize(args): + place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace() + #place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + quant_config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'not_quant_pattern': ['skip_quant'], + 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'] + } + train_config={ + "num_epoch": args.num_epoch, # training epoch num + "max_iter": -1, + "save_iter_step": args.save_iter_step, + "learning_rate": args.learning_rate, + "weight_decay": args.weight_decay, + "use_pact": args.use_pact, + "quant_model_ckpt_path":args.checkpoint_path, + "teacher_model_path_prefix": args.teacher_model_path_prefix, + "model_path_prefix": args.model_path_prefix, + "distill_node_pair": args.distill_node_name_list + } + def test_callback(compiled_test_program, feed_names, fetch_list, checkpoint_name): + ret = eval(exe, place, compiled_test_program, feed_names, fetch_list) + print("{0} top1_acc/top5_acc= {1}".format(checkpoint_name, ret)) + + train_reader = paddle.batch(reader.train(), batch_size=args.batch_size) + def train_reader_wrapper(): + def gen(): + for i, data in enumerate(train_reader()): + imgs = np.float32([item[0] for item in data]) + yield {"x":imgs} + return gen + quant_aware_with_infermodel( + exe, + place, + scope=None, + train_reader=train_reader_wrapper(), + quant_config=quant_config, + train_config=train_config, + test_callback=test_callback) + +def main(): + args = parser.parse_args() + args.use_pact=bool(args.use_pact) + print("args.use_pact", args.use_pact) + print_arguments(args) + quantize(args) + + +if __name__ == '__main__': + paddle.enable_static() + main() diff --git a/demo/quant/quant_post_hpo/README.md b/demo/quant/quant_post_hpo/README.md new file mode 100755 index 00000000..dddc9baa --- /dev/null +++ b/demo/quant/quant_post_hpo/README.md @@ -0,0 +1,72 @@ +# 静态离线量化超参搜索示例 + +本示例将介绍如何使用离线量化超参搜索接口``paddleslim.quant.quant_post_hpo``来对训练好的分类模型进行离线量化超参搜索。 + +## 分类模型的离线量化超参搜索流程 + +### 准备数据 + +在``demo``文件夹下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件: +- ``'train'``文件夹,训练图片 +- ``'train_list.txt'``文件 +- ``'val'``文件夹,验证图片 +- ``'val_list.txt'``文件 + +### 准备需要量化的模型 +离线量化接口只支持加载通过``paddle.static.save_inference_model``接口保存的模型。因此如果您的模型是通过其他接口保存的,需要先将模型进行转化。本示例将以分类模型为例进行说明。 + +首先在[imagenet分类模型](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD)中下载训练好的``mobilenetv1``模型。 + +在当前文件夹下创建``'pretrain'``文件夹,将``mobilenetv1``模型在该文件夹下解压,解压后的目录为``pretrain/MobileNetV1_pretrained`` + +### 导出模型 +通过运行以下命令可将模型转化为离线量化接口可用的模型: +``` +python ../quant_post/export_model.py --model "MobileNet" --pretrained_model ./pretrain/MobileNetV1_pretrained --data imagenet +``` +转化之后的模型存储在``inference_model/MobileNet/``文件夹下,可看到该文件夹下有``'model'``, ``'weights'``两个文件。 + +### 静态离线量化 +接下来对导出的模型文件进行静态离线量化,静态离线量化的脚本为[quant_post_hpo.py](./quant_post_hpo.py),脚本中使用接口``paddleslim.quant.quant_post_hpo``对模型进行离线量化。运行命令为: +``` +python quant_post_hpo.py \ + --use_gpu=True \ + --model_path="./inference_model/MobileNet/" \ + --save_path="./inference_model/MobileNet_quant/" \ + --model_filename="model" \ + --params_filename="weights" \ + --max_model_quant_count=26 +``` + +- ``model_path``: 需要量化的模型所在路径 +- ``save_path``: 量化后的模型保存的路径 +- ``model_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的模型文件名称,如果参数文件保存在多个文件中,则不需要设置。 +- ``params_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的参数文件名称,如果参数文件保存在多个文件中,则不需要设置。 +- ``max_model_quant_count``: 最大离线量化搜索次数,次数越多产出高精度量化模型概率越大,耗时也会相应增加。建议值:大于20小于30。 + +运行以上命令后,可在``${save_path}``下看到量化后的模型文件和参数文件。 + + +### 测试精度 + +使用[eval.py](../quant_post/eval.py)脚本对量化前后的模型进行测试,得到模型的分类精度进行对比。 + +首先测试量化前的模型的精度,运行以下命令: +``` +python ../quant_post/eval.py --model_path ./inference_model/MobileNet --model_name model --params_name weights +``` +精度输出为: +``` +top1_acc/top5_acc= [0.70898 0.89534] +``` + +使用以下命令测试离线量化后的模型的精度: + +``` +python ../quant_post/eval.py --model_path ./inference_model/MobileNet_quant/ --model_name __model__ --params_name __params__ +``` + +精度输出为 +``` +top1_acc/top5_acc= [0.70653 0.89369] +``` diff --git a/paddleslim/dist/single_distiller.py b/paddleslim/dist/single_distiller.py index be29db6d..fbc190ca 100644 --- a/paddleslim/dist/single_distiller.py +++ b/paddleslim/dist/single_distiller.py @@ -14,6 +14,7 @@ import numpy as np import paddle +from paddleslim.core import GraphWrapper def merge(teacher_program, @@ -94,6 +95,16 @@ def merge(teacher_program, student_program.global_block().append_op( type=op.type, inputs=inputs, outputs=outputs, attrs=attrs) + student_graph = GraphWrapper(student_program) + for op in student_graph.ops(): + belongsto_teacher = False + for inp in op.all_inputs(): + if 'teacher' in inp.name(): + belongsto_teacher = True + break + if belongsto_teacher: + op._op._set_attr("skip_quant", True) + def fsp_loss(teacher_var1_name, teacher_var2_name, diff --git a/paddleslim/quant/__init__.py b/paddleslim/quant/__init__.py index cc8acda5..134165b3 100644 --- a/paddleslim/quant/__init__.py +++ b/paddleslim/quant/__init__.py @@ -31,6 +31,7 @@ try: ], "training-aware and post-training quant is not supported in 2.0 alpha version paddle" from .quanter import quant_aware, convert, quant_post_static, quant_post_dynamic from .quanter import quant_post, quant_post_only_weight + from .quant_aware_with_infermodel import quant_aware_with_infermodel, export_quant_infermodel from .quant_post_hpo import quant_post_hpo except Exception as e: _logger.warning(e) diff --git a/paddleslim/quant/quant_aware_with_infermodel.py b/paddleslim/quant/quant_aware_with_infermodel.py new file mode 100644 index 00000000..433c7ce1 --- /dev/null +++ b/paddleslim/quant/quant_aware_with_infermodel.py @@ -0,0 +1,474 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""train aware quant with infermodel""" + +import copy +import os +import argparse +import json +import six +from collections import namedtuple +import time +import shutil +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid import unique_name +from paddle.fluid import core +from paddle.fluid.framework import Parameter +from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss +from paddleslim.core import GraphWrapper +from paddleslim.quant import quant_aware, convert +from .quanter import _quant_config_default, _parse_configs, pact, get_pact_optimizer +import logging +logging.getLogger().setLevel(logging.INFO) +from ..common import get_logger +_logger = get_logger(__name__, level=logging.INFO) + +############################################################################################################ +# quantization training configs +############################################################################################################ +_train_config_default = { + # configs of training aware quantization with infermodel + "num_epoch": 1000, # training epoch num + "max_iter": -1, # max training iteration num + "save_iter_step": + 1000, # save quant model checkpoint every save_iter_step iteration + "learning_rate": 0.0001, # learning rate + "weight_decay": 0.0001, # weight decay + "use_pact": False, # use pact quantization or not + # quant model checkpoints save path + "quant_model_ckpt_path": "./quant_model_checkpoints/", + # storage directory of teacher model + teacher model name (excluding suffix) + "teacher_model_path_prefix": None, + # storage directory of model + model name (excluding suffix) + "model_path_prefix": None, + """ distillation node configuration: + the name of the distillation supervision nodes is configured as a list, + and the teacher node and student node are arranged in pairs. + for example, ["teacher_fc_0.tmp_0", "fc_0.tmp_0", "teacher_batch_norm_24.tmp_4", "batch_norm_24.tmp_4"] + """ + "distill_node_pair": None +} + + +def _parse_train_configs(train_config): + """ + check if user's train configs are valid. + Args: + train_config(dict): user's train config. + Return: + configs(dict): final configs will be used. + """ + + configs = copy.deepcopy(_train_config_default) + configs.update(train_config) + + assert isinstance(configs['num_epoch'], int), \ + "'num_epoch' must be int value" + assert isinstance(configs['max_iter'], int), \ + "'max_iter' must be int value" + assert isinstance(configs['save_iter_step'], int), \ + "'save_iter_step' must be int value" + assert isinstance(configs['learning_rate'], float), \ + "'learning_rate' must be float" + assert isinstance(configs['weight_decay'], float), \ + "'weight_decay' must be float" + assert isinstance(configs['use_pact'], bool), \ + "'use_pact' must be bool" + assert isinstance(configs['quant_model_ckpt_path'], str), \ + "'quant_model_ckpt_path' must be str" + assert isinstance(configs['teacher_model_path_prefix'], str), \ + "'teacher_model_path_prefix' must both be string" + assert isinstance(configs['model_path_prefix'], str), \ + "'model_path_prefix' must both be str" + assert isinstance(configs['distill_node_pair'], list), \ + "'distill_node_pair' must both be list" + assert len(configs['distill_node_pair']) > 0, \ + "'distill_node_pair' not configured with distillation nodes" + assert len(configs['distill_node_pair']) % 2 == 0, \ + "'distill_node_pair' distillation nodes need to be configured in pairs" + return train_config + + +def _create_optimizer(train_config): + """create optimizer""" + optimizer = paddle.optimizer.SGD( + learning_rate=train_config["learning_rate"], + weight_decay=paddle.regularizer.L2Decay(train_config["weight_decay"])) + return optimizer + + +def _remove_fetch_node(program): + """remove fetch node in program""" + for block in program.blocks: + removed = 0 + ops = list(block.ops) + for op in ops: + if op.type == "fetch": + idx = ops.index(op) + block._remove_op(idx - removed) + removed += 1 + + +def _recover_reserve_space_with_bn(program): + """Add the outputs which is only used for training and not saved in + inference program.""" + for block_idx in six.moves.range(program.num_blocks): + block = program.block(block_idx) + for op in block.ops: + if op.type == "batch_norm": + if "ReserveSpace" not in op.output_names or len( + op.output("ReserveSpace")) == 0: + reserve_space = block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["reserve_space", 'tmp'])), + dtype=block.var(op.input("X")[0]).dtype, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=True) + op.desc.set_output("ReserveSpace", [reserve_space.name]) + return program + + +def _recover_param_attr(program): + """recover parameters attribute. + Params in infermodel are stored in the form of variable, which can not be trained.""" + all_weights = [param for param in program.list_vars() \ + if param.persistable is True and param.name != 'feed' and param.name != 'fetch'] + for w in all_weights: + new_w = Parameter( + block=program.block(0), + shape=w.shape, + dtype=w.dtype, + type=w.type, + name=w.name) + new_w.set_value(w.get_value()) + program.block(0).vars[w.name] = new_w + return program + + +def _parse_distill_loss(train_config): + """parse distill loss config""" + assert len(train_config["distill_node_pair"]) % 2 == 0, \ + "distill_node_pair config wrong, the length needs to be an even number" + print("train config.distill_node_pair: ", train_config["distill_node_pair"]) + distill_loss = 0 + for i in range(len(train_config["distill_node_pair"]) // 2): + print(train_config["distill_node_pair"][i * 2], + train_config["distill_node_pair"][i * 2 + 1]) + distill_loss += l2_loss(train_config["distill_node_pair"][i * 2], + train_config["distill_node_pair"][i * 2 + 1]) + return distill_loss + +DistillProgramInfo = namedtuple("DistillProgramInfo", \ + "startup_program train_program train_feed_names train_fetch_list \ + optimizer test_program test_feed_names test_fetch_list" + ) + + +def build_distill_prog_with_infermodel(executor, place, train_config): + """build distill program with infermodel""" + [train_program, feed_target_names, fetch_targets]= paddle.static.load_inference_model( \ + path_prefix=train_config["model_path_prefix"], \ + executor=executor) + _remove_fetch_node(train_program) + [teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.static.load_inference_model( \ + path_prefix=train_config["teacher_model_path_prefix"], \ + executor=executor) + _remove_fetch_node(teacher_program) + test_program = train_program.clone(for_test=True) + + train_program = _recover_param_attr(train_program) + train_program = _recover_reserve_space_with_bn(train_program) + for var in train_program.list_vars(): + var.stop_gradient = False + train_graph = GraphWrapper(train_program) + for op in train_graph.ops(): + op._op._set_attr("is_test", False) + + ############################################################################ + # distill + ############################################################################ + data_name_map = {} + assert len(feed_target_names) == len(teacher_feed_target_names), \ + "the number of feed nodes in the teacher model is not equal to the student model" + for i, name in enumerate(feed_target_names): + data_name_map[teacher_feed_target_names[i]] = name + merge(teacher_program, train_program, data_name_map, place) + + # all feed node should set stop_gradient is False, for using pact quant algo. + for var in train_program.list_vars(): + if var.name in data_name_map.values() or var.name in data_name_map.keys( + ): + var.stop_gradient = False + + train_fetch_list = [] + train_fetch_name_list = [] + startup_program = paddle.static.Program() + with paddle.static.program_guard(train_program, startup_program): + with fluid.unique_name.guard('merge'): + optimizer = _create_optimizer(train_config) + + distill_loss = _parse_distill_loss(train_config) + loss = paddle.mean(distill_loss) + loss.stop_gradient = False + p_g_list = paddle.static.append_backward(loss=loss) + opts = optimizer.apply_gradients(p_g_list) + + train_fetch_list.append(loss) + train_fetch_name_list.append(loss.name) + + return DistillProgramInfo(startup_program, train_program, \ + feed_target_names, train_fetch_list, optimizer, \ + test_program, feed_target_names, fetch_targets) + + +def _compile_program(program, fetch_var_name): + """compiling program""" + compiled_prog = paddle.static.CompiledProgram(program) + build_strategy = paddle.static.BuildStrategy() + build_strategy.memory_optimize = False + build_strategy.enable_inplace = False + build_strategy.fuse_all_reduce_ops = False + build_strategy.sync_batch_norm = False + exec_strategy = paddle.static.ExecutionStrategy() + compiled_prog = compiled_prog.with_data_parallel( + loss_name=fetch_var_name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + return compiled_prog + + +def quant_aware_with_infermodel(executor, + place, + scope=None, + train_reader=None, + quant_config=None, + train_config=None, + test_callback=None): + """train aware quantization with infermodel + Args: + executor(paddle.static.Executor): The executor to load, run and save the + quantized model. + place(paddle.CPUPlace or paddle.CUDAPlace): This parameter represents + the executor run on which device. + scope(paddle.static.Scope, optional): Scope records the mapping between + variable names and variables, similar to brackets in + programming languages. Usually users can use + `paddle.static.global_scope `_. + When ``None`` will use + `paddle.static.global_scope() `_ + . Default: ``None``. + train_reader(data generator): data generator, yield feed_dictionary, {feed_name[0]:data[0], feed_name[1]:data[1]}. + quant_config(dict, optional): configs for convert. if set None, will use + default config. It must be same with config that used in + 'quant_aware'. Default is None. + train_config(dict):train aware configs, include num_epoch, save_iter_step, learning_rate, + weight_decay, use_pact, quant_model_ckpt_path, + model_path_prefix, teacher_model_path_prefix, + distill_node_pair(teacher_node_name1, node_name1, teacher_node_name2, teacher_node_name2, ...) + test_callback(callback function): callback function include two params: compiled test quant program and checkpoint save filename. + user can implement test logic. + Returns: + None + """ + scope = paddle.static.global_scope() if not scope else scope + # parse quant config + if quant_config is None: + quant_config = _quant_config_default + else: + assert isinstance(quant_config, dict), "quant config must be dict" + quant_config = _parse_configs(quant_config) + _logger.info("quant_aware config {}".format(quant_config)) + + train_config = _parse_train_configs(train_config) + distill_program_info = build_distill_prog_with_infermodel(executor, place, + train_config) + startup_program = distill_program_info.startup_program + train_program = distill_program_info.train_program + train_feed_names = distill_program_info.train_feed_names + train_fetch_list = distill_program_info.train_fetch_list + optimizer = distill_program_info.optimizer + test_program = distill_program_info.test_program + test_feed_names = distill_program_info.test_feed_names + test_fetch_list = distill_program_info.test_fetch_list + + ############################################################################ + # quant + ############################################################################ + use_pact = train_config["use_pact"] + if use_pact: + act_preprocess_func = pact + optimizer_func = get_pact_optimizer + pact_executor = executor + else: + act_preprocess_func = None + optimizer_func = None + pact_executor = None + + test_program = quant_aware( + test_program, + place, + quant_config, + scope=scope, + act_preprocess_func=act_preprocess_func, + optimizer_func=optimizer_func, + executor=pact_executor, + for_test=True) + train_program = quant_aware( + train_program, + place, + quant_config, + scope=scope, + act_preprocess_func=act_preprocess_func, + optimizer_func=optimizer_func, + executor=pact_executor, + for_test=False, + return_program=True) + + executor.run(startup_program) + compiled_train_prog = _compile_program(train_program, + train_fetch_list[0].name) + compiled_test_prog = _compile_program(test_program, test_fetch_list[0].name) + num_epoch = train_config["num_epoch"] + save_iter_step = train_config["save_iter_step"] + iter_sum = 0 + for epoch in range(num_epoch): + for iter_num, feed_dict in enumerate(train_reader()): + np_probs_float = executor.run(compiled_train_prog, \ + feed=feed_dict, \ + fetch_list=train_fetch_list) + print("loss: ", np_probs_float) + + if iter_num > 0 and iter_num % save_iter_step == 0: + checkpoint_name = "epoch_" + str(epoch) + "_iter_" + str( + iter_num) + paddle.static.save( + program=test_program, + model_path=os.path.join( + train_config["quant_model_ckpt_path"], checkpoint_name)) + test_callback(compiled_test_prog, test_feed_names, + test_fetch_list, checkpoint_name) + iter_sum += 1 + if train_config["max_iter"] >= 0 and iter_sum > train_config[ + "max_iter"]: + return + + +def export_quant_infermodel( + executor, + place=None, + scope=None, + quant_config=None, + train_config=None, + checkpoint_path=None, + export_inference_model_path_prefix="./export_quant_infermodel"): + """export quant model checkpoints to infermodel. + Args: + executor(paddle.static.Executor): The executor to load, run and save the + quantized model. + place(paddle.CPUPlace or paddle.CUDAPlace): This parameter represents + the executor run on which device. + scope(paddle.static.Scope, optional): Scope records the mapping between + variable names and variables, similar to brackets in + programming languages. Usually users can use + `paddle.static.global_scope `_. + When ``None`` will use + `paddle.static.global_scope() `_ + . Default: ``None``. + quant_config(dict, optional): configs for convert. if set None, will use + default config. It must be same with config that used in + 'quant_aware'. Default is None. + train_config(dict):train aware configs, include num_epoch, save_iter_step, learning_rate, + weight_decay, use_pact, quant_model_ckpt_path, + model_path_prefix, teacher_model_path_prefix, + distill_node_pair(teacher_node_name1, node_name1, teacher_node_name2, teacher_node_name2, ...) + checkpoint_path(str): checkpoint path need to export quant infer model. + export_inference_model_path_prefix(str): export infer model path prefix, storage directory of model + model name (excluding suffix). + Returns: + None + """ + scope = paddle.static.global_scope() if not scope else scope + # parse quant config + if quant_config is None: + quant_config = _quant_config_default + else: + assert isinstance(quant_config, dict), "quant config must be dict" + quant_config = _parse_configs(quant_config) + _logger.info("quant_aware config {}".format(quant_config)) + + train_config = _parse_train_configs(train_config) + distill_program_info = build_distill_prog_with_infermodel(executor, place, + train_config) + test_program = distill_program_info.test_program + test_feed_names = distill_program_info.test_feed_names + test_fetch_list = distill_program_info.test_fetch_list + + ############################################################################ + # quant + ############################################################################ + use_pact = train_config["use_pact"] + if use_pact: + act_preprocess_func = pact + optimizer_func = get_pact_optimizer + pact_executor = executor + else: + act_preprocess_func = None + optimizer_func = None + pact_executor = None + + test_program = quant_aware( + test_program, + place, + quant_config, + scope=scope, + act_preprocess_func=act_preprocess_func, + optimizer_func=optimizer_func, + executor=pact_executor, + for_test=True) + + paddle.static.load( + executor=executor, + model_path=os.path.join(checkpoint_path), + program=test_program) + ############################################################################################################ + # 3. Freeze the graph after training by adjusting the quantize + # operators' order for the inference. + # The dtype of float_program's weights is float32, but in int8 range. + ############################################################################################################ + float_program, int8_program = convert(test_program, place, quant_config, \ + scope=scope, \ + save_int8=True) + ############################################################################################################ + # 4. Save inference model + ############################################################################################################ + export_model_dir = os.path.abspath( + os.path.join(export_inference_model_path_prefix, os.path.pardir)) + if not os.path.exists(export_model_dir): + os.makedirs(export_model_dir) + + feed_vars = [] + for name in test_feed_names: + for var in float_program.list_vars(): + if var.name == name: + feed_vars.append(var) + break + assert len(feed_vars) > 0, "can not find feed vars in quant program" + paddle.static.save_inference_model( + path_prefix=export_inference_model_path_prefix, + feed_vars=feed_vars, + fetch_vars=test_fetch_list, + executor=executor, + program=float_program) diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index 2522fed7..43c68173 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -29,6 +29,7 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass from paddle.fluid import core from paddle.fluid.contrib.slim.quantization import WeightQuantization +from paddle.fluid.layer_helper import LayerHelper from ..common import get_logger _logger = get_logger(__name__, level=logging.INFO) @@ -561,3 +562,30 @@ def quant_post_dynamic(model_dir, # For compatibility, we keep quant_post_only_weight api for now, # and it will be deprecated in the future. quant_post_only_weight = quant_post_dynamic + + +def pact(x, name=None): + helper = LayerHelper("pact", **locals()) + dtype = 'float32' + init_thres = 20 + u_param_attr = paddle.fluid.ParamAttr( + name=x.name + '_pact', + initializer=paddle.fluid.initializer.ConstantInitializer( + value=init_thres), + regularizer=paddle.fluid.regularizer.L2Decay(0.0001), + learning_rate=1) + u_param = helper.create_parameter(attr=u_param_attr, shape=[1], dtype=dtype) + x = paddle.fluid.layers.elementwise_sub( + x, + paddle.fluid.layers.relu( + paddle.fluid.layers.elementwise_sub(x, u_param))) + x = paddle.fluid.layers.elementwise_add( + x, + paddle.fluid.layers.relu( + paddle.fluid.layers.elementwise_sub(-u_param, x))) + + return x + + +def get_pact_optimizer(): + return paddle.fluid.optimizer.MomentumOptimizer(0.0001, 0.9) diff --git a/tests/test_quant_aware_with_infermodel.py b/tests/test_quant_aware_with_infermodel.py new file mode 100644 index 00000000..687152ab --- /dev/null +++ b/tests/test_quant_aware_with_infermodel.py @@ -0,0 +1,221 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import os +sys.path.append("../") +sys.path.append(".") +sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir) +import unittest +import paddle +from paddleslim.quant import quant_aware, convert +from paddleslim.quant import quant_aware_with_infermodel, export_quant_infermodel +from static_case import StaticCase +sys.path.append("../demo") +from models import MobileNet +from layers import conv_bn_layer +import paddle.dataset.mnist as reader +from paddle.fluid.framework import IrGraph +from paddle.fluid import core +import numpy as np + + +class TestQuantAwareWithInferModelCase1(StaticCase): + def test_accuracy(self): + float_infer_model_path_prefix = "./mv1_float_inference" + + image = paddle.static.data( + name='image', shape=[None, 1, 28, 28], dtype='float32') + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + model = MobileNet() + out = model.net(input=image, class_dim=10) + cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label) + avg_cost = paddle.mean(x=cost) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + optimizer = paddle.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + weight_decay=paddle.regularizer.L2Decay(4e-5)) + optimizer.minimize(avg_cost) + main_prog = paddle.static.default_main_program() + val_prog = main_prog.clone(for_test=True) + + #place = paddle.CPUPlace() + place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( + ) else paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + + def transform(x): + return np.reshape(x, [1, 28, 28]) + + train_dataset = paddle.vision.datasets.MNIST( + mode='train', backend='cv2', transform=transform) + test_dataset = paddle.vision.datasets.MNIST( + mode='test', backend='cv2', transform=transform) + + train_loader = paddle.io.DataLoader( + train_dataset, + places=place, + feed_list=[image, label], + drop_last=True, + batch_size=64, + return_list=False) + + valid_loader = paddle.io.DataLoader( + test_dataset, + places=place, + feed_list=[image, label], + batch_size=64, + return_list=False) + + def sample_generator_creator(): + def __reader__(): + for data in test_dataset: + image, label = data + yield image, label + + return __reader__ + + def train(program): + iter = 0 + for data in train_loader(): + cost, top1, top5 = exe.run( + program, + feed=data, + fetch_list=[avg_cost, acc_top1, acc_top5]) + iter += 1 + if iter % 100 == 0: + print( + 'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'. + format(iter, cost, top1, top5)) + + def test(program, outputs=[avg_cost, acc_top1, acc_top5]): + iter = 0 + result = [[], [], []] + for data in valid_loader(): + cost, top1, top5 = exe.run(program, + feed=data, + fetch_list=outputs) + iter += 1 + if iter % 100 == 0: + print('eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'. + format(iter, cost, top1, top5)) + result[0].append(cost) + result[1].append(top1) + result[2].append(top5) + print(' avg loss {}, acc_top1 {}, acc_top5 {}'.format( + np.mean(result[0]), np.mean(result[1]), np.mean(result[2]))) + return np.mean(result[1]), np.mean(result[2]) + + train(main_prog) + top1_1, top5_1 = test(val_prog) + paddle.static.save_inference_model( + path_prefix=float_infer_model_path_prefix, + feed_vars=[image, label], + fetch_vars=[avg_cost, acc_top1, acc_top5], + executor=exe, + program=val_prog) + + quant_config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'not_quant_pattern': ['skip_quant'], + 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'] + } + train_config = { + "num_epoch": 1, # training epoch num + "max_iter": 20, + "save_iter_step": 10, + "learning_rate": 0.0001, + "weight_decay": 0.0001, + "use_pact": False, + "quant_model_ckpt_path": + "./quantaware_with_infermodel_checkpoints/", + "teacher_model_path_prefix": float_infer_model_path_prefix, + "model_path_prefix": float_infer_model_path_prefix, + "distill_node_pair": [ + "teacher_fc_0.tmp_0", "fc_0.tmp_0", + "teacher_batch_norm_24.tmp_4", "batch_norm_24.tmp_4", + "teacher_batch_norm_22.tmp_4", "batch_norm_22.tmp_4", + "teacher_batch_norm_18.tmp_4", "batch_norm_18.tmp_4", + "teacher_batch_norm_13.tmp_4", "batch_norm_13.tmp_4", + "teacher_batch_norm_5.tmp_4", "batch_norm_5.tmp_4" + ] + } + + def test_callback(compiled_test_program, feed_names, fetch_list, + checkpoint_name): + outputs = fetch_list + iter = 0 + result = [[], [], []] + for data in valid_loader(): + cost, top1, top5 = exe.run(compiled_test_program, + feed=data, + fetch_list=fetch_list) + iter += 1 + if iter % 100 == 0: + print('eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'. + format(iter, cost, top1, top5)) + result[0].append(cost) + result[1].append(top1) + result[2].append(top5) + print("quant model checkpoint: " + checkpoint_name + + ' avg loss {}, acc_top1 {}, acc_top5 {}'.format( + np.mean(result[0]), + np.mean(result[1]), np.mean(result[2]))) + return np.mean(result[1]), np.mean(result[2]) + + def test_quant_aware_with_infermodel(exe, place): + quant_aware_with_infermodel( + exe, + place, + scope=None, + train_reader=train_loader, + quant_config=quant_config, + train_config=train_config, + test_callback=test_callback) + + def test_export_quant_infermodel(exe, place, checkpoint_path, + quant_infermodel_save_path): + export_quant_infermodel( + exe, + place, + scope=None, + quant_config=quant_config, + train_config=train_config, + checkpoint_path=checkpoint_path, + export_inference_model_path_prefix=quant_infermodel_save_path) + + #place = paddle.CPUPlace() + place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( + ) else paddle.CPUPlace() + exe = paddle.static.Executor(place) + + test_quant_aware_with_infermodel(exe, place) + checkpoint_path = "./quantaware_with_infermodel_checkpoints/epoch_0_iter_10" + quant_infermodel_save_path = "./quantaware_with_infermodel_export" + test_export_quant_infermodel(exe, place, checkpoint_path, + quant_infermodel_save_path) + train_config["use_pact"] = True + test_quant_aware_with_infermodel(exe, place) + train_config["use_pact"] = False + checkpoint_path = "./quantaware_with_infermodel_checkpoints/epoch_0_iter_10" + quant_infermodel_save_path = "./quantaware_with_infermodel_export" + test_export_quant_infermodel(exe, place, checkpoint_path, + quant_infermodel_save_path) + + +if __name__ == '__main__': + unittest.main() -- GitLab