Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
759fe41c
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
759fe41c
编写于
5月 19, 2022
作者:
D
dongshuilong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify face_dataset and ir_net name issue
上级
511669a4
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
125 addition
and
110 deletion
+125
-110
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-1
ppcls/arch/backbone/model_zoo/adaface_ir_net.py
ppcls/arch/backbone/model_zoo/adaface_ir_net.py
+10
-10
ppcls/configs/metric_learning/adaface_ir18.yaml
ppcls/configs/metric_learning/adaface_ir18.yaml
+15
-4
ppcls/data/dataloader/face_dataset.py
ppcls/data/dataloader/face_dataset.py
+3
-86
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+3
-0
ppcls/data/preprocess/ops/operators.py
ppcls/data/preprocess/ops/operators.py
+93
-9
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
759fe41c
...
...
@@ -68,7 +68,7 @@ from ppcls.arch.backbone.model_zoo.van import VAN_tiny
from
ppcls.arch.backbone.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
ppcls.arch.backbone.variant_models.vgg_variant
import
VGG19Sigmoid
from
ppcls.arch.backbone.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
ppcls.arch.backbone.model_zoo.
ir_net
import
IR_18
,
IR_34
,
IR_50
,
IR_101
,
IR_152
,
IR_200
,
IR_SE_50
,
IR_SE_101
,
IR_SE_152
,
IR_SE_200
from
ppcls.arch.backbone.model_zoo.
adaface_ir_net
import
AdaFace_IR_18
,
AdaFace_IR_34
,
AdaFace_IR_50
,
AdaFace_IR_101
,
AdaFace_IR_152
,
AdaFace_IR_SE_50
,
AdaFace_IR_SE_101
,
AdaFace_IR_SE_152
,
AdaFace_
IR_SE_200
# help whl get all the models' api (class type) and components' api (func type)
...
...
ppcls/arch/backbone/model_zoo/ir_net.py
→
ppcls/arch/backbone/model_zoo/
adaface_
ir_net.py
浏览文件 @
759fe41c
...
...
@@ -450,14 +450,14 @@ class Backbone(Layer):
return
x
def
IR_18
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_18
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-18 model.
"""
model
=
Backbone
(
input_size
,
18
,
'ir'
)
return
model
def
IR_34
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_34
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-34 model.
"""
model
=
Backbone
(
input_size
,
34
,
'ir'
)
...
...
@@ -465,7 +465,7 @@ def IR_34(input_size=(112, 112)):
return
model
def
IR_50
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_50
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-50 model.
"""
model
=
Backbone
(
input_size
,
50
,
'ir'
)
...
...
@@ -473,7 +473,7 @@ def IR_50(input_size=(112, 112)):
return
model
def
IR_101
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_101
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-101 model.
"""
model
=
Backbone
(
input_size
,
100
,
'ir'
)
...
...
@@ -481,7 +481,7 @@ def IR_101(input_size=(112, 112)):
return
model
def
IR_152
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_152
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-152 model.
"""
model
=
Backbone
(
input_size
,
152
,
'ir'
)
...
...
@@ -489,7 +489,7 @@ def IR_152(input_size=(112, 112)):
return
model
def
IR_200
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_200
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-200 model.
"""
model
=
Backbone
(
input_size
,
200
,
'ir'
)
...
...
@@ -497,7 +497,7 @@ def IR_200(input_size=(112, 112)):
return
model
def
IR_SE_50
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_SE_50
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-50 model.
"""
model
=
Backbone
(
input_size
,
50
,
'ir_se'
)
...
...
@@ -505,7 +505,7 @@ def IR_SE_50(input_size=(112, 112)):
return
model
def
IR_SE_101
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_SE_101
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-101 model.
"""
model
=
Backbone
(
input_size
,
100
,
'ir_se'
)
...
...
@@ -513,7 +513,7 @@ def IR_SE_101(input_size=(112, 112)):
return
model
def
IR_SE_152
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_SE_152
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-152 model.
"""
model
=
Backbone
(
input_size
,
152
,
'ir_se'
)
...
...
@@ -521,7 +521,7 @@ def IR_SE_152(input_size=(112, 112)):
return
model
def
IR_SE_200
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_SE_200
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-200 model.
"""
model
=
Backbone
(
input_size
,
200
,
'ir_se'
)
...
...
ppcls/configs/metric_learning/
ir18_adaface
.yaml
→
ppcls/configs/metric_learning/
adaface_ir18
.yaml
浏览文件 @
759fe41c
...
...
@@ -21,7 +21,7 @@ Arch:
infer_output_key
:
"
features"
infer_add_softmax
:
False
Backbone
:
name
:
"
IR_18"
name
:
"
AdaFace_
IR_18"
input_size
:
[
112
,
112
]
Head
:
name
:
"
AdaMargin"
...
...
@@ -57,10 +57,21 @@ DataLoader:
name
:
"
AdaFaceDataset"
root_dir
:
"
dataset/face/"
label_path
:
"
dataset/face/train_filter_label.txt"
low_res_augmentation_prob
:
0.2
crop_augmentation_prob
:
0.2
photometric_augmentation_prob
:
0.2
transform
:
-
CropWithPadding
:
prob
:
0.2
padding_num
:
0
size
:
[
112
,
112
]
scale
:
[
0.2
,
1.0
]
ratio
:
[
0.75
,
1.3333333333333333
]
-
RandomInterpolationAugment
:
prob
:
0.2
-
ColorJitter
:
prob
:
0.2
brightness
:
0.5
contrast
:
0.5
saturation
:
0.5
hue
:
0
-
RandomHorizontalFlip
:
-
ToTensor
:
-
Normalize
:
...
...
ppcls/data/dataloader/face_dataset.py
浏览文件 @
759fe41c
...
...
@@ -14,40 +14,11 @@ from ppcls.data.preprocess import transform as transform_func
# code is based on AdaFace: https://github.com/mk-minchul/AdaFace
def
_get_image_size
(
img
):
if
F
.
_is_pil_image
(
img
):
return
img
.
size
elif
F
.
_is_numpy_image
(
img
):
return
img
.
shape
[:
2
][::
-
1
]
elif
F
.
_is_tensor_image
(
img
):
return
img
.
shape
[
1
:][::
-
1
]
# chw
else
:
raise
TypeError
(
"Unexpected type {}"
.
format
(
type
(
img
)))
class
AdaFaceDataset
(
Dataset
):
def
__init__
(
self
,
root_dir
,
label_path
,
transform
=
None
,
low_res_augmentation_prob
=
0.0
,
crop_augmentation_prob
=
0.0
,
photometric_augmentation_prob
=
0.0
,
):
def
__init__
(
self
,
root_dir
,
label_path
,
transform
=
None
):
self
.
root_dir
=
root_dir
self
.
low_res_augmentation_prob
=
low_res_augmentation_prob
self
.
crop_augmentation_prob
=
crop_augmentation_prob
self
.
photometric_augmentation_prob
=
photometric_augmentation_prob
self
.
random_resized_crop
=
transforms
.
RandomResizedCrop
(
size
=
(
112
,
112
),
scale
=
(
0.2
,
1.0
),
ratio
=
(
0.75
,
1.3333333333333333
))
self
.
photometric
=
transforms
.
ColorJitter
(
brightness
=
0.5
,
contrast
=
0.5
,
saturation
=
0.5
,
hue
=
0
)
self
.
transform
=
create_operators
(
transform
)
self
.
tot_rot_try
=
0
self
.
rot_success
=
0
with
open
(
label_path
)
as
fd
:
lines
=
fd
.
readlines
()
self
.
samples
=
[]
...
...
@@ -73,65 +44,11 @@ class AdaFaceDataset(Dataset):
# if 'WebFace' in self.root:
# # swap rgb to bgr since image is in rgb for webface
# sample = Image.fromarray(np.asarray(sample)[:, :, ::-1])
sample
,
_
=
self
.
augment
(
sample
)
# sample = Image.fromarray(np.asarray(sample)[:, :, ::-1]
if
self
.
transform
is
not
None
:
sample
=
transform_func
(
sample
,
self
.
transform
)
return
sample
,
target
def
augment
(
self
,
sample
):
# crop with zero padding augmentation
if
np
.
random
.
random
()
<
self
.
crop_augmentation_prob
:
# RandomResizedCrop augmentation
new
=
np
.
zeros_like
(
np
.
array
(
sample
))
# orig_W, orig_H = F._get_image_size(sample)
orig_W
,
orig_H
=
_get_image_size
(
sample
)
i
,
j
,
h
,
w
=
self
.
random_resized_crop
.
_get_param
(
sample
)
cropped
=
F
.
crop
(
sample
,
i
,
j
,
h
,
w
)
new
[
i
:
i
+
h
,
j
:
j
+
w
,
:]
=
np
.
array
(
cropped
)
sample
=
Image
.
fromarray
(
new
.
astype
(
np
.
uint8
))
crop_ratio
=
min
(
h
,
w
)
/
max
(
orig_H
,
orig_W
)
else
:
crop_ratio
=
1.0
# low resolution augmentation
if
np
.
random
.
random
()
<
self
.
low_res_augmentation_prob
:
# low res augmentation
img_np
,
resize_ratio
=
low_res_augmentation
(
np
.
array
(
sample
))
sample
=
Image
.
fromarray
(
img_np
.
astype
(
np
.
uint8
))
else
:
resize_ratio
=
1
# photometric augmentation
if
np
.
random
.
random
()
<
self
.
photometric_augmentation_prob
:
sample
=
self
.
photometric
(
sample
)
information_score
=
resize_ratio
*
crop_ratio
return
sample
,
information_score
def
low_res_augmentation
(
img
):
# resize the image to a small size and enlarge it back
img_shape
=
img
.
shape
side_ratio
=
np
.
random
.
uniform
(
0.2
,
1.0
)
small_side
=
int
(
side_ratio
*
img_shape
[
0
])
interpolation
=
np
.
random
.
choice
([
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_AREA
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
])
small_img
=
cv2
.
resize
(
img
,
(
small_side
,
small_side
),
interpolation
=
interpolation
)
interpolation
=
np
.
random
.
choice
([
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_AREA
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
])
aug_img
=
cv2
.
resize
(
small_img
,
(
img_shape
[
1
],
img_shape
[
0
]),
interpolation
=
interpolation
)
return
aug_img
,
side_ratio
class
FiveValidationDataset
(
Dataset
):
def
__init__
(
self
,
val_data_path
,
concat_mem_file_name
):
...
...
ppcls/data/preprocess/__init__.py
浏览文件 @
759fe41c
...
...
@@ -34,6 +34,9 @@ from ppcls.data.preprocess.ops.operators import Pad
from
ppcls.data.preprocess.ops.operators
import
ToTensor
from
ppcls.data.preprocess.ops.operators
import
Normalize
from
ppcls.data.preprocess.ops.operators
import
RandomHorizontalFlip
from
ppcls.data.preprocess.ops.operators
import
CropWithPadding
from
ppcls.data.preprocess.ops.operators
import
RandomInterpolationAugment
from
ppcls.data.preprocess.ops.operators
import
ColorJitter
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
OpSampler
,
FmixOperator
...
...
ppcls/data/preprocess/ops/operators.py
浏览文件 @
759fe41c
...
...
@@ -25,8 +25,8 @@ import cv2
import
numpy
as
np
from
PIL
import
Image
,
ImageOps
,
__version__
as
PILLOW_VERSION
from
paddle.vision.transforms
import
ColorJitter
as
RawColorJitter
from
paddle.vision.transforms
import
ToTensor
,
Normalize
,
RandomHorizontalFlip
from
paddle.vision.transforms
import
ToTensor
,
Normalize
,
RandomHorizontalFlip
,
RandomResizedCrop
from
paddle.vision.transforms
import
functional
as
F
from
.autoaugment
import
ImageNetPolicy
from
.functional
import
augmentations
from
ppcls.utils
import
logger
...
...
@@ -93,6 +93,42 @@ class UnifiedResize(object):
return
self
.
resize_func
(
src
,
size
)
class
RandomInterpolationAugment
(
object
):
def
__init__
(
self
,
prob
):
self
.
prob
=
prob
def
_aug
(
self
,
img
):
img_shape
=
img
.
shape
side_ratio
=
np
.
random
.
uniform
(
0.2
,
1.0
)
small_side
=
int
(
side_ratio
*
img_shape
[
0
])
interpolation
=
np
.
random
.
choice
([
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_AREA
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
])
small_img
=
cv2
.
resize
(
img
,
(
small_side
,
small_side
),
interpolation
=
interpolation
)
interpolation
=
np
.
random
.
choice
([
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_AREA
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
])
aug_img
=
cv2
.
resize
(
small_img
,
(
img_shape
[
1
],
img_shape
[
0
]),
interpolation
=
interpolation
)
return
aug_img
def
__call__
(
self
,
img
):
if
np
.
random
.
random
()
<
self
.
prob
:
if
isinstance
(
img
,
np
.
ndarray
):
return
self
.
_aug
(
img
)
else
:
pil_img
=
np
.
array
(
img
)
aug_img
=
self
.
_aug
(
pil_img
)
img
=
Image
.
fromarray
(
aug_img
.
astype
(
np
.
uint8
))
return
img
else
:
return
img
class
OperatorParamError
(
ValueError
):
""" OperatorParamError
"""
...
...
@@ -170,6 +206,52 @@ class ResizeImage(object):
return
self
.
_resize_func
(
img
,
(
w
,
h
))
class
CropWithPadding
(
RandomResizedCrop
):
"""
crop image and padding to original size
"""
def
__init__
(
self
,
prob
=
1
,
padding_num
=
0
,
size
=
224
,
scale
=
(
0.08
,
1.0
),
ratio
=
(
3.
/
4
,
4.
/
3
),
interpolation
=
'bilinear'
,
key
=
None
):
super
().
__init__
(
size
,
scale
,
ratio
,
interpolation
,
key
)
self
.
prob
=
prob
self
.
padding_num
=
padding_num
def
__call__
(
self
,
img
):
is_cv2_img
=
False
if
isinstance
(
img
,
np
.
ndarray
):
flag
=
True
if
np
.
random
.
random
()
<
self
.
prob
:
# RandomResizedCrop augmentation
new
=
np
.
zeros_like
(
np
.
array
(
img
))
+
self
.
padding_num
# orig_W, orig_H = F._get_image_size(sample)
orig_W
,
orig_H
=
self
.
_get_image_size
(
img
)
i
,
j
,
h
,
w
=
self
.
_get_param
(
img
)
cropped
=
F
.
crop
(
img
,
i
,
j
,
h
,
w
)
new
[
i
:
i
+
h
,
j
:
j
+
w
,
:]
=
np
.
array
(
cropped
)
if
not
isinstance
:
new
=
Image
.
fromarray
(
new
.
astype
(
np
.
uint8
))
return
new
else
:
return
img
def
_get_image_size
(
self
,
img
):
if
F
.
_is_pil_image
(
img
):
return
img
.
size
elif
F
.
_is_numpy_image
(
img
):
return
img
.
shape
[:
2
][::
-
1
]
elif
F
.
_is_tensor_image
(
img
):
return
img
.
shape
[
1
:][::
-
1
]
# chw
else
:
raise
TypeError
(
"Unexpected type {}"
.
format
(
type
(
img
)))
class
CropImage
(
object
):
""" crop image """
...
...
@@ -434,10 +516,12 @@ class ColorJitter(RawColorJitter):
"""ColorJitter.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
prob
=
2
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
prob
=
prob
def
__call__
(
self
,
img
):
if
np
.
random
.
random
()
<
self
.
prob
:
if
not
isinstance
(
img
,
Image
.
Image
):
img
=
np
.
ascontiguousarray
(
img
)
img
=
Image
.
fromarray
(
img
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录