From 90797a222a32850bdff1d7f5c413f7a3ba2412cf Mon Sep 17 00:00:00 2001 From: whs Date: Wed, 29 Mar 2023 19:11:15 +0800 Subject: [PATCH] Add tutorial of PTQ for classification (#1705) --- .../quantization/ptq/classification/README.md | 59 ++++ .../quantization/ptq/classification/eval.py | 138 ++++++++++ .../quantization/ptq/classification/ptq.py | 251 ++++++++++++++++++ 3 files changed, 448 insertions(+) create mode 100644 example/quantization/ptq/classification/README.md create mode 100644 example/quantization/ptq/classification/eval.py create mode 100644 example/quantization/ptq/classification/ptq.py diff --git a/example/quantization/ptq/classification/README.md b/example/quantization/ptq/classification/README.md new file mode 100644 index 00000000..ddaf4801 --- /dev/null +++ b/example/quantization/ptq/classification/README.md @@ -0,0 +1,59 @@ +# 动态图离线量化 + +本示例介绍如何对动态图模型进行离线量化,示例以常用的MobileNetV1和MobileNetV3模型为例,介绍如何对其进行离线量化。 + + +## 分类模型的离线量化流程 + +#### 准备数据 + +在当前目录下创建``data``文件夹,将``ImageNet``的验证集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件: +- ``'val'``文件夹,验证图片 +- ``'val_list.txt'``文件 + +#### 准备需要离线量化的模型 + +本示例直接使用[paddle vision](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/vision/models)内置的模型结构和预训练权重。通过以下命令查看支持的所有模型: + +``` +python ptq.py --help +``` + +## 启动命令 +以MobileNetV1为例,通过以下脚本启动PTQ任务: + +```bash +python ptq.py \ + --data=dataset/ILSVRC2012/ \ + --model=mobilenet_v1 \ + --activation_observer='mse' \ + --weight_observer='mse_channel_wise' \ + --quant_batch_num=10 \ + --quant_batch_size=10 \ + --output_dir="output_ptq" +``` + +其中,通过 `activation_observer` 配置用于激活的量化算法,通过 `weight_observer` 配置用于权重的量化算法。 +更多支持的量化算法,请执行以下命令查看: + +``` +python ptq.py --help +``` + +## 评估精度 + +执行以下命令,使用 PaddleInference 推理库测试推理精度: + +```bash +python eval.py --model_path=output_ptq/mobilenet_v1/int8_infer/ --data_dir=dataset/ILSVRC2012/ --use_gpu=True +``` + +- 评估时支持CPU,并且不依赖TensorRT,MKLDNN。 + + +## 量化结果 + +| 模型 | FP32模型准确率(Top1/Top5) | 量化方法(activation/weight) | 量化模型准确率(Top1/Top5) | +| ----------- | --------------------------- | ------------ | --------------------------- | +| MobileNetV1 | 70.10%/90.10% | mse / mes_channel_wise | 69.10%/89.80% | +| MobileNetV2 | 71.10%/90.90% | mse / mes_channel_wise | 70.70%/90.10% | diff --git a/example/quantization/ptq/classification/eval.py b/example/quantization/ptq/classification/eval.py new file mode 100644 index 00000000..ef7dc749 --- /dev/null +++ b/example/quantization/ptq/classification/eval.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np +import time +import sys +import argparse +import math + +import paddle +import paddle.inference as paddle_infer +from ptq import ImageNetValDataset + + +def eval(): + # create predictor + model_file = os.path.join(FLAGS.model_path, FLAGS.model_filename) + params_file = os.path.join(FLAGS.model_path, FLAGS.params_filename) + config = paddle_infer.Config(model_file, params_file) + if FLAGS.use_gpu: + config.enable_use_gpu(1000, 0) + if not FLAGS.ir_optim: + config.switch_ir_optim(False) + + predictor = paddle_infer.create_predictor(config) + + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + + # prepare data + val_dataset = ImageNetValDataset(FLAGS.data_dir) + eval_loader = paddle.io.DataLoader( + val_dataset, batch_size=FLAGS.batch_size, num_workers=5) + + cost_time = 0. + total_num = 0. + correct_1_num = 0 + correct_5_num = 0 + for batch_id, data in enumerate(eval_loader()): + # set input + img_np = np.array([tensor.numpy() for tensor in data[0]]) + label_np = np.array([tensor.numpy() for tensor in data[1]]) + + input_handle.reshape(img_np.shape) + input_handle.copy_from_cpu(img_np) + + # run + t1 = time.time() + predictor.run() + t2 = time.time() + cost_time += (t2 - t1) + + output_data = output_handle.copy_to_cpu() + + # calculate accuracy + for i in range(len(label_np)): + label = label_np[i][0] + result = output_data[i, :] + index = result.argsort() + total_num += 1 + if index[-1] == label: + correct_1_num += 1 + if label in index[-5:]: + correct_5_num += 1 + + if batch_id % 10 == 0: + acc1 = correct_1_num / total_num + acc5 = correct_5_num / total_num + avg_time = cost_time / total_num + print( + "batch_id {}, acc1 {:.3f}, acc5 {:.3f}, avg time {:.5f} sec/img". + format(batch_id, acc1, acc5, avg_time)) + + if FLAGS.test_samples > 0 and \ + (batch_id + 1)* FLAGS.batch_size >= FLAGS.test_samples: + break + + acc1 = correct_1_num / total_num + acc5 = correct_5_num / total_num + avg_time = cost_time / total_num + print("End test: test image {}".format(total_num)) + print("test_acc1: {:.4f}; test_acc5: {:.4f}; avg time: {:.5f} sec/img". + format(acc1, acc5, avg_time)) + print("\n") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--model_path', type=str, default="", help="The inference model path.") + parser.add_argument( + '--model_filename', + type=str, + default="model.pdmodel", + help="model filename") + parser.add_argument( + '--params_filename', + type=str, + default="model.pdiparams", + help="params filename") + parser.add_argument( + '--data_dir', + type=str, + default="dataset/ILSVRC2012/", + help="The ImageNet dataset root dir.") + parser.add_argument( + '--test_samples', + type=int, + default=-1, + help="Test samples. If set -1, use all test samples") + parser.add_argument( + '--batch_size', type=int, default=10, help="Batch size.") + parser.add_argument( + '--use_gpu', type=bool, default=False, help=" Whether use gpu or not.") + parser.add_argument( + '--ir_optim', type=bool, default=False, help="Enable ir optim.") + + FLAGS = parser.parse_args() + + eval() diff --git a/example/quantization/ptq/classification/ptq.py b/example/quantization/ptq/classification/ptq.py new file mode 100644 index 00000000..71ab7b37 --- /dev/null +++ b/example/quantization/ptq/classification/ptq.py @@ -0,0 +1,251 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import argparse +import six +from inspect import isfunction +import os +import time +import random +from types import FunctionType +from typing import Dict +import numpy as np +from PIL import Image + +import paddle +from paddle.io import Dataset +from paddle.vision.transforms import transforms +import paddle.vision.models as models +from paddle.quantization import QuantConfig +from paddle.quantization import PTQ +from tqdm import tqdm +from paddleslim.quant.observers import HistObserver, KLObserver, EMDObserver, MSEObserver, AVGObserver +from paddleslim.quant.observers import MSEChannelWiseWeightObserver, AbsMaxChannelWiseWeightObserver + +import sys +sys.path.append(os.path.dirname("__file__")) +sys.path.append( + os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir)) + +SUPPORT_MODELS: Dict[str, FunctionType] = {} +for _name, _module in models.__dict__.items(): + if isfunction(_module) and 'pretrained' in _module.__code__.co_varnames: + SUPPORT_MODELS[_name] = _module + +ACTIVATION_OBSERVERS: Dict[str, type] = { + 'hist': HistObserver, + 'kl': KLObserver, + 'emd': EMDObserver, + 'mse': MSEObserver, + 'avg': AVGObserver, +} + +WEIGHT_OBSERVERS: Dict[str, type] = { + 'mse_channel_wise': MSEChannelWiseWeightObserver, + 'abs_max_channel_wise': AbsMaxChannelWiseWeightObserver, +} + + +class ImageNetValDataset(Dataset): + def __init__(self, data_dir, image_size=224, resize_short_size=256): + super(ImageNetValDataset, self).__init__() + val_file_list = os.path.join(data_dir, 'val_list.txt') + test_file_list = os.path.join(data_dir, 'test_list.txt') + self.data_dir = data_dir + + normalize = transforms.Normalize( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375]) + self.transform = transforms.Compose([ + transforms.Resize(resize_short_size), + transforms.CenterCrop(image_size), + transforms.Transpose(), normalize + ]) + + with open(val_file_list) as flist: + lines = [line.strip() for line in flist] + self.data = [line.split() for line in lines] + + def __getitem__(self, index): + img_path, label = self.data[index] + img_path = os.path.join(self.data_dir, img_path) + img = Image.open(img_path).convert('RGB') + label = np.array([label]).astype(np.int64) + return self.transform(img), label + + def __len__(self): + return len(self.data) + + +def test(net, dataset): + valid_loader = paddle.io.DataLoader(dataset, batch_size=1) + net.eval() + batch_id = 0 + acc_top1_ns = [] + acc_top5_ns = [] + + eval_reader_cost = 0.0 + eval_run_cost = 0.0 + total_samples = 0 + reader_start = time.time() + for data in tqdm(valid_loader()): + eval_reader_cost += time.time() - reader_start + image = data[0] + label = data[1] + eval_start = time.time() + + out = net(image) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + + eval_run_cost += time.time() - eval_start + batch_size = image.shape[0] + total_samples += batch_size + + acc_top1_ns.append(np.mean(acc_top1.numpy())) + acc_top5_ns.append(np.mean(acc_top5.numpy())) + batch_id += 1 + reader_start = time.time() + return np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns)) + + +def calibrate(model, dataset, batch_num, batch_size, num_workers=1): + data_loader = paddle.io.DataLoader( + dataset, batch_size=batch_size, num_workers=num_workers) + + pbar = tqdm(total=batch_num) + for idx, data in enumerate(data_loader()): + model(data[0]) + pbar.update(1) + if (batch_num > 0) and (idx + 1 >= batch_num): + break + pbar.close() + + +def main(): + num_workers = 5 + if FLAGS.ce_test: + # set seed + seed = 111 + paddle.seed(seed) + np.random.seed(seed) + random.seed(seed) + num_workers = 0 + + # 1 load model + fp32_model = SUPPORT_MODELS[FLAGS.model](pretrained=True) + if FLAGS.pretrain_weight: + info_dict = paddle.load(FLAGS.pretrain_weight) + fp32_model.load_dict(info_dict) + print('Finish loading model weights:{}'.format(FLAGS.pretrain_weight)) + fp32_model.eval() + val_dataset = ImageNetValDataset(FLAGS.data) + + # 2 quantizations + activation_observer = ACTIVATION_OBSERVERS[FLAGS.activation_observer]() + weight_observer = WEIGHT_OBSERVERS[FLAGS.weight_observer]() + + config = QuantConfig(weight=None, activation=None) + config.add_type_config( + paddle.nn.Conv2D, + activation=activation_observer, + weight=weight_observer) + ptq = PTQ(config) + top1, top5 = test(fp32_model, val_dataset) + print( + f"\033[31mBaseline(FP32): top1/top5 = {top1*100:.2f}%/{top5*100:.2f}%\033[0m" + ) + quant_model = ptq.quantize(fp32_model) + + print("Start PTQ calibration for quantization") + calibrate( + quant_model, + val_dataset, + FLAGS.quant_batch_num, + FLAGS.quant_batch_size, + num_workers=num_workers) + + infer_model = ptq.convert(quant_model, inplace=True) + + top1, top5 = test(infer_model, val_dataset) + print( + f"\033[31mPTQ with {FLAGS.activation_observer}/{FLAGS.weight_observer}: top1/top5 = {top1*100:.2f}%/{top5*100:.2f}%\033[0m" + ) + + dummy_input = paddle.static.InputSpec( + shape=[None, 3, 224, 224], dtype='float32') + paddle.jit.save(infer_model, "./int8_infer", [dummy_input]) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Quantization on ImageNet") + + # model + parser.add_argument( + "--model", + type=str, + choices=SUPPORT_MODELS.keys(), + default='mobilenet_v1', + help="model name", ) + parser.add_argument( + "--pretrain_weight", + type=str, + default=None, + help="pretrain weight path") + parser.add_argument( + "--output_dir", type=str, default='output', help="save dir") + + # data + parser.add_argument( + '--data', + default="/dataset/ILSVRC2012", + help= + 'path to dataset (should have subdirectories named "train" and "val"', + required=True, ) + + parser.add_argument( + '--val_dir', + default="val", + help='the dir that saves val images for paddle.Model') + + # quantization + parser.add_argument( + "--activation_observer", + default='mse', + type=str, + choices=ACTIVATION_OBSERVERS.keys(), + help="batch num for quant") + parser.add_argument( + "--weight_observer", + default='mse_channel_wise', + choices=WEIGHT_OBSERVERS.keys(), + type=str, + help="batch size for quant") + + # train + parser.add_argument( + "--quant_batch_num", default=10, type=int, help="batch num for quant") + parser.add_argument( + "--quant_batch_size", default=10, type=int, help="batch size for quant") + parser.add_argument( + '--ce_test', default=False, type=bool, help="Whether to CE test.") + + FLAGS = parser.parse_args() + print("----------- Configuration Arguments -----------") + for arg, value in sorted(six.iteritems(vars(FLAGS))): + print("%s: %s" % (arg, value)) + print("------------------------------------------------") + main() -- GitLab