diff --git a/example/auto_compression/image_classification/README.md b/example/auto_compression/image_classification/README.md index d857d7bdcaa4664a87954697160e183f73c9b72b..9153fe9b1bd2a676b46135559b01baa3221b5db8 100644 --- a/example/auto_compression/image_classification/README.md +++ b/example/auto_compression/image_classification/README.md @@ -45,6 +45,8 @@ | MobileNetV3_large_x1_0 | 量化+蒸馏 | 74.04 | - | 9.85 | [Config](./configs/MobileNetV3_large_x1_0/qat_dis.yaml) | [Model](https://paddle-slim-models.bj.bcebos.com/act/MobileNetV3_large_x1_0_QAT.tar) | | MobileNetV3_large_x1_0_ssld | Baseline | 78.96 | - | 16.62 | - | [Model](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV3_large_x1_0_ssld_infer.tar) | | MobileNetV3_large_x1_0_ssld | 量化+蒸馏 | 77.17 | - | 9.85 | [Config](./configs/MobileNetV3_large_x1_0/qat_dis.yaml) | [Model](https://paddle-slim-models.bj.bcebos.com/act/MobileNetV3_large_x1_0_ssld_QAT.tar) | +| ViT_base_patch16_224 | Baseline | 81.89 | - | - | - | [Model](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ViT_base_patch16_224_infer.tar) | +| ViT_base_patch16_224 | 量化+蒸馏 | 82.05 | - | - | [Config](./configs/VIT/qat_dis.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/ViT_base_patch16_224_QAT.tar) | - ARM CPU 测试环境:`SDM865(4xA77+4xA55)` - Nvidia GPU 测试环境: @@ -73,6 +75,11 @@ pip install paddlepaddle-gpu pip install paddleslim ``` +若使用`run_ppclas.py`脚本,需安装paddleclas: +```shell +pip install paddleclas +``` + #### 3.2 准备数据集 本案例默认以ImageNet1k数据进行自动压缩实验,如数据集为非ImageNet1k格式数据, 请参考[PaddleClas数据准备文档](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/data_preparation/classification_dataset.md)。将下载好的数据集放在当前目录下`./ILSVRC2012`。 @@ -111,7 +118,12 @@ python -m paddle.distributed.launch run.py --save_dir='./save_quant_mobilev1/' - ``` 多卡训练指的是将训练任务按照一定方法拆分到多个训练节点完成数据读取、前向计算、反向梯度计算等过程,并将计算出的梯度上传至服务节点。服务节点在收到所有训练节点传来的梯度后,会将梯度聚合并更新参数。最后将参数发送给训练节点,开始新一轮的训练。多卡训练一轮训练能训练```batch size * num gpus```的数据,比如单卡的```batch size```为32,单轮训练的数据量即32,而四卡训练的```batch size```为32,单轮训练的数据量为128。 -注意 ```learning rate``` 与 ```batch size``` 呈线性关系,这里单卡 ```batch size``` 为32,对应的 ```learning rate``` 为0.015,那么如果 ```batch size``` 减小4倍改为8,```learning rate``` 也需除以4;多卡时 ```batch size``` 为32,```learning rate``` 需乘上卡数。所以改变 ```batch size``` 或改变训练卡数都需要对应修改 ```learning rate```。 +注意: + +- 参数设置:```learning rate``` 与 ```batch size``` 呈线性关系,这里单卡 ```batch size``` 为32,对应的 ```learning rate``` 为0.015,那么如果 ```batch size``` 减小4倍改为8,```learning rate``` 也需除以4;多卡时 ```batch size``` 为32,```learning rate``` 需乘上卡数。所以改变 ```batch size``` 或改变训练卡数都需要对应修改 ```learning rate```。 + +- 如需要使用`PaddleClas`中的数据预处理和`DataLoader`,可以使用`run_ppclas.py`脚本启动,启动方式跟以上示例相同,但配置需要对其```PaddleClas```,可参考[ViT配置文件](./configs/VIT/data_reader.yml)。 + ## 4.预测部署 diff --git a/example/auto_compression/image_classification/configs/VIT/data_reader.yml b/example/auto_compression/image_classification/configs/VIT/data_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..370dde64407de3f2df7634cc68692a54a49f5aa6 --- /dev/null +++ b/example/auto_compression/image_classification/configs/VIT/data_reader.yml @@ -0,0 +1,56 @@ +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 16 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True diff --git a/example/auto_compression/image_classification/configs/VIT/qat_dis.yaml b/example/auto_compression/image_classification/configs/VIT/qat_dis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..450683758997927810c2466e9fa9e12e6f7bf4a4 --- /dev/null +++ b/example/auto_compression/image_classification/configs/VIT/qat_dis.yaml @@ -0,0 +1,28 @@ +Global: + model_dir: ViT_base_patch16_224_infer + model_filename: inference.pdmodel + params_filename: inference.pdiparams + batch_size: 16 + input_name: inputs + reader_config: ./configs/VIT/data_reader.yml + +Distillation: + node: + - softmax_12.tmp_0 + +QuantAware: + use_pact: true + onnx_format: true + +TrainConfig: + epochs: 1 + eval_iter: 500 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.015 + optimizer_builder: + optimizer: + type: Momentum + weight_decay: 0.00002 + origin_metric: 0.8189 + diff --git a/example/auto_compression/image_classification/run_ppclas.py b/example/auto_compression/image_classification/run_ppclas.py new file mode 100644 index 0000000000000000000000000000000000000000..3bdf5897d05f518f363c6078996eb04472062808 --- /dev/null +++ b/example/auto_compression/image_classification/run_ppclas.py @@ -0,0 +1,170 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import argparse +import functools +from functools import partial +import math +from tqdm import tqdm + +import numpy as np +import paddle +import paddle.nn as nn +from paddle.io import DataLoader +from paddleslim.common import load_config as load_slim_config +from paddleslim.auto_compression import AutoCompression +from ppcls.data import build_dataloader +from ppcls.utils import config +from ppcls.utils.logger import init_logger + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--save_dir', + type=str, + default='output', + help="directory to save compressed model.") + parser.add_argument( + '--total_images', + type=int, + default=1281167, + help="the number of total training images.") + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") + return parser + + +# yapf: enable +def reader_wrapper(reader, input_name): + if isinstance(input_name, list) and len(input_name) == 1: + input_name = input_name[0] + + def gen(): + for i, (imgs, label) in enumerate(reader()): + yield {input_name: imgs} + + return gen + + +def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): + + results = [] + with tqdm( + total=len(eval_loader), + bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: + for batch_id, (image, label) in enumerate(eval_loader): + # top1_acc, top5_acc + if len(test_feed_names) == 1: + image = np.array(image) + label = np.array(label).astype('int64') + if len(label.shape) == 1: + label = label.reshape([label.shape[0], -1]) + pred = exe.run(compiled_test_program, + feed={test_feed_names[0]: image}, + fetch_list=test_fetch_list) + pred = np.array(pred[0]) + sort_array = pred.argsort(axis=1) + top_1_pred = sort_array[:, -1:][:, ::-1] + top_1 = np.mean(label == top_1_pred) + top_5_pred = sort_array[:, -5:][:, ::-1] + acc_num = 0 + for i in range(len(label)): + if label[i][0] in top_5_pred[i]: + acc_num += 1 + top_5 = float(acc_num) / len(label) + results.append([top_1, top_5]) + else: + # eval "eval model", which inputs are image and label, output is top1 and top5 accuracy + image = np.array(image) + label = np.array(label).astype('int64') + result = exe.run(compiled_test_program, + feed={ + test_feed_names[0]: image, + test_feed_names[1]: label + }, + fetch_list=test_fetch_list) + result = [np.mean(r) for r in result] + results.append(result) + t.update() + result = np.mean(np.array(results), axis=0) + return result[0] + + +def main(): + rank_id = paddle.distributed.get_rank() + if args.devices == 'gpu': + paddle.CUDAPlace(rank_id) + device = paddle.set_device('gpu') + else: + paddle.CPUPlace() + device = paddle.set_device('cpu') + global global_config + all_config = load_slim_config(args.config_path) + + assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + global_config = all_config["Global"] + + gpu_num = paddle.distributed.get_world_size() + if isinstance(all_config['TrainConfig']['learning_rate'], + dict) and all_config['TrainConfig']['learning_rate'][ + 'type'] == 'CosineAnnealingDecay': + step = int( + math.ceil( + float(args.total_images) / (global_config['batch_size'] * + gpu_num))) + all_config['TrainConfig']['learning_rate']['T_max'] = step + print('total training steps:', step) + + init_logger() + data_config = config.get_config(global_config['reader_config'], show=False) + train_loader = build_dataloader(data_config["DataLoader"], "Train", device, + False) + train_dataloader = reader_wrapper(train_loader, global_config['input_name']) + + global eval_loader + eval_loader = build_dataloader(data_config["DataLoader"], "Eval", device, + False) + + ac = AutoCompression( + model_dir=global_config['model_dir'], + model_filename=global_config['model_filename'], + params_filename=global_config['params_filename'], + save_dir=args.save_dir, + config=all_config, + train_dataloader=train_dataloader, + eval_callback=eval_function if rank_id == 0 else None, + eval_dataloader=reader_wrapper(eval_loader, + global_config['input_name'])) + + ac.compress() + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + args = parser.parse_args() + main()