diff --git a/demo/auto_compression/detection/keypoint_utils.py b/demo/auto_compression/detection/keypoint_utils.py index f378747b9e29e9c3cc37740089e618b0c94f9842..0d5a9b7c64a9dacc0eba1ad9388b4ca0f1f4522b 100644 --- a/demo/auto_compression/detection/keypoint_utils.py +++ b/demo/auto_compression/detection/keypoint_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/demo/auto_compression/image_classification/README.md b/demo/auto_compression/image_classification/README.md index a853c1e6a7150c85a7940b4cdc0955f0f7dff337..aee959c9919a66fc863aa986a03126638d117208 100644 --- a/demo/auto_compression/image_classification/README.md +++ b/demo/auto_compression/image_classification/README.md @@ -50,20 +50,6 @@ - 软件:CUDA 11.2, cuDNN 8.0, TensorRT 8.4 - 测试配置:batch_size: 1, image size: 224 - -### TensorFlow MobileNetV1模型 - -| 模型 | 策略 | Top-1 Acc | 耗时(ms) threads=1 | Inference模型 | -|:------:|:------:|:------:|:------:|:------:| -| MobileNetV1 | Base模型 | 71.0 | 30.45 | [Model](https://paddle-slim-models.bj.bcebos.com/act/mobilenetv1_inference_model_tf2paddle.tar) | -| MobileNetV1 | 量化+蒸馏 | 70.22 | 15.86 | [Model](https://paddle-slim-models.bj.bcebos.com/act/mobilenetv1_quant.tar) | - -- 测试环境:`骁龙865 4*A77 4*A55` - -说明: -- MobileNetV1模型源自[tensorflow/models](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) - - ## 3. 自动压缩流程 #### 3.1 准备环境 diff --git a/demo/auto_compression/image_classification/eval.py b/demo/auto_compression/image_classification/eval.py index 4469fd08e4143908d37c8ace287f09d05518f17a..5d8a327aa1344354682cbe0ef59d5b1150e88008 100644 --- a/demo/auto_compression/image_classification/eval.py +++ b/demo/auto_compression/image_classification/eval.py @@ -1,3 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import sys sys.path[0] = os.path.join( @@ -12,7 +26,6 @@ import paddle.nn as nn from paddle.io import Dataset, BatchSampler, DataLoader import imagenet_reader as reader from paddleslim.auto_compression.config_helpers import load_config as load_slim_config -from paddleslim.auto_compression import AutoCompression def argsparser(): @@ -23,22 +36,6 @@ def argsparser(): default=None, help="path of compression strategy config.", required=True) - parser.add_argument( - '--save_dir', - type=str, - default='output', - help="directory to save compressed model.") - return parser - - -# yapf: enable -def reader_wrapper(reader, input_name): - def gen(): - for i, data in enumerate(reader()): - imgs = np.float32([item[0] for item in data]) - yield {input_name: imgs} - - return gen def eval_reader(data_dir, batch_size): diff --git a/demo/auto_compression/image_classification/infer.py b/demo/auto_compression/image_classification/infer.py index 9a98f41633593f0bafd3de49ff645f55f8588526..88e4b82de2f13b448a6e9425174c2297421a1594 100644 --- a/demo/auto_compression/image_classification/infer.py +++ b/demo/auto_compression/image_classification/infer.py @@ -1,3 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import sys import cv2 diff --git a/demo/auto_compression/image_classification/preprocess.py b/demo/auto_compression/image_classification/preprocess.py index 2d56789e7cc75998bad3189816cb3288743fa177..6ef16b4b2e926c8ba90a7ee101045e1d9498b916 100644 --- a/demo/auto_compression/image_classification/preprocess.py +++ b/demo/auto_compression/image_classification/preprocess.py @@ -1,4 +1,3 @@ -""" # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" from __future__ import absolute_import from __future__ import division diff --git a/demo/auto_compression/image_classification/run.py b/demo/auto_compression/image_classification/run.py index 666911b21eee53d63c5e9a237fa56238698b0f69..28c8fbf9244e0f5ae894ae4f16bfdd1709ba0a1a 100644 --- a/demo/auto_compression/image_classification/run.py +++ b/demo/auto_compression/image_classification/run.py @@ -1,3 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import sys sys.path[0] = os.path.join( @@ -14,7 +28,6 @@ from paddle.io import Dataset, BatchSampler, DataLoader import imagenet_reader as reader from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.auto_compression import AutoCompression -from utility import add_arguments, print_arguments def argsparser(): diff --git a/demo/auto_compression/image_classification/run_tf.sh b/demo/auto_compression/image_classification/run_tf.sh deleted file mode 100644 index d4cf3e73a1f4a5c5ff0c54f4ed41b01d5e1d874e..0000000000000000000000000000000000000000 --- a/demo/auto_compression/image_classification/run_tf.sh +++ /dev/null @@ -1,14 +0,0 @@ -# 单卡启动 -export CUDA_VISIBLE_DEVICES=0 -python run.py \ - --model_dir='inference_model_usex2paddle' \ - --model_filename='model.pdmodel' \ - --params_filename='model.pdiparams' \ - --save_dir='./save_quant_mobilev1/' \ - --batch_size=128 \ - --config_path='./configs/mobilenetv1_qat_dis.yaml'\ - --input_shape 224 224 3 \ - --image_reader_type='tensorflow' \ - --input_name "input" \ - --data_dir='ILSVRC2012' - diff --git a/demo/auto_compression/pytorch_huggingface/infer.py b/demo/auto_compression/pytorch_huggingface/infer.py index 0d95fc2a99a17cd56c9fbb40bfcdb3ea53015c50..d1db4bf478f4a460cf1f7dc2383006b936fd7385 100644 --- a/demo/auto_compression/pytorch_huggingface/infer.py +++ b/demo/auto_compression/pytorch_huggingface/infer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/demo/auto_compression/semantic_segmentation/run.py b/demo/auto_compression/semantic_segmentation/run.py index 4e4706d27c98b53560452cd4042946fa38c1f07d..9a6eea1b25eae094211f0d3689f8e00944ff219d 100644 --- a/demo/auto_compression/semantic_segmentation/run.py +++ b/demo/auto_compression/semantic_segmentation/run.py @@ -1,3 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import argparse import random diff --git a/demo/auto_compression/tensorflow_mobilenet/README.md b/demo/auto_compression/tensorflow_mobilenet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7a2dca214647fcf5fb80577ab0bd5322a83d9055 --- /dev/null +++ b/demo/auto_compression/tensorflow_mobilenet/README.md @@ -0,0 +1,100 @@ +# TensorFlow图像分类模型自动压缩示例 + +目录: +- [1. 简介](#1简介) +- [2. Benchmark](#2Benchmark) +- [3. 自动压缩流程](#自动压缩流程) + - [3.1 准备环境](#31-准备准备) + - [3.2 准备数据集](#32-准备数据集) + - [3.3 X2Paddle转换模型流程](#33-X2Paddle转换模型流程) + - [3.4 自动压缩并产出模型](#34-自动压缩并产出模型) +- [4. 预测部署](#4预测部署) +- [5. FAQ](5FAQ) + + +## 1. 简介 +飞桨模型转换工具[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)支持将```Caffe/TensorFlow/ONNX/PyTorch```的模型一键转为飞桨(PaddlePaddle)的预测模型。借助X2Paddle的能力,PaddleSlim的自动压缩功能可方便地用于各种框架的推理模型。 + +本示例将以[TensorFlow](https://github.com/tensorflow/tensorflow)框架的MobileNetV1模型为例,介绍如何自动压缩其他框架中的图像分类模型。本示例会利用[TensorFlow](https://github.com/tensorflow/models)开源models库,将TensorFlow框架模型转换为Paddle框架模型,再使用ACT自动压缩功能进行自动压缩。本示例使用的自动压缩策略为量化训练。 + +## 2. Benchmark +| 模型 | 策略 | Top-1 Acc | 耗时(ms) threads=1 | Inference模型 | +|:------:|:------:|:------:|:------:|:------:| +| MobileNetV1 | Base模型 | 71.0 | 30.45 | [Model](https://paddle-slim-models.bj.bcebos.com/act/mobilenetv1_inference_model_tf2paddle.tar) | +| MobileNetV1 | 量化+蒸馏 | 70.22 | 15.86 | [Model](https://paddle-slim-models.bj.bcebos.com/act/mobilenetv1_quant.tar) | + +- 测试环境:`骁龙865 4*A77 4*A55` + +说明: +- MobileNetV1模型源自[tensorflow/models](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) + +## 3. 自动压缩流程 + +#### 3.1 准备环境 +- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) +- PaddleSlim develop版本 +- [X2Paddle](https://github.com/PaddlePaddle/X2Paddle) >= 1.3.6 +- opencv-python + +(1)安装paddlepaddle: +```shell +# CPU +pip install paddlepaddle +# GPU +pip install paddlepaddle-gpu +``` + +(2)安装paddleslim: +```shell +https://github.com/PaddlePaddle/PaddleSlim.git +python setup.py install +``` + +(3)安装TensorFlow: +```shell +pip install tensorflow == 1.14 +``` + +(3)安装X2Paddle的1.3.6以上版本: +```shell +pip install x2paddle +``` + +#### 3.2 准备数据集 +本案例默认以ImageNet1k数据进行自动压缩实验。 + +#### 3.3 准备预测模型 + +(1)转换模型 + +``` +x2paddle --framework=tensorflow --model=tf_model.pb --save_dir=pd_model +``` +即可得到MobileNetV1模型的预测模型(`model.pdmodel` 和 `model.pdiparams`)。如想快速体验,可直接下载上方表格中MobileNetV1的[Base模型](https://paddle-slim-models.bj.bcebos.com/act/mobilenetv1_inference_model_tf2paddle.tar)。 + +预测模型的格式为:`model.pdmodel` 和 `model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。 + +### 3.4 自动压缩并产出模型 + +蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为: +``` +# 单卡 +export CUDA_VISIBLE_DEVICES=0 +python run.py --config_path=./configs/mbv1_qat_dis.yaml --save_dir='./output/' +``` + +#### 3.5 测试模型精度 + +使用eval.py脚本得到模型的mAP: +``` +export CUDA_VISIBLE_DEVICES=0 +python eval.py --config_path=./configs/mbv1_qat_dis.yaml +``` + +## 4.预测部署 + +#### 4.1 PaddleLite端侧部署 +PaddleLite端侧部署可参考: +- [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/inference_deployment/paddle_lite_deploy.md) + +## 5.FAQ diff --git a/demo/auto_compression/tensorflow_mobilenet/configs/mbv1_qat_dis.yaml b/demo/auto_compression/tensorflow_mobilenet/configs/mbv1_qat_dis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..359ac18d11f00bf4a06926f72a2df27257ae6be8 --- /dev/null +++ b/demo/auto_compression/tensorflow_mobilenet/configs/mbv1_qat_dis.yaml @@ -0,0 +1,62 @@ +Global: + input_name: input + model_dir: inference_model_usex2paddle + model_filename: model.pdmodel + params_filename: model.pdiparams + batch_size: 32 + data_dir: ./ILSVRC2012 +Distillation: + alpha: 1.0 + loss: l2 + node: + - batch_norm_0.tmp_3 + - batch_norm_1.tmp_3 + - batch_norm_2.tmp_3 + - batch_norm_3.tmp_3 + - batch_norm_4.tmp_3 + - batch_norm_5.tmp_3 + - batch_norm_6.tmp_3 + - batch_norm_7.tmp_3 + - batch_norm_8.tmp_3 + - batch_norm_9.tmp_3 + - batch_norm_10.tmp_3 + - batch_norm_11.tmp_3 + - batch_norm_12.tmp_3 + - batch_norm_13.tmp_3 + - batch_norm_14.tmp_3 + - batch_norm_15.tmp_3 + - batch_norm_16.tmp_3 + - batch_norm_17.tmp_3 + - batch_norm_18.tmp_3 + - batch_norm_19.tmp_3 + - batch_norm_20.tmp_3 + - batch_norm_21.tmp_3 + - batch_norm_22.tmp_3 + - batch_norm_23.tmp_3 + - batch_norm_24.tmp_3 + - batch_norm_25.tmp_3 + - batch_norm_26.tmp_3 + - conv2d_42.tmp_1 + +Quantization: + use_pact: true + activation_bits: 8 + is_full_quantize: false + not_quant_pattern: + - skip_quant + quantize_op_types: + - conv2d + - depthwise_conv2d + weight_bits: 8 + activation_quantize_type: moving_average_abs_max + weight_quantize_type: channel_wise_abs_max + +TrainConfig: + epochs: 1000 + eval_iter: 1000 + learning_rate: 0.00001 + optimizer_builder: + optimizer: + type: SGD + weight_decay: 4.0e-05 + origin_metric: 0.71028 diff --git a/demo/auto_compression/tensorflow_mobilenet/eval.py b/demo/auto_compression/tensorflow_mobilenet/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..85e5fdaf504ce30c7ddb8d4013371da906006b7b --- /dev/null +++ b/demo/auto_compression/tensorflow_mobilenet/eval.py @@ -0,0 +1,108 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import argparse +import functools +from functools import partial + +import numpy as np +import paddle +import paddle.nn as nn +from paddle.io import DataLoader +from imagenet_reader import ImageNetDataset +from paddleslim.auto_compression.config_helpers import load_config as load_slim_config + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + return parser + + +def eval_reader(data_dir, batch_size): + val_reader = ImageNetDataset(mode='val', data_dir=data_dir) + val_loader = DataLoader( + val_reader, + batch_size=global_config['batch_size'], + shuffle=False, + drop_last=False, + num_workers=0) + return val_loader + + +def eval(): + devices = paddle.device.get_device().split(':')[0] + places = paddle.device._convert_to_place(devices) + exe = paddle.static.Executor(places) + val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( + global_config["model_dir"], + exe, + model_filename=global_config["model_filename"], + params_filename=global_config["params_filename"]) + print('Loaded model from: {}'.format(global_config["model_dir"])) + + val_reader = eval_reader(data_dir, batch_size=global_config['batch_size']) + image = paddle.static.data( + name=global_config['input_name'], + shape=[None, 224, 224, 3], + dtype='float32') + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + results = [] + print('Evaluating... It will take a while. Please wait...') + for batch_id, (image, label) in enumerate(val_reader): + # top1_acc, top5_acc + image = np.array(image) + label = np.array(label).astype('int64') + pred = exe.run(val_program, + feed={feed_target_names[0]: image}, + fetch_list=fetch_targets) + pred = np.array(pred[0]) + label = np.array(label) + sort_array = pred.argsort(axis=1) + top_1_pred = sort_array[:, -1:][:, ::-1] + top_1 = np.mean(label == top_1_pred) + top_5_pred = sort_array[:, -5:][:, ::-1] + acc_num = 0 + for i in range(len(label)): + if label[i][0] in top_5_pred[i]: + acc_num += 1 + top_5 = float(acc_num) / len(label) + results.append([top_1, top_5]) + result = np.mean(np.array(results), axis=0) + return result[0] + + +def main(): + global global_config + all_config = load_slim_config(args.config_path) + assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + global_config = all_config["Global"] + global data_dir + data_dir = global_config['data_dir'] + result = eval() + print('Eval Top1:', result) + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + args = parser.parse_args() + main() diff --git a/demo/auto_compression/image_classification/tf_imagenet_reader.py b/demo/auto_compression/tensorflow_mobilenet/imagenet_reader.py similarity index 66% rename from demo/auto_compression/image_classification/tf_imagenet_reader.py rename to demo/auto_compression/tensorflow_mobilenet/imagenet_reader.py index e2fb33de2b11c85d9b3c7ea6d38a168a482cd202..ed997581b337a77d36ef10a56def3a46aba7f095 100644 --- a/demo/auto_compression/image_classification/tf_imagenet_reader.py +++ b/demo/auto_compression/tensorflow_mobilenet/imagenet_reader.py @@ -128,54 +128,51 @@ def process_image(sample, mode, color_jitter, rotate): return [img] -def _reader_creator(file_list, - mode, - shuffle=False, - color_jitter=False, - rotate=False, - data_dir=DATA_DIR, - batch_size=1): - def reader(): - try: - with open(file_list) as flist: +class ImageNetDataset(Dataset): + def __init__(self, data_dir=DATA_DIR, mode='train'): + super(ImageNetDataset, self).__init__() + self.data_dir = data_dir + train_file_list = os.path.join(data_dir, 'train_list.txt') + val_file_list = os.path.join(data_dir, 'val_list.txt') + test_file_list = os.path.join(data_dir, 'test_list.txt') + self.mode = mode + if mode == 'train': + with open(train_file_list) as flist: full_lines = [line.strip() for line in flist] - if shuffle: - np.random.shuffle(full_lines) + np.random.shuffle(full_lines) lines = full_lines - for line in lines: - if mode == 'train' or mode == 'val': - img_path, label = line.split() - img_path = os.path.join(data_dir, img_path) - yield img_path, int(label) + 1 - elif mode == 'test': - img_path = os.path.join(data_dir, line) - yield [img_path] - except Exception as e: - print("Reader failed!\n{}".format(str(e))) - os._exit(1) - - mapper = functools.partial( - process_image, mode=mode, color_jitter=color_jitter, rotate=rotate) - - return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) - - -def train(data_dir=DATA_DIR): - file_list = os.path.join(data_dir, 'train_list.txt') - return _reader_creator( - file_list, - 'train', - shuffle=True, - color_jitter=False, - rotate=False, - data_dir=data_dir) - - -def val(data_dir=DATA_DIR): - file_list = os.path.join(data_dir, 'val_list.txt') - return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir) - - -def test(data_dir=DATA_DIR): - file_list = os.path.join(data_dir, 'test_list.txt') - return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir) + self.data = [line.split() for line in lines] + else: + with open(val_file_list) as flist: + lines = [line.strip() for line in flist] + self.data = [line.split() for line in lines] + + def __getitem__(self, index): + sample = self.data[index] + data_path = os.path.join(self.data_dir, sample[0]) + if self.mode == 'train': + data, label = process_image( + [data_path, sample[1]], + mode='train', + color_jitter=False, + rotate=False) + return np.array(data).astype('float32'), ( + np.array([label]).astype('int64') + 1) + elif self.mode == 'val': + data, label = process_image( + [data_path, sample[1]], + mode='val', + color_jitter=False, + rotate=False) + return np.array(data).astype('float32'), ( + np.array([label]).astype('int64') + 1) + elif self.mode == 'test': + data = process_image( + [data_path, sample[1]], + mode='test', + color_jitter=False, + rotate=False) + return np.array(data).astype('float32') + + def __len__(self): + return len(self.data) diff --git a/demo/auto_compression/tensorflow_mobilenet/run.py b/demo/auto_compression/tensorflow_mobilenet/run.py new file mode 100644 index 0000000000000000000000000000000000000000..86345ec2071a95b3e630120f274cb1e6e4b99ba6 --- /dev/null +++ b/demo/auto_compression/tensorflow_mobilenet/run.py @@ -0,0 +1,142 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import argparse +import functools +from functools import partial + +import numpy as np +import paddle +import paddle.nn as nn +from paddle.io import DataLoader +from imagenet_reader import ImageNetDataset +from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.auto_compression import AutoCompression + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--save_dir', + type=str, + default='output', + help="directory to save compressed model.") + return parser + + +# yapf: enable +def reader_wrapper(reader, input_name): + def gen(): + for i, (imgs, label) in enumerate(reader()): + yield {input_name: imgs} + + return gen + + +def eval_reader(data_dir, batch_size): + val_reader = ImageNetDataset(mode='val', data_dir=data_dir) + val_loader = DataLoader( + val_reader, + batch_size=global_config['batch_size'], + shuffle=False, + drop_last=False, + num_workers=0) + return val_loader + + +def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): + val_loader = eval_reader(data_dir, batch_size=global_config['batch_size']) + + results = [] + for batch_id, (image, label) in enumerate(val_loader): + # top1_acc, top5_acc + if len(test_feed_names) == 1: + image = np.array(image) + label = np.array(label).astype('int64') + pred = exe.run(compiled_test_program, + feed={test_feed_names[0]: image}, + fetch_list=test_fetch_list) + pred = np.array(pred[0]) + label = np.array(label) + sort_array = pred.argsort(axis=1) + top_1_pred = sort_array[:, -1:][:, ::-1] + top_1 = np.mean(label == top_1_pred) + top_5_pred = sort_array[:, -5:][:, ::-1] + acc_num = 0 + for i in range(len(label)): + if label[i][0] in top_5_pred[i]: + acc_num += 1 + top_5 = float(acc_num) / len(label) + results.append([top_1, top_5]) + else: + # eval "eval model", which inputs are image and label, output is top1 and top5 accuracy + image = np.array(image) + label = np.array(label).astype('int64') + result = exe.run( + compiled_test_program, + feed={test_feed_names[0]: image, + test_feed_names[1]: label}, + fetch_list=test_fetch_list) + result = [np.mean(r) for r in result] + results.append(result) + if batch_id % 50 == 0: + print('Eval iter: ', batch_id) + result = np.mean(np.array(results), axis=0) + return result[0] + + +def main(): + global global_config + all_config = load_slim_config(args.config_path) + assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + global_config = all_config["Global"] + global data_dir + data_dir = global_config['data_dir'] + + train_dataset = ImageNetDataset(mode='train', data_dir=data_dir) + + train_loader = DataLoader( + train_dataset, + batch_size=global_config['batch_size'], + shuffle=True, + drop_last=True, + num_workers=0) + train_dataloader = reader_wrapper(train_loader, global_config['input_name']) + + ac = AutoCompression( + model_dir=global_config['model_dir'], + model_filename=global_config['model_filename'], + params_filename=global_config['params_filename'], + save_dir=args.save_dir, + config=all_config, + train_dataloader=train_dataloader, + eval_callback=eval_function) + + ac.compress() + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + args = parser.parse_args() + print_arguments(args) + main()