Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
1b5e00e8
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1b5e00e8
编写于
8月 23, 2022
作者:
H
HydrogenSulfate
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add PP-ShiTuV2 code
上级
dab99e3e
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
418 addition
and
27 deletion
+418
-27
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/base/theseus_layer.py
ppcls/arch/backbone/base/theseus_layer.py
+2
-0
ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
+2
-2
ppcls/arch/backbone/variant_models/__init__.py
ppcls/arch/backbone/variant_models/__init__.py
+1
-0
ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py
ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py
+44
-0
ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml
...ralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml
+198
-0
ppcls/data/dataloader/imagenet_dataset.py
ppcls/data/dataloader/imagenet_dataset.py
+13
-12
ppcls/data/dataloader/vehicle_dataset.py
ppcls/data/dataloader/vehicle_dataset.py
+12
-11
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+1
-0
ppcls/data/preprocess/ops/operators.py
ppcls/data/preprocess/ops/operators.py
+24
-2
ppcls/engine/engine.py
ppcls/engine/engine.py
+4
-0
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+1
-0
ppcls/loss/tripletangularmarginloss.py
ppcls/loss/tripletangularmarginloss.py
+115
-0
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
1b5e00e8
...
@@ -73,6 +73,7 @@ from .model_zoo.convnext import ConvNeXt_tiny
...
@@ -73,6 +73,7 @@ from .model_zoo.convnext import ConvNeXt_tiny
from
.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
.variant_models.vgg_variant
import
VGG19Sigmoid
from
.variant_models.vgg_variant
import
VGG19Sigmoid
from
.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
.variant_models.pp_lcnetv2_variant
import
PPLCNetV2_base_ShiTu
from
.model_zoo.adaface_ir_net
import
AdaFace_IR_18
,
AdaFace_IR_34
,
AdaFace_IR_50
,
AdaFace_IR_101
,
AdaFace_IR_152
,
AdaFace_IR_SE_50
,
AdaFace_IR_SE_101
,
AdaFace_IR_SE_152
,
AdaFace_IR_SE_200
from
.model_zoo.adaface_ir_net
import
AdaFace_IR_18
,
AdaFace_IR_34
,
AdaFace_IR_50
,
AdaFace_IR_101
,
AdaFace_IR_152
,
AdaFace_IR_SE_50
,
AdaFace_IR_SE_101
,
AdaFace_IR_SE_152
,
AdaFace_IR_SE_200
...
...
ppcls/arch/backbone/base/theseus_layer.py
浏览文件 @
1b5e00e8
...
@@ -158,6 +158,8 @@ class TheseusLayer(nn.Layer):
...
@@ -158,6 +158,8 @@ class TheseusLayer(nn.Layer):
return
False
return
False
parent_layer
=
layer_dict
[
"layer"
]
parent_layer
=
layer_dict
[
"layer"
]
msg
=
f
"Successfully set the layers that after stop_layer_name('
{
stop_layer_name
}
') to IdentityLayer."
logger
.
info
(
msg
)
return
True
return
True
def
update_res
(
def
update_res
(
...
...
ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
浏览文件 @
1b5e00e8
...
@@ -306,8 +306,8 @@ class PPLCNetV2(TheseusLayer):
...
@@ -306,8 +306,8 @@ class PPLCNetV2(TheseusLayer):
self
.
dropout
=
Dropout
(
p
=
dropout_prob
,
mode
=
"downscale_in_infer"
)
self
.
dropout
=
Dropout
(
p
=
dropout_prob
,
mode
=
"downscale_in_infer"
)
self
.
flatten
=
nn
.
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
self
.
flatten
=
nn
.
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
in_features
=
self
.
class_expand
if
self
.
use_last_conv
else
NET_CONFIG
[
in_features
=
self
.
class_expand
if
self
.
use_last_conv
else
make_divisible
(
"stage4"
][
0
]
*
2
*
scale
NET_CONFIG
[
"stage4"
][
0
]
*
2
*
scale
)
self
.
fc
=
Linear
(
in_features
,
class_num
)
self
.
fc
=
Linear
(
in_features
,
class_num
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
...
ppcls/arch/backbone/variant_models/__init__.py
浏览文件 @
1b5e00e8
from
.resnet_variant
import
ResNet50_last_stage_stride1
from
.resnet_variant
import
ResNet50_last_stage_stride1
from
.vgg_variant
import
VGG19Sigmoid
from
.vgg_variant
import
VGG19Sigmoid
from
.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
.pp_lcnetv2_variant
import
PPLCNetV2_base_ShiTu
ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py
0 → 100644
浏览文件 @
1b5e00e8
from
paddle.nn
import
Conv2D
,
Identity
from
..legendary_models.pp_lcnet_v2
import
PPLCNetV2_base
,
RepDepthwiseSeparable
,
MODEL_URLS
,
_load_pretrained
__all__
=
[
"PPLCNetV2_base_ShiTu"
]
def
PPLCNetV2_base_ShiTu
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
PPLCNetV2_base
(
pretrained
=
False
,
use_ssld
=
use_ssld
,
**
kwargs
)
def
remove_ReLU_function
(
conv
,
pattern
):
new_conv
=
Identity
()
return
new_conv
def
last_stride_1_function
(
conv
,
pattern
):
new_conv
=
Conv2D
(
weight_attr
=
conv
.
_weight_attr
,
in_channels
=
conv
.
_in_channels
,
out_channels
=
conv
.
_out_channels
,
kernel_size
=
conv
.
_kernel_size
,
stride
=
1
,
padding
=
conv
.
_padding
,
groups
=
conv
.
_groups
,
bias_attr
=
conv
.
_bias_attr
)
return
new_conv
pattern_act
=
[
"act"
]
pattern_last_stride
=
[
"stages[3][0].dw_conv_list[0].conv"
,
"stages[3][0].dw_conv_list[1].conv"
,
"stages[3][0].dw_conv"
,
"stages[3][0].pw_conv.conv"
,
"stages[3][1].dw_conv_list[0].conv"
,
"stages[3][1].dw_conv_list[1].conv"
,
"stages[3][1].dw_conv_list[2].conv"
,
"stages[3][1].dw_conv"
,
"stages[3][1].pw_conv.conv"
,
]
model
.
upgrade_sublayer
(
pattern_last_stride
,
last_stride_1_function
)
model
.
upgrade_sublayer
(
pattern_act
,
remove_ReLU_function
)
# load params again after upgrade some layers
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"PPLCNetV2_base"
],
use_ssld
)
return
model
ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml
0 → 100644
浏览文件 @
1b5e00e8
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
100
print_batch_step
:
20
use_visualdl
:
False
eval_mode
:
retrieval
retrieval_feature_from
:
features
# 'backbone' or 'features'
re_ranking
:
False
use_dali
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# AMP:
# scale_loss: 65536
# use_dynamic_loss_scaling: True
# # O1: mixed fp16
# level: O1
# model architecture
Arch
:
name
:
RecModel
infer_output_key
:
features
infer_add_softmax
:
False
Backbone
:
name
:
PPLCNetV2_base_ShiTu
pretrained
:
True
use_ssld
:
True
class_expand
:
&feat_dim
512
BackboneStopLayer
:
name
:
flatten
Neck
:
name
:
BNNeck
num_features
:
*feat_dim
weight_attr
:
initializer
:
name
:
Constant
value
:
1.0
bias_attr
:
initializer
:
name
:
Constant
value
:
0.0
learning_rate
:
1.0e-20
# NOTE: Temporarily set lr small enough to freeze the bias to zero
Head
:
name
:
FC
embedding_size
:
*feat_dim
class_num
:
192612
weight_attr
:
initializer
:
name
:
Normal
std
:
0.001
bias_attr
:
False
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
-
TripletAngleMarinLoss
:
weight
:
1.0
margin
:
0.5
reduction
:
mean
add_absolute
:
True
absolute_loss_weight
:
0.1
normalize_feature
:
True
feature_from
:
features
ap_value
:
0.8
an_value
:
0.4
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.04
warmup_epoch
:
5
regularizer
:
name
:
L2
coeff
:
0.00001
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/
cls_label_path
:
./dataset/train_reg_all_data.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
224
,
224
]
return_numpy
:
False
interpolation
:
bilinear
backend
:
cv2
-
RandFlipImage
:
flip_code
:
1
-
Pad_cv2
:
padding
:
10
-
RandCropImageV2
:
size
:
[
224
,
224
]
-
RandomRotation
:
prob
:
0.5
degrees
:
90
interpolation
:
bilinear
-
ResizeImage
:
size
:
[
224
,
224
]
return_numpy
:
False
interpolation
:
bilinear
backend
:
cv2
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
hwc
sampler
:
name
:
DistributedBatchSampler
batch_size
:
256
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
Query
:
dataset
:
name
:
VeriWild
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
224
,
224
]
return_numpy
:
False
interpolation
:
bilinear
backend
:
cv2
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
hwc
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Gallery
:
dataset
:
name
:
VeriWild
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
224
,
224
]
return_numpy
:
False
interpolation
:
bilinear
backend
:
cv2
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
hwc
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Metric
:
Eval
:
-
Recallk
:
topk
:
[
1
,
5
]
ppcls/data/dataloader/imagenet_dataset.py
浏览文件 @
1b5e00e8
...
@@ -21,14 +21,14 @@ from .common_dataset import CommonDataset
...
@@ -21,14 +21,14 @@ from .common_dataset import CommonDataset
class
ImageNetDataset
(
CommonDataset
):
class
ImageNetDataset
(
CommonDataset
):
def
__init__
(
def
__init__
(
self
,
self
,
image_root
,
image_root
,
cls_label_path
,
cls_label_path
,
transform_ops
=
None
,
transform_ops
=
None
,
delimiter
=
None
):
delimiter
=
None
):
self
.
delimiter
=
delimiter
if
delimiter
is
not
None
else
" "
self
.
delimiter
=
delimiter
if
delimiter
is
not
None
else
" "
super
(
ImageNetDataset
,
self
).
__init__
(
image_root
,
cls_label_path
,
transform_ops
)
super
(
ImageNetDataset
,
self
).
__init__
(
image_root
,
cls_label_path
,
transform_ops
)
def
_load_anno
(
self
,
seed
=
None
):
def
_load_anno
(
self
,
seed
=
None
):
assert
os
.
path
.
exists
(
self
.
_cls_path
)
assert
os
.
path
.
exists
(
self
.
_cls_path
)
...
@@ -40,8 +40,9 @@ class ImageNetDataset(CommonDataset):
...
@@ -40,8 +40,9 @@ class ImageNetDataset(CommonDataset):
lines
=
fd
.
readlines
()
lines
=
fd
.
readlines
()
if
seed
is
not
None
:
if
seed
is
not
None
:
np
.
random
.
RandomState
(
seed
).
shuffle
(
lines
)
np
.
random
.
RandomState
(
seed
).
shuffle
(
lines
)
for
l
in
lines
:
for
line
in
lines
:
l
=
l
.
strip
().
split
(
self
.
delimiter
)
line
=
line
.
strip
().
split
(
self
.
delimiter
)
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
l
[
0
]))
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
line
[
0
]))
self
.
labels
.
append
(
np
.
int64
(
l
[
1
]))
self
.
labels
.
append
(
np
.
int64
(
line
[
1
]))
assert
os
.
path
.
exists
(
self
.
images
[
-
1
])
assert
os
.
path
.
exists
(
self
.
images
[
-
1
]),
f
"path
{
self
.
images
[
-
1
]
}
does not exist."
ppcls/data/dataloader/vehicle_dataset.py
浏览文件 @
1b5e00e8
...
@@ -89,11 +89,7 @@ class CompCars(Dataset):
...
@@ -89,11 +89,7 @@ class CompCars(Dataset):
class
VeriWild
(
Dataset
):
class
VeriWild
(
Dataset
):
def
__init__
(
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
):
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
,
):
self
.
_img_root
=
image_root
self
.
_img_root
=
image_root
self
.
_cls_path
=
cls_label_path
self
.
_cls_path
=
cls_label_path
if
transform_ops
:
if
transform_ops
:
...
@@ -109,12 +105,14 @@ class VeriWild(Dataset):
...
@@ -109,12 +105,14 @@ class VeriWild(Dataset):
self
.
cameras
=
[]
self
.
cameras
=
[]
with
open
(
self
.
_cls_path
)
as
fd
:
with
open
(
self
.
_cls_path
)
as
fd
:
lines
=
fd
.
readlines
()
lines
=
fd
.
readlines
()
for
l
in
lines
:
for
line
in
lines
:
l
=
l
.
strip
().
split
()
line
=
line
.
strip
().
split
()
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
l
[
0
]))
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
line
[
0
]))
self
.
labels
.
append
(
np
.
int64
(
l
[
1
]))
self
.
labels
.
append
(
np
.
int64
(
line
[
1
]))
self
.
cameras
.
append
(
np
.
int64
(
l
[
2
]))
if
len
(
line
)
>=
3
:
self
.
cameras
.
append
(
np
.
int64
(
line
[
2
]))
assert
os
.
path
.
exists
(
self
.
images
[
-
1
])
assert
os
.
path
.
exists
(
self
.
images
[
-
1
])
self
.
has_camera
=
len
(
self
.
cameras
)
>
0
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
try
:
try
:
...
@@ -123,7 +121,10 @@ class VeriWild(Dataset):
...
@@ -123,7 +121,10 @@ class VeriWild(Dataset):
if
self
.
_transform_ops
:
if
self
.
_transform_ops
:
img
=
transform
(
img
,
self
.
_transform_ops
)
img
=
transform
(
img
,
self
.
_transform_ops
)
img
=
img
.
transpose
((
2
,
0
,
1
))
img
=
img
.
transpose
((
2
,
0
,
1
))
return
(
img
,
self
.
labels
[
idx
],
self
.
cameras
[
idx
])
if
self
.
has_camera
:
return
(
img
,
self
.
labels
[
idx
],
self
.
cameras
[
idx
])
else
:
return
(
img
,
self
.
labels
[
idx
])
except
Exception
as
ex
:
except
Exception
as
ex
:
logger
.
error
(
"Exception occured when parse line: {} with msg: {}"
.
logger
.
error
(
"Exception occured when parse line: {} with msg: {}"
.
format
(
self
.
images
[
idx
],
ex
))
format
(
self
.
images
[
idx
],
ex
))
...
...
ppcls/data/preprocess/__init__.py
浏览文件 @
1b5e00e8
...
@@ -38,6 +38,7 @@ from ppcls.data.preprocess.ops.operators import CropWithPadding
...
@@ -38,6 +38,7 @@ from ppcls.data.preprocess.ops.operators import CropWithPadding
from
ppcls.data.preprocess.ops.operators
import
RandomInterpolationAugment
from
ppcls.data.preprocess.ops.operators
import
RandomInterpolationAugment
from
ppcls.data.preprocess.ops.operators
import
ColorJitter
from
ppcls.data.preprocess.ops.operators
import
ColorJitter
from
ppcls.data.preprocess.ops.operators
import
RandomCropImage
from
ppcls.data.preprocess.ops.operators
import
RandomCropImage
from
ppcls.data.preprocess.ops.operators
import
RandomRotation
from
ppcls.data.preprocess.ops.operators
import
Padv2
from
ppcls.data.preprocess.ops.operators
import
Padv2
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
OpSampler
,
FmixOperator
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
OpSampler
,
FmixOperator
...
...
ppcls/data/preprocess/ops/operators.py
浏览文件 @
1b5e00e8
...
@@ -26,6 +26,7 @@ import cv2
...
@@ -26,6 +26,7 @@ import cv2
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
,
ImageOps
,
__version__
as
PILLOW_VERSION
from
PIL
import
Image
,
ImageOps
,
__version__
as
PILLOW_VERSION
from
paddle.vision.transforms
import
ColorJitter
as
RawColorJitter
from
paddle.vision.transforms
import
ColorJitter
as
RawColorJitter
from
paddle.vision.transforms
import
RandomRotation
as
RawRandomRotation
from
paddle.vision.transforms
import
ToTensor
,
Normalize
,
RandomHorizontalFlip
,
RandomResizedCrop
from
paddle.vision.transforms
import
ToTensor
,
Normalize
,
RandomHorizontalFlip
,
RandomResizedCrop
from
paddle.vision.transforms
import
functional
as
F
from
paddle.vision.transforms
import
functional
as
F
from
.autoaugment
import
ImageNetPolicy
from
.autoaugment
import
ImageNetPolicy
...
@@ -181,7 +182,8 @@ class DecodeImage(object):
...
@@ -181,7 +182,8 @@ class DecodeImage(object):
img
=
np
.
asarray
(
img
)[:,
:,
::
-
1
]
# BRG
img
=
np
.
asarray
(
img
)[:,
:,
::
-
1
]
# BRG
if
self
.
to_rgb
:
if
self
.
to_rgb
:
assert
img
.
shape
[
2
]
==
3
,
f
"invalid shape of image[
{
img
.
shape
}
]"
assert
img
.
shape
[
2
]
==
3
,
f
"invalid shape of image[
{
img
.
shape
}
]"
img
=
img
[:,
:,
::
-
1
]
img
=
img
[:,
:,
::
-
1
]
if
self
.
channel_first
:
if
self
.
channel_first
:
...
@@ -495,7 +497,13 @@ class RandFlipImage(object):
...
@@ -495,7 +497,13 @@ class RandFlipImage(object):
if
isinstance
(
img
,
np
.
ndarray
):
if
isinstance
(
img
,
np
.
ndarray
):
return
cv2
.
flip
(
img
,
self
.
flip_code
)
return
cv2
.
flip
(
img
,
self
.
flip_code
)
else
:
else
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
if
self
.
flip_code
==
1
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
elif
self
.
flip_code
==
0
:
return
img
.
transpose
(
Image
.
FLIP_TOP_BOTTOM
)
else
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
).
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
else
:
else
:
return
img
return
img
...
@@ -653,6 +661,20 @@ class ColorJitter(RawColorJitter):
...
@@ -653,6 +661,20 @@ class ColorJitter(RawColorJitter):
return
img
return
img
class
RandomRotation
(
RawRandomRotation
):
"""RandomRotation.
"""
def
__init__
(
self
,
prob
=
0.5
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
prob
=
prob
def
__call__
(
self
,
img
):
if
np
.
random
.
random
()
<
self
.
prob
:
img
=
super
().
_apply_image
(
img
)
return
img
class
Pad
(
object
):
class
Pad
(
object
):
"""
"""
Pads the given PIL.Image on all sides with specified padding mode and fill value.
Pads the given PIL.Image on all sides with specified padding mode and fill value.
...
...
ppcls/engine/engine.py
浏览文件 @
1b5e00e8
...
@@ -114,6 +114,10 @@ class Engine(object):
...
@@ -114,6 +114,10 @@ class Engine(object):
#TODO(gaotingquan): support rec
#TODO(gaotingquan): support rec
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
self
.
config
[
"DataLoader"
].
update
({
"class_num"
:
class_num
})
self
.
config
[
"DataLoader"
].
update
({
"class_num"
:
class_num
})
self
.
model
=
build_model
(
self
.
config
,
self
.
mode
)
# print(*self.model.state_dict().keys(), sep='\n')
print
(
self
.
model
.
backbone
.
stages
[
3
][
0
].
dw_conv_list
[
0
].
conv
)
exit
(
0
)
# build dataloader
# build dataloader
if
self
.
mode
==
'train'
:
if
self
.
mode
==
'train'
:
self
.
train_dataloader
=
build_dataloader
(
self
.
train_dataloader
=
build_dataloader
(
...
...
ppcls/loss/__init__.py
浏览文件 @
1b5e00e8
...
@@ -12,6 +12,7 @@ from .msmloss import MSMLoss
...
@@ -12,6 +12,7 @@ from .msmloss import MSMLoss
from
.npairsloss
import
NpairsLoss
from
.npairsloss
import
NpairsLoss
from
.trihardloss
import
TriHardLoss
from
.trihardloss
import
TriHardLoss
from
.triplet
import
TripletLoss
,
TripletLossV2
from
.triplet
import
TripletLoss
,
TripletLossV2
from
.tripletangularmarginloss
import
TTripletAngularMarginLoss
from
.supconloss
import
SupConLoss
from
.supconloss
import
SupConLoss
from
.pairwisecosface
import
PairwiseCosface
from
.pairwisecosface
import
PairwiseCosface
from
.dmlloss
import
DMLLoss
from
.dmlloss
import
DMLLoss
...
...
ppcls/loss/tripletangularmarginloss.py
0 → 100644
浏览文件 @
1b5e00e8
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
class
TripletAngularMarginLoss
(
nn
.
Layer
):
"""A more robust triplet loss with hard positive/negative mining on angular margin instead of relative distance between d(a,p) and d(a,n).
Args:
margin (float, optional): angular margin. Defaults to 0.5.
normalize_feature (bool, optional): whether to apply L2-norm in feature before computing distance(cos-similarity). Defaults to True.
reduction (str, optional): reducing option within an batch . Defaults to "mean".
add_absolute (bool, optional): whether add absolute loss within d(a,p) or d(a,n). Defaults to False.
absolute_loss_weight (float, optional): weight for absolute loss. Defaults to 1.0.
ap_value (float, optional): weight for d(a, p). Defaults to 0.9.
an_value (float, optional): weight for d(a, n). Defaults to 0.5.
feature_from (str, optional): which key feature from. Defaults to "features".
"""
def
__init__
(
self
,
margin
=
0.5
,
normalize_feature
=
True
,
reduction
=
"mean"
,
add_absolute
=
False
,
absolute_loss_weight
=
1.0
,
ap_value
=
0.9
,
an_value
=
0.5
,
feature_from
=
"features"
):
super
(
TripletAngleMarginLoss
,
self
).
__init__
()
self
.
margin
=
margin
self
.
feature_from
=
feature_from
self
.
ranking_loss
=
paddle
.
nn
.
loss
.
MarginRankingLoss
(
margin
=
margin
,
reduction
=
reduction
)
self
.
normalize_feature
=
normalize_feature
self
.
add_absolute
=
add_absolute
self
.
ap_value
=
ap_value
self
.
an_value
=
an_value
self
.
absolute_loss_weight
=
absolute_loss_weight
def
forward
(
self
,
input
,
target
):
"""
Args:
inputs: feature matrix with shape (batch_size, feat_dim)
target: ground truth labels with shape (num_classes)
"""
inputs
=
input
[
self
.
feature_from
]
if
self
.
normalize_feature
:
inputs
=
paddle
.
divide
(
inputs
,
paddle
.
norm
(
inputs
,
p
=
2
,
axis
=-
1
,
keepdim
=
True
))
bs
=
inputs
.
shape
[
0
]
# compute distance(cos-similarity)
dist
=
paddle
.
matmul
(
inputs
,
inputs
.
t
())
# hard negative mining
is_pos
=
paddle
.
expand
(
target
,
(
bs
,
bs
)).
equal
(
paddle
.
expand
(
target
,
(
bs
,
bs
)).
t
())
is_neg
=
paddle
.
expand
(
target
,
(
bs
,
bs
)).
not_equal
(
paddle
.
expand
(
target
,
(
bs
,
bs
)).
t
())
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
dist_ap
=
paddle
.
min
(
paddle
.
reshape
(
paddle
.
masked_select
(
dist
,
is_pos
),
(
bs
,
-
1
)),
axis
=
1
,
keepdim
=
True
)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an
=
paddle
.
max
(
paddle
.
reshape
(
paddle
.
masked_select
(
dist
,
is_neg
),
(
bs
,
-
1
)),
axis
=
1
,
keepdim
=
True
)
# shape [N]
dist_ap
=
paddle
.
squeeze
(
dist_ap
,
axis
=
1
)
dist_an
=
paddle
.
squeeze
(
dist_an
,
axis
=
1
)
# Compute ranking hinge loss
y
=
paddle
.
ones_like
(
dist_an
)
loss
=
self
.
ranking_loss
(
dist_ap
,
dist_an
,
y
)
if
self
.
add_absolute
:
absolut_loss_ap
=
self
.
ap_value
-
dist_ap
absolut_loss_ap
=
paddle
.
where
(
absolut_loss_ap
>
0
,
absolut_loss_ap
,
paddle
.
zeros_like
(
absolut_loss_ap
))
absolut_loss_an
=
dist_an
-
self
.
an_value
absolut_loss_an
=
paddle
.
where
(
absolut_loss_an
>
0
,
absolut_loss_an
,
paddle
.
ones_like
(
absolut_loss_an
))
loss
=
(
absolut_loss_an
.
mean
()
+
absolut_loss_ap
.
mean
()
)
*
self
.
absolute_loss_weight
+
loss
.
mean
()
return
{
"TripletAngularMarginLoss"
:
loss
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录