Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
af9aae73
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
af9aae73
编写于
9月 26, 2021
作者:
C
cuicheng01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add multilabel feature
上级
e431fe33
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
519 addition
and
64 deletion
+519
-64
deploy/configs/inference_multilabel_cls.yaml
deploy/configs/inference_multilabel_cls.yaml
+33
-0
deploy/images/0517_2715693311.jpg
deploy/images/0517_2715693311.jpg
+0
-0
deploy/python/postprocess.py
deploy/python/postprocess.py
+17
-6
deploy/python/predict_cls.py
deploy/python/predict_cls.py
+0
-2
deploy/shell/predict.sh
deploy/shell/predict.sh
+3
-0
docs/zh_CN/advanced_tutorials/multilabel/multilabel.md
docs/zh_CN/advanced_tutorials/multilabel/multilabel.md
+43
-35
ppcls/configs/quick_start/MobileNetV1_multilabel.yaml
ppcls/configs/quick_start/MobileNetV1_multilabel.yaml
+129
-0
ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml
...figs/quick_start/professional/MobileNetV1_multilabel.yaml
+129
-0
ppcls/data/dataloader/multilabel_dataset.py
ppcls/data/dataloader/multilabel_dataset.py
+4
-3
ppcls/data/postprocess/__init__.py
ppcls/data/postprocess/__init__.py
+1
-1
ppcls/data/postprocess/topk.py
ppcls/data/postprocess/topk.py
+13
-3
ppcls/engine/engine.py
ppcls/engine/engine.py
+11
-8
ppcls/engine/evaluation/classification.py
ppcls/engine/evaluation/classification.py
+2
-1
ppcls/engine/train/train.py
ppcls/engine/train/train.py
+2
-2
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+1
-0
ppcls/loss/multilabelloss.py
ppcls/loss/multilabelloss.py
+43
-0
ppcls/metric/__init__.py
ppcls/metric/__init__.py
+5
-1
ppcls/metric/metrics.py
ppcls/metric/metrics.py
+75
-1
tools/train.sh
tools/train.sh
+1
-1
train.sh
train.sh
+7
-0
未找到文件。
deploy/configs/inference_multilabel_cls.yaml
0 → 100644
浏览文件 @
af9aae73
Global
:
infer_imgs
:
"
./images/0517_2715693311.jpg"
inference_model_dir
:
"
../inference/"
batch_size
:
1
use_gpu
:
True
enable_mkldnn
:
False
cpu_num_threads
:
10
enable_benchmark
:
True
use_fp16
:
False
ir_optim
:
True
use_tensorrt
:
False
gpu_mem
:
8000
enable_profile
:
False
PreProcess
:
transform_ops
:
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
channel_num
:
3
-
ToCHWImage
:
PostProcess
:
main_indicator
:
MultiLabelTopk
MultiLabelTopk
:
topk
:
5
class_id_map_file
:
None
SavePreLabel
:
save_dir
:
./pre_label/
deploy/images/0517_2715693311.jpg
0 → 100644
浏览文件 @
af9aae73
16.3 KB
deploy/python/postprocess.py
浏览文件 @
af9aae73
...
...
@@ -81,12 +81,14 @@ class Topk(object):
class_id_map
=
None
return
class_id_map
def
__call__
(
self
,
x
,
file_names
=
None
):
def
__call__
(
self
,
x
,
file_names
=
None
,
multilabel
=
False
):
if
file_names
is
not
None
:
assert
x
.
shape
[
0
]
==
len
(
file_names
)
y
=
[]
for
idx
,
probs
in
enumerate
(
x
):
index
=
probs
.
argsort
(
axis
=
0
)[
-
self
.
topk
:][::
-
1
].
astype
(
"int32"
)
index
=
probs
.
argsort
(
axis
=
0
)[
-
self
.
topk
:][::
-
1
].
astype
(
"int32"
)
if
not
multilabel
else
np
.
where
(
probs
>=
0.5
)[
0
].
astype
(
"int32"
)
clas_id_list
=
[]
score_list
=
[]
label_name_list
=
[]
...
...
@@ -108,6 +110,14 @@ class Topk(object):
return
y
class
MultiLabelTopk
(
Topk
):
def
__init__
(
self
,
topk
=
1
,
class_id_map_file
=
None
):
super
().
__init__
()
def
__call__
(
self
,
x
,
file_names
=
None
):
return
super
().
__call__
(
x
,
file_names
,
multilabel
=
True
)
class
SavePreLabel
(
object
):
def
__init__
(
self
,
save_dir
):
if
save_dir
is
None
:
...
...
@@ -128,23 +138,24 @@ class SavePreLabel(object):
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
shutil
.
copy
(
image_file
,
output_dir
)
class
Binarize
(
object
):
def
__init__
(
self
,
method
=
"round"
):
def
__init__
(
self
,
method
=
"round"
):
self
.
method
=
method
self
.
unit
=
np
.
array
([[
128
,
64
,
32
,
16
,
8
,
4
,
2
,
1
]]).
T
def
__call__
(
self
,
x
,
file_names
=
None
):
if
self
.
method
==
"round"
:
x
=
np
.
round
(
x
+
1
).
astype
(
"uint8"
)
-
1
if
self
.
method
==
"sign"
:
x
=
((
np
.
sign
(
x
)
+
1
)
/
2
).
astype
(
"uint8"
)
embedding_size
=
x
.
shape
[
1
]
assert
embedding_size
%
8
==
0
,
"The Binary index only support vectors with sizes multiple of 8"
byte
=
np
.
zeros
([
x
.
shape
[
0
],
embedding_size
//
8
],
dtype
=
np
.
uint8
)
for
i
in
range
(
embedding_size
//
8
):
byte
[:,
i
:
i
+
1
]
=
np
.
dot
(
x
[:,
i
*
8
:
(
i
+
1
)
*
8
],
self
.
unit
)
byte
[:,
i
:
i
+
1
]
=
np
.
dot
(
x
[:,
i
*
8
:(
i
+
1
)
*
8
],
self
.
unit
)
return
byte
deploy/python/predict_cls.py
浏览文件 @
af9aae73
...
...
@@ -71,7 +71,6 @@ class ClsPredictor(Predictor):
output_names
=
self
.
paddle_predictor
.
get_output_names
()
output_tensor
=
self
.
paddle_predictor
.
get_output_handle
(
output_names
[
0
])
if
self
.
benchmark
:
self
.
auto_logger
.
times
.
start
()
if
not
isinstance
(
images
,
(
list
,
)):
...
...
@@ -119,7 +118,6 @@ def main(config):
)
==
len
(
image_list
):
if
len
(
batch_imgs
)
==
0
:
continue
batch_results
=
cls_predictor
.
predict
(
batch_imgs
)
for
number
,
result_dict
in
enumerate
(
batch_results
):
filename
=
batch_names
[
number
]
...
...
deploy/shell/predict.sh
浏览文件 @
af9aae73
# classification
python3.7 python/predict_cls.py
-c
configs/inference_cls.yaml
# multilabel_classification
#python3.7 python/predict_cls.py -c configs/inference_multilabel_cls.yaml
# feature extractor
# python3.7 python/predict_rec.py -c configs/inference_rec.yaml
...
...
docs/zh_CN/advanced_tutorials/multilabel/multilabel.md
浏览文件 @
af9aae73
...
...
@@ -25,58 +25,66 @@ tar -xf NUS-SCENE-dataset.tar
cd ../../
```
## 二、
环境准备
## 二、
模型训练
### 2.1 下载预训练模型
```
shell
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python3
-m
paddle.distributed.launch
\
--gpus
=
"0,1,2,3"
\
tools/train.py
\
-c
./ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml
```
训练10epoch之后,验证集最好的正确率应该在0.95左右。
本例展示基于ResNet50_vd模型的多标签分类流程,因此首先下载ResNet50_vd的预训练模型
## 三、模型评估
```
bash
mkdir
pretrained
cd
pretrained
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams
cd
../
python3 tools/eval.py
\
-c
./ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml
\
-o
Arch.pretrained
=
"./output/MobileNetV1/best_model"
```
##
三、模型训练
##
四、模型预测
```
shell
export
CUDA_VISIBLE_DEVICES
=
0
python
-m
paddle.distributed.launch
\
--gpus
=
"0"
\
tools/train.py
\
-c
./configs/quick_start/ResNet50_vd_multilabel.yaml
```
bash
python3 tools/infer.py
\
-c
./ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml
\
-o
Arch.pretrained
=
"./output/MobileNetV1/best_model"
```
得到类似下面的输出:
```
[{'class_ids': [6, 13, 17, 23, 26, 30], 'scores': [0.95683, 0.5567, 0.55211, 0.99088, 0.5943, 0.78767], 'file_name': './deploy/images/0517_2715693311.jpg', 'label_names': []}]
```
训练10epoch之后,验证集最好的正确率应该在0.72左右。
## 五、基于预测引擎预测
##
四、模型评估
##
# 5.1 导出inference model
```
bash
python tools/eval.py
\
-c
./configs/quick_start/ResNet50_vd_multilabel.yaml
\
-o
pretrained_model
=
"./output/ResNet50_vd/best_model/ppcls"
\
-o
load_static_weights
=
False
python3 tools/export_model.py
\
-c
./ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml
\
-o
Arch.pretrained
=
"./output/MobileNetV1/best_model"
```
inference model的路径默认在当前路径下
`./inference`
评估指标采用mAP,验证集的mAP应该在0.57左右。
### 5.2 基于预测引擎预测
## 五、模型预测
首先进入deploy目录下:
```
bash
python tools/infer/infer.py
\
-i
"./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/0199_434752251.jpg"
\
--model
ResNet50_vd
\
--pretrained_model
"./output/ResNet50_vd/best_model/ppcls"
\
--use_gpu
True
\
--load_static_weights
False
\
--multilabel
True
\
--class_num
33
cd
./deploy
```
通过预测引擎推理预测:
```
python3 python/predict_cls.py
\
-c configs/inference_multilabel_cls.yaml
```
得到类似下面的输出:
```
class id: 3, probability: 0.6025
class id: 23, probability: 0.5491
class id: 32, probability: 0.7006
```
\ No newline at end of file
```
0517_2715693311.jpg: class id(s): [6, 13, 17, 23, 26, 30], score(s): [0.96, 0.56, 0.55, 0.99, 0.59, 0.79], label_name(s): []
```
ppcls/configs/quick_start/MobileNetV1_multilabel.yaml
0 → 100644
浏览文件 @
af9aae73
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
10
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
use_multilabel
:
True
# model architecture
Arch
:
name
:
MobileNetV1
class_num
:
33
pretrained
:
True
# loss function config for traing/eval process
Loss
:
Train
:
-
MultiLabelLoss
:
weight
:
1.0
Eval
:
-
MultiLabelLoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.1
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
MultiLabelDataset
image_root
:
./dataset/NUS-SCENE-dataset/images/
cls_label_path
:
./dataset/NUS-SCENE-dataset/multilabel_train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
MultiLabelDataset
image_root
:
./dataset/NUS-SCENE-dataset/images/
cls_label_path
:
./dataset/NUS-SCENE-dataset/multilabel_test_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
256
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
dataset/NUS-SCENE-dataset/images/0001_109549716.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
MutiLabelTopk
topk
:
5
class_id_map_file
:
None
Metric
:
Train
:
-
HammingDistance
:
-
AccuracyScore
:
Eval
:
-
HammingDistance
:
-
AccuracyScore
:
ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml
0 → 100644
浏览文件 @
af9aae73
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
10
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
use_multilabel
:
True
# model architecture
Arch
:
name
:
MobileNetV1
class_num
:
33
pretrained
:
True
# loss function config for traing/eval process
Loss
:
Train
:
-
MultiLabelLoss
:
weight
:
1.0
Eval
:
-
MultiLabelLoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.1
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
MultiLabelDataset
image_root
:
./dataset/NUS-SCENE-dataset/images/
cls_label_path
:
./dataset/NUS-SCENE-dataset/multilabel_train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
MultiLabelDataset
image_root
:
./dataset/NUS-SCENE-dataset/images/
cls_label_path
:
./dataset/NUS-SCENE-dataset/multilabel_test_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
256
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
./deploy/images/0517_2715693311.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
MultiLabelTopk
topk
:
5
class_id_map_file
:
None
Metric
:
Train
:
-
HammingDistance
:
-
AccuracyScore
:
Eval
:
-
HammingDistance
:
-
AccuracyScore
:
ppcls/data/dataloader/multilabel_dataset.py
浏览文件 @
af9aae73
...
...
@@ -33,7 +33,7 @@ class MultiLabelDataset(CommonDataset):
with
open
(
self
.
_cls_path
)
as
fd
:
lines
=
fd
.
readlines
()
for
l
in
lines
:
l
=
l
.
strip
().
split
(
"
"
)
l
=
l
.
strip
().
split
(
"
\t
"
)
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
l
[
0
]))
labels
=
l
[
1
].
split
(
','
)
...
...
@@ -44,13 +44,14 @@ class MultiLabelDataset(CommonDataset):
def
__getitem__
(
self
,
idx
):
try
:
img
=
cv2
.
imread
(
self
.
images
[
idx
])
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
with
open
(
self
.
images
[
idx
],
'rb'
)
as
f
:
img
=
f
.
read
(
)
if
self
.
_transform_ops
:
img
=
transform
(
img
,
self
.
_transform_ops
)
img
=
img
.
transpose
((
2
,
0
,
1
))
label
=
np
.
array
(
self
.
labels
[
idx
]).
astype
(
"float32"
)
return
(
img
,
label
)
except
Exception
as
ex
:
logger
.
error
(
"Exception occured when parse line: {} with msg: {}"
.
format
(
self
.
images
[
idx
],
ex
))
...
...
ppcls/data/postprocess/__init__.py
浏览文件 @
af9aae73
...
...
@@ -16,7 +16,7 @@ import importlib
from
.
import
topk
from
.topk
import
Topk
from
.topk
import
Topk
,
MultiLabelTopk
def
build_postprocess
(
config
):
...
...
ppcls/data/postprocess/topk.py
浏览文件 @
af9aae73
...
...
@@ -45,15 +45,17 @@ class Topk(object):
class_id_map
=
None
return
class_id_map
def
__call__
(
self
,
x
,
file_names
=
None
):
def
__call__
(
self
,
x
,
file_names
=
None
,
multilabel
=
False
):
assert
isinstance
(
x
,
paddle
.
Tensor
)
if
file_names
is
not
None
:
assert
x
.
shape
[
0
]
==
len
(
file_names
)
x
=
F
.
softmax
(
x
,
axis
=-
1
)
x
=
F
.
softmax
(
x
,
axis
=-
1
)
if
not
multilabel
else
F
.
sigmoid
(
x
)
x
=
x
.
numpy
()
y
=
[]
for
idx
,
probs
in
enumerate
(
x
):
index
=
probs
.
argsort
(
axis
=
0
)[
-
self
.
topk
:][::
-
1
].
astype
(
"int32"
)
index
=
probs
.
argsort
(
axis
=
0
)[
-
self
.
topk
:][::
-
1
].
astype
(
"int32"
)
if
not
multilabel
else
np
.
where
(
probs
>=
0.5
)[
0
].
astype
(
"int32"
)
clas_id_list
=
[]
score_list
=
[]
label_name_list
=
[]
...
...
@@ -73,3 +75,11 @@ class Topk(object):
result
[
"label_names"
]
=
label_name_list
y
.
append
(
result
)
return
y
class
MultiLabelTopk
(
Topk
):
def
__init__
(
self
,
topk
=
1
,
class_id_map_file
=
None
):
super
().
__init__
()
def
__call__
(
self
,
x
,
file_names
=
None
):
return
super
().
__call__
(
x
,
file_names
,
multilabel
=
True
)
ppcls/engine/engine.py
浏览文件 @
af9aae73
...
...
@@ -355,7 +355,8 @@ class Engine(object):
def
export
(
self
):
assert
self
.
mode
==
"export"
model
=
ExportModel
(
self
.
config
[
"Arch"
],
self
.
model
)
use_multilabel
=
self
.
config
[
"Global"
].
get
(
"use_multilabel"
,
False
)
model
=
ExportModel
(
self
.
config
[
"Arch"
],
self
.
model
,
use_multilabel
)
if
self
.
config
[
"Global"
][
"pretrained_model"
]
is
not
None
:
load_dygraph_pretrain
(
model
.
base_model
,
self
.
config
[
"Global"
][
"pretrained_model"
])
...
...
@@ -388,10 +389,9 @@ class ExportModel(nn.Layer):
ExportModel: add softmax onto the model
"""
def
__init__
(
self
,
config
,
model
):
def
__init__
(
self
,
config
,
model
,
use_multilabel
):
super
().
__init__
()
self
.
base_model
=
model
# we should choose a final model to export
if
isinstance
(
self
.
base_model
,
DistillationModel
):
self
.
infer_model_name
=
config
[
"infer_model_name"
]
...
...
@@ -402,10 +402,13 @@ class ExportModel(nn.Layer):
if
self
.
infer_output_key
==
"features"
and
isinstance
(
self
.
base_model
,
RecModel
):
self
.
base_model
.
head
=
IdentityHead
()
if
config
.
get
(
"infer_add_softmax"
,
True
)
:
self
.
softmax
=
nn
.
Softmax
(
axis
=-
1
)
if
use_multilabel
:
self
.
out_act
=
nn
.
Sigmoid
(
)
else
:
self
.
softmax
=
None
if
config
.
get
(
"infer_add_softmax"
,
True
):
self
.
out_act
=
nn
.
Softmax
(
axis
=-
1
)
else
:
self
.
out_act
=
None
def
eval
(
self
):
self
.
training
=
False
...
...
@@ -421,6 +424,6 @@ class ExportModel(nn.Layer):
x
=
x
[
self
.
infer_model_name
]
if
self
.
infer_output_key
is
not
None
:
x
=
x
[
self
.
infer_output_key
]
if
self
.
softmax
is
not
None
:
x
=
self
.
softmax
(
x
)
if
self
.
out_act
is
not
None
:
x
=
self
.
out_act
(
x
)
return
x
ppcls/engine/evaluation/classification.py
浏览文件 @
af9aae73
...
...
@@ -52,7 +52,8 @@ def classification_eval(evaler, epoch_id=0):
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
batch_size
=
batch
[
0
].
shape
[
0
]
batch
[
0
]
=
paddle
.
to_tensor
(
batch
[
0
]).
astype
(
"float32"
)
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
if
not
evaler
.
config
[
"Global"
].
get
(
"use_multilabel"
,
False
):
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
# image input
out
=
evaler
.
model
(
batch
[
0
])
# calc loss
...
...
ppcls/engine/train/train.py
浏览文件 @
af9aae73
...
...
@@ -36,8 +36,8 @@ def train_epoch(trainer, epoch_id, print_batch_step):
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
]
batch_size
=
batch
[
0
].
shape
[
0
]
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
if
not
trainer
.
config
[
"Global"
].
get
(
"use_multilabel"
,
False
):
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
trainer
.
global_step
+=
1
# image input
if
trainer
.
amp
:
...
...
ppcls/loss/__init__.py
浏览文件 @
af9aae73
...
...
@@ -20,6 +20,7 @@ from .distanceloss import DistanceLoss
from
.distillationloss
import
DistillationCELoss
from
.distillationloss
import
DistillationGTCELoss
from
.distillationloss
import
DistillationDMLLoss
from
.multilabelloss
import
MultiLabelLoss
class
CombinedLoss
(
nn
.
Layer
):
...
...
ppcls/loss/multilabelloss.py
0 → 100644
浏览文件 @
af9aae73
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
MultiLabelLoss
(
nn
.
Layer
):
"""
Multi-label loss
"""
def
__init__
(
self
,
epsilon
=
None
):
super
().
__init__
()
if
epsilon
is
not
None
and
(
epsilon
<=
0
or
epsilon
>=
1
):
epsilon
=
None
self
.
epsilon
=
epsilon
def
_labelsmoothing
(
self
,
target
,
class_num
):
if
target
.
ndim
==
1
or
target
.
shape
[
-
1
]
!=
class_num
:
one_hot_target
=
F
.
one_hot
(
target
,
class_num
)
else
:
one_hot_target
=
target
soft_target
=
F
.
label_smooth
(
one_hot_target
,
epsilon
=
self
.
epsilon
)
soft_target
=
paddle
.
reshape
(
soft_target
,
shape
=
[
-
1
,
class_num
])
return
soft_target
def
_binary_crossentropy
(
self
,
input
,
target
,
class_num
):
if
self
.
epsilon
is
not
None
:
target
=
self
.
_labelsmoothing
(
target
,
class_num
)
cost
=
F
.
binary_cross_entropy_with_logits
(
logit
=
input
,
label
=
target
)
else
:
cost
=
F
.
binary_cross_entropy_with_logits
(
logit
=
input
,
label
=
target
)
return
cost
def
forward
(
self
,
x
,
target
):
if
isinstance
(
x
,
dict
):
x
=
x
[
"logits"
]
class_num
=
x
.
shape
[
-
1
]
loss
=
self
.
_binary_crossentropy
(
x
,
target
,
class_num
)
loss
=
loss
.
mean
()
return
{
"MultiLabelLoss"
:
loss
}
ppcls/metric/__init__.py
浏览文件 @
af9aae73
...
...
@@ -19,6 +19,8 @@ from collections import OrderedDict
from
.metrics
import
TopkAcc
,
mAP
,
mINP
,
Recallk
,
Precisionk
from
.metrics
import
DistillationTopkAcc
from
.metrics
import
GoogLeNetTopkAcc
from
.metrics
import
HammingDistance
,
AccuracyScore
class
CombinedMetrics
(
nn
.
Layer
):
def
__init__
(
self
,
config_list
):
...
...
@@ -32,7 +34,8 @@ class CombinedMetrics(nn.Layer):
metric_name
=
list
(
config
)[
0
]
metric_params
=
config
[
metric_name
]
if
metric_params
is
not
None
:
self
.
metric_func_list
.
append
(
eval
(
metric_name
)(
**
metric_params
))
self
.
metric_func_list
.
append
(
eval
(
metric_name
)(
**
metric_params
))
else
:
self
.
metric_func_list
.
append
(
eval
(
metric_name
)())
...
...
@@ -42,6 +45,7 @@ class CombinedMetrics(nn.Layer):
metric_dict
.
update
(
metric_func
(
*
args
,
**
kwargs
))
return
metric_dict
def
build_metrics
(
config
):
metrics_list
=
CombinedMetrics
(
copy
.
deepcopy
(
config
))
return
metrics_list
ppcls/metric/metrics.py
浏览文件 @
af9aae73
...
...
@@ -15,6 +15,12 @@
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
sklearn.metrics
import
hamming_loss
from
sklearn.metrics
import
accuracy_score
as
accuracy_metric
from
sklearn.metrics
import
multilabel_confusion_matrix
from
sklearn.preprocessing
import
binarize
class
TopkAcc
(
nn
.
Layer
):
...
...
@@ -198,7 +204,7 @@ class Precisionk(nn.Layer):
equal_flag
=
paddle
.
logical_and
(
equal_flag
,
keep_mask
.
astype
(
'bool'
))
equal_flag
=
paddle
.
cast
(
equal_flag
,
'float32'
)
Ns
=
paddle
.
arange
(
gallery_img_id
.
shape
[
0
])
+
1
equal_flag_cumsum
=
paddle
.
cumsum
(
equal_flag
,
axis
=
1
)
Precision_at_k
=
(
paddle
.
mean
(
equal_flag_cumsum
,
axis
=
0
)
/
Ns
).
numpy
()
...
...
@@ -232,3 +238,71 @@ class GoogLeNetTopkAcc(TopkAcc):
def
forward
(
self
,
x
,
label
):
return
super
().
forward
(
x
[
0
],
label
)
class
MutiLabelMetric
(
object
):
def
__init__
(
self
):
pass
def
_multi_hot_encode
(
self
,
logits
,
threshold
=
0.5
):
return
binarize
(
logits
,
threshold
=
threshold
)
def
__call__
(
self
,
output
):
output
=
F
.
sigmoid
(
output
)
preds
=
self
.
_multi_hot_encode
(
logits
=
output
.
numpy
(),
threshold
=
0.5
)
return
preds
class
HammingDistance
(
MutiLabelMetric
):
"""
Soft metric based label for multilabel classification
Returns:
The smaller the return value is, the better model is.
"""
def
__init__
(
self
):
super
().
__init__
()
def
__call__
(
self
,
output
,
target
):
preds
=
super
().
__call__
(
output
)
metric_dict
=
dict
()
metric_dict
[
"HammingDistance"
]
=
paddle
.
to_tensor
(
hamming_loss
(
target
,
preds
))
return
metric_dict
class
AccuracyScore
(
MutiLabelMetric
):
"""
Hard metric for multilabel classification
Args:
base: ["sample", "label"], default="sample"
if "sample", return metric score based sample,
if "label", return metric score based label.
Returns:
accuracy:
"""
def
__init__
(
self
,
base
=
"label"
):
super
().
__init__
()
assert
base
in
[
"sample"
,
"label"
],
'must be one of ["sample", "label"]'
self
.
base
=
base
def
__call__
(
self
,
output
,
target
):
preds
=
super
().
__call__
(
output
)
metric_dict
=
dict
()
if
self
.
base
==
"sample"
:
accuracy
=
accuracy_metric
(
target
,
preds
)
elif
self
.
base
==
"label"
:
mcm
=
multilabel_confusion_matrix
(
target
,
preds
)
tns
=
mcm
[:,
0
,
0
]
fns
=
mcm
[:,
1
,
0
]
tps
=
mcm
[:,
1
,
1
]
fps
=
mcm
[:,
0
,
1
]
accuracy
=
(
sum
(
tps
)
+
sum
(
tns
))
/
(
sum
(
tps
)
+
sum
(
tns
)
+
sum
(
fns
)
+
sum
(
fps
))
precision
=
sum
(
tps
)
/
(
sum
(
tps
)
+
sum
(
fps
))
recall
=
sum
(
tps
)
/
(
sum
(
tps
)
+
sum
(
fns
))
F1
=
2
*
(
accuracy
*
recall
)
/
(
accuracy
+
recall
)
metric_dict
[
"AccuracyScore"
]
=
paddle
.
to_tensor
(
accuracy
)
return
metric_dict
tools/train.sh
浏览文件 @
af9aae73
...
...
@@ -4,4 +4,4 @@
# python3.7 tools/train.py -c ./ppcls/configs/ImageNet/ResNet/ResNet50.yaml
# for multi-cards train
python3.7
-m
paddle.distributed.launch
--gpus
=
"0,1,2,3"
tools/train.py
-c
./ppcls/configs/ImageNet/ResNet/ResNet50.yaml
\ No newline at end of file
python3.7
-m
paddle.distributed.launch
--gpus
=
"0,1,2,3"
tools/train.py
-c
./ppcls/configs/ImageNet/ResNet/ResNet50.yaml
train.sh
0 → 100755
浏览文件 @
af9aae73
#!/usr/bin/env bash
# for single card train
# python3.7 tools/train.py -c ./ppcls/configs/ImageNet/ResNet/ResNet50.yaml
# for multi-cards train
python3.7
-m
paddle.distributed.launch
--gpus
=
"0"
tools/train.py
-c
./MobileNetV2.yaml
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录