Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
09200a31
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
09200a31
编写于
10月 19, 2022
作者:
H
HydrogenSulfate
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove redundant code, fix bugs in lr.step, merge GoodsDataset into Vehicle
上级
30cbb183
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
59 addition
and
137 deletion
+59
-137
ppcls/configs/metric_learning/xbm_resnet50.yaml
ppcls/configs/metric_learning/xbm_resnet50.yaml
+3
-3
ppcls/data/__init__.py
ppcls/data/__init__.py
+0
-1
ppcls/data/dataloader/DistributedRandomIdentitySampler.py
ppcls/data/dataloader/DistributedRandomIdentitySampler.py
+0
-20
ppcls/data/dataloader/goods_dataset.py
ppcls/data/dataloader/goods_dataset.py
+0
-95
ppcls/data/dataloader/vehicle_dataset.py
ppcls/data/dataloader/vehicle_dataset.py
+42
-10
ppcls/engine/engine.py
ppcls/engine/engine.py
+2
-1
ppcls/engine/train/train.py
ppcls/engine/train/train.py
+3
-2
ppcls/engine/train/utils.py
ppcls/engine/train/utils.py
+9
-5
未找到文件。
ppcls/configs/metric_learning/xbm_resnet50.yaml
浏览文件 @
09200a31
...
...
@@ -88,7 +88,7 @@ Optimizer:
DataLoader
:
Train
:
dataset
:
name
:
GoodsDataset
name
:
VeriWild
image_root
:
./dataset/SOP
cls_label_path
:
./dataset/SOP/train_list.txt
backend
:
pil
...
...
@@ -117,7 +117,7 @@ DataLoader:
Eval
:
Gallery
:
dataset
:
name
:
GoodsDataset
name
:
VeriWild
image_root
:
./dataset/SOP
cls_label_path
:
./dataset/SOP/test_list.txt
backend
:
pil
...
...
@@ -141,7 +141,7 @@ DataLoader:
Query
:
dataset
:
name
:
GoodsDataset
name
:
VeriWild
image_root
:
./dataset/SOP
cls_label_path
:
./dataset/SOP/test_list.txt
backend
:
pil
...
...
ppcls/data/__init__.py
浏览文件 @
09200a31
...
...
@@ -25,7 +25,6 @@ from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset
from
ppcls.data.dataloader.multilabel_dataset
import
MultiLabelDataset
from
ppcls.data.dataloader.common_dataset
import
create_operators
from
ppcls.data.dataloader.vehicle_dataset
import
CompCars
,
VeriWild
from
ppcls.data.dataloader.goods_dataset
import
GoodsDataset
from
ppcls.data.dataloader.logo_dataset
import
LogoDataset
from
ppcls.data.dataloader.icartoon_dataset
import
ICartoonDataset
from
ppcls.data.dataloader.mix_dataset
import
MixDataset
...
...
ppcls/data/dataloader/DistributedRandomIdentitySampler.py
浏览文件 @
09200a31
...
...
@@ -82,26 +82,6 @@ class DistributedRandomIdentitySampler(DistributedBatchSampler):
avai_pids
=
copy
.
deepcopy
(
self
.
pids
)
return
batch_idxs_dict
,
avai_pids
,
count
def
__iter__
(
self
):
batch_idxs_dict
,
avai_pids
,
count
=
self
.
_prepare_batch
()
for
_
in
range
(
self
.
max_iters
):
final_idxs
=
[]
if
len
(
avai_pids
)
<
self
.
num_pids_per_batch
:
batch_idxs_dict
,
avai_pids
,
count
=
self
.
_prepare_batch
()
selected_pids
=
np
.
random
.
choice
(
avai_pids
,
self
.
num_pids_per_batch
,
False
,
count
/
count
.
sum
())
for
pid
in
selected_pids
:
batch_idxs
=
batch_idxs_dict
[
pid
].
pop
(
0
)
final_idxs
.
extend
(
batch_idxs
)
pid_idx
=
avai_pids
.
index
(
pid
)
if
len
(
batch_idxs_dict
[
pid
])
==
0
:
avai_pids
.
pop
(
pid_idx
)
count
=
np
.
delete
(
count
,
pid_idx
)
else
:
count
[
pid_idx
]
=
len
(
batch_idxs_dict
[
pid
])
yield
final_idxs
def
__iter__
(
self
):
# prepare
batch_idxs_dict
,
avai_pids
,
count
=
self
.
_prepare_batch
()
...
...
ppcls/data/dataloader/goods_dataset.py
已删除
100644 → 0
浏览文件 @
30cbb183
from
__future__
import
print_function
import
os
from
typing
import
Callable
,
List
import
numpy
as
np
import
paddle
from
paddle.io
import
Dataset
from
PIL
import
Image
from
ppcls.data.preprocess
import
transform
from
ppcls.utils
import
logger
from
.common_dataset
import
create_operators
class
GoodsDataset
(
Dataset
):
"""Dataset for Goods, such as SOP, Inshop...
Args:
image_root (str): image root
cls_label_path (str): path to annotation file
transform_ops (List[Callable], optional): list of transform op(s). Defaults to None.
backend (str, optional): pil or cv2. Defaults to "cv2".
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
"""
def
__init__
(
self
,
image_root
:
str
,
cls_label_path
:
str
,
transform_ops
:
List
[
Callable
]
=
None
,
backend
=
"cv2"
,
relabel
=
False
):
self
.
_img_root
=
image_root
self
.
_cls_path
=
cls_label_path
if
transform_ops
:
self
.
_transform_ops
=
create_operators
(
transform_ops
)
self
.
backend
=
backend
self
.
_dtype
=
paddle
.
get_default_dtype
()
self
.
_load_anno
(
relabel
)
def
_load_anno
(
self
,
seed
=
None
,
relabel
=
False
):
assert
os
.
path
.
exists
(
self
.
_cls_path
),
f
"path
{
self
.
_cls_path
}
does not exist."
assert
os
.
path
.
exists
(
self
.
_img_root
),
f
"path
{
self
.
_img_root
}
does not exist."
self
.
images
=
[]
self
.
labels
=
[]
self
.
cameras
=
[]
with
open
(
self
.
_cls_path
)
as
fd
:
lines
=
fd
.
readlines
()
if
relabel
:
label_set
=
set
()
for
line
in
lines
:
line
=
line
.
strip
().
split
()
label_set
.
add
(
np
.
int64
(
line
[
1
]))
label_map
=
{
oldlabel
:
newlabel
for
newlabel
,
oldlabel
in
enumerate
(
label_set
)
}
if
seed
is
not
None
:
np
.
random
.
RandomState
(
seed
).
shuffle
(
lines
)
for
line
in
lines
:
line
=
line
.
strip
().
split
()
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
line
[
0
]))
if
relabel
:
self
.
labels
.
append
(
label_map
[
np
.
int64
(
line
[
1
])])
else
:
self
.
labels
.
append
(
np
.
int64
(
line
[
1
]))
self
.
cameras
.
append
(
np
.
int64
(
line
[
2
]))
assert
os
.
path
.
exists
(
self
.
images
[
-
1
]),
f
"path
{
self
.
images
[
-
1
]
}
does not exist."
def
__getitem__
(
self
,
idx
):
try
:
img
=
Image
.
open
(
self
.
images
[
idx
]).
convert
(
"RGB"
)
if
self
.
backend
==
"cv2"
:
img
=
np
.
array
(
img
,
dtype
=
"float32"
).
astype
(
np
.
uint8
)
if
self
.
_transform_ops
:
img
=
transform
(
img
,
self
.
_transform_ops
)
if
self
.
backend
==
"cv2"
:
img
=
img
.
transpose
((
2
,
0
,
1
))
return
(
img
,
self
.
labels
[
idx
],
self
.
cameras
[
idx
])
except
Exception
as
ex
:
logger
.
error
(
"Exception occured when parse line: {} with msg: {}"
.
format
(
self
.
images
[
idx
],
ex
))
rnd_idx
=
np
.
random
.
randint
(
self
.
__len__
())
return
self
.
__getitem__
(
rnd_idx
)
def
__len__
(
self
):
return
len
(
self
.
images
)
@
property
def
class_num
(
self
):
return
len
(
set
(
self
.
labels
))
ppcls/data/dataloader/vehicle_dataset.py
浏览文件 @
09200a31
...
...
@@ -19,8 +19,7 @@ import paddle
from
paddle.io
import
Dataset
import
os
import
cv2
from
ppcls.data
import
preprocess
from
PIL
import
Image
from
ppcls.data.preprocess
import
transform
from
ppcls.utils
import
logger
from
.common_dataset
import
create_operators
...
...
@@ -89,15 +88,30 @@ class CompCars(Dataset):
class
VeriWild
(
Dataset
):
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
):
"""Dataset for Vehicle and other similar data structure, such as VeRI-Wild, SOP, Inshop...
Args:
image_root (str): image root
cls_label_path (str): path to annotation file
transform_ops (List[Callable], optional): list of transform op(s). Defaults to None.
backend (str, optional): pil or cv2. Defaults to "cv2".
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
"""
def
__init__
(
self
,
image_root
,
cls_label_path
,
transform_ops
=
None
,
backend
=
"cv2"
,
relabel
=
False
):
self
.
_img_root
=
image_root
self
.
_cls_path
=
cls_label_path
if
transform_ops
:
self
.
_transform_ops
=
create_operators
(
transform_ops
)
self
.
backend
=
backend
self
.
_dtype
=
paddle
.
get_default_dtype
()
self
.
_load_anno
()
self
.
_load_anno
(
relabel
)
def
_load_anno
(
self
):
def
_load_anno
(
self
,
relabel
):
assert
os
.
path
.
exists
(
self
.
_cls_path
),
f
"path
{
self
.
_cls_path
}
does not exist."
assert
os
.
path
.
exists
(
...
...
@@ -107,22 +121,40 @@ class VeriWild(Dataset):
self
.
cameras
=
[]
with
open
(
self
.
_cls_path
)
as
fd
:
lines
=
fd
.
readlines
()
if
relabel
:
label_set
=
set
()
for
line
in
lines
:
line
=
line
.
strip
().
split
()
label_set
.
add
(
np
.
int64
(
line
[
1
]))
label_map
=
{
oldlabel
:
newlabel
for
newlabel
,
oldlabel
in
enumerate
(
label_set
)
}
for
line
in
lines
:
line
=
line
.
strip
().
split
()
self
.
images
.
append
(
os
.
path
.
join
(
self
.
_img_root
,
line
[
0
]))
self
.
labels
.
append
(
np
.
int64
(
line
[
1
]))
if
relabel
:
self
.
labels
.
append
(
label_map
[
np
.
int64
(
line
[
1
])])
else
:
self
.
labels
.
append
(
np
.
int64
(
line
[
1
]))
if
len
(
line
)
>=
3
:
self
.
cameras
.
append
(
np
.
int64
(
line
[
2
]))
assert
os
.
path
.
exists
(
self
.
images
[
-
1
])
assert
os
.
path
.
exists
(
self
.
images
[
-
1
]),
\
f
"path
{
self
.
images
[
-
1
]
}
does not exist."
self
.
has_camera
=
len
(
self
.
cameras
)
>
0
def
__getitem__
(
self
,
idx
):
try
:
with
open
(
self
.
images
[
idx
],
'rb'
)
as
f
:
img
=
f
.
read
()
if
self
.
backend
==
"cv2"
:
with
open
(
self
.
images
[
idx
],
'rb'
)
as
f
:
img
=
f
.
read
()
else
:
img
=
Image
.
open
(
self
.
images
[
idx
]).
convert
(
"RGB"
)
if
self
.
_transform_ops
:
img
=
transform
(
img
,
self
.
_transform_ops
)
img
=
img
.
transpose
((
2
,
0
,
1
))
if
self
.
backend
==
"cv2"
:
img
=
img
.
transpose
((
2
,
0
,
1
))
if
self
.
has_camera
:
return
(
img
,
self
.
labels
[
idx
],
self
.
cameras
[
idx
])
else
:
...
...
ppcls/engine/engine.py
浏览文件 @
09200a31
...
...
@@ -42,6 +42,7 @@ from ppcls.data.utils.get_image_list import get_image_list
from
ppcls.data.postprocess
import
build_postprocess
from
ppcls.data
import
create_operators
from
ppcls.engine.train
import
train_epoch
from
ppcls.engine.train.utils
import
type_name
from
ppcls.engine
import
evaluation
from
ppcls.arch.gears.identity_head
import
IdentityHead
...
...
@@ -377,7 +378,7 @@ class Engine(object):
# step lr (by epoch) according to given metric, such as acc
for
i
in
range
(
len
(
self
.
lr_sch
)):
if
getattr
(
self
.
lr_sch
[
i
],
"by_epoch"
,
False
)
and
\
self
.
lr_sch
[
i
].
__class__
.
__name__
==
"ReduceOnPlateau"
:
type_name
(
self
.
lr_sch
[
i
])
==
"ReduceOnPlateau"
:
self
.
lr_sch
[
i
].
step
(
acc
)
if
acc
>
best_metric
[
"metric"
]:
...
...
ppcls/engine/train/train.py
浏览文件 @
09200a31
...
...
@@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function
import
time
import
paddle
from
ppcls.engine.train.utils
import
update_loss
,
update_metric
,
log_info
from
ppcls.engine.train.utils
import
update_loss
,
update_metric
,
log_info
,
type_name
from
ppcls.utils
import
profiler
...
...
@@ -98,7 +98,8 @@ def train_epoch(engine, epoch_id, print_batch_step):
# step lr(by epoch)
for
i
in
range
(
len
(
engine
.
lr_sch
)):
if
getattr
(
engine
.
lr_sch
[
i
],
"by_epoch"
,
False
):
if
getattr
(
engine
.
lr_sch
[
i
],
"by_epoch"
,
False
)
and
\
type_name
(
engine
.
lr_sch
[
i
])
!=
"ReduceOnPlateau"
:
engine
.
lr_sch
[
i
].
step
()
...
...
ppcls/engine/train/utils.py
浏览文件 @
09200a31
...
...
@@ -53,14 +53,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
ips_msg
=
"ips: {:.5f} samples/s"
.
format
(
batch_size
/
trainer
.
time_info
[
"batch_cost"
].
avg
)
eta_sec
=
(
(
trainer
.
config
[
"Global"
][
"epochs"
]
-
epoch_id
+
1
)
*
trainer
.
max_iter
-
iter_id
)
*
trainer
.
time_info
[
"batch_cost"
].
avg
eta_sec
=
(
(
trainer
.
config
[
"Global"
][
"epochs"
]
-
epoch_id
+
1
)
*
trainer
.
max_iter
-
iter_id
)
*
trainer
.
time_info
[
"batch_cost"
].
avg
eta_msg
=
"eta: {:s}"
.
format
(
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_sec
))))
logger
.
info
(
"[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}"
.
format
(
epoch_id
,
trainer
.
config
[
"Global"
][
"epochs"
],
iter_id
,
trainer
.
max_iter
,
lr_msg
,
metric_msg
,
time_msg
,
ips_msg
,
eta_msg
))
trainer
.
max_iter
,
lr_msg
,
metric_msg
,
time_msg
,
ips_msg
,
eta_msg
))
for
i
,
lr
in
enumerate
(
trainer
.
lr_sch
):
logger
.
scaler
(
...
...
@@ -74,3 +73,8 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
value
=
trainer
.
output_info
[
key
].
avg
,
step
=
trainer
.
global_step
,
writer
=
trainer
.
vdl_writer
)
def
type_name
(
object
:
object
)
->
str
:
"""get class name of an object"""
return
object
.
__class__
.
__name__
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录