提交 79cbd735 编写于 作者: weixin_46524038's avatar weixin_46524038 提交者: cuicheng01

Aesthetic

上级 4fdcda7c
Global:
infer_imgs: "./images/practical/aesthetic_score_predictor/Highscore.png"
inference_model_dir: "./models/CLIP_large_patch14_224_aesthetic_infer/"
batch_size: 1
use_gpu: True
enable_mkldnn: False
cpu_num_threads: 10
enable_benchmark: True
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
PreProcess:
transform_ops:
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
channel_num: 3
- ToCHWImage:
PostProcess:
main_indicator: ScoreOutput
ScoreOutput:
decimal_places: 2
\ No newline at end of file
...@@ -17,6 +17,7 @@ import copy ...@@ -17,6 +17,7 @@ import copy
import shutil import shutil
from functools import partial from functools import partial
import importlib import importlib
import numpy
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -147,6 +148,21 @@ class ThreshOutput(object): ...@@ -147,6 +148,21 @@ class ThreshOutput(object):
return multi_classification(x) return multi_classification(x)
class ScoreOutput(object):
def __init__(self, decimal_places):
self.decimal_places = decimal_places
def __call__(self, x, file_names=None):
y = []
for idx, probs in enumerate(x):
score = np.around(x[idx], self.decimal_places)
result = {"scores": score}
if file_names is not None:
result["file_name"] = file_names[idx]
y.append(result)
return y
class Topk(object): class Topk(object):
def __init__(self, topk=1, class_id_map_file=None, delimiter=None): def __init__(self, topk=1, class_id_map_file=None, delimiter=None):
assert isinstance(topk, (int, )) assert isinstance(topk, (int, ))
......
...@@ -142,13 +142,17 @@ def main(config): ...@@ -142,13 +142,17 @@ def main(config):
print("{}:\t {}".format(filename, result_dict)) print("{}:\t {}".format(filename, result_dict))
else: else:
filename = batch_names[number] filename = batch_names[number]
clas_ids = result_dict["class_ids"]
scores_str = "[{}]".format(", ".join("{:.2f}".format( scores_str = "[{}]".format(", ".join("{:.2f}".format(
r) for r in result_dict["scores"])) r) for r in result_dict["scores"]))
if "class_ids" in result_dict and "label_names" in result_dict:
clas_ids = result_dict["class_ids"]
label_names = result_dict["label_names"] label_names = result_dict["label_names"]
print( print(
"{}:\tclass id(s): {}, score(s): {}, label_name(s): {}". "{}:\tclass id(s): {}, score(s): {}, label_name(s): {}".
format(filename, clas_ids, scores_str, label_names)) format(filename, clas_ids, scores_str,
label_names))
else:
print("{}:\tscore(s): {}".format(filename, scores_str))
batch_imgs = [] batch_imgs = []
batch_names = [] batch_names = []
if cls_predictor.benchmark: if cls_predictor.benchmark:
......
# 美观度打分模型
------
## 目录
- [1. 模型和应用场景介绍](#1)
- [2. 模型快速体验](#2)
- [2.1 安装 paddlepaddle](#2.1)
- [2.2 安装 paddleclas](#2.2)
- [3. 模型预测](#3)
- [3.1 模型预测](#3.1)
- [3.1.1 基于训练引擎预测](#3.1.1)
- [3.1.2 基于推理引擎预测](#3.1.2)
<a name="1"></a>
## 1. 模型和应用场景介绍
该案例提供了用户使用 PaddleClas 的基于 CLIP_large_patch14_224 网络构建图像美观度打分的模型。该模型可以自动为图像打分,对于越符合人类审美的图像,得分越高,越不符合人类审美的图像,得分越低,可用于推荐和搜索等应用场景。本案例引用自[美观度](https://github.com/christophschuhmann/improved-aesthetic-predictor),权重由官方权重转换而来。得分较高和得分较低的两张图片如下:
<center><img src='https://user-images.githubusercontent.com/94225063/215502324-e22b72dc-bb6a-42fa-8f9d-d1069b74c6b7.jpg' width=800></center>
可以看到,相比于右图,左图更加符合人类审美。
**备注:**
* 图片引用自[链接](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html),得分范围为1.00-8.00
<a name="2"></a>
## 2. 模型快速体验
<a name="2.1"></a>
### 2.1 安装 paddlepaddle
- 您的机器安装的是 CUDA9 或 CUDA10,请运行以下命令安装
```bash
python3 -m pip install paddlepaddle-gpu -i https://mirror.baidu.com/pypi/simple
```
- 您的机器是 CPU,请运行以下命令安装
```bash
python3 -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple
```
更多的版本需求,请参照[飞桨官网安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
<a name="2.2"></a>
### 2.2 安装 paddleclas
请确保已clone本项目,本地构建安装:
```
cd path/to/PaddleClas
#使用下面的命令构建
python3 setup.py install
```
<a name="3"></a>
## 3. 模型预测
<a name="3.1"></a>
### 3.1模型预测
<a name="3.1.1"></a>
### 3.1.1 基于训练引擎预测
加载预训练模型,进行模型预测。在模型库的 `tools/infer.py` 中提供了完整的示例,只需执行下述命令即可完成模型预测:
```python
python3 tools/infer.py \
-c ./ppcls/configs/practical_models/CLIP_large_patch14_224_aesthetic.yaml
```
输出结果如下:
```
[{'scores': array([7.85], dtype=float32), 'file_name': 'deploy/images/practical/aesthetic_score_predictor/Highscore.png'}]
```
**备注:**
* 默认是对 `deploy/images/practical/aesthetic_score_predictor/Highscore.png` 进行打分,此处也可以通过增加字段 `-o Infer.infer_imgs=xxx` 对其他图片打分。
<a name="3.1.2"></a>
### 3.1.2 基于推理引擎预测
Paddle Inference 是飞桨的原生推理库, 作用于服务器端和云端,提供高性能的推理能力。相比于直接基于预训练模型进行预测,Paddle Inference可使用 MKLDNN、CUDNN、TensorRT 进行预测加速,从而实现更优的推理性能。更多关于 Paddle Inference 推理引擎的介绍,可以参考 [Paddle Inference官网教程](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/infer/inference/inference_cn.html)
选择直接下载的方式得到对应的 inference 模型:
```
cd deploy/models
# 下载 inference 模型并解压
wget https://paddleclas.bj.bcebos.com/models/practical/inference/CLIP_large_patch14_224_aesthetic_infer.tar && tar -xf CLIP_large_patch14_224_aesthetic_infer.tar
```
解压完毕后,`models` 文件夹下应有如下文件结构:
```
├── CLIP_large_patch14_224_aesthetic_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
```
得到 inference 模型之后基于推理引擎进行预测:
返回 `deploy` 目录:
```
cd ../
```
运行下面的命令,对图像 `./images/practical/aesthetic_score_predictor/Highscore.png` 进行美观度打分。
```shell
# 使用下面的命令使用 GPU 进行预测
python3.7 python/predict_cls.py -c ./configs/practical_models/aesthetic_score_predictor/inference_aesthetic_score_predictor.yaml
# 使用下面的命令使用 CPU 进行预测
python3.7 python/predict_cls.py -c ./configs/practical_models/aesthetic_score_predictor/inference_aesthetic_score_predictor.yaml -o Global.use_gpu=False
```
输出结果如下。
```
Highscore.png: score(s): [7.85]
```
...@@ -38,7 +38,6 @@ from .model_zoo.dpn import DPN68, DPN92, DPN98, DPN107, DPN131 ...@@ -38,7 +38,6 @@ from .model_zoo.dpn import DPN68, DPN92, DPN98, DPN107, DPN131
from .model_zoo.dsnet import DSNet_tiny, DSNet_small, DSNet_base from .model_zoo.dsnet import DSNet_tiny, DSNet_small, DSNet_base
from .model_zoo.densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseNet264 from .model_zoo.densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseNet264
from .model_zoo.efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7, EfficientNetB0_small from .model_zoo.efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7, EfficientNetB0_small
from .model_zoo.efficientnet_v2 import EfficientNetV2_S
from .model_zoo.resnest import ResNeSt50_fast_1s1x64d, ResNeSt50, ResNeSt101, ResNeSt200, ResNeSt269 from .model_zoo.resnest import ResNeSt50_fast_1s1x64d, ResNeSt50, ResNeSt101, ResNeSt200, ResNeSt269
from .model_zoo.googlenet import GoogLeNet from .model_zoo.googlenet import GoogLeNet
from .model_zoo.mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0 from .model_zoo.mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0
...@@ -81,6 +80,7 @@ from .variant_models.vgg_variant import VGG19Sigmoid ...@@ -81,6 +80,7 @@ from .variant_models.vgg_variant import VGG19Sigmoid
from .variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh from .variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh
from .variant_models.pp_lcnetv2_variant import PPLCNetV2_base_ShiTu from .variant_models.pp_lcnetv2_variant import PPLCNetV2_base_ShiTu
from .variant_models.efficientnet_variant import EfficientNetB3_watermark from .variant_models.efficientnet_variant import EfficientNetB3_watermark
from .variant_models.foundation_vit_variant import CLIP_large_patch14_224_aesthetic
from .model_zoo.adaface_ir_net import AdaFace_IR_18, AdaFace_IR_34, AdaFace_IR_50, AdaFace_IR_101, AdaFace_IR_152, AdaFace_IR_SE_50, AdaFace_IR_SE_101, AdaFace_IR_SE_152, AdaFace_IR_SE_200 from .model_zoo.adaface_ir_net import AdaFace_IR_18, AdaFace_IR_34, AdaFace_IR_50, AdaFace_IR_101, AdaFace_IR_152, AdaFace_IR_SE_50, AdaFace_IR_SE_101, AdaFace_IR_SE_152, AdaFace_IR_SE_200
from .model_zoo.wideresnet import WideResNet from .model_zoo.wideresnet import WideResNet
from .model_zoo.uniformer import UniFormer_small, UniFormer_small_plus, UniFormer_small_plus_dim64, UniFormer_base, UniFormer_base_ls from .model_zoo.uniformer import UniFormer_small, UniFormer_small_plus, UniFormer_small_plus_dim64, UniFormer_base, UniFormer_base_ls
......
...@@ -23,6 +23,8 @@ import paddle.nn as nn ...@@ -23,6 +23,8 @@ import paddle.nn as nn
import sys import sys
from paddle.nn.initializer import TruncatedNormal, Constant, Normal from paddle.nn.initializer import TruncatedNormal, Constant, Normal
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = { MODEL_URLS = {
"CLIP_small_patch16_224": None, "CLIP_small_patch16_224": None,
"CLIP_base_patch32_224": None, "CLIP_base_patch32_224": None,
......
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ..model_zoo.foundation_vit import CLIP_large_patch14_224, _load_pretrained
MODEL_URLS = {
"CLIP_large_patch14_224_aesthetic":
"https://paddleclas.bj.bcebos.com/models/practical/pretrained/CLIP_large_patch14_224_aesthetic_pretrained.pdparams"
}
__all__ = list(MODEL_URLS.keys())
class MLP(nn.Layer):
def __init__(self, input_size):
super().__init__()
self.input_size = input_size
self.layers = nn.Sequential(
nn.Linear(self.input_size, 1024),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Dropout(0.1), nn.Linear(64, 16), nn.Linear(16, 1))
def forward(self, x):
return self.layers(x)
class Aesthetic_Score_Predictor(nn.Layer):
def __init__(self):
super().__init__()
self.model = CLIP_large_patch14_224()
self.fc_head = nn.Linear(1024, 768, bias_attr=False)
self.mlp = MLP(768)
def forward(self, x):
x = self.model(x)
x = x[:, 0, :]
x = self.fc_head(x)
x = F.normalize(x, p=2, axis=-1)
x = self.mlp(x)
return x
def CLIP_large_patch14_224_aesthetic(pretrained=False,
use_ssld=False,
**kwargs):
model = Aesthetic_Score_Predictor()
_load_pretrained(pretrained, model,
MODEL_URLS["CLIP_large_patch14_224_aesthetic"], use_ssld)
return model
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 50
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
use_dali: False
# model architecture
Arch:
name: CLIP_large_patch14_224_aesthetic
pretrained: True
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/
cls_label_path: ./dataset/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: False
Infer:
infer_imgs: deploy/images/practical/aesthetic_score_predictor/Highscore.png
batch_size: 1
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: ScoreOutput
decimal_places: 2
Metric:
Eval:
- TopkAcc:
topk: [1, 2]
\ No newline at end of file
...@@ -19,7 +19,7 @@ from . import topk, threshoutput ...@@ -19,7 +19,7 @@ from . import topk, threshoutput
from .topk import Topk from .topk import Topk
from .threshoutput import ThreshOutput, MultiLabelThreshOutput from .threshoutput import ThreshOutput, MultiLabelThreshOutput
from .attr_rec import VehicleAttribute, PersonAttribute, TableAttribute from .attr_rec import VehicleAttribute, PersonAttribute, TableAttribute
from .scoreoutput import ScoreOutput
def build_postprocess(config): def build_postprocess(config):
......
import numpy
import numpy as np
import paddle
class ScoreOutput(object):
def __init__(self, decimal_places):
self.decimal_places = decimal_places
def __call__(self, x, file_names=None):
y = []
for idx, probs in enumerate(x):
score = np.around(x[idx].numpy(), self.decimal_places)
result = {"scores": score}
if file_names is not None:
result["file_name"] = file_names[idx]
y.append(result)
return y
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册