Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
759fe41c
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
759fe41c
编写于
5月 19, 2022
作者:
D
dongshuilong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify face_dataset and ir_net name issue
上级
511669a4
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
125 addition
and
110 deletion
+125
-110
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-1
ppcls/arch/backbone/model_zoo/adaface_ir_net.py
ppcls/arch/backbone/model_zoo/adaface_ir_net.py
+10
-10
ppcls/configs/metric_learning/adaface_ir18.yaml
ppcls/configs/metric_learning/adaface_ir18.yaml
+15
-4
ppcls/data/dataloader/face_dataset.py
ppcls/data/dataloader/face_dataset.py
+3
-86
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+3
-0
ppcls/data/preprocess/ops/operators.py
ppcls/data/preprocess/ops/operators.py
+93
-9
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
759fe41c
...
@@ -68,7 +68,7 @@ from ppcls.arch.backbone.model_zoo.van import VAN_tiny
...
@@ -68,7 +68,7 @@ from ppcls.arch.backbone.model_zoo.van import VAN_tiny
from
ppcls.arch.backbone.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
ppcls.arch.backbone.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
ppcls.arch.backbone.variant_models.vgg_variant
import
VGG19Sigmoid
from
ppcls.arch.backbone.variant_models.vgg_variant
import
VGG19Sigmoid
from
ppcls.arch.backbone.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
ppcls.arch.backbone.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
ppcls.arch.backbone.model_zoo.
ir_net
import
IR_18
,
IR_34
,
IR_50
,
IR_101
,
IR_152
,
IR_200
,
IR_SE_50
,
IR_SE_101
,
IR_SE_152
,
IR_SE_200
from
ppcls.arch.backbone.model_zoo.
adaface_ir_net
import
AdaFace_IR_18
,
AdaFace_IR_34
,
AdaFace_IR_50
,
AdaFace_IR_101
,
AdaFace_IR_152
,
AdaFace_IR_SE_50
,
AdaFace_IR_SE_101
,
AdaFace_IR_SE_152
,
AdaFace_
IR_SE_200
# help whl get all the models' api (class type) and components' api (func type)
# help whl get all the models' api (class type) and components' api (func type)
...
...
ppcls/arch/backbone/model_zoo/ir_net.py
→
ppcls/arch/backbone/model_zoo/
adaface_
ir_net.py
浏览文件 @
759fe41c
...
@@ -450,14 +450,14 @@ class Backbone(Layer):
...
@@ -450,14 +450,14 @@ class Backbone(Layer):
return
x
return
x
def
IR_18
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_18
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-18 model.
""" Constructs a ir-18 model.
"""
"""
model
=
Backbone
(
input_size
,
18
,
'ir'
)
model
=
Backbone
(
input_size
,
18
,
'ir'
)
return
model
return
model
def
IR_34
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_34
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-34 model.
""" Constructs a ir-34 model.
"""
"""
model
=
Backbone
(
input_size
,
34
,
'ir'
)
model
=
Backbone
(
input_size
,
34
,
'ir'
)
...
@@ -465,7 +465,7 @@ def IR_34(input_size=(112, 112)):
...
@@ -465,7 +465,7 @@ def IR_34(input_size=(112, 112)):
return
model
return
model
def
IR_50
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_50
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-50 model.
""" Constructs a ir-50 model.
"""
"""
model
=
Backbone
(
input_size
,
50
,
'ir'
)
model
=
Backbone
(
input_size
,
50
,
'ir'
)
...
@@ -473,7 +473,7 @@ def IR_50(input_size=(112, 112)):
...
@@ -473,7 +473,7 @@ def IR_50(input_size=(112, 112)):
return
model
return
model
def
IR_101
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_101
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-101 model.
""" Constructs a ir-101 model.
"""
"""
model
=
Backbone
(
input_size
,
100
,
'ir'
)
model
=
Backbone
(
input_size
,
100
,
'ir'
)
...
@@ -481,7 +481,7 @@ def IR_101(input_size=(112, 112)):
...
@@ -481,7 +481,7 @@ def IR_101(input_size=(112, 112)):
return
model
return
model
def
IR_152
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_152
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-152 model.
""" Constructs a ir-152 model.
"""
"""
model
=
Backbone
(
input_size
,
152
,
'ir'
)
model
=
Backbone
(
input_size
,
152
,
'ir'
)
...
@@ -489,7 +489,7 @@ def IR_152(input_size=(112, 112)):
...
@@ -489,7 +489,7 @@ def IR_152(input_size=(112, 112)):
return
model
return
model
def
IR_200
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_200
(
input_size
=
(
112
,
112
)):
""" Constructs a ir-200 model.
""" Constructs a ir-200 model.
"""
"""
model
=
Backbone
(
input_size
,
200
,
'ir'
)
model
=
Backbone
(
input_size
,
200
,
'ir'
)
...
@@ -497,7 +497,7 @@ def IR_200(input_size=(112, 112)):
...
@@ -497,7 +497,7 @@ def IR_200(input_size=(112, 112)):
return
model
return
model
def
IR_SE_50
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_SE_50
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-50 model.
""" Constructs a ir_se-50 model.
"""
"""
model
=
Backbone
(
input_size
,
50
,
'ir_se'
)
model
=
Backbone
(
input_size
,
50
,
'ir_se'
)
...
@@ -505,7 +505,7 @@ def IR_SE_50(input_size=(112, 112)):
...
@@ -505,7 +505,7 @@ def IR_SE_50(input_size=(112, 112)):
return
model
return
model
def
IR_SE_101
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_SE_101
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-101 model.
""" Constructs a ir_se-101 model.
"""
"""
model
=
Backbone
(
input_size
,
100
,
'ir_se'
)
model
=
Backbone
(
input_size
,
100
,
'ir_se'
)
...
@@ -513,7 +513,7 @@ def IR_SE_101(input_size=(112, 112)):
...
@@ -513,7 +513,7 @@ def IR_SE_101(input_size=(112, 112)):
return
model
return
model
def
IR_SE_152
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_SE_152
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-152 model.
""" Constructs a ir_se-152 model.
"""
"""
model
=
Backbone
(
input_size
,
152
,
'ir_se'
)
model
=
Backbone
(
input_size
,
152
,
'ir_se'
)
...
@@ -521,7 +521,7 @@ def IR_SE_152(input_size=(112, 112)):
...
@@ -521,7 +521,7 @@ def IR_SE_152(input_size=(112, 112)):
return
model
return
model
def
IR_SE_200
(
input_size
=
(
112
,
112
)):
def
AdaFace_
IR_SE_200
(
input_size
=
(
112
,
112
)):
""" Constructs a ir_se-200 model.
""" Constructs a ir_se-200 model.
"""
"""
model
=
Backbone
(
input_size
,
200
,
'ir_se'
)
model
=
Backbone
(
input_size
,
200
,
'ir_se'
)
...
...
ppcls/configs/metric_learning/
ir18_adaface
.yaml
→
ppcls/configs/metric_learning/
adaface_ir18
.yaml
浏览文件 @
759fe41c
...
@@ -21,7 +21,7 @@ Arch:
...
@@ -21,7 +21,7 @@ Arch:
infer_output_key
:
"
features"
infer_output_key
:
"
features"
infer_add_softmax
:
False
infer_add_softmax
:
False
Backbone
:
Backbone
:
name
:
"
IR_18"
name
:
"
AdaFace_
IR_18"
input_size
:
[
112
,
112
]
input_size
:
[
112
,
112
]
Head
:
Head
:
name
:
"
AdaMargin"
name
:
"
AdaMargin"
...
@@ -57,10 +57,21 @@ DataLoader:
...
@@ -57,10 +57,21 @@ DataLoader:
name
:
"
AdaFaceDataset"
name
:
"
AdaFaceDataset"
root_dir
:
"
dataset/face/"
root_dir
:
"
dataset/face/"
label_path
:
"
dataset/face/train_filter_label.txt"
label_path
:
"
dataset/face/train_filter_label.txt"
low_res_augmentation_prob
:
0.2
crop_augmentation_prob
:
0.2
photometric_augmentation_prob
:
0.2
transform
:
transform
:
-
CropWithPadding
:
prob
:
0.2
padding_num
:
0
size
:
[
112
,
112
]
scale
:
[
0.2
,
1.0
]
ratio
:
[
0.75
,
1.3333333333333333
]
-
RandomInterpolationAugment
:
prob
:
0.2
-
ColorJitter
:
prob
:
0.2
brightness
:
0.5
contrast
:
0.5
saturation
:
0.5
hue
:
0
-
RandomHorizontalFlip
:
-
RandomHorizontalFlip
:
-
ToTensor
:
-
ToTensor
:
-
Normalize
:
-
Normalize
:
...
...
ppcls/data/dataloader/face_dataset.py
浏览文件 @
759fe41c
...
@@ -14,40 +14,11 @@ from ppcls.data.preprocess import transform as transform_func
...
@@ -14,40 +14,11 @@ from ppcls.data.preprocess import transform as transform_func
# code is based on AdaFace: https://github.com/mk-minchul/AdaFace
# code is based on AdaFace: https://github.com/mk-minchul/AdaFace
def
_get_image_size
(
img
):
if
F
.
_is_pil_image
(
img
):
return
img
.
size
elif
F
.
_is_numpy_image
(
img
):
return
img
.
shape
[:
2
][::
-
1
]
elif
F
.
_is_tensor_image
(
img
):
return
img
.
shape
[
1
:][::
-
1
]
# chw
else
:
raise
TypeError
(
"Unexpected type {}"
.
format
(
type
(
img
)))
class
AdaFaceDataset
(
Dataset
):
class
AdaFaceDataset
(
Dataset
):
def
__init__
(
def
__init__
(
self
,
root_dir
,
label_path
,
transform
=
None
):
self
,
root_dir
,
label_path
,
transform
=
None
,
low_res_augmentation_prob
=
0.0
,
crop_augmentation_prob
=
0.0
,
photometric_augmentation_prob
=
0.0
,
):
self
.
root_dir
=
root_dir
self
.
root_dir
=
root_dir
self
.
low_res_augmentation_prob
=
low_res_augmentation_prob
self
.
crop_augmentation_prob
=
crop_augmentation_prob
self
.
photometric_augmentation_prob
=
photometric_augmentation_prob
self
.
random_resized_crop
=
transforms
.
RandomResizedCrop
(
size
=
(
112
,
112
),
scale
=
(
0.2
,
1.0
),
ratio
=
(
0.75
,
1.3333333333333333
))
self
.
photometric
=
transforms
.
ColorJitter
(
brightness
=
0.5
,
contrast
=
0.5
,
saturation
=
0.5
,
hue
=
0
)
self
.
transform
=
create_operators
(
transform
)
self
.
transform
=
create_operators
(
transform
)
self
.
tot_rot_try
=
0
self
.
rot_success
=
0
with
open
(
label_path
)
as
fd
:
with
open
(
label_path
)
as
fd
:
lines
=
fd
.
readlines
()
lines
=
fd
.
readlines
()
self
.
samples
=
[]
self
.
samples
=
[]
...
@@ -73,65 +44,11 @@ class AdaFaceDataset(Dataset):
...
@@ -73,65 +44,11 @@ class AdaFaceDataset(Dataset):
# if 'WebFace' in self.root:
# if 'WebFace' in self.root:
# # swap rgb to bgr since image is in rgb for webface
# # swap rgb to bgr since image is in rgb for webface
# sample = Image.fromarray(np.asarray(sample)[:, :, ::-1])
# sample = Image.fromarray(np.asarray(sample)[:, :, ::-1]
sample
,
_
=
self
.
augment
(
sample
)
if
self
.
transform
is
not
None
:
if
self
.
transform
is
not
None
:
sample
=
transform_func
(
sample
,
self
.
transform
)
sample
=
transform_func
(
sample
,
self
.
transform
)
return
sample
,
target
return
sample
,
target
def
augment
(
self
,
sample
):
# crop with zero padding augmentation
if
np
.
random
.
random
()
<
self
.
crop_augmentation_prob
:
# RandomResizedCrop augmentation
new
=
np
.
zeros_like
(
np
.
array
(
sample
))
# orig_W, orig_H = F._get_image_size(sample)
orig_W
,
orig_H
=
_get_image_size
(
sample
)
i
,
j
,
h
,
w
=
self
.
random_resized_crop
.
_get_param
(
sample
)
cropped
=
F
.
crop
(
sample
,
i
,
j
,
h
,
w
)
new
[
i
:
i
+
h
,
j
:
j
+
w
,
:]
=
np
.
array
(
cropped
)
sample
=
Image
.
fromarray
(
new
.
astype
(
np
.
uint8
))
crop_ratio
=
min
(
h
,
w
)
/
max
(
orig_H
,
orig_W
)
else
:
crop_ratio
=
1.0
# low resolution augmentation
if
np
.
random
.
random
()
<
self
.
low_res_augmentation_prob
:
# low res augmentation
img_np
,
resize_ratio
=
low_res_augmentation
(
np
.
array
(
sample
))
sample
=
Image
.
fromarray
(
img_np
.
astype
(
np
.
uint8
))
else
:
resize_ratio
=
1
# photometric augmentation
if
np
.
random
.
random
()
<
self
.
photometric_augmentation_prob
:
sample
=
self
.
photometric
(
sample
)
information_score
=
resize_ratio
*
crop_ratio
return
sample
,
information_score
def
low_res_augmentation
(
img
):
# resize the image to a small size and enlarge it back
img_shape
=
img
.
shape
side_ratio
=
np
.
random
.
uniform
(
0.2
,
1.0
)
small_side
=
int
(
side_ratio
*
img_shape
[
0
])
interpolation
=
np
.
random
.
choice
([
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_AREA
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
])
small_img
=
cv2
.
resize
(
img
,
(
small_side
,
small_side
),
interpolation
=
interpolation
)
interpolation
=
np
.
random
.
choice
([
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_AREA
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
])
aug_img
=
cv2
.
resize
(
small_img
,
(
img_shape
[
1
],
img_shape
[
0
]),
interpolation
=
interpolation
)
return
aug_img
,
side_ratio
class
FiveValidationDataset
(
Dataset
):
class
FiveValidationDataset
(
Dataset
):
def
__init__
(
self
,
val_data_path
,
concat_mem_file_name
):
def
__init__
(
self
,
val_data_path
,
concat_mem_file_name
):
...
@@ -243,4 +160,4 @@ def get_val_data(data_path):
...
@@ -243,4 +160,4 @@ def get_val_data(data_path):
lfw
,
lfw_issame
=
get_val_pair
(
data_path
,
'lfw'
)
lfw
,
lfw_issame
=
get_val_pair
(
data_path
,
'lfw'
)
cplfw
,
cplfw_issame
=
get_val_pair
(
data_path
,
'cplfw'
)
cplfw
,
cplfw_issame
=
get_val_pair
(
data_path
,
'cplfw'
)
calfw
,
calfw_issame
=
get_val_pair
(
data_path
,
'calfw'
)
calfw
,
calfw_issame
=
get_val_pair
(
data_path
,
'calfw'
)
return
agedb_30
,
cfp_fp
,
lfw
,
agedb_30_issame
,
cfp_fp_issame
,
lfw_issame
,
cplfw
,
cplfw_issame
,
calfw
,
calfw_issame
return
agedb_30
,
cfp_fp
,
lfw
,
agedb_30_issame
,
cfp_fp_issame
,
lfw_issame
,
cplfw
,
cplfw_issame
,
calfw
,
calfw_issame
\ No newline at end of file
ppcls/data/preprocess/__init__.py
浏览文件 @
759fe41c
...
@@ -34,6 +34,9 @@ from ppcls.data.preprocess.ops.operators import Pad
...
@@ -34,6 +34,9 @@ from ppcls.data.preprocess.ops.operators import Pad
from
ppcls.data.preprocess.ops.operators
import
ToTensor
from
ppcls.data.preprocess.ops.operators
import
ToTensor
from
ppcls.data.preprocess.ops.operators
import
Normalize
from
ppcls.data.preprocess.ops.operators
import
Normalize
from
ppcls.data.preprocess.ops.operators
import
RandomHorizontalFlip
from
ppcls.data.preprocess.ops.operators
import
RandomHorizontalFlip
from
ppcls.data.preprocess.ops.operators
import
CropWithPadding
from
ppcls.data.preprocess.ops.operators
import
RandomInterpolationAugment
from
ppcls.data.preprocess.ops.operators
import
ColorJitter
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
OpSampler
,
FmixOperator
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
OpSampler
,
FmixOperator
...
...
ppcls/data/preprocess/ops/operators.py
浏览文件 @
759fe41c
...
@@ -25,8 +25,8 @@ import cv2
...
@@ -25,8 +25,8 @@ import cv2
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
,
ImageOps
,
__version__
as
PILLOW_VERSION
from
PIL
import
Image
,
ImageOps
,
__version__
as
PILLOW_VERSION
from
paddle.vision.transforms
import
ColorJitter
as
RawColorJitter
from
paddle.vision.transforms
import
ColorJitter
as
RawColorJitter
from
paddle.vision.transforms
import
ToTensor
,
Normalize
,
RandomHorizontalFlip
from
paddle.vision.transforms
import
ToTensor
,
Normalize
,
RandomHorizontalFlip
,
RandomResizedCrop
from
paddle.vision.transforms
import
functional
as
F
from
.autoaugment
import
ImageNetPolicy
from
.autoaugment
import
ImageNetPolicy
from
.functional
import
augmentations
from
.functional
import
augmentations
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
...
@@ -93,6 +93,42 @@ class UnifiedResize(object):
...
@@ -93,6 +93,42 @@ class UnifiedResize(object):
return
self
.
resize_func
(
src
,
size
)
return
self
.
resize_func
(
src
,
size
)
class
RandomInterpolationAugment
(
object
):
def
__init__
(
self
,
prob
):
self
.
prob
=
prob
def
_aug
(
self
,
img
):
img_shape
=
img
.
shape
side_ratio
=
np
.
random
.
uniform
(
0.2
,
1.0
)
small_side
=
int
(
side_ratio
*
img_shape
[
0
])
interpolation
=
np
.
random
.
choice
([
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_AREA
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
])
small_img
=
cv2
.
resize
(
img
,
(
small_side
,
small_side
),
interpolation
=
interpolation
)
interpolation
=
np
.
random
.
choice
([
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_AREA
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
])
aug_img
=
cv2
.
resize
(
small_img
,
(
img_shape
[
1
],
img_shape
[
0
]),
interpolation
=
interpolation
)
return
aug_img
def
__call__
(
self
,
img
):
if
np
.
random
.
random
()
<
self
.
prob
:
if
isinstance
(
img
,
np
.
ndarray
):
return
self
.
_aug
(
img
)
else
:
pil_img
=
np
.
array
(
img
)
aug_img
=
self
.
_aug
(
pil_img
)
img
=
Image
.
fromarray
(
aug_img
.
astype
(
np
.
uint8
))
return
img
else
:
return
img
class
OperatorParamError
(
ValueError
):
class
OperatorParamError
(
ValueError
):
""" OperatorParamError
""" OperatorParamError
"""
"""
...
@@ -170,6 +206,52 @@ class ResizeImage(object):
...
@@ -170,6 +206,52 @@ class ResizeImage(object):
return
self
.
_resize_func
(
img
,
(
w
,
h
))
return
self
.
_resize_func
(
img
,
(
w
,
h
))
class
CropWithPadding
(
RandomResizedCrop
):
"""
crop image and padding to original size
"""
def
__init__
(
self
,
prob
=
1
,
padding_num
=
0
,
size
=
224
,
scale
=
(
0.08
,
1.0
),
ratio
=
(
3.
/
4
,
4.
/
3
),
interpolation
=
'bilinear'
,
key
=
None
):
super
().
__init__
(
size
,
scale
,
ratio
,
interpolation
,
key
)
self
.
prob
=
prob
self
.
padding_num
=
padding_num
def
__call__
(
self
,
img
):
is_cv2_img
=
False
if
isinstance
(
img
,
np
.
ndarray
):
flag
=
True
if
np
.
random
.
random
()
<
self
.
prob
:
# RandomResizedCrop augmentation
new
=
np
.
zeros_like
(
np
.
array
(
img
))
+
self
.
padding_num
# orig_W, orig_H = F._get_image_size(sample)
orig_W
,
orig_H
=
self
.
_get_image_size
(
img
)
i
,
j
,
h
,
w
=
self
.
_get_param
(
img
)
cropped
=
F
.
crop
(
img
,
i
,
j
,
h
,
w
)
new
[
i
:
i
+
h
,
j
:
j
+
w
,
:]
=
np
.
array
(
cropped
)
if
not
isinstance
:
new
=
Image
.
fromarray
(
new
.
astype
(
np
.
uint8
))
return
new
else
:
return
img
def
_get_image_size
(
self
,
img
):
if
F
.
_is_pil_image
(
img
):
return
img
.
size
elif
F
.
_is_numpy_image
(
img
):
return
img
.
shape
[:
2
][::
-
1
]
elif
F
.
_is_tensor_image
(
img
):
return
img
.
shape
[
1
:][::
-
1
]
# chw
else
:
raise
TypeError
(
"Unexpected type {}"
.
format
(
type
(
img
)))
class
CropImage
(
object
):
class
CropImage
(
object
):
""" crop image """
""" crop image """
...
@@ -434,16 +516,18 @@ class ColorJitter(RawColorJitter):
...
@@ -434,16 +516,18 @@ class ColorJitter(RawColorJitter):
"""ColorJitter.
"""ColorJitter.
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
prob
=
2
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
prob
=
prob
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
if
not
isinstance
(
img
,
Image
.
Image
):
if
np
.
random
.
random
()
<
self
.
prob
:
img
=
np
.
ascontiguousarray
(
img
)
if
not
isinstance
(
img
,
Image
.
Image
):
img
=
Image
.
fromarray
(
img
)
img
=
np
.
ascontiguousarray
(
img
)
img
=
super
().
_apply_image
(
img
)
img
=
Image
.
fromarray
(
img
)
if
isinstance
(
img
,
Image
.
Image
):
img
=
super
().
_apply_image
(
img
)
img
=
np
.
asarray
(
img
)
if
isinstance
(
img
,
Image
.
Image
):
img
=
np
.
asarray
(
img
)
return
img
return
img
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录