Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
d01c3323
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d01c3323
编写于
8月 26, 2022
作者:
W
Walter
提交者:
GitHub
8月 26, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2224 from HydrogenSulfate/shituv2_reimplement
【WIP】add PP-ShiTuV2
上级
ced5b523
73e1d366
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
519 addition
and
44 deletion
+519
-44
docs/zh_CN/quick_start/quick_start_recognition.md
docs/zh_CN/quick_start/quick_start_recognition.md
+2
-1
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
+4
-2
ppcls/arch/backbone/variant_models/__init__.py
ppcls/arch/backbone/variant_models/__init__.py
+1
-0
ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py
ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py
+56
-0
ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml
...ralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml
+205
-0
ppcls/data/dataloader/imagenet_dataset.py
ppcls/data/dataloader/imagenet_dataset.py
+41
-14
ppcls/data/dataloader/pk_sampler.py
ppcls/data/dataloader/pk_sampler.py
+19
-3
ppcls/data/dataloader/vehicle_dataset.py
ppcls/data/dataloader/vehicle_dataset.py
+16
-13
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+1
-0
ppcls/data/preprocess/ops/operators.py
ppcls/data/preprocess/ops/operators.py
+50
-11
ppcls/engine/engine.py
ppcls/engine/engine.py
+1
-0
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+1
-0
ppcls/loss/tripletangularmarginloss.py
ppcls/loss/tripletangularmarginloss.py
+115
-0
ppcls/metric/metrics.py
ppcls/metric/metrics.py
+6
-0
未找到文件。
docs/zh_CN/quick_start/quick_start_recognition.md
浏览文件 @
d01c3323
...
...
@@ -42,9 +42,10 @@
### 1.1 安装 PP-ShiTu android demo
可以通过扫描二维码或者
[
点击链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/demos/PP-ShiTu.apk
)
下载并安装APP
**注:**
华为鸿蒙OS 3.0的系统可能会出现无法调用摄像头的情况,建议更换低版本系统或者使用其它安卓机型进行快速体验。
<
img
src=
"../../images/quick_start/android_demo/PPShiTu_qcode.png"
height=
"250"
width=
"250"
/
>
<
div
align=
center
><img
src=
"../../images/quick_start/android_demo/PPShiTu_qcode.png"
height=
"400"
width=
"400"
/></div
>
<a
name=
"功能体验"
></a>
...
...
ppcls/arch/backbone/__init__.py
浏览文件 @
d01c3323
...
...
@@ -73,6 +73,7 @@ from .model_zoo.convnext import ConvNeXt_tiny
from
.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
.variant_models.vgg_variant
import
VGG19Sigmoid
from
.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
.variant_models.pp_lcnetv2_variant
import
PPLCNetV2_base_ShiTu
from
.model_zoo.adaface_ir_net
import
AdaFace_IR_18
,
AdaFace_IR_34
,
AdaFace_IR_50
,
AdaFace_IR_101
,
AdaFace_IR_152
,
AdaFace_IR_SE_50
,
AdaFace_IR_SE_101
,
AdaFace_IR_SE_152
,
AdaFace_IR_SE_200
...
...
ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
浏览文件 @
d01c3323
...
...
@@ -126,6 +126,8 @@ class RepDepthwiseSeparable(TheseusLayer):
use_se
=
False
,
use_shortcut
=
False
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
is_repped
=
False
self
.
dw_size
=
dw_size
...
...
@@ -306,8 +308,8 @@ class PPLCNetV2(TheseusLayer):
self
.
dropout
=
Dropout
(
p
=
dropout_prob
,
mode
=
"downscale_in_infer"
)
self
.
flatten
=
nn
.
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
in_features
=
self
.
class_expand
if
self
.
use_last_conv
else
NET_CONFIG
[
"stage4"
][
0
]
*
2
*
scale
in_features
=
self
.
class_expand
if
self
.
use_last_conv
else
make_divisible
(
NET_CONFIG
[
"stage4"
][
0
]
*
2
*
scale
)
self
.
fc
=
Linear
(
in_features
,
class_num
)
def
forward
(
self
,
x
):
...
...
ppcls/arch/backbone/variant_models/__init__.py
浏览文件 @
d01c3323
from
.resnet_variant
import
ResNet50_last_stage_stride1
from
.vgg_variant
import
VGG19Sigmoid
from
.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
.pp_lcnetv2_variant
import
PPLCNetV2_base_ShiTu
ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py
0 → 100644
浏览文件 @
d01c3323
from
paddle.nn
import
Conv2D
,
Identity
from
..legendary_models.pp_lcnet_v2
import
MODEL_URLS
,
PPLCNetV2_base
,
RepDepthwiseSeparable
,
_load_pretrained
__all__
=
[
"PPLCNetV2_base_ShiTu"
]
def
PPLCNetV2_base_ShiTu
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
An variant network of PPLCNetV2_base
1. remove ReLU layer after last_conv
2. add bias to last_conv
3. change stride to 1 in last two RepDepthwiseSeparable Block
"""
model
=
PPLCNetV2_base
(
pretrained
=
False
,
use_ssld
=
use_ssld
,
**
kwargs
)
def
remove_ReLU_function
(
conv
,
pattern
):
new_conv
=
Identity
()
return
new_conv
def
add_bias_last_conv
(
conv
,
pattern
):
new_conv
=
Conv2D
(
in_channels
=
conv
.
_in_channels
,
out_channels
=
conv
.
_out_channels
,
kernel_size
=
conv
.
_kernel_size
,
stride
=
conv
.
_stride
,
padding
=
conv
.
_padding
,
groups
=
conv
.
_groups
,
bias_attr
=
True
)
return
new_conv
def
last_stride_function
(
rep_block
,
pattern
):
new_conv
=
RepDepthwiseSeparable
(
in_channels
=
rep_block
.
in_channels
,
out_channels
=
rep_block
.
out_channels
,
stride
=
1
,
dw_size
=
rep_block
.
dw_size
,
split_pw
=
rep_block
.
split_pw
,
use_rep
=
rep_block
.
use_rep
,
use_se
=
rep_block
.
use_se
,
use_shortcut
=
rep_block
.
use_shortcut
)
return
new_conv
pattern_act
=
[
"act"
]
pattern_lastconv
=
[
"last_conv"
]
pattern_last_stride
=
[
"stages[3][0]"
,
"stages[3][1]"
,
]
model
.
upgrade_sublayer
(
pattern_act
,
remove_ReLU_function
)
model
.
upgrade_sublayer
(
pattern_lastconv
,
add_bias_last_conv
)
model
.
upgrade_sublayer
(
pattern_last_stride
,
last_stride_function
)
# load params again after upgrade some layers
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"PPLCNetV2_base"
],
use_ssld
)
return
model
ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml
0 → 100644
浏览文件 @
d01c3323
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
100
print_batch_step
:
20
use_visualdl
:
False
eval_mode
:
retrieval
retrieval_feature_from
:
features
# 'backbone' or 'features'
re_ranking
:
False
use_dali
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
AMP
:
scale_loss
:
65536
use_dynamic_loss_scaling
:
True
# O1: mixed fp16
level
:
O1
# model architecture
Arch
:
name
:
RecModel
infer_output_key
:
features
infer_add_softmax
:
False
Backbone
:
name
:
PPLCNetV2_base_ShiTu
pretrained
:
True
use_ssld
:
True
class_expand
:
&feat_dim
512
BackboneStopLayer
:
name
:
flatten
Neck
:
name
:
BNNeck
num_features
:
*feat_dim
weight_attr
:
initializer
:
name
:
Constant
value
:
1.0
bias_attr
:
initializer
:
name
:
Constant
value
:
0.0
learning_rate
:
1.0e-20
# NOTE: Temporarily set lr small enough to freeze the bias to zero
Head
:
name
:
FC
embedding_size
:
*feat_dim
class_num
:
192612
weight_attr
:
initializer
:
name
:
Normal
std
:
0.001
bias_attr
:
False
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
-
TripletAngularMarginLoss
:
weight
:
1.0
feature_from
:
features
margin
:
0.5
reduction
:
mean
add_absolute
:
True
absolute_loss_weight
:
0.1
normalize_feature
:
True
ap_value
:
0.8
an_value
:
0.4
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.06
# for 8gpu x 256bs
warmup_epoch
:
5
regularizer
:
name
:
L2
coeff
:
0.00001
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/
cls_label_path
:
./dataset/train_reg_all_data_v2.txt
relabel
:
True
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
224
,
224
]
return_numpy
:
False
interpolation
:
bilinear
backend
:
cv2
-
RandFlipImage
:
flip_code
:
1
-
Pad
:
padding
:
10
backend
:
cv2
-
RandCropImageV2
:
size
:
[
224
,
224
]
-
RandomRotation
:
prob
:
0.5
degrees
:
90
interpolation
:
bilinear
-
ResizeImage
:
size
:
[
224
,
224
]
return_numpy
:
False
interpolation
:
bilinear
backend
:
cv2
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
hwc
sampler
:
name
:
PKSampler
batch_size
:
256
sample_per_id
:
4
drop_last
:
False
shuffle
:
True
sample_method
:
"
id_avg_prob"
id_list
:
[
50030
,
80700
,
92019
,
96015
]
# be careful when set relabel=True
ratio
:
[
4
,
4
]
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
Query
:
dataset
:
name
:
VeriWild
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
224
,
224
]
return_numpy
:
False
interpolation
:
bilinear
backend
:
cv2
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
hwc
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Gallery
:
dataset
:
name
:
VeriWild
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
224
,
224
]
return_numpy
:
False
interpolation
:
bilinear
backend
:
cv2
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
hwc
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Metric
:
Eval
:
-
Recallk
:
topk
:
[
1
,
5
]
-
mAP
:
{}
ppcls/data/dataloader/imagenet_dataset.py
浏览文件 @
d01c3323
...
...
@@ -21,27 +21,54 @@ from .common_dataset import CommonDataset
class
ImageNetDataset
(
CommonDataset
):
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
,
delimiter
=
None
):
"""ImageNetDataset
Args:
image_root (str): image root, path to `ILSVRC2012`
cls_label_path (str): path to annotation file `train_list.txt` or 'val_list.txt`
transform_ops (list, optional): list of transform op(s). Defaults to None.
delimiter (str, optional): delimiter. Defaults to None.
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
"""
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
,
delimiter
=
None
,
relabel
=
False
):
self
.
delimiter
=
delimiter
if
delimiter
is
not
None
else
" "
super
(
ImageNetDataset
,
self
).
__init__
(
image_root
,
cls_label_path
,
transform_ops
)
self
.
relabel
=
relabel
super
(
ImageNetDataset
,
self
).
__init__
(
image_root
,
cls_label_path
,
transform_ops
)
def
_load_anno
(
self
,
seed
=
None
):
assert
os
.
path
.
exists
(
self
.
_cls_path
)
assert
os
.
path
.
exists
(
self
.
_img_root
)
assert
os
.
path
.
exists
(
self
.
_cls_path
),
f
"path
{
self
.
_cls_path
}
does not exist."
assert
os
.
path
.
exists
(
self
.
_img_root
),
f
"path
{
self
.
_img_root
}
does not exist."
self
.
images
=
[]
self
.
labels
=
[]
with
open
(
self
.
_cls_path
)
as
fd
:
lines
=
fd
.
readlines
()
if
self
.
relabel
:
label_set
=
set
()
for
line
in
lines
:
line
=
line
.
strip
().
split
(
self
.
delimiter
)
label_set
.
add
(
np
.
int64
(
line
[
1
]))
label_map
=
{
oldlabel
:
newlabel
for
newlabel
,
oldlabel
in
enumerate
(
label_set
)
}
if
seed
is
not
None
:
np
.
random
.
RandomState
(
seed
).
shuffle
(
lines
)
for
l
in
lines
:
l
=
l
.
strip
().
split
(
self
.
delimiter
)
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
l
[
0
]))
self
.
labels
.
append
(
np
.
int64
(
l
[
1
]))
assert
os
.
path
.
exists
(
self
.
images
[
-
1
])
for
line
in
lines
:
line
=
line
.
strip
().
split
(
self
.
delimiter
)
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
line
[
0
]))
if
self
.
relabel
:
self
.
labels
.
append
(
label_map
[
np
.
int64
(
line
[
1
])])
else
:
self
.
labels
.
append
(
np
.
int64
(
line
[
1
]))
assert
os
.
path
.
exists
(
self
.
images
[
-
1
]),
f
"path
{
self
.
images
[
-
1
]
}
does not exist."
ppcls/data/dataloader/pk_sampler.py
浏览文件 @
d01c3323
...
...
@@ -32,17 +32,23 @@ class PKSampler(DistributedBatchSampler):
batch_size (int): batch size
sample_per_id (int): number of instance(s) within an class
shuffle (bool, optional): _description_. Defaults to True.
id_list(list): list of (start_id, end_id, start_id, end_id) for set of ids to duplicated.
ratio(list): list of (ratio1, ratio2..) the duplication number for ids in id_list.
drop_last (bool, optional): whether to discard the data at the end. Defaults to True.
sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
"""
def
__init__
(
self
,
dataset
,
batch_size
,
sample_per_id
,
shuffle
=
True
,
drop_last
=
True
,
id_list
=
None
,
ratio
=
None
,
sample_method
=
"sample_avg_prob"
):
super
().
__init__
(
dataset
,
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
super
().
__init__
(
dataset
,
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
assert
batch_size
%
sample_per_id
==
0
,
\
f
"PKSampler configs error, sample_per_id(
{
sample_per_id
}
) must be a divisor of batch_size(
{
batch_size
}
)."
assert
hasattr
(
self
.
dataset
,
...
...
@@ -67,6 +73,16 @@ class PKSampler(DistributedBatchSampler):
logger
.
error
(
"PKSampler only support id_avg_prob and sample_avg_prob sample method, "
"but receive {}."
.
format
(
self
.
sample_method
))
if
id_list
and
ratio
:
assert
len
(
id_list
)
%
2
==
0
and
len
(
id_list
)
==
len
(
ratio
)
*
2
for
i
in
range
(
len
(
self
.
prob_list
)):
for
j
in
range
(
len
(
ratio
)):
if
i
>=
id_list
[
j
*
2
]
and
i
<=
id_list
[
j
*
2
+
1
]:
self
.
prob_list
[
i
]
=
self
.
prob_list
[
i
]
*
ratio
[
j
]
break
self
.
prob_list
=
self
.
prob_list
/
sum
(
self
.
prob_list
)
diff
=
np
.
abs
(
sum
(
self
.
prob_list
)
-
1
)
if
diff
>
0.00000001
:
self
.
prob_list
[
-
1
]
=
1
-
sum
(
self
.
prob_list
[:
-
1
])
...
...
@@ -74,8 +90,8 @@ class PKSampler(DistributedBatchSampler):
logger
.
error
(
"PKSampler prob list error"
)
else
:
logger
.
info
(
"PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob"
.
format
(
diff
)
)
"PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob"
.
format
(
diff
)
)
def
__iter__
(
self
):
label_per_batch
=
self
.
batch_size
//
self
.
sample_per_label
...
...
ppcls/data/dataloader/vehicle_dataset.py
浏览文件 @
d01c3323
...
...
@@ -89,11 +89,7 @@ class CompCars(Dataset):
class
VeriWild
(
Dataset
):
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
,
):
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
):
self
.
_img_root
=
image_root
self
.
_cls_path
=
cls_label_path
if
transform_ops
:
...
...
@@ -102,19 +98,23 @@ class VeriWild(Dataset):
self
.
_load_anno
()
def
_load_anno
(
self
):
assert
os
.
path
.
exists
(
self
.
_cls_path
)
assert
os
.
path
.
exists
(
self
.
_img_root
)
assert
os
.
path
.
exists
(
self
.
_cls_path
),
f
"path
{
self
.
_cls_path
}
does not exist."
assert
os
.
path
.
exists
(
self
.
_img_root
),
f
"path
{
self
.
_img_root
}
does not exist."
self
.
images
=
[]
self
.
labels
=
[]
self
.
cameras
=
[]
with
open
(
self
.
_cls_path
)
as
fd
:
lines
=
fd
.
readlines
()
for
l
in
lines
:
l
=
l
.
strip
().
split
()
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
l
[
0
]))
self
.
labels
.
append
(
np
.
int64
(
l
[
1
]))
self
.
cameras
.
append
(
np
.
int64
(
l
[
2
]))
for
line
in
lines
:
line
=
line
.
strip
().
split
()
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
line
[
0
]))
self
.
labels
.
append
(
np
.
int64
(
line
[
1
]))
if
len
(
line
)
>=
3
:
self
.
cameras
.
append
(
np
.
int64
(
line
[
2
]))
assert
os
.
path
.
exists
(
self
.
images
[
-
1
])
self
.
has_camera
=
len
(
self
.
cameras
)
>
0
def
__getitem__
(
self
,
idx
):
try
:
...
...
@@ -123,7 +123,10 @@ class VeriWild(Dataset):
if
self
.
_transform_ops
:
img
=
transform
(
img
,
self
.
_transform_ops
)
img
=
img
.
transpose
((
2
,
0
,
1
))
return
(
img
,
self
.
labels
[
idx
],
self
.
cameras
[
idx
])
if
self
.
has_camera
:
return
(
img
,
self
.
labels
[
idx
],
self
.
cameras
[
idx
])
else
:
return
(
img
,
self
.
labels
[
idx
])
except
Exception
as
ex
:
logger
.
error
(
"Exception occured when parse line: {} with msg: {}"
.
format
(
self
.
images
[
idx
],
ex
))
...
...
ppcls/data/preprocess/__init__.py
浏览文件 @
d01c3323
...
...
@@ -38,6 +38,7 @@ from ppcls.data.preprocess.ops.operators import CropWithPadding
from
ppcls.data.preprocess.ops.operators
import
RandomInterpolationAugment
from
ppcls.data.preprocess.ops.operators
import
ColorJitter
from
ppcls.data.preprocess.ops.operators
import
RandomCropImage
from
ppcls.data.preprocess.ops.operators
import
RandomRotation
from
ppcls.data.preprocess.ops.operators
import
Padv2
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
OpSampler
,
FmixOperator
...
...
ppcls/data/preprocess/ops/operators.py
浏览文件 @
d01c3323
...
...
@@ -26,6 +26,7 @@ import cv2
import
numpy
as
np
from
PIL
import
Image
,
ImageOps
,
__version__
as
PILLOW_VERSION
from
paddle.vision.transforms
import
ColorJitter
as
RawColorJitter
from
paddle.vision.transforms
import
RandomRotation
as
RawRandomRotation
from
paddle.vision.transforms
import
ToTensor
,
Normalize
,
RandomHorizontalFlip
,
RandomResizedCrop
from
paddle.vision.transforms
import
functional
as
F
from
.autoaugment
import
ImageNetPolicy
...
...
@@ -181,7 +182,8 @@ class DecodeImage(object):
img
=
np
.
asarray
(
img
)[:,
:,
::
-
1
]
# BRG
if
self
.
to_rgb
:
assert
img
.
shape
[
2
]
==
3
,
f
"invalid shape of image[
{
img
.
shape
}
]"
assert
img
.
shape
[
2
]
==
3
,
f
"invalid shape of image[
{
img
.
shape
}
]"
img
=
img
[:,
:,
::
-
1
]
if
self
.
channel_first
:
...
...
@@ -495,7 +497,13 @@ class RandFlipImage(object):
if
isinstance
(
img
,
np
.
ndarray
):
return
cv2
.
flip
(
img
,
self
.
flip_code
)
else
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
if
self
.
flip_code
==
1
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
elif
self
.
flip_code
==
0
:
return
img
.
transpose
(
Image
.
FLIP_TOP_BOTTOM
)
else
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
).
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
else
:
return
img
...
...
@@ -653,17 +661,38 @@ class ColorJitter(RawColorJitter):
return
img
class
RandomRotation
(
RawRandomRotation
):
"""RandomRotation.
"""
def
__init__
(
self
,
prob
=
0.5
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
prob
=
prob
def
__call__
(
self
,
img
):
if
np
.
random
.
random
()
<
self
.
prob
:
img
=
super
().
_apply_image
(
img
)
return
img
class
Pad
(
object
):
"""
Pads the given PIL.Image on all sides with specified padding mode and fill value.
adapted from: https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Pad
"""
def
__init__
(
self
,
padding
:
int
,
fill
:
int
=
0
,
padding_mode
:
str
=
"constant"
):
def
__init__
(
self
,
padding
:
int
,
fill
:
int
=
0
,
padding_mode
:
str
=
"constant"
,
backend
:
str
=
"pil"
):
self
.
padding
=
padding
self
.
fill
=
fill
self
.
padding_mode
=
padding_mode
self
.
backend
=
backend
assert
backend
in
[
"pil"
,
"cv2"
],
f
"backend must in ['pil', 'cv2'], but got
{
backend
}
"
def
_parse_fill
(
self
,
fill
,
img
,
min_pil_version
,
name
=
"fillcolor"
):
# Process fill color for affine transforms
...
...
@@ -698,11 +727,21 @@ class Pad(object):
return
{
name
:
fill
}
def
__call__
(
self
,
img
):
opts
=
self
.
_parse_fill
(
self
.
fill
,
img
,
"2.3.0"
,
name
=
"fill"
)
if
img
.
mode
==
"P"
:
palette
=
img
.
getpalette
()
img
=
ImageOps
.
expand
(
img
,
border
=
self
.
padding
,
**
opts
)
img
.
putpalette
(
palette
)
if
self
.
backend
==
"pil"
:
opts
=
self
.
_parse_fill
(
self
.
fill
,
img
,
"2.3.0"
,
name
=
"fill"
)
if
img
.
mode
==
"P"
:
palette
=
img
.
getpalette
()
img
=
ImageOps
.
expand
(
img
,
border
=
self
.
padding
,
**
opts
)
img
.
putpalette
(
palette
)
return
img
return
ImageOps
.
expand
(
img
,
border
=
self
.
padding
,
**
opts
)
else
:
img
=
cv2
.
copyMakeBorder
(
img
,
self
.
padding
,
self
.
padding
,
self
.
padding
,
self
.
padding
,
cv2
.
BORDER_CONSTANT
,
value
=
(
self
.
fill
,
self
.
fill
,
self
.
fill
))
return
img
return
ImageOps
.
expand
(
img
,
border
=
self
.
padding
,
**
opts
)
ppcls/engine/engine.py
浏览文件 @
d01c3323
...
...
@@ -114,6 +114,7 @@ class Engine(object):
#TODO(gaotingquan): support rec
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
self
.
config
[
"DataLoader"
].
update
({
"class_num"
:
class_num
})
# build dataloader
if
self
.
mode
==
'train'
:
self
.
train_dataloader
=
build_dataloader
(
...
...
ppcls/loss/__init__.py
浏览文件 @
d01c3323
...
...
@@ -12,6 +12,7 @@ from .msmloss import MSMLoss
from
.npairsloss
import
NpairsLoss
from
.trihardloss
import
TriHardLoss
from
.triplet
import
TripletLoss
,
TripletLossV2
from
.tripletangularmarginloss
import
TripletAngularMarginLoss
from
.supconloss
import
SupConLoss
from
.pairwisecosface
import
PairwiseCosface
from
.dmlloss
import
DMLLoss
...
...
ppcls/loss/tripletangularmarginloss.py
0 → 100644
浏览文件 @
d01c3323
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
class
TripletAngularMarginLoss
(
nn
.
Layer
):
"""A more robust triplet loss with hard positive/negative mining on angular margin instead of relative distance between d(a,p) and d(a,n).
Args:
margin (float, optional): angular margin. Defaults to 0.5.
normalize_feature (bool, optional): whether to apply L2-norm in feature before computing distance(cos-similarity). Defaults to True.
reduction (str, optional): reducing option within an batch . Defaults to "mean".
add_absolute (bool, optional): whether add absolute loss within d(a,p) or d(a,n). Defaults to False.
absolute_loss_weight (float, optional): weight for absolute loss. Defaults to 1.0.
ap_value (float, optional): weight for d(a, p). Defaults to 0.9.
an_value (float, optional): weight for d(a, n). Defaults to 0.5.
feature_from (str, optional): which key feature from. Defaults to "features".
"""
def
__init__
(
self
,
margin
=
0.5
,
normalize_feature
=
True
,
reduction
=
"mean"
,
add_absolute
=
False
,
absolute_loss_weight
=
1.0
,
ap_value
=
0.9
,
an_value
=
0.5
,
feature_from
=
"features"
):
super
(
TripletAngularMarginLoss
,
self
).
__init__
()
self
.
margin
=
margin
self
.
feature_from
=
feature_from
self
.
ranking_loss
=
paddle
.
nn
.
loss
.
MarginRankingLoss
(
margin
=
margin
,
reduction
=
reduction
)
self
.
normalize_feature
=
normalize_feature
self
.
add_absolute
=
add_absolute
self
.
ap_value
=
ap_value
self
.
an_value
=
an_value
self
.
absolute_loss_weight
=
absolute_loss_weight
def
forward
(
self
,
input
,
target
):
"""
Args:
inputs: feature matrix with shape (batch_size, feat_dim)
target: ground truth labels with shape (num_classes)
"""
inputs
=
input
[
self
.
feature_from
]
if
self
.
normalize_feature
:
inputs
=
paddle
.
divide
(
inputs
,
paddle
.
norm
(
inputs
,
p
=
2
,
axis
=-
1
,
keepdim
=
True
))
bs
=
inputs
.
shape
[
0
]
# compute distance(cos-similarity)
dist
=
paddle
.
matmul
(
inputs
,
inputs
.
t
())
# hard negative mining
is_pos
=
paddle
.
expand
(
target
,
(
bs
,
bs
)).
equal
(
paddle
.
expand
(
target
,
(
bs
,
bs
)).
t
())
is_neg
=
paddle
.
expand
(
target
,
(
bs
,
bs
)).
not_equal
(
paddle
.
expand
(
target
,
(
bs
,
bs
)).
t
())
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
dist_ap
=
paddle
.
min
(
paddle
.
reshape
(
paddle
.
masked_select
(
dist
,
is_pos
),
(
bs
,
-
1
)),
axis
=
1
,
keepdim
=
True
)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an
=
paddle
.
max
(
paddle
.
reshape
(
paddle
.
masked_select
(
dist
,
is_neg
),
(
bs
,
-
1
)),
axis
=
1
,
keepdim
=
True
)
# shape [N]
dist_ap
=
paddle
.
squeeze
(
dist_ap
,
axis
=
1
)
dist_an
=
paddle
.
squeeze
(
dist_an
,
axis
=
1
)
# Compute ranking hinge loss
y
=
paddle
.
ones_like
(
dist_an
)
loss
=
self
.
ranking_loss
(
dist_ap
,
dist_an
,
y
)
if
self
.
add_absolute
:
absolut_loss_ap
=
self
.
ap_value
-
dist_ap
absolut_loss_ap
=
paddle
.
where
(
absolut_loss_ap
>
0
,
absolut_loss_ap
,
paddle
.
zeros_like
(
absolut_loss_ap
))
absolut_loss_an
=
dist_an
-
self
.
an_value
absolut_loss_an
=
paddle
.
where
(
absolut_loss_an
>
0
,
absolut_loss_an
,
paddle
.
ones_like
(
absolut_loss_an
))
loss
=
(
absolut_loss_an
.
mean
()
+
absolut_loss_ap
.
mean
()
)
*
self
.
absolute_loss_weight
+
loss
.
mean
()
return
{
"TripletAngularMarginLoss"
:
loss
}
ppcls/metric/metrics.py
浏览文件 @
d01c3323
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
cmath
import
nan
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
...
...
@@ -97,6 +98,11 @@ class mAP(nn.Layer):
num_rel
=
paddle
.
greater_than
(
num_rel
,
paddle
.
to_tensor
(
0.
))
num_rel_index
=
paddle
.
nonzero
(
num_rel
.
astype
(
"int"
))
num_rel_index
=
paddle
.
reshape
(
num_rel_index
,
[
num_rel_index
.
shape
[
0
]])
if
paddle
.
numel
(
num_rel_index
).
item
()
==
0
:
metric_dict
[
"mAP"
]
=
np
.
nan
return
metric_dict
equal_flag
=
paddle
.
index_select
(
equal_flag
,
num_rel_index
,
axis
=
0
)
acc_sum
=
paddle
.
cumsum
(
equal_flag
,
axis
=
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录