Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
5fd7085d
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5fd7085d
编写于
3月 30, 2021
作者:
Y
yaohai
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add multilabel feature
上级
8a469799
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
383 addition
and
40 deletion
+383
-40
configs/quick_start/ResNet50_vd_multilabel.yam
configs/quick_start/ResNet50_vd_multilabel.yam
+79
-0
ppcls/data/reader.py
ppcls/data/reader.py
+39
-1
ppcls/modeling/loss.py
ppcls/modeling/loss.py
+26
-1
ppcls/utils/__init__.py
ppcls/utils/__init__.py
+6
-0
ppcls/utils/metrics.py
ppcls/utils/metrics.py
+107
-0
requirements.txt
requirements.txt
+1
-0
tools/eval.py
tools/eval.py
+39
-5
tools/infer/infer.py
tools/infer/infer.py
+6
-2
tools/infer/utils.py
tools/infer/utils.py
+6
-2
tools/program.py
tools/program.py
+74
-29
未找到文件。
configs/quick_start/ResNet50_vd_multilabel.yam
0 → 100644
浏览文件 @
5fd7085d
mode: 'train'
ARCHITECTURE:
name: 'ResNet50_vd'
pretrained_model: "./pretrained/ResNet50_vd_pretrained"
model_save_dir: "./output/"
classes_num: 33
total_images: 17463
save_interval: 1
validate: True
valid_interval: 1
epochs: 10
topk: 1
image_shape: [3, 224, 224]
multilabel: True
use_mix: False
ls_epsilon: 0.1
LEARNING_RATE:
function: 'Cosine'
params:
lr: 0.07
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.000070
TRAIN:
batch_size: 256
num_workers: 4
file_list: "./dataset/NUS-SCENE-dataset/multilabel_train_list.txt"
data_dir: "./dataset/NUS-SCENE-dataset/images"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
mix:
- MixupOperator:
alpha: 0.2
VALID:
batch_size: 64
num_workers: 4
file_list: "./dataset/NUS-SCENE-dataset/multilabel_test_list.txt"
data_dir: "./dataset/NUS-SCENE-dataset/images"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
\ No newline at end of file
ppcls/data/reader.py
浏览文件 @
5fd7085d
...
@@ -197,6 +197,40 @@ class CommonDataset(Dataset):
...
@@ -197,6 +197,40 @@ class CommonDataset(Dataset):
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
num_samples
return
self
.
num_samples
class
MultiLabelDataset
(
Dataset
):
"""
Define dataset class for multilabel image classification
"""
def
__init__
(
self
,
params
):
self
.
params
=
params
self
.
mode
=
params
.
get
(
"mode"
,
"train"
)
self
.
full_lines
=
get_file_list
(
params
)
self
.
delimiter
=
params
.
get
(
"delimiter"
,
"
\t
"
)
self
.
ops
=
create_operators
(
params
[
"transforms"
])
self
.
num_samples
=
len
(
self
.
full_lines
)
return
def
__getitem__
(
self
,
idx
):
try
:
line
=
self
.
full_lines
[
idx
]
img_path
,
label_str
=
line
.
split
(
self
.
delimiter
)
img_path
=
os
.
path
.
join
(
self
.
params
[
"data_dir"
],
img_path
)
with
open
(
img_path
,
"rb"
)
as
f
:
img
=
f
.
read
()
labels
=
label_str
.
split
(
','
)
labels
=
[
int
(
i
)
for
i
in
labels
]
return
(
transform
(
img
,
self
.
ops
),
np
.
array
(
labels
).
astype
(
"float32"
))
except
Exception
as
e
:
logger
.
error
(
"data read failed: {}, exception info: {}"
.
format
(
line
,
e
))
return
self
.
__getitem__
(
random
.
randint
(
0
,
len
(
self
)))
def
__len__
(
self
):
return
self
.
num_samples
class
Reader
:
class
Reader
:
...
@@ -229,6 +263,7 @@ class Reader:
...
@@ -229,6 +263,7 @@ class Reader:
self
.
collate_fn
=
self
.
mix_collate_fn
self
.
collate_fn
=
self
.
mix_collate_fn
self
.
places
=
places
self
.
places
=
places
self
.
multilabel
=
config
.
get
(
"multilabel"
,
False
)
def
mix_collate_fn
(
self
,
batch
):
def
mix_collate_fn
(
self
,
batch
):
batch
=
transform
(
batch
,
self
.
batch_ops
)
batch
=
transform
(
batch
,
self
.
batch_ops
)
...
@@ -246,7 +281,10 @@ class Reader:
...
@@ -246,7 +281,10 @@ class Reader:
def
__call__
(
self
):
def
__call__
(
self
):
batch_size
=
int
(
self
.
params
[
'batch_size'
])
//
trainers_num
batch_size
=
int
(
self
.
params
[
'batch_size'
])
//
trainers_num
dataset
=
CommonDataset
(
self
.
params
)
if
self
.
multilabel
:
dataset
=
MultiLabelDataset
(
self
.
params
)
else
:
dataset
=
CommonDataset
(
self
.
params
)
is_train
=
self
.
params
[
'mode'
]
==
"train"
is_train
=
self
.
params
[
'mode'
]
==
"train"
batch_sampler
=
DistributedBatchSampler
(
batch_sampler
=
DistributedBatchSampler
(
...
...
ppcls/modeling/loss.py
浏览文件 @
5fd7085d
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
paddle
import
paddle
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
__all__
=
[
'CELoss'
,
'MixCELoss'
,
'GoogLeNetLoss'
,
'JSDivLoss'
]
__all__
=
[
'CELoss'
,
'MixCELoss'
,
'GoogLeNetLoss'
,
'JSDivLoss'
,
'MultiLabelLoss'
]
class
Loss
(
object
):
class
Loss
(
object
):
...
@@ -41,6 +41,17 @@ class Loss(object):
...
@@ -41,6 +41,17 @@ class Loss(object):
soft_target
=
F
.
label_smooth
(
one_hot_target
,
epsilon
=
self
.
_epsilon
)
soft_target
=
F
.
label_smooth
(
one_hot_target
,
epsilon
=
self
.
_epsilon
)
soft_target
=
paddle
.
reshape
(
soft_target
,
shape
=
[
-
1
,
self
.
_class_dim
])
soft_target
=
paddle
.
reshape
(
soft_target
,
shape
=
[
-
1
,
self
.
_class_dim
])
return
soft_target
return
soft_target
def
_binary_crossentropy
(
self
,
input
,
target
):
if
self
.
_label_smoothing
:
target
=
self
.
_labelsmoothing
(
target
)
cost
=
F
.
binary_cross_entropy_with_logits
(
logit
=
input
,
label
=
target
)
else
:
cost
=
F
.
binary_cross_entropy_with_logits
(
logit
=
input
,
label
=
target
)
avg_cost
=
paddle
.
mean
(
cost
)
return
avg_cost
def
_crossentropy
(
self
,
input
,
target
):
def
_crossentropy
(
self
,
input
,
target
):
if
self
.
_label_smoothing
:
if
self
.
_label_smoothing
:
...
@@ -68,6 +79,20 @@ class Loss(object):
...
@@ -68,6 +79,20 @@ class Loss(object):
def
__call__
(
self
,
input
,
target
):
def
__call__
(
self
,
input
,
target
):
pass
pass
class
MultiLabelLoss
(
Loss
):
"""
Multilabel loss based binary cross entropy
"""
def
__init__
(
self
,
class_dim
=
1000
,
epsilon
=
None
):
super
(
MultiLabelLoss
,
self
).
__init__
(
class_dim
,
epsilon
)
def
__call__
(
self
,
input
,
target
,
use_pure_fp16
=
False
):
cost
=
self
.
_binary_crossentropy
(
input
,
target
,
use_pure_fp16
)
return
cost
class
CELoss
(
Loss
):
class
CELoss
(
Loss
):
...
...
ppcls/utils/__init__.py
浏览文件 @
5fd7085d
...
@@ -15,7 +15,13 @@
...
@@ -15,7 +15,13 @@
from
.
import
logger
from
.
import
logger
from
.
import
misc
from
.
import
misc
from
.
import
model_zoo
from
.
import
model_zoo
from
.
import
metrics
from
.save_load
import
init_model
,
save_model
from
.save_load
import
init_model
,
save_model
from
.config
import
get_config
from
.config
import
get_config
from
.misc
import
AverageMeter
from
.misc
import
AverageMeter
from
.metrics
import
multi_hot_encode
from
.metrics
import
hamming_distance
from
.metrics
import
accuracy_score
from
.metrics
import
precision_recall_fscore
from
.metrics
import
mean_average_precision
ppcls/utils/metrics.py
0 → 100644
浏览文件 @
5fd7085d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
sklearn.metrics
import
hamming_loss
from
sklearn.metrics
import
accuracy_score
as
accuracy_metric
from
sklearn.metrics
import
multilabel_confusion_matrix
from
sklearn.metrics
import
precision_recall_fscore_support
from
sklearn.metrics
import
average_precision_score
from
sklearn.preprocessing
import
binarize
import
numpy
as
np
__all__
=
[
"multi_hot_encode"
,
"hamming_distance"
,
"accuracy_score"
,
"precision_recall_fscore"
,
"mean_average_precision"
]
def
multi_hot_encode
(
logits
,
threshold
=
0.5
):
"""
Encode logits to multi-hot by elementwise for multilabel
"""
return
binarize
(
logits
,
threshold
)
def
hamming_distance
(
output
,
target
):
"""
Soft metric based label for multilabel classification
Returns:
The smaller the return value is, the better model is.
"""
return
hamming_loss
(
target
,
output
)
def
accuracy_score
(
output
,
target
,
base
=
"sample"
):
"""
Hard metric for multilabel classification
Args:
output:
target:
base: ["sample", "label"], default="sample"
if "sample", return metric score based sample,
if "label", return metric score based label.
Returns:
accuracy:
"""
assert
base
in
[
"sample"
,
"label"
],
'must be one of ["sample", "label"]'
if
base
==
"sample"
:
accuracy
=
accuracy_metric
(
target
,
output
)
elif
base
==
"label"
:
mcm
=
multilabel_confusion_matrix
(
target
,
output
)
tns
=
mcm
[:,
0
,
0
]
fns
=
mcm
[:,
1
,
0
]
tps
=
mcm
[:,
1
,
1
]
fps
=
mcm
[:,
0
,
1
]
accuracy
=
(
sum
(
tps
)
+
sum
(
tns
))
/
(
sum
(
tps
)
+
sum
(
tns
)
+
sum
(
fns
)
+
sum
(
fps
))
return
accuracy
def
precision_recall_fscore
(
output
,
target
):
"""
Metric based label for multilabel classification
Returns:
precisions:
recalls:
fscores:
"""
precisions
,
recalls
,
fscores
,
_
=
precision_recall_fscore_support
(
target
,
output
)
return
precisions
,
recalls
,
fscores
def
mean_average_precision
(
logits
,
target
):
"""
Calculate average precision
Args:
logits: probability from network before sigmoid or softmax
target: ground truth, 0 or 1
"""
if
not
(
isinstance
(
logits
,
np
.
ndarray
)
and
isinstance
(
target
,
np
.
ndarray
)):
raise
TypeError
(
"logits and target should be np.ndarray."
)
aps
=
[]
for
i
in
range
(
target
.
shape
[
1
]):
ap
=
average_precision_score
(
target
[:,
i
],
logits
[:,
i
])
aps
.
append
(
ap
)
return
np
.
mean
(
aps
)
requirements.txt
浏览文件 @
5fd7085d
...
@@ -5,3 +5,4 @@ tqdm
...
@@ -5,3 +5,4 @@ tqdm
PyYAML
PyYAML
visualdl
>= 2.0.0b
visualdl
>= 2.0.0b
scipy
scipy
scikit-learn
==0.23.2
tools/eval.py
浏览文件 @
5fd7085d
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
import
paddle.nn.functional
as
F
import
argparse
import
argparse
import
os
import
os
...
@@ -24,9 +25,15 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
...
@@ -24,9 +25,15 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
from
ppcls.utils.save_load
import
init_model
from
ppcls.utils.save_load
import
init_model
from
ppcls.utils.config
import
get_config
from
ppcls.utils.config
import
get_config
from
ppcls.utils
import
multi_hot_encode
from
ppcls.utils
import
accuracy_score
from
ppcls.utils
import
mean_average_precision
from
ppcls.utils
import
precision_recall_fscore
from
ppcls.data
import
Reader
from
ppcls.data
import
Reader
import
program
import
program
import
numpy
as
np
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"PaddleClas eval script"
)
parser
=
argparse
.
ArgumentParser
(
"PaddleClas eval script"
)
...
@@ -52,6 +59,7 @@ def main(args, return_dict={}):
...
@@ -52,6 +59,7 @@ def main(args, return_dict={}):
# assign place
# assign place
use_gpu
=
config
.
get
(
"use_gpu"
,
True
)
use_gpu
=
config
.
get
(
"use_gpu"
,
True
)
place
=
paddle
.
set_device
(
'gpu'
if
use_gpu
else
'cpu'
)
place
=
paddle
.
set_device
(
'gpu'
if
use_gpu
else
'cpu'
)
multilabel
=
config
.
get
(
"multilabel"
,
False
)
trainer_num
=
paddle
.
distributed
.
get_world_size
()
trainer_num
=
paddle
.
distributed
.
get_world_size
()
use_data_parallel
=
trainer_num
!=
1
use_data_parallel
=
trainer_num
!=
1
...
@@ -68,12 +76,38 @@ def main(args, return_dict={}):
...
@@ -68,12 +76,38 @@ def main(args, return_dict={}):
valid_dataloader
=
Reader
(
config
,
'valid'
,
places
=
place
)()
valid_dataloader
=
Reader
(
config
,
'valid'
,
places
=
place
)()
net
.
eval
()
net
.
eval
()
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
top1_acc
=
program
.
run
(
valid_dataloader
,
config
,
net
,
None
,
None
,
0
,
if
not
multilabel
:
'valid'
)
top1_acc
=
program
.
run
(
valid_dataloader
,
config
,
net
,
None
,
None
,
0
,
return_dict
[
"top1_acc"
]
=
top1_acc
'valid'
)
return
top1_acc
return_dict
[
"top1_acc"
]
=
top1_acc
return
top1_acc
else
:
all_outs
=
[]
targets
=
[]
for
idx
,
batch
in
enumerate
(
valid_dataloader
()):
feeds
=
program
.
create_feeds
(
batch
,
False
,
config
.
classes_num
,
multilabel
)
out
=
net
(
feeds
[
"image"
])
out
=
F
.
sigmoid
(
out
)
use_distillation
=
config
.
get
(
"use_distillation"
,
False
)
if
use_distillation
:
out
=
out
[
1
]
all_outs
.
extend
(
list
(
out
.
numpy
()))
targets
.
extend
(
list
(
feeds
[
"label"
].
numpy
()))
all_outs
=
np
.
array
(
all_outs
)
targets
=
np
.
array
(
targets
)
mAP
=
mean_average_precision
(
all_outs
,
targets
)
return_dict
[
"mean average precision"
]
=
mAP
return
mAP
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
args
=
parse_args
()
args
=
parse_args
()
main
(
args
)
return_dict
=
{}
main
(
args
,
return_dict
)
print
(
return_dict
)
tools/infer/infer.py
浏览文件 @
5fd7085d
...
@@ -34,6 +34,7 @@ def main():
...
@@ -34,6 +34,7 @@ def main():
args
=
parse_args
()
args
=
parse_args
()
# assign the place
# assign the place
place
=
paddle
.
set_device
(
'gpu'
if
args
.
use_gpu
else
'cpu'
)
place
=
paddle
.
set_device
(
'gpu'
if
args
.
use_gpu
else
'cpu'
)
multilabel
=
True
if
args
.
multilabel
else
False
net
=
architectures
.
__dict__
[
args
.
model
](
class_dim
=
args
.
class_num
)
net
=
architectures
.
__dict__
[
args
.
model
](
class_dim
=
args
.
class_num
)
load_dygraph_pretrain
(
net
,
args
.
pretrained_model
,
args
.
load_static_weights
)
load_dygraph_pretrain
(
net
,
args
.
pretrained_model
,
args
.
load_static_weights
)
...
@@ -61,9 +62,12 @@ def main():
...
@@ -61,9 +62,12 @@ def main():
batch_outputs
=
net
(
batch_tensor
)
batch_outputs
=
net
(
batch_tensor
)
if
args
.
model
==
"GoogLeNet"
:
if
args
.
model
==
"GoogLeNet"
:
batch_outputs
=
batch_outputs
[
0
]
batch_outputs
=
batch_outputs
[
0
]
batch_outputs
=
F
.
softmax
(
batch_outputs
)
if
multilabel
:
batch_outputs
=
F
.
sigmoid
(
batch_outputs
)
else
:
batch_outputs
=
F
.
softmax
(
batch_outputs
)
batch_outputs
=
batch_outputs
.
numpy
()
batch_outputs
=
batch_outputs
.
numpy
()
batch_result_list
=
postprocess
(
batch_outputs
,
args
.
top_k
)
batch_result_list
=
postprocess
(
batch_outputs
,
args
.
top_k
,
multilabel
=
multilabel
)
for
number
,
result_dict
in
enumerate
(
batch_result_list
):
for
number
,
result_dict
in
enumerate
(
batch_result_list
):
filename
=
img_path_list
[
number
].
split
(
"/"
)[
-
1
]
filename
=
img_path_list
[
number
].
split
(
"/"
)[
-
1
]
...
...
tools/infer/utils.py
浏览文件 @
5fd7085d
...
@@ -31,6 +31,7 @@ def parse_args():
...
@@ -31,6 +31,7 @@ def parse_args():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-i"
,
"--image_file"
,
type
=
str
)
parser
.
add_argument
(
"-i"
,
"--image_file"
,
type
=
str
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--multilabel"
,
type
=
str2bool
,
default
=
False
)
# params for preprocess
# params for preprocess
parser
.
add_argument
(
"--resize_short"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--resize_short"
,
type
=
int
,
default
=
256
)
...
@@ -124,11 +125,14 @@ def preprocess(img, args):
...
@@ -124,11 +125,14 @@ def preprocess(img, args):
return
img
return
img
def
postprocess
(
batch_outputs
,
topk
=
5
):
def
postprocess
(
batch_outputs
,
topk
=
5
,
multilabel
=
False
):
batch_results
=
[]
batch_results
=
[]
for
probs
in
batch_outputs
:
for
probs
in
batch_outputs
:
results
=
[]
results
=
[]
index
=
probs
.
argsort
(
axis
=
0
)[
-
topk
:][::
-
1
].
astype
(
"int32"
)
if
multilabel
:
index
=
np
.
where
(
probs
>=
0.5
)[
0
].
astype
(
'int32'
)
else
:
index
=
probs
.
argsort
(
axis
=
0
)[
-
topk
:][::
-
1
].
astype
(
"int32"
)
clas_id_list
=
[]
clas_id_list
=
[]
score_list
=
[]
score_list
=
[]
for
i
in
index
:
for
i
in
index
:
...
...
tools/program.py
浏览文件 @
5fd7085d
...
@@ -29,12 +29,16 @@ import paddle.nn.functional as F
...
@@ -29,12 +29,16 @@ import paddle.nn.functional as F
from
ppcls.optimizer
import
LearningRateBuilder
from
ppcls.optimizer
import
LearningRateBuilder
from
ppcls.optimizer
import
OptimizerBuilder
from
ppcls.optimizer
import
OptimizerBuilder
from
ppcls.modeling
import
architectures
from
ppcls.modeling
import
architectures
from
ppcls.modeling.loss
import
MultiLabelLoss
from
ppcls.modeling.loss
import
CELoss
from
ppcls.modeling.loss
import
CELoss
from
ppcls.modeling.loss
import
MixCELoss
from
ppcls.modeling.loss
import
MixCELoss
from
ppcls.modeling.loss
import
JSDivLoss
from
ppcls.modeling.loss
import
JSDivLoss
from
ppcls.modeling.loss
import
GoogLeNetLoss
from
ppcls.modeling.loss
import
GoogLeNetLoss
from
ppcls.utils.misc
import
AverageMeter
from
ppcls.utils.misc
import
AverageMeter
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
from
ppcls.utils
import
multi_hot_encode
from
ppcls.utils
import
hamming_distance
from
ppcls.utils
import
accuracy_score
def
create_model
(
architecture
,
classes_num
):
def
create_model
(
architecture
,
classes_num
):
...
@@ -61,7 +65,8 @@ def create_loss(feeds,
...
@@ -61,7 +65,8 @@ def create_loss(feeds,
classes_num
=
1000
,
classes_num
=
1000
,
epsilon
=
None
,
epsilon
=
None
,
use_mix
=
False
,
use_mix
=
False
,
use_distillation
=
False
):
use_distillation
=
False
,
multilabel
=
False
):
"""
"""
Create a loss for optimization, such as:
Create a loss for optimization, such as:
1. CrossEnotry loss
1. CrossEnotry loss
...
@@ -100,7 +105,10 @@ def create_loss(feeds,
...
@@ -100,7 +105,10 @@ def create_loss(feeds,
feed_lam
=
feeds
[
'lam'
]
feed_lam
=
feeds
[
'lam'
]
return
loss
(
out
,
feed_y_a
,
feed_y_b
,
feed_lam
)
return
loss
(
out
,
feed_y_a
,
feed_y_b
,
feed_lam
)
else
:
else
:
loss
=
CELoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
if
not
multilabel
:
loss
=
CELoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
else
:
loss
=
MultiLabelLoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
return
loss
(
out
,
feeds
[
"label"
])
return
loss
(
out
,
feeds
[
"label"
])
...
@@ -110,6 +118,7 @@ def create_metric(out,
...
@@ -110,6 +118,7 @@ def create_metric(out,
topk
=
5
,
topk
=
5
,
classes_num
=
1000
,
classes_num
=
1000
,
use_distillation
=
False
,
use_distillation
=
False
,
multilabel
=
False
,
mode
=
"train"
):
mode
=
"train"
):
"""
"""
Create measures of model accuracy, such as top1 and top5
Create measures of model accuracy, such as top1 and top5
...
@@ -135,24 +144,43 @@ def create_metric(out,
...
@@ -135,24 +144,43 @@ def create_metric(out,
softmax_out
=
F
.
softmax
(
out
)
softmax_out
=
F
.
softmax
(
out
)
fetchs
=
OrderedDict
()
fetchs
=
OrderedDict
()
# set top1 to fetchs
metric_names
=
set
()
top1
=
paddle
.
metric
.
accuracy
(
softmax_out
,
label
=
label
,
k
=
1
)
if
not
multilabel
:
# set topk to fetchs
softmax_out
=
F
.
softmax
(
out
)
k
=
min
(
topk
,
classes_num
)
topk
=
paddle
.
metric
.
accuracy
(
softmax_out
,
label
=
label
,
k
=
k
)
# set top1 to fetchs
top1
=
paddle
.
metric
.
accuracy
(
softmax_out
,
label
=
label
,
k
=
1
)
# set topk to fetchs
k
=
min
(
topk
,
classes_num
)
topk
=
paddle
.
metric
.
accuracy
(
softmax_out
,
label
=
label
,
k
=
k
)
metric_names
.
add
(
"top1"
)
metric_names
.
add
(
"top{}"
.
format
(
k
))
fetchs
[
'top1'
]
=
top1
topk_name
=
"top{}"
.
format
(
k
)
fetchs
[
topk_name
]
=
topk
else
:
out
=
F
.
sigmoid
(
out
)
preds
=
multi_hot_encode
(
out
.
numpy
())
targets
=
label
.
numpy
()
ham_dist
=
to_tensor
(
hamming_distance
(
preds
,
targets
))
accuracy
=
to_tensor
(
accuracy_score
(
preds
,
targets
,
base
=
"label"
))
ham_dist_name
=
"hamming_distance"
accuracy_name
=
"multilabel_accuracy"
metric_names
.
add
(
ham_dist_name
)
metric_names
.
add
(
accuracy_name
)
fetchs
[
accuracy_name
]
=
accuracy
fetchs
[
ham_dist_name
]
=
ham_dist
# multi cards' eval
# multi cards' eval
if
mode
!=
"train"
and
paddle
.
distributed
.
get_world_size
()
>
1
:
if
mode
!=
"train"
and
paddle
.
distributed
.
get_world_size
()
>
1
:
top1
=
paddle
.
distributed
.
all_reduce
(
for
metric_name
in
metric_names
:
top1
,
op
=
paddle
.
distributed
.
ReduceOp
.
fetchs
[
metric_name
]
=
paddle
.
distributed
.
all_reduce
(
SUM
)
/
paddle
.
distributed
.
get_world_size
()
fetchs
[
metric_name
],
op
=
paddle
.
distributed
.
ReduceOp
.
topk
=
paddle
.
distributed
.
all_reduce
(
SUM
)
/
paddle
.
distributed
.
get_world_size
()
topk
,
op
=
paddle
.
distributed
.
ReduceOp
.
SUM
)
/
paddle
.
distributed
.
get_world_size
()
fetchs
[
'top1'
]
=
top1
topk_name
=
'top{}'
.
format
(
k
)
fetchs
[
topk_name
]
=
topk
return
fetchs
return
fetchs
...
@@ -182,12 +210,14 @@ def create_fetchs(feeds, net, config, mode="train"):
...
@@ -182,12 +210,14 @@ def create_fetchs(feeds, net, config, mode="train"):
epsilon
=
config
.
get
(
'ls_epsilon'
)
epsilon
=
config
.
get
(
'ls_epsilon'
)
use_mix
=
config
.
get
(
'use_mix'
)
and
mode
==
'train'
use_mix
=
config
.
get
(
'use_mix'
)
and
mode
==
'train'
use_distillation
=
config
.
get
(
'use_distillation'
)
use_distillation
=
config
.
get
(
'use_distillation'
)
multilabel
=
config
.
get
(
'multilabel'
,
False
)
out
=
net
(
feeds
[
"image"
])
out
=
net
(
feeds
[
"image"
])
fetchs
=
OrderedDict
()
fetchs
=
OrderedDict
()
fetchs
[
'loss'
]
=
create_loss
(
feeds
,
out
,
architecture
,
classes_num
,
fetchs
[
'loss'
]
=
create_loss
(
feeds
,
out
,
architecture
,
classes_num
,
epsilon
,
use_mix
,
use_distillation
)
epsilon
,
use_mix
,
use_distillation
,
multilabel
)
if
not
use_mix
:
if
not
use_mix
:
metric
=
create_metric
(
metric
=
create_metric
(
out
,
out
,
...
@@ -196,6 +226,7 @@ def create_fetchs(feeds, net, config, mode="train"):
...
@@ -196,6 +226,7 @@ def create_fetchs(feeds, net, config, mode="train"):
topk
,
topk
,
classes_num
,
classes_num
,
use_distillation
,
use_distillation
,
multilabel
=
multilabel
,
mode
=
mode
)
mode
=
mode
)
fetchs
.
update
(
metric
)
fetchs
.
update
(
metric
)
...
@@ -240,7 +271,7 @@ def create_optimizer(config, parameter_list=None):
...
@@ -240,7 +271,7 @@ def create_optimizer(config, parameter_list=None):
return
opt
(
lr
,
parameter_list
),
lr
return
opt
(
lr
,
parameter_list
),
lr
def
create_feeds
(
batch
,
use_mix
):
def
create_feeds
(
batch
,
use_mix
,
num_classes
,
multilabel
=
False
):
image
=
batch
[
0
]
image
=
batch
[
0
]
if
use_mix
:
if
use_mix
:
y_a
=
to_tensor
(
batch
[
1
].
numpy
().
astype
(
"int64"
).
reshape
(
-
1
,
1
))
y_a
=
to_tensor
(
batch
[
1
].
numpy
().
astype
(
"int64"
).
reshape
(
-
1
,
1
))
...
@@ -248,7 +279,10 @@ def create_feeds(batch, use_mix):
...
@@ -248,7 +279,10 @@ def create_feeds(batch, use_mix):
lam
=
to_tensor
(
batch
[
3
].
numpy
().
astype
(
"float32"
).
reshape
(
-
1
,
1
))
lam
=
to_tensor
(
batch
[
3
].
numpy
().
astype
(
"float32"
).
reshape
(
-
1
,
1
))
feeds
=
{
"image"
:
image
,
"y_a"
:
y_a
,
"y_b"
:
y_b
,
"lam"
:
lam
}
feeds
=
{
"image"
:
image
,
"y_a"
:
y_a
,
"y_b"
:
y_b
,
"lam"
:
lam
}
else
:
else
:
label
=
to_tensor
(
batch
[
1
].
numpy
().
astype
(
'int64'
).
reshape
(
-
1
,
1
))
if
not
multilabel
:
label
=
to_tensor
(
batch
[
1
].
numpy
().
astype
(
"int64"
).
reshape
(
-
1
,
1
))
else
:
label
=
to_tensor
(
batch
[
1
].
numpy
().
astype
(
'float32'
).
reshape
(
-
1
,
num_classes
))
feeds
=
{
"image"
:
image
,
"label"
:
label
}
feeds
=
{
"image"
:
image
,
"label"
:
label
}
return
feeds
return
feeds
...
@@ -279,6 +313,8 @@ def run(dataloader,
...
@@ -279,6 +313,8 @@ def run(dataloader,
"""
"""
print_interval
=
config
.
get
(
"print_interval"
,
10
)
print_interval
=
config
.
get
(
"print_interval"
,
10
)
use_mix
=
config
.
get
(
"use_mix"
,
False
)
and
mode
==
"train"
use_mix
=
config
.
get
(
"use_mix"
,
False
)
and
mode
==
"train"
multilabel
=
config
.
get
(
"multilabel"
,
False
)
classes_num
=
config
.
get
(
"classes_num"
)
metric_list
=
[
metric_list
=
[
(
"loss"
,
AverageMeter
(
(
"loss"
,
AverageMeter
(
...
@@ -291,13 +327,19 @@ def run(dataloader,
...
@@ -291,13 +327,19 @@ def run(dataloader,
'reader_cost'
,
'.5f'
,
postfix
=
" s,"
)),
'reader_cost'
,
'.5f'
,
postfix
=
" s,"
)),
]
]
if
not
use_mix
:
if
not
use_mix
:
topk_name
=
'top{}'
.
format
(
config
.
topk
)
if
not
multilabel
:
metric_list
.
insert
(
topk_name
=
'top{}'
.
format
(
config
.
topk
)
0
,
(
topk_name
,
AverageMeter
(
metric_list
.
insert
(
topk_name
,
'.5f'
,
postfix
=
","
)))
0
,
(
topk_name
,
AverageMeter
(
metric_list
.
insert
(
topk_name
,
'.5f'
,
postfix
=
","
)))
0
,
(
"top1"
,
AverageMeter
(
metric_list
.
insert
(
"top1"
,
'.5f'
,
postfix
=
","
)))
0
,
(
"top1"
,
AverageMeter
(
"top1"
,
'.5f'
,
postfix
=
","
)))
else
:
metric_list
.
insert
(
0
,
(
"multilabel_accuracy"
,
AverageMeter
(
"multilabel_accuracy"
,
'.5f'
,
postfix
=
","
)))
metric_list
.
insert
(
0
,
(
"hamming_distance"
,
AverageMeter
(
"hamming_distance"
,
'.5f'
,
postfix
=
","
)))
metric_list
=
OrderedDict
(
metric_list
)
metric_list
=
OrderedDict
(
metric_list
)
...
@@ -310,7 +352,7 @@ def run(dataloader,
...
@@ -310,7 +352,7 @@ def run(dataloader,
metric_list
[
'reader_time'
].
update
(
time
.
time
()
-
tic
)
metric_list
[
'reader_time'
].
update
(
time
.
time
()
-
tic
)
batch_size
=
len
(
batch
[
0
])
batch_size
=
len
(
batch
[
0
])
feeds
=
create_feeds
(
batch
,
use_mix
)
feeds
=
create_feeds
(
batch
,
use_mix
,
classes_num
,
multilabel
)
fetchs
=
create_fetchs
(
feeds
,
net
,
config
,
mode
)
fetchs
=
create_fetchs
(
feeds
,
net
,
config
,
mode
)
if
mode
==
'train'
:
if
mode
==
'train'
:
avg_loss
=
fetchs
[
'loss'
]
avg_loss
=
fetchs
[
'loss'
]
...
@@ -387,4 +429,7 @@ def run(dataloader,
...
@@ -387,4 +429,7 @@ def run(dataloader,
# return top1_acc in order to save the best model
# return top1_acc in order to save the best model
if
mode
==
'valid'
:
if
mode
==
'valid'
:
return
metric_list
[
'top1'
].
avg
if
multilabel
:
return
metric_list
[
'multilabel_accuracy'
].
avg
else
:
return
metric_list
[
'top1'
].
avg
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录