Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
32c99be6
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
32c99be6
编写于
5月 16, 2022
作者:
D
dongshuilong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add adaface
上级
72835980
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
186 addition
and
124 deletion
+186
-124
ppcls/arch/backbone/model_zoo/ir_net.py
ppcls/arch/backbone/model_zoo/ir_net.py
+117
-46
ppcls/arch/gears/adamargin.py
ppcls/arch/gears/adamargin.py
+32
-10
ppcls/configs/metric_learning/ir18_adaface.yaml
ppcls/configs/metric_learning/ir18_adaface.yaml
+15
-11
ppcls/data/__init__.py
ppcls/data/__init__.py
+2
-1
ppcls/data/dataloader/face_dataset.py
ppcls/data/dataloader/face_dataset.py
+3
-43
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+1
-0
ppcls/data/preprocess/ops/operators.py
ppcls/data/preprocess/ops/operators.py
+3
-3
ppcls/engine/engine.py
ppcls/engine/engine.py
+1
-1
ppcls/engine/evaluation/adaface.py
ppcls/engine/evaluation/adaface.py
+12
-9
未找到文件。
ppcls/arch/backbone/model_zoo/ir_net.py
浏览文件 @
32c99be6
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this code is based on AdaFace(https://github.com/mk-minchul/AdaFace)
from
collections
import
namedtuple
import
paddle
...
...
@@ -10,28 +23,8 @@ from paddle.nn import BatchNorm1D, BatchNorm2D
from
paddle.nn
import
ReLU
,
Sigmoid
from
paddle.nn
import
Layer
from
paddle.nn
import
PReLU
from
ppcls.arch.backbone.legendary_models.resnet
import
_load_pretrained
import
os
# def initialize_weights(modules):
# """ Weight initilize, conv2d and linear is initialized with kaiming_normal
# """
# for m in modules:
# if isinstance(m, nn.Conv2D):
# nn.init.kaiming_normal_(m.weight,
# mode='fan_out',
# nonlinearity='relu')
# if m.bias is not None:
# m.bias.data.zero_()
# elif isinstance(m, nn.BatchNorm2D):
# m.weight.data.fill_(1)
# m.bias.data.zero_()
# elif isinstance(m, nn.Linear):
# nn.init.kaiming_normal_(m.weight,
# mode='fan_out',
# nonlinearity='relu')
# if m.bias is not None:
# m.bias.data.zero_()
# from ppcls.arch.backbone.legendary_models.resnet import _load_pretrained
class
Flatten
(
Layer
):
...
...
@@ -61,8 +54,14 @@ class LinearBlock(Layer):
stride
,
padding
,
groups
=
groups
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
None
)
self
.
bn
=
BatchNorm2D
(
out_c
)
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
1.0
))
bias_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
0.0
))
self
.
bn
=
BatchNorm2D
(
out_c
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
...
...
@@ -106,7 +105,11 @@ class GDC(Layer):
stride
=
(
1
,
1
),
padding
=
(
0
,
0
))
self
.
conv_6_flatten
=
Flatten
()
self
.
linear
=
Linear
(
in_c
,
embedding_size
,
bias_attr
=
False
)
self
.
linear
=
Linear
(
in_c
,
embedding_size
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
)
self
.
bn
=
BatchNorm1D
(
embedding_size
,
weight_attr
=
False
,
bias_attr
=
False
)
...
...
@@ -125,8 +128,7 @@ class SELayer(Layer):
def
__init__
(
self
,
channels
,
reduction
):
super
(
SELayer
,
self
).
__init__
()
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2D
(
1
)
weight_attr
=
paddle
.
framework
.
ParamAttr
(
name
=
"linear_weight"
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
self
.
fc1
=
Conv2D
(
channels
,
...
...
@@ -142,6 +144,7 @@ class SELayer(Layer):
channels
,
kernel_size
=
1
,
padding
=
0
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
)
self
.
sigmoid
=
Sigmoid
()
...
...
@@ -163,22 +166,44 @@ class BasicBlockIR(Layer):
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
super
(
BasicBlockIR
,
self
).
__init__
()
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
1.0
))
bias_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
0.0
))
if
in_channel
==
depth
:
self
.
shortcut_layer
=
MaxPool2D
(
1
,
stride
)
else
:
self
.
shortcut_layer
=
Sequential
(
Conv2D
(
in_channel
,
depth
,
(
1
,
1
),
stride
,
bias_attr
=
False
),
BatchNorm2D
(
depth
))
in_channel
,
depth
,
(
1
,
1
),
stride
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
))
self
.
res_layer
=
Sequential
(
BatchNorm2D
(
in_channel
),
BatchNorm2D
(
in_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
Conv2D
(
in_channel
,
depth
,
(
3
,
3
),
(
1
,
1
),
1
,
bias_attr
=
False
),
BatchNorm2D
(
depth
),
in_channel
,
depth
,
(
3
,
3
),
(
1
,
1
),
1
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
PReLU
(
depth
),
Conv2D
(
depth
,
depth
,
(
3
,
3
),
stride
,
1
,
bias_attr
=
False
),
BatchNorm2D
(
depth
))
depth
,
depth
,
(
3
,
3
),
stride
,
1
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
))
def
forward
(
self
,
x
):
shortcut
=
self
.
shortcut_layer
(
x
)
...
...
@@ -194,32 +219,56 @@ class BottleneckIR(Layer):
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
super
(
BottleneckIR
,
self
).
__init__
()
reduction_channel
=
depth
//
4
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
1.0
))
bias_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
0.0
))
if
in_channel
==
depth
:
self
.
shortcut_layer
=
MaxPool2D
(
1
,
stride
)
else
:
self
.
shortcut_layer
=
Sequential
(
Conv2D
(
in_channel
,
depth
,
(
1
,
1
),
stride
,
bias_attr
=
False
),
BatchNorm2D
(
depth
))
in_channel
,
depth
,
(
1
,
1
),
stride
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
))
self
.
res_layer
=
Sequential
(
BatchNorm2D
(
in_channel
),
BatchNorm2D
(
in_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
Conv2D
(
in_channel
,
reduction_channel
,
(
1
,
1
),
(
1
,
1
),
0
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
reduction_channel
),
BatchNorm2D
(
reduction_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
PReLU
(
reduction_channel
),
Conv2D
(
reduction_channel
,
reduction_channel
,
(
3
,
3
),
(
1
,
1
),
1
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
reduction_channel
),
BatchNorm2D
(
reduction_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
PReLU
(
reduction_channel
),
Conv2D
(
reduction_channel
,
depth
,
(
1
,
1
),
stride
,
0
,
bias_attr
=
False
),
BatchNorm2D
(
depth
))
reduction_channel
,
depth
,
(
1
,
1
),
stride
,
0
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
))
def
forward
(
self
,
x
):
shortcut
=
self
.
shortcut_layer
(
x
)
...
...
@@ -317,10 +366,20 @@ class Backbone(Layer):
"num_layers should be 18, 34, 50, 100 or 152"
assert
mode
in
[
'ir'
,
'ir_se'
],
\
"mode should be ir or ir_se"
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
1.0
))
bias_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
0.0
))
self
.
input_layer
=
Sequential
(
Conv2D
(
3
,
64
,
(
3
,
3
),
1
,
1
,
bias_attr
=
False
),
BatchNorm2D
(
64
),
3
,
64
,
(
3
,
3
),
1
,
1
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
64
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
PReLU
(
64
))
blocks
=
get_blocks
(
num_layers
)
if
num_layers
<=
100
:
...
...
@@ -338,18 +397,30 @@ class Backbone(Layer):
if
input_size
[
0
]
==
112
:
self
.
output_layer
=
Sequential
(
BatchNorm2D
(
output_channel
),
BatchNorm2D
(
output_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
Dropout
(
0.4
),
Flatten
(),
Linear
(
output_channel
*
7
*
7
,
512
),
Linear
(
output_channel
*
7
*
7
,
512
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
()),
BatchNorm1D
(
512
,
weight_attr
=
False
,
bias_attr
=
False
))
else
:
self
.
output_layer
=
Sequential
(
BatchNorm2D
(
output_channel
),
BatchNorm2D
(
output_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
Dropout
(
0.4
),
Flatten
(),
Linear
(
output_channel
*
14
*
14
,
512
),
Linear
(
output_channel
*
14
*
14
,
512
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
()),
BatchNorm1D
(
512
,
weight_attr
=
False
,
bias_attr
=
False
))
...
...
ppcls/arch/gears/adamargin.py
浏览文件 @
32c99be6
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is based on AdaFace(https://github.com/mk-minchul/AdaFace)
# Paper: AdaFace: Quality Adaptive Margin for Face Recognition
from
paddle.nn
import
Layer
import
math
import
paddle
...
...
@@ -21,8 +36,17 @@ class AdaMargin(Layer):
t_alpha
=
1.0
,
):
super
(
AdaMargin
,
self
).
__init__
()
self
.
classnum
=
class_num
kernel_weight
=
paddle
.
uniform
(
[
embedding_size
,
class_num
],
min
=-
1
,
max
=
1
)
kernel_weight_norm
=
paddle
.
norm
(
kernel_weight
,
p
=
2
,
axis
=
0
,
keepdim
=
True
)
kernel_weight_norm
=
paddle
.
where
(
kernel_weight_norm
>
1e-5
,
kernel_weight_norm
,
paddle
.
ones_like
(
kernel_weight_norm
))
kernel_weight
=
kernel_weight
/
kernel_weight_norm
self
.
kernel
=
self
.
create_parameter
(
[
embedding_size
,
class_num
],
attr
=
paddle
.
nn
.
initializer
.
Uniform
())
[
embedding_size
,
class_num
],
attr
=
paddle
.
nn
.
initializer
.
Assign
(
kernel_weight
))
# initial kernel
# self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
...
...
@@ -39,14 +63,10 @@ class AdaMargin(Layer):
self
.
register_buffer
(
'batch_std'
,
paddle
.
ones
([
1
])
*
100
,
persistable
=
True
)
print
(
'
\n
\AdaFace with the following property'
)
print
(
'self.m'
,
self
.
m
)
print
(
'self.h'
,
self
.
h
)
print
(
'self.s'
,
self
.
s
)
print
(
'self.t_alpha'
,
self
.
t_alpha
)
def
forward
(
self
,
embbedings
,
norms
,
label
):
def
forward
(
self
,
embbedings
,
label
):
norms
=
paddle
.
norm
(
embbedings
,
2
,
1
,
True
)
embbedings
=
paddle
.
divide
(
embbedings
,
norms
)
kernel_norm
=
l2_norm
(
self
.
kernel
,
axis
=
0
)
cosine
=
paddle
.
mm
(
embbedings
,
kernel_norm
)
cosine
=
paddle
.
clip
(
cosine
,
-
1
+
self
.
eps
,
...
...
@@ -70,7 +90,8 @@ class AdaMargin(Layer):
margin_scaler
=
paddle
.
clip
(
margin_scaler
,
-
1
,
1
)
# g_angular
m_arc
=
paddle
.
nn
.
functional
.
one_hot
(
label
,
self
.
classnum
)
m_arc
=
paddle
.
nn
.
functional
.
one_hot
(
label
.
reshape
([
-
1
]),
self
.
classnum
)
g_angular
=
self
.
m
*
margin_scaler
*
-
1
m_arc
=
m_arc
*
g_angular
theta
=
paddle
.
acos
(
cosine
)
...
...
@@ -79,7 +100,8 @@ class AdaMargin(Layer):
cosine
=
paddle
.
cos
(
theta_m
)
# g_additive
m_cos
=
paddle
.
nn
.
functional
.
one_hot
(
label
,
self
.
classnum
)
m_cos
=
paddle
.
nn
.
functional
.
one_hot
(
label
.
reshape
([
-
1
]),
self
.
classnum
)
g_add
=
self
.
m
+
(
self
.
m
*
margin_scaler
)
m_cos
=
m_cos
*
g_add
cosine
=
cosine
-
m_cos
...
...
ppcls/configs/metric_learning/ir18_adaface.yaml
浏览文件 @
32c99be6
...
...
@@ -22,14 +22,14 @@ Arch:
infer_add_softmax
:
False
Backbone
:
name
:
"
IR_18"
pretrained
:
False
input_size
:
[
112
,
112
]
Head
:
name
:
"
AdaMargin"
embedding_size
:
512
class_num
:
70722
m
:
0.4
s
cale
:
32
h
:
0.333
3
s
:
64
h
:
0.333
t_alpha
:
0.01
# loss function config for traing/eval process
...
...
@@ -48,15 +48,15 @@ Optimizer:
values
:
[
0.1
,
0.01
,
0.001
,
0.0001
]
regularizer
:
name
:
'
L2'
coeff
:
0.000
1
coeff
:
0.000
5
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
"
AdaFaceDataset"
root_dir
:
"
/work/
dataset/face/"
label_path
:
"
/work/
dataset/face/train_filter_label.txt"
root_dir
:
"
dataset/face/"
label_path
:
"
dataset/face/train_filter_label.txt"
low_res_augmentation_prob
:
0.2
crop_augmentation_prob
:
0.2
photometric_augmentation_prob
:
0.2
...
...
@@ -66,7 +66,6 @@ DataLoader:
-
Normalize
:
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
sampler
:
name
:
DistributedBatchSampler
batch_size
:
256
...
...
@@ -75,16 +74,21 @@ DataLoader:
loader
:
num_workers
:
6
use_shared_memory
:
True
Eval
:
dataset
:
name
:
FiveValidationDataset
val_data_path
:
/work/
dataset/face/faces_emore
concat_mem_file_name
:
/work/
dataset/face/faces_emore/concat_validation_memfile
val_data_path
:
dataset/face/faces_emore
concat_mem_file_name
:
dataset/face/faces_emore/concat_validation_memfile
sampler
:
name
:
Distributed
BatchSampler
name
:
BatchSampler
batch_size
:
256
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
6
use_shared_memory
:
True
\ No newline at end of file
use_shared_memory
:
True
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
\ No newline at end of file
ppcls/data/__init__.py
浏览文件 @
32c99be6
...
...
@@ -29,6 +29,7 @@ from ppcls.data.dataloader.logo_dataset import LogoDataset
from
ppcls.data.dataloader.icartoon_dataset
import
ICartoonDataset
from
ppcls.data.dataloader.mix_dataset
import
MixDataset
from
ppcls.data.dataloader.person_dataset
import
Market1501
,
MSMT17
from
ppcls.data.dataloader.face_dataset
import
FiveValidationDataset
,
AdaFaceDataset
# sampler
from
ppcls.data.dataloader.DistributedRandomIdentitySampler
import
DistributedRandomIdentitySampler
...
...
@@ -85,7 +86,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
# build sampler
config_sampler
=
config
[
mode
][
'sampler'
]
if
"name"
not
in
config_sampler
:
if
config_sampler
and
"name"
not
in
config_sampler
:
batch_sampler
=
None
batch_size
=
config_sampler
[
"batch_size"
]
drop_last
=
config_sampler
[
"drop_last"
]
...
...
ppcls/data/dataloader/face_dataset.py
浏览文件 @
32c99be6
...
...
@@ -10,28 +10,11 @@ from paddle.vision import transforms
from
paddle.vision.transforms
import
functional
as
F
from
paddle.io
import
Dataset
from
.common_dataset
import
create_operators
from
ppcls.data.preprocess
import
transform
as
transform_func
# code is based on AdaFace: https://github.com/mk-minchul/AdaFace
def
train_dataset
(
train_dir
,
label_path
,
low_res_augmentation_prob
,
crop_augmentation_prob
,
photometric_augmentation_prob
):
# train_dir = os.path.join(data_root, train_data_path)
train_dataset
=
AdaFaceDataset
(
root_dir
=
train_dir
,
label_path
=
label_path
,
transform
=
transforms
.
Compose
([
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.5
])
]),
low_res_augmentation_prob
=
low_res_augmentation_prob
,
crop_augmentation_prob
=
crop_augmentation_prob
,
photometric_augmentation_prob
=
photometric_augmentation_prob
,
)
return
train_dataset
def
_get_image_size
(
img
):
if
F
.
_is_pil_image
(
img
):
return
img
.
size
...
...
@@ -95,7 +78,7 @@ class AdaFaceDataset(Dataset):
sample
,
_
=
self
.
augment
(
sample
)
if
self
.
transform
is
not
None
:
sample
=
self
.
transform
(
sample
)
sample
=
transform_func
(
sample
,
self
.
transform
)
return
sample
,
target
...
...
@@ -125,16 +108,6 @@ class AdaFaceDataset(Dataset):
# photometric augmentation
if
np
.
random
.
random
()
<
self
.
photometric_augmentation_prob
:
# fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
# self.photometric._get_params(self.photometric.brightness, self.photometric.contrast,
# self.photometric.saturation, self.photometric.hue)
# for fn_id in fn_idx:
# if fn_id == 0 and brightness_factor is not None:
# sample = F.adjust_brightness(sample, brightness_factor)
# elif fn_id == 1 and contrast_factor is not None:
# sample = F.adjust_contrast(sample, contrast_factor)
# elif fn_id == 2 and saturation_factor is not None:
# sample = F.adjust_saturation(sample, saturation_factor)
sample
=
self
.
photometric
(
sample
)
information_score
=
resize_ratio
*
crop_ratio
return
sample
,
information_score
...
...
@@ -269,17 +242,4 @@ def get_val_data(data_path):
lfw
,
lfw_issame
=
get_val_pair
(
data_path
,
'lfw'
)
cplfw
,
cplfw_issame
=
get_val_pair
(
data_path
,
'cplfw'
)
calfw
,
calfw_issame
=
get_val_pair
(
data_path
,
'calfw'
)
return
agedb_30
,
cfp_fp
,
lfw
,
agedb_30_issame
,
cfp_fp_issame
,
lfw_issame
,
cplfw
,
cplfw_issame
,
calfw
,
calfw_issame
if
__name__
==
"__main__"
:
t_dataset
=
train_dataset
(
'/work/dataset/face/'
,
'/work/dataset/face/train_filter_label.txt'
,
1
,
1
,
1
)
img
=
t_dataset
.
__getitem__
(
100
)
print
(
len
(
t_dataset
))
val
=
FiveValidationDataset
(
'/work/dataset/face/faces_emore'
,
'/work/dataset/face/faces_emore/concat_validation_memfile'
)
a
=
1
return
agedb_30
,
cfp_fp
,
lfw
,
agedb_30_issame
,
cfp_fp_issame
,
lfw_issame
,
cplfw
,
cplfw_issame
,
calfw
,
calfw_issame
\ No newline at end of file
ppcls/data/preprocess/__init__.py
浏览文件 @
32c99be6
...
...
@@ -33,6 +33,7 @@ from ppcls.data.preprocess.ops.operators import AugMix
from
ppcls.data.preprocess.ops.operators
import
Pad
from
ppcls.data.preprocess.ops.operators
import
ToTensor
from
ppcls.data.preprocess.ops.operators
import
Normalize
from
ppcls.data.preprocess.ops.operators
import
RandomHorizontalFlip
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
OpSampler
,
FmixOperator
...
...
ppcls/data/preprocess/ops/operators.py
浏览文件 @
32c99be6
...
...
@@ -25,7 +25,7 @@ import cv2
import
numpy
as
np
from
PIL
import
Image
,
ImageOps
,
__version__
as
PILLOW_VERSION
from
paddle.vision.transforms
import
ColorJitter
as
RawColorJitter
from
paddle.vision.transforms
import
ToTensor
,
Normalize
from
paddle.vision.transforms
import
ToTensor
,
Normalize
,
RandomHorizontalFlip
from
.autoaugment
import
ImageNetPolicy
from
.functional
import
augmentations
...
...
@@ -463,8 +463,8 @@ class Pad(object):
# Process fill color for affine transforms
major_found
,
minor_found
=
(
int
(
v
)
for
v
in
PILLOW_VERSION
.
split
(
'.'
)[:
2
])
major_required
,
minor_required
=
(
int
(
v
)
for
v
in
min_pil_version
.
split
(
'.'
)[:
2
])
major_required
,
minor_required
=
(
int
(
v
)
for
v
in
min_pil_version
.
split
(
'.'
)[:
2
])
if
major_found
<
major_required
or
(
major_found
==
major_required
and
minor_found
<
minor_required
):
if
fill
is
None
:
...
...
ppcls/engine/engine.py
浏览文件 @
32c99be6
...
...
@@ -116,7 +116,7 @@ class Engine(object):
self
.
config
[
"DataLoader"
],
"Train"
,
self
.
device
,
self
.
use_dali
)
if
self
.
mode
==
"eval"
or
(
self
.
mode
==
"train"
and
self
.
config
[
"Global"
][
"eval_during_train"
]):
if
self
.
eval_mode
==
"classification"
:
if
self
.
eval_mode
in
[
"classification"
,
"adaface"
]
:
self
.
eval_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Eval"
,
self
.
device
,
self
.
use_dali
)
...
...
ppcls/engine/evaluation/adaface.py
浏览文件 @
32c99be6
...
...
@@ -30,7 +30,7 @@ def fuse_features_with_norm(stacked_embeddings, stacked_norms):
assert
stacked_embeddings
.
ndim
==
3
# (n_features_to_fuse, batch_size, channel)
assert
stacked_norms
.
ndim
==
3
# (n_features_to_fuse, batch_size, 1)
pre_norm_embeddings
=
stacked_embeddings
*
stacked_norms
fused
=
pre_norm_embeddings
.
sum
(
dim
=
0
)
fused
=
pre_norm_embeddings
.
sum
(
axis
=
0
)
norm
=
paddle
.
norm
(
fused
,
2
,
1
,
True
)
fused
=
paddle
.
divide
(
fused
,
norm
)
return
fused
,
norm
...
...
@@ -57,12 +57,14 @@ def adaface_eval(engine, epoch_id=0):
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
batch_size
=
images
.
shape
[
0
]
batch
[
0
]
=
paddle
.
to_tensor
(
images
)
embeddings
=
engine
.
model
(
images
)[
"features"
]
embeddings
=
engine
.
model
(
images
,
labels
)[
'features'
]
norms
=
paddle
.
divide
(
embeddings
,
paddle
.
norm
(
embeddings
,
2
,
1
,
True
))
embeddings
=
paddle
.
divide
(
embeddings
,
norms
)
fliped_images
=
paddle
.
flip
(
images
,
axis
=
[
3
])
flipped_embeddings
=
engine
.
model
(
fliped_images
)[
"features"
]
flipped_embeddings
=
engine
.
model
(
fliped_images
,
labels
)[
'features'
]
flipped_norms
=
paddle
.
divide
(
flipped_embeddings
,
paddle
.
norm
(
flipped_embeddings
,
2
,
1
,
True
))
flipped_embeddings
=
paddle
.
divide
(
flipped_embeddings
,
flipped_norms
)
stacked_embeddings
=
paddle
.
stack
(
[
embeddings
,
flipped_embeddings
],
axis
=
0
)
stacked_norms
=
paddle
.
stack
([
norms
,
flipped_norms
],
axis
=
0
)
...
...
@@ -114,20 +116,21 @@ def adaface_eval(engine, epoch_id=0):
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
])
face_msg
=
", "
.
join
(
[
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
])
for
key
in
eval_result
])
face_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
eval_result
[
key
])
for
key
in
eval_result
.
keys
()
])
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
+
", "
+
face_msg
))
# do not try to save best eval.model
if
engine
.
eval_metric_func
is
None
:
return
-
1
# return 1st metric in the dict
return
output_info
[
metric_key
].
avg
return
eval_result
[
'all_test_acc'
]
def
cal_metric
(
all_output_tensor
,
all_norm_tensor
,
all_target_tensor
,
all_dataname_tensor
):
all_target_tensor
=
all_target_tensor
.
reshape
([
-
1
])
all_dataname_tensor
=
all_dataname_tensor
.
reshape
([
-
1
])
dataname_to_idx
=
{
"agedb_30"
:
0
,
"cfp_fp"
:
1
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录