Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
f275cefa
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f275cefa
编写于
6月 29, 2022
作者:
C
ceci3
提交者:
GitHub
6月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add tf mbv1 demo (#1205)
上级
e7ef0299
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
516 addition
and
99 deletion
+516
-99
demo/auto_compression/detection/keypoint_utils.py
demo/auto_compression/detection/keypoint_utils.py
+1
-1
demo/auto_compression/image_classification/README.md
demo/auto_compression/image_classification/README.md
+0
-14
demo/auto_compression/image_classification/eval.py
demo/auto_compression/image_classification/eval.py
+14
-17
demo/auto_compression/image_classification/infer.py
demo/auto_compression/image_classification/infer.py
+14
-0
demo/auto_compression/image_classification/preprocess.py
demo/auto_compression/image_classification/preprocess.py
+0
-2
demo/auto_compression/image_classification/run.py
demo/auto_compression/image_classification/run.py
+14
-1
demo/auto_compression/image_classification/run_tf.sh
demo/auto_compression/image_classification/run_tf.sh
+0
-14
demo/auto_compression/pytorch_huggingface/infer.py
demo/auto_compression/pytorch_huggingface/infer.py
+1
-1
demo/auto_compression/semantic_segmentation/run.py
demo/auto_compression/semantic_segmentation/run.py
+14
-0
demo/auto_compression/tensorflow_mobilenet/README.md
demo/auto_compression/tensorflow_mobilenet/README.md
+100
-0
demo/auto_compression/tensorflow_mobilenet/configs/mbv1_qat_dis.yaml
...ompression/tensorflow_mobilenet/configs/mbv1_qat_dis.yaml
+62
-0
demo/auto_compression/tensorflow_mobilenet/eval.py
demo/auto_compression/tensorflow_mobilenet/eval.py
+108
-0
demo/auto_compression/tensorflow_mobilenet/imagenet_reader.py
.../auto_compression/tensorflow_mobilenet/imagenet_reader.py
+46
-49
demo/auto_compression/tensorflow_mobilenet/run.py
demo/auto_compression/tensorflow_mobilenet/run.py
+142
-0
未找到文件。
demo/auto_compression/detection/keypoint_utils.py
浏览文件 @
f275cefa
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
demo/auto_compression/image_classification/README.md
浏览文件 @
f275cefa
...
...
@@ -50,20 +50,6 @@
-
软件:CUDA 11.2, cuDNN 8.0, TensorRT 8.4
-
测试配置:batch_size: 1, image size: 224
### TensorFlow MobileNetV1模型
| 模型 | 策略 | Top-1 Acc | 耗时(ms) threads=1 | Inference模型 |
|:------:|:------:|:------:|:------:|:------:|
| MobileNetV1 | Base模型 | 71.0 | 30.45 |
[
Model
](
https://paddle-slim-models.bj.bcebos.com/act/mobilenetv1_inference_model_tf2paddle.tar
)
|
| MobileNetV1 | 量化+蒸馏 | 70.22 | 15.86 |
[
Model
](
https://paddle-slim-models.bj.bcebos.com/act/mobilenetv1_quant.tar
)
|
-
测试环境:
`骁龙865 4*A77 4*A55`
说明:
-
MobileNetV1模型源自
[
tensorflow/models
](
http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz
)
## 3. 自动压缩流程
#### 3.1 准备环境
...
...
demo/auto_compression/image_classification/eval.py
浏览文件 @
f275cefa
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
sys
.
path
[
0
]
=
os
.
path
.
join
(
...
...
@@ -12,7 +26,6 @@ import paddle.nn as nn
from
paddle.io
import
Dataset
,
BatchSampler
,
DataLoader
import
imagenet_reader
as
reader
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
from
paddleslim.auto_compression
import
AutoCompression
def
argsparser
():
...
...
@@ -23,22 +36,6 @@ def argsparser():
default
=
None
,
help
=
"path of compression strategy config."
,
required
=
True
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
'output'
,
help
=
"directory to save compressed model."
)
return
parser
# yapf: enable
def
reader_wrapper
(
reader
,
input_name
):
def
gen
():
for
i
,
data
in
enumerate
(
reader
()):
imgs
=
np
.
float32
([
item
[
0
]
for
item
in
data
])
yield
{
input_name
:
imgs
}
return
gen
def
eval_reader
(
data_dir
,
batch_size
):
...
...
demo/auto_compression/image_classification/infer.py
浏览文件 @
f275cefa
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
cv2
...
...
demo/auto_compression/image_classification/preprocess.py
浏览文件 @
f275cefa
"""
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -12,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
demo/auto_compression/image_classification/run.py
浏览文件 @
f275cefa
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
sys
.
path
[
0
]
=
os
.
path
.
join
(
...
...
@@ -14,7 +28,6 @@ from paddle.io import Dataset, BatchSampler, DataLoader
import
imagenet_reader
as
reader
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
from
paddleslim.auto_compression
import
AutoCompression
from
utility
import
add_arguments
,
print_arguments
def
argsparser
():
...
...
demo/auto_compression/image_classification/run_tf.sh
已删除
100644 → 0
浏览文件 @
e7ef0299
# 单卡启动
export
CUDA_VISIBLE_DEVICES
=
0
python run.py
\
--model_dir
=
'inference_model_usex2paddle'
\
--model_filename
=
'model.pdmodel'
\
--params_filename
=
'model.pdiparams'
\
--save_dir
=
'./save_quant_mobilev1/'
\
--batch_size
=
128
\
--config_path
=
'./configs/mobilenetv1_qat_dis.yaml'
\
--input_shape
224 224 3
\
--image_reader_type
=
'tensorflow'
\
--input_name
"input"
\
--data_dir
=
'ILSVRC2012'
demo/auto_compression/pytorch_huggingface/infer.py
浏览文件 @
f275cefa
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
demo/auto_compression/semantic_segmentation/run.py
浏览文件 @
f275cefa
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
argparse
import
random
...
...
demo/auto_compression/tensorflow_mobilenet/README.md
0 → 100644
浏览文件 @
f275cefa
# TensorFlow图像分类模型自动压缩示例
目录:
-
[
1. 简介
](
#1简介
)
-
[
2. Benchmark
](
#2Benchmark
)
-
[
3. 自动压缩流程
](
#自动压缩流程
)
-
[
3.1 准备环境
](
#31-准备准备
)
-
[
3.2 准备数据集
](
#32-准备数据集
)
-
[
3.3 X2Paddle转换模型流程
](
#33-X2Paddle转换模型流程
)
-
[
3.4 自动压缩并产出模型
](
#34-自动压缩并产出模型
)
-
[
4. 预测部署
](
#4预测部署
)
-
[
5. FAQ
](
5FAQ
)
## 1. 简介
飞桨模型转换工具
[
X2Paddle
](
https://github.com/PaddlePaddle/X2Paddle
)
支持将
```Caffe/TensorFlow/ONNX/PyTorch```
的模型一键转为飞桨(PaddlePaddle)的预测模型。借助X2Paddle的能力,PaddleSlim的自动压缩功能可方便地用于各种框架的推理模型。
本示例将以
[
TensorFlow
](
https://github.com/tensorflow/tensorflow
)
框架的MobileNetV1模型为例,介绍如何自动压缩其他框架中的图像分类模型。本示例会利用
[
TensorFlow
](
https://github.com/tensorflow/models
)
开源models库,将TensorFlow框架模型转换为Paddle框架模型,再使用ACT自动压缩功能进行自动压缩。本示例使用的自动压缩策略为量化训练。
## 2. Benchmark
| 模型 | 策略 | Top-1 Acc | 耗时(ms) threads=1 | Inference模型 |
|:------:|:------:|:------:|:------:|:------:|
| MobileNetV1 | Base模型 | 71.0 | 30.45 |
[
Model
](
https://paddle-slim-models.bj.bcebos.com/act/mobilenetv1_inference_model_tf2paddle.tar
)
|
| MobileNetV1 | 量化+蒸馏 | 70.22 | 15.86 |
[
Model
](
https://paddle-slim-models.bj.bcebos.com/act/mobilenetv1_quant.tar
)
|
-
测试环境:
`骁龙865 4*A77 4*A55`
说明:
-
MobileNetV1模型源自
[
tensorflow/models
](
http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz
)
## 3. 自动压缩流程
#### 3.1 准备环境
-
PaddlePaddle >= 2.3 (可从
[
Paddle官网
](
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html
)
下载安装)
-
PaddleSlim develop版本
-
[
X2Paddle
](
https://github.com/PaddlePaddle/X2Paddle
)
>= 1.3.6
-
opencv-python
(1)安装paddlepaddle:
```
shell
# CPU
pip
install
paddlepaddle
# GPU
pip
install
paddlepaddle-gpu
```
(2)安装paddleslim:
```
shell
https://github.com/PaddlePaddle/PaddleSlim.git
python setup.py
install
```
(3)安装TensorFlow:
```
shell
pip
install
tensorflow
==
1.14
```
(3)安装X2Paddle的1.3.6以上版本:
```
shell
pip
install
x2paddle
```
#### 3.2 准备数据集
本案例默认以ImageNet1k数据进行自动压缩实验。
#### 3.3 准备预测模型
(1)转换模型
```
x2paddle --framework=tensorflow --model=tf_model.pb --save_dir=pd_model
```
即可得到MobileNetV1模型的预测模型(
`model.pdmodel`
和
`model.pdiparams`
)。如想快速体验,可直接下载上方表格中MobileNetV1的
[
Base模型
](
https://paddle-slim-models.bj.bcebos.com/act/mobilenetv1_inference_model_tf2paddle.tar
)
。
预测模型的格式为:
`model.pdmodel`
和
`model.pdiparams`
两个,带
`pdmodel`
的是模型文件,带
`pdiparams`
后缀的是权重文件。
### 3.4 自动压缩并产出模型
蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口
```paddleslim.auto_compression.AutoCompression```
对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为:
```
# 单卡
export CUDA_VISIBLE_DEVICES=0
python run.py --config_path=./configs/mbv1_qat_dis.yaml --save_dir='./output/'
```
#### 3.5 测试模型精度
使用eval.py脚本得到模型的mAP:
```
export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path=./configs/mbv1_qat_dis.yaml
```
## 4.预测部署
#### 4.1 PaddleLite端侧部署
PaddleLite端侧部署可参考:
-
[
Paddle Lite部署
](
https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/inference_deployment/paddle_lite_deploy.md
)
## 5.FAQ
demo/auto_compression/tensorflow_mobilenet/configs/mbv1_qat_dis.yaml
0 → 100644
浏览文件 @
f275cefa
Global
:
input_name
:
input
model_dir
:
inference_model_usex2paddle
model_filename
:
model.pdmodel
params_filename
:
model.pdiparams
batch_size
:
32
data_dir
:
./ILSVRC2012
Distillation
:
alpha
:
1.0
loss
:
l2
node
:
-
batch_norm_0.tmp_3
-
batch_norm_1.tmp_3
-
batch_norm_2.tmp_3
-
batch_norm_3.tmp_3
-
batch_norm_4.tmp_3
-
batch_norm_5.tmp_3
-
batch_norm_6.tmp_3
-
batch_norm_7.tmp_3
-
batch_norm_8.tmp_3
-
batch_norm_9.tmp_3
-
batch_norm_10.tmp_3
-
batch_norm_11.tmp_3
-
batch_norm_12.tmp_3
-
batch_norm_13.tmp_3
-
batch_norm_14.tmp_3
-
batch_norm_15.tmp_3
-
batch_norm_16.tmp_3
-
batch_norm_17.tmp_3
-
batch_norm_18.tmp_3
-
batch_norm_19.tmp_3
-
batch_norm_20.tmp_3
-
batch_norm_21.tmp_3
-
batch_norm_22.tmp_3
-
batch_norm_23.tmp_3
-
batch_norm_24.tmp_3
-
batch_norm_25.tmp_3
-
batch_norm_26.tmp_3
-
conv2d_42.tmp_1
Quantization
:
use_pact
:
true
activation_bits
:
8
is_full_quantize
:
false
not_quant_pattern
:
-
skip_quant
quantize_op_types
:
-
conv2d
-
depthwise_conv2d
weight_bits
:
8
activation_quantize_type
:
moving_average_abs_max
weight_quantize_type
:
channel_wise_abs_max
TrainConfig
:
epochs
:
1000
eval_iter
:
1000
learning_rate
:
0.00001
optimizer_builder
:
optimizer
:
type
:
SGD
weight_decay
:
4.0e-05
origin_metric
:
0.71028
demo/auto_compression/tensorflow_mobilenet/eval.py
0 → 100644
浏览文件 @
f275cefa
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
argparse
import
functools
from
functools
import
partial
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
from
paddle.io
import
DataLoader
from
imagenet_reader
import
ImageNetDataset
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--config_path'
,
type
=
str
,
default
=
None
,
help
=
"path of compression strategy config."
,
required
=
True
)
return
parser
def
eval_reader
(
data_dir
,
batch_size
):
val_reader
=
ImageNetDataset
(
mode
=
'val'
,
data_dir
=
data_dir
)
val_loader
=
DataLoader
(
val_reader
,
batch_size
=
global_config
[
'batch_size'
],
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
0
)
return
val_loader
def
eval
():
devices
=
paddle
.
device
.
get_device
().
split
(
':'
)[
0
]
places
=
paddle
.
device
.
_convert_to_place
(
devices
)
exe
=
paddle
.
static
.
Executor
(
places
)
val_program
,
feed_target_names
,
fetch_targets
=
paddle
.
static
.
load_inference_model
(
global_config
[
"model_dir"
],
exe
,
model_filename
=
global_config
[
"model_filename"
],
params_filename
=
global_config
[
"params_filename"
])
print
(
'Loaded model from: {}'
.
format
(
global_config
[
"model_dir"
]))
val_reader
=
eval_reader
(
data_dir
,
batch_size
=
global_config
[
'batch_size'
])
image
=
paddle
.
static
.
data
(
name
=
global_config
[
'input_name'
],
shape
=
[
None
,
224
,
224
,
3
],
dtype
=
'float32'
)
label
=
paddle
.
static
.
data
(
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
results
=
[]
print
(
'Evaluating... It will take a while. Please wait...'
)
for
batch_id
,
(
image
,
label
)
in
enumerate
(
val_reader
):
# top1_acc, top5_acc
image
=
np
.
array
(
image
)
label
=
np
.
array
(
label
).
astype
(
'int64'
)
pred
=
exe
.
run
(
val_program
,
feed
=
{
feed_target_names
[
0
]:
image
},
fetch_list
=
fetch_targets
)
pred
=
np
.
array
(
pred
[
0
])
label
=
np
.
array
(
label
)
sort_array
=
pred
.
argsort
(
axis
=
1
)
top_1_pred
=
sort_array
[:,
-
1
:][:,
::
-
1
]
top_1
=
np
.
mean
(
label
==
top_1_pred
)
top_5_pred
=
sort_array
[:,
-
5
:][:,
::
-
1
]
acc_num
=
0
for
i
in
range
(
len
(
label
)):
if
label
[
i
][
0
]
in
top_5_pred
[
i
]:
acc_num
+=
1
top_5
=
float
(
acc_num
)
/
len
(
label
)
results
.
append
([
top_1
,
top_5
])
result
=
np
.
mean
(
np
.
array
(
results
),
axis
=
0
)
return
result
[
0
]
def
main
():
global
global_config
all_config
=
load_slim_config
(
args
.
config_path
)
assert
"Global"
in
all_config
,
f
"Key 'Global' not found in config file.
\n
{
all_config
}
"
global_config
=
all_config
[
"Global"
]
global
data_dir
data_dir
=
global_config
[
'data_dir'
]
result
=
eval
()
print
(
'Eval Top1:'
,
result
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
args
=
parser
.
parse_args
()
main
()
demo/auto_compression/
image_classification/tf_
imagenet_reader.py
→
demo/auto_compression/
tensorflow_mobilenet/
imagenet_reader.py
浏览文件 @
f275cefa
...
...
@@ -128,54 +128,51 @@ def process_image(sample, mode, color_jitter, rotate):
return
[
img
]
def
_reader_creator
(
file_list
,
mode
,
shuffle
=
False
,
color_jitter
=
False
,
rotate
=
False
,
data_dir
=
DATA_DIR
,
batch_size
=
1
):
def
reader
():
try
:
with
open
(
file_list
)
as
flist
:
class
ImageNetDataset
(
Dataset
):
def
__init__
(
self
,
data_dir
=
DATA_DIR
,
mode
=
'train'
):
super
(
ImageNetDataset
,
self
).
__init__
()
self
.
data_dir
=
data_dir
train_file_list
=
os
.
path
.
join
(
data_dir
,
'train_list.txt'
)
val_file_list
=
os
.
path
.
join
(
data_dir
,
'val_list.txt'
)
test_file_list
=
os
.
path
.
join
(
data_dir
,
'test_list.txt'
)
self
.
mode
=
mode
if
mode
==
'train'
:
with
open
(
train_
file_list
)
as
flist
:
full_lines
=
[
line
.
strip
()
for
line
in
flist
]
if
shuffle
:
np
.
random
.
shuffle
(
full_lines
)
np
.
random
.
shuffle
(
full_lines
)
lines
=
full_lines
for
line
in
lines
:
if
mode
==
'train'
or
mode
==
'val'
:
img_path
,
label
=
line
.
split
()
img_path
=
os
.
path
.
join
(
data_dir
,
img_path
)
yield
img_path
,
int
(
label
)
+
1
elif
mode
==
'test'
:
img_path
=
os
.
path
.
join
(
data_dir
,
line
)
yield
[
img_path
]
except
Exception
as
e
:
print
(
"Reader failed!
\n
{}"
.
format
(
str
(
e
)))
os
.
_exit
(
1
)
mapper
=
functools
.
partial
(
process_image
,
mode
=
mode
,
color_jitter
=
color_jitter
,
rotate
=
rotate
)
return
paddle
.
reader
.
xmap_readers
(
mapper
,
reader
,
THREAD
,
BUF_SIZE
)
def
train
(
data_dir
=
DATA_DIR
):
file_list
=
os
.
path
.
join
(
data_dir
,
'train_list.txt'
)
return
_reader_creator
(
file_list
,
'train'
,
shuffle
=
True
,
color_jitter
=
False
,
rotate
=
False
,
data_dir
=
data_dir
)
def
val
(
data_dir
=
DATA_DIR
):
file_list
=
os
.
path
.
join
(
data_dir
,
'val_list.txt'
)
return
_reader_creator
(
file_list
,
'val'
,
shuffle
=
False
,
data_dir
=
data_dir
)
def
test
(
data_dir
=
DATA_DIR
):
file_list
=
os
.
path
.
join
(
data_dir
,
'test_list.txt'
)
return
_reader_creator
(
file_list
,
'test'
,
shuffle
=
False
,
data_dir
=
data_dir
)
self
.
data
=
[
line
.
split
()
for
line
in
lines
]
else
:
with
open
(
val_file_list
)
as
flist
:
lines
=
[
line
.
strip
()
for
line
in
flist
]
self
.
data
=
[
line
.
split
()
for
line
in
lines
]
def
__getitem__
(
self
,
index
):
sample
=
self
.
data
[
index
]
data_path
=
os
.
path
.
join
(
self
.
data_dir
,
sample
[
0
])
if
self
.
mode
==
'train'
:
data
,
label
=
process_image
(
[
data_path
,
sample
[
1
]],
mode
=
'train'
,
color_jitter
=
False
,
rotate
=
False
)
return
np
.
array
(
data
).
astype
(
'float32'
),
(
np
.
array
([
label
]).
astype
(
'int64'
)
+
1
)
elif
self
.
mode
==
'val'
:
data
,
label
=
process_image
(
[
data_path
,
sample
[
1
]],
mode
=
'val'
,
color_jitter
=
False
,
rotate
=
False
)
return
np
.
array
(
data
).
astype
(
'float32'
),
(
np
.
array
([
label
]).
astype
(
'int64'
)
+
1
)
elif
self
.
mode
==
'test'
:
data
=
process_image
(
[
data_path
,
sample
[
1
]],
mode
=
'test'
,
color_jitter
=
False
,
rotate
=
False
)
return
np
.
array
(
data
).
astype
(
'float32'
)
def
__len__
(
self
):
return
len
(
self
.
data
)
demo/auto_compression/tensorflow_mobilenet/run.py
0 → 100644
浏览文件 @
f275cefa
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
argparse
import
functools
from
functools
import
partial
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
from
paddle.io
import
DataLoader
from
imagenet_reader
import
ImageNetDataset
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
from
paddleslim.auto_compression
import
AutoCompression
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--config_path'
,
type
=
str
,
default
=
None
,
help
=
"path of compression strategy config."
,
required
=
True
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
'output'
,
help
=
"directory to save compressed model."
)
return
parser
# yapf: enable
def
reader_wrapper
(
reader
,
input_name
):
def
gen
():
for
i
,
(
imgs
,
label
)
in
enumerate
(
reader
()):
yield
{
input_name
:
imgs
}
return
gen
def
eval_reader
(
data_dir
,
batch_size
):
val_reader
=
ImageNetDataset
(
mode
=
'val'
,
data_dir
=
data_dir
)
val_loader
=
DataLoader
(
val_reader
,
batch_size
=
global_config
[
'batch_size'
],
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
0
)
return
val_loader
def
eval_function
(
exe
,
compiled_test_program
,
test_feed_names
,
test_fetch_list
):
val_loader
=
eval_reader
(
data_dir
,
batch_size
=
global_config
[
'batch_size'
])
results
=
[]
for
batch_id
,
(
image
,
label
)
in
enumerate
(
val_loader
):
# top1_acc, top5_acc
if
len
(
test_feed_names
)
==
1
:
image
=
np
.
array
(
image
)
label
=
np
.
array
(
label
).
astype
(
'int64'
)
pred
=
exe
.
run
(
compiled_test_program
,
feed
=
{
test_feed_names
[
0
]:
image
},
fetch_list
=
test_fetch_list
)
pred
=
np
.
array
(
pred
[
0
])
label
=
np
.
array
(
label
)
sort_array
=
pred
.
argsort
(
axis
=
1
)
top_1_pred
=
sort_array
[:,
-
1
:][:,
::
-
1
]
top_1
=
np
.
mean
(
label
==
top_1_pred
)
top_5_pred
=
sort_array
[:,
-
5
:][:,
::
-
1
]
acc_num
=
0
for
i
in
range
(
len
(
label
)):
if
label
[
i
][
0
]
in
top_5_pred
[
i
]:
acc_num
+=
1
top_5
=
float
(
acc_num
)
/
len
(
label
)
results
.
append
([
top_1
,
top_5
])
else
:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image
=
np
.
array
(
image
)
label
=
np
.
array
(
label
).
astype
(
'int64'
)
result
=
exe
.
run
(
compiled_test_program
,
feed
=
{
test_feed_names
[
0
]:
image
,
test_feed_names
[
1
]:
label
},
fetch_list
=
test_fetch_list
)
result
=
[
np
.
mean
(
r
)
for
r
in
result
]
results
.
append
(
result
)
if
batch_id
%
50
==
0
:
print
(
'Eval iter: '
,
batch_id
)
result
=
np
.
mean
(
np
.
array
(
results
),
axis
=
0
)
return
result
[
0
]
def
main
():
global
global_config
all_config
=
load_slim_config
(
args
.
config_path
)
assert
"Global"
in
all_config
,
f
"Key 'Global' not found in config file.
\n
{
all_config
}
"
global_config
=
all_config
[
"Global"
]
global
data_dir
data_dir
=
global_config
[
'data_dir'
]
train_dataset
=
ImageNetDataset
(
mode
=
'train'
,
data_dir
=
data_dir
)
train_loader
=
DataLoader
(
train_dataset
,
batch_size
=
global_config
[
'batch_size'
],
shuffle
=
True
,
drop_last
=
True
,
num_workers
=
0
)
train_dataloader
=
reader_wrapper
(
train_loader
,
global_config
[
'input_name'
])
ac
=
AutoCompression
(
model_dir
=
global_config
[
'model_dir'
],
model_filename
=
global_config
[
'model_filename'
],
params_filename
=
global_config
[
'params_filename'
],
save_dir
=
args
.
save_dir
,
config
=
all_config
,
train_dataloader
=
train_dataloader
,
eval_callback
=
eval_function
)
ac
.
compress
()
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
args
=
parser
.
parse_args
()
print_arguments
(
args
)
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录