Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
79cbd735
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
79cbd735
编写于
2月 01, 2023
作者:
weixin_46524038
提交者:
cuicheng01
2月 01, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Aesthetic
上级
4fdcda7c
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
345 addition
and
8 deletion
+345
-8
deploy/configs/practical_models/aesthetic_score_predictor/inference_aesthetic_score_predictor.yaml
..._score_predictor/inference_aesthetic_score_predictor.yaml
+32
-0
deploy/images/practical/aesthetic_score_predictor/Highscore.png
.../images/practical/aesthetic_score_predictor/Highscore.png
+0
-0
deploy/images/practical/aesthetic_score_predictor/Lowscore.png
...y/images/practical/aesthetic_score_predictor/Lowscore.png
+0
-0
deploy/python/postprocess.py
deploy/python/postprocess.py
+16
-0
deploy/python/predict_cls.py
deploy/python/predict_cls.py
+9
-5
docs/zh_CN/models/practical_models/CLIP_large_patch14_224_aesthetic.md
...dels/practical_models/CLIP_large_patch14_224_aesthetic.md
+135
-0
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-1
ppcls/arch/backbone/model_zoo/foundation_vit.py
ppcls/arch/backbone/model_zoo/foundation_vit.py
+3
-1
ppcls/arch/backbone/variant_models/foundation_vit_variant.py
ppcls/arch/backbone/variant_models/foundation_vit_variant.py
+52
-0
ppcls/configs/practical_models/CLIP_large_patch14_224_aesthetic.yaml
...gs/practical_models/CLIP_large_patch14_224_aesthetic.yaml
+78
-0
ppcls/data/postprocess/__init__.py
ppcls/data/postprocess/__init__.py
+1
-1
ppcls/data/postprocess/scoreoutput.py
ppcls/data/postprocess/scoreoutput.py
+18
-0
未找到文件。
deploy/configs/practical_models/aesthetic_score_predictor/inference_aesthetic_score_predictor.yaml
0 → 100644
浏览文件 @
79cbd735
Global
:
infer_imgs
:
"
./images/practical/aesthetic_score_predictor/Highscore.png"
inference_model_dir
:
"
./models/CLIP_large_patch14_224_aesthetic_infer/"
batch_size
:
1
use_gpu
:
True
enable_mkldnn
:
False
cpu_num_threads
:
10
enable_benchmark
:
True
use_fp16
:
False
ir_optim
:
True
use_tensorrt
:
False
gpu_mem
:
8000
enable_profile
:
False
PreProcess
:
transform_ops
:
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
channel_num
:
3
-
ToCHWImage
:
PostProcess
:
main_indicator
:
ScoreOutput
ScoreOutput
:
decimal_places
:
2
\ No newline at end of file
deploy/images/practical/aesthetic_score_predictor/Highscore.png
0 → 100644
浏览文件 @
79cbd735
471.9 KB
deploy/images/practical/aesthetic_score_predictor/Lowscore.png
0 → 100644
浏览文件 @
79cbd735
321.0 KB
deploy/python/postprocess.py
浏览文件 @
79cbd735
...
...
@@ -17,6 +17,7 @@ import copy
import
shutil
from
functools
import
partial
import
importlib
import
numpy
import
numpy
as
np
import
paddle
import
paddle.nn.functional
as
F
...
...
@@ -147,6 +148,21 @@ class ThreshOutput(object):
return
multi_classification
(
x
)
class
ScoreOutput
(
object
):
def
__init__
(
self
,
decimal_places
):
self
.
decimal_places
=
decimal_places
def
__call__
(
self
,
x
,
file_names
=
None
):
y
=
[]
for
idx
,
probs
in
enumerate
(
x
):
score
=
np
.
around
(
x
[
idx
],
self
.
decimal_places
)
result
=
{
"scores"
:
score
}
if
file_names
is
not
None
:
result
[
"file_name"
]
=
file_names
[
idx
]
y
.
append
(
result
)
return
y
class
Topk
(
object
):
def
__init__
(
self
,
topk
=
1
,
class_id_map_file
=
None
,
delimiter
=
None
):
assert
isinstance
(
topk
,
(
int
,
))
...
...
deploy/python/predict_cls.py
浏览文件 @
79cbd735
...
...
@@ -142,13 +142,17 @@ def main(config):
print
(
"{}:
\t
{}"
.
format
(
filename
,
result_dict
))
else
:
filename
=
batch_names
[
number
]
clas_ids
=
result_dict
[
"class_ids"
]
scores_str
=
"[{}]"
.
format
(
", "
.
join
(
"{:.2f}"
.
format
(
r
)
for
r
in
result_dict
[
"scores"
]))
label_names
=
result_dict
[
"label_names"
]
print
(
"{}:
\t
class id(s): {}, score(s): {}, label_name(s): {}"
.
format
(
filename
,
clas_ids
,
scores_str
,
label_names
))
if
"class_ids"
in
result_dict
and
"label_names"
in
result_dict
:
clas_ids
=
result_dict
[
"class_ids"
]
label_names
=
result_dict
[
"label_names"
]
print
(
"{}:
\t
class id(s): {}, score(s): {}, label_name(s): {}"
.
format
(
filename
,
clas_ids
,
scores_str
,
label_names
))
else
:
print
(
"{}:
\t
score(s): {}"
.
format
(
filename
,
scores_str
))
batch_imgs
=
[]
batch_names
=
[]
if
cls_predictor
.
benchmark
:
...
...
docs/zh_CN/models/practical_models/CLIP_large_patch14_224_aesthetic.md
0 → 100644
浏览文件 @
79cbd735
# 美观度打分模型
------
## 目录
-
[
1. 模型和应用场景介绍
](
#1
)
-
[
2. 模型快速体验
](
#2
)
-
[
2.1 安装 paddlepaddle
](
#2.1
)
-
[
2.2 安装 paddleclas
](
#2.2
)
-
[
3. 模型预测
](
#3
)
-
[
3.1 模型预测
](
#3.1
)
-
[
3.1.1 基于训练引擎预测
](
#3.1.1
)
-
[
3.1.2 基于推理引擎预测
](
#3.1.2
)
<a
name=
"1"
></a>
## 1. 模型和应用场景介绍
该案例提供了用户使用 PaddleClas 的基于 CLIP_large_patch14_224 网络构建图像美观度打分的模型。该模型可以自动为图像打分,对于越符合人类审美的图像,得分越高,越不符合人类审美的图像,得分越低,可用于推荐和搜索等应用场景。本案例引用自
[
美观度
](
https://github.com/christophschuhmann/improved-aesthetic-predictor
)
,权重由官方权重转换而来。得分较高和得分较低的两张图片如下:
<center><img
src=
'https://user-images.githubusercontent.com/94225063/215502324-e22b72dc-bb6a-42fa-8f9d-d1069b74c6b7.jpg'
width=
800
></center>
可以看到,相比于右图,左图更加符合人类审美。
**备注:**
*
图片引用自
[
链接
](
http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html
)
,得分范围为1.00-8.00
<a
name=
"2"
></a>
## 2. 模型快速体验
<a
name=
"2.1"
></a>
### 2.1 安装 paddlepaddle
-
您的机器安装的是 CUDA9 或 CUDA10,请运行以下命令安装
```
bash
python3
-m
pip
install
paddlepaddle-gpu
-i
https://mirror.baidu.com/pypi/simple
```
-
您的机器是 CPU,请运行以下命令安装
```
bash
python3
-m
pip
install
paddlepaddle
-i
https://mirror.baidu.com/pypi/simple
```
更多的版本需求,请参照
[
飞桨官网安装文档
](
https://www.paddlepaddle.org.cn/install/quick
)
中的说明进行操作。
<a
name=
"2.2"
></a>
### 2.2 安装 paddleclas
请确保已clone本项目,本地构建安装:
```
cd path/to/PaddleClas
#使用下面的命令构建
python3 setup.py install
```
<a
name=
"3"
></a>
## 3. 模型预测
<a
name=
"3.1"
></a>
### 3.1模型预测
<a
name=
"3.1.1"
></a>
### 3.1.1 基于训练引擎预测
加载预训练模型,进行模型预测。在模型库的
`tools/infer.py`
中提供了完整的示例,只需执行下述命令即可完成模型预测:
```
python
python3
tools
/
infer
.
py
\
-
c
.
/
ppcls
/
configs
/
practical_models
/
CLIP_large_patch14_224_aesthetic
.
yaml
```
输出结果如下:
```
[{'scores': array([7.85], dtype=float32), 'file_name': 'deploy/images/practical/aesthetic_score_predictor/Highscore.png'}]
```
**备注:**
*
默认是对
`deploy/images/practical/aesthetic_score_predictor/Highscore.png`
进行打分,此处也可以通过增加字段
`-o Infer.infer_imgs=xxx`
对其他图片打分。
<a
name=
"3.1.2"
></a>
### 3.1.2 基于推理引擎预测
Paddle Inference 是飞桨的原生推理库, 作用于服务器端和云端,提供高性能的推理能力。相比于直接基于预训练模型进行预测,Paddle Inference可使用 MKLDNN、CUDNN、TensorRT 进行预测加速,从而实现更优的推理性能。更多关于 Paddle Inference 推理引擎的介绍,可以参考
[
Paddle Inference官网教程
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/infer/inference/inference_cn.html
)
。
选择直接下载的方式得到对应的 inference 模型:
```
cd deploy/models
# 下载 inference 模型并解压
wget https://paddleclas.bj.bcebos.com/models/practical/inference/CLIP_large_patch14_224_aesthetic_infer.tar && tar -xf CLIP_large_patch14_224_aesthetic_infer.tar
```
解压完毕后,
`models`
文件夹下应有如下文件结构:
```
├── CLIP_large_patch14_224_aesthetic_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
```
得到 inference 模型之后基于推理引擎进行预测:
返回
`deploy`
目录:
```
cd ../
```
运行下面的命令,对图像
`./images/practical/aesthetic_score_predictor/Highscore.png`
进行美观度打分。
```
shell
# 使用下面的命令使用 GPU 进行预测
python3.7 python/predict_cls.py
-c
./configs/practical_models/aesthetic_score_predictor/inference_aesthetic_score_predictor.yaml
# 使用下面的命令使用 CPU 进行预测
python3.7 python/predict_cls.py
-c
./configs/practical_models/aesthetic_score_predictor/inference_aesthetic_score_predictor.yaml
-o
Global.use_gpu
=
False
```
输出结果如下。
```
Highscore.png: score(s): [7.85]
```
ppcls/arch/backbone/__init__.py
浏览文件 @
79cbd735
...
...
@@ -38,7 +38,6 @@ from .model_zoo.dpn import DPN68, DPN92, DPN98, DPN107, DPN131
from
.model_zoo.dsnet
import
DSNet_tiny
,
DSNet_small
,
DSNet_base
from
.model_zoo.densenet
import
DenseNet121
,
DenseNet161
,
DenseNet169
,
DenseNet201
,
DenseNet264
from
.model_zoo.efficientnet
import
EfficientNetB0
,
EfficientNetB1
,
EfficientNetB2
,
EfficientNetB3
,
EfficientNetB4
,
EfficientNetB5
,
EfficientNetB6
,
EfficientNetB7
,
EfficientNetB0_small
from
.model_zoo.efficientnet_v2
import
EfficientNetV2_S
from
.model_zoo.resnest
import
ResNeSt50_fast_1s1x64d
,
ResNeSt50
,
ResNeSt101
,
ResNeSt200
,
ResNeSt269
from
.model_zoo.googlenet
import
GoogLeNet
from
.model_zoo.mobilenet_v2
import
MobileNetV2_x0_25
,
MobileNetV2_x0_5
,
MobileNetV2_x0_75
,
MobileNetV2
,
MobileNetV2_x1_5
,
MobileNetV2_x2_0
...
...
@@ -81,6 +80,7 @@ from .variant_models.vgg_variant import VGG19Sigmoid
from
.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
.variant_models.pp_lcnetv2_variant
import
PPLCNetV2_base_ShiTu
from
.variant_models.efficientnet_variant
import
EfficientNetB3_watermark
from
.variant_models.foundation_vit_variant
import
CLIP_large_patch14_224_aesthetic
from
.model_zoo.adaface_ir_net
import
AdaFace_IR_18
,
AdaFace_IR_34
,
AdaFace_IR_50
,
AdaFace_IR_101
,
AdaFace_IR_152
,
AdaFace_IR_SE_50
,
AdaFace_IR_SE_101
,
AdaFace_IR_SE_152
,
AdaFace_IR_SE_200
from
.model_zoo.wideresnet
import
WideResNet
from
.model_zoo.uniformer
import
UniFormer_small
,
UniFormer_small_plus
,
UniFormer_small_plus_dim64
,
UniFormer_base
,
UniFormer_base_ls
...
...
ppcls/arch/backbone/model_zoo/foundation_vit.py
浏览文件 @
79cbd735
...
...
@@ -23,6 +23,8 @@ import paddle.nn as nn
import
sys
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
,
Normal
from
....utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
MODEL_URLS
=
{
"CLIP_small_patch16_224"
:
None
,
"CLIP_base_patch32_224"
:
None
,
...
...
@@ -885,4 +887,4 @@ def CAE_base_patch16_224(pretrained=False, use_ssld=False, **kwargs):
**
kwargs
,
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
model_name
],
use_ssld
=
use_ssld
)
return
model
return
model
\ No newline at end of file
ppcls/arch/backbone/variant_models/foundation_vit_variant.py
0 → 100644
浏览文件 @
79cbd735
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
..model_zoo.foundation_vit
import
CLIP_large_patch14_224
,
_load_pretrained
MODEL_URLS
=
{
"CLIP_large_patch14_224_aesthetic"
:
"https://paddleclas.bj.bcebos.com/models/practical/pretrained/CLIP_large_patch14_224_aesthetic_pretrained.pdparams"
}
__all__
=
list
(
MODEL_URLS
.
keys
())
class
MLP
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
):
super
().
__init__
()
self
.
input_size
=
input_size
self
.
layers
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
input_size
,
1024
),
nn
.
Dropout
(
0.2
),
nn
.
Linear
(
1024
,
128
),
nn
.
Dropout
(
0.2
),
nn
.
Linear
(
128
,
64
),
nn
.
Dropout
(
0.1
),
nn
.
Linear
(
64
,
16
),
nn
.
Linear
(
16
,
1
))
def
forward
(
self
,
x
):
return
self
.
layers
(
x
)
class
Aesthetic_Score_Predictor
(
nn
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
self
.
model
=
CLIP_large_patch14_224
()
self
.
fc_head
=
nn
.
Linear
(
1024
,
768
,
bias_attr
=
False
)
self
.
mlp
=
MLP
(
768
)
def
forward
(
self
,
x
):
x
=
self
.
model
(
x
)
x
=
x
[:,
0
,
:]
x
=
self
.
fc_head
(
x
)
x
=
F
.
normalize
(
x
,
p
=
2
,
axis
=-
1
)
x
=
self
.
mlp
(
x
)
return
x
def
CLIP_large_patch14_224_aesthetic
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
Aesthetic_Score_Predictor
()
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"CLIP_large_patch14_224_aesthetic"
],
use_ssld
)
return
model
ppcls/configs/practical_models/CLIP_large_patch14_224_aesthetic.yaml
0 → 100644
浏览文件 @
79cbd735
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
50
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
use_dali
:
False
# model architecture
Arch
:
name
:
CLIP_large_patch14_224_aesthetic
pretrained
:
True
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/
cls_label_path
:
./dataset/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
False
Infer
:
infer_imgs
:
deploy/images/practical/aesthetic_score_predictor/Highscore.png
batch_size
:
1
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
ScoreOutput
decimal_places
:
2
Metric
:
Eval
:
-
TopkAcc
:
topk
:
[
1
,
2
]
\ No newline at end of file
ppcls/data/postprocess/__init__.py
浏览文件 @
79cbd735
...
...
@@ -19,7 +19,7 @@ from . import topk, threshoutput
from
.topk
import
Topk
from
.threshoutput
import
ThreshOutput
,
MultiLabelThreshOutput
from
.attr_rec
import
VehicleAttribute
,
PersonAttribute
,
TableAttribute
from
.scoreoutput
import
ScoreOutput
def
build_postprocess
(
config
):
...
...
ppcls/data/postprocess/scoreoutput.py
0 → 100644
浏览文件 @
79cbd735
import
numpy
import
numpy
as
np
import
paddle
class
ScoreOutput
(
object
):
def
__init__
(
self
,
decimal_places
):
self
.
decimal_places
=
decimal_places
def
__call__
(
self
,
x
,
file_names
=
None
):
y
=
[]
for
idx
,
probs
in
enumerate
(
x
):
score
=
np
.
around
(
x
[
idx
].
numpy
(),
self
.
decimal_places
)
result
=
{
"scores"
:
score
}
if
file_names
is
not
None
:
result
[
"file_name"
]
=
file_names
[
idx
]
y
.
append
(
result
)
return
y
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录