Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
1b5e00e8
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1b5e00e8
编写于
8月 23, 2022
作者:
H
HydrogenSulfate
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add PP-ShiTuV2 code
上级
dab99e3e
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
418 addition
and
27 deletion
+418
-27
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/base/theseus_layer.py
ppcls/arch/backbone/base/theseus_layer.py
+2
-0
ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
+2
-2
ppcls/arch/backbone/variant_models/__init__.py
ppcls/arch/backbone/variant_models/__init__.py
+1
-0
ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py
ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py
+44
-0
ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml
...ralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml
+198
-0
ppcls/data/dataloader/imagenet_dataset.py
ppcls/data/dataloader/imagenet_dataset.py
+13
-12
ppcls/data/dataloader/vehicle_dataset.py
ppcls/data/dataloader/vehicle_dataset.py
+12
-11
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+1
-0
ppcls/data/preprocess/ops/operators.py
ppcls/data/preprocess/ops/operators.py
+24
-2
ppcls/engine/engine.py
ppcls/engine/engine.py
+4
-0
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+1
-0
ppcls/loss/tripletangularmarginloss.py
ppcls/loss/tripletangularmarginloss.py
+115
-0
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
1b5e00e8
...
...
@@ -73,6 +73,7 @@ from .model_zoo.convnext import ConvNeXt_tiny
from
.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
.variant_models.vgg_variant
import
VGG19Sigmoid
from
.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
.variant_models.pp_lcnetv2_variant
import
PPLCNetV2_base_ShiTu
from
.model_zoo.adaface_ir_net
import
AdaFace_IR_18
,
AdaFace_IR_34
,
AdaFace_IR_50
,
AdaFace_IR_101
,
AdaFace_IR_152
,
AdaFace_IR_SE_50
,
AdaFace_IR_SE_101
,
AdaFace_IR_SE_152
,
AdaFace_IR_SE_200
...
...
ppcls/arch/backbone/base/theseus_layer.py
浏览文件 @
1b5e00e8
...
...
@@ -158,6 +158,8 @@ class TheseusLayer(nn.Layer):
return
False
parent_layer
=
layer_dict
[
"layer"
]
msg
=
f
"Successfully set the layers that after stop_layer_name('
{
stop_layer_name
}
') to IdentityLayer."
logger
.
info
(
msg
)
return
True
def
update_res
(
...
...
ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
浏览文件 @
1b5e00e8
...
...
@@ -306,8 +306,8 @@ class PPLCNetV2(TheseusLayer):
self
.
dropout
=
Dropout
(
p
=
dropout_prob
,
mode
=
"downscale_in_infer"
)
self
.
flatten
=
nn
.
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
in_features
=
self
.
class_expand
if
self
.
use_last_conv
else
NET_CONFIG
[
"stage4"
][
0
]
*
2
*
scale
in_features
=
self
.
class_expand
if
self
.
use_last_conv
else
make_divisible
(
NET_CONFIG
[
"stage4"
][
0
]
*
2
*
scale
)
self
.
fc
=
Linear
(
in_features
,
class_num
)
def
forward
(
self
,
x
):
...
...
ppcls/arch/backbone/variant_models/__init__.py
浏览文件 @
1b5e00e8
from
.resnet_variant
import
ResNet50_last_stage_stride1
from
.vgg_variant
import
VGG19Sigmoid
from
.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
.pp_lcnetv2_variant
import
PPLCNetV2_base_ShiTu
ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py
0 → 100644
浏览文件 @
1b5e00e8
from
paddle.nn
import
Conv2D
,
Identity
from
..legendary_models.pp_lcnet_v2
import
PPLCNetV2_base
,
RepDepthwiseSeparable
,
MODEL_URLS
,
_load_pretrained
__all__
=
[
"PPLCNetV2_base_ShiTu"
]
def
PPLCNetV2_base_ShiTu
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
PPLCNetV2_base
(
pretrained
=
False
,
use_ssld
=
use_ssld
,
**
kwargs
)
def
remove_ReLU_function
(
conv
,
pattern
):
new_conv
=
Identity
()
return
new_conv
def
last_stride_1_function
(
conv
,
pattern
):
new_conv
=
Conv2D
(
weight_attr
=
conv
.
_weight_attr
,
in_channels
=
conv
.
_in_channels
,
out_channels
=
conv
.
_out_channels
,
kernel_size
=
conv
.
_kernel_size
,
stride
=
1
,
padding
=
conv
.
_padding
,
groups
=
conv
.
_groups
,
bias_attr
=
conv
.
_bias_attr
)
return
new_conv
pattern_act
=
[
"act"
]
pattern_last_stride
=
[
"stages[3][0].dw_conv_list[0].conv"
,
"stages[3][0].dw_conv_list[1].conv"
,
"stages[3][0].dw_conv"
,
"stages[3][0].pw_conv.conv"
,
"stages[3][1].dw_conv_list[0].conv"
,
"stages[3][1].dw_conv_list[1].conv"
,
"stages[3][1].dw_conv_list[2].conv"
,
"stages[3][1].dw_conv"
,
"stages[3][1].pw_conv.conv"
,
]
model
.
upgrade_sublayer
(
pattern_last_stride
,
last_stride_1_function
)
model
.
upgrade_sublayer
(
pattern_act
,
remove_ReLU_function
)
# load params again after upgrade some layers
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"PPLCNetV2_base"
],
use_ssld
)
return
model
ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml
0 → 100644
浏览文件 @
1b5e00e8
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
100
print_batch_step
:
20
use_visualdl
:
False
eval_mode
:
retrieval
retrieval_feature_from
:
features
# 'backbone' or 'features'
re_ranking
:
False
use_dali
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# AMP:
# scale_loss: 65536
# use_dynamic_loss_scaling: True
# # O1: mixed fp16
# level: O1
# model architecture
Arch
:
name
:
RecModel
infer_output_key
:
features
infer_add_softmax
:
False
Backbone
:
name
:
PPLCNetV2_base_ShiTu
pretrained
:
True
use_ssld
:
True
class_expand
:
&feat_dim
512
BackboneStopLayer
:
name
:
flatten
Neck
:
name
:
BNNeck
num_features
:
*feat_dim
weight_attr
:
initializer
:
name
:
Constant
value
:
1.0
bias_attr
:
initializer
:
name
:
Constant
value
:
0.0
learning_rate
:
1.0e-20
# NOTE: Temporarily set lr small enough to freeze the bias to zero
Head
:
name
:
FC
embedding_size
:
*feat_dim
class_num
:
192612
weight_attr
:
initializer
:
name
:
Normal
std
:
0.001
bias_attr
:
False
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
-
TripletAngleMarinLoss
:
weight
:
1.0
margin
:
0.5
reduction
:
mean
add_absolute
:
True
absolute_loss_weight
:
0.1
normalize_feature
:
True
feature_from
:
features
ap_value
:
0.8
an_value
:
0.4
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.04
warmup_epoch
:
5
regularizer
:
name
:
L2
coeff
:
0.00001
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/
cls_label_path
:
./dataset/train_reg_all_data.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
224
,
224
]
return_numpy
:
False
interpolation
:
bilinear
backend
:
cv2
-
RandFlipImage
:
flip_code
:
1
-
Pad_cv2
:
padding
:
10
-
RandCropImageV2
:
size
:
[
224
,
224
]
-
RandomRotation
:
prob
:
0.5
degrees
:
90
interpolation
:
bilinear
-
ResizeImage
:
size
:
[
224
,
224
]
return_numpy
:
False
interpolation
:
bilinear
backend
:
cv2
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
hwc
sampler
:
name
:
DistributedBatchSampler
batch_size
:
256
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
Query
:
dataset
:
name
:
VeriWild
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
224
,
224
]
return_numpy
:
False
interpolation
:
bilinear
backend
:
cv2
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
hwc
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Gallery
:
dataset
:
name
:
VeriWild
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
224
,
224
]
return_numpy
:
False
interpolation
:
bilinear
backend
:
cv2
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
hwc
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Metric
:
Eval
:
-
Recallk
:
topk
:
[
1
,
5
]
ppcls/data/dataloader/imagenet_dataset.py
浏览文件 @
1b5e00e8
...
...
@@ -21,14 +21,14 @@ from .common_dataset import CommonDataset
class
ImageNetDataset
(
CommonDataset
):
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
,
delimiter
=
None
):
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
,
delimiter
=
None
):
self
.
delimiter
=
delimiter
if
delimiter
is
not
None
else
" "
super
(
ImageNetDataset
,
self
).
__init__
(
image_root
,
cls_label_path
,
transform_ops
)
super
(
ImageNetDataset
,
self
).
__init__
(
image_root
,
cls_label_path
,
transform_ops
)
def
_load_anno
(
self
,
seed
=
None
):
assert
os
.
path
.
exists
(
self
.
_cls_path
)
...
...
@@ -40,8 +40,9 @@ class ImageNetDataset(CommonDataset):
lines
=
fd
.
readlines
()
if
seed
is
not
None
:
np
.
random
.
RandomState
(
seed
).
shuffle
(
lines
)
for
l
in
lines
:
l
=
l
.
strip
().
split
(
self
.
delimiter
)
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
l
[
0
]))
self
.
labels
.
append
(
np
.
int64
(
l
[
1
]))
assert
os
.
path
.
exists
(
self
.
images
[
-
1
])
for
line
in
lines
:
line
=
line
.
strip
().
split
(
self
.
delimiter
)
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
line
[
0
]))
self
.
labels
.
append
(
np
.
int64
(
line
[
1
]))
assert
os
.
path
.
exists
(
self
.
images
[
-
1
]),
f
"path
{
self
.
images
[
-
1
]
}
does not exist."
ppcls/data/dataloader/vehicle_dataset.py
浏览文件 @
1b5e00e8
...
...
@@ -89,11 +89,7 @@ class CompCars(Dataset):
class
VeriWild
(
Dataset
):
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
,
):
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
):
self
.
_img_root
=
image_root
self
.
_cls_path
=
cls_label_path
if
transform_ops
:
...
...
@@ -109,12 +105,14 @@ class VeriWild(Dataset):
self
.
cameras
=
[]
with
open
(
self
.
_cls_path
)
as
fd
:
lines
=
fd
.
readlines
()
for
l
in
lines
:
l
=
l
.
strip
().
split
()
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
l
[
0
]))
self
.
labels
.
append
(
np
.
int64
(
l
[
1
]))
self
.
cameras
.
append
(
np
.
int64
(
l
[
2
]))
for
line
in
lines
:
line
=
line
.
strip
().
split
()
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
line
[
0
]))
self
.
labels
.
append
(
np
.
int64
(
line
[
1
]))
if
len
(
line
)
>=
3
:
self
.
cameras
.
append
(
np
.
int64
(
line
[
2
]))
assert
os
.
path
.
exists
(
self
.
images
[
-
1
])
self
.
has_camera
=
len
(
self
.
cameras
)
>
0
def
__getitem__
(
self
,
idx
):
try
:
...
...
@@ -123,7 +121,10 @@ class VeriWild(Dataset):
if
self
.
_transform_ops
:
img
=
transform
(
img
,
self
.
_transform_ops
)
img
=
img
.
transpose
((
2
,
0
,
1
))
return
(
img
,
self
.
labels
[
idx
],
self
.
cameras
[
idx
])
if
self
.
has_camera
:
return
(
img
,
self
.
labels
[
idx
],
self
.
cameras
[
idx
])
else
:
return
(
img
,
self
.
labels
[
idx
])
except
Exception
as
ex
:
logger
.
error
(
"Exception occured when parse line: {} with msg: {}"
.
format
(
self
.
images
[
idx
],
ex
))
...
...
ppcls/data/preprocess/__init__.py
浏览文件 @
1b5e00e8
...
...
@@ -38,6 +38,7 @@ from ppcls.data.preprocess.ops.operators import CropWithPadding
from
ppcls.data.preprocess.ops.operators
import
RandomInterpolationAugment
from
ppcls.data.preprocess.ops.operators
import
ColorJitter
from
ppcls.data.preprocess.ops.operators
import
RandomCropImage
from
ppcls.data.preprocess.ops.operators
import
RandomRotation
from
ppcls.data.preprocess.ops.operators
import
Padv2
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
OpSampler
,
FmixOperator
...
...
ppcls/data/preprocess/ops/operators.py
浏览文件 @
1b5e00e8
...
...
@@ -26,6 +26,7 @@ import cv2
import
numpy
as
np
from
PIL
import
Image
,
ImageOps
,
__version__
as
PILLOW_VERSION
from
paddle.vision.transforms
import
ColorJitter
as
RawColorJitter
from
paddle.vision.transforms
import
RandomRotation
as
RawRandomRotation
from
paddle.vision.transforms
import
ToTensor
,
Normalize
,
RandomHorizontalFlip
,
RandomResizedCrop
from
paddle.vision.transforms
import
functional
as
F
from
.autoaugment
import
ImageNetPolicy
...
...
@@ -181,7 +182,8 @@ class DecodeImage(object):
img
=
np
.
asarray
(
img
)[:,
:,
::
-
1
]
# BRG
if
self
.
to_rgb
:
assert
img
.
shape
[
2
]
==
3
,
f
"invalid shape of image[
{
img
.
shape
}
]"
assert
img
.
shape
[
2
]
==
3
,
f
"invalid shape of image[
{
img
.
shape
}
]"
img
=
img
[:,
:,
::
-
1
]
if
self
.
channel_first
:
...
...
@@ -495,7 +497,13 @@ class RandFlipImage(object):
if
isinstance
(
img
,
np
.
ndarray
):
return
cv2
.
flip
(
img
,
self
.
flip_code
)
else
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
if
self
.
flip_code
==
1
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
elif
self
.
flip_code
==
0
:
return
img
.
transpose
(
Image
.
FLIP_TOP_BOTTOM
)
else
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
).
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
else
:
return
img
...
...
@@ -653,6 +661,20 @@ class ColorJitter(RawColorJitter):
return
img
class
RandomRotation
(
RawRandomRotation
):
"""RandomRotation.
"""
def
__init__
(
self
,
prob
=
0.5
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
prob
=
prob
def
__call__
(
self
,
img
):
if
np
.
random
.
random
()
<
self
.
prob
:
img
=
super
().
_apply_image
(
img
)
return
img
class
Pad
(
object
):
"""
Pads the given PIL.Image on all sides with specified padding mode and fill value.
...
...
ppcls/engine/engine.py
浏览文件 @
1b5e00e8
...
...
@@ -114,6 +114,10 @@ class Engine(object):
#TODO(gaotingquan): support rec
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
self
.
config
[
"DataLoader"
].
update
({
"class_num"
:
class_num
})
self
.
model
=
build_model
(
self
.
config
,
self
.
mode
)
# print(*self.model.state_dict().keys(), sep='\n')
print
(
self
.
model
.
backbone
.
stages
[
3
][
0
].
dw_conv_list
[
0
].
conv
)
exit
(
0
)
# build dataloader
if
self
.
mode
==
'train'
:
self
.
train_dataloader
=
build_dataloader
(
...
...
ppcls/loss/__init__.py
浏览文件 @
1b5e00e8
...
...
@@ -12,6 +12,7 @@ from .msmloss import MSMLoss
from
.npairsloss
import
NpairsLoss
from
.trihardloss
import
TriHardLoss
from
.triplet
import
TripletLoss
,
TripletLossV2
from
.tripletangularmarginloss
import
TTripletAngularMarginLoss
from
.supconloss
import
SupConLoss
from
.pairwisecosface
import
PairwiseCosface
from
.dmlloss
import
DMLLoss
...
...
ppcls/loss/tripletangularmarginloss.py
0 → 100644
浏览文件 @
1b5e00e8
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
class
TripletAngularMarginLoss
(
nn
.
Layer
):
"""A more robust triplet loss with hard positive/negative mining on angular margin instead of relative distance between d(a,p) and d(a,n).
Args:
margin (float, optional): angular margin. Defaults to 0.5.
normalize_feature (bool, optional): whether to apply L2-norm in feature before computing distance(cos-similarity). Defaults to True.
reduction (str, optional): reducing option within an batch . Defaults to "mean".
add_absolute (bool, optional): whether add absolute loss within d(a,p) or d(a,n). Defaults to False.
absolute_loss_weight (float, optional): weight for absolute loss. Defaults to 1.0.
ap_value (float, optional): weight for d(a, p). Defaults to 0.9.
an_value (float, optional): weight for d(a, n). Defaults to 0.5.
feature_from (str, optional): which key feature from. Defaults to "features".
"""
def
__init__
(
self
,
margin
=
0.5
,
normalize_feature
=
True
,
reduction
=
"mean"
,
add_absolute
=
False
,
absolute_loss_weight
=
1.0
,
ap_value
=
0.9
,
an_value
=
0.5
,
feature_from
=
"features"
):
super
(
TripletAngleMarginLoss
,
self
).
__init__
()
self
.
margin
=
margin
self
.
feature_from
=
feature_from
self
.
ranking_loss
=
paddle
.
nn
.
loss
.
MarginRankingLoss
(
margin
=
margin
,
reduction
=
reduction
)
self
.
normalize_feature
=
normalize_feature
self
.
add_absolute
=
add_absolute
self
.
ap_value
=
ap_value
self
.
an_value
=
an_value
self
.
absolute_loss_weight
=
absolute_loss_weight
def
forward
(
self
,
input
,
target
):
"""
Args:
inputs: feature matrix with shape (batch_size, feat_dim)
target: ground truth labels with shape (num_classes)
"""
inputs
=
input
[
self
.
feature_from
]
if
self
.
normalize_feature
:
inputs
=
paddle
.
divide
(
inputs
,
paddle
.
norm
(
inputs
,
p
=
2
,
axis
=-
1
,
keepdim
=
True
))
bs
=
inputs
.
shape
[
0
]
# compute distance(cos-similarity)
dist
=
paddle
.
matmul
(
inputs
,
inputs
.
t
())
# hard negative mining
is_pos
=
paddle
.
expand
(
target
,
(
bs
,
bs
)).
equal
(
paddle
.
expand
(
target
,
(
bs
,
bs
)).
t
())
is_neg
=
paddle
.
expand
(
target
,
(
bs
,
bs
)).
not_equal
(
paddle
.
expand
(
target
,
(
bs
,
bs
)).
t
())
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
dist_ap
=
paddle
.
min
(
paddle
.
reshape
(
paddle
.
masked_select
(
dist
,
is_pos
),
(
bs
,
-
1
)),
axis
=
1
,
keepdim
=
True
)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an
=
paddle
.
max
(
paddle
.
reshape
(
paddle
.
masked_select
(
dist
,
is_neg
),
(
bs
,
-
1
)),
axis
=
1
,
keepdim
=
True
)
# shape [N]
dist_ap
=
paddle
.
squeeze
(
dist_ap
,
axis
=
1
)
dist_an
=
paddle
.
squeeze
(
dist_an
,
axis
=
1
)
# Compute ranking hinge loss
y
=
paddle
.
ones_like
(
dist_an
)
loss
=
self
.
ranking_loss
(
dist_ap
,
dist_an
,
y
)
if
self
.
add_absolute
:
absolut_loss_ap
=
self
.
ap_value
-
dist_ap
absolut_loss_ap
=
paddle
.
where
(
absolut_loss_ap
>
0
,
absolut_loss_ap
,
paddle
.
zeros_like
(
absolut_loss_ap
))
absolut_loss_an
=
dist_an
-
self
.
an_value
absolut_loss_an
=
paddle
.
where
(
absolut_loss_an
>
0
,
absolut_loss_an
,
paddle
.
ones_like
(
absolut_loss_an
))
loss
=
(
absolut_loss_an
.
mean
()
+
absolut_loss_ap
.
mean
()
)
*
self
.
absolute_loss_weight
+
loss
.
mean
()
return
{
"TripletAngularMarginLoss"
:
loss
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录