Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
ec5e07da
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ec5e07da
编写于
8月 26, 2021
作者:
B
Bin Lu
提交者:
GitHub
8月 26, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1166 from Intsigstephon/develop
add Deephash method: DLBHC
上级
41036408
d388d69a
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
476 addition
and
4 deletion
+476
-4
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/variant_models/__init__.py
ppcls/arch/backbone/variant_models/__init__.py
+1
-0
ppcls/arch/backbone/variant_models/vgg_variant.py
ppcls/arch/backbone/variant_models/vgg_variant.py
+28
-0
ppcls/arch/gears/circlemargin.py
ppcls/arch/gears/circlemargin.py
+6
-3
ppcls/arch/gears/cosmargin.py
ppcls/arch/gears/cosmargin.py
+3
-0
ppcls/configs/Products/MV3_Large_1x_Aliproduct_DLBHC.yaml
ppcls/configs/Products/MV3_Large_1x_Aliproduct_DLBHC.yaml
+149
-0
ppcls/configs/quick_start/professional/VGG19_CIFAR10_DeepHash.yaml
...figs/quick_start/professional/VGG19_CIFAR10_DeepHash.yaml
+147
-0
ppcls/engine/evaluation/retrieval.py
ppcls/engine/evaluation/retrieval.py
+9
-0
ppcls/loss/deephashloss.py
ppcls/loss/deephashloss.py
+90
-0
ppcls/metric/__init__.py
ppcls/metric/__init__.py
+1
-1
ppcls/metric/metrics.py
ppcls/metric/metrics.py
+41
-0
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
ec5e07da
...
...
@@ -58,6 +58,7 @@ from ppcls.arch.backbone.model_zoo.rednet import RedNet26, RedNet38, RedNet50, R
from
ppcls.arch.backbone.model_zoo.tnt
import
TNT_small
from
ppcls.arch.backbone.model_zoo.hardnet
import
HarDNet68
,
HarDNet85
,
HarDNet39_ds
,
HarDNet68_ds
from
ppcls.arch.backbone.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
ppcls.arch.backbone.variant_models.vgg_variant
import
VGG19Sigmoid
def
get_apis
():
...
...
ppcls/arch/backbone/variant_models/__init__.py
浏览文件 @
ec5e07da
from
.resnet_variant
import
ResNet50_last_stage_stride1
from
.vgg_variant
import
VGG19Sigmoid
ppcls/arch/backbone/variant_models/vgg_variant.py
0 → 100644
浏览文件 @
ec5e07da
import
paddle
from
paddle.nn
import
Sigmoid
from
ppcls.arch.backbone.legendary_models.vgg
import
VGG19
__all__
=
[
"VGG19Sigmoid"
]
class
SigmoidSuffix
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
origin_layer
):
super
(
SigmoidSuffix
,
self
).
__init__
()
self
.
origin_layer
=
origin_layer
self
.
sigmoid
=
Sigmoid
()
def
forward
(
self
,
input
,
res_dict
=
None
,
**
kwargs
):
x
=
self
.
origin_layer
(
input
)
x
=
self
.
sigmoid
(
x
)
return
x
def
VGG19Sigmoid
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
def
replace_function
(
origin_layer
):
new_layer
=
SigmoidSuffix
(
origin_layer
)
return
new_layer
match_re
=
"linear_2"
model
=
VGG19
(
pretrained
=
pretrained
,
use_ssld
=
use_ssld
,
**
kwargs
)
model
.
replace_sub
(
match_re
,
replace_function
,
True
)
return
model
ppcls/arch/gears/circlemargin.py
浏览文件 @
ec5e07da
...
...
@@ -28,7 +28,7 @@ class CircleMargin(nn.Layer):
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
XavierNormal
())
self
.
fc
0
=
paddle
.
nn
.
Linear
(
self
.
fc
=
paddle
.
nn
.
Linear
(
self
.
embedding_size
,
self
.
class_num
,
weight_attr
=
weight_attr
)
def
forward
(
self
,
input
,
label
):
...
...
@@ -36,19 +36,22 @@ class CircleMargin(nn.Layer):
paddle
.
sum
(
paddle
.
square
(
input
),
axis
=
1
,
keepdim
=
True
))
input
=
paddle
.
divide
(
input
,
feat_norm
)
weight
=
self
.
fc
0
.
weight
weight
=
self
.
fc
.
weight
weight_norm
=
paddle
.
sqrt
(
paddle
.
sum
(
paddle
.
square
(
weight
),
axis
=
0
,
keepdim
=
True
))
weight
=
paddle
.
divide
(
weight
,
weight_norm
)
logits
=
paddle
.
matmul
(
input
,
weight
)
if
not
self
.
training
or
label
is
None
:
return
logits
alpha_p
=
paddle
.
clip
(
-
logits
.
detach
()
+
1
+
self
.
margin
,
min
=
0.
)
alpha_n
=
paddle
.
clip
(
logits
.
detach
()
+
self
.
margin
,
min
=
0.
)
delta_p
=
1
-
self
.
margin
delta_n
=
self
.
margin
index
=
paddle
.
fluid
.
layers
.
where
(
label
!=
-
1
).
reshape
([
-
1
])
m_hot
=
F
.
one_hot
(
label
.
reshape
([
-
1
]),
num_classes
=
logits
.
shape
[
1
])
logits_p
=
alpha_p
*
(
logits
-
delta_p
)
logits_n
=
alpha_n
*
(
logits
-
delta_n
)
pre_logits
=
logits_p
*
m_hot
+
logits_n
*
(
1
-
m_hot
)
...
...
ppcls/arch/gears/cosmargin.py
浏览文件 @
ec5e07da
...
...
@@ -46,6 +46,9 @@ class CosMargin(paddle.nn.Layer):
weight
=
paddle
.
divide
(
weight
,
weight_norm
)
cos
=
paddle
.
matmul
(
input
,
weight
)
if
not
self
.
training
or
label
is
None
:
return
cos
cos_m
=
cos
-
self
.
margin
one_hot
=
paddle
.
nn
.
functional
.
one_hot
(
label
,
self
.
class_num
)
...
...
ppcls/configs/Products/MV3_Large_1x_Aliproduct_DLBHC.yaml
0 → 100644
浏览文件 @
ec5e07da
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output_dlbhc/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
100
#eval_mode: "retrieval"
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
#feature postprocess
feature_normalize
:
False
feature_binarize
:
"
round"
# model architecture
Arch
:
name
:
"
RecModel"
Backbone
:
name
:
"
MobileNetV3_large_x1_0"
pretrained
:
True
class_num
:
512
Head
:
name
:
"
FC"
class_num
:
50030
embedding_size
:
512
infer_output_key
:
"
features"
infer_add_softmax
:
"
false"
# loss function config for train/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Piecewise
learning_rate
:
0.1
decay_epochs
:
[
50
,
150
]
values
:
[
0.1
,
0.01
,
0.001
]
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
256
-
RandCropImage
:
size
:
227
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.4914
,
0.4822
,
0.4465
]
std
:
[
0.2023
,
0.1994
,
0.2010
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
227
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.4914
,
0.4822
,
0.4465
]
std
:
[
0.2023
,
0.1994
,
0.2010
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
256
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/whl/demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
227
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.4914
,
0.4822
,
0.4465
]
std
:
[
0.2023
,
0.1994
,
0.2010
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
# switch to metric below when eval by retrieval
# - Recallk:
# topk: [1]
# - mAP:
# - Precisionk:
# topk: [1]
ppcls/configs/quick_start/professional/VGG19_CIFAR10_DeepHash.yaml
0 → 100644
浏览文件 @
ec5e07da
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
eval_mode
:
"
retrieval"
epochs
:
128
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
#feature postprocess
feature_normalize
:
False
feature_binarize
:
"
round"
# model architecture
Arch
:
name
:
"
RecModel"
Backbone
:
name
:
"
VGG19Sigmoid"
pretrained
:
True
class_num
:
48
Head
:
name
:
"
FC"
class_num
:
10
embedding_size
:
48
infer_output_key
:
"
features"
infer_add_softmax
:
"
false"
# loss function config for train/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Piecewise
learning_rate
:
0.01
decay_epochs
:
[
200
]
values
:
[
0.01
,
0.001
]
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/cifar10/
cls_label_path
:
./dataset/cifar10/cifar10-2/train.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
256
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.4914
,
0.4822
,
0.4465
]
std
:
[
0.2023
,
0.1994
,
0.2010
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
Query
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/cifar10/
cls_label_path
:
./dataset/cifar10/cifar10-2/test.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.4914
,
0.4822
,
0.4465
]
std
:
[
0.2023
,
0.1994
,
0.2010
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
512
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Gallery
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/cifar10/
cls_label_path
:
./dataset/cifar10/cifar10-2/database.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.4914
,
0.4822
,
0.4465
]
std
:
[
0.2023
,
0.1994
,
0.2010
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
512
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
mAP
:
-
Precisionk
:
topk
:
[
1
,
5
]
ppcls/engine/evaluation/retrieval.py
浏览文件 @
ec5e07da
...
...
@@ -125,6 +125,13 @@ def cal_feature(evaler, name='gallery'):
paddle
.
sum
(
paddle
.
square
(
batch_feas
),
axis
=
1
,
keepdim
=
True
))
batch_feas
=
paddle
.
divide
(
batch_feas
,
feas_norm
)
# do binarize
if
evaler
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"round"
:
batch_feas
=
paddle
.
round
(
batch_feas
).
astype
(
"float32"
)
*
2.0
-
1.0
if
evaler
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"sign"
:
batch_feas
=
paddle
.
sign
(
batch_feas
).
astype
(
"float32"
)
if
all_feas
is
None
:
all_feas
=
batch_feas
if
has_unique_id
:
...
...
@@ -135,8 +142,10 @@ def cal_feature(evaler, name='gallery'):
all_image_id
=
paddle
.
concat
([
all_image_id
,
batch
[
1
]])
if
has_unique_id
:
all_unique_id
=
paddle
.
concat
([
all_unique_id
,
batch
[
2
]])
if
evaler
.
use_dali
:
dataloader_tmp
.
reset
()
if
paddle
.
distributed
.
get_world_size
()
>
1
:
feat_list
=
[]
img_id_list
=
[]
...
...
ppcls/loss/deephashloss.py
0 → 100644
浏览文件 @
ec5e07da
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle
import
paddle.nn
as
nn
class
DSHSDLoss
(
nn
.
Layer
):
"""
# DSHSD(IEEE ACCESS 2019)
# paper [Deep Supervised Hashing Based on Stable Distribution](https://ieeexplore.ieee.org/document/8648432/)
# [DSHSD] epoch:70, bit:48, dataset:cifar10-1, MAP:0.809, Best MAP: 0.809
# [DSHSD] epoch:250, bit:48, dataset:nuswide_21, MAP:0.809, Best MAP: 0.815
# [DSHSD] epoch:135, bit:48, dataset:imagenet, MAP:0.647, Best MAP: 0.647
"""
def
__init__
(
self
,
n_class
,
bit
,
alpha
,
multi_label
=
False
):
super
(
DSHSDLoss
,
self
).
__init__
()
self
.
m
=
2
*
bit
self
.
alpha
=
alpha
self
.
multi_label
=
multi_label
self
.
n_class
=
n_class
self
.
fc
=
paddle
.
nn
.
Linear
(
bit
,
n_class
,
bias_attr
=
False
)
def
forward
(
self
,
input
,
label
):
feature
=
input
[
"features"
]
feature
=
feature
.
tanh
().
astype
(
"float32"
)
dist
=
paddle
.
sum
(
paddle
.
square
((
paddle
.
unsqueeze
(
feature
,
1
)
-
paddle
.
unsqueeze
(
feature
,
0
))),
axis
=
2
)
# label to ont-hot
label
=
paddle
.
flatten
(
label
)
label
=
paddle
.
nn
.
functional
.
one_hot
(
label
,
self
.
n_class
).
astype
(
"float32"
)
s
=
(
paddle
.
matmul
(
label
,
label
,
transpose_y
=
True
)
==
0
).
astype
(
"float32"
)
Ld
=
(
1
-
s
)
/
2
*
dist
+
s
/
2
*
(
self
.
m
-
dist
).
clip
(
min
=
0
)
Ld
=
Ld
.
mean
()
logits
=
self
.
fc
(
feature
)
if
self
.
multi_label
:
# multiple labels classification loss
Lc
=
(
logits
-
label
*
logits
+
((
1
+
(
-
logits
).
exp
()).
log
())).
sum
(
axis
=
1
).
mean
()
else
:
# single labels classification loss
Lc
=
(
-
paddle
.
nn
.
functional
.
softmax
(
logits
).
log
()
*
label
).
sum
(
axis
=
1
).
mean
()
return
{
"dshsdloss"
:
Lc
+
Ld
*
self
.
alpha
}
class
LCDSHLoss
(
nn
.
Layer
):
"""
# paper [Locality-Constrained Deep Supervised Hashing for Image Retrieval](https://www.ijcai.org/Proceedings/2017/0499.pdf)
# [LCDSH] epoch:145, bit:48, dataset:cifar10-1, MAP:0.798, Best MAP: 0.798
# [LCDSH] epoch:183, bit:48, dataset:nuswide_21, MAP:0.833, Best MAP: 0.834
"""
def
__init__
(
self
,
n_class
,
_lambda
):
super
(
LCDSHLoss
,
self
).
__init__
()
self
.
_lambda
=
_lambda
self
.
n_class
=
n_class
def
forward
(
self
,
input
,
label
):
feature
=
input
[
"features"
]
# label to ont-hot
label
=
paddle
.
flatten
(
label
)
label
=
paddle
.
nn
.
functional
.
one_hot
(
label
,
self
.
n_class
).
astype
(
"float32"
)
s
=
2
*
(
paddle
.
matmul
(
label
,
label
,
transpose_y
=
True
)
>
0
).
astype
(
"float32"
)
-
1
inner_product
=
paddle
.
matmul
(
feature
,
feature
,
transpose_y
=
True
)
*
0.5
inner_product
=
inner_product
.
clip
(
min
=-
50
,
max
=
50
)
L1
=
paddle
.
log
(
1
+
paddle
.
exp
(
-
s
*
inner_product
)).
mean
()
b
=
feature
.
sign
()
inner_product_
=
paddle
.
matmul
(
b
,
b
,
transpose_y
=
True
)
*
0.5
sigmoid
=
paddle
.
nn
.
Sigmoid
()
L2
=
(
sigmoid
(
inner_product
)
-
sigmoid
(
inner_product_
)).
pow
(
2
).
mean
()
return
{
"lcdshloss"
:
L1
+
self
.
_lambda
*
L2
}
ppcls/metric/__init__.py
浏览文件 @
ec5e07da
...
...
@@ -16,7 +16,7 @@ from paddle import nn
import
copy
from
collections
import
OrderedDict
from
.metrics
import
TopkAcc
,
mAP
,
mINP
,
Recallk
from
.metrics
import
TopkAcc
,
mAP
,
mINP
,
Recallk
,
Precisionk
from
.metrics
import
DistillationTopkAcc
from
.metrics
import
GoogLeNetTopkAcc
...
...
ppcls/metric/metrics.py
浏览文件 @
ec5e07da
...
...
@@ -168,6 +168,47 @@ class Recallk(nn.Layer):
return
metric_dict
class
Precisionk
(
nn
.
Layer
):
def
__init__
(
self
,
topk
=
(
1
,
5
)):
super
().
__init__
()
assert
isinstance
(
topk
,
(
int
,
list
,
tuple
))
if
isinstance
(
topk
,
int
):
topk
=
[
topk
]
self
.
topk
=
topk
def
forward
(
self
,
similarities_matrix
,
query_img_id
,
gallery_img_id
,
keep_mask
):
metric_dict
=
dict
()
#get cmc
choosen_indices
=
paddle
.
argsort
(
similarities_matrix
,
axis
=
1
,
descending
=
True
)
gallery_labels_transpose
=
paddle
.
transpose
(
gallery_img_id
,
[
1
,
0
])
gallery_labels_transpose
=
paddle
.
broadcast_to
(
gallery_labels_transpose
,
shape
=
[
choosen_indices
.
shape
[
0
],
gallery_labels_transpose
.
shape
[
1
]
])
choosen_label
=
paddle
.
index_sample
(
gallery_labels_transpose
,
choosen_indices
)
equal_flag
=
paddle
.
equal
(
choosen_label
,
query_img_id
)
if
keep_mask
is
not
None
:
keep_mask
=
paddle
.
index_sample
(
keep_mask
.
astype
(
'float32'
),
choosen_indices
)
equal_flag
=
paddle
.
logical_and
(
equal_flag
,
keep_mask
.
astype
(
'bool'
))
equal_flag
=
paddle
.
cast
(
equal_flag
,
'float32'
)
Ns
=
paddle
.
arange
(
gallery_img_id
.
shape
[
0
])
+
1
equal_flag_cumsum
=
paddle
.
cumsum
(
equal_flag
,
axis
=
1
)
Precision_at_k
=
(
paddle
.
mean
(
equal_flag_cumsum
,
axis
=
0
)
/
Ns
).
numpy
()
for
k
in
self
.
topk
:
metric_dict
[
"precision@{}"
.
format
(
k
)]
=
Precision_at_k
[
k
-
1
]
return
metric_dict
class
DistillationTopkAcc
(
TopkAcc
):
def
__init__
(
self
,
model_key
,
feature_key
=
None
,
topk
=
(
1
,
5
)):
super
().
__init__
(
topk
=
topk
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录