Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
79cbd735
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
79cbd735
编写于
2月 01, 2023
作者:
weixin_46524038
提交者:
cuicheng01
2月 01, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Aesthetic
上级
4fdcda7c
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
345 addition
and
8 deletion
+345
-8
deploy/configs/practical_models/aesthetic_score_predictor/inference_aesthetic_score_predictor.yaml
..._score_predictor/inference_aesthetic_score_predictor.yaml
+32
-0
deploy/images/practical/aesthetic_score_predictor/Highscore.png
.../images/practical/aesthetic_score_predictor/Highscore.png
+0
-0
deploy/images/practical/aesthetic_score_predictor/Lowscore.png
...y/images/practical/aesthetic_score_predictor/Lowscore.png
+0
-0
deploy/python/postprocess.py
deploy/python/postprocess.py
+16
-0
deploy/python/predict_cls.py
deploy/python/predict_cls.py
+9
-5
docs/zh_CN/models/practical_models/CLIP_large_patch14_224_aesthetic.md
...dels/practical_models/CLIP_large_patch14_224_aesthetic.md
+135
-0
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-1
ppcls/arch/backbone/model_zoo/foundation_vit.py
ppcls/arch/backbone/model_zoo/foundation_vit.py
+3
-1
ppcls/arch/backbone/variant_models/foundation_vit_variant.py
ppcls/arch/backbone/variant_models/foundation_vit_variant.py
+52
-0
ppcls/configs/practical_models/CLIP_large_patch14_224_aesthetic.yaml
...gs/practical_models/CLIP_large_patch14_224_aesthetic.yaml
+78
-0
ppcls/data/postprocess/__init__.py
ppcls/data/postprocess/__init__.py
+1
-1
ppcls/data/postprocess/scoreoutput.py
ppcls/data/postprocess/scoreoutput.py
+18
-0
未找到文件。
deploy/configs/practical_models/aesthetic_score_predictor/inference_aesthetic_score_predictor.yaml
0 → 100644
浏览文件 @
79cbd735
Global
:
infer_imgs
:
"
./images/practical/aesthetic_score_predictor/Highscore.png"
inference_model_dir
:
"
./models/CLIP_large_patch14_224_aesthetic_infer/"
batch_size
:
1
use_gpu
:
True
enable_mkldnn
:
False
cpu_num_threads
:
10
enable_benchmark
:
True
use_fp16
:
False
ir_optim
:
True
use_tensorrt
:
False
gpu_mem
:
8000
enable_profile
:
False
PreProcess
:
transform_ops
:
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
channel_num
:
3
-
ToCHWImage
:
PostProcess
:
main_indicator
:
ScoreOutput
ScoreOutput
:
decimal_places
:
2
\ No newline at end of file
deploy/images/practical/aesthetic_score_predictor/Highscore.png
0 → 100644
浏览文件 @
79cbd735
471.9 KB
deploy/images/practical/aesthetic_score_predictor/Lowscore.png
0 → 100644
浏览文件 @
79cbd735
321.0 KB
deploy/python/postprocess.py
浏览文件 @
79cbd735
...
@@ -17,6 +17,7 @@ import copy
...
@@ -17,6 +17,7 @@ import copy
import
shutil
import
shutil
from
functools
import
partial
from
functools
import
partial
import
importlib
import
importlib
import
numpy
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
...
@@ -147,6 +148,21 @@ class ThreshOutput(object):
...
@@ -147,6 +148,21 @@ class ThreshOutput(object):
return
multi_classification
(
x
)
return
multi_classification
(
x
)
class
ScoreOutput
(
object
):
def
__init__
(
self
,
decimal_places
):
self
.
decimal_places
=
decimal_places
def
__call__
(
self
,
x
,
file_names
=
None
):
y
=
[]
for
idx
,
probs
in
enumerate
(
x
):
score
=
np
.
around
(
x
[
idx
],
self
.
decimal_places
)
result
=
{
"scores"
:
score
}
if
file_names
is
not
None
:
result
[
"file_name"
]
=
file_names
[
idx
]
y
.
append
(
result
)
return
y
class
Topk
(
object
):
class
Topk
(
object
):
def
__init__
(
self
,
topk
=
1
,
class_id_map_file
=
None
,
delimiter
=
None
):
def
__init__
(
self
,
topk
=
1
,
class_id_map_file
=
None
,
delimiter
=
None
):
assert
isinstance
(
topk
,
(
int
,
))
assert
isinstance
(
topk
,
(
int
,
))
...
...
deploy/python/predict_cls.py
浏览文件 @
79cbd735
...
@@ -142,13 +142,17 @@ def main(config):
...
@@ -142,13 +142,17 @@ def main(config):
print
(
"{}:
\t
{}"
.
format
(
filename
,
result_dict
))
print
(
"{}:
\t
{}"
.
format
(
filename
,
result_dict
))
else
:
else
:
filename
=
batch_names
[
number
]
filename
=
batch_names
[
number
]
clas_ids
=
result_dict
[
"class_ids"
]
scores_str
=
"[{}]"
.
format
(
", "
.
join
(
"{:.2f}"
.
format
(
scores_str
=
"[{}]"
.
format
(
", "
.
join
(
"{:.2f}"
.
format
(
r
)
for
r
in
result_dict
[
"scores"
]))
r
)
for
r
in
result_dict
[
"scores"
]))
if
"class_ids"
in
result_dict
and
"label_names"
in
result_dict
:
clas_ids
=
result_dict
[
"class_ids"
]
label_names
=
result_dict
[
"label_names"
]
label_names
=
result_dict
[
"label_names"
]
print
(
print
(
"{}:
\t
class id(s): {}, score(s): {}, label_name(s): {}"
.
"{}:
\t
class id(s): {}, score(s): {}, label_name(s): {}"
.
format
(
filename
,
clas_ids
,
scores_str
,
label_names
))
format
(
filename
,
clas_ids
,
scores_str
,
label_names
))
else
:
print
(
"{}:
\t
score(s): {}"
.
format
(
filename
,
scores_str
))
batch_imgs
=
[]
batch_imgs
=
[]
batch_names
=
[]
batch_names
=
[]
if
cls_predictor
.
benchmark
:
if
cls_predictor
.
benchmark
:
...
...
docs/zh_CN/models/practical_models/CLIP_large_patch14_224_aesthetic.md
0 → 100644
浏览文件 @
79cbd735
# 美观度打分模型
------
## 目录
-
[
1. 模型和应用场景介绍
](
#1
)
-
[
2. 模型快速体验
](
#2
)
-
[
2.1 安装 paddlepaddle
](
#2.1
)
-
[
2.2 安装 paddleclas
](
#2.2
)
-
[
3. 模型预测
](
#3
)
-
[
3.1 模型预测
](
#3.1
)
-
[
3.1.1 基于训练引擎预测
](
#3.1.1
)
-
[
3.1.2 基于推理引擎预测
](
#3.1.2
)
<a
name=
"1"
></a>
## 1. 模型和应用场景介绍
该案例提供了用户使用 PaddleClas 的基于 CLIP_large_patch14_224 网络构建图像美观度打分的模型。该模型可以自动为图像打分,对于越符合人类审美的图像,得分越高,越不符合人类审美的图像,得分越低,可用于推荐和搜索等应用场景。本案例引用自
[
美观度
](
https://github.com/christophschuhmann/improved-aesthetic-predictor
)
,权重由官方权重转换而来。得分较高和得分较低的两张图片如下:
<center><img
src=
'https://user-images.githubusercontent.com/94225063/215502324-e22b72dc-bb6a-42fa-8f9d-d1069b74c6b7.jpg'
width=
800
></center>
可以看到,相比于右图,左图更加符合人类审美。
**备注:**
*
图片引用自
[
链接
](
http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html
)
,得分范围为1.00-8.00
<a
name=
"2"
></a>
## 2. 模型快速体验
<a
name=
"2.1"
></a>
### 2.1 安装 paddlepaddle
-
您的机器安装的是 CUDA9 或 CUDA10,请运行以下命令安装
```
bash
python3
-m
pip
install
paddlepaddle-gpu
-i
https://mirror.baidu.com/pypi/simple
```
-
您的机器是 CPU,请运行以下命令安装
```
bash
python3
-m
pip
install
paddlepaddle
-i
https://mirror.baidu.com/pypi/simple
```
更多的版本需求,请参照
[
飞桨官网安装文档
](
https://www.paddlepaddle.org.cn/install/quick
)
中的说明进行操作。
<a
name=
"2.2"
></a>
### 2.2 安装 paddleclas
请确保已clone本项目,本地构建安装:
```
cd path/to/PaddleClas
#使用下面的命令构建
python3 setup.py install
```
<a
name=
"3"
></a>
## 3. 模型预测
<a
name=
"3.1"
></a>
### 3.1模型预测
<a
name=
"3.1.1"
></a>
### 3.1.1 基于训练引擎预测
加载预训练模型,进行模型预测。在模型库的
`tools/infer.py`
中提供了完整的示例,只需执行下述命令即可完成模型预测:
```
python
python3
tools
/
infer
.
py
\
-
c
.
/
ppcls
/
configs
/
practical_models
/
CLIP_large_patch14_224_aesthetic
.
yaml
```
输出结果如下:
```
[{'scores': array([7.85], dtype=float32), 'file_name': 'deploy/images/practical/aesthetic_score_predictor/Highscore.png'}]
```
**备注:**
*
默认是对
`deploy/images/practical/aesthetic_score_predictor/Highscore.png`
进行打分,此处也可以通过增加字段
`-o Infer.infer_imgs=xxx`
对其他图片打分。
<a
name=
"3.1.2"
></a>
### 3.1.2 基于推理引擎预测
Paddle Inference 是飞桨的原生推理库, 作用于服务器端和云端,提供高性能的推理能力。相比于直接基于预训练模型进行预测,Paddle Inference可使用 MKLDNN、CUDNN、TensorRT 进行预测加速,从而实现更优的推理性能。更多关于 Paddle Inference 推理引擎的介绍,可以参考
[
Paddle Inference官网教程
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/infer/inference/inference_cn.html
)
。
选择直接下载的方式得到对应的 inference 模型:
```
cd deploy/models
# 下载 inference 模型并解压
wget https://paddleclas.bj.bcebos.com/models/practical/inference/CLIP_large_patch14_224_aesthetic_infer.tar && tar -xf CLIP_large_patch14_224_aesthetic_infer.tar
```
解压完毕后,
`models`
文件夹下应有如下文件结构:
```
├── CLIP_large_patch14_224_aesthetic_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
```
得到 inference 模型之后基于推理引擎进行预测:
返回
`deploy`
目录:
```
cd ../
```
运行下面的命令,对图像
`./images/practical/aesthetic_score_predictor/Highscore.png`
进行美观度打分。
```
shell
# 使用下面的命令使用 GPU 进行预测
python3.7 python/predict_cls.py
-c
./configs/practical_models/aesthetic_score_predictor/inference_aesthetic_score_predictor.yaml
# 使用下面的命令使用 CPU 进行预测
python3.7 python/predict_cls.py
-c
./configs/practical_models/aesthetic_score_predictor/inference_aesthetic_score_predictor.yaml
-o
Global.use_gpu
=
False
```
输出结果如下。
```
Highscore.png: score(s): [7.85]
```
ppcls/arch/backbone/__init__.py
浏览文件 @
79cbd735
...
@@ -38,7 +38,6 @@ from .model_zoo.dpn import DPN68, DPN92, DPN98, DPN107, DPN131
...
@@ -38,7 +38,6 @@ from .model_zoo.dpn import DPN68, DPN92, DPN98, DPN107, DPN131
from
.model_zoo.dsnet
import
DSNet_tiny
,
DSNet_small
,
DSNet_base
from
.model_zoo.dsnet
import
DSNet_tiny
,
DSNet_small
,
DSNet_base
from
.model_zoo.densenet
import
DenseNet121
,
DenseNet161
,
DenseNet169
,
DenseNet201
,
DenseNet264
from
.model_zoo.densenet
import
DenseNet121
,
DenseNet161
,
DenseNet169
,
DenseNet201
,
DenseNet264
from
.model_zoo.efficientnet
import
EfficientNetB0
,
EfficientNetB1
,
EfficientNetB2
,
EfficientNetB3
,
EfficientNetB4
,
EfficientNetB5
,
EfficientNetB6
,
EfficientNetB7
,
EfficientNetB0_small
from
.model_zoo.efficientnet
import
EfficientNetB0
,
EfficientNetB1
,
EfficientNetB2
,
EfficientNetB3
,
EfficientNetB4
,
EfficientNetB5
,
EfficientNetB6
,
EfficientNetB7
,
EfficientNetB0_small
from
.model_zoo.efficientnet_v2
import
EfficientNetV2_S
from
.model_zoo.resnest
import
ResNeSt50_fast_1s1x64d
,
ResNeSt50
,
ResNeSt101
,
ResNeSt200
,
ResNeSt269
from
.model_zoo.resnest
import
ResNeSt50_fast_1s1x64d
,
ResNeSt50
,
ResNeSt101
,
ResNeSt200
,
ResNeSt269
from
.model_zoo.googlenet
import
GoogLeNet
from
.model_zoo.googlenet
import
GoogLeNet
from
.model_zoo.mobilenet_v2
import
MobileNetV2_x0_25
,
MobileNetV2_x0_5
,
MobileNetV2_x0_75
,
MobileNetV2
,
MobileNetV2_x1_5
,
MobileNetV2_x2_0
from
.model_zoo.mobilenet_v2
import
MobileNetV2_x0_25
,
MobileNetV2_x0_5
,
MobileNetV2_x0_75
,
MobileNetV2
,
MobileNetV2_x1_5
,
MobileNetV2_x2_0
...
@@ -81,6 +80,7 @@ from .variant_models.vgg_variant import VGG19Sigmoid
...
@@ -81,6 +80,7 @@ from .variant_models.vgg_variant import VGG19Sigmoid
from
.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
.variant_models.pp_lcnetv2_variant
import
PPLCNetV2_base_ShiTu
from
.variant_models.pp_lcnetv2_variant
import
PPLCNetV2_base_ShiTu
from
.variant_models.efficientnet_variant
import
EfficientNetB3_watermark
from
.variant_models.efficientnet_variant
import
EfficientNetB3_watermark
from
.variant_models.foundation_vit_variant
import
CLIP_large_patch14_224_aesthetic
from
.model_zoo.adaface_ir_net
import
AdaFace_IR_18
,
AdaFace_IR_34
,
AdaFace_IR_50
,
AdaFace_IR_101
,
AdaFace_IR_152
,
AdaFace_IR_SE_50
,
AdaFace_IR_SE_101
,
AdaFace_IR_SE_152
,
AdaFace_IR_SE_200
from
.model_zoo.adaface_ir_net
import
AdaFace_IR_18
,
AdaFace_IR_34
,
AdaFace_IR_50
,
AdaFace_IR_101
,
AdaFace_IR_152
,
AdaFace_IR_SE_50
,
AdaFace_IR_SE_101
,
AdaFace_IR_SE_152
,
AdaFace_IR_SE_200
from
.model_zoo.wideresnet
import
WideResNet
from
.model_zoo.wideresnet
import
WideResNet
from
.model_zoo.uniformer
import
UniFormer_small
,
UniFormer_small_plus
,
UniFormer_small_plus_dim64
,
UniFormer_base
,
UniFormer_base_ls
from
.model_zoo.uniformer
import
UniFormer_small
,
UniFormer_small_plus
,
UniFormer_small_plus_dim64
,
UniFormer_base
,
UniFormer_base_ls
...
...
ppcls/arch/backbone/model_zoo/foundation_vit.py
浏览文件 @
79cbd735
...
@@ -23,6 +23,8 @@ import paddle.nn as nn
...
@@ -23,6 +23,8 @@ import paddle.nn as nn
import
sys
import
sys
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
,
Normal
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
,
Normal
from
....utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
MODEL_URLS
=
{
MODEL_URLS
=
{
"CLIP_small_patch16_224"
:
None
,
"CLIP_small_patch16_224"
:
None
,
"CLIP_base_patch32_224"
:
None
,
"CLIP_base_patch32_224"
:
None
,
...
...
ppcls/arch/backbone/variant_models/foundation_vit_variant.py
0 → 100644
浏览文件 @
79cbd735
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
..model_zoo.foundation_vit
import
CLIP_large_patch14_224
,
_load_pretrained
MODEL_URLS
=
{
"CLIP_large_patch14_224_aesthetic"
:
"https://paddleclas.bj.bcebos.com/models/practical/pretrained/CLIP_large_patch14_224_aesthetic_pretrained.pdparams"
}
__all__
=
list
(
MODEL_URLS
.
keys
())
class
MLP
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
):
super
().
__init__
()
self
.
input_size
=
input_size
self
.
layers
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
input_size
,
1024
),
nn
.
Dropout
(
0.2
),
nn
.
Linear
(
1024
,
128
),
nn
.
Dropout
(
0.2
),
nn
.
Linear
(
128
,
64
),
nn
.
Dropout
(
0.1
),
nn
.
Linear
(
64
,
16
),
nn
.
Linear
(
16
,
1
))
def
forward
(
self
,
x
):
return
self
.
layers
(
x
)
class
Aesthetic_Score_Predictor
(
nn
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
self
.
model
=
CLIP_large_patch14_224
()
self
.
fc_head
=
nn
.
Linear
(
1024
,
768
,
bias_attr
=
False
)
self
.
mlp
=
MLP
(
768
)
def
forward
(
self
,
x
):
x
=
self
.
model
(
x
)
x
=
x
[:,
0
,
:]
x
=
self
.
fc_head
(
x
)
x
=
F
.
normalize
(
x
,
p
=
2
,
axis
=-
1
)
x
=
self
.
mlp
(
x
)
return
x
def
CLIP_large_patch14_224_aesthetic
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
Aesthetic_Score_Predictor
()
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"CLIP_large_patch14_224_aesthetic"
],
use_ssld
)
return
model
ppcls/configs/practical_models/CLIP_large_patch14_224_aesthetic.yaml
0 → 100644
浏览文件 @
79cbd735
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
50
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
use_dali
:
False
# model architecture
Arch
:
name
:
CLIP_large_patch14_224_aesthetic
pretrained
:
True
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/
cls_label_path
:
./dataset/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
False
Infer
:
infer_imgs
:
deploy/images/practical/aesthetic_score_predictor/Highscore.png
batch_size
:
1
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
ScoreOutput
decimal_places
:
2
Metric
:
Eval
:
-
TopkAcc
:
topk
:
[
1
,
2
]
\ No newline at end of file
ppcls/data/postprocess/__init__.py
浏览文件 @
79cbd735
...
@@ -19,7 +19,7 @@ from . import topk, threshoutput
...
@@ -19,7 +19,7 @@ from . import topk, threshoutput
from
.topk
import
Topk
from
.topk
import
Topk
from
.threshoutput
import
ThreshOutput
,
MultiLabelThreshOutput
from
.threshoutput
import
ThreshOutput
,
MultiLabelThreshOutput
from
.attr_rec
import
VehicleAttribute
,
PersonAttribute
,
TableAttribute
from
.attr_rec
import
VehicleAttribute
,
PersonAttribute
,
TableAttribute
from
.scoreoutput
import
ScoreOutput
def
build_postprocess
(
config
):
def
build_postprocess
(
config
):
...
...
ppcls/data/postprocess/scoreoutput.py
0 → 100644
浏览文件 @
79cbd735
import
numpy
import
numpy
as
np
import
paddle
class
ScoreOutput
(
object
):
def
__init__
(
self
,
decimal_places
):
self
.
decimal_places
=
decimal_places
def
__call__
(
self
,
x
,
file_names
=
None
):
y
=
[]
for
idx
,
probs
in
enumerate
(
x
):
score
=
np
.
around
(
x
[
idx
].
numpy
(),
self
.
decimal_places
)
result
=
{
"scores"
:
score
}
if
file_names
is
not
None
:
result
[
"file_name"
]
=
file_names
[
idx
]
y
.
append
(
result
)
return
y
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录