未验证 提交 d6596a63 编写于 作者: C Chang Xu 提交者: GitHub

Add ViT demo for ACT (#1552)

* Add ViT demo for ACT

* update
Co-authored-by: Nceci3 <ceci3@users.noreply.github.com>
上级 15aa529b
......@@ -45,6 +45,8 @@
| MobileNetV3_large_x1_0 | 量化+蒸馏 | 74.04 | - | 9.85 | [Config](./configs/MobileNetV3_large_x1_0/qat_dis.yaml) | [Model](https://paddle-slim-models.bj.bcebos.com/act/MobileNetV3_large_x1_0_QAT.tar) |
| MobileNetV3_large_x1_0_ssld | Baseline | 78.96 | - | 16.62 | - | [Model](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV3_large_x1_0_ssld_infer.tar) |
| MobileNetV3_large_x1_0_ssld | 量化+蒸馏 | 77.17 | - | 9.85 | [Config](./configs/MobileNetV3_large_x1_0/qat_dis.yaml) | [Model](https://paddle-slim-models.bj.bcebos.com/act/MobileNetV3_large_x1_0_ssld_QAT.tar) |
| ViT_base_patch16_224 | Baseline | 81.89 | - | - | - | [Model](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ViT_base_patch16_224_infer.tar) |
| ViT_base_patch16_224 | 量化+蒸馏 | 82.05 | - | - | [Config](./configs/VIT/qat_dis.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/ViT_base_patch16_224_QAT.tar) |
- ARM CPU 测试环境:`SDM865(4xA77+4xA55)`
- Nvidia GPU 测试环境:
......@@ -73,6 +75,11 @@ pip install paddlepaddle-gpu
pip install paddleslim
```
若使用`run_ppclas.py`脚本,需安装paddleclas:
```shell
pip install paddleclas
```
#### 3.2 准备数据集
本案例默认以ImageNet1k数据进行自动压缩实验,如数据集为非ImageNet1k格式数据, 请参考[PaddleClas数据准备文档](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/data_preparation/classification_dataset.md)。将下载好的数据集放在当前目录下`./ILSVRC2012`
......@@ -111,7 +118,12 @@ python -m paddle.distributed.launch run.py --save_dir='./save_quant_mobilev1/' -
```
多卡训练指的是将训练任务按照一定方法拆分到多个训练节点完成数据读取、前向计算、反向梯度计算等过程,并将计算出的梯度上传至服务节点。服务节点在收到所有训练节点传来的梯度后,会将梯度聚合并更新参数。最后将参数发送给训练节点,开始新一轮的训练。多卡训练一轮训练能训练```batch size * num gpus```的数据,比如单卡的```batch size```为32,单轮训练的数据量即32,而四卡训练的```batch size```为32,单轮训练的数据量为128。
注意 ```learning rate``````batch size``` 呈线性关系,这里单卡 ```batch size``` 为32,对应的 ```learning rate``` 为0.015,那么如果 ```batch size``` 减小4倍改为8,```learning rate``` 也需除以4;多卡时 ```batch size``` 为32,```learning rate``` 需乘上卡数。所以改变 ```batch size``` 或改变训练卡数都需要对应修改 ```learning rate```
注意:
- 参数设置:```learning rate``````batch size``` 呈线性关系,这里单卡 ```batch size``` 为32,对应的 ```learning rate``` 为0.015,那么如果 ```batch size``` 减小4倍改为8,```learning rate``` 也需除以4;多卡时 ```batch size``` 为32,```learning rate``` 需乘上卡数。所以改变 ```batch size``` 或改变训练卡数都需要对应修改 ```learning rate```
- 如需要使用`PaddleClas`中的数据预处理和`DataLoader`,可以使用`run_ppclas.py`脚本启动,启动方式跟以上示例相同,但配置需要对其```PaddleClas```,可参考[ViT配置文件](./configs/VIT/data_reader.yml)
## 4.预测部署
......
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 16
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Global:
model_dir: ViT_base_patch16_224_infer
model_filename: inference.pdmodel
params_filename: inference.pdiparams
batch_size: 16
input_name: inputs
reader_config: ./configs/VIT/data_reader.yml
Distillation:
node:
- softmax_12.tmp_0
QuantAware:
use_pact: true
onnx_format: true
TrainConfig:
epochs: 1
eval_iter: 500
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.015
optimizer_builder:
optimizer:
type: Momentum
weight_decay: 0.00002
origin_metric: 0.8189
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
import functools
from functools import partial
import math
from tqdm import tqdm
import numpy as np
import paddle
import paddle.nn as nn
from paddle.io import DataLoader
from paddleslim.common import load_config as load_slim_config
from paddleslim.auto_compression import AutoCompression
from ppcls.data import build_dataloader
from ppcls.utils import config
from ppcls.utils.logger import init_logger
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--save_dir',
type=str,
default='output',
help="directory to save compressed model.")
parser.add_argument(
'--total_images',
type=int,
default=1281167,
help="the number of total training images.")
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
return parser
# yapf: enable
def reader_wrapper(reader, input_name):
if isinstance(input_name, list) and len(input_name) == 1:
input_name = input_name[0]
def gen():
for i, (imgs, label) in enumerate(reader()):
yield {input_name: imgs}
return gen
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
results = []
with tqdm(
total=len(eval_loader),
bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for batch_id, (image, label) in enumerate(eval_loader):
# top1_acc, top5_acc
if len(test_feed_names) == 1:
image = np.array(image)
label = np.array(label).astype('int64')
if len(label.shape) == 1:
label = label.reshape([label.shape[0], -1])
pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = np.array(pred[0])
sort_array = pred.argsort(axis=1)
top_1_pred = sort_array[:, -1:][:, ::-1]
top_1 = np.mean(label == top_1_pred)
top_5_pred = sort_array[:, -5:][:, ::-1]
acc_num = 0
for i in range(len(label)):
if label[i][0] in top_5_pred[i]:
acc_num += 1
top_5 = float(acc_num) / len(label)
results.append([top_1, top_5])
else:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = np.array(image)
label = np.array(label).astype('int64')
result = exe.run(compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = [np.mean(r) for r in result]
results.append(result)
t.update()
result = np.mean(np.array(results), axis=0)
return result[0]
def main():
rank_id = paddle.distributed.get_rank()
if args.devices == 'gpu':
paddle.CUDAPlace(rank_id)
device = paddle.set_device('gpu')
else:
paddle.CPUPlace()
device = paddle.set_device('cpu')
global global_config
all_config = load_slim_config(args.config_path)
assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
global_config = all_config["Global"]
gpu_num = paddle.distributed.get_world_size()
if isinstance(all_config['TrainConfig']['learning_rate'],
dict) and all_config['TrainConfig']['learning_rate'][
'type'] == 'CosineAnnealingDecay':
step = int(
math.ceil(
float(args.total_images) / (global_config['batch_size'] *
gpu_num)))
all_config['TrainConfig']['learning_rate']['T_max'] = step
print('total training steps:', step)
init_logger()
data_config = config.get_config(global_config['reader_config'], show=False)
train_loader = build_dataloader(data_config["DataLoader"], "Train", device,
False)
train_dataloader = reader_wrapper(train_loader, global_config['input_name'])
global eval_loader
eval_loader = build_dataloader(data_config["DataLoader"], "Eval", device,
False)
ac = AutoCompression(
model_dir=global_config['model_dir'],
model_filename=global_config['model_filename'],
params_filename=global_config['params_filename'],
save_dir=args.save_dir,
config=all_config,
train_dataloader=train_dataloader,
eval_callback=eval_function if rank_id == 0 else None,
eval_dataloader=reader_wrapper(eval_loader,
global_config['input_name']))
ac.compress()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
args = parser.parse_args()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册