Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
bdc535bb
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bdc535bb
编写于
5月 06, 2022
作者:
D
dongshuilong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add adaface
上级
fea9522a
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
1187 addition
and
3 deletion
+1187
-3
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/model_zoo/ir_net.py
ppcls/arch/backbone/model_zoo/ir_net.py
+458
-0
ppcls/arch/gears/__init__.py
ppcls/arch/gears/__init__.py
+2
-1
ppcls/arch/gears/adamargin.py
ppcls/arch/gears/adamargin.py
+89
-0
ppcls/configs/metric_learning/ir18_adaface.yaml
ppcls/configs/metric_learning/ir18_adaface.yaml
+90
-0
ppcls/data/dataloader/__init__.py
ppcls/data/dataloader/__init__.py
+1
-0
ppcls/data/dataloader/face_dataset.py
ppcls/data/dataloader/face_dataset.py
+285
-0
ppcls/engine/engine.py
ppcls/engine/engine.py
+3
-2
ppcls/engine/evaluation/__init__.py
ppcls/engine/evaluation/__init__.py
+1
-0
ppcls/engine/evaluation/adaface.py
ppcls/engine/evaluation/adaface.py
+257
-0
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
bdc535bb
...
...
@@ -68,6 +68,7 @@ from ppcls.arch.backbone.model_zoo.van import VAN_tiny
from
ppcls.arch.backbone.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
ppcls.arch.backbone.variant_models.vgg_variant
import
VGG19Sigmoid
from
ppcls.arch.backbone.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
ppcls.arch.backbone.model_zoo.ir_net
import
IR_18
,
IR_34
,
IR_50
,
IR_101
,
IR_152
,
IR_200
,
IR_SE_50
,
IR_SE_101
,
IR_SE_152
,
IR_SE_200
# help whl get all the models' api (class type) and components' api (func type)
...
...
ppcls/arch/backbone/model_zoo/ir_net.py
0 → 100644
浏览文件 @
bdc535bb
# this code is based on AdaFace(https://github.com/mk-minchul/AdaFace)
from
collections
import
namedtuple
import
paddle
import
paddle.nn
as
nn
from
paddle.nn
import
Dropout
from
paddle.nn
import
MaxPool2D
from
paddle.nn
import
Sequential
from
paddle.nn
import
Conv2D
,
Linear
from
paddle.nn
import
BatchNorm1D
,
BatchNorm2D
from
paddle.nn
import
ReLU
,
Sigmoid
from
paddle.nn
import
Layer
from
paddle.nn
import
PReLU
from
ppcls.arch.backbone.legendary_models.resnet
import
_load_pretrained
import
os
# def initialize_weights(modules):
# """ Weight initilize, conv2d and linear is initialized with kaiming_normal
# """
# for m in modules:
# if isinstance(m, nn.Conv2D):
# nn.init.kaiming_normal_(m.weight,
# mode='fan_out',
# nonlinearity='relu')
# if m.bias is not None:
# m.bias.data.zero_()
# elif isinstance(m, nn.BatchNorm2D):
# m.weight.data.fill_(1)
# m.bias.data.zero_()
# elif isinstance(m, nn.Linear):
# nn.init.kaiming_normal_(m.weight,
# mode='fan_out',
# nonlinearity='relu')
# if m.bias is not None:
# m.bias.data.zero_()
class
Flatten
(
Layer
):
""" Flat tensor
"""
def
forward
(
self
,
input
):
return
paddle
.
reshape
(
input
,
[
input
.
shape
[
0
],
-
1
])
class
LinearBlock
(
Layer
):
""" Convolution block without no-linear activation layer
"""
def
__init__
(
self
,
in_c
,
out_c
,
kernel
=
(
1
,
1
),
stride
=
(
1
,
1
),
padding
=
(
0
,
0
),
groups
=
1
):
super
(
LinearBlock
,
self
).
__init__
()
self
.
conv
=
Conv2D
(
in_c
,
out_c
,
kernel
,
stride
,
padding
,
groups
=
groups
,
bias_attr
=
None
)
self
.
bn
=
BatchNorm2D
(
out_c
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
return
x
class
GNAP
(
Layer
):
""" Global Norm-Aware Pooling block
"""
def
__init__
(
self
,
in_c
):
super
(
GNAP
,
self
).
__init__
()
self
.
bn1
=
BatchNorm2D
(
in_c
,
weight_attr
=
False
,
bias_attr
=
False
)
self
.
pool
=
nn
.
AdaptiveAvgPool2D
((
1
,
1
))
self
.
bn2
=
BatchNorm1D
(
in_c
,
weight_attr
=
False
,
bias_attr
=
False
)
def
forward
(
self
,
x
):
x
=
self
.
bn1
(
x
)
x_norm
=
paddle
.
norm
(
x
,
2
,
1
,
True
)
x_norm_mean
=
paddle
.
mean
(
x_norm
)
weight
=
x_norm_mean
/
x_norm
x
=
x
*
weight
x
=
self
.
pool
(
x
)
x
=
x
.
view
(
x
.
shape
[
0
],
-
1
)
feature
=
self
.
bn2
(
x
)
return
feature
class
GDC
(
Layer
):
""" Global Depthwise Convolution block
"""
def
__init__
(
self
,
in_c
,
embedding_size
):
super
(
GDC
,
self
).
__init__
()
self
.
conv_6_dw
=
LinearBlock
(
in_c
,
in_c
,
groups
=
in_c
,
kernel
=
(
7
,
7
),
stride
=
(
1
,
1
),
padding
=
(
0
,
0
))
self
.
conv_6_flatten
=
Flatten
()
self
.
linear
=
Linear
(
in_c
,
embedding_size
,
bias_attr
=
False
)
self
.
bn
=
BatchNorm1D
(
embedding_size
,
weight_attr
=
False
,
bias_attr
=
False
)
def
forward
(
self
,
x
):
x
=
self
.
conv_6_dw
(
x
)
x
=
self
.
conv_6_flatten
(
x
)
x
=
self
.
linear
(
x
)
x
=
self
.
bn
(
x
)
return
x
class
SELayer
(
Layer
):
""" SE block
"""
def
__init__
(
self
,
channels
,
reduction
):
super
(
SELayer
,
self
).
__init__
()
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2D
(
1
)
weight_attr
=
paddle
.
framework
.
ParamAttr
(
name
=
"linear_weight"
,
initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
self
.
fc1
=
Conv2D
(
channels
,
channels
//
reduction
,
kernel_size
=
1
,
padding
=
0
,
weight_attr
=
weight_attr
,
bias_attr
=
False
)
self
.
relu
=
ReLU
()
self
.
fc2
=
Conv2D
(
channels
//
reduction
,
channels
,
kernel_size
=
1
,
padding
=
0
,
bias_attr
=
False
)
self
.
sigmoid
=
Sigmoid
()
def
forward
(
self
,
x
):
module_input
=
x
x
=
self
.
avg_pool
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
sigmoid
(
x
)
return
module_input
*
x
class
BasicBlockIR
(
Layer
):
""" BasicBlock for IRNet
"""
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
super
(
BasicBlockIR
,
self
).
__init__
()
if
in_channel
==
depth
:
self
.
shortcut_layer
=
MaxPool2D
(
1
,
stride
)
else
:
self
.
shortcut_layer
=
Sequential
(
Conv2D
(
in_channel
,
depth
,
(
1
,
1
),
stride
,
bias_attr
=
False
),
BatchNorm2D
(
depth
))
self
.
res_layer
=
Sequential
(
BatchNorm2D
(
in_channel
),
Conv2D
(
in_channel
,
depth
,
(
3
,
3
),
(
1
,
1
),
1
,
bias_attr
=
False
),
BatchNorm2D
(
depth
),
PReLU
(
depth
),
Conv2D
(
depth
,
depth
,
(
3
,
3
),
stride
,
1
,
bias_attr
=
False
),
BatchNorm2D
(
depth
))
def
forward
(
self
,
x
):
shortcut
=
self
.
shortcut_layer
(
x
)
res
=
self
.
res_layer
(
x
)
return
res
+
shortcut
class
BottleneckIR
(
Layer
):
""" BasicBlock with bottleneck for IRNet
"""
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
super
(
BottleneckIR
,
self
).
__init__
()
reduction_channel
=
depth
//
4
if
in_channel
==
depth
:
self
.
shortcut_layer
=
MaxPool2D
(
1
,
stride
)
else
:
self
.
shortcut_layer
=
Sequential
(
Conv2D
(
in_channel
,
depth
,
(
1
,
1
),
stride
,
bias_attr
=
False
),
BatchNorm2D
(
depth
))
self
.
res_layer
=
Sequential
(
BatchNorm2D
(
in_channel
),
Conv2D
(
in_channel
,
reduction_channel
,
(
1
,
1
),
(
1
,
1
),
0
,
bias_attr
=
False
),
BatchNorm2D
(
reduction_channel
),
PReLU
(
reduction_channel
),
Conv2D
(
reduction_channel
,
reduction_channel
,
(
3
,
3
),
(
1
,
1
),
1
,
bias_attr
=
False
),
BatchNorm2D
(
reduction_channel
),
PReLU
(
reduction_channel
),
Conv2D
(
reduction_channel
,
depth
,
(
1
,
1
),
stride
,
0
,
bias_attr
=
False
),
BatchNorm2D
(
depth
))
def
forward
(
self
,
x
):
shortcut
=
self
.
shortcut_layer
(
x
)
res
=
self
.
res_layer
(
x
)
return
res
+
shortcut
class
BasicBlockIRSE
(
BasicBlockIR
):
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
super
(
BasicBlockIRSE
,
self
).
__init__
(
in_channel
,
depth
,
stride
)
self
.
res_layer
.
add_sublayer
(
"se_block"
,
SELayer
(
depth
,
16
))
class
BottleneckIRSE
(
BottleneckIR
):
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
super
(
BottleneckIRSE
,
self
).
__init__
(
in_channel
,
depth
,
stride
)
self
.
res_layer
.
add_sublayer
(
"se_block"
,
SELayer
(
depth
,
16
))
class
Bottleneck
(
namedtuple
(
'Block'
,
[
'in_channel'
,
'depth'
,
'stride'
])):
'''A named tuple describing a ResNet block.'''
def
get_block
(
in_channel
,
depth
,
num_units
,
stride
=
2
):
return
[
Bottleneck
(
in_channel
,
depth
,
stride
)]
+
\
[
Bottleneck
(
depth
,
depth
,
1
)
for
i
in
range
(
num_units
-
1
)]
def
get_blocks
(
num_layers
):
if
num_layers
==
18
:
blocks
=
[
get_block
(
in_channel
=
64
,
depth
=
64
,
num_units
=
2
),
get_block
(
in_channel
=
64
,
depth
=
128
,
num_units
=
2
),
get_block
(
in_channel
=
128
,
depth
=
256
,
num_units
=
2
),
get_block
(
in_channel
=
256
,
depth
=
512
,
num_units
=
2
)
]
elif
num_layers
==
34
:
blocks
=
[
get_block
(
in_channel
=
64
,
depth
=
64
,
num_units
=
3
),
get_block
(
in_channel
=
64
,
depth
=
128
,
num_units
=
4
),
get_block
(
in_channel
=
128
,
depth
=
256
,
num_units
=
6
),
get_block
(
in_channel
=
256
,
depth
=
512
,
num_units
=
3
)
]
elif
num_layers
==
50
:
blocks
=
[
get_block
(
in_channel
=
64
,
depth
=
64
,
num_units
=
3
),
get_block
(
in_channel
=
64
,
depth
=
128
,
num_units
=
4
),
get_block
(
in_channel
=
128
,
depth
=
256
,
num_units
=
14
),
get_block
(
in_channel
=
256
,
depth
=
512
,
num_units
=
3
)
]
elif
num_layers
==
100
:
blocks
=
[
get_block
(
in_channel
=
64
,
depth
=
64
,
num_units
=
3
),
get_block
(
in_channel
=
64
,
depth
=
128
,
num_units
=
13
),
get_block
(
in_channel
=
128
,
depth
=
256
,
num_units
=
30
),
get_block
(
in_channel
=
256
,
depth
=
512
,
num_units
=
3
)
]
elif
num_layers
==
152
:
blocks
=
[
get_block
(
in_channel
=
64
,
depth
=
256
,
num_units
=
3
),
get_block
(
in_channel
=
256
,
depth
=
512
,
num_units
=
8
),
get_block
(
in_channel
=
512
,
depth
=
1024
,
num_units
=
36
),
get_block
(
in_channel
=
1024
,
depth
=
2048
,
num_units
=
3
)
]
elif
num_layers
==
200
:
blocks
=
[
get_block
(
in_channel
=
64
,
depth
=
256
,
num_units
=
3
),
get_block
(
in_channel
=
256
,
depth
=
512
,
num_units
=
24
),
get_block
(
in_channel
=
512
,
depth
=
1024
,
num_units
=
36
),
get_block
(
in_channel
=
1024
,
depth
=
2048
,
num_units
=
3
)
]
return
blocks
class
Backbone
(
Layer
):
def
__init__
(
self
,
input_size
,
num_layers
,
mode
=
'ir'
):
""" Args:
input_size: input_size of backbone
num_layers: num_layers of backbone
mode: support ir or irse
"""
super
(
Backbone
,
self
).
__init__
()
assert
input_size
[
0
]
in
[
112
,
224
],
\
"input_size should be [112, 112] or [224, 224]"
assert
num_layers
in
[
18
,
34
,
50
,
100
,
152
,
200
],
\
"num_layers should be 18, 34, 50, 100 or 152"
assert
mode
in
[
'ir'
,
'ir_se'
],
\
"mode should be ir or ir_se"
self
.
input_layer
=
Sequential
(
Conv2D
(
3
,
64
,
(
3
,
3
),
1
,
1
,
bias_attr
=
False
),
BatchNorm2D
(
64
),
PReLU
(
64
))
blocks
=
get_blocks
(
num_layers
)
if
num_layers
<=
100
:
if
mode
==
'ir'
:
unit_module
=
BasicBlockIR
elif
mode
==
'ir_se'
:
unit_module
=
BasicBlockIRSE
output_channel
=
512
else
:
if
mode
==
'ir'
:
unit_module
=
BottleneckIR
elif
mode
==
'ir_se'
:
unit_module
=
BottleneckIRSE
output_channel
=
2048
if
input_size
[
0
]
==
112
:
self
.
output_layer
=
Sequential
(
BatchNorm2D
(
output_channel
),
Dropout
(
0.4
),
Flatten
(),
Linear
(
output_channel
*
7
*
7
,
512
),
BatchNorm1D
(
512
,
weight_attr
=
False
,
bias_attr
=
False
))
else
:
self
.
output_layer
=
Sequential
(
BatchNorm2D
(
output_channel
),
Dropout
(
0.4
),
Flatten
(),
Linear
(
output_channel
*
14
*
14
,
512
),
BatchNorm1D
(
512
,
weight_attr
=
False
,
bias_attr
=
False
))
modules
=
[]
for
block
in
blocks
:
for
bottleneck
in
block
:
modules
.
append
(
unit_module
(
bottleneck
.
in_channel
,
bottleneck
.
depth
,
bottleneck
.
stride
))
self
.
body
=
Sequential
(
*
modules
)
# initialize_weights(self.modules())
def
forward
(
self
,
x
):
# current code only supports one extra image
# it comes with a extra dimension for number of extra image. We will just squeeze it out for now
x
=
self
.
input_layer
(
x
)
for
idx
,
module
in
enumerate
(
self
.
body
):
x
=
module
(
x
)
x
=
self
.
output_layer
(
x
)
# norm = paddle.norm(x, 2, 1, True)
# output = paddle.divide(x, norm)
# return output, norm
return
x
def
IR_18
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-18 model.
"""
model
=
Backbone
(
input_size
,
18
,
'ir'
)
return
model
def
IR_34
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-34 model.
"""
model
=
Backbone
(
input_size
,
34
,
'ir'
)
return
model
def
IR_50
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-50 model.
"""
model
=
Backbone
(
input_size
,
50
,
'ir'
)
return
model
def
IR_101
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-101 model.
"""
model
=
Backbone
(
input_size
,
100
,
'ir'
)
return
model
def
IR_152
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-152 model.
"""
model
=
Backbone
(
input_size
,
152
,
'ir'
)
return
model
def
IR_200
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-200 model.
"""
model
=
Backbone
(
input_size
,
200
,
'ir'
)
return
model
def
IR_SE_50
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-50 model.
"""
model
=
Backbone
(
input_size
,
50
,
'ir_se'
)
return
model
def
IR_SE_101
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-101 model.
"""
model
=
Backbone
(
input_size
,
100
,
'ir_se'
)
return
model
def
IR_SE_152
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-152 model.
"""
model
=
Backbone
(
input_size
,
152
,
'ir_se'
)
return
model
def
IR_SE_200
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-200 model.
"""
model
=
Backbone
(
input_size
,
200
,
'ir_se'
)
return
model
ppcls/arch/gears/__init__.py
浏览文件 @
bdc535bb
...
...
@@ -19,6 +19,7 @@ from .fc import FC
from
.vehicle_neck
import
VehicleNeck
from
paddle.nn
import
Tanh
from
.bnneck
import
BNNeck
from
.adamargin
import
AdaMargin
__all__
=
[
'build_gear'
]
...
...
@@ -26,7 +27,7 @@ __all__ = ['build_gear']
def
build_gear
(
config
):
support_dict
=
[
'ArcMargin'
,
'CosMargin'
,
'CircleMargin'
,
'FC'
,
'VehicleNeck'
,
'Tanh'
,
'BNNeck'
'BNNeck'
,
'AdaMargin'
]
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
...
...
ppcls/arch/gears/adamargin.py
0 → 100644
浏览文件 @
bdc535bb
# This code is based on AdaFace(https://github.com/mk-minchul/AdaFace)
from
paddle.nn
import
Layer
import
math
import
paddle
def
l2_norm
(
input
,
axis
=
1
):
norm
=
paddle
.
norm
(
input
,
2
,
axis
,
True
)
output
=
paddle
.
divide
(
input
,
norm
)
return
output
class
AdaMargin
(
Layer
):
def
__init__
(
self
,
embedding_size
=
512
,
class_num
=
70722
,
m
=
0.4
,
h
=
0.333
,
s
=
64.
,
t_alpha
=
1.0
,
):
super
(
AdaMargin
,
self
).
__init__
()
self
.
classnum
=
class_num
self
.
kernel
=
self
.
create_parameter
(
[
embedding_size
,
class_num
],
attr
=
paddle
.
nn
.
initializer
.
Uniform
())
# initial kernel
# self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
self
.
m
=
m
self
.
eps
=
1e-3
self
.
h
=
h
self
.
s
=
s
# ema prep
self
.
t_alpha
=
t_alpha
self
.
register_buffer
(
't'
,
paddle
.
zeros
([
1
]),
persistable
=
True
)
self
.
register_buffer
(
'batch_mean'
,
paddle
.
ones
([
1
])
*
20
,
persistable
=
True
)
self
.
register_buffer
(
'batch_std'
,
paddle
.
ones
([
1
])
*
100
,
persistable
=
True
)
print
(
'
\n
\AdaFace with the following property'
)
print
(
'self.m'
,
self
.
m
)
print
(
'self.h'
,
self
.
h
)
print
(
'self.s'
,
self
.
s
)
print
(
'self.t_alpha'
,
self
.
t_alpha
)
def
forward
(
self
,
embbedings
,
norms
,
label
):
kernel_norm
=
l2_norm
(
self
.
kernel
,
axis
=
0
)
cosine
=
paddle
.
mm
(
embbedings
,
kernel_norm
)
cosine
=
paddle
.
clip
(
cosine
,
-
1
+
self
.
eps
,
1
-
self
.
eps
)
# for stability
safe_norms
=
paddle
.
clip
(
norms
,
min
=
0.001
,
max
=
100
)
# for stability
safe_norms
=
safe_norms
.
clone
().
detach
()
# update batchmean batchstd
with
paddle
.
no_grad
():
mean
=
safe_norms
.
mean
().
detach
()
std
=
safe_norms
.
std
().
detach
()
self
.
batch_mean
=
mean
*
self
.
t_alpha
+
(
1
-
self
.
t_alpha
)
*
self
.
batch_mean
self
.
batch_std
=
std
*
self
.
t_alpha
+
(
1
-
self
.
t_alpha
)
*
self
.
batch_std
margin_scaler
=
(
safe_norms
-
self
.
batch_mean
)
/
(
self
.
batch_std
+
self
.
eps
)
# 66% between -1, 1
margin_scaler
=
margin_scaler
*
self
.
h
# 68% between -0.333 ,0.333 when h:0.333
margin_scaler
=
paddle
.
clip
(
margin_scaler
,
-
1
,
1
)
# g_angular
m_arc
=
paddle
.
nn
.
functional
.
one_hot
(
label
,
self
.
classnum
)
g_angular
=
self
.
m
*
margin_scaler
*
-
1
m_arc
=
m_arc
*
g_angular
theta
=
paddle
.
acos
(
cosine
)
theta_m
=
paddle
.
clip
(
theta
+
m_arc
,
min
=
self
.
eps
,
max
=
math
.
pi
-
self
.
eps
)
cosine
=
paddle
.
cos
(
theta_m
)
# g_additive
m_cos
=
paddle
.
nn
.
functional
.
one_hot
(
label
,
self
.
classnum
)
g_add
=
self
.
m
+
(
self
.
m
*
margin_scaler
)
m_cos
=
m_cos
*
g_add
cosine
=
cosine
-
m_cos
# scale
scaled_cosine_m
=
cosine
*
self
.
s
return
scaled_cosine_m
ppcls/configs/metric_learning/ir18_adaface.yaml
0 → 100644
浏览文件 @
bdc535bb
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
"
./output/"
device
:
"
gpu"
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
26
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
112
,
112
]
save_inference_dir
:
"
./inference"
eval_mode
:
"
adaface"
# model architecture
Arch
:
name
:
"
RecModel"
infer_output_key
:
"
features"
infer_add_softmax
:
False
Backbone
:
name
:
"
IR_18"
pretrained
:
False
Head
:
name
:
"
AdaMargin"
embedding_size
:
512
class_num
:
70722
m
:
0.4
scale
:
32
h
:
0.3333
t_alpha
:
0.01
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Piecewise
learning_rate
:
0.1
decay_epochs
:
[
12
,
20
,
24
]
values
:
[
0.1
,
0.01
,
0.001
,
0.0001
]
regularizer
:
name
:
'
L2'
coeff
:
0.0001
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
"
AdaFaceDataset"
root_dir
:
"
/work/dataset/face/"
label_path
:
"
/work/dataset/face/train_filter_label.txt"
low_res_augmentation_prob
:
0.2
crop_augmentation_prob
:
0.2
photometric_augmentation_prob
:
0.2
transform
:
-
RandomHorizontalFlip
:
-
ToTensor
:
-
Normalize
:
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
sampler
:
name
:
DistributedBatchSampler
batch_size
:
256
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
6
use_shared_memory
:
True
Eval
:
dataset
:
name
:
FiveValidationDataset
val_data_path
:
/work/dataset/face/faces_emore
concat_mem_file_name
:
/work/dataset/face/faces_emore/concat_validation_memfile
sampler
:
name
:
DistributedBatchSampler
batch_size
:
256
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
6
use_shared_memory
:
True
\ No newline at end of file
ppcls/data/dataloader/__init__.py
浏览文件 @
bdc535bb
...
...
@@ -8,3 +8,4 @@ from ppcls.data.dataloader.mix_dataset import MixDataset
from
ppcls.data.dataloader.mix_sampler
import
MixSampler
from
ppcls.data.dataloader.pk_sampler
import
PKSampler
from
ppcls.data.dataloader.person_dataset
import
Market1501
,
MSMT17
from
ppcls.data.dataloader.face_dataset
import
AdaFaceDataset
,
FiveValidationDataset
ppcls/data/dataloader/face_dataset.py
0 → 100644
浏览文件 @
bdc535bb
import
os
import
json
import
bcolz
import
numpy
as
np
from
PIL
import
Image
import
cv2
import
paddle
import
paddle.vision.datasets
as
datasets
from
paddle.vision
import
transforms
from
paddle.vision.transforms
import
functional
as
F
from
paddle.io
import
Dataset
from
.common_dataset
import
create_operators
# code is based on AdaFace: https://github.com/mk-minchul/AdaFace
def
train_dataset
(
train_dir
,
label_path
,
low_res_augmentation_prob
,
crop_augmentation_prob
,
photometric_augmentation_prob
):
# train_dir = os.path.join(data_root, train_data_path)
train_dataset
=
AdaFaceDataset
(
root_dir
=
train_dir
,
label_path
=
label_path
,
transform
=
transforms
.
Compose
([
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.5
])
]),
low_res_augmentation_prob
=
low_res_augmentation_prob
,
crop_augmentation_prob
=
crop_augmentation_prob
,
photometric_augmentation_prob
=
photometric_augmentation_prob
,
)
return
train_dataset
def
_get_image_size
(
img
):
if
F
.
_is_pil_image
(
img
):
return
img
.
size
elif
F
.
_is_numpy_image
(
img
):
return
img
.
shape
[:
2
][::
-
1
]
elif
F
.
_is_tensor_image
(
img
):
return
img
.
shape
[
1
:][::
-
1
]
# chw
else
:
raise
TypeError
(
"Unexpected type {}"
.
format
(
type
(
img
)))
class
AdaFaceDataset
(
Dataset
):
def
__init__
(
self
,
root_dir
,
label_path
,
transform
=
None
,
low_res_augmentation_prob
=
0.0
,
crop_augmentation_prob
=
0.0
,
photometric_augmentation_prob
=
0.0
,
):
self
.
root_dir
=
root_dir
self
.
low_res_augmentation_prob
=
low_res_augmentation_prob
self
.
crop_augmentation_prob
=
crop_augmentation_prob
self
.
photometric_augmentation_prob
=
photometric_augmentation_prob
self
.
random_resized_crop
=
transforms
.
RandomResizedCrop
(
size
=
(
112
,
112
),
scale
=
(
0.2
,
1.0
),
ratio
=
(
0.75
,
1.3333333333333333
))
self
.
photometric
=
transforms
.
ColorJitter
(
brightness
=
0.5
,
contrast
=
0.5
,
saturation
=
0.5
,
hue
=
0
)
self
.
transform
=
create_operators
(
transform
)
self
.
tot_rot_try
=
0
self
.
rot_success
=
0
with
open
(
label_path
)
as
fd
:
lines
=
fd
.
readlines
()
self
.
samples
=
[]
for
l
in
lines
:
l
=
l
.
strip
().
split
()
self
.
samples
.
append
([
os
.
path
.
join
(
root_dir
,
l
[
0
]),
int
(
l
[
1
])])
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
index
):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
[
path
,
target
]
=
self
.
samples
[
index
]
with
open
(
path
,
'rb'
)
as
f
:
img
=
Image
.
open
(
f
)
sample
=
img
.
convert
(
'RGB'
)
# if 'WebFace' in self.root:
# # swap rgb to bgr since image is in rgb for webface
# sample = Image.fromarray(np.asarray(sample)[:, :, ::-1])
sample
,
_
=
self
.
augment
(
sample
)
if
self
.
transform
is
not
None
:
sample
=
self
.
transform
(
sample
)
return
sample
,
target
def
augment
(
self
,
sample
):
# crop with zero padding augmentation
if
np
.
random
.
random
()
<
self
.
crop_augmentation_prob
:
# RandomResizedCrop augmentation
new
=
np
.
zeros_like
(
np
.
array
(
sample
))
# orig_W, orig_H = F._get_image_size(sample)
orig_W
,
orig_H
=
_get_image_size
(
sample
)
i
,
j
,
h
,
w
=
self
.
random_resized_crop
.
_get_param
(
sample
)
cropped
=
F
.
crop
(
sample
,
i
,
j
,
h
,
w
)
new
[
i
:
i
+
h
,
j
:
j
+
w
,
:]
=
np
.
array
(
cropped
)
sample
=
Image
.
fromarray
(
new
.
astype
(
np
.
uint8
))
crop_ratio
=
min
(
h
,
w
)
/
max
(
orig_H
,
orig_W
)
else
:
crop_ratio
=
1.0
# low resolution augmentation
if
np
.
random
.
random
()
<
self
.
low_res_augmentation_prob
:
# low res augmentation
img_np
,
resize_ratio
=
low_res_augmentation
(
np
.
array
(
sample
))
sample
=
Image
.
fromarray
(
img_np
.
astype
(
np
.
uint8
))
else
:
resize_ratio
=
1
# photometric augmentation
if
np
.
random
.
random
()
<
self
.
photometric_augmentation_prob
:
# fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
# self.photometric._get_params(self.photometric.brightness, self.photometric.contrast,
# self.photometric.saturation, self.photometric.hue)
# for fn_id in fn_idx:
# if fn_id == 0 and brightness_factor is not None:
# sample = F.adjust_brightness(sample, brightness_factor)
# elif fn_id == 1 and contrast_factor is not None:
# sample = F.adjust_contrast(sample, contrast_factor)
# elif fn_id == 2 and saturation_factor is not None:
# sample = F.adjust_saturation(sample, saturation_factor)
sample
=
self
.
photometric
(
sample
)
information_score
=
resize_ratio
*
crop_ratio
return
sample
,
information_score
def
low_res_augmentation
(
img
):
# resize the image to a small size and enlarge it back
img_shape
=
img
.
shape
side_ratio
=
np
.
random
.
uniform
(
0.2
,
1.0
)
small_side
=
int
(
side_ratio
*
img_shape
[
0
])
interpolation
=
np
.
random
.
choice
([
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_AREA
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
])
small_img
=
cv2
.
resize
(
img
,
(
small_side
,
small_side
),
interpolation
=
interpolation
)
interpolation
=
np
.
random
.
choice
([
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_AREA
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
])
aug_img
=
cv2
.
resize
(
small_img
,
(
img_shape
[
1
],
img_shape
[
0
]),
interpolation
=
interpolation
)
return
aug_img
,
side_ratio
class
FiveValidationDataset
(
Dataset
):
def
__init__
(
self
,
val_data_path
,
concat_mem_file_name
):
'''
concatenates all validation datasets from emore
val_data_dict = {
'agedb_30': (agedb_30, agedb_30_issame),
"cfp_fp": (cfp_fp, cfp_fp_issame),
"lfw": (lfw, lfw_issame),
"cplfw": (cplfw, cplfw_issame),
"calfw": (calfw, calfw_issame),
}
agedb_30: 0
cfp_fp: 1
lfw: 2
cplfw: 3
calfw: 4
'''
val_data
=
get_val_data
(
val_data_path
)
age_30
,
cfp_fp
,
lfw
,
age_30_issame
,
cfp_fp_issame
,
lfw_issame
,
cplfw
,
cplfw_issame
,
calfw
,
calfw_issame
=
val_data
val_data_dict
=
{
'agedb_30'
:
(
age_30
,
age_30_issame
),
"cfp_fp"
:
(
cfp_fp
,
cfp_fp_issame
),
"lfw"
:
(
lfw
,
lfw_issame
),
"cplfw"
:
(
cplfw
,
cplfw_issame
),
"calfw"
:
(
calfw
,
calfw_issame
),
}
self
.
dataname_to_idx
=
{
"agedb_30"
:
0
,
"cfp_fp"
:
1
,
"lfw"
:
2
,
"cplfw"
:
3
,
"calfw"
:
4
}
self
.
val_data_dict
=
val_data_dict
# concat all dataset
all_imgs
=
[]
all_issame
=
[]
all_dataname
=
[]
key_orders
=
[]
for
key
,
(
imgs
,
issame
)
in
val_data_dict
.
items
():
all_imgs
.
append
(
imgs
)
dup_issame
=
[
]
# hacky way to make the issame length same as imgs. [1, 1, 0, 0, ...]
for
same
in
issame
:
dup_issame
.
append
(
same
)
dup_issame
.
append
(
same
)
all_issame
.
append
(
dup_issame
)
all_dataname
.
append
([
self
.
dataname_to_idx
[
key
]]
*
len
(
imgs
))
key_orders
.
append
(
key
)
assert
key_orders
==
[
'agedb_30'
,
'cfp_fp'
,
'lfw'
,
'cplfw'
,
'calfw'
]
if
isinstance
(
all_imgs
[
0
],
np
.
memmap
):
self
.
all_imgs
=
read_memmap
(
concat_mem_file_name
)
else
:
self
.
all_imgs
=
np
.
concatenate
(
all_imgs
)
self
.
all_issame
=
np
.
concatenate
(
all_issame
)
self
.
all_dataname
=
np
.
concatenate
(
all_dataname
)
def
__getitem__
(
self
,
index
):
x_np
=
self
.
all_imgs
[
index
].
copy
()
x
=
paddle
.
to_tensor
(
x_np
)
y
=
self
.
all_issame
[
index
]
dataname
=
self
.
all_dataname
[
index
]
return
x
,
y
,
dataname
,
index
def
__len__
(
self
):
return
len
(
self
.
all_imgs
)
def
read_memmap
(
mem_file_name
):
# r+ mode: Open existing file for reading and writing
with
open
(
mem_file_name
+
'.conf'
,
'r'
)
as
file
:
memmap_configs
=
json
.
load
(
file
)
return
np
.
memmap
(
mem_file_name
,
mode
=
'r+'
,
\
shape
=
tuple
(
memmap_configs
[
'shape'
]),
\
dtype
=
memmap_configs
[
'dtype'
])
def
get_val_pair
(
path
,
name
,
use_memfile
=
True
):
if
use_memfile
:
mem_file_dir
=
os
.
path
.
join
(
path
,
name
,
'memfile'
)
mem_file_name
=
os
.
path
.
join
(
mem_file_dir
,
'mem_file.dat'
)
if
os
.
path
.
isdir
(
mem_file_dir
):
print
(
'laoding validation data memfile'
)
np_array
=
read_memmap
(
mem_file_name
)
else
:
os
.
makedirs
(
mem_file_dir
)
carray
=
bcolz
.
carray
(
rootdir
=
os
.
path
.
join
(
path
,
name
),
mode
=
'r'
)
np_array
=
np
.
array
(
carray
)
# mem_array = make_memmap(mem_file_name, np_array)
# del np_array, mem_array
del
np_array
np_array
=
read_memmap
(
mem_file_name
)
else
:
np_array
=
bcolz
.
carray
(
rootdir
=
os
.
path
.
join
(
path
,
name
),
mode
=
'r'
)
issame
=
np
.
load
(
os
.
path
.
join
(
path
,
'{}_list.npy'
.
format
(
name
)))
return
np_array
,
issame
def
get_val_data
(
data_path
):
agedb_30
,
agedb_30_issame
=
get_val_pair
(
data_path
,
'agedb_30'
)
cfp_fp
,
cfp_fp_issame
=
get_val_pair
(
data_path
,
'cfp_fp'
)
lfw
,
lfw_issame
=
get_val_pair
(
data_path
,
'lfw'
)
cplfw
,
cplfw_issame
=
get_val_pair
(
data_path
,
'cplfw'
)
calfw
,
calfw_issame
=
get_val_pair
(
data_path
,
'calfw'
)
return
agedb_30
,
cfp_fp
,
lfw
,
agedb_30_issame
,
cfp_fp_issame
,
lfw_issame
,
cplfw
,
cplfw_issame
,
calfw
,
calfw_issame
if
__name__
==
"__main__"
:
t_dataset
=
train_dataset
(
'/work/dataset/face/'
,
'/work/dataset/face/train_filter_label.txt'
,
1
,
1
,
1
)
img
=
t_dataset
.
__getitem__
(
100
)
print
(
len
(
t_dataset
))
val
=
FiveValidationDataset
(
'/work/dataset/face/faces_emore'
,
'/work/dataset/face/faces_emore/concat_validation_memfile'
)
a
=
1
ppcls/engine/engine.py
浏览文件 @
bdc535bb
...
...
@@ -75,8 +75,9 @@ class Engine(object):
print_config
(
config
)
# init train_func and eval_func
assert
self
.
eval_mode
in
[
"classification"
,
"retrieval"
],
logger
.
error
(
"Invalid eval mode: {}"
.
format
(
self
.
eval_mode
))
assert
self
.
eval_mode
in
[
"classification"
,
"retrieval"
,
"adaface"
],
logger
.
error
(
"Invalid eval mode: {}"
.
format
(
self
.
eval_mode
))
self
.
train_epoch_func
=
train_epoch
self
.
eval_func
=
getattr
(
evaluation
,
self
.
eval_mode
+
"_eval"
)
...
...
ppcls/engine/evaluation/__init__.py
浏览文件 @
bdc535bb
...
...
@@ -14,3 +14,4 @@
from
ppcls.engine.evaluation.classification
import
classification_eval
from
ppcls.engine.evaluation.retrieval
import
retrieval_eval
from
ppcls.engine.evaluation.adaface
import
adaface_eval
\ No newline at end of file
ppcls/engine/evaluation/adaface.py
0 → 100644
浏览文件 @
bdc535bb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
time
import
numpy
as
np
import
platform
import
paddle
import
sklearn
from
sklearn.model_selection
import
KFold
from
sklearn.decomposition
import
PCA
from
ppcls.utils.misc
import
AverageMeter
from
ppcls.utils
import
logger
def
fuse_features_with_norm
(
stacked_embeddings
,
stacked_norms
):
assert
stacked_embeddings
.
ndim
==
3
# (n_features_to_fuse, batch_size, channel)
assert
stacked_norms
.
ndim
==
3
# (n_features_to_fuse, batch_size, 1)
pre_norm_embeddings
=
stacked_embeddings
*
stacked_norms
fused
=
pre_norm_embeddings
.
sum
(
dim
=
0
)
norm
=
paddle
.
norm
(
fused
,
2
,
1
,
True
)
fused
=
paddle
.
divide
(
fused
,
norm
)
return
fused
,
norm
def
adaface_eval
(
engine
,
epoch_id
=
0
):
output_info
=
dict
()
time_info
=
{
"batch_cost"
:
AverageMeter
(
"batch_cost"
,
'.5f'
,
postfix
=
" s,"
),
"reader_cost"
:
AverageMeter
(
"reader_cost"
,
".5f"
,
postfix
=
" s,"
),
}
print_batch_step
=
engine
.
config
[
"Global"
][
"print_batch_step"
]
metric_key
=
None
tic
=
time
.
time
()
unique_dict
=
{}
for
iter_id
,
batch
in
enumerate
(
engine
.
eval_dataloader
):
images
,
labels
,
dataname
,
image_index
=
batch
if
iter_id
==
5
:
for
key
in
time_info
:
time_info
[
key
].
reset
()
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
batch_size
=
images
.
shape
[
0
]
batch
[
0
]
=
paddle
.
to_tensor
(
images
)
embeddings
=
engine
.
model
(
images
)[
"features"
]
norms
=
paddle
.
divide
(
embeddings
,
paddle
.
norm
(
embeddings
,
2
,
1
,
True
))
fliped_images
=
paddle
.
flip
(
images
,
axis
=
[
3
])
flipped_embeddings
=
engine
.
model
(
fliped_images
)[
"features"
]
flipped_norms
=
paddle
.
divide
(
flipped_embeddings
,
paddle
.
norm
(
flipped_embeddings
,
2
,
1
,
True
))
stacked_embeddings
=
paddle
.
stack
(
[
embeddings
,
flipped_embeddings
],
axis
=
0
)
stacked_norms
=
paddle
.
stack
([
norms
,
flipped_norms
],
axis
=
0
)
embeddings
,
norms
=
fuse_features_with_norm
(
stacked_embeddings
,
stacked_norms
)
for
out
,
nor
,
label
,
data
,
idx
in
zip
(
embeddings
,
norms
,
labels
,
dataname
,
image_index
):
unique_dict
[
int
(
idx
.
numpy
())]
=
{
'output'
:
out
,
'norm'
:
nor
,
'target'
:
label
,
'dataname'
:
data
}
# calc metric
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
print_batch_step
==
0
:
time_msg
=
"s, "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
time_info
[
key
].
avg
)
for
key
in
time_info
])
ips_msg
=
"ips: {:.5f} images/sec"
.
format
(
batch_size
/
time_info
[
"batch_cost"
].
avg
)
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
val
)
for
key
in
output_info
])
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
epoch_id
,
iter_id
,
len
(
engine
.
eval_dataloader
),
metric_msg
,
time_msg
,
ips_msg
))
tic
=
time
.
time
()
unique_keys
=
sorted
(
unique_dict
.
keys
())
all_output_tensor
=
paddle
.
stack
(
[
unique_dict
[
key
][
'output'
]
for
key
in
unique_keys
],
axis
=
0
)
all_norm_tensor
=
paddle
.
stack
(
[
unique_dict
[
key
][
'norm'
]
for
key
in
unique_keys
],
axis
=
0
)
all_target_tensor
=
paddle
.
stack
(
[
unique_dict
[
key
][
'target'
]
for
key
in
unique_keys
],
axis
=
0
)
all_dataname_tensor
=
paddle
.
stack
(
[
unique_dict
[
key
][
'dataname'
]
for
key
in
unique_keys
],
axis
=
0
)
eval_result
=
cal_metric
(
all_output_tensor
,
all_norm_tensor
,
all_target_tensor
,
all_dataname_tensor
)
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
])
face_msg
=
", "
.
join
(
[
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
])
for
key
in
eval_result
])
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
+
", "
+
face_msg
))
# do not try to save best eval.model
if
engine
.
eval_metric_func
is
None
:
return
-
1
# return 1st metric in the dict
return
output_info
[
metric_key
].
avg
def
cal_metric
(
all_output_tensor
,
all_norm_tensor
,
all_target_tensor
,
all_dataname_tensor
):
dataname_to_idx
=
{
"agedb_30"
:
0
,
"cfp_fp"
:
1
,
"lfw"
:
2
,
"cplfw"
:
3
,
"calfw"
:
4
}
idx_to_dataname
=
{
val
:
key
for
key
,
val
in
dataname_to_idx
.
items
()}
test_logs
=
{}
# _, indices = paddle.unique(all_dataname_tensor, return_index=True, return_inverse=False, return_counts=False)
for
dataname_idx
in
all_dataname_tensor
.
unique
():
dataname
=
idx_to_dataname
[
dataname_idx
.
item
()]
# per dataset evaluation
embeddings
=
all_output_tensor
[
all_dataname_tensor
==
dataname_idx
].
numpy
()
labels
=
all_target_tensor
[
all_dataname_tensor
==
dataname_idx
].
numpy
()
issame
=
labels
[
0
::
2
]
tpr
,
fpr
,
accuracy
,
best_thresholds
=
evaluate_face
(
embeddings
,
issame
,
nrof_folds
=
10
)
acc
,
best_threshold
=
accuracy
.
mean
(),
best_thresholds
.
mean
()
num_test_samples
=
len
(
embeddings
)
test_logs
[
f
'
{
dataname
}
_test_acc'
]
=
acc
test_logs
[
f
'
{
dataname
}
_test_best_threshold'
]
=
best_threshold
test_logs
[
f
'
{
dataname
}
_num_test_samples'
]
=
num_test_samples
test_acc
=
np
.
mean
([
test_logs
[
f
'
{
dataname
}
_test_acc'
]
for
dataname
in
dataname_to_idx
.
keys
()
if
f
'
{
dataname
}
_test_acc'
in
test_logs
])
test_logs
[
'all_test_acc'
]
=
test_acc
return
test_logs
def
evaluate_face
(
embeddings
,
actual_issame
,
nrof_folds
=
10
,
pca
=
0
):
# Calculate evaluation metrics
thresholds
=
np
.
arange
(
0
,
4
,
0.01
)
embeddings1
=
embeddings
[
0
::
2
]
embeddings2
=
embeddings
[
1
::
2
]
tpr
,
fpr
,
accuracy
,
best_thresholds
=
calculate_roc
(
thresholds
,
embeddings1
,
embeddings2
,
np
.
asarray
(
actual_issame
),
nrof_folds
=
nrof_folds
,
pca
=
pca
)
return
tpr
,
fpr
,
accuracy
,
best_thresholds
def
calculate_roc
(
thresholds
,
embeddings1
,
embeddings2
,
actual_issame
,
nrof_folds
=
10
,
pca
=
0
):
assert
(
embeddings1
.
shape
[
0
]
==
embeddings2
.
shape
[
0
])
assert
(
embeddings1
.
shape
[
1
]
==
embeddings2
.
shape
[
1
])
nrof_pairs
=
min
(
len
(
actual_issame
),
embeddings1
.
shape
[
0
])
nrof_thresholds
=
len
(
thresholds
)
k_fold
=
KFold
(
n_splits
=
nrof_folds
,
shuffle
=
False
)
tprs
=
np
.
zeros
((
nrof_folds
,
nrof_thresholds
))
fprs
=
np
.
zeros
((
nrof_folds
,
nrof_thresholds
))
accuracy
=
np
.
zeros
((
nrof_folds
))
best_thresholds
=
np
.
zeros
((
nrof_folds
))
indices
=
np
.
arange
(
nrof_pairs
)
# print('pca', pca)
dist
=
None
if
pca
==
0
:
diff
=
np
.
subtract
(
embeddings1
,
embeddings2
)
dist
=
np
.
sum
(
np
.
square
(
diff
),
1
)
for
fold_idx
,
(
train_set
,
test_set
)
in
enumerate
(
k_fold
.
split
(
indices
)):
# print('train_set', train_set)
# print('test_set', test_set)
if
pca
>
0
:
print
(
'doing pca on'
,
fold_idx
)
embed1_train
=
embeddings1
[
train_set
]
embed2_train
=
embeddings2
[
train_set
]
_embed_train
=
np
.
concatenate
((
embed1_train
,
embed2_train
),
axis
=
0
)
# print(_embed_train.shape)
pca_model
=
PCA
(
n_components
=
pca
)
pca_model
.
fit
(
_embed_train
)
embed1
=
pca_model
.
transform
(
embeddings1
)
embed2
=
pca_model
.
transform
(
embeddings2
)
embed1
=
sklearn
.
preprocessing
.
normalize
(
embed1
)
embed2
=
sklearn
.
preprocessing
.
normalize
(
embed2
)
# print(embed1.shape, embed2.shape)
diff
=
np
.
subtract
(
embed1
,
embed2
)
dist
=
np
.
sum
(
np
.
square
(
diff
),
1
)
# Find the best threshold for the fold
acc_train
=
np
.
zeros
((
nrof_thresholds
))
for
threshold_idx
,
threshold
in
enumerate
(
thresholds
):
_
,
_
,
acc_train
[
threshold_idx
]
=
calculate_accuracy
(
threshold
,
dist
[
train_set
],
actual_issame
[
train_set
])
best_threshold_index
=
np
.
argmax
(
acc_train
)
best_thresholds
[
fold_idx
]
=
thresholds
[
best_threshold_index
]
for
threshold_idx
,
threshold
in
enumerate
(
thresholds
):
tprs
[
fold_idx
,
threshold_idx
],
fprs
[
fold_idx
,
threshold_idx
],
_
=
calculate_accuracy
(
threshold
,
dist
[
test_set
],
actual_issame
[
test_set
])
_
,
_
,
accuracy
[
fold_idx
]
=
calculate_accuracy
(
thresholds
[
best_threshold_index
],
dist
[
test_set
],
actual_issame
[
test_set
])
tpr
=
np
.
mean
(
tprs
,
0
)
fpr
=
np
.
mean
(
fprs
,
0
)
return
tpr
,
fpr
,
accuracy
,
best_thresholds
def
calculate_accuracy
(
threshold
,
dist
,
actual_issame
):
predict_issame
=
np
.
less
(
dist
,
threshold
)
tp
=
np
.
sum
(
np
.
logical_and
(
predict_issame
,
actual_issame
))
fp
=
np
.
sum
(
np
.
logical_and
(
predict_issame
,
np
.
logical_not
(
actual_issame
)))
tn
=
np
.
sum
(
np
.
logical_and
(
np
.
logical_not
(
predict_issame
),
np
.
logical_not
(
actual_issame
)))
fn
=
np
.
sum
(
np
.
logical_and
(
np
.
logical_not
(
predict_issame
),
actual_issame
))
tpr
=
0
if
(
tp
+
fn
==
0
)
else
float
(
tp
)
/
float
(
tp
+
fn
)
fpr
=
0
if
(
fp
+
tn
==
0
)
else
float
(
fp
)
/
float
(
fp
+
tn
)
acc
=
float
(
tp
+
tn
)
/
dist
.
size
return
tpr
,
fpr
,
acc
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录