Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
c089b67b
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c089b67b
编写于
7月 31, 2019
作者:
W
walloollaw
提交者:
qingqing01
7月 31, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support custom reader in ppdet.data (#2965)
上级
3a6c1f95
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
249 addition
and
29 deletion
+249
-29
ppdet/data/data_feed.py
ppdet/data/data_feed.py
+29
-22
ppdet/data/reader.py
ppdet/data/reader.py
+21
-6
ppdet/data/source/__init__.py
ppdet/data/source/__init__.py
+4
-1
ppdet/data/source/iterator_source.py
ppdet/data/source/iterator_source.py
+103
-0
ppdet/data/tests/run_all_tests.py
ppdet/data/tests/run_all_tests.py
+2
-0
ppdet/data/tests/set_env.py
ppdet/data/tests/set_env.py
+3
-0
ppdet/data/tests/test_iterator_source.py
ppdet/data/tests/test_iterator_source.py
+60
-0
ppdet/data/tests/test_reader.py
ppdet/data/tests/test_reader.py
+27
-0
未找到文件。
ppdet/data/data_feed.py
浏览文件 @
c089b67b
...
...
@@ -42,14 +42,7 @@ __all__ = [
]
def
create_reader
(
feed
,
max_iter
=
0
,
args_path
=
None
):
"""
Return iterable data reader.
Args:
max_iter (int): number of iterations.
"""
def
_prepare_data_config
(
feed
,
args_path
):
# if `DATASET_DIR` does not exists, search ~/.paddle/dataset for a directory
# named `DATASET_DIR` (e.g., coco, pascal), if not present either, download
dataset_home
=
args_path
if
args_path
else
feed
.
dataset
.
dataset_dir
...
...
@@ -72,22 +65,36 @@ def create_reader(feed, max_iter=0, args_path=None):
if
getattr
(
feed
,
'use_process'
,
None
)
is
not
None
:
use_process
=
feed
.
use_process
mode
=
feed
.
mode
data_config
=
{
mode
:
{
'ANNO_FILE'
:
feed
.
dataset
.
annotation
,
'IMAGE_DIR'
:
feed
.
dataset
.
image_dir
,
'USE_DEFAULT_LABEL'
:
feed
.
dataset
.
use_default_label
,
'IS_SHUFFLE'
:
feed
.
shuffle
,
'SAMPLES'
:
feed
.
samples
,
'WITH_BACKGROUND'
:
feed
.
with_background
,
'MIXUP_EPOCH'
:
mixup_epoch
,
'TYPE'
:
type
(
feed
.
dataset
).
__source__
}
'ANNO_FILE'
:
feed
.
dataset
.
annotation
,
'IMAGE_DIR'
:
feed
.
dataset
.
image_dir
,
'USE_DEFAULT_LABEL'
:
feed
.
dataset
.
use_default_label
,
'IS_SHUFFLE'
:
feed
.
shuffle
,
'SAMPLES'
:
feed
.
samples
,
'WITH_BACKGROUND'
:
feed
.
with_background
,
'MIXUP_EPOCH'
:
mixup_epoch
,
'TYPE'
:
type
(
feed
.
dataset
).
__source__
}
if
len
(
getattr
(
feed
.
dataset
,
'images'
,
[]))
>
0
:
data_config
[
mode
][
'IMAGES'
]
=
feed
.
dataset
.
images
data_config
[
'IMAGES'
]
=
feed
.
dataset
.
images
return
data_config
def
create_reader
(
feed
,
max_iter
=
0
,
args_path
=
None
,
my_source
=
None
):
"""
Return iterable data reader.
Args:
max_iter (int): number of iterations.
my_source (callable): callable function to create a source iterator
which is used to provide source data in 'ppdet.data.reader'
"""
# if `DATASET_DIR` does not exists, search ~/.paddle/dataset for a directory
# named `DATASET_DIR` (e.g., coco, pascal), if not present either, download
data_config
=
_prepare_data_config
(
feed
,
args_path
)
transform_config
=
{
'WORKER_CONF'
:
{
...
...
@@ -130,8 +137,8 @@ def create_reader(feed, max_iter=0, args_path=None):
ops
.
append
(
op_dict
)
transform_config
[
'OPS'
]
=
ops
re
ader
=
Reader
(
data_config
,
{
mode
:
transform_config
},
max_iter
)
return
reader
.
_make_reader
(
mod
e
)
re
turn
Reader
.
create
(
feed
.
mode
,
data_config
,
transform_config
,
max_iter
,
my_sourc
e
)
# XXX batch transforms are only stubs for now, actually handled by `post_map`
...
...
ppdet/data/reader.py
浏览文件 @
c089b67b
...
...
@@ -40,14 +40,17 @@ class Reader(object):
self
.
_cname2cid
=
None
assert
isinstance
(
self
.
_maxiter
,
Integral
),
"maxiter should be int"
def
_make_reader
(
self
,
mode
):
def
_make_reader
(
self
,
mode
,
my_source
=
None
):
"""Build reader for training or validation"""
file_conf
=
self
.
_data_cf
[
mode
]
if
my_source
is
None
:
file_conf
=
self
.
_data_cf
[
mode
]
# 1, Build data source
# 1, Build data source
sc_conf
=
{
'data_cf'
:
file_conf
,
'cname2cid'
:
self
.
_cname2cid
}
sc
=
build_source
(
sc_conf
)
sc_conf
=
{
'data_cf'
:
file_conf
,
'cname2cid'
:
self
.
_cname2cid
}
sc
=
build_source
(
sc_conf
)
else
:
sc
=
my_source
# 2, Buid a transformed dataset
ops
=
self
.
_trans_conf
[
mode
][
'OPS'
]
...
...
@@ -87,7 +90,7 @@ class Reader(object):
if
mode
.
lower
()
==
'train'
:
if
self
.
_cname2cid
is
not
None
:
logger
.
warn
(
'cname2cid already set, it will be overridden'
)
self
.
_cname2cid
=
sc
.
cname2cid
self
.
_cname2cid
=
getattr
(
sc
,
'cname2cid'
,
None
)
# 3, Build a reader
maxit
=
-
1
if
self
.
_maxiter
<=
0
else
self
.
_maxiter
...
...
@@ -120,3 +123,15 @@ class Reader(object):
def
test
(
self
):
"""Build reader for inference"""
return
self
.
_make_reader
(
'TEST'
)
@
classmethod
def
create
(
cls
,
mode
,
data_config
,
transform_config
,
max_iter
=-
1
,
my_source
=
None
,
ret_iter
=
True
):
""" create a specific reader """
reader
=
Reader
({
mode
:
data_config
},
{
mode
:
transform_config
},
max_iter
)
if
ret_iter
:
return
reader
.
_make_reader
(
mode
,
my_source
)
else
:
return
reader
ppdet/data/source/__init__.py
浏览文件 @
c089b67b
...
...
@@ -20,6 +20,7 @@ import copy
from
.roidb_source
import
RoiDbSource
from
.simple_source
import
SimpleSource
from
.iterator_source
import
IteratorSource
def
build_source
(
config
):
...
...
@@ -40,11 +41,13 @@ def build_source(config):
}
"""
if
'data_cf'
in
config
:
data_cf
=
{
k
.
lower
():
v
for
k
,
v
in
config
[
'data_cf'
].
items
()}
data_cf
=
config
[
'data_cf'
]
data_cf
[
'cname2cid'
]
=
config
[
'cname2cid'
]
else
:
data_cf
=
config
data_cf
=
{
k
.
lower
():
v
for
k
,
v
in
data_cf
.
items
()}
args
=
copy
.
deepcopy
(
data_cf
)
# defaut type is 'RoiDbSource'
source_type
=
'RoiDbSource'
...
...
ppdet/data/source/iterator_source.py
0 → 100644
浏览文件 @
c089b67b
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
numpy
as
np
import
copy
import
logging
logger
=
logging
.
getLogger
(
__name__
)
from
..dataset
import
Dataset
class
IteratorSource
(
Dataset
):
"""
Load data samples from iterator in stream mode
Args:
iter_maker (callable): callable function to generate a iter
samples (int): number of samples to load, -1 means all
"""
def
__init__
(
self
,
iter_maker
,
samples
=-
1
,
**
kwargs
):
super
(
IteratorSource
,
self
).
__init__
()
self
.
_epoch
=
-
1
self
.
_iter_maker
=
iter_maker
self
.
_data_iter
=
None
self
.
_pos
=
-
1
self
.
_drained
=
False
self
.
_samples
=
samples
self
.
_sample_num
=
-
1
def
next
(
self
):
if
self
.
_epoch
<
0
:
self
.
reset
()
if
self
.
_data_iter
is
not
None
:
try
:
sample
=
next
(
self
.
_data_iter
)
self
.
_pos
+=
1
ret
=
sample
except
StopIteration
as
e
:
if
self
.
_sample_num
<=
0
:
self
.
_sample_num
=
self
.
_pos
elif
self
.
_sample_num
!=
self
.
_pos
:
logger
.
info
(
'num of loaded samples is different '
'with previouse setting[prev:%d,now:%d]'
%
(
self
.
_sample_num
,
self
.
_pos
))
self
.
_sample_num
=
self
.
_pos
self
.
_data_iter
=
None
self
.
_drained
=
True
raise
e
else
:
raise
StopIteration
(
"no more data in "
+
str
(
self
))
if
self
.
_samples
>
0
and
self
.
_pos
>=
self
.
_samples
:
self
.
_data_iter
=
None
self
.
_drained
=
True
raise
StopIteration
(
"no more data in "
+
str
(
self
))
else
:
return
ret
def
reset
(
self
):
if
self
.
_data_iter
is
None
:
self
.
_data_iter
=
self
.
_iter_maker
()
if
self
.
_epoch
<
0
:
self
.
_epoch
=
0
else
:
self
.
_epoch
+=
1
self
.
_pos
=
0
self
.
_drained
=
False
def
size
(
self
):
return
self
.
_sample_num
def
drained
(
self
):
assert
self
.
_epoch
>=
0
,
"the first epoch has not started yet"
return
self
.
_pos
>=
self
.
size
()
def
epoch_id
(
self
):
return
self
.
_epoch
ppdet/data/tests/run_all_tests.py
浏览文件 @
c089b67b
...
...
@@ -7,6 +7,7 @@ import unittest
import
test_loader
import
test_operator
import
test_roidb_source
import
test_iterator_source
import
test_transformer
import
test_reader
...
...
@@ -17,6 +18,7 @@ if __name__ == '__main__':
test_loader
.
TestLoader
,
test_operator
.
TestBase
,
test_roidb_source
.
TestRoiDbSource
,
test_iterator_source
.
TestIteratorSource
,
test_transformer
.
TestTransformer
,
test_reader
.
TestReader
,
]
...
...
ppdet/data/tests/set_env.py
浏览文件 @
c089b67b
...
...
@@ -3,6 +3,9 @@ import os
import
six
import
logging
import
matplotlib
matplotlib
.
use
(
'Agg'
,
force
=
False
)
prefix
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
#coco data for testing
...
...
ppdet/data/tests/test_iterator_source.py
0 → 100644
浏览文件 @
c089b67b
import
os
import
time
import
unittest
import
sys
import
logging
import
set_env
from
ppdet.data.source
import
IteratorSource
def
_generate_iter_maker
(
num
=
10
):
def
_reader
():
for
i
in
range
(
num
):
yield
{
'image'
:
'image_'
+
str
(
i
),
'label'
:
i
}
return
_reader
class
TestIteratorSource
(
unittest
.
TestCase
):
"""Test cases for dataset.source.roidb_source
"""
@
classmethod
def
setUpClass
(
cls
):
""" setup
"""
pass
@
classmethod
def
tearDownClass
(
cls
):
""" tearDownClass """
pass
def
test_basic
(
self
):
""" test basic apis 'next/size/drained'
"""
iter_maker
=
_generate_iter_maker
()
iter_source
=
IteratorSource
(
iter_maker
)
for
i
,
sample
in
enumerate
(
iter_source
):
self
.
assertTrue
(
'image'
in
sample
)
self
.
assertGreater
(
len
(
sample
[
'image'
]),
0
)
self
.
assertTrue
(
iter_source
.
drained
())
self
.
assertEqual
(
i
+
1
,
iter_source
.
size
())
def
test_reset
(
self
):
""" test functions 'reset/epoch_id'
"""
iter_maker
=
_generate_iter_maker
()
iter_source
=
IteratorSource
(
iter_maker
)
self
.
assertTrue
(
iter_source
.
next
()
is
not
None
)
self
.
assertEqual
(
iter_source
.
epoch_id
(),
0
)
iter_source
.
reset
()
self
.
assertEqual
(
iter_source
.
epoch_id
(),
1
)
self
.
assertTrue
(
iter_source
.
next
()
is
not
None
)
if
__name__
==
'__main__'
:
unittest
.
main
()
ppdet/data/tests/test_reader.py
浏览文件 @
c089b67b
...
...
@@ -8,6 +8,8 @@ import yaml
import
set_env
from
ppdet.data.reader
import
Reader
from
ppdet.data.source
import
build_source
from
ppdet.data.source
import
IteratorSource
class
TestReader
(
unittest
.
TestCase
):
...
...
@@ -114,6 +116,31 @@ class TestReader(unittest.TestCase):
self
.
assertEqual
(
out
[
0
][
5
].
shape
[
1
],
1
)
self
.
assertGreaterEqual
(
ct
,
rcnn
.
_maxiter
)
def
test_create
(
self
):
""" Test create a reader using my source
"""
def
_my_data_reader
():
mydata
=
build_source
(
self
.
rcnn_conf
[
'DATA'
][
'TRAIN'
])
for
i
,
sample
in
enumerate
(
mydata
):
yield
sample
my_source
=
IteratorSource
(
_my_data_reader
)
mode
=
'TRAIN'
train_rd
=
Reader
.
create
(
mode
,
self
.
rcnn_conf
[
'DATA'
][
mode
],
self
.
rcnn_conf
[
'TRANSFORM'
][
mode
],
max_iter
=
10
,
my_source
=
my_source
)
out
=
None
for
sample
in
train_rd
():
out
=
sample
self
.
assertTrue
(
sample
is
not
None
)
self
.
assertEqual
(
out
[
0
][
0
].
shape
[
0
],
3
)
self
.
assertEqual
(
out
[
0
][
1
].
shape
[
0
],
3
)
self
.
assertEqual
(
out
[
0
][
3
].
shape
[
1
],
4
)
self
.
assertEqual
(
out
[
0
][
4
].
shape
[
1
],
1
)
self
.
assertEqual
(
out
[
0
][
5
].
shape
[
1
],
1
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录