Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
16f910b4
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
16f910b4
编写于
5月 05, 2022
作者:
H
HydrogenSulfate
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add m1 and baseline config
上级
0b148140
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
279 addition
and
343 deletion
+279
-343
ppcls/configs/Pedestrian/strong_baseline_baseline.yaml
ppcls/configs/Pedestrian/strong_baseline_baseline.yaml
+22
-25
ppcls/configs/Pedestrian/strong_baseline_m1.yaml
ppcls/configs/Pedestrian/strong_baseline_m1.yaml
+21
-29
ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml
ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml
+21
-30
ppcls/data/dataloader/person_dataset.py
ppcls/data/dataloader/person_dataset.py
+10
-3
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+2
-0
ppcls/data/preprocess/ops/operators.py
ppcls/data/preprocess/ops/operators.py
+115
-11
ppcls/data/preprocess/ops/random_erasing.py
ppcls/data/preprocess/ops/random_erasing.py
+31
-13
ppcls/engine/engine.py
ppcls/engine/engine.py
+0
-13
ppcls/engine/evaluation/retrieval.py
ppcls/engine/evaluation/retrieval.py
+31
-200
ppcls/engine/train/train.py
ppcls/engine/train/train.py
+8
-3
ppcls/loss/centerloss.py
ppcls/loss/centerloss.py
+4
-4
ppcls/loss/triplet.py
ppcls/loss/triplet.py
+3
-3
ppcls/optimizer/learning_rate.py
ppcls/optimizer/learning_rate.py
+11
-9
未找到文件。
ppcls/configs/Pedestrian/strong_baseline_baseline.yaml
浏览文件 @
16f910b4
...
...
@@ -11,8 +11,7 @@ Global:
print_batch_step
:
20
use_visualdl
:
False
eval_mode
:
"
retrieval"
re_ranking
:
False
feat_from
:
"
backbone"
# 'backbone' or 'neck'
retrieval_feature_from
:
"
backbone"
# 'backbone' or 'neck'
# used for static mode and model export
image_shape
:
[
3
,
256
,
128
]
save_inference_dir
:
"
./inference"
...
...
@@ -23,7 +22,7 @@ Arch:
infer_output_key
:
"
features"
infer_add_softmax
:
False
Backbone
:
name
:
"
ResNet50
_last_stage_stride1
"
name
:
"
ResNet50"
pretrained
:
True
stem_act
:
null
BackboneStopLayer
:
...
...
@@ -32,36 +31,30 @@ Arch:
name
:
"
FC"
embedding_size
:
2048
class_num
:
751
weight_attr
:
initializer
:
name
:
Normal
std
:
0.001
bias_attr
:
False
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
-
TripletLossV2
:
weight
:
1.0
margin
:
0.3
normalize_feature
:
f
alse
feat_from
:
"
backbone"
normalize_feature
:
F
alse
feat
ure
_from
:
"
backbone"
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
name
:
Adam
lr
:
name
:
Piecewise
decay_epochs
:
[
30
,
6
0
]
decay_epochs
:
[
40
,
7
0
]
values
:
[
0.00035
,
0.000035
,
0.0000035
]
warmup_epoch
:
10
warmup_start_lr
:
0.0000035
by_epoch
:
True
last_epoch
:
0
regularizer
:
name
:
'
L2'
coeff
:
0.0005
...
...
@@ -73,26 +66,26 @@ DataLoader:
name
:
"
Market1501"
image_root
:
"
./dataset/"
cls_label_path
:
"
bounding_box_train"
backend
:
"
pil"
transform_ops
:
-
ResizeImage
:
size
:
[
128
,
256
]
return_numpy
:
False
-
RandFlipImage
:
flip_code
:
1
-
Pad
:
padding
:
10
-
RandCropImage
:
-
RandCropImage
V2
:
size
:
[
128
,
256
]
scale
:
[
0.8022
,
0.8022
]
ratio
:
[
0.5
,
0.5
]
-
NormalizeImage
:
-
ToTensor
:
-
Normalize
:
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedRandomIdentitySampler
batch_size
:
64
num_instances
:
4
drop_last
:
Tru
e
drop_last
:
Fals
e
shuffle
:
True
loader
:
num_workers
:
4
...
...
@@ -103,13 +96,15 @@ DataLoader:
name
:
"
Market1501"
image_root
:
"
./dataset/"
cls_label_path
:
"
query"
backend
:
"
pil"
transform_ops
:
-
ResizeImage
:
size
:
[
128
,
256
]
-
NormalizeImage
:
return_numpy
:
False
-
ToTensor
:
-
Normalize
:
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
...
...
@@ -124,13 +119,15 @@ DataLoader:
name
:
"
Market1501"
image_root
:
"
./dataset/"
cls_label_path
:
"
bounding_box_test"
backend
:
"
pil"
transform_ops
:
-
ResizeImage
:
size
:
[
128
,
256
]
-
NormalizeImage
:
return_numpy
:
False
-
ToTensor
:
-
Normalize
:
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
...
...
ppcls/configs/Pedestrian/strong_baseline_m1.yaml
浏览文件 @
16f910b4
...
...
@@ -10,10 +10,8 @@ Global:
epochs
:
120
print_batch_step
:
20
use_visualdl
:
False
warmup_by_epoch
:
True
eval_mode
:
"
retrieval"
re_ranking
:
False
feat_from
:
"
neck"
# 'backbone' or 'neck'
retrieval_feature_from
:
"
features"
# 'backbone' or 'features'
# used for static mode and model export
image_shape
:
[
3
,
256
,
128
]
save_inference_dir
:
"
./inference"
...
...
@@ -40,7 +38,7 @@ Arch:
initializer
:
name
:
Constant
value
:
0.0
learning_rate
:
1.0e-20
# NOTE: Temporarily set lr small enough to freeze the bias
learning_rate
:
1.0e-20
# NOTE: Temporarily set lr small enough to freeze the bias
to zero
Head
:
name
:
"
FC"
embedding_size
:
*feat_dim
...
...
@@ -60,8 +58,8 @@ Loss:
-
TripletLossV2
:
weight
:
1.0
margin
:
0.3
normalize_feature
:
f
alse
feat_from
:
"
backbone"
normalize_feature
:
F
alse
feat
ure
_from
:
"
backbone"
Eval
:
-
CELoss
:
weight
:
1.0
...
...
@@ -74,6 +72,8 @@ Optimizer:
values
:
[
0.00035
,
0.000035
,
0.0000035
]
warmup_epoch
:
10
warmup_start_lr
:
0.0000035
by_epoch
:
True
last_epoch
:
0
regularizer
:
name
:
'
L2'
coeff
:
0.0005
...
...
@@ -85,36 +85,32 @@ DataLoader:
name
:
"
Market1501"
image_root
:
"
./dataset/"
cls_label_path
:
"
bounding_box_train"
backend
:
"
pil"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
128
,
256
]
return_numpy
:
False
-
RandFlipImage
:
flip_code
:
1
-
Pad
:
padding
:
10
-
RandCropImage
:
-
RandCropImage
V2
:
size
:
[
128
,
256
]
scale
:
[
0.8022
,
0.8022
]
ratio
:
[
0.5
,
0.5
]
-
NormalizeImage
:
scale
:
0.00392157
-
ToTensor
:
-
Normalize
:
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.5
sl
:
0.02
sh
:
0.4
r1
:
0.3
mean
:
[
0.4
914
,
0.4822
,
0.4465
]
mean
:
[
0.4
85
,
0.456
,
0.406
]
sampler
:
name
:
DistributedRandomIdentitySampler
batch_size
:
64
num_instances
:
4
drop_last
:
Tru
e
drop_last
:
Fals
e
shuffle
:
True
loader
:
num_workers
:
4
...
...
@@ -125,17 +121,15 @@ DataLoader:
name
:
"
Market1501"
image_root
:
"
./dataset/"
cls_label_path
:
"
query"
backend
:
"
pil"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
128
,
256
]
-
NormalizeImage
:
scale
:
0.00392157
return_numpy
:
False
-
ToTensor
:
-
Normalize
:
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
...
...
@@ -150,17 +144,15 @@ DataLoader:
name
:
"
Market1501"
image_root
:
"
./dataset/"
cls_label_path
:
"
bounding_box_test"
backend
:
"
pil"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
128
,
256
]
-
NormalizeImage
:
scale
:
0.00392157
return_numpy
:
False
-
ToTensor
:
-
Normalize
:
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
...
...
ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml
浏览文件 @
16f910b4
...
...
@@ -10,10 +10,8 @@ Global:
epochs
:
120
print_batch_step
:
20
use_visualdl
:
False
warmup_by_epoch
:
True
eval_mode
:
"
retrieval"
re_ranking
:
False
feat_from
:
"
neck"
# 'backbone' or 'neck'
retrieval_feature_from
:
"
features"
# 'backbone' or 'features'
# used for static mode and model export
image_shape
:
[
3
,
256
,
128
]
save_inference_dir
:
"
./inference"
...
...
@@ -40,7 +38,7 @@ Arch:
initializer
:
name
:
Constant
value
:
0.0
learning_rate
:
1.0e-20
# NOTE: Temporarily set lr small enough to freeze the bias
learning_rate
:
1.0e-20
# NOTE: Temporarily set lr small enough to freeze the bias
to zero
Head
:
name
:
"
FC"
embedding_size
:
*feat_dim
...
...
@@ -60,8 +58,8 @@ Loss:
-
TripletLossV2
:
weight
:
1.0
margin
:
0.3
normalize_feature
:
f
alse
feat_from
:
"
backbone"
normalize_feature
:
F
alse
feat
ure
_from
:
"
backbone"
-
CenterLoss
:
weight
:
0.0005
num_classes
:
*class_num
...
...
@@ -80,7 +78,8 @@ Optimizer:
values
:
[
0.00035
,
0.000035
,
0.0000035
]
warmup_epoch
:
10
warmup_start_lr
:
0.0000035
warmup_by_epoch
:
True
by_epoch
:
True
last_epoch
:
0
regularizer
:
name
:
'
L2'
coeff
:
0.0005
...
...
@@ -97,36 +96,32 @@ DataLoader:
name
:
"
Market1501"
image_root
:
"
./dataset/"
cls_label_path
:
"
bounding_box_train"
backend
:
"
pil"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
128
,
256
]
return_numpy
:
False
-
RandFlipImage
:
flip_code
:
1
-
Pad
:
padding
:
10
-
RandCropImage
:
-
RandCropImage
V2
:
size
:
[
128
,
256
]
scale
:
[
0.8022
,
0.8022
]
ratio
:
[
0.5
,
0.5
]
-
NormalizeImage
:
scale
:
0.00392157
-
ToTensor
:
-
Normalize
:
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.5
sl
:
0.02
sh
:
0.4
r1
:
0.3
mean
:
[
0.4
914
,
0.4822
,
0.4465
]
mean
:
[
0.4
85
,
0.456
,
0.406
]
sampler
:
name
:
DistributedRandomIdentitySampler
batch_size
:
64
num_instances
:
4
drop_last
:
Tru
e
drop_last
:
Fals
e
shuffle
:
True
loader
:
num_workers
:
4
...
...
@@ -137,17 +132,15 @@ DataLoader:
name
:
"
Market1501"
image_root
:
"
./dataset/"
cls_label_path
:
"
query"
backend
:
"
pil"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
128
,
256
]
-
NormalizeImage
:
scale
:
0.00392157
return_numpy
:
False
-
ToTensor
:
-
Normalize
:
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
...
...
@@ -162,17 +155,15 @@ DataLoader:
name
:
"
Market1501"
image_root
:
"
./dataset/"
cls_label_path
:
"
bounding_box_test"
backend
:
"
pil"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
128
,
256
]
-
NormalizeImage
:
scale
:
0.00392157
return_numpy
:
False
-
ToTensor
:
-
Normalize
:
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
...
...
ppcls/data/dataloader/person_dataset.py
浏览文件 @
16f910b4
...
...
@@ -43,7 +43,11 @@ class Market1501(Dataset):
"""
_dataset_dir
=
'market1501/Market-1501-v15.09.15'
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
):
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
,
backend
=
"cv2"
):
self
.
_img_root
=
image_root
self
.
_cls_path
=
cls_label_path
# the sub folder in the dataset
self
.
_dataset_dir
=
osp
.
join
(
image_root
,
self
.
_dataset_dir
,
...
...
@@ -51,6 +55,7 @@ class Market1501(Dataset):
self
.
_check_before_run
()
if
transform_ops
:
self
.
_transform_ops
=
create_operators
(
transform_ops
)
self
.
backend
=
backend
self
.
_dtype
=
paddle
.
get_default_dtype
()
self
.
_load_anno
(
relabel
=
True
if
'train'
in
self
.
_cls_path
else
False
)
...
...
@@ -92,10 +97,12 @@ class Market1501(Dataset):
def
__getitem__
(
self
,
idx
):
try
:
img
=
Image
.
open
(
self
.
images
[
idx
]).
convert
(
'RGB'
)
img
=
np
.
array
(
img
,
dtype
=
"float32"
).
astype
(
np
.
uint8
)
if
self
.
backend
==
"cv2"
:
img
=
np
.
array
(
img
,
dtype
=
"float32"
).
astype
(
np
.
uint8
)
if
self
.
_transform_ops
:
img
=
transform
(
img
,
self
.
_transform_ops
)
img
=
img
.
transpose
((
2
,
0
,
1
))
if
self
.
backend
==
"cv2"
:
img
=
img
.
transpose
((
2
,
0
,
1
))
return
(
img
,
self
.
labels
[
idx
],
self
.
cameras
[
idx
])
except
Exception
as
ex
:
logger
.
error
(
"Exception occured when parse line: {} with msg: {}"
.
...
...
ppcls/data/preprocess/__init__.py
浏览文件 @
16f910b4
...
...
@@ -30,6 +30,8 @@ from ppcls.data.preprocess.ops.operators import NormalizeImage
from
ppcls.data.preprocess.ops.operators
import
ToCHWImage
from
ppcls.data.preprocess.ops.operators
import
AugMix
from
ppcls.data.preprocess.ops.operators
import
Pad
from
ppcls.data.preprocess.ops.operators
import
ToTensor
from
ppcls.data.preprocess.ops.operators
import
Normalize
from
ppcls.data.preprocess.batch_ops.batch_operators
import
MixupOperator
,
CutmixOperator
,
OpSampler
,
FmixOperator
...
...
ppcls/data/preprocess/ops/operators.py
浏览文件 @
16f910b4
...
...
@@ -22,10 +22,11 @@ import six
import
math
import
random
import
cv2
from
typing
import
Sequence
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
,
ImageOps
,
__version__
as
PILLOW_VERSION
from
paddle.vision.transforms
import
ColorJitter
as
RawColorJitter
from
paddle.vision.transforms
import
Pad
from
paddle.vision.transforms
import
ToTensor
,
Normalize
from
.autoaugment
import
ImageNetPolicy
from
.functional
import
augmentations
...
...
@@ -33,7 +34,7 @@ from ppcls.utils import logger
class
UnifiedResize
(
object
):
def
__init__
(
self
,
interpolation
=
None
,
backend
=
"cv2"
):
def
__init__
(
self
,
interpolation
=
None
,
backend
=
"cv2"
,
return_numpy
=
True
):
_cv2_interp_from_str
=
{
'nearest'
:
cv2
.
INTER_NEAREST
,
'bilinear'
:
cv2
.
INTER_LINEAR
,
...
...
@@ -57,12 +58,15 @@ class UnifiedResize(object):
resample
=
random
.
choice
(
resample
)
return
cv2
.
resize
(
src
,
size
,
interpolation
=
resample
)
def
_pil_resize
(
src
,
size
,
resample
):
def
_pil_resize
(
src
,
size
,
resample
,
return_numpy
=
True
):
if
isinstance
(
resample
,
tuple
):
resample
=
random
.
choice
(
resample
)
pil_img
=
Image
.
fromarray
(
src
)
if
isinstance
(
src
,
np
.
ndarray
):
pil_img
=
Image
.
fromarray
(
src
)
pil_img
=
pil_img
.
resize
(
size
,
resample
)
return
np
.
asarray
(
pil_img
)
if
return_numpy
:
return
np
.
asarray
(
pil_img
)
return
pil_img
if
backend
.
lower
()
==
"cv2"
:
if
isinstance
(
interpolation
,
str
):
...
...
@@ -74,7 +78,8 @@ class UnifiedResize(object):
elif
backend
.
lower
()
==
"pil"
:
if
isinstance
(
interpolation
,
str
):
interpolation
=
_pil_interp_from_str
[
interpolation
.
lower
()]
self
.
resize_func
=
partial
(
_pil_resize
,
resample
=
interpolation
)
self
.
resize_func
=
partial
(
_pil_resize
,
resample
=
interpolation
,
return_numpy
=
return_numpy
)
else
:
logger
.
warning
(
f
"The backend of Resize only support
\"
cv2
\"
or
\"
PIL
\"
.
\"
f
{
backend
}
\"
is unavailable. Use
\"
cv2
\"
instead."
...
...
@@ -129,7 +134,8 @@ class ResizeImage(object):
size
=
None
,
resize_short
=
None
,
interpolation
=
None
,
backend
=
"cv2"
):
backend
=
"cv2"
,
return_numpy
=
True
):
if
resize_short
is
not
None
and
resize_short
>
0
:
self
.
resize_short
=
resize_short
self
.
w
=
None
...
...
@@ -143,10 +149,16 @@ class ResizeImage(object):
'both 'size' and 'resize_short' are None"
)
self
.
_resize_func
=
UnifiedResize
(
interpolation
=
interpolation
,
backend
=
backend
)
interpolation
=
interpolation
,
backend
=
backend
,
return_numpy
=
return_numpy
)
def
__call__
(
self
,
img
):
img_h
,
img_w
=
img
.
shape
[:
2
]
if
isinstance
(
img
,
np
.
ndarray
):
img_h
,
img_w
=
img
.
shape
[:
2
]
else
:
img_w
,
img_h
=
img
.
size
if
self
.
resize_short
is
not
None
:
percent
=
float
(
self
.
resize_short
)
/
min
(
img_w
,
img_h
)
w
=
int
(
round
(
img_w
*
percent
))
...
...
@@ -226,6 +238,40 @@ class RandCropImage(object):
return
self
.
_resize_func
(
img
,
size
)
class
RandCropImageV2
(
object
):
""" RandCropImageV2 is different from RandCropImage,
it will Select a cutting position randomly in a uniform distribution way,
and cut according to the given size without resize at last."""
def
__init__
(
self
,
size
):
if
type
(
size
)
is
int
:
self
.
size
=
(
size
,
size
)
# (h, w)
else
:
self
.
size
=
size
def
__call__
(
self
,
img
):
if
isinstance
(
img
,
np
.
ndarray
):
img_h
,
img_w
=
img
.
shap
[
0
],
img
.
shap
[
1
]
else
:
img_w
,
img_h
=
img
.
size
tw
,
th
=
self
.
size
if
img_h
+
1
<
th
or
img_w
+
1
<
tw
:
raise
ValueError
(
"Required crop size {} is larger then input image size {}"
.
format
((
th
,
tw
),
(
img_h
,
img_w
)))
if
img_w
==
tw
and
img_h
==
th
:
return
img
top
=
random
.
randint
(
0
,
img_h
-
th
+
1
)
left
=
random
.
randint
(
0
,
img_w
-
tw
+
1
)
if
isinstance
(
img
,
np
.
ndarray
):
return
img
[
top
:
top
+
th
,
left
:
left
+
tw
,
:]
else
:
return
img
.
crop
((
left
,
top
,
left
+
tw
,
top
+
th
))
class
RandFlipImage
(
object
):
""" random flip image
flip_code:
...
...
@@ -241,7 +287,10 @@ class RandFlipImage(object):
def
__call__
(
self
,
img
):
if
random
.
randint
(
0
,
1
)
==
1
:
return
cv2
.
flip
(
img
,
self
.
flip_code
)
if
isinstance
(
img
,
np
.
ndarray
):
return
cv2
.
flip
(
img
,
self
.
flip_code
)
else
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
else
:
return
img
...
...
@@ -395,3 +444,58 @@ class ColorJitter(RawColorJitter):
if
isinstance
(
img
,
Image
.
Image
):
img
=
np
.
asarray
(
img
)
return
img
class
Pad
(
object
):
"""
Pads the given PIL.Image on all sides with specified padding mode and fill value.
adapted from: https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Pad
"""
def
__init__
(
self
,
padding
:
int
,
fill
:
int
=
0
,
padding_mode
:
str
=
"constant"
):
self
.
padding
=
padding
self
.
fill
=
fill
self
.
padding_mode
=
padding_mode
def
_parse_fill
(
self
,
fill
,
img
,
min_pil_version
,
name
=
"fillcolor"
):
# Process fill color for affine transforms
major_found
,
minor_found
=
(
int
(
v
)
for
v
in
PILLOW_VERSION
.
split
(
'.'
)[:
2
])
major_required
,
minor_required
=
(
int
(
v
)
for
v
in
min_pil_version
.
split
(
'.'
)[:
2
])
if
major_found
<
major_required
or
(
major_found
==
major_required
and
minor_found
<
minor_required
):
if
fill
is
None
:
return
{}
else
:
msg
=
(
"The option to fill background area of the transformed image, "
"requires pillow>={}"
)
raise
RuntimeError
(
msg
.
format
(
min_pil_version
))
num_bands
=
len
(
img
.
getbands
())
if
fill
is
None
:
fill
=
0
if
isinstance
(
fill
,
(
int
,
float
))
and
num_bands
>
1
:
fill
=
tuple
([
fill
]
*
num_bands
)
if
isinstance
(
fill
,
(
list
,
tuple
)):
if
len
(
fill
)
!=
num_bands
:
msg
=
(
"The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})"
)
raise
ValueError
(
msg
.
format
(
len
(
fill
),
num_bands
))
fill
=
tuple
(
fill
)
return
{
name
:
fill
}
def
__call__
(
self
,
img
):
opts
=
self
.
_parse_fill
(
self
.
fill
,
img
,
"2.3.0"
,
name
=
"fill"
)
if
img
.
mode
==
"P"
:
palette
=
img
.
getpalette
()
img
=
ImageOps
.
expand
(
img
,
border
=
self
.
padding
,
**
opts
)
img
.
putpalette
(
palette
)
return
img
return
ImageOps
.
expand
(
img
,
border
=
self
.
padding
,
**
opts
)
ppcls/data/preprocess/ops/random_erasing.py
浏览文件 @
16f910b4
...
...
@@ -25,15 +25,21 @@ import numpy as np
class
Pixels
(
object
):
def
__init__
(
self
,
mode
=
"const"
,
mean
=
[
0.
,
0.
,
0.
]):
self
.
_mode
=
mode
self
.
_mean
=
mean
self
.
_mean
=
np
.
array
(
mean
)
def
__call__
(
self
,
h
=
224
,
w
=
224
,
c
=
3
):
def
__call__
(
self
,
h
=
224
,
w
=
224
,
c
=
3
,
channel_first
=
False
):
if
self
.
_mode
==
"rand"
:
return
np
.
random
.
normal
(
size
=
(
1
,
1
,
3
))
return
np
.
random
.
normal
(
size
=
(
1
,
1
,
3
))
if
not
channel_first
else
np
.
random
.
normal
(
size
=
(
3
,
1
,
1
))
elif
self
.
_mode
==
"pixel"
:
return
np
.
random
.
normal
(
size
=
(
h
,
w
,
c
))
return
np
.
random
.
normal
(
size
=
(
h
,
w
,
c
))
if
not
channel_first
else
np
.
random
.
normal
(
size
=
(
c
,
h
,
w
))
elif
self
.
_mode
==
"const"
:
return
self
.
_mean
return
np
.
reshape
(
self
.
_mean
,
(
1
,
1
,
c
))
if
not
channel_first
else
np
.
reshape
(
self
.
_mean
,
(
c
,
1
,
1
))
else
:
raise
Exception
(
"Invalid mode in RandomErasing, only support
\"
const
\"
,
\"
rand
\"
,
\"
pixel
\"
"
...
...
@@ -68,7 +74,13 @@ class RandomErasing(object):
return
img
for
_
in
range
(
self
.
attempt
):
area
=
img
.
shape
[
0
]
*
img
.
shape
[
1
]
if
isinstance
(
img
,
np
.
ndarray
):
img_h
,
img_w
,
img_c
=
img
.
shape
channel_first
=
False
else
:
img_c
,
img_h
,
img_w
=
img
.
shape
channel_first
=
True
area
=
img_h
*
img_w
target_area
=
random
.
uniform
(
self
.
sl
,
self
.
sh
)
*
area
aspect_ratio
=
random
.
uniform
(
*
self
.
r1
)
...
...
@@ -78,13 +90,19 @@ class RandomErasing(object):
h
=
int
(
round
(
math
.
sqrt
(
target_area
*
aspect_ratio
)))
w
=
int
(
round
(
math
.
sqrt
(
target_area
/
aspect_ratio
)))
if
w
<
img
.
shape
[
1
]
and
h
<
img
.
shape
[
0
]:
pixels
=
self
.
get_pixels
(
h
,
w
,
img
.
shape
[
2
])
x1
=
random
.
randint
(
0
,
img
.
shape
[
0
]
-
h
)
y1
=
random
.
randint
(
0
,
img
.
shape
[
1
]
-
w
)
if
img
.
shape
[
2
]
==
3
:
img
[
x1
:
x1
+
h
,
y1
:
y1
+
w
,
:]
=
pixels
if
w
<
img_w
and
h
<
img_h
:
pixels
=
self
.
get_pixels
(
h
,
w
,
img_c
,
channel_first
)
x1
=
random
.
randint
(
0
,
img_h
-
h
)
y1
=
random
.
randint
(
0
,
img_w
-
w
)
if
img_c
==
3
:
if
channel_first
:
img
[:,
x1
:
x1
+
h
,
y1
:
y1
+
w
]
=
pixels
else
:
img
[
x1
:
x1
+
h
,
y1
:
y1
+
w
,
:]
=
pixels
else
:
img
[
x1
:
x1
+
h
,
y1
:
y1
+
w
,
0
]
=
pixels
[
0
]
if
channel_first
:
img
[
0
,
x1
:
x1
+
h
,
y1
:
y1
+
w
]
=
pixels
[
0
]
else
:
img
[
x1
:
x1
+
h
,
y1
:
y1
+
w
,
0
]
=
pixels
[:,
:,
0
]
return
img
return
img
ppcls/engine/engine.py
浏览文件 @
16f910b4
...
...
@@ -304,25 +304,12 @@ class Engine(object):
self
.
max_iter
=
len
(
self
.
train_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
self
.
train_dataloader
)
# step lr once before first epoch when when Global.warmup_by_epoch=True
if
self
.
config
[
"Global"
].
get
(
"warmup_by_epoch"
,
False
):
for
i
in
range
(
len
(
self
.
lr_sch
)):
self
.
lr_sch
[
i
].
step
()
logger
.
info
(
"lr_sch step once before the first epoch, when Global.warmup_by_epoch=True"
)
for
epoch_id
in
range
(
best_metric
[
"epoch"
]
+
1
,
self
.
config
[
"Global"
][
"epochs"
]
+
1
):
acc
=
0.0
# for one epoch train
self
.
train_epoch_func
(
self
,
epoch_id
,
print_batch_step
)
# lr step when Global.warmup_by_epoch=True
if
self
.
config
[
"Global"
].
get
(
"warmup_by_epoch"
,
False
):
for
i
in
range
(
len
(
self
.
lr_sch
)):
self
.
lr_sch
[
i
].
step
()
if
self
.
use_dali
:
self
.
train_dataloader
.
reset
()
metric_msg
=
", "
.
join
([
...
...
ppcls/engine/evaluation/retrieval.py
浏览文件 @
16f910b4
...
...
@@ -16,8 +16,6 @@ from __future__ import division
from
__future__
import
print_function
import
platform
import
numpy
as
np
import
paddle
from
ppcls.utils
import
logger
...
...
@@ -51,48 +49,33 @@ def retrieval_eval(engine, epoch_id=0):
metric_dict
=
{
metric_key
:
0.
}
else
:
metric_dict
=
dict
()
reranking_flag
=
engine
.
config
[
'Global'
].
get
(
're_ranking'
,
False
)
logger
.
info
(
f
"re_ranking=
{
reranking_flag
}
"
)
if
not
reranking_flag
:
for
block_idx
,
block_fea
in
enumerate
(
fea_blocks
):
similarity_matrix
=
paddle
.
matmul
(
block_fea
,
gallery_feas
,
transpose_y
=
True
)
if
query_query_id
is
not
None
:
query_id_block
=
query_id_blocks
[
block_idx
]
query_id_mask
=
(
query_id_block
!=
gallery_unique_id
.
t
())
image_id_block
=
image_id_blocks
[
block_idx
]
image_id_mask
=
(
image_id_block
!=
gallery_img_id
.
t
())
keep_mask
=
paddle
.
logical_or
(
query_id_mask
,
image_id_mask
)
similarity_matrix
=
similarity_matrix
*
keep_mask
.
astype
(
"float32"
)
for
block_idx
,
block_fea
in
enumerate
(
fea_blocks
):
similarity_matrix
=
paddle
.
matmul
(
block_fea
,
gallery_feas
,
transpose_y
=
True
)
if
query_query_id
is
not
None
:
query_id_block
=
query_id_blocks
[
block_idx
]
query_id_mask
=
(
query_id_block
!=
gallery_unique_id
.
t
())
image_id_block
=
image_id_blocks
[
block_idx
]
image_id_mask
=
(
image_id_block
!=
gallery_img_id
.
t
())
keep_mask
=
paddle
.
logical_or
(
query_id_mask
,
image_id_mask
)
similarity_matrix
=
similarity_matrix
*
keep_mask
.
astype
(
"float32"
)
else
:
keep_mask
=
None
metric_tmp
=
engine
.
eval_metric_func
(
similarity_matrix
,
image_id_blocks
[
block_idx
],
gallery_img_id
,
keep_mask
)
for
key
in
metric_tmp
:
if
key
not
in
metric_dict
:
metric_dict
[
key
]
=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
else
:
keep_mask
=
None
metric_tmp
=
engine
.
eval_metric_func
(
similarity_matrix
,
image_id_blocks
[
block_idx
],
gallery_img_id
,
keep_mask
)
for
key
in
metric_tmp
:
if
key
not
in
metric_dict
:
metric_dict
[
key
]
=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
else
:
metric_dict
[
key
]
+=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
else
:
metric_dict
=
dict
()
distmat
=
re_ranking
(
query_feas
,
gallery_feas
,
k1
=
20
,
k2
=
6
,
lambda_value
=
0.3
)
cmc
,
mAP
=
eval_func
(
distmat
,
np
.
squeeze
(
query_img_id
.
numpy
()),
np
.
squeeze
(
gallery_img_id
.
numpy
()),
np
.
squeeze
(
query_query_id
.
numpy
()),
np
.
squeeze
(
gallery_unique_id
.
numpy
()))
metric_dict
[
"recall1(RK)"
]
=
cmc
[
0
]
metric_dict
[
"recall5(RK)"
]
=
cmc
[
4
]
metric_dict
[
"mAP(RK)"
]
=
mAP
metric_dict
[
key
]
+=
metric_tmp
[
key
]
*
block_fea
.
shape
[
0
]
/
len
(
query_feas
)
metric_info_list
=
[]
for
key
in
metric_dict
:
...
...
@@ -105,159 +88,6 @@ def retrieval_eval(engine, epoch_id=0):
return
metric_dict
[
metric_key
]
def
re_ranking
(
queFea
,
galFea
,
k1
=
20
,
k2
=
6
,
lambda_value
=
0.5
,
local_distmat
=
None
,
only_local
=
False
):
# if feature vector is numpy, you should use 'paddle.tensor' transform it to tensor
query_num
=
queFea
.
shape
[
0
]
all_num
=
query_num
+
galFea
.
shape
[
0
]
if
only_local
:
original_dist
=
local_distmat
else
:
feat
=
paddle
.
concat
([
queFea
,
galFea
])
logger
.
info
(
'using GPU to compute original distance'
)
# L2 distance
distmat
=
paddle
.
pow
(
feat
,
2
).
sum
(
axis
=
1
,
keepdim
=
True
).
expand
([
all_num
,
all_num
])
+
\
paddle
.
pow
(
feat
,
2
).
sum
(
axis
=
1
,
keepdim
=
True
).
expand
([
all_num
,
all_num
]).
t
()
distmat
=
distmat
.
addmm
(
x
=
feat
,
y
=
feat
.
t
(),
alpha
=-
2.0
,
beta
=
1.0
)
# Cosine distance
# distmat = paddle.matmul(queFea, galFea, transpose_y=True)
# if query_query_id is not None:
# query_id_mask = (queCid != galCid.t())
# image_id_mask = (queId != galId.t())
# keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
# distmat = distmat * keep_mask.astype("float32")
original_dist
=
distmat
.
cpu
().
numpy
()
del
feat
if
local_distmat
is
not
None
:
original_dist
=
original_dist
+
local_distmat
gallery_num
=
original_dist
.
shape
[
0
]
original_dist
=
np
.
transpose
(
original_dist
/
np
.
max
(
original_dist
,
axis
=
0
))
V
=
np
.
zeros_like
(
original_dist
).
astype
(
np
.
float16
)
initial_rank
=
np
.
argsort
(
original_dist
).
astype
(
np
.
int32
)
logger
.
info
(
'starting re_ranking'
)
for
i
in
range
(
all_num
):
# k-reciprocal neighbors
forward_k_neigh_index
=
initial_rank
[
i
,
:
k1
+
1
]
backward_k_neigh_index
=
initial_rank
[
forward_k_neigh_index
,
:
k1
+
1
]
fi
=
np
.
where
(
backward_k_neigh_index
==
i
)[
0
]
k_reciprocal_index
=
forward_k_neigh_index
[
fi
]
k_reciprocal_expansion_index
=
k_reciprocal_index
for
j
in
range
(
len
(
k_reciprocal_index
)):
candidate
=
k_reciprocal_index
[
j
]
candidate_forward_k_neigh_index
=
initial_rank
[
candidate
,
:
int
(
np
.
around
(
k1
/
2
))
+
1
]
candidate_backward_k_neigh_index
=
initial_rank
[
candidate_forward_k_neigh_index
,
:
int
(
np
.
around
(
k1
/
2
))
+
1
]
fi_candidate
=
np
.
where
(
candidate_backward_k_neigh_index
==
candidate
)[
0
]
candidate_k_reciprocal_index
=
candidate_forward_k_neigh_index
[
fi_candidate
]
if
len
(
np
.
intersect1d
(
candidate_k_reciprocal_index
,
k_reciprocal_index
))
>
2
/
3
*
len
(
candidate_k_reciprocal_index
):
k_reciprocal_expansion_index
=
np
.
append
(
k_reciprocal_expansion_index
,
candidate_k_reciprocal_index
)
k_reciprocal_expansion_index
=
np
.
unique
(
k_reciprocal_expansion_index
)
weight
=
np
.
exp
(
-
original_dist
[
i
,
k_reciprocal_expansion_index
])
V
[
i
,
k_reciprocal_expansion_index
]
=
weight
/
np
.
sum
(
weight
)
original_dist
=
original_dist
[:
query_num
,
]
if
k2
!=
1
:
V_qe
=
np
.
zeros_like
(
V
,
dtype
=
np
.
float16
)
for
i
in
range
(
all_num
):
V_qe
[
i
,
:]
=
np
.
mean
(
V
[
initial_rank
[
i
,
:
k2
],
:],
axis
=
0
)
V
=
V_qe
del
V_qe
del
initial_rank
invIndex
=
[]
for
i
in
range
(
gallery_num
):
invIndex
.
append
(
np
.
where
(
V
[:,
i
]
!=
0
)[
0
])
jaccard_dist
=
np
.
zeros_like
(
original_dist
,
dtype
=
np
.
float16
)
for
i
in
range
(
query_num
):
temp_min
=
np
.
zeros
(
shape
=
[
1
,
gallery_num
],
dtype
=
np
.
float16
)
indNonZero
=
np
.
where
(
V
[
i
,
:]
!=
0
)[
0
]
indImages
=
[
invIndex
[
ind
]
for
ind
in
indNonZero
]
for
j
in
range
(
len
(
indNonZero
)):
temp_min
[
0
,
indImages
[
j
]]
=
temp_min
[
0
,
indImages
[
j
]]
+
np
.
minimum
(
V
[
i
,
indNonZero
[
j
]],
V
[
indImages
[
j
],
indNonZero
[
j
]])
jaccard_dist
[
i
]
=
1
-
temp_min
/
(
2
-
temp_min
)
final_dist
=
jaccard_dist
*
(
1
-
lambda_value
)
+
original_dist
*
lambda_value
del
original_dist
del
V
del
jaccard_dist
final_dist
=
final_dist
[:
query_num
,
query_num
:]
return
final_dist
def
eval_func
(
distmat
,
q_pids
,
g_pids
,
q_camids
,
g_camids
,
max_rank
=
50
):
"""Evaluation with market1501 metric
Key: for each query identity, its gallery images from the same camera view are discarded.
"""
num_q
,
num_g
=
distmat
.
shape
if
num_g
<
max_rank
:
max_rank
=
num_g
print
(
"Note: number of gallery samples is quite small, got {}"
.
format
(
num_g
))
indices
=
np
.
argsort
(
distmat
,
axis
=
1
)
matches
=
(
g_pids
[
indices
]
==
q_pids
[:,
np
.
newaxis
]).
astype
(
np
.
int32
)
# compute cmc curve for each query
all_cmc
=
[]
all_AP
=
[]
num_valid_q
=
0.
# number of valid query
for
q_idx
in
range
(
num_q
):
# get query pid and camid
q_pid
=
q_pids
[
q_idx
]
q_camid
=
q_camids
[
q_idx
]
# remove gallery samples that have the same pid and camid with query
order
=
indices
[
q_idx
]
remove
=
(
g_pids
[
order
]
==
q_pid
)
&
(
g_camids
[
order
]
==
q_camid
)
keep
=
np
.
invert
(
remove
)
# compute cmc curve
# binary vector, positions with value 1 are correct matches
orig_cmc
=
matches
[
q_idx
][
keep
]
if
not
np
.
any
(
orig_cmc
):
# this condition is true when query identity does not appear in gallery
continue
cmc
=
orig_cmc
.
cumsum
()
cmc
[
cmc
>
1
]
=
1
all_cmc
.
append
(
cmc
[:
max_rank
])
num_valid_q
+=
1.
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
num_rel
=
orig_cmc
.
sum
()
tmp_cmc
=
orig_cmc
.
cumsum
()
tmp_cmc
=
[
x
/
(
i
+
1.
)
for
i
,
x
in
enumerate
(
tmp_cmc
)]
tmp_cmc
=
np
.
asarray
(
tmp_cmc
)
*
orig_cmc
AP
=
tmp_cmc
.
sum
()
/
num_rel
all_AP
.
append
(
AP
)
assert
num_valid_q
>
0
,
"Error: all query identities do not appear in gallery"
all_cmc
=
np
.
asarray
(
all_cmc
).
astype
(
np
.
float32
)
all_cmc
=
all_cmc
.
sum
(
0
)
/
num_valid_q
mAP
=
np
.
mean
(
all_AP
)
return
all_cmc
,
mAP
def
cal_feature
(
engine
,
name
=
'gallery'
):
has_unique_id
=
False
all_unique_id
=
None
...
...
@@ -298,12 +128,13 @@ def cal_feature(engine, name='gallery'):
out
=
out
[
"Student"
]
# get features
if
engine
.
config
[
"Global"
].
get
(
"feat_from"
,
'backbone'
)
==
'backbone'
:
if
engine
.
config
[
"Global"
].
get
(
"retrieval_feature_from"
,
"features"
)
==
"features"
:
# use neck's output as features
batch_feas
=
out
[
"features"
]
else
:
# use backbone's output as features
batch_feas
=
out
[
"backbone"
]
else
:
# use neck's output as features
batch_feas
=
out
[
"neck"
]
# do norm
if
engine
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
...
...
ppcls/engine/train/train.py
浏览文件 @
16f910b4
...
...
@@ -68,9 +68,9 @@ def train_epoch(engine, epoch_id, print_batch_step):
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
optimizer
[
i
].
clear_grad
()
# step lr
if
engine
.
config
[
"Global"
].
get
(
"warmup_by_epoch"
,
False
)
is
False
:
for
i
in
range
(
len
(
engine
.
lr_sch
)
):
# step lr
(by step)
for
i
in
range
(
len
(
engine
.
lr_sch
))
:
if
not
getattr
(
engine
.
lr_sch
[
i
],
"by_epoch"
,
False
):
engine
.
lr_sch
[
i
].
step
()
# below code just for logging
...
...
@@ -83,6 +83,11 @@ def train_epoch(engine, epoch_id, print_batch_step):
log_info
(
engine
,
batch_size
,
epoch_id
,
iter_id
)
tic
=
time
.
time
()
# step lr(by epoch)
for
i
in
range
(
len
(
engine
.
lr_sch
)):
if
getattr
(
engine
.
lr_sch
[
i
],
"by_epoch"
,
False
):
engine
.
lr_sch
[
i
].
step
()
def
forward
(
engine
,
batch
):
if
not
engine
.
is_rec
:
...
...
ppcls/loss/centerloss.py
浏览文件 @
16f910b4
...
...
@@ -28,17 +28,17 @@ class CenterLoss(nn.Layer):
Args:
num_classes (int): number of classes.
feat_dim (int): number of feature dimensions.
feat
_from (str): features from backbone or neck
feat
ure_from (str): feature from "backbone" or "features"
"""
def
__init__
(
self
,
num_classes
:
int
,
feat_dim
:
int
,
feat
_from
:
str
=
'backbone'
):
feat
ure_from
:
str
=
"features"
):
super
(
CenterLoss
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
feat_dim
=
feat_dim
self
.
feat
_from
=
feat
_from
self
.
feat
ure_from
=
feature
_from
random_init_centers
=
paddle
.
randn
(
shape
=
[
self
.
num_classes
,
self
.
feat_dim
])
self
.
centers
=
self
.
create_parameter
(
...
...
@@ -57,7 +57,7 @@ class CenterLoss(nn.Layer):
Returns:
Dict[str, paddle.Tensor]: {'CenterLoss': loss}.
"""
feats
=
input
[
self
.
feat_from
]
feats
=
input
[
self
.
feat
ure
_from
]
labels
=
target
# squeeze labels to shape (batch_size, )
...
...
ppcls/loss/triplet.py
浏览文件 @
16f910b4
...
...
@@ -31,10 +31,10 @@ class TripletLossV2(nn.Layer):
def
__init__
(
self
,
margin
=
0.5
,
normalize_feature
=
True
,
feat
_from
=
'backbone'
):
feat
ure_from
=
"features"
):
super
(
TripletLossV2
,
self
).
__init__
()
self
.
margin
=
margin
self
.
feat
_from
=
feat
_from
self
.
feat
ure_from
=
feature
_from
self
.
ranking_loss
=
paddle
.
nn
.
loss
.
MarginRankingLoss
(
margin
=
margin
)
self
.
normalize_feature
=
normalize_feature
...
...
@@ -44,7 +44,7 @@ class TripletLossV2(nn.Layer):
inputs: feature matrix with shape (batch_size, feat_dim)
target: ground truth labels with shape (num_classes)
"""
inputs
=
input
[
self
.
feat_from
]
inputs
=
input
[
self
.
feat
ure
_from
]
if
self
.
normalize_feature
:
inputs
=
1.
*
inputs
/
(
paddle
.
expand_as
(
...
...
ppcls/optimizer/learning_rate.py
浏览文件 @
16f910b4
...
...
@@ -205,6 +205,7 @@ class Piecewise(object):
The type of element in the list is python float.
warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
by_epoch(bool): Whether lr decay by epoch. Default: False.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
...
...
@@ -215,7 +216,7 @@ class Piecewise(object):
epochs
,
warmup_epoch
=
0
,
warmup_start_lr
=
0.0
,
warmup_
by_epoch
=
False
,
by_epoch
=
False
,
last_epoch
=-
1
,
**
kwargs
):
super
().
__init__
()
...
...
@@ -230,33 +231,34 @@ class Piecewise(object):
self
.
warmup_steps
=
round
(
warmup_epoch
*
step_each_epoch
)
self
.
warmup_epoch
=
warmup_epoch
self
.
warmup_start_lr
=
warmup_start_lr
self
.
warmup_by_epoch
=
warmup_
by_epoch
self
.
by_epoch
=
by_epoch
def
__call__
(
self
):
if
self
.
warmup_by_epoch
is
False
:
if
self
.
by_epoch
:
learning_rate
=
lr
.
PiecewiseDecay
(
boundaries
=
self
.
boundaries_
steps
,
boundaries
=
self
.
boundaries_
epoch
,
values
=
self
.
values
,
last_epoch
=
self
.
last_epoch
)
if
self
.
warmup_
steps
>
0
:
if
self
.
warmup_
epoch
>
0
:
learning_rate
=
lr
.
LinearWarmup
(
learning_rate
=
learning_rate
,
warmup_steps
=
self
.
warmup_
steps
,
warmup_steps
=
self
.
warmup_
epoch
,
start_lr
=
self
.
warmup_start_lr
,
end_lr
=
self
.
values
[
0
],
last_epoch
=
self
.
last_epoch
)
else
:
learning_rate
=
lr
.
PiecewiseDecay
(
boundaries
=
self
.
boundaries_
epoch
,
boundaries
=
self
.
boundaries_
steps
,
values
=
self
.
values
,
last_epoch
=
self
.
last_epoch
)
if
self
.
warmup_
epoch
>
0
:
if
self
.
warmup_
steps
>
0
:
learning_rate
=
lr
.
LinearWarmup
(
learning_rate
=
learning_rate
,
warmup_steps
=
self
.
warmup_
epoch
,
warmup_steps
=
self
.
warmup_
steps
,
start_lr
=
self
.
warmup_start_lr
,
end_lr
=
self
.
values
[
0
],
last_epoch
=
self
.
last_epoch
)
setattr
(
learning_rate
,
"by_epoch"
,
self
.
by_epoch
)
return
learning_rate
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录