Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
6e4bf593
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6e4bf593
编写于
11月 12, 2021
作者:
C
cuicheng01
提交者:
GitHub
11月 12, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1390 from Intsigstephon/feature_binary_model
add Binary general recog configure
上级
680961dc
3b4a45f8
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
199 addition
and
17 deletion
+199
-17
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/variant_models/__init__.py
ppcls/arch/backbone/variant_models/__init__.py
+1
-0
ppcls/arch/backbone/variant_models/pp_lcnet_variant.py
ppcls/arch/backbone/variant_models/pp_lcnet_variant.py
+29
-0
ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_binary.yaml
...alRecognition/GeneralRecognition_PPLCNet_x2_5_binary.yaml
+147
-0
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+2
-0
ppcls/loss/deephashloss.py
ppcls/loss/deephashloss.py
+19
-17
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
6e4bf593
...
...
@@ -62,6 +62,7 @@ from ppcls.arch.backbone.model_zoo.hardnet import HarDNet68, HarDNet85, HarDNet3
from
ppcls.arch.backbone.model_zoo.cspnet
import
CSPDarkNet53
from
ppcls.arch.backbone.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
ppcls.arch.backbone.variant_models.vgg_variant
import
VGG19Sigmoid
from
ppcls.arch.backbone.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
def
get_apis
():
...
...
ppcls/arch/backbone/variant_models/__init__.py
浏览文件 @
6e4bf593
from
.resnet_variant
import
ResNet50_last_stage_stride1
from
.vgg_variant
import
VGG19Sigmoid
from
.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
ppcls/arch/backbone/variant_models/pp_lcnet_variant.py
0 → 100644
浏览文件 @
6e4bf593
import
paddle
from
paddle.nn
import
Sigmoid
from
paddle.nn
import
Tanh
from
ppcls.arch.backbone.legendary_models.pp_lcnet
import
PPLCNet_x2_5
__all__
=
[
"PPLCNet_x2_5_Tanh"
]
class
TanhSuffix
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
origin_layer
):
super
(
TanhSuffix
,
self
).
__init__
()
self
.
origin_layer
=
origin_layer
self
.
tanh
=
Tanh
()
def
forward
(
self
,
input
,
res_dict
=
None
,
**
kwargs
):
x
=
self
.
origin_layer
(
input
)
x
=
self
.
tanh
(
x
)
return
x
def
PPLCNet_x2_5_Tanh
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
def
replace_function
(
origin_layer
):
new_layer
=
TanhSuffix
(
origin_layer
)
return
new_layer
match_re
=
"linear_0"
model
=
PPLCNet_x2_5
(
pretrained
=
pretrained
,
use_ssld
=
use_ssld
,
**
kwargs
)
model
.
replace_sub
(
match_re
,
replace_function
,
True
)
return
model
ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_binary.yaml
0 → 100644
浏览文件 @
6e4bf593
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
100
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
eval_mode
:
retrieval
use_dali
:
False
to_static
:
False
#feature postprocess
feature_normalize
:
False
feature_binarize
:
"
sign"
# model architecture
Arch
:
name
:
RecModel
infer_output_key
:
features
infer_add_softmax
:
False
Backbone
:
name
:
PPLCNet_x2_5_Tanh
pretrained
:
True
use_ssld
:
True
class_num
:
512
Head
:
name
:
FC
embedding_size
:
512
class_num
:
185341
# loss function config for traing/eval process
Loss
:
Train
:
-
DSHSDLoss
:
weight
:
1.0
alpha
:
0.1
Eval
:
-
DSHSDLoss
:
weight
:
1.0
alpha
:
0.1
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.04
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00001
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/all_data
cls_label_path
:
./dataset/all_data/train_reg_all_data.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
256
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
Query
:
dataset
:
name
:
VeriWild
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Gallery
:
dataset
:
name
:
VeriWild
image_root
:
./dataset/Aliproduct/
cls_label_path
:
./dataset/Aliproduct/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Metric
:
Eval
:
-
Recallk
:
topk
:
[
1
,
5
]
ppcls/loss/__init__.py
浏览文件 @
6e4bf593
...
...
@@ -22,6 +22,8 @@ from .distillationloss import DistillationGTCELoss
from
.distillationloss
import
DistillationDMLLoss
from
.multilabelloss
import
MultiLabelLoss
from
.deephashloss
import
DSHSDLoss
,
LCDSHLoss
class
CombinedLoss
(
nn
.
Layer
):
def
__init__
(
self
,
config_list
):
...
...
ppcls/loss/deephashloss.py
浏览文件 @
6e4bf593
...
...
@@ -23,40 +23,42 @@ class DSHSDLoss(nn.Layer):
# [DSHSD] epoch:250, bit:48, dataset:nuswide_21, MAP:0.809, Best MAP: 0.815
# [DSHSD] epoch:135, bit:48, dataset:imagenet, MAP:0.647, Best MAP: 0.647
"""
def
__init__
(
self
,
n_class
,
bit
,
alpha
,
multi_label
=
False
):
def
__init__
(
self
,
alpha
,
multi_label
=
False
):
super
(
DSHSDLoss
,
self
).
__init__
()
self
.
m
=
2
*
bit
self
.
alpha
=
alpha
self
.
multi_label
=
multi_label
self
.
n_class
=
n_class
self
.
fc
=
paddle
.
nn
.
Linear
(
bit
,
n_class
,
bias_attr
=
False
)
def
forward
(
self
,
input
,
label
):
def
forward
(
self
,
input
,
label
):
feature
=
input
[
"features"
]
feature
=
feature
.
tanh
().
astype
(
"float32"
)
logits
=
input
[
"logits"
]
dist
=
paddle
.
sum
(
paddle
.
square
(
(
paddle
.
unsqueeze
(
feature
,
1
)
-
paddle
.
unsqueeze
(
feature
,
0
))),
axis
=
2
)
dist
=
paddle
.
sum
(
paddle
.
square
((
paddle
.
unsqueeze
(
feature
,
1
)
-
paddle
.
unsqueeze
(
feature
,
0
))),
axis
=
2
)
# label to ont-hot
label
=
paddle
.
flatten
(
label
)
label
=
paddle
.
nn
.
functional
.
one_hot
(
label
,
self
.
n_class
).
astype
(
"float32"
)
n_class
=
logits
.
shape
[
1
]
label
=
paddle
.
nn
.
functional
.
one_hot
(
label
,
n_class
).
astype
(
"float32"
)
s
=
(
paddle
.
matmul
(
label
,
label
,
transpose_y
=
True
)
==
0
).
astype
(
"float32"
)
Ld
=
(
1
-
s
)
/
2
*
dist
+
s
/
2
*
(
self
.
m
-
dist
).
clip
(
min
=
0
)
s
=
(
paddle
.
matmul
(
label
,
label
,
transpose_y
=
True
)
==
0
).
astype
(
"float32"
)
margin
=
2
*
feature
.
shape
[
1
]
Ld
=
(
1
-
s
)
/
2
*
dist
+
s
/
2
*
(
margin
-
dist
).
clip
(
min
=
0
)
Ld
=
Ld
.
mean
()
logits
=
self
.
fc
(
feature
)
if
self
.
multi_label
:
# multiple labels classification loss
Lc
=
(
logits
-
label
*
logits
+
((
1
+
(
-
logits
).
exp
()).
log
())).
sum
(
axis
=
1
).
mean
()
Lc
=
(
logits
-
label
*
logits
+
(
(
1
+
(
-
logits
).
exp
()).
log
())).
sum
(
axis
=
1
).
mean
()
else
:
# single labels classification loss
Lc
=
(
-
paddle
.
nn
.
functional
.
softmax
(
logits
).
log
()
*
label
).
sum
(
axis
=
1
).
mean
()
Lc
=
(
-
paddle
.
nn
.
functional
.
softmax
(
logits
).
log
()
*
label
).
sum
(
axis
=
1
).
mean
()
return
{
"dshsdloss"
:
Lc
+
Ld
*
self
.
alpha
}
class
LCDSHLoss
(
nn
.
Layer
):
"""
# paper [Locality-Constrained Deep Supervised Hashing for Image Retrieval](https://www.ijcai.org/Proceedings/2017/0499.pdf)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录