Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
32c99be6
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
32c99be6
编写于
5月 16, 2022
作者:
D
dongshuilong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add adaface
上级
72835980
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
186 addition
and
124 deletion
+186
-124
ppcls/arch/backbone/model_zoo/ir_net.py
ppcls/arch/backbone/model_zoo/ir_net.py
+117
-46
ppcls/arch/gears/adamargin.py
ppcls/arch/gears/adamargin.py
+32
-10
ppcls/configs/metric_learning/ir18_adaface.yaml
ppcls/configs/metric_learning/ir18_adaface.yaml
+15
-11
ppcls/data/__init__.py
ppcls/data/__init__.py
+2
-1
ppcls/data/dataloader/face_dataset.py
ppcls/data/dataloader/face_dataset.py
+3
-43
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+1
-0
ppcls/data/preprocess/ops/operators.py
ppcls/data/preprocess/ops/operators.py
+3
-3
ppcls/engine/engine.py
ppcls/engine/engine.py
+1
-1
ppcls/engine/evaluation/adaface.py
ppcls/engine/evaluation/adaface.py
+12
-9
未找到文件。
ppcls/arch/backbone/model_zoo/ir_net.py
浏览文件 @
32c99be6
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this code is based on AdaFace(https://github.com/mk-minchul/AdaFace)
# this code is based on AdaFace(https://github.com/mk-minchul/AdaFace)
from
collections
import
namedtuple
from
collections
import
namedtuple
import
paddle
import
paddle
...
@@ -10,28 +23,8 @@ from paddle.nn import BatchNorm1D, BatchNorm2D
...
@@ -10,28 +23,8 @@ from paddle.nn import BatchNorm1D, BatchNorm2D
from
paddle.nn
import
ReLU
,
Sigmoid
from
paddle.nn
import
ReLU
,
Sigmoid
from
paddle.nn
import
Layer
from
paddle.nn
import
Layer
from
paddle.nn
import
PReLU
from
paddle.nn
import
PReLU
from
ppcls.arch.backbone.legendary_models.resnet
import
_load_pretrained
import
os
# from ppcls.arch.backbone.legendary_models.resnet import _load_pretrained
# def initialize_weights(modules):
# """ Weight initilize, conv2d and linear is initialized with kaiming_normal
# """
# for m in modules:
# if isinstance(m, nn.Conv2D):
# nn.init.kaiming_normal_(m.weight,
# mode='fan_out',
# nonlinearity='relu')
# if m.bias is not None:
# m.bias.data.zero_()
# elif isinstance(m, nn.BatchNorm2D):
# m.weight.data.fill_(1)
# m.bias.data.zero_()
# elif isinstance(m, nn.Linear):
# nn.init.kaiming_normal_(m.weight,
# mode='fan_out',
# nonlinearity='relu')
# if m.bias is not None:
# m.bias.data.zero_()
class
Flatten
(
Layer
):
class
Flatten
(
Layer
):
...
@@ -61,8 +54,14 @@ class LinearBlock(Layer):
...
@@ -61,8 +54,14 @@ class LinearBlock(Layer):
stride
,
stride
,
padding
,
padding
,
groups
=
groups
,
groups
=
groups
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
None
)
bias_attr
=
None
)
self
.
bn
=
BatchNorm2D
(
out_c
)
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
1.0
))
bias_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
0.0
))
self
.
bn
=
BatchNorm2D
(
out_c
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
...
@@ -106,7 +105,11 @@ class GDC(Layer):
...
@@ -106,7 +105,11 @@ class GDC(Layer):
stride
=
(
1
,
1
),
stride
=
(
1
,
1
),
padding
=
(
0
,
0
))
padding
=
(
0
,
0
))
self
.
conv_6_flatten
=
Flatten
()
self
.
conv_6_flatten
=
Flatten
()
self
.
linear
=
Linear
(
in_c
,
embedding_size
,
bias_attr
=
False
)
self
.
linear
=
Linear
(
in_c
,
embedding_size
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
)
self
.
bn
=
BatchNorm1D
(
self
.
bn
=
BatchNorm1D
(
embedding_size
,
weight_attr
=
False
,
bias_attr
=
False
)
embedding_size
,
weight_attr
=
False
,
bias_attr
=
False
)
...
@@ -125,8 +128,7 @@ class SELayer(Layer):
...
@@ -125,8 +128,7 @@ class SELayer(Layer):
def
__init__
(
self
,
channels
,
reduction
):
def
__init__
(
self
,
channels
,
reduction
):
super
(
SELayer
,
self
).
__init__
()
super
(
SELayer
,
self
).
__init__
()
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2D
(
1
)
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2D
(
1
)
weight_attr
=
paddle
.
framework
.
ParamAttr
(
weight_attr
=
paddle
.
ParamAttr
(
name
=
"linear_weight"
,
initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
self
.
fc1
=
Conv2D
(
self
.
fc1
=
Conv2D
(
channels
,
channels
,
...
@@ -142,6 +144,7 @@ class SELayer(Layer):
...
@@ -142,6 +144,7 @@ class SELayer(Layer):
channels
,
channels
,
kernel_size
=
1
,
kernel_size
=
1
,
padding
=
0
,
padding
=
0
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
sigmoid
=
Sigmoid
()
self
.
sigmoid
=
Sigmoid
()
...
@@ -163,22 +166,44 @@ class BasicBlockIR(Layer):
...
@@ -163,22 +166,44 @@ class BasicBlockIR(Layer):
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
super
(
BasicBlockIR
,
self
).
__init__
()
super
(
BasicBlockIR
,
self
).
__init__
()
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
1.0
))
bias_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
0.0
))
if
in_channel
==
depth
:
if
in_channel
==
depth
:
self
.
shortcut_layer
=
MaxPool2D
(
1
,
stride
)
self
.
shortcut_layer
=
MaxPool2D
(
1
,
stride
)
else
:
else
:
self
.
shortcut_layer
=
Sequential
(
self
.
shortcut_layer
=
Sequential
(
Conv2D
(
Conv2D
(
in_channel
,
depth
,
(
1
,
1
),
stride
,
bias_attr
=
False
),
in_channel
,
BatchNorm2D
(
depth
))
depth
,
(
1
,
1
),
stride
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
))
self
.
res_layer
=
Sequential
(
self
.
res_layer
=
Sequential
(
BatchNorm2D
(
in_channel
),
BatchNorm2D
(
in_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
Conv2D
(
Conv2D
(
in_channel
,
depth
,
(
3
,
3
),
(
1
,
1
),
1
,
bias_attr
=
False
),
in_channel
,
BatchNorm2D
(
depth
),
depth
,
(
3
,
3
),
(
1
,
1
),
1
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
PReLU
(
depth
),
PReLU
(
depth
),
Conv2D
(
Conv2D
(
depth
,
depth
,
(
3
,
3
),
stride
,
1
,
bias_attr
=
False
),
depth
,
BatchNorm2D
(
depth
))
depth
,
(
3
,
3
),
stride
,
1
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
shortcut
=
self
.
shortcut_layer
(
x
)
shortcut
=
self
.
shortcut_layer
(
x
)
...
@@ -194,32 +219,56 @@ class BottleneckIR(Layer):
...
@@ -194,32 +219,56 @@ class BottleneckIR(Layer):
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
super
(
BottleneckIR
,
self
).
__init__
()
super
(
BottleneckIR
,
self
).
__init__
()
reduction_channel
=
depth
//
4
reduction_channel
=
depth
//
4
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
1.0
))
bias_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
0.0
))
if
in_channel
==
depth
:
if
in_channel
==
depth
:
self
.
shortcut_layer
=
MaxPool2D
(
1
,
stride
)
self
.
shortcut_layer
=
MaxPool2D
(
1
,
stride
)
else
:
else
:
self
.
shortcut_layer
=
Sequential
(
self
.
shortcut_layer
=
Sequential
(
Conv2D
(
Conv2D
(
in_channel
,
depth
,
(
1
,
1
),
stride
,
bias_attr
=
False
),
in_channel
,
BatchNorm2D
(
depth
))
depth
,
(
1
,
1
),
stride
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
))
self
.
res_layer
=
Sequential
(
self
.
res_layer
=
Sequential
(
BatchNorm2D
(
in_channel
),
BatchNorm2D
(
in_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
Conv2D
(
Conv2D
(
in_channel
,
in_channel
,
reduction_channel
,
(
1
,
1
),
(
1
,
1
),
reduction_channel
,
(
1
,
1
),
(
1
,
1
),
0
,
0
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
bias_attr
=
False
),
BatchNorm2D
(
reduction_channel
),
BatchNorm2D
(
reduction_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
PReLU
(
reduction_channel
),
PReLU
(
reduction_channel
),
Conv2D
(
Conv2D
(
reduction_channel
,
reduction_channel
,
reduction_channel
,
(
3
,
3
),
(
1
,
1
),
reduction_channel
,
(
3
,
3
),
(
1
,
1
),
1
,
1
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
bias_attr
=
False
),
BatchNorm2D
(
reduction_channel
),
BatchNorm2D
(
reduction_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
PReLU
(
reduction_channel
),
PReLU
(
reduction_channel
),
Conv2D
(
Conv2D
(
reduction_channel
,
depth
,
(
1
,
1
),
stride
,
0
,
bias_attr
=
False
),
reduction_channel
,
BatchNorm2D
(
depth
))
depth
,
(
1
,
1
),
stride
,
0
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
shortcut
=
self
.
shortcut_layer
(
x
)
shortcut
=
self
.
shortcut_layer
(
x
)
...
@@ -317,10 +366,20 @@ class Backbone(Layer):
...
@@ -317,10 +366,20 @@ class Backbone(Layer):
"num_layers should be 18, 34, 50, 100 or 152"
"num_layers should be 18, 34, 50, 100 or 152"
assert
mode
in
[
'ir'
,
'ir_se'
],
\
assert
mode
in
[
'ir'
,
'ir_se'
],
\
"mode should be ir or ir_se"
"mode should be ir or ir_se"
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
1.0
))
bias_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
0.0
))
self
.
input_layer
=
Sequential
(
self
.
input_layer
=
Sequential
(
Conv2D
(
Conv2D
(
3
,
64
,
(
3
,
3
),
1
,
1
,
bias_attr
=
False
),
3
,
BatchNorm2D
(
64
),
64
,
(
3
,
3
),
1
,
1
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
64
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
PReLU
(
64
))
PReLU
(
64
))
blocks
=
get_blocks
(
num_layers
)
blocks
=
get_blocks
(
num_layers
)
if
num_layers
<=
100
:
if
num_layers
<=
100
:
...
@@ -338,18 +397,30 @@ class Backbone(Layer):
...
@@ -338,18 +397,30 @@ class Backbone(Layer):
if
input_size
[
0
]
==
112
:
if
input_size
[
0
]
==
112
:
self
.
output_layer
=
Sequential
(
self
.
output_layer
=
Sequential
(
BatchNorm2D
(
output_channel
),
BatchNorm2D
(
output_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
Dropout
(
0.4
),
Dropout
(
0.4
),
Flatten
(),
Flatten
(),
Linear
(
output_channel
*
7
*
7
,
512
),
Linear
(
output_channel
*
7
*
7
,
512
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
()),
BatchNorm1D
(
BatchNorm1D
(
512
,
weight_attr
=
False
,
bias_attr
=
False
))
512
,
weight_attr
=
False
,
bias_attr
=
False
))
else
:
else
:
self
.
output_layer
=
Sequential
(
self
.
output_layer
=
Sequential
(
BatchNorm2D
(
output_channel
),
BatchNorm2D
(
output_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
Dropout
(
0.4
),
Dropout
(
0.4
),
Flatten
(),
Flatten
(),
Linear
(
output_channel
*
14
*
14
,
512
),
Linear
(
output_channel
*
14
*
14
,
512
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
()),
BatchNorm1D
(
BatchNorm1D
(
512
,
weight_attr
=
False
,
bias_attr
=
False
))
512
,
weight_attr
=
False
,
bias_attr
=
False
))
...
...
ppcls/arch/gears/adamargin.py
浏览文件 @
32c99be6
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is based on AdaFace(https://github.com/mk-minchul/AdaFace)
# This code is based on AdaFace(https://github.com/mk-minchul/AdaFace)
# Paper: AdaFace: Quality Adaptive Margin for Face Recognition
from
paddle.nn
import
Layer
from
paddle.nn
import
Layer
import
math
import
math
import
paddle
import
paddle
...
@@ -21,8 +36,17 @@ class AdaMargin(Layer):
...
@@ -21,8 +36,17 @@ class AdaMargin(Layer):
t_alpha
=
1.0
,
):
t_alpha
=
1.0
,
):
super
(
AdaMargin
,
self
).
__init__
()
super
(
AdaMargin
,
self
).
__init__
()
self
.
classnum
=
class_num
self
.
classnum
=
class_num
kernel_weight
=
paddle
.
uniform
(
[
embedding_size
,
class_num
],
min
=-
1
,
max
=
1
)
kernel_weight_norm
=
paddle
.
norm
(
kernel_weight
,
p
=
2
,
axis
=
0
,
keepdim
=
True
)
kernel_weight_norm
=
paddle
.
where
(
kernel_weight_norm
>
1e-5
,
kernel_weight_norm
,
paddle
.
ones_like
(
kernel_weight_norm
))
kernel_weight
=
kernel_weight
/
kernel_weight_norm
self
.
kernel
=
self
.
create_parameter
(
self
.
kernel
=
self
.
create_parameter
(
[
embedding_size
,
class_num
],
attr
=
paddle
.
nn
.
initializer
.
Uniform
())
[
embedding_size
,
class_num
],
attr
=
paddle
.
nn
.
initializer
.
Assign
(
kernel_weight
))
# initial kernel
# initial kernel
# self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
# self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
...
@@ -39,14 +63,10 @@ class AdaMargin(Layer):
...
@@ -39,14 +63,10 @@ class AdaMargin(Layer):
self
.
register_buffer
(
self
.
register_buffer
(
'batch_std'
,
paddle
.
ones
([
1
])
*
100
,
persistable
=
True
)
'batch_std'
,
paddle
.
ones
([
1
])
*
100
,
persistable
=
True
)
print
(
'
\n
\AdaFace with the following property'
)
def
forward
(
self
,
embbedings
,
label
):
print
(
'self.m'
,
self
.
m
)
print
(
'self.h'
,
self
.
h
)
print
(
'self.s'
,
self
.
s
)
print
(
'self.t_alpha'
,
self
.
t_alpha
)
def
forward
(
self
,
embbedings
,
norms
,
label
):
norms
=
paddle
.
norm
(
embbedings
,
2
,
1
,
True
)
embbedings
=
paddle
.
divide
(
embbedings
,
norms
)
kernel_norm
=
l2_norm
(
self
.
kernel
,
axis
=
0
)
kernel_norm
=
l2_norm
(
self
.
kernel
,
axis
=
0
)
cosine
=
paddle
.
mm
(
embbedings
,
kernel_norm
)
cosine
=
paddle
.
mm
(
embbedings
,
kernel_norm
)
cosine
=
paddle
.
clip
(
cosine
,
-
1
+
self
.
eps
,
cosine
=
paddle
.
clip
(
cosine
,
-
1
+
self
.
eps
,
...
@@ -70,7 +90,8 @@ class AdaMargin(Layer):
...
@@ -70,7 +90,8 @@ class AdaMargin(Layer):
margin_scaler
=
paddle
.
clip
(
margin_scaler
,
-
1
,
1
)
margin_scaler
=
paddle
.
clip
(
margin_scaler
,
-
1
,
1
)
# g_angular
# g_angular
m_arc
=
paddle
.
nn
.
functional
.
one_hot
(
label
,
self
.
classnum
)
m_arc
=
paddle
.
nn
.
functional
.
one_hot
(
label
.
reshape
([
-
1
]),
self
.
classnum
)
g_angular
=
self
.
m
*
margin_scaler
*
-
1
g_angular
=
self
.
m
*
margin_scaler
*
-
1
m_arc
=
m_arc
*
g_angular
m_arc
=
m_arc
*
g_angular
theta
=
paddle
.
acos
(
cosine
)
theta
=
paddle
.
acos
(
cosine
)
...
@@ -79,7 +100,8 @@ class AdaMargin(Layer):
...
@@ -79,7 +100,8 @@ class AdaMargin(Layer):
cosine
=
paddle
.
cos
(
theta_m
)
cosine
=
paddle
.
cos
(
theta_m
)
# g_additive
# g_additive
m_cos
=
paddle
.
nn
.
functional
.
one_hot
(
label
,
self
.
classnum
)
m_cos
=
paddle
.
nn
.
functional
.
one_hot
(
label
.
reshape
([
-
1
]),
self
.
classnum
)
g_add
=
self
.
m
+
(
self
.
m
*
margin_scaler
)
g_add
=
self
.
m
+
(
self
.
m
*
margin_scaler
)
m_cos
=
m_cos
*
g_add
m_cos
=
m_cos
*
g_add
cosine
=
cosine
-
m_cos
cosine
=
cosine
-
m_cos
...
...
ppcls/configs/metric_learning/ir18_adaface.yaml
浏览文件 @
32c99be6
...
@@ -22,14 +22,14 @@ Arch:
...
@@ -22,14 +22,14 @@ Arch:
infer_add_softmax
:
False
infer_add_softmax
:
False
Backbone
:
Backbone
:
name
:
"
IR_18"
name
:
"
IR_18"
pretrained
:
False
input_size
:
[
112
,
112
]
Head
:
Head
:
name
:
"
AdaMargin"
name
:
"
AdaMargin"
embedding_size
:
512
embedding_size
:
512
class_num
:
70722
class_num
:
70722
m
:
0.4
m
:
0.4
s
cale
:
32
s
:
64
h
:
0.333
3
h
:
0.333
t_alpha
:
0.01
t_alpha
:
0.01
# loss function config for traing/eval process
# loss function config for traing/eval process
...
@@ -48,15 +48,15 @@ Optimizer:
...
@@ -48,15 +48,15 @@ Optimizer:
values
:
[
0.1
,
0.01
,
0.001
,
0.0001
]
values
:
[
0.1
,
0.01
,
0.001
,
0.0001
]
regularizer
:
regularizer
:
name
:
'
L2'
name
:
'
L2'
coeff
:
0.000
1
coeff
:
0.000
5
# data loader for train and eval
# data loader for train and eval
DataLoader
:
DataLoader
:
Train
:
Train
:
dataset
:
dataset
:
name
:
"
AdaFaceDataset"
name
:
"
AdaFaceDataset"
root_dir
:
"
/work/
dataset/face/"
root_dir
:
"
dataset/face/"
label_path
:
"
/work/
dataset/face/train_filter_label.txt"
label_path
:
"
dataset/face/train_filter_label.txt"
low_res_augmentation_prob
:
0.2
low_res_augmentation_prob
:
0.2
crop_augmentation_prob
:
0.2
crop_augmentation_prob
:
0.2
photometric_augmentation_prob
:
0.2
photometric_augmentation_prob
:
0.2
...
@@ -66,7 +66,6 @@ DataLoader:
...
@@ -66,7 +66,6 @@ DataLoader:
-
Normalize
:
-
Normalize
:
mean
:
[
0.5
,
0.5
,
0.5
]
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
sampler
:
sampler
:
name
:
DistributedBatchSampler
name
:
DistributedBatchSampler
batch_size
:
256
batch_size
:
256
...
@@ -75,16 +74,21 @@ DataLoader:
...
@@ -75,16 +74,21 @@ DataLoader:
loader
:
loader
:
num_workers
:
6
num_workers
:
6
use_shared_memory
:
True
use_shared_memory
:
True
Eval
:
Eval
:
dataset
:
dataset
:
name
:
FiveValidationDataset
name
:
FiveValidationDataset
val_data_path
:
/work/
dataset/face/faces_emore
val_data_path
:
dataset/face/faces_emore
concat_mem_file_name
:
/work/
dataset/face/faces_emore/concat_validation_memfile
concat_mem_file_name
:
dataset/face/faces_emore/concat_validation_memfile
sampler
:
sampler
:
name
:
Distributed
BatchSampler
name
:
BatchSampler
batch_size
:
256
batch_size
:
256
drop_last
:
False
drop_last
:
False
shuffle
:
True
shuffle
:
True
loader
:
loader
:
num_workers
:
6
num_workers
:
6
use_shared_memory
:
True
use_shared_memory
:
True
\ No newline at end of file
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
\ No newline at end of file
ppcls/data/__init__.py
浏览文件 @
32c99be6
...
@@ -29,6 +29,7 @@ from ppcls.data.dataloader.logo_dataset import LogoDataset
...
@@ -29,6 +29,7 @@ from ppcls.data.dataloader.logo_dataset import LogoDataset
from
ppcls.data.dataloader.icartoon_dataset
import
ICartoonDataset
from
ppcls.data.dataloader.icartoon_dataset
import
ICartoonDataset
from
ppcls.data.dataloader.mix_dataset
import
MixDataset
from
ppcls.data.dataloader.mix_dataset
import
MixDataset
from
ppcls.data.dataloader.person_dataset
import
Market1501
,
MSMT17
from
ppcls.data.dataloader.person_dataset
import
Market1501
,
MSMT17
from
ppcls.data.dataloader.face_dataset
import
FiveValidationDataset
,
AdaFaceDataset
# sampler
# sampler
from
ppcls.data.dataloader.DistributedRandomIdentitySampler
import
DistributedRandomIdentitySampler
from
ppcls.data.dataloader.DistributedRandomIdentitySampler
import
DistributedRandomIdentitySampler
...
@@ -85,7 +86,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
...
@@ -85,7 +86,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
# build sampler
# build sampler
config_sampler
=
config
[
mode
][
'sampler'
]
config_sampler
=
config
[
mode
][
'sampler'
]
if
"name"
not
in
config_sampler
:
if
config_sampler
and
"name"
not
in
config_sampler
:
batch_sampler
=
None
batch_sampler
=
None
batch_size
=
config_sampler
[
"batch_size"
]
batch_size
=
config_sampler
[
"batch_size"
]
drop_last
=
config_sampler
[
"drop_last"
]
drop_last
=
config_sampler
[
"drop_last"
]
...
...
ppcls/data/dataloader/face_dataset.py
浏览文件 @
32c99be6
...
@@ -10,28 +10,11 @@ from paddle.vision import transforms
...
@@ -10,28 +10,11 @@ from paddle.vision import transforms
from
paddle.vision.transforms
import
functional
as
F
from
paddle.vision.transforms
import
functional
as
F
from
paddle.io
import
Dataset
from
paddle.io
import
Dataset
from
.common_dataset
import
create_operators
from
.common_dataset
import
create_operators
from
ppcls.data.preprocess
import
transform
as
transform_func
# code is based on AdaFace: https://github.com/mk-minchul/AdaFace
# code is based on AdaFace: https://github.com/mk-minchul/AdaFace
def
train_dataset
(
train_dir
,
label_path
,
low_res_augmentation_prob
,
crop_augmentation_prob
,
photometric_augmentation_prob
):
# train_dir = os.path.join(data_root, train_data_path)
train_dataset
=
AdaFaceDataset
(
root_dir
=
train_dir
,
label_path
=
label_path
,
transform
=
transforms
.
Compose
([
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.5
])
]),
low_res_augmentation_prob
=
low_res_augmentation_prob
,
crop_augmentation_prob
=
crop_augmentation_prob
,
photometric_augmentation_prob
=
photometric_augmentation_prob
,
)
return
train_dataset
def
_get_image_size
(
img
):
def
_get_image_size
(
img
):
if
F
.
_is_pil_image
(
img
):
if
F
.
_is_pil_image
(
img
):
return
img
.
size
return
img
.
size
...
@@ -95,7 +78,7 @@ class AdaFaceDataset(Dataset):
...
@@ -95,7 +78,7 @@ class AdaFaceDataset(Dataset):
sample
,
_
=
self
.
augment
(
sample
)
sample
,
_
=
self
.
augment
(
sample
)
if
self
.
transform
is
not
None
:
if
self
.
transform
is
not
None
:
sample
=
self
.
transform
(
sample
)
sample
=
transform_func
(
sample
,
self
.
transform
)
return
sample
,
target
return
sample
,
target
...
@@ -125,16 +108,6 @@ class AdaFaceDataset(Dataset):
...
@@ -125,16 +108,6 @@ class AdaFaceDataset(Dataset):
# photometric augmentation
# photometric augmentation
if
np
.
random
.
random
()
<
self
.
photometric_augmentation_prob
:
if
np
.
random
.
random
()
<
self
.
photometric_augmentation_prob
:
# fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
# self.photometric._get_params(self.photometric.brightness, self.photometric.contrast,
# self.photometric.saturation, self.photometric.hue)
# for fn_id in fn_idx:
# if fn_id == 0 and brightness_factor is not None:
# sample = F.adjust_brightness(sample, brightness_factor)
# elif fn_id == 1 and contrast_factor is not None:
# sample = F.adjust_contrast(sample, contrast_factor)
# elif fn_id == 2 and saturation_factor is not None:
# sample = F.adjust_saturation(sample, saturation_factor)
sample
=
self
.
photometric
(
sample
)
sample
=
self
.
photometric
(
sample
)
information_score
=
resize_ratio
*
crop_ratio
information_score
=
resize_ratio
*
crop_ratio
return
sample
,
information_score
return
sample
,
information_score
...
@@ -269,17 +242,4 @@ def get_val_data(data_path):
...
@@ -269,17 +242,4 @@ def get_val_data(data_path):
lfw
,
lfw_issame
=
get_val_pair
(
data_path
,
'lfw'
)
lfw
,
lfw_issame
=
get_val_pair
(
data_path
,
'lfw'
)
cplfw
,
cplfw_issame
=
get_val_pair
(
data_path
,
'cplfw'
)
cplfw
,
cplfw_issame
=
get_val_pair
(
data_path
,
'cplfw'
)
calfw
,
calfw_issame
=
get_val_pair
(
data_path
,
'calfw'
)
calfw
,
calfw_issame
=
get_val_pair
(
data_path
,
'calfw'
)
return
agedb_30
,
cfp_fp
,
lfw
,
agedb_30_issame
,
cfp_fp_issame
,
lfw_issame
,
cplfw
,
cplfw_issame
,
calfw
,
calfw_issame
return
agedb_30
,
cfp_fp
,
lfw
,
agedb_30_issame
,
cfp_fp_issame
,
lfw_issame
,
cplfw
,
cplfw_issame
,
calfw
,
calfw_issame
\ No newline at end of file
if
__name__
==
"__main__"
:
t_dataset
=
train_dataset
(
'/work/dataset/face/'
,
'/work/dataset/face/train_filter_label.txt'
,
1
,
1
,
1
)
img
=
t_dataset
.
__getitem__
(
100
)
print
(
len
(
t_dataset
))
val
=
FiveValidationDataset
(
'/work/dataset/face/faces_emore'
,
'/work/dataset/face/faces_emore/concat_validation_memfile'
)
a
=
1
ppcls/data/preprocess/__init__.py
浏览文件 @
32c99be6
...
@@ -33,6 +33,7 @@ from ppcls.data.preprocess.ops.operators import AugMix
...
@@ -33,6 +33,7 @@ from ppcls.data.preprocess.ops.operators import AugMix
from
ppcls.data.preprocess.ops.operators
import
Pad
from
ppcls.data.preprocess.ops.operators
import
Pad
from
ppcls.data.preprocess.ops.operators
import
ToTensor
from
ppcls.data.preprocess.ops.operators
import
ToTensor
from
ppcls.data.preprocess.ops.operators
import
Normalize
from
ppcls.data.preprocess.ops.operators
import
Normalize
from
ppcls.data.preprocess.ops.operators
import
RandomHorizontalFlip
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
OpSampler
,
FmixOperator
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
OpSampler
,
FmixOperator
...
...
ppcls/data/preprocess/ops/operators.py
浏览文件 @
32c99be6
...
@@ -25,7 +25,7 @@ import cv2
...
@@ -25,7 +25,7 @@ import cv2
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
,
ImageOps
,
__version__
as
PILLOW_VERSION
from
PIL
import
Image
,
ImageOps
,
__version__
as
PILLOW_VERSION
from
paddle.vision.transforms
import
ColorJitter
as
RawColorJitter
from
paddle.vision.transforms
import
ColorJitter
as
RawColorJitter
from
paddle.vision.transforms
import
ToTensor
,
Normalize
from
paddle.vision.transforms
import
ToTensor
,
Normalize
,
RandomHorizontalFlip
from
.autoaugment
import
ImageNetPolicy
from
.autoaugment
import
ImageNetPolicy
from
.functional
import
augmentations
from
.functional
import
augmentations
...
@@ -463,8 +463,8 @@ class Pad(object):
...
@@ -463,8 +463,8 @@ class Pad(object):
# Process fill color for affine transforms
# Process fill color for affine transforms
major_found
,
minor_found
=
(
int
(
v
)
major_found
,
minor_found
=
(
int
(
v
)
for
v
in
PILLOW_VERSION
.
split
(
'.'
)[:
2
])
for
v
in
PILLOW_VERSION
.
split
(
'.'
)[:
2
])
major_required
,
minor_required
=
(
major_required
,
minor_required
=
(
int
(
v
)
for
v
in
int
(
v
)
for
v
in
min_pil_version
.
split
(
'.'
)[:
2
])
min_pil_version
.
split
(
'.'
)[:
2
])
if
major_found
<
major_required
or
(
major_found
==
major_required
and
if
major_found
<
major_required
or
(
major_found
==
major_required
and
minor_found
<
minor_required
):
minor_found
<
minor_required
):
if
fill
is
None
:
if
fill
is
None
:
...
...
ppcls/engine/engine.py
浏览文件 @
32c99be6
...
@@ -116,7 +116,7 @@ class Engine(object):
...
@@ -116,7 +116,7 @@ class Engine(object):
self
.
config
[
"DataLoader"
],
"Train"
,
self
.
device
,
self
.
use_dali
)
self
.
config
[
"DataLoader"
],
"Train"
,
self
.
device
,
self
.
use_dali
)
if
self
.
mode
==
"eval"
or
(
self
.
mode
==
"train"
and
if
self
.
mode
==
"eval"
or
(
self
.
mode
==
"train"
and
self
.
config
[
"Global"
][
"eval_during_train"
]):
self
.
config
[
"Global"
][
"eval_during_train"
]):
if
self
.
eval_mode
==
"classification"
:
if
self
.
eval_mode
in
[
"classification"
,
"adaface"
]
:
self
.
eval_dataloader
=
build_dataloader
(
self
.
eval_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Eval"
,
self
.
device
,
self
.
config
[
"DataLoader"
],
"Eval"
,
self
.
device
,
self
.
use_dali
)
self
.
use_dali
)
...
...
ppcls/engine/evaluation/adaface.py
浏览文件 @
32c99be6
...
@@ -30,7 +30,7 @@ def fuse_features_with_norm(stacked_embeddings, stacked_norms):
...
@@ -30,7 +30,7 @@ def fuse_features_with_norm(stacked_embeddings, stacked_norms):
assert
stacked_embeddings
.
ndim
==
3
# (n_features_to_fuse, batch_size, channel)
assert
stacked_embeddings
.
ndim
==
3
# (n_features_to_fuse, batch_size, channel)
assert
stacked_norms
.
ndim
==
3
# (n_features_to_fuse, batch_size, 1)
assert
stacked_norms
.
ndim
==
3
# (n_features_to_fuse, batch_size, 1)
pre_norm_embeddings
=
stacked_embeddings
*
stacked_norms
pre_norm_embeddings
=
stacked_embeddings
*
stacked_norms
fused
=
pre_norm_embeddings
.
sum
(
dim
=
0
)
fused
=
pre_norm_embeddings
.
sum
(
axis
=
0
)
norm
=
paddle
.
norm
(
fused
,
2
,
1
,
True
)
norm
=
paddle
.
norm
(
fused
,
2
,
1
,
True
)
fused
=
paddle
.
divide
(
fused
,
norm
)
fused
=
paddle
.
divide
(
fused
,
norm
)
return
fused
,
norm
return
fused
,
norm
...
@@ -57,12 +57,14 @@ def adaface_eval(engine, epoch_id=0):
...
@@ -57,12 +57,14 @@ def adaface_eval(engine, epoch_id=0):
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
batch_size
=
images
.
shape
[
0
]
batch_size
=
images
.
shape
[
0
]
batch
[
0
]
=
paddle
.
to_tensor
(
images
)
batch
[
0
]
=
paddle
.
to_tensor
(
images
)
embeddings
=
engine
.
model
(
images
)[
"features"
]
embeddings
=
engine
.
model
(
images
,
labels
)[
'features'
]
norms
=
paddle
.
divide
(
embeddings
,
paddle
.
norm
(
embeddings
,
2
,
1
,
True
))
norms
=
paddle
.
divide
(
embeddings
,
paddle
.
norm
(
embeddings
,
2
,
1
,
True
))
embeddings
=
paddle
.
divide
(
embeddings
,
norms
)
fliped_images
=
paddle
.
flip
(
images
,
axis
=
[
3
])
fliped_images
=
paddle
.
flip
(
images
,
axis
=
[
3
])
flipped_embeddings
=
engine
.
model
(
fliped_images
)[
"features"
]
flipped_embeddings
=
engine
.
model
(
fliped_images
,
labels
)[
'features'
]
flipped_norms
=
paddle
.
divide
(
flipped_norms
=
paddle
.
divide
(
flipped_embeddings
,
paddle
.
norm
(
flipped_embeddings
,
2
,
1
,
True
))
flipped_embeddings
,
paddle
.
norm
(
flipped_embeddings
,
2
,
1
,
True
))
flipped_embeddings
=
paddle
.
divide
(
flipped_embeddings
,
flipped_norms
)
stacked_embeddings
=
paddle
.
stack
(
stacked_embeddings
=
paddle
.
stack
(
[
embeddings
,
flipped_embeddings
],
axis
=
0
)
[
embeddings
,
flipped_embeddings
],
axis
=
0
)
stacked_norms
=
paddle
.
stack
([
norms
,
flipped_norms
],
axis
=
0
)
stacked_norms
=
paddle
.
stack
([
norms
,
flipped_norms
],
axis
=
0
)
...
@@ -114,20 +116,21 @@ def adaface_eval(engine, epoch_id=0):
...
@@ -114,20 +116,21 @@ def adaface_eval(engine, epoch_id=0):
metric_msg
=
", "
.
join
([
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
])
])
face_msg
=
", "
.
join
(
face_msg
=
", "
.
join
([
[
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
])
for
key
in
eval_result
])
"{}: {:.5f}"
.
format
(
key
,
eval_result
[
key
])
for
key
in
eval_result
.
keys
()
])
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
+
", "
+
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
+
", "
+
face_msg
))
face_msg
))
# do not try to save best eval.model
if
engine
.
eval_metric_func
is
None
:
return
-
1
# return 1st metric in the dict
# return 1st metric in the dict
return
output_info
[
metric_key
].
avg
return
eval_result
[
'all_test_acc'
]
def
cal_metric
(
all_output_tensor
,
all_norm_tensor
,
all_target_tensor
,
def
cal_metric
(
all_output_tensor
,
all_norm_tensor
,
all_target_tensor
,
all_dataname_tensor
):
all_dataname_tensor
):
all_target_tensor
=
all_target_tensor
.
reshape
([
-
1
])
all_dataname_tensor
=
all_dataname_tensor
.
reshape
([
-
1
])
dataname_to_idx
=
{
dataname_to_idx
=
{
"agedb_30"
:
0
,
"agedb_30"
:
0
,
"cfp_fp"
:
1
,
"cfp_fp"
:
1
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录