Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
15acd6a5
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
15acd6a5
编写于
6月 10, 2021
作者:
C
cuicheng01
提交者:
GitHub
6月 10, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #822 from littletomatodonkey/reg/fix_data
fix reader
上级
799467e1
edec759f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
38 addition
and
348 deletion
+38
-348
ppcls/arch/__init__.py
ppcls/arch/__init__.py
+0
-1
ppcls/data/__init__.py
ppcls/data/__init__.py
+25
-16
ppcls/data/reader.py
ppcls/data/reader.py
+0
-319
ppcls/engine/trainer.py
ppcls/engine/trainer.py
+13
-12
未找到文件。
ppcls/arch/__init__.py
浏览文件 @
15acd6a5
...
...
@@ -84,7 +84,6 @@ class DistillationModel(nn.Layer):
assert
len
(
model_config
)
==
1
key
=
list
(
model_config
.
keys
())[
0
]
model_config
=
model_config
[
key
]
print
(
model_config
)
model_name
=
model_config
.
pop
(
"name"
)
model
=
eval
(
model_name
)(
**
model_config
)
...
...
ppcls/data/__init__.py
浏览文件 @
15acd6a5
...
...
@@ -19,7 +19,6 @@ from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader
from
ppcls.utils
import
logger
from
ppcls.data
import
dataloader
from
ppcls.data
import
imaug
# dataset
from
ppcls.data.dataloader.imagenet_dataset
import
ImageNetDataset
from
ppcls.data.dataloader.multilabel_dataset
import
MultiLabelDataset
...
...
@@ -28,15 +27,37 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from
ppcls.data.dataloader.logo_dataset
import
LogoDataset
from
ppcls.data.dataloader.icartoon_dataset
import
ICartoonDataset
# sampler
from
ppcls.data.dataloader.DistributedRandomIdentitySampler
import
DistributedRandomIdentitySampler
from
ppcls.data.preprocess
import
transform
def
create_operators
(
params
):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert
isinstance
(
params
,
list
),
(
'operator config should be a list'
)
ops
=
[]
for
operator
in
params
:
assert
isinstance
(
operator
,
dict
)
and
len
(
operator
)
==
1
,
"yaml format error"
op_name
=
list
(
operator
)[
0
]
param
=
{}
if
operator
[
op_name
]
is
None
else
operator
[
op_name
]
op
=
getattr
(
imaug
,
op_name
)(
**
param
)
ops
.
append
(
op
)
return
ops
def
build_dataloader
(
config
,
mode
,
device
,
seed
=
None
):
assert
mode
in
[
'Train'
,
'Eval'
,
'Test'
,
],
"Mode should be Train, Eval, Test"
assert
mode
in
[
'Train'
,
'Eval'
,
'Test'
,
],
"Mode should be Train, Eval, Test"
# build dataset
config_dataset
=
config
[
mode
][
'dataset'
]
config_dataset
=
copy
.
deepcopy
(
config_dataset
)
...
...
@@ -109,16 +130,4 @@ def build_dataloader(config, mode, device, seed=None):
collate_fn
=
batch_collate_fn
)
logger
.
info
(
"build data_loader({}) success..."
.
format
(
data_loader
))
return
data_loader
'''
# TODO: fix the format
def build_dataloader(config, mode, device, seed=None):
from . import reader
from .reader import Reader
dataloader = Reader(config, mode=mode, places=device)()
return dataloader
'''
ppcls/data/reader.py
已删除
100755 → 0
浏览文件 @
799467e1
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
random
import
imghdr
import
os
import
signal
from
paddle.io
import
Dataset
,
DataLoader
,
DistributedBatchSampler
from
.
import
imaug
from
.imaug
import
transform
from
ppcls.utils
import
logger
trainers_num
=
int
(
os
.
environ
.
get
(
'PADDLE_TRAINERS_NUM'
,
1
))
trainer_id
=
int
(
os
.
environ
.
get
(
"PADDLE_TRAINER_ID"
,
0
))
class
ModeException
(
Exception
):
"""
ModeException
"""
def
__init__
(
self
,
message
=
''
,
mode
=
''
):
message
+=
"
\n
Only the following 3 modes are supported: "
\
"train, valid, test. Given mode is {}"
.
format
(
mode
)
super
(
ModeException
,
self
).
__init__
(
message
)
class
SampleNumException
(
Exception
):
"""
SampleNumException
"""
def
__init__
(
self
,
message
=
''
,
sample_num
=
0
,
batch_size
=
1
):
message
+=
"
\n
Error: The number of the whole data ({}) "
\
"is smaller than the batch_size ({}), and drop_last "
\
"is turnning on, so nothing will feed in program, "
\
"Terminated now. Please reset batch_size to a smaller "
\
"number or feed more data!"
.
format
(
sample_num
,
batch_size
)
super
(
SampleNumException
,
self
).
__init__
(
message
)
class
ShuffleSeedException
(
Exception
):
"""
ShuffleSeedException
"""
def
__init__
(
self
,
message
=
''
):
message
+=
"
\n
If trainers_num > 1, the shuffle_seed must be set, "
\
"because the order of batch data generated by reader "
\
"must be the same in the respective processes."
super
(
ShuffleSeedException
,
self
).
__init__
(
message
)
def
check_params
(
params
):
"""
check params to avoid unexpect errors
Args:
params(dict):
"""
if
'shuffle_seed'
not
in
params
:
params
[
'shuffle_seed'
]
=
None
if
trainers_num
>
1
and
params
[
'shuffle_seed'
]
is
None
:
raise
ShuffleSeedException
()
data_dir
=
params
.
get
(
'data_dir'
,
''
)
assert
os
.
path
.
isdir
(
data_dir
),
\
"{} doesn't exist, please check datadir path"
.
format
(
data_dir
)
if
params
[
'mode'
]
!=
'test'
:
file_list
=
params
.
get
(
'file_list'
,
''
)
assert
os
.
path
.
isfile
(
file_list
),
\
"{} doesn't exist, please check file list path"
.
format
(
file_list
)
def
create_file_list
(
params
):
"""
if mode is test, create the file list
Args:
params(dict):
"""
data_dir
=
params
.
get
(
'data_dir'
,
''
)
params
[
'file_list'
]
=
".tmp.txt"
imgtype_list
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
}
with
open
(
params
[
'file_list'
],
"w"
)
as
fout
:
tmp_file_list
=
os
.
listdir
(
data_dir
)
for
file_name
in
tmp_file_list
:
file_path
=
os
.
path
.
join
(
data_dir
,
file_name
)
if
imghdr
.
what
(
file_path
)
not
in
imgtype_list
:
continue
fout
.
write
(
file_name
+
" 0"
+
"
\n
"
)
def
shuffle_lines
(
full_lines
,
seed
=
None
):
"""
random shuffle lines
Args:
full_lines(list):
seed(int): random seed
"""
if
seed
is
not
None
:
np
.
random
.
RandomState
(
seed
).
shuffle
(
full_lines
)
else
:
np
.
random
.
shuffle
(
full_lines
)
return
full_lines
def
get_file_list
(
params
):
"""
read label list from file and shuffle the list
Args:
params(dict):
"""
if
params
[
'mode'
]
==
'test'
:
create_file_list
(
params
)
with
open
(
params
[
'file_list'
])
as
flist
:
full_lines
=
[
line
.
strip
()
for
line
in
flist
]
if
params
[
"mode"
]
==
"train"
:
full_lines
=
shuffle_lines
(
full_lines
,
seed
=
params
[
'shuffle_seed'
])
return
full_lines
def
create_operators
(
params
):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert
isinstance
(
params
,
list
),
(
'operator config should be a list'
)
ops
=
[]
for
operator
in
params
:
assert
isinstance
(
operator
,
dict
)
and
len
(
operator
)
==
1
,
"yaml format error"
op_name
=
list
(
operator
)[
0
]
param
=
{}
if
operator
[
op_name
]
is
None
else
operator
[
op_name
]
op
=
getattr
(
imaug
,
op_name
)(
**
param
)
ops
.
append
(
op
)
return
ops
def
term_mp
(
sig_num
,
frame
):
""" kill all child processes
"""
pid
=
os
.
getpid
()
pgid
=
os
.
getpgid
(
os
.
getpid
())
logger
.
info
(
"main proc {} exit, kill process group "
"{}"
.
format
(
pid
,
pgid
))
os
.
killpg
(
pgid
,
signal
.
SIGKILL
)
return
class
CommonDataset
(
Dataset
):
def
__init__
(
self
,
params
):
self
.
params
=
params
self
.
mode
=
params
.
get
(
"mode"
,
"train"
)
self
.
full_lines
=
get_file_list
(
params
)
self
.
delimiter
=
params
.
get
(
'delimiter'
,
' '
)
self
.
ops
=
create_operators
(
params
[
'transforms'
])
self
.
num_samples
=
len
(
self
.
full_lines
)
return
def
__getitem__
(
self
,
idx
):
try
:
line
=
self
.
full_lines
[
idx
]
img_path
,
label
=
line
.
split
(
self
.
delimiter
)
img_path
=
os
.
path
.
join
(
self
.
params
[
'data_dir'
],
img_path
)
with
open
(
img_path
,
'rb'
)
as
f
:
img
=
f
.
read
()
return
(
transform
(
img
,
self
.
ops
),
int
(
label
))
except
Exception
as
e
:
logger
.
error
(
"data read faild: {}, exception info: {}"
.
format
(
line
,
e
))
return
self
.
__getitem__
(
random
.
randint
(
0
,
len
(
self
)))
def
__len__
(
self
):
return
self
.
num_samples
class
MultiLabelDataset
(
Dataset
):
"""
Define dataset class for multilabel image classification
"""
def
__init__
(
self
,
params
):
self
.
params
=
params
self
.
mode
=
params
.
get
(
"mode"
,
"train"
)
self
.
full_lines
=
get_file_list
(
params
)
self
.
delimiter
=
params
.
get
(
"delimiter"
,
"
\t
"
)
self
.
ops
=
create_operators
(
params
[
"transforms"
])
self
.
num_samples
=
len
(
self
.
full_lines
)
return
def
__getitem__
(
self
,
idx
):
try
:
line
=
self
.
full_lines
[
idx
]
img_path
,
label_str
=
line
.
split
(
self
.
delimiter
)
img_path
=
os
.
path
.
join
(
self
.
params
[
"data_dir"
],
img_path
)
with
open
(
img_path
,
"rb"
)
as
f
:
img
=
f
.
read
()
labels
=
label_str
.
split
(
','
)
labels
=
[
int
(
i
)
for
i
in
labels
]
return
(
transform
(
img
,
self
.
ops
),
np
.
array
(
labels
).
astype
(
"float32"
))
except
Exception
as
e
:
logger
.
error
(
"data read failed: {}, exception info: {}"
.
format
(
line
,
e
))
return
self
.
__getitem__
(
random
.
randint
(
0
,
len
(
self
)))
def
__len__
(
self
):
return
self
.
num_samples
class
Reader
:
"""
Create a reader for trainning/validate/test
Args:
config(dict): arguments
mode(str): train or val or test
seed(int): random seed used to generate same sequence in each trainer
Returns:
the specific reader
"""
def
__init__
(
self
,
config
,
mode
=
'train'
,
places
=
None
):
try
:
self
.
params
=
config
[
mode
.
capitalize
()]
except
KeyError
:
raise
ModeException
(
mode
=
mode
)
use_mix
=
config
.
get
(
'use_mix'
)
self
.
params
[
'mode'
]
=
mode
self
.
shuffle
=
mode
==
"train"
self
.
is_train
=
mode
==
"train"
self
.
collate_fn
=
None
self
.
batch_ops
=
[]
if
use_mix
and
mode
==
"train"
:
self
.
batch_ops
=
create_operators
(
self
.
params
[
'mix'
])
self
.
collate_fn
=
self
.
mix_collate_fn
self
.
places
=
places
self
.
use_xpu
=
config
.
get
(
"use_xpu"
,
False
)
self
.
multilabel
=
config
.
get
(
"multilabel"
,
False
)
def
mix_collate_fn
(
self
,
batch
):
batch
=
transform
(
batch
,
self
.
batch_ops
)
# batch each field
slots
=
[]
for
items
in
batch
:
for
i
,
item
in
enumerate
(
items
):
if
len
(
slots
)
<
len
(
items
):
slots
.
append
([
item
])
else
:
slots
[
i
].
append
(
item
)
return
[
np
.
stack
(
slot
,
axis
=
0
)
for
slot
in
slots
]
def
__call__
(
self
):
batch_size
=
int
(
self
.
params
[
'batch_size'
])
//
trainers_num
if
self
.
multilabel
:
dataset
=
MultiLabelDataset
(
self
.
params
)
else
:
dataset
=
CommonDataset
(
self
.
params
)
if
(
self
.
params
[
'mode'
]
!=
"train"
)
and
self
.
use_xpu
:
loader
=
DataLoader
(
dataset
,
places
=
self
.
places
,
batch_size
=
batch_size
,
drop_last
=
False
,
return_list
=
True
,
shuffle
=
False
,
num_workers
=
self
.
params
[
"num_workers"
])
else
:
is_train
=
self
.
is_train
batch_sampler
=
DistributedBatchSampler
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
self
.
shuffle
and
is_train
,
drop_last
=
is_train
)
loader
=
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
collate_fn
=
self
.
collate_fn
if
is_train
else
None
,
places
=
self
.
places
,
return_list
=
True
,
num_workers
=
self
.
params
[
"num_workers"
])
return
loader
signal
.
signal
(
signal
.
SIGINT
,
term_mp
)
signal
.
signal
(
signal
.
SIGTERM
,
term_mp
)
ppcls/engine/trainer.py
浏览文件 @
15acd6a5
...
...
@@ -41,7 +41,7 @@ from ppcls.utils import save_load
from
ppcls.data.utils.get_image_list
import
get_image_list
from
ppcls.data.postprocess
import
build_postprocess
from
ppcls.data
.reader
import
create_operators
from
ppcls.data
import
create_operators
class
Trainer
(
object
):
...
...
@@ -413,8 +413,7 @@ class Trainer(object):
if
query_query_id
is
not
None
:
query_id_blocks
=
paddle
.
split
(
query_query_id
,
num_or_sections
=
sections
)
image_id_blocks
=
paddle
.
split
(
query_img_id
,
num_or_sections
=
sections
)
image_id_blocks
=
paddle
.
split
(
query_img_id
,
num_or_sections
=
sections
)
metric_key
=
None
if
self
.
eval_metric_func
is
None
:
...
...
@@ -432,20 +431,23 @@ class Trainer(object):
image_id_mask
=
(
image_id_block
!=
gallery_img_id
.
t
())
keep_mask
=
paddle
.
logical_or
(
query_id_mask
,
image_id_mask
)
similarity_matrix
=
similarity_matrix
*
keep_mask
.
astype
(
"float32"
)
metric_tmp
=
self
.
eval_metric_func
(
similarity_matrix
,
image_id_blocks
[
block_idx
],
gallery_img_id
)
similarity_matrix
=
similarity_matrix
*
keep_mask
.
astype
(
"float32"
)
metric_tmp
=
self
.
eval_metric_func
(
similarity_matrix
,
image_id_blocks
[
block_idx
],
gallery_img_id
)
for
key
in
metric_tmp
:
if
key
not
in
metric_dict
:
metric_dict
[
key
]
=
metric_tmp
[
key
]
else
:
metric_dict
[
key
]
+=
metric_tmp
[
key
]
num_sections
=
len
(
fea_blocks
)
for
key
in
metric_dict
:
metric_dict
[
key
]
=
metric_dict
[
key
]
/
num_sections
metric_dict
[
key
]
=
metric_dict
[
key
]
/
num_sections
metric_info_list
=
[]
for
key
in
metric_dict
:
if
metric_key
is
None
:
...
...
@@ -454,8 +456,7 @@ class Trainer(object):
metric_msg
=
", "
.
join
(
metric_info_list
)
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
return
metric_dict
[
metric_key
]
return
metric_dict
[
metric_key
]
def
_cal_feature
(
self
,
name
=
'gallery'
):
all_feas
=
None
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录