Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
a4753fc3
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a4753fc3
编写于
10月 10, 2020
作者:
L
littletomatodonkey
提交者:
GitHub
10月 10, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #296 from littletomatodonkey/dyg/add_dataloader
add dataloader inferface
上级
53c5850d
2d86d1ac
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
121 addition
and
131 deletion
+121
-131
.pre-commit-config.yaml
.pre-commit-config.yaml
+35
-35
configs/MobileNetV3/MobileNetV3_large_x1_0.yaml
configs/MobileNetV3/MobileNetV3_large_x1_0.yaml
+1
-1
ppcls/data/reader.py
ppcls/data/reader.py
+72
-77
ppcls/modeling/loss.py
ppcls/modeling/loss.py
+0
-1
tools/eval.py
tools/eval.py
+2
-3
tools/train.py
tools/train.py
+11
-14
未找到文件。
.pre-commit-config.yaml
浏览文件 @
a4753fc3
...
...
@@ -3,33 +3,33 @@
hooks
:
-
id
:
yapf
files
:
\.py$
-
repo
:
https://github.com/pre-commit/mirrors-autopep8
rev
:
v1.5
-
repo
:
https://github.com/pre-commit/pre-commit-hooks
sha
:
a11d9314b22d8f8c7556443875b731ef05965464
hooks
:
-
id
:
autopep8
-
id
:
check-merge-conflict
-
id
:
check-symlinks
-
id
:
detect-private-key
files
:
(?!.*paddle)^.*$
-
id
:
end-of-file-fixer
files
:
\.md$
-
id
:
trailing-whitespace
files
:
\.md$
-
repo
:
https://github.com/Lucas-C/pre-commit-hooks
sha
:
v1.0.1
hooks
:
-
id
:
forbid-crlf
files
:
\.(md|yml)
$
files
:
\.md
$
-
id
:
remove-crlf
files
:
\.(md|yml)
$
files
:
\.md
$
-
id
:
forbid-tabs
files
:
\.(md|yml)
$
files
:
\.md
$
-
id
:
remove-tabs
files
:
\.(md|yml)$
-
repo
:
https://github.com/pre-commit/pre-commit-hooks
rev
:
v2.5.0
files
:
\.md$
-
repo
:
local
hooks
:
-
id
:
check-yaml
-
id
:
check-merge-conflict
-
id
:
detect-private-key
files
:
(?!.*paddle)^.*$
-
id
:
end-of-file-fixer
files
:
\.(md|yml)$
-
id
:
trailing-whitespace
files
:
\.(md|yml)$
-
id
:
check-case-conflict
-
id
:
clang-format
name
:
clang-format
description
:
Format files with ClangFormat
entry
:
bash .clang_format.hook -i
language
:
system
files
:
\.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
configs/MobileNetV3/MobileNetV3_large_x1_0.yaml
浏览文件 @
a4753fc3
...
...
@@ -54,7 +54,7 @@ TRAIN:
VALID
:
batch_size
:
32
batch_size
:
1024
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
...
...
ppcls/data/reader.py
浏览文件 @
a4753fc3
...
...
@@ -17,7 +17,7 @@ import imghdr
import
os
import
signal
from
paddle.
reader
import
multiprocess_read
er
from
paddle.
io
import
Dataset
,
DataLoader
,
DistributedBatchSampl
er
from
.
import
imaug
from
.imaug
import
transform
...
...
@@ -109,7 +109,6 @@ def create_file_list(params):
def
shuffle_lines
(
full_lines
,
seed
=
None
):
"""
random shuffle lines
Args:
full_lines(list):
seed(int): random seed
...
...
@@ -135,12 +134,8 @@ def get_file_list(params):
with
open
(
params
[
'file_list'
])
as
flist
:
full_lines
=
[
line
.
strip
()
for
line
in
flist
]
full_lines
=
shuffle_lines
(
full_lines
,
params
[
"shuffle_seed"
])
# use only partial data for each trainer in distributed training
if
params
[
'mode'
]
==
'train'
:
img_per_trainer
=
len
(
full_lines
)
//
trainers_num
full_lines
=
full_lines
[
trainer_id
::
trainers_num
][:
img_per_trainer
]
if
params
[
"mode"
]
==
"train"
:
full_lines
=
shuffle_lines
(
full_lines
,
seed
=
params
[
'shuffle_seed'
])
return
full_lines
...
...
@@ -165,60 +160,6 @@ def create_operators(params):
return
ops
def
partial_reader
(
params
,
full_lines
,
part_id
=
0
,
part_num
=
1
):
"""
create a reader with partial data
Args:
params(dict):
full_lines: label list
part_id(int): part index of the current partial data
part_num(int): part num of the dataset
"""
assert
part_id
<
part_num
,
(
"part_num: {} should be larger "
"than part_id: {}"
.
format
(
part_num
,
part_id
))
full_lines
=
full_lines
[
part_id
::
part_num
]
batch_size
=
int
(
params
[
'batch_size'
])
//
trainers_num
if
params
[
'mode'
]
!=
"test"
and
len
(
full_lines
)
<
batch_size
:
raise
SampleNumException
(
''
,
len
(
full_lines
),
batch_size
)
def
reader
():
ops
=
create_operators
(
params
[
'transforms'
])
delimiter
=
params
.
get
(
'delimiter'
,
' '
)
for
line
in
full_lines
:
img_path
,
label
=
line
.
split
(
delimiter
)
img_path
=
os
.
path
.
join
(
params
[
'data_dir'
],
img_path
)
with
open
(
img_path
,
'rb'
)
as
f
:
img
=
f
.
read
()
yield
(
transform
(
img
,
ops
),
int
(
label
))
return
reader
def
mp_reader
(
params
):
"""
multiprocess reader
Args:
params(dict):
"""
check_params
(
params
)
full_lines
=
get_file_list
(
params
)
if
params
[
"mode"
]
==
"train"
:
full_lines
=
shuffle_lines
(
full_lines
,
seed
=
None
)
part_num
=
1
if
'num_workers'
not
in
params
else
params
[
'num_workers'
]
readers
=
[]
for
part_id
in
range
(
part_num
):
readers
.
append
(
partial_reader
(
params
,
full_lines
,
part_id
,
part_num
))
return
multiprocess_reader
(
readers
,
use_pipe
=
False
)
def
term_mp
(
sig_num
,
frame
):
""" kill all child processes
"""
...
...
@@ -227,6 +168,29 @@ def term_mp(sig_num, frame):
logger
.
info
(
"main proc {} exit, kill process group "
"{}"
.
format
(
pid
,
pgid
))
os
.
killpg
(
pgid
,
signal
.
SIGKILL
)
return
class
CommonDataset
(
Dataset
):
def
__init__
(
self
,
params
):
self
.
params
=
params
self
.
mode
=
params
.
get
(
"mode"
,
"train"
)
self
.
full_lines
=
get_file_list
(
params
)
self
.
delimiter
=
params
.
get
(
'delimiter'
,
' '
)
self
.
ops
=
create_operators
(
params
[
'transforms'
])
self
.
num_samples
=
len
(
self
.
full_lines
)
return
def
__getitem__
(
self
,
idx
):
line
=
self
.
full_lines
[
idx
]
img_path
,
label
=
line
.
split
(
self
.
delimiter
)
img_path
=
os
.
path
.
join
(
self
.
params
[
'data_dir'
],
img_path
)
with
open
(
img_path
,
'rb'
)
as
f
:
img
=
f
.
read
()
return
(
transform
(
img
,
self
.
ops
),
int
(
label
))
def
__len__
(
self
):
return
self
.
num_samples
class
Reader
:
...
...
@@ -242,7 +206,7 @@ class Reader:
the specific reader
"""
def
__init__
(
self
,
config
,
mode
=
'train'
,
seed
=
None
):
def
__init__
(
self
,
config
,
mode
=
'train'
,
places
=
None
):
try
:
self
.
params
=
config
[
mode
.
upper
()]
except
KeyError
:
...
...
@@ -250,27 +214,58 @@ class Reader:
use_mix
=
config
.
get
(
'use_mix'
)
self
.
params
[
'mode'
]
=
mode
if
seed
is
not
None
:
self
.
params
[
'shuffle_seed'
]
=
seed
self
.
shuffle
=
mode
==
"train"
self
.
collate_fn
=
None
self
.
batch_ops
=
[]
if
use_mix
and
mode
==
"train"
:
self
.
batch_ops
=
create_operators
(
self
.
params
[
'mix'
])
self
.
collate_fn
=
self
.
mix_collate_fn
def
__call__
(
self
):
batch_size
=
int
(
self
.
params
[
'batch_size'
])
//
trainers_num
self
.
places
=
places
def
wrapper
():
reader
=
mp_reader
(
self
.
params
)
batch
=
[]
for
idx
,
sample
in
enumerate
(
reader
()):
img
,
label
=
sample
batch
.
append
((
img
,
label
))
if
(
idx
+
1
)
%
batch_size
==
0
:
def
mix_collate_fn
(
self
,
batch
):
batch
=
transform
(
batch
,
self
.
batch_ops
)
yield
batch
batch
=
[]
# batch each field
slots
=
[]
for
items
in
batch
:
for
i
,
item
in
enumerate
(
items
):
if
len
(
slots
)
<
len
(
items
):
slots
.
append
([
item
])
else
:
slots
[
i
].
append
(
item
)
return
wrapper
return
[
np
.
stack
(
slot
,
axis
=
0
)
for
slot
in
slots
]
def
__call__
(
self
):
batch_size
=
int
(
self
.
params
[
'batch_size'
])
//
trainers_num
dataset
=
CommonDataset
(
self
.
params
)
if
self
.
params
[
'mode'
]
==
"train"
:
batch_sampler
=
DistributedBatchSampler
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
self
.
shuffle
,
drop_last
=
True
)
loader
=
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
collate_fn
=
self
.
collate_fn
,
places
=
self
.
places
,
return_list
=
True
,
num_workers
=
self
.
params
[
"num_workers"
])
else
:
loader
=
DataLoader
(
dataset
,
places
=
self
.
places
,
batch_size
=
batch_size
,
drop_last
=
False
,
return_list
=
True
,
shuffle
=
False
,
num_workers
=
self
.
params
[
"num_workers"
])
return
loader
signal
.
signal
(
signal
.
SIGINT
,
term_mp
)
...
...
ppcls/modeling/loss.py
浏览文件 @
a4753fc3
...
...
@@ -49,7 +49,6 @@ class Loss(object):
input
=
-
F
.
log_softmax
(
input
,
axis
=-
1
)
cost
=
paddle
.
reduce_sum
(
target
*
input
,
dim
=-
1
)
else
:
# softmax_out = F.softmax(input)
cost
=
F
.
cross_entropy
(
input
=
input
,
label
=
target
)
avg_cost
=
paddle
.
mean
(
cost
)
return
avg_cost
...
...
tools/eval.py
浏览文件 @
a4753fc3
...
...
@@ -63,10 +63,9 @@ def main(args):
net
=
program
.
create_model
(
config
.
ARCHITECTURE
,
config
.
classes_num
)
net
=
paddle
.
DataParallel
(
net
,
strategy
)
init_model
(
config
,
net
,
optimizer
=
None
)
valid_dataloader
=
program
.
create_dataloader
()
valid_reader
=
Reader
(
config
,
'valid'
)()
valid_dataloader
.
set_sample_list_generator
(
valid_reader
,
place
)
valid_dataloader
=
Reader
(
config
,
'valid'
,
places
=
place
)()
net
.
eval
()
top1_acc
=
program
.
run
(
valid_dataloader
,
config
,
net
,
None
,
None
,
0
,
'valid'
)
...
...
tools/train.py
浏览文件 @
a4753fc3
...
...
@@ -23,7 +23,6 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
import
paddle
from
paddle.distributed
import
ParallelEnv
...
...
@@ -33,6 +32,7 @@ from ppcls.utils.save_load import init_model, save_model
from
ppcls.utils
import
logger
import
program
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"PaddleClas train script"
)
parser
.
add_argument
(
...
...
@@ -78,16 +78,13 @@ def main(args):
# load model from checkpoint or pretrained model
init_model
(
config
,
net
,
optimizer
)
train_dataloader
=
program
.
create_dataloader
()
train_reader
=
Reader
(
config
,
'train'
)()
train_dataloader
.
set_sample_list_generator
(
train_reader
,
place
)
train_dataloader
=
Reader
(
config
,
'train'
,
places
=
place
)()
if
config
.
validate
:
valid_dataloader
=
program
.
create_dataloader
()
valid_reader
=
Reader
(
config
,
'valid'
)()
valid_dataloader
.
set_sample_list_generator
(
valid_reader
,
place
)
if
config
.
validate
and
ParallelEnv
().
local_rank
==
0
:
valid_dataloader
=
Reader
(
config
,
'valid'
,
places
=
place
)()
best_top1_acc
=
0.0
# best top1 acc record
best_top1_epoch
=
0
for
epoch_id
in
range
(
config
.
epochs
):
net
.
train
()
# 1. train with train dataset
...
...
@@ -98,18 +95,18 @@ def main(args):
# 2. validate with validate dataset
if
config
.
validate
and
epoch_id
%
config
.
valid_interval
==
0
:
net
.
eval
()
top1_acc
=
program
.
run
(
valid_dataloader
,
config
,
net
,
None
,
None
,
epoch_id
,
'valid'
)
top1_acc
=
program
.
run
(
valid_dataloader
,
config
,
net
,
None
,
None
,
epoch_id
,
'valid'
)
if
top1_acc
>
best_top1_acc
:
best_top1_acc
=
top1_acc
message
=
"The best top1 acc {:.5f}, in epoch: {:d}"
.
format
(
best_top1_acc
,
epoch_id
)
logger
.
info
(
"{:s}"
.
format
(
logger
.
coloring
(
message
,
"RED"
)))
best_top1_epoch
=
epoch_id
if
epoch_id
%
config
.
save_interval
==
0
:
model_path
=
os
.
path
.
join
(
config
.
model_save_dir
,
config
.
ARCHITECTURE
[
"name"
])
save_model
(
net
,
optimizer
,
model_path
,
"best_model"
)
message
=
"The best top1 acc {:.5f}, in epoch: {:d}"
.
format
(
best_top1_acc
,
best_top1_epoch
)
logger
.
info
(
"{:s}"
.
format
(
logger
.
coloring
(
message
,
"RED"
)))
# 3. save the persistable model
if
epoch_id
%
config
.
save_interval
==
0
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录