Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
dc651d47
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dc651d47
编写于
5月 29, 2022
作者:
C
cuicheng01
提交者:
GitHub
5月 29, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'PaddlePaddle:develop' into develop
上级
a226a058
ad71254e
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
1441 addition
and
31 deletion
+1441
-31
deploy/configs/inference_attr.yaml
deploy/configs/inference_attr.yaml
+33
-0
deploy/images/Pedestrain_Attr.jpg
deploy/images/Pedestrain_Attr.jpg
+0
-0
deploy/python/postprocess.py
deploy/python/postprocess.py
+103
-2
deploy/python/predict_cls.py
deploy/python/predict_cls.py
+15
-7
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/legendary_models/resnet.py
ppcls/arch/backbone/legendary_models/resnet.py
+7
-3
ppcls/arch/backbone/model_zoo/adaface_ir_net.py
ppcls/arch/backbone/model_zoo/adaface_ir_net.py
+529
-0
ppcls/arch/gears/__init__.py
ppcls/arch/gears/__init__.py
+2
-1
ppcls/arch/gears/adamargin.py
ppcls/arch/gears/adamargin.py
+111
-0
ppcls/configs/Attr/StrongBaselineAttr.yaml
ppcls/configs/Attr/StrongBaselineAttr.yaml
+1
-2
ppcls/configs/metric_learning/adaface_ir18.yaml
ppcls/configs/metric_learning/adaface_ir18.yaml
+105
-0
ppcls/data/__init__.py
ppcls/data/__init__.py
+2
-1
ppcls/data/dataloader/__init__.py
ppcls/data/dataloader/__init__.py
+1
-0
ppcls/data/dataloader/face_dataset.py
ppcls/data/dataloader/face_dataset.py
+163
-0
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+4
-0
ppcls/data/preprocess/ops/operators.py
ppcls/data/preprocess/ops/operators.py
+93
-9
ppcls/engine/engine.py
ppcls/engine/engine.py
+7
-4
ppcls/engine/evaluation/__init__.py
ppcls/engine/evaluation/__init__.py
+1
-0
ppcls/engine/evaluation/adaface.py
ppcls/engine/evaluation/adaface.py
+260
-0
ppcls/metric/metrics.py
ppcls/metric/metrics.py
+1
-0
requirements.txt
requirements.txt
+2
-2
未找到文件。
deploy/configs/inference_attr.yaml
0 → 100644
浏览文件 @
dc651d47
Global
:
infer_imgs
:
"
./images/Pedestrain_Attr.jpg"
inference_model_dir
:
"
../inference/"
batch_size
:
1
use_gpu
:
True
enable_mkldnn
:
False
cpu_num_threads
:
10
enable_benchmark
:
True
use_fp16
:
False
ir_optim
:
True
use_tensorrt
:
False
gpu_mem
:
8000
enable_profile
:
False
PreProcess
:
transform_ops
:
-
ResizeImage
:
size
:
[
192
,
256
]
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
channel_num
:
3
-
ToCHWImage
:
PostProcess
:
main_indicator
:
Attribute
Attribute
:
threshold
:
0.5
#default threshold
glasses_threshold
:
0.3
#threshold only for glasses
hold_threshold
:
0.6
#threshold only for hold
\ No newline at end of file
deploy/images/Pedestrain_Attr.jpg
0 → 100644
浏览文件 @
dc651d47
12.2 KB
deploy/python/postprocess.py
浏览文件 @
dc651d47
...
...
@@ -64,9 +64,17 @@ class ThreshOutput(object):
for
idx
,
probs
in
enumerate
(
x
):
score
=
probs
[
1
]
if
score
<
self
.
threshold
:
result
=
{
"class_ids"
:
[
0
],
"scores"
:
[
1
-
score
],
"label_names"
:
[
self
.
label_0
]}
result
=
{
"class_ids"
:
[
0
],
"scores"
:
[
1
-
score
],
"label_names"
:
[
self
.
label_0
]
}
else
:
result
=
{
"class_ids"
:
[
1
],
"scores"
:
[
score
],
"label_names"
:
[
self
.
label_1
]}
result
=
{
"class_ids"
:
[
1
],
"scores"
:
[
score
],
"label_names"
:
[
self
.
label_1
]
}
if
file_names
is
not
None
:
result
[
"file_name"
]
=
file_names
[
idx
]
y
.
append
(
result
)
...
...
@@ -179,3 +187,96 @@ class Binarize(object):
byte
[:,
i
:
i
+
1
]
=
np
.
dot
(
x
[:,
i
*
8
:(
i
+
1
)
*
8
],
self
.
unit
)
return
byte
class
Attribute
(
object
):
def
__init__
(
self
,
threshold
=
0.5
,
glasses_threshold
=
0.3
,
hold_threshold
=
0.6
):
self
.
threshold
=
threshold
self
.
glasses_threshold
=
glasses_threshold
self
.
hold_threshold
=
hold_threshold
def
__call__
(
self
,
batch_preds
,
file_names
=
None
):
# postprocess output of predictor
age_list
=
[
'AgeLess18'
,
'Age18-60'
,
'AgeOver60'
]
direct_list
=
[
'Front'
,
'Side'
,
'Back'
]
bag_list
=
[
'HandBag'
,
'ShoulderBag'
,
'Backpack'
]
upper_list
=
[
'UpperStride'
,
'UpperLogo'
,
'UpperPlaid'
,
'UpperSplice'
]
lower_list
=
[
'LowerStripe'
,
'LowerPattern'
,
'LongCoat'
,
'Trousers'
,
'Shorts'
,
'Skirt&Dress'
]
batch_res
=
[]
for
res
in
batch_preds
:
res
=
res
.
tolist
()
label_res
=
[]
# gender
gender
=
'Female'
if
res
[
22
]
>
self
.
threshold
else
'Male'
label_res
.
append
(
gender
)
# age
age
=
age_list
[
np
.
argmax
(
res
[
19
:
22
])]
label_res
.
append
(
age
)
# direction
direction
=
direct_list
[
np
.
argmax
(
res
[
23
:])]
label_res
.
append
(
direction
)
# glasses
glasses
=
'Glasses: '
if
res
[
1
]
>
self
.
glasses_threshold
:
glasses
+=
'True'
else
:
glasses
+=
'False'
label_res
.
append
(
glasses
)
# hat
hat
=
'Hat: '
if
res
[
0
]
>
self
.
threshold
:
hat
+=
'True'
else
:
hat
+=
'False'
label_res
.
append
(
hat
)
# hold obj
hold_obj
=
'HoldObjectsInFront: '
if
res
[
18
]
>
self
.
hold_threshold
:
hold_obj
+=
'True'
else
:
hold_obj
+=
'False'
label_res
.
append
(
hold_obj
)
# bag
bag
=
bag_list
[
np
.
argmax
(
res
[
15
:
18
])]
bag_score
=
res
[
15
+
np
.
argmax
(
res
[
15
:
18
])]
bag_label
=
bag
if
bag_score
>
self
.
threshold
else
'No bag'
label_res
.
append
(
bag_label
)
# upper
upper_res
=
res
[
4
:
8
]
upper_label
=
'Upper:'
sleeve
=
'LongSleeve'
if
res
[
3
]
>
res
[
2
]
else
'ShortSleeve'
upper_label
+=
' {}'
.
format
(
sleeve
)
for
i
,
r
in
enumerate
(
upper_res
):
if
r
>
self
.
threshold
:
upper_label
+=
' {}'
.
format
(
upper_list
[
i
])
label_res
.
append
(
upper_label
)
# lower
lower_res
=
res
[
8
:
14
]
lower_label
=
'Lower: '
has_lower
=
False
for
i
,
l
in
enumerate
(
lower_res
):
if
l
>
self
.
threshold
:
lower_label
+=
' {}'
.
format
(
lower_list
[
i
])
has_lower
=
True
if
not
has_lower
:
lower_label
+=
' {}'
.
format
(
lower_list
[
np
.
argmax
(
lower_res
)])
label_res
.
append
(
lower_label
)
# shoe
shoe
=
'Boots'
if
res
[
14
]
>
self
.
threshold
else
'No boots'
label_res
.
append
(
shoe
)
threshold_list
=
[
0.5
]
*
len
(
res
)
threshold_list
[
1
]
=
self
.
glasses_threshold
threshold_list
[
18
]
=
self
.
hold_threshold
pred_res
=
(
np
.
array
(
res
)
>
np
.
array
(
threshold_list
)
).
astype
(
np
.
int8
).
tolist
()
batch_res
.
append
([
label_res
,
pred_res
])
return
batch_res
deploy/python/predict_cls.py
浏览文件 @
dc651d47
...
...
@@ -138,13 +138,21 @@ def main(config):
continue
batch_results
=
cls_predictor
.
predict
(
batch_imgs
)
for
number
,
result_dict
in
enumerate
(
batch_results
):
filename
=
batch_names
[
number
]
clas_ids
=
result_dict
[
"class_ids"
]
scores_str
=
"[{}]"
.
format
(
", "
.
join
(
"{:.2f}"
.
format
(
r
)
for
r
in
result_dict
[
"scores"
]))
label_names
=
result_dict
[
"label_names"
]
print
(
"{}:
\t
class id(s): {}, score(s): {}, label_name(s): {}"
.
format
(
filename
,
clas_ids
,
scores_str
,
label_names
))
if
"Attribute"
in
config
[
"PostProcess"
]:
filename
=
batch_names
[
number
]
attr_message
=
result_dict
[
0
]
pred_res
=
result_dict
[
1
]
print
(
"{}:
\t
attributes: {},
\n
predict output: {}"
.
format
(
filename
,
attr_message
,
pred_res
))
else
:
filename
=
batch_names
[
number
]
clas_ids
=
result_dict
[
"class_ids"
]
scores_str
=
"[{}]"
.
format
(
", "
.
join
(
"{:.2f}"
.
format
(
r
)
for
r
in
result_dict
[
"scores"
]))
label_names
=
result_dict
[
"label_names"
]
print
(
"{}:
\t
class id(s): {}, score(s): {}, label_name(s): {}"
.
format
(
filename
,
clas_ids
,
scores_str
,
label_names
))
batch_imgs
=
[]
batch_names
=
[]
if
cls_predictor
.
benchmark
:
...
...
ppcls/arch/backbone/__init__.py
浏览文件 @
dc651d47
...
...
@@ -70,6 +70,7 @@ from ppcls.arch.backbone.model_zoo.van import VAN_tiny
from
ppcls.arch.backbone.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
ppcls.arch.backbone.variant_models.vgg_variant
import
VGG19Sigmoid
from
ppcls.arch.backbone.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
ppcls.arch.backbone.model_zoo.adaface_ir_net
import
AdaFace_IR_18
,
AdaFace_IR_34
,
AdaFace_IR_50
,
AdaFace_IR_101
,
AdaFace_IR_152
,
AdaFace_IR_SE_50
,
AdaFace_IR_SE_101
,
AdaFace_IR_SE_152
,
AdaFace_IR_SE_200
# help whl get all the models' api (class type) and components' api (func type)
...
...
ppcls/arch/backbone/legendary_models/resnet.py
浏览文件 @
dc651d47
...
...
@@ -137,8 +137,11 @@ class ConvBNLayer(TheseusLayer):
weight_attr
=
ParamAttr
(
learning_rate
=
lr_mult
,
trainable
=
True
)
bias_attr
=
ParamAttr
(
learning_rate
=
lr_mult
,
trainable
=
True
)
self
.
bn
=
BatchNorm2D
(
num_filters
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
)
self
.
bn
=
BatchNorm
(
num_filters
,
param_attr
=
ParamAttr
(
learning_rate
=
lr_mult
),
bias_attr
=
ParamAttr
(
learning_rate
=
lr_mult
),
data_layout
=
data_format
)
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
...
...
@@ -287,7 +290,8 @@ class ResNet(TheseusLayer):
data_format
=
"NCHW"
,
input_image_channel
=
3
,
return_patterns
=
None
,
return_stages
=
None
):
return_stages
=
None
,
**
kargs
):
super
().
__init__
()
self
.
cfg
=
config
...
...
ppcls/arch/backbone/model_zoo/adaface_ir_net.py
0 → 100644
浏览文件 @
dc651d47
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this code is based on AdaFace(https://github.com/mk-minchul/AdaFace)
from
collections
import
namedtuple
import
paddle
import
paddle.nn
as
nn
from
paddle.nn
import
Dropout
from
paddle.nn
import
MaxPool2D
from
paddle.nn
import
Sequential
from
paddle.nn
import
Conv2D
,
Linear
from
paddle.nn
import
BatchNorm1D
,
BatchNorm2D
from
paddle.nn
import
ReLU
,
Sigmoid
from
paddle.nn
import
Layer
from
paddle.nn
import
PReLU
# from ppcls.arch.backbone.legendary_models.resnet import _load_pretrained
class
Flatten
(
Layer
):
""" Flat tensor
"""
def
forward
(
self
,
input
):
return
paddle
.
reshape
(
input
,
[
input
.
shape
[
0
],
-
1
])
class
LinearBlock
(
Layer
):
""" Convolution block without no-linear activation layer
"""
def
__init__
(
self
,
in_c
,
out_c
,
kernel
=
(
1
,
1
),
stride
=
(
1
,
1
),
padding
=
(
0
,
0
),
groups
=
1
):
super
(
LinearBlock
,
self
).
__init__
()
self
.
conv
=
Conv2D
(
in_c
,
out_c
,
kernel
,
stride
,
padding
,
groups
=
groups
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
None
)
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
1.0
))
bias_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
0.0
))
self
.
bn
=
BatchNorm2D
(
out_c
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
return
x
class
GNAP
(
Layer
):
""" Global Norm-Aware Pooling block
"""
def
__init__
(
self
,
in_c
):
super
(
GNAP
,
self
).
__init__
()
self
.
bn1
=
BatchNorm2D
(
in_c
,
weight_attr
=
False
,
bias_attr
=
False
)
self
.
pool
=
nn
.
AdaptiveAvgPool2D
((
1
,
1
))
self
.
bn2
=
BatchNorm1D
(
in_c
,
weight_attr
=
False
,
bias_attr
=
False
)
def
forward
(
self
,
x
):
x
=
self
.
bn1
(
x
)
x_norm
=
paddle
.
norm
(
x
,
2
,
1
,
True
)
x_norm_mean
=
paddle
.
mean
(
x_norm
)
weight
=
x_norm_mean
/
x_norm
x
=
x
*
weight
x
=
self
.
pool
(
x
)
x
=
x
.
view
(
x
.
shape
[
0
],
-
1
)
feature
=
self
.
bn2
(
x
)
return
feature
class
GDC
(
Layer
):
""" Global Depthwise Convolution block
"""
def
__init__
(
self
,
in_c
,
embedding_size
):
super
(
GDC
,
self
).
__init__
()
self
.
conv_6_dw
=
LinearBlock
(
in_c
,
in_c
,
groups
=
in_c
,
kernel
=
(
7
,
7
),
stride
=
(
1
,
1
),
padding
=
(
0
,
0
))
self
.
conv_6_flatten
=
Flatten
()
self
.
linear
=
Linear
(
in_c
,
embedding_size
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
)
self
.
bn
=
BatchNorm1D
(
embedding_size
,
weight_attr
=
False
,
bias_attr
=
False
)
def
forward
(
self
,
x
):
x
=
self
.
conv_6_dw
(
x
)
x
=
self
.
conv_6_flatten
(
x
)
x
=
self
.
linear
(
x
)
x
=
self
.
bn
(
x
)
return
x
class
SELayer
(
Layer
):
""" SE block
"""
def
__init__
(
self
,
channels
,
reduction
):
super
(
SELayer
,
self
).
__init__
()
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2D
(
1
)
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
XavierUniform
())
self
.
fc1
=
Conv2D
(
channels
,
channels
//
reduction
,
kernel_size
=
1
,
padding
=
0
,
weight_attr
=
weight_attr
,
bias_attr
=
False
)
self
.
relu
=
ReLU
()
self
.
fc2
=
Conv2D
(
channels
//
reduction
,
channels
,
kernel_size
=
1
,
padding
=
0
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
)
self
.
sigmoid
=
Sigmoid
()
def
forward
(
self
,
x
):
module_input
=
x
x
=
self
.
avg_pool
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
sigmoid
(
x
)
return
module_input
*
x
class
BasicBlockIR
(
Layer
):
""" BasicBlock for IRNet
"""
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
super
(
BasicBlockIR
,
self
).
__init__
()
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
1.0
))
bias_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
0.0
))
if
in_channel
==
depth
:
self
.
shortcut_layer
=
MaxPool2D
(
1
,
stride
)
else
:
self
.
shortcut_layer
=
Sequential
(
Conv2D
(
in_channel
,
depth
,
(
1
,
1
),
stride
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
))
self
.
res_layer
=
Sequential
(
BatchNorm2D
(
in_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
Conv2D
(
in_channel
,
depth
,
(
3
,
3
),
(
1
,
1
),
1
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
PReLU
(
depth
),
Conv2D
(
depth
,
depth
,
(
3
,
3
),
stride
,
1
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
))
def
forward
(
self
,
x
):
shortcut
=
self
.
shortcut_layer
(
x
)
res
=
self
.
res_layer
(
x
)
return
res
+
shortcut
class
BottleneckIR
(
Layer
):
""" BasicBlock with bottleneck for IRNet
"""
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
super
(
BottleneckIR
,
self
).
__init__
()
reduction_channel
=
depth
//
4
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
1.0
))
bias_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
0.0
))
if
in_channel
==
depth
:
self
.
shortcut_layer
=
MaxPool2D
(
1
,
stride
)
else
:
self
.
shortcut_layer
=
Sequential
(
Conv2D
(
in_channel
,
depth
,
(
1
,
1
),
stride
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
))
self
.
res_layer
=
Sequential
(
BatchNorm2D
(
in_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
Conv2D
(
in_channel
,
reduction_channel
,
(
1
,
1
),
(
1
,
1
),
0
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
reduction_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
PReLU
(
reduction_channel
),
Conv2D
(
reduction_channel
,
reduction_channel
,
(
3
,
3
),
(
1
,
1
),
1
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
reduction_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
PReLU
(
reduction_channel
),
Conv2D
(
reduction_channel
,
depth
,
(
1
,
1
),
stride
,
0
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
depth
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
))
def
forward
(
self
,
x
):
shortcut
=
self
.
shortcut_layer
(
x
)
res
=
self
.
res_layer
(
x
)
return
res
+
shortcut
class
BasicBlockIRSE
(
BasicBlockIR
):
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
super
(
BasicBlockIRSE
,
self
).
__init__
(
in_channel
,
depth
,
stride
)
self
.
res_layer
.
add_sublayer
(
"se_block"
,
SELayer
(
depth
,
16
))
class
BottleneckIRSE
(
BottleneckIR
):
def
__init__
(
self
,
in_channel
,
depth
,
stride
):
super
(
BottleneckIRSE
,
self
).
__init__
(
in_channel
,
depth
,
stride
)
self
.
res_layer
.
add_sublayer
(
"se_block"
,
SELayer
(
depth
,
16
))
class
Bottleneck
(
namedtuple
(
'Block'
,
[
'in_channel'
,
'depth'
,
'stride'
])):
'''A named tuple describing a ResNet block.'''
def
get_block
(
in_channel
,
depth
,
num_units
,
stride
=
2
):
return
[
Bottleneck
(
in_channel
,
depth
,
stride
)]
+
\
[
Bottleneck
(
depth
,
depth
,
1
)
for
i
in
range
(
num_units
-
1
)]
def
get_blocks
(
num_layers
):
if
num_layers
==
18
:
blocks
=
[
get_block
(
in_channel
=
64
,
depth
=
64
,
num_units
=
2
),
get_block
(
in_channel
=
64
,
depth
=
128
,
num_units
=
2
),
get_block
(
in_channel
=
128
,
depth
=
256
,
num_units
=
2
),
get_block
(
in_channel
=
256
,
depth
=
512
,
num_units
=
2
)
]
elif
num_layers
==
34
:
blocks
=
[
get_block
(
in_channel
=
64
,
depth
=
64
,
num_units
=
3
),
get_block
(
in_channel
=
64
,
depth
=
128
,
num_units
=
4
),
get_block
(
in_channel
=
128
,
depth
=
256
,
num_units
=
6
),
get_block
(
in_channel
=
256
,
depth
=
512
,
num_units
=
3
)
]
elif
num_layers
==
50
:
blocks
=
[
get_block
(
in_channel
=
64
,
depth
=
64
,
num_units
=
3
),
get_block
(
in_channel
=
64
,
depth
=
128
,
num_units
=
4
),
get_block
(
in_channel
=
128
,
depth
=
256
,
num_units
=
14
),
get_block
(
in_channel
=
256
,
depth
=
512
,
num_units
=
3
)
]
elif
num_layers
==
100
:
blocks
=
[
get_block
(
in_channel
=
64
,
depth
=
64
,
num_units
=
3
),
get_block
(
in_channel
=
64
,
depth
=
128
,
num_units
=
13
),
get_block
(
in_channel
=
128
,
depth
=
256
,
num_units
=
30
),
get_block
(
in_channel
=
256
,
depth
=
512
,
num_units
=
3
)
]
elif
num_layers
==
152
:
blocks
=
[
get_block
(
in_channel
=
64
,
depth
=
256
,
num_units
=
3
),
get_block
(
in_channel
=
256
,
depth
=
512
,
num_units
=
8
),
get_block
(
in_channel
=
512
,
depth
=
1024
,
num_units
=
36
),
get_block
(
in_channel
=
1024
,
depth
=
2048
,
num_units
=
3
)
]
elif
num_layers
==
200
:
blocks
=
[
get_block
(
in_channel
=
64
,
depth
=
256
,
num_units
=
3
),
get_block
(
in_channel
=
256
,
depth
=
512
,
num_units
=
24
),
get_block
(
in_channel
=
512
,
depth
=
1024
,
num_units
=
36
),
get_block
(
in_channel
=
1024
,
depth
=
2048
,
num_units
=
3
)
]
return
blocks
class
Backbone
(
Layer
):
def
__init__
(
self
,
input_size
,
num_layers
,
mode
=
'ir'
):
""" Args:
input_size: input_size of backbone
num_layers: num_layers of backbone
mode: support ir or irse
"""
super
(
Backbone
,
self
).
__init__
()
assert
input_size
[
0
]
in
[
112
,
224
],
\
"input_size should be [112, 112] or [224, 224]"
assert
num_layers
in
[
18
,
34
,
50
,
100
,
152
,
200
],
\
"num_layers should be 18, 34, 50, 100 or 152"
assert
mode
in
[
'ir'
,
'ir_se'
],
\
"mode should be ir or ir_se"
weight_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
1.0
))
bias_attr
=
paddle
.
ParamAttr
(
regularizer
=
None
,
initializer
=
nn
.
initializer
.
Constant
(
value
=
0.0
))
self
.
input_layer
=
Sequential
(
Conv2D
(
3
,
64
,
(
3
,
3
),
1
,
1
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
(),
bias_attr
=
False
),
BatchNorm2D
(
64
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
PReLU
(
64
))
blocks
=
get_blocks
(
num_layers
)
if
num_layers
<=
100
:
if
mode
==
'ir'
:
unit_module
=
BasicBlockIR
elif
mode
==
'ir_se'
:
unit_module
=
BasicBlockIRSE
output_channel
=
512
else
:
if
mode
==
'ir'
:
unit_module
=
BottleneckIR
elif
mode
==
'ir_se'
:
unit_module
=
BottleneckIRSE
output_channel
=
2048
if
input_size
[
0
]
==
112
:
self
.
output_layer
=
Sequential
(
BatchNorm2D
(
output_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
Dropout
(
0.4
),
Flatten
(),
Linear
(
output_channel
*
7
*
7
,
512
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
()),
BatchNorm1D
(
512
,
weight_attr
=
False
,
bias_attr
=
False
))
else
:
self
.
output_layer
=
Sequential
(
BatchNorm2D
(
output_channel
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
),
Dropout
(
0.4
),
Flatten
(),
Linear
(
output_channel
*
14
*
14
,
512
,
weight_attr
=
nn
.
initializer
.
KaimingNormal
()),
BatchNorm1D
(
512
,
weight_attr
=
False
,
bias_attr
=
False
))
modules
=
[]
for
block
in
blocks
:
for
bottleneck
in
block
:
modules
.
append
(
unit_module
(
bottleneck
.
in_channel
,
bottleneck
.
depth
,
bottleneck
.
stride
))
self
.
body
=
Sequential
(
*
modules
)
# initialize_weights(self.modules())
def
forward
(
self
,
x
):
# current code only supports one extra image
# it comes with a extra dimension for number of extra image. We will just squeeze it out for now
x
=
self
.
input_layer
(
x
)
for
idx
,
module
in
enumerate
(
self
.
body
):
x
=
module
(
x
)
x
=
self
.
output_layer
(
x
)
# norm = paddle.norm(x, 2, 1, True)
# output = paddle.divide(x, norm)
# return output, norm
return
x
def
AdaFace_IR_18
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-18 model.
"""
model
=
Backbone
(
input_size
,
18
,
'ir'
)
return
model
def
AdaFace_IR_34
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-34 model.
"""
model
=
Backbone
(
input_size
,
34
,
'ir'
)
return
model
def
AdaFace_IR_50
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-50 model.
"""
model
=
Backbone
(
input_size
,
50
,
'ir'
)
return
model
def
AdaFace_IR_101
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-101 model.
"""
model
=
Backbone
(
input_size
,
100
,
'ir'
)
return
model
def
AdaFace_IR_152
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-152 model.
"""
model
=
Backbone
(
input_size
,
152
,
'ir'
)
return
model
def
AdaFace_IR_200
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-200 model.
"""
model
=
Backbone
(
input_size
,
200
,
'ir'
)
return
model
def
AdaFace_IR_SE_50
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-50 model.
"""
model
=
Backbone
(
input_size
,
50
,
'ir_se'
)
return
model
def
AdaFace_IR_SE_101
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-101 model.
"""
model
=
Backbone
(
input_size
,
100
,
'ir_se'
)
return
model
def
AdaFace_IR_SE_152
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-152 model.
"""
model
=
Backbone
(
input_size
,
152
,
'ir_se'
)
return
model
def
AdaFace_IR_SE_200
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-200 model.
"""
model
=
Backbone
(
input_size
,
200
,
'ir_se'
)
return
model
ppcls/arch/gears/__init__.py
浏览文件 @
dc651d47
...
...
@@ -19,6 +19,7 @@ from .fc import FC
from
.vehicle_neck
import
VehicleNeck
from
paddle.nn
import
Tanh
from
.bnneck
import
BNNeck
from
.adamargin
import
AdaMargin
__all__
=
[
'build_gear'
]
...
...
@@ -26,7 +27,7 @@ __all__ = ['build_gear']
def
build_gear
(
config
):
support_dict
=
[
'ArcMargin'
,
'CosMargin'
,
'CircleMargin'
,
'FC'
,
'VehicleNeck'
,
'Tanh'
,
'BNNeck'
'BNNeck'
,
'AdaMargin'
]
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
...
...
ppcls/arch/gears/adamargin.py
0 → 100644
浏览文件 @
dc651d47
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is based on AdaFace(https://github.com/mk-minchul/AdaFace)
# Paper: AdaFace: Quality Adaptive Margin for Face Recognition
from
paddle.nn
import
Layer
import
math
import
paddle
def
l2_norm
(
input
,
axis
=
1
):
norm
=
paddle
.
norm
(
input
,
2
,
axis
,
True
)
output
=
paddle
.
divide
(
input
,
norm
)
return
output
class
AdaMargin
(
Layer
):
def
__init__
(
self
,
embedding_size
=
512
,
class_num
=
70722
,
m
=
0.4
,
h
=
0.333
,
s
=
64.
,
t_alpha
=
1.0
,
):
super
(
AdaMargin
,
self
).
__init__
()
self
.
classnum
=
class_num
kernel_weight
=
paddle
.
uniform
(
[
embedding_size
,
class_num
],
min
=-
1
,
max
=
1
)
kernel_weight_norm
=
paddle
.
norm
(
kernel_weight
,
p
=
2
,
axis
=
0
,
keepdim
=
True
)
kernel_weight_norm
=
paddle
.
where
(
kernel_weight_norm
>
1e-5
,
kernel_weight_norm
,
paddle
.
ones_like
(
kernel_weight_norm
))
kernel_weight
=
kernel_weight
/
kernel_weight_norm
self
.
kernel
=
self
.
create_parameter
(
[
embedding_size
,
class_num
],
attr
=
paddle
.
nn
.
initializer
.
Assign
(
kernel_weight
))
# initial kernel
# self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
self
.
m
=
m
self
.
eps
=
1e-3
self
.
h
=
h
self
.
s
=
s
# ema prep
self
.
t_alpha
=
t_alpha
self
.
register_buffer
(
't'
,
paddle
.
zeros
([
1
]),
persistable
=
True
)
self
.
register_buffer
(
'batch_mean'
,
paddle
.
ones
([
1
])
*
20
,
persistable
=
True
)
self
.
register_buffer
(
'batch_std'
,
paddle
.
ones
([
1
])
*
100
,
persistable
=
True
)
def
forward
(
self
,
embbedings
,
label
):
norms
=
paddle
.
norm
(
embbedings
,
2
,
1
,
True
)
embbedings
=
paddle
.
divide
(
embbedings
,
norms
)
kernel_norm
=
l2_norm
(
self
.
kernel
,
axis
=
0
)
cosine
=
paddle
.
mm
(
embbedings
,
kernel_norm
)
cosine
=
paddle
.
clip
(
cosine
,
-
1
+
self
.
eps
,
1
-
self
.
eps
)
# for stability
safe_norms
=
paddle
.
clip
(
norms
,
min
=
0.001
,
max
=
100
)
# for stability
safe_norms
=
safe_norms
.
clone
().
detach
()
# update batchmean batchstd
with
paddle
.
no_grad
():
mean
=
safe_norms
.
mean
().
detach
()
std
=
safe_norms
.
std
().
detach
()
self
.
batch_mean
=
mean
*
self
.
t_alpha
+
(
1
-
self
.
t_alpha
)
*
self
.
batch_mean
self
.
batch_std
=
std
*
self
.
t_alpha
+
(
1
-
self
.
t_alpha
)
*
self
.
batch_std
margin_scaler
=
(
safe_norms
-
self
.
batch_mean
)
/
(
self
.
batch_std
+
self
.
eps
)
# 66% between -1, 1
margin_scaler
=
margin_scaler
*
self
.
h
# 68% between -0.333 ,0.333 when h:0.333
margin_scaler
=
paddle
.
clip
(
margin_scaler
,
-
1
,
1
)
# g_angular
m_arc
=
paddle
.
nn
.
functional
.
one_hot
(
label
.
reshape
([
-
1
]),
self
.
classnum
)
g_angular
=
self
.
m
*
margin_scaler
*
-
1
m_arc
=
m_arc
*
g_angular
theta
=
paddle
.
acos
(
cosine
)
theta_m
=
paddle
.
clip
(
theta
+
m_arc
,
min
=
self
.
eps
,
max
=
math
.
pi
-
self
.
eps
)
cosine
=
paddle
.
cos
(
theta_m
)
# g_additive
m_cos
=
paddle
.
nn
.
functional
.
one_hot
(
label
.
reshape
([
-
1
]),
self
.
classnum
)
g_add
=
self
.
m
+
(
self
.
m
*
margin_scaler
)
m_cos
=
m_cos
*
g_add
cosine
=
cosine
-
m_cos
# scale
scaled_cosine_m
=
cosine
*
self
.
s
return
scaled_cosine_m
ppcls/configs/Attr/StrongBaselineAttr.yaml
浏览文件 @
dc651d47
...
...
@@ -20,6 +20,7 @@ Arch:
name
:
"
ResNet50"
pretrained
:
True
class_num
:
26
infer_add_softmax
:
False
# loss function config for traing/eval process
Loss
:
...
...
@@ -110,5 +111,3 @@ DataLoader:
Metric
:
Eval
:
-
ATTRMetric
:
ppcls/configs/metric_learning/adaface_ir18.yaml
0 → 100644
浏览文件 @
dc651d47
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
"
./output/"
device
:
"
gpu"
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
26
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
112
,
112
]
save_inference_dir
:
"
./inference"
eval_mode
:
"
adaface"
# model architecture
Arch
:
name
:
"
RecModel"
infer_output_key
:
"
features"
infer_add_softmax
:
False
Backbone
:
name
:
"
AdaFace_IR_18"
input_size
:
[
112
,
112
]
Head
:
name
:
"
AdaMargin"
embedding_size
:
512
class_num
:
70722
m
:
0.4
s
:
64
h
:
0.333
t_alpha
:
0.01
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Piecewise
learning_rate
:
0.1
decay_epochs
:
[
12
,
20
,
24
]
values
:
[
0.1
,
0.01
,
0.001
,
0.0001
]
regularizer
:
name
:
'
L2'
coeff
:
0.0005
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
"
AdaFaceDataset"
root_dir
:
"
dataset/face/"
label_path
:
"
dataset/face/train_filter_label.txt"
transform
:
-
CropWithPadding
:
prob
:
0.2
padding_num
:
0
size
:
[
112
,
112
]
scale
:
[
0.2
,
1.0
]
ratio
:
[
0.75
,
1.3333333333333333
]
-
RandomInterpolationAugment
:
prob
:
0.2
-
ColorJitter
:
prob
:
0.2
brightness
:
0.5
contrast
:
0.5
saturation
:
0.5
hue
:
0
-
RandomHorizontalFlip
:
-
ToTensor
:
-
Normalize
:
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
sampler
:
name
:
DistributedBatchSampler
batch_size
:
256
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
6
use_shared_memory
:
True
Eval
:
dataset
:
name
:
FiveValidationDataset
val_data_path
:
dataset/face/faces_emore
concat_mem_file_name
:
dataset/face/faces_emore/concat_validation_memfile
sampler
:
name
:
BatchSampler
batch_size
:
256
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
6
use_shared_memory
:
True
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
\ No newline at end of file
ppcls/data/__init__.py
浏览文件 @
dc651d47
...
...
@@ -30,6 +30,7 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from
ppcls.data.dataloader.mix_dataset
import
MixDataset
from
ppcls.data.dataloader.multi_scale_dataset
import
MultiScaleDataset
from
ppcls.data.dataloader.person_dataset
import
Market1501
,
MSMT17
from
ppcls.data.dataloader.face_dataset
import
FiveValidationDataset
,
AdaFaceDataset
# sampler
...
...
@@ -88,7 +89,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
# build sampler
config_sampler
=
config
[
mode
][
'sampler'
]
if
"name"
not
in
config_sampler
:
if
config_sampler
and
"name"
not
in
config_sampler
:
batch_sampler
=
None
batch_size
=
config_sampler
[
"batch_size"
]
drop_last
=
config_sampler
[
"drop_last"
]
...
...
ppcls/data/dataloader/__init__.py
浏览文件 @
dc651d47
...
...
@@ -10,3 +10,4 @@ from ppcls.data.dataloader.mix_sampler import MixSampler
from
ppcls.data.dataloader.multi_scale_sampler
import
MultiScaleSampler
from
ppcls.data.dataloader.pk_sampler
import
PKSampler
from
ppcls.data.dataloader.person_dataset
import
Market1501
,
MSMT17
from
ppcls.data.dataloader.face_dataset
import
AdaFaceDataset
,
FiveValidationDataset
ppcls/data/dataloader/face_dataset.py
0 → 100644
浏览文件 @
dc651d47
import
os
import
json
import
numpy
as
np
from
PIL
import
Image
import
cv2
import
paddle
import
paddle.vision.datasets
as
datasets
from
paddle.vision
import
transforms
from
paddle.vision.transforms
import
functional
as
F
from
paddle.io
import
Dataset
from
.common_dataset
import
create_operators
from
ppcls.data.preprocess
import
transform
as
transform_func
# code is based on AdaFace: https://github.com/mk-minchul/AdaFace
class
AdaFaceDataset
(
Dataset
):
def
__init__
(
self
,
root_dir
,
label_path
,
transform
=
None
):
self
.
root_dir
=
root_dir
self
.
transform
=
create_operators
(
transform
)
with
open
(
label_path
)
as
fd
:
lines
=
fd
.
readlines
()
self
.
samples
=
[]
for
l
in
lines
:
l
=
l
.
strip
().
split
()
self
.
samples
.
append
([
os
.
path
.
join
(
root_dir
,
l
[
0
]),
int
(
l
[
1
])])
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
index
):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
[
path
,
target
]
=
self
.
samples
[
index
]
with
open
(
path
,
'rb'
)
as
f
:
img
=
Image
.
open
(
f
)
sample
=
img
.
convert
(
'RGB'
)
# if 'WebFace' in self.root:
# # swap rgb to bgr since image is in rgb for webface
# sample = Image.fromarray(np.asarray(sample)[:, :, ::-1]
if
self
.
transform
is
not
None
:
sample
=
transform_func
(
sample
,
self
.
transform
)
return
sample
,
target
class
FiveValidationDataset
(
Dataset
):
def
__init__
(
self
,
val_data_path
,
concat_mem_file_name
):
'''
concatenates all validation datasets from emore
val_data_dict = {
'agedb_30': (agedb_30, agedb_30_issame),
"cfp_fp": (cfp_fp, cfp_fp_issame),
"lfw": (lfw, lfw_issame),
"cplfw": (cplfw, cplfw_issame),
"calfw": (calfw, calfw_issame),
}
agedb_30: 0
cfp_fp: 1
lfw: 2
cplfw: 3
calfw: 4
'''
val_data
=
get_val_data
(
val_data_path
)
age_30
,
cfp_fp
,
lfw
,
age_30_issame
,
cfp_fp_issame
,
lfw_issame
,
cplfw
,
cplfw_issame
,
calfw
,
calfw_issame
=
val_data
val_data_dict
=
{
'agedb_30'
:
(
age_30
,
age_30_issame
),
"cfp_fp"
:
(
cfp_fp
,
cfp_fp_issame
),
"lfw"
:
(
lfw
,
lfw_issame
),
"cplfw"
:
(
cplfw
,
cplfw_issame
),
"calfw"
:
(
calfw
,
calfw_issame
),
}
self
.
dataname_to_idx
=
{
"agedb_30"
:
0
,
"cfp_fp"
:
1
,
"lfw"
:
2
,
"cplfw"
:
3
,
"calfw"
:
4
}
self
.
val_data_dict
=
val_data_dict
# concat all dataset
all_imgs
=
[]
all_issame
=
[]
all_dataname
=
[]
key_orders
=
[]
for
key
,
(
imgs
,
issame
)
in
val_data_dict
.
items
():
all_imgs
.
append
(
imgs
)
dup_issame
=
[
]
# hacky way to make the issame length same as imgs. [1, 1, 0, 0, ...]
for
same
in
issame
:
dup_issame
.
append
(
same
)
dup_issame
.
append
(
same
)
all_issame
.
append
(
dup_issame
)
all_dataname
.
append
([
self
.
dataname_to_idx
[
key
]]
*
len
(
imgs
))
key_orders
.
append
(
key
)
assert
key_orders
==
[
'agedb_30'
,
'cfp_fp'
,
'lfw'
,
'cplfw'
,
'calfw'
]
if
isinstance
(
all_imgs
[
0
],
np
.
memmap
):
self
.
all_imgs
=
read_memmap
(
concat_mem_file_name
)
else
:
self
.
all_imgs
=
np
.
concatenate
(
all_imgs
)
self
.
all_issame
=
np
.
concatenate
(
all_issame
)
self
.
all_dataname
=
np
.
concatenate
(
all_dataname
)
def
__getitem__
(
self
,
index
):
x_np
=
self
.
all_imgs
[
index
].
copy
()
x
=
paddle
.
to_tensor
(
x_np
)
y
=
self
.
all_issame
[
index
]
dataname
=
self
.
all_dataname
[
index
]
return
x
,
y
,
dataname
,
index
def
__len__
(
self
):
return
len
(
self
.
all_imgs
)
def
read_memmap
(
mem_file_name
):
# r+ mode: Open existing file for reading and writing
with
open
(
mem_file_name
+
'.conf'
,
'r'
)
as
file
:
memmap_configs
=
json
.
load
(
file
)
return
np
.
memmap
(
mem_file_name
,
mode
=
'r+'
,
\
shape
=
tuple
(
memmap_configs
[
'shape'
]),
\
dtype
=
memmap_configs
[
'dtype'
])
def
get_val_pair
(
path
,
name
,
use_memfile
=
True
):
# installing bcolz should set proxy to access internet
import
bcolz
if
use_memfile
:
mem_file_dir
=
os
.
path
.
join
(
path
,
name
,
'memfile'
)
mem_file_name
=
os
.
path
.
join
(
mem_file_dir
,
'mem_file.dat'
)
if
os
.
path
.
isdir
(
mem_file_dir
):
print
(
'laoding validation data memfile'
)
np_array
=
read_memmap
(
mem_file_name
)
else
:
os
.
makedirs
(
mem_file_dir
)
carray
=
bcolz
.
carray
(
rootdir
=
os
.
path
.
join
(
path
,
name
),
mode
=
'r'
)
np_array
=
np
.
array
(
carray
)
# mem_array = make_memmap(mem_file_name, np_array)
# del np_array, mem_array
del
np_array
np_array
=
read_memmap
(
mem_file_name
)
else
:
np_array
=
bcolz
.
carray
(
rootdir
=
os
.
path
.
join
(
path
,
name
),
mode
=
'r'
)
issame
=
np
.
load
(
os
.
path
.
join
(
path
,
'{}_list.npy'
.
format
(
name
)))
return
np_array
,
issame
def
get_val_data
(
data_path
):
agedb_30
,
agedb_30_issame
=
get_val_pair
(
data_path
,
'agedb_30'
)
cfp_fp
,
cfp_fp_issame
=
get_val_pair
(
data_path
,
'cfp_fp'
)
lfw
,
lfw_issame
=
get_val_pair
(
data_path
,
'lfw'
)
cplfw
,
cplfw_issame
=
get_val_pair
(
data_path
,
'cplfw'
)
calfw
,
calfw_issame
=
get_val_pair
(
data_path
,
'calfw'
)
return
agedb_30
,
cfp_fp
,
lfw
,
agedb_30_issame
,
cfp_fp_issame
,
lfw_issame
,
cplfw
,
cplfw_issame
,
calfw
,
calfw_issame
ppcls/data/preprocess/__init__.py
浏览文件 @
dc651d47
...
...
@@ -33,6 +33,10 @@ from ppcls.data.preprocess.ops.operators import AugMix
from
ppcls.data.preprocess.ops.operators
import
Pad
from
ppcls.data.preprocess.ops.operators
import
ToTensor
from
ppcls.data.preprocess.ops.operators
import
Normalize
from
ppcls.data.preprocess.ops.operators
import
RandomHorizontalFlip
from
ppcls.data.preprocess.ops.operators
import
CropWithPadding
from
ppcls.data.preprocess.ops.operators
import
RandomInterpolationAugment
from
ppcls.data.preprocess.ops.operators
import
ColorJitter
from
ppcls.data.preprocess.ops.operators
import
RandomCropImage
from
ppcls.data.preprocess.ops.operators
import
Padv2
...
...
ppcls/data/preprocess/ops/operators.py
浏览文件 @
dc651d47
...
...
@@ -25,8 +25,8 @@ import cv2
import
numpy
as
np
from
PIL
import
Image
,
ImageOps
,
__version__
as
PILLOW_VERSION
from
paddle.vision.transforms
import
ColorJitter
as
RawColorJitter
from
paddle.vision.transforms
import
ToTensor
,
Normalize
from
paddle.vision.transforms
import
ToTensor
,
Normalize
,
RandomHorizontalFlip
,
RandomResizedCrop
from
paddle.vision.transforms
import
functional
as
F
from
.autoaugment
import
ImageNetPolicy
from
.functional
import
augmentations
from
ppcls.utils
import
logger
...
...
@@ -93,6 +93,42 @@ class UnifiedResize(object):
return
self
.
resize_func
(
src
,
size
)
class
RandomInterpolationAugment
(
object
):
def
__init__
(
self
,
prob
):
self
.
prob
=
prob
def
_aug
(
self
,
img
):
img_shape
=
img
.
shape
side_ratio
=
np
.
random
.
uniform
(
0.2
,
1.0
)
small_side
=
int
(
side_ratio
*
img_shape
[
0
])
interpolation
=
np
.
random
.
choice
([
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_AREA
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
])
small_img
=
cv2
.
resize
(
img
,
(
small_side
,
small_side
),
interpolation
=
interpolation
)
interpolation
=
np
.
random
.
choice
([
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_AREA
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
])
aug_img
=
cv2
.
resize
(
small_img
,
(
img_shape
[
1
],
img_shape
[
0
]),
interpolation
=
interpolation
)
return
aug_img
def
__call__
(
self
,
img
):
if
np
.
random
.
random
()
<
self
.
prob
:
if
isinstance
(
img
,
np
.
ndarray
):
return
self
.
_aug
(
img
)
else
:
pil_img
=
np
.
array
(
img
)
aug_img
=
self
.
_aug
(
pil_img
)
img
=
Image
.
fromarray
(
aug_img
.
astype
(
np
.
uint8
))
return
img
else
:
return
img
class
OperatorParamError
(
ValueError
):
""" OperatorParamError
"""
...
...
@@ -170,6 +206,52 @@ class ResizeImage(object):
return
self
.
_resize_func
(
img
,
(
w
,
h
))
class
CropWithPadding
(
RandomResizedCrop
):
"""
crop image and padding to original size
"""
def
__init__
(
self
,
prob
=
1
,
padding_num
=
0
,
size
=
224
,
scale
=
(
0.08
,
1.0
),
ratio
=
(
3.
/
4
,
4.
/
3
),
interpolation
=
'bilinear'
,
key
=
None
):
super
().
__init__
(
size
,
scale
,
ratio
,
interpolation
,
key
)
self
.
prob
=
prob
self
.
padding_num
=
padding_num
def
__call__
(
self
,
img
):
is_cv2_img
=
False
if
isinstance
(
img
,
np
.
ndarray
):
flag
=
True
if
np
.
random
.
random
()
<
self
.
prob
:
# RandomResizedCrop augmentation
new
=
np
.
zeros_like
(
np
.
array
(
img
))
+
self
.
padding_num
# orig_W, orig_H = F._get_image_size(sample)
orig_W
,
orig_H
=
self
.
_get_image_size
(
img
)
i
,
j
,
h
,
w
=
self
.
_get_param
(
img
)
cropped
=
F
.
crop
(
img
,
i
,
j
,
h
,
w
)
new
[
i
:
i
+
h
,
j
:
j
+
w
,
:]
=
np
.
array
(
cropped
)
if
not
isinstance
:
new
=
Image
.
fromarray
(
new
.
astype
(
np
.
uint8
))
return
new
else
:
return
img
def
_get_image_size
(
self
,
img
):
if
F
.
_is_pil_image
(
img
):
return
img
.
size
elif
F
.
_is_numpy_image
(
img
):
return
img
.
shape
[:
2
][::
-
1
]
elif
F
.
_is_tensor_image
(
img
):
return
img
.
shape
[
1
:][::
-
1
]
# chw
else
:
raise
TypeError
(
"Unexpected type {}"
.
format
(
type
(
img
)))
class
CropImage
(
object
):
""" crop image """
...
...
@@ -533,16 +615,18 @@ class ColorJitter(RawColorJitter):
"""ColorJitter.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
prob
=
2
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
prob
=
prob
def
__call__
(
self
,
img
):
if
not
isinstance
(
img
,
Image
.
Image
):
img
=
np
.
ascontiguousarray
(
img
)
img
=
Image
.
fromarray
(
img
)
img
=
super
().
_apply_image
(
img
)
if
isinstance
(
img
,
Image
.
Image
):
img
=
np
.
asarray
(
img
)
if
np
.
random
.
random
()
<
self
.
prob
:
if
not
isinstance
(
img
,
Image
.
Image
):
img
=
np
.
ascontiguousarray
(
img
)
img
=
Image
.
fromarray
(
img
)
img
=
super
().
_apply_image
(
img
)
if
isinstance
(
img
,
Image
.
Image
):
img
=
np
.
asarray
(
img
)
return
img
...
...
ppcls/engine/engine.py
浏览文件 @
dc651d47
...
...
@@ -75,8 +75,9 @@ class Engine(object):
print_config
(
config
)
# init train_func and eval_func
assert
self
.
eval_mode
in
[
"classification"
,
"retrieval"
],
logger
.
error
(
"Invalid eval mode: {}"
.
format
(
self
.
eval_mode
))
assert
self
.
eval_mode
in
[
"classification"
,
"retrieval"
,
"adaface"
],
logger
.
error
(
"Invalid eval mode: {}"
.
format
(
self
.
eval_mode
))
self
.
train_epoch_func
=
train_epoch
self
.
eval_func
=
getattr
(
evaluation
,
self
.
eval_mode
+
"_eval"
)
...
...
@@ -115,7 +116,7 @@ class Engine(object):
self
.
config
[
"DataLoader"
],
"Train"
,
self
.
device
,
self
.
use_dali
)
if
self
.
mode
==
"eval"
or
(
self
.
mode
==
"train"
and
self
.
config
[
"Global"
][
"eval_during_train"
]):
if
self
.
eval_mode
==
"classification"
:
if
self
.
eval_mode
in
[
"classification"
,
"adaface"
]
:
self
.
eval_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Eval"
,
self
.
device
,
self
.
use_dali
)
...
...
@@ -457,7 +458,9 @@ class Engine(object):
def
export
(
self
):
assert
self
.
mode
==
"export"
use_multilabel
=
self
.
config
[
"Global"
].
get
(
"use_multilabel"
,
False
)
use_multilabel
=
self
.
config
[
"Global"
].
get
(
"use_multilabel"
,
False
)
and
not
"ATTRMetric"
in
self
.
config
[
"Metric"
][
"Eval"
][
0
]
model
=
ExportModel
(
self
.
config
[
"Arch"
],
self
.
model
,
use_multilabel
)
if
self
.
config
[
"Global"
][
"pretrained_model"
]
is
not
None
:
load_dygraph_pretrain
(
model
.
base_model
,
...
...
ppcls/engine/evaluation/__init__.py
浏览文件 @
dc651d47
...
...
@@ -14,3 +14,4 @@
from
ppcls.engine.evaluation.classification
import
classification_eval
from
ppcls.engine.evaluation.retrieval
import
retrieval_eval
from
ppcls.engine.evaluation.adaface
import
adaface_eval
\ No newline at end of file
ppcls/engine/evaluation/adaface.py
0 → 100644
浏览文件 @
dc651d47
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
time
import
numpy
as
np
import
platform
import
paddle
import
sklearn
from
sklearn.model_selection
import
KFold
from
sklearn.decomposition
import
PCA
from
ppcls.utils.misc
import
AverageMeter
from
ppcls.utils
import
logger
def
fuse_features_with_norm
(
stacked_embeddings
,
stacked_norms
):
assert
stacked_embeddings
.
ndim
==
3
# (n_features_to_fuse, batch_size, channel)
assert
stacked_norms
.
ndim
==
3
# (n_features_to_fuse, batch_size, 1)
pre_norm_embeddings
=
stacked_embeddings
*
stacked_norms
fused
=
pre_norm_embeddings
.
sum
(
axis
=
0
)
norm
=
paddle
.
norm
(
fused
,
2
,
1
,
True
)
fused
=
paddle
.
divide
(
fused
,
norm
)
return
fused
,
norm
def
adaface_eval
(
engine
,
epoch_id
=
0
):
output_info
=
dict
()
time_info
=
{
"batch_cost"
:
AverageMeter
(
"batch_cost"
,
'.5f'
,
postfix
=
" s,"
),
"reader_cost"
:
AverageMeter
(
"reader_cost"
,
".5f"
,
postfix
=
" s,"
),
}
print_batch_step
=
engine
.
config
[
"Global"
][
"print_batch_step"
]
metric_key
=
None
tic
=
time
.
time
()
unique_dict
=
{}
for
iter_id
,
batch
in
enumerate
(
engine
.
eval_dataloader
):
images
,
labels
,
dataname
,
image_index
=
batch
if
iter_id
==
5
:
for
key
in
time_info
:
time_info
[
key
].
reset
()
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
batch_size
=
images
.
shape
[
0
]
batch
[
0
]
=
paddle
.
to_tensor
(
images
)
embeddings
=
engine
.
model
(
images
,
labels
)[
'features'
]
norms
=
paddle
.
divide
(
embeddings
,
paddle
.
norm
(
embeddings
,
2
,
1
,
True
))
embeddings
=
paddle
.
divide
(
embeddings
,
norms
)
fliped_images
=
paddle
.
flip
(
images
,
axis
=
[
3
])
flipped_embeddings
=
engine
.
model
(
fliped_images
,
labels
)[
'features'
]
flipped_norms
=
paddle
.
divide
(
flipped_embeddings
,
paddle
.
norm
(
flipped_embeddings
,
2
,
1
,
True
))
flipped_embeddings
=
paddle
.
divide
(
flipped_embeddings
,
flipped_norms
)
stacked_embeddings
=
paddle
.
stack
(
[
embeddings
,
flipped_embeddings
],
axis
=
0
)
stacked_norms
=
paddle
.
stack
([
norms
,
flipped_norms
],
axis
=
0
)
embeddings
,
norms
=
fuse_features_with_norm
(
stacked_embeddings
,
stacked_norms
)
for
out
,
nor
,
label
,
data
,
idx
in
zip
(
embeddings
,
norms
,
labels
,
dataname
,
image_index
):
unique_dict
[
int
(
idx
.
numpy
())]
=
{
'output'
:
out
,
'norm'
:
nor
,
'target'
:
label
,
'dataname'
:
data
}
# calc metric
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
print_batch_step
==
0
:
time_msg
=
"s, "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
time_info
[
key
].
avg
)
for
key
in
time_info
])
ips_msg
=
"ips: {:.5f} images/sec"
.
format
(
batch_size
/
time_info
[
"batch_cost"
].
avg
)
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
val
)
for
key
in
output_info
])
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
epoch_id
,
iter_id
,
len
(
engine
.
eval_dataloader
),
metric_msg
,
time_msg
,
ips_msg
))
tic
=
time
.
time
()
unique_keys
=
sorted
(
unique_dict
.
keys
())
all_output_tensor
=
paddle
.
stack
(
[
unique_dict
[
key
][
'output'
]
for
key
in
unique_keys
],
axis
=
0
)
all_norm_tensor
=
paddle
.
stack
(
[
unique_dict
[
key
][
'norm'
]
for
key
in
unique_keys
],
axis
=
0
)
all_target_tensor
=
paddle
.
stack
(
[
unique_dict
[
key
][
'target'
]
for
key
in
unique_keys
],
axis
=
0
)
all_dataname_tensor
=
paddle
.
stack
(
[
unique_dict
[
key
][
'dataname'
]
for
key
in
unique_keys
],
axis
=
0
)
eval_result
=
cal_metric
(
all_output_tensor
,
all_norm_tensor
,
all_target_tensor
,
all_dataname_tensor
)
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
])
face_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
eval_result
[
key
])
for
key
in
eval_result
.
keys
()
])
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
+
", "
+
face_msg
))
# return 1st metric in the dict
return
eval_result
[
'all_test_acc'
]
def
cal_metric
(
all_output_tensor
,
all_norm_tensor
,
all_target_tensor
,
all_dataname_tensor
):
all_target_tensor
=
all_target_tensor
.
reshape
([
-
1
])
all_dataname_tensor
=
all_dataname_tensor
.
reshape
([
-
1
])
dataname_to_idx
=
{
"agedb_30"
:
0
,
"cfp_fp"
:
1
,
"lfw"
:
2
,
"cplfw"
:
3
,
"calfw"
:
4
}
idx_to_dataname
=
{
val
:
key
for
key
,
val
in
dataname_to_idx
.
items
()}
test_logs
=
{}
# _, indices = paddle.unique(all_dataname_tensor, return_index=True, return_inverse=False, return_counts=False)
for
dataname_idx
in
all_dataname_tensor
.
unique
():
dataname
=
idx_to_dataname
[
dataname_idx
.
item
()]
# per dataset evaluation
embeddings
=
all_output_tensor
[
all_dataname_tensor
==
dataname_idx
].
numpy
()
labels
=
all_target_tensor
[
all_dataname_tensor
==
dataname_idx
].
numpy
()
issame
=
labels
[
0
::
2
]
tpr
,
fpr
,
accuracy
,
best_thresholds
=
evaluate_face
(
embeddings
,
issame
,
nrof_folds
=
10
)
acc
,
best_threshold
=
accuracy
.
mean
(),
best_thresholds
.
mean
()
num_test_samples
=
len
(
embeddings
)
test_logs
[
f
'
{
dataname
}
_test_acc'
]
=
acc
test_logs
[
f
'
{
dataname
}
_test_best_threshold'
]
=
best_threshold
test_logs
[
f
'
{
dataname
}
_num_test_samples'
]
=
num_test_samples
test_acc
=
np
.
mean
([
test_logs
[
f
'
{
dataname
}
_test_acc'
]
for
dataname
in
dataname_to_idx
.
keys
()
if
f
'
{
dataname
}
_test_acc'
in
test_logs
])
test_logs
[
'all_test_acc'
]
=
test_acc
return
test_logs
def
evaluate_face
(
embeddings
,
actual_issame
,
nrof_folds
=
10
,
pca
=
0
):
# Calculate evaluation metrics
thresholds
=
np
.
arange
(
0
,
4
,
0.01
)
embeddings1
=
embeddings
[
0
::
2
]
embeddings2
=
embeddings
[
1
::
2
]
tpr
,
fpr
,
accuracy
,
best_thresholds
=
calculate_roc
(
thresholds
,
embeddings1
,
embeddings2
,
np
.
asarray
(
actual_issame
),
nrof_folds
=
nrof_folds
,
pca
=
pca
)
return
tpr
,
fpr
,
accuracy
,
best_thresholds
def
calculate_roc
(
thresholds
,
embeddings1
,
embeddings2
,
actual_issame
,
nrof_folds
=
10
,
pca
=
0
):
assert
(
embeddings1
.
shape
[
0
]
==
embeddings2
.
shape
[
0
])
assert
(
embeddings1
.
shape
[
1
]
==
embeddings2
.
shape
[
1
])
nrof_pairs
=
min
(
len
(
actual_issame
),
embeddings1
.
shape
[
0
])
nrof_thresholds
=
len
(
thresholds
)
k_fold
=
KFold
(
n_splits
=
nrof_folds
,
shuffle
=
False
)
tprs
=
np
.
zeros
((
nrof_folds
,
nrof_thresholds
))
fprs
=
np
.
zeros
((
nrof_folds
,
nrof_thresholds
))
accuracy
=
np
.
zeros
((
nrof_folds
))
best_thresholds
=
np
.
zeros
((
nrof_folds
))
indices
=
np
.
arange
(
nrof_pairs
)
# print('pca', pca)
dist
=
None
if
pca
==
0
:
diff
=
np
.
subtract
(
embeddings1
,
embeddings2
)
dist
=
np
.
sum
(
np
.
square
(
diff
),
1
)
for
fold_idx
,
(
train_set
,
test_set
)
in
enumerate
(
k_fold
.
split
(
indices
)):
# print('train_set', train_set)
# print('test_set', test_set)
if
pca
>
0
:
print
(
'doing pca on'
,
fold_idx
)
embed1_train
=
embeddings1
[
train_set
]
embed2_train
=
embeddings2
[
train_set
]
_embed_train
=
np
.
concatenate
((
embed1_train
,
embed2_train
),
axis
=
0
)
# print(_embed_train.shape)
pca_model
=
PCA
(
n_components
=
pca
)
pca_model
.
fit
(
_embed_train
)
embed1
=
pca_model
.
transform
(
embeddings1
)
embed2
=
pca_model
.
transform
(
embeddings2
)
embed1
=
sklearn
.
preprocessing
.
normalize
(
embed1
)
embed2
=
sklearn
.
preprocessing
.
normalize
(
embed2
)
# print(embed1.shape, embed2.shape)
diff
=
np
.
subtract
(
embed1
,
embed2
)
dist
=
np
.
sum
(
np
.
square
(
diff
),
1
)
# Find the best threshold for the fold
acc_train
=
np
.
zeros
((
nrof_thresholds
))
for
threshold_idx
,
threshold
in
enumerate
(
thresholds
):
_
,
_
,
acc_train
[
threshold_idx
]
=
calculate_accuracy
(
threshold
,
dist
[
train_set
],
actual_issame
[
train_set
])
best_threshold_index
=
np
.
argmax
(
acc_train
)
best_thresholds
[
fold_idx
]
=
thresholds
[
best_threshold_index
]
for
threshold_idx
,
threshold
in
enumerate
(
thresholds
):
tprs
[
fold_idx
,
threshold_idx
],
fprs
[
fold_idx
,
threshold_idx
],
_
=
calculate_accuracy
(
threshold
,
dist
[
test_set
],
actual_issame
[
test_set
])
_
,
_
,
accuracy
[
fold_idx
]
=
calculate_accuracy
(
thresholds
[
best_threshold_index
],
dist
[
test_set
],
actual_issame
[
test_set
])
tpr
=
np
.
mean
(
tprs
,
0
)
fpr
=
np
.
mean
(
fprs
,
0
)
return
tpr
,
fpr
,
accuracy
,
best_thresholds
def
calculate_accuracy
(
threshold
,
dist
,
actual_issame
):
predict_issame
=
np
.
less
(
dist
,
threshold
)
tp
=
np
.
sum
(
np
.
logical_and
(
predict_issame
,
actual_issame
))
fp
=
np
.
sum
(
np
.
logical_and
(
predict_issame
,
np
.
logical_not
(
actual_issame
)))
tn
=
np
.
sum
(
np
.
logical_and
(
np
.
logical_not
(
predict_issame
),
np
.
logical_not
(
actual_issame
)))
fn
=
np
.
sum
(
np
.
logical_and
(
np
.
logical_not
(
predict_issame
),
actual_issame
))
tpr
=
0
if
(
tp
+
fn
==
0
)
else
float
(
tp
)
/
float
(
tp
+
fn
)
fpr
=
0
if
(
fp
+
tn
==
0
)
else
float
(
fp
)
/
float
(
fp
+
tn
)
acc
=
float
(
tp
+
tn
)
/
dist
.
size
return
tpr
,
fpr
,
acc
ppcls/metric/metrics.py
浏览文件 @
dc651d47
...
...
@@ -390,6 +390,7 @@ class AccuracyScore(MultiLabelMetric):
def
get_attr_metrics
(
gt_label
,
preds_probs
,
threshold
):
"""
index: evaluated label index
adapted from "https://github.com/valencebond/Rethinking_of_PAR/blob/master/metrics/pedestrian_metrics.py"
"""
pred_label
=
(
preds_probs
>
threshold
).
astype
(
int
)
...
...
requirements.txt
浏览文件 @
dc651d47
...
...
@@ -4,9 +4,9 @@ opencv-python==4.4.0.46
pillow
tqdm
PyYAML
visualdl
>=
2.2.0
visualdl
>=
2.2.0
scipy
scikit-learn
==0.23.2
scikit-learn
>=0.21.0
gast
==0.3.3
faiss-cpu
==1.7.1.post2
easydict
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录