Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
c089b67b
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c089b67b
编写于
7月 31, 2019
作者:
W
walloollaw
提交者:
qingqing01
7月 31, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support custom reader in ppdet.data (#2965)
上级
3a6c1f95
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
249 addition
and
29 deletion
+249
-29
ppdet/data/data_feed.py
ppdet/data/data_feed.py
+29
-22
ppdet/data/reader.py
ppdet/data/reader.py
+21
-6
ppdet/data/source/__init__.py
ppdet/data/source/__init__.py
+4
-1
ppdet/data/source/iterator_source.py
ppdet/data/source/iterator_source.py
+103
-0
ppdet/data/tests/run_all_tests.py
ppdet/data/tests/run_all_tests.py
+2
-0
ppdet/data/tests/set_env.py
ppdet/data/tests/set_env.py
+3
-0
ppdet/data/tests/test_iterator_source.py
ppdet/data/tests/test_iterator_source.py
+60
-0
ppdet/data/tests/test_reader.py
ppdet/data/tests/test_reader.py
+27
-0
未找到文件。
ppdet/data/data_feed.py
浏览文件 @
c089b67b
...
...
@@ -42,14 +42,7 @@ __all__ = [
]
def
create_reader
(
feed
,
max_iter
=
0
,
args_path
=
None
):
"""
Return iterable data reader.
Args:
max_iter (int): number of iterations.
"""
def
_prepare_data_config
(
feed
,
args_path
):
# if `DATASET_DIR` does not exists, search ~/.paddle/dataset for a directory
# named `DATASET_DIR` (e.g., coco, pascal), if not present either, download
dataset_home
=
args_path
if
args_path
else
feed
.
dataset
.
dataset_dir
...
...
@@ -72,9 +65,7 @@ def create_reader(feed, max_iter=0, args_path=None):
if
getattr
(
feed
,
'use_process'
,
None
)
is
not
None
:
use_process
=
feed
.
use_process
mode
=
feed
.
mode
data_config
=
{
mode
:
{
'ANNO_FILE'
:
feed
.
dataset
.
annotation
,
'IMAGE_DIR'
:
feed
.
dataset
.
image_dir
,
'USE_DEFAULT_LABEL'
:
feed
.
dataset
.
use_default_label
,
...
...
@@ -84,10 +75,26 @@ def create_reader(feed, max_iter=0, args_path=None):
'MIXUP_EPOCH'
:
mixup_epoch
,
'TYPE'
:
type
(
feed
.
dataset
).
__source__
}
}
if
len
(
getattr
(
feed
.
dataset
,
'images'
,
[]))
>
0
:
data_config
[
mode
][
'IMAGES'
]
=
feed
.
dataset
.
images
data_config
[
'IMAGES'
]
=
feed
.
dataset
.
images
return
data_config
def
create_reader
(
feed
,
max_iter
=
0
,
args_path
=
None
,
my_source
=
None
):
"""
Return iterable data reader.
Args:
max_iter (int): number of iterations.
my_source (callable): callable function to create a source iterator
which is used to provide source data in 'ppdet.data.reader'
"""
# if `DATASET_DIR` does not exists, search ~/.paddle/dataset for a directory
# named `DATASET_DIR` (e.g., coco, pascal), if not present either, download
data_config
=
_prepare_data_config
(
feed
,
args_path
)
transform_config
=
{
'WORKER_CONF'
:
{
...
...
@@ -130,8 +137,8 @@ def create_reader(feed, max_iter=0, args_path=None):
ops
.
append
(
op_dict
)
transform_config
[
'OPS'
]
=
ops
re
ader
=
Reader
(
data_config
,
{
mode
:
transform_config
},
max_iter
)
return
reader
.
_make_reader
(
mod
e
)
re
turn
Reader
.
create
(
feed
.
mode
,
data_config
,
transform_config
,
max_iter
,
my_sourc
e
)
# XXX batch transforms are only stubs for now, actually handled by `post_map`
...
...
ppdet/data/reader.py
浏览文件 @
c089b67b
...
...
@@ -40,14 +40,17 @@ class Reader(object):
self
.
_cname2cid
=
None
assert
isinstance
(
self
.
_maxiter
,
Integral
),
"maxiter should be int"
def
_make_reader
(
self
,
mode
):
def
_make_reader
(
self
,
mode
,
my_source
=
None
):
"""Build reader for training or validation"""
if
my_source
is
None
:
file_conf
=
self
.
_data_cf
[
mode
]
# 1, Build data source
sc_conf
=
{
'data_cf'
:
file_conf
,
'cname2cid'
:
self
.
_cname2cid
}
sc
=
build_source
(
sc_conf
)
else
:
sc
=
my_source
# 2, Buid a transformed dataset
ops
=
self
.
_trans_conf
[
mode
][
'OPS'
]
...
...
@@ -87,7 +90,7 @@ class Reader(object):
if
mode
.
lower
()
==
'train'
:
if
self
.
_cname2cid
is
not
None
:
logger
.
warn
(
'cname2cid already set, it will be overridden'
)
self
.
_cname2cid
=
sc
.
cname2cid
self
.
_cname2cid
=
getattr
(
sc
,
'cname2cid'
,
None
)
# 3, Build a reader
maxit
=
-
1
if
self
.
_maxiter
<=
0
else
self
.
_maxiter
...
...
@@ -120,3 +123,15 @@ class Reader(object):
def
test
(
self
):
"""Build reader for inference"""
return
self
.
_make_reader
(
'TEST'
)
@
classmethod
def
create
(
cls
,
mode
,
data_config
,
transform_config
,
max_iter
=-
1
,
my_source
=
None
,
ret_iter
=
True
):
""" create a specific reader """
reader
=
Reader
({
mode
:
data_config
},
{
mode
:
transform_config
},
max_iter
)
if
ret_iter
:
return
reader
.
_make_reader
(
mode
,
my_source
)
else
:
return
reader
ppdet/data/source/__init__.py
浏览文件 @
c089b67b
...
...
@@ -20,6 +20,7 @@ import copy
from
.roidb_source
import
RoiDbSource
from
.simple_source
import
SimpleSource
from
.iterator_source
import
IteratorSource
def
build_source
(
config
):
...
...
@@ -40,11 +41,13 @@ def build_source(config):
}
"""
if
'data_cf'
in
config
:
data_cf
=
{
k
.
lower
():
v
for
k
,
v
in
config
[
'data_cf'
].
items
()}
data_cf
=
config
[
'data_cf'
]
data_cf
[
'cname2cid'
]
=
config
[
'cname2cid'
]
else
:
data_cf
=
config
data_cf
=
{
k
.
lower
():
v
for
k
,
v
in
data_cf
.
items
()}
args
=
copy
.
deepcopy
(
data_cf
)
# defaut type is 'RoiDbSource'
source_type
=
'RoiDbSource'
...
...
ppdet/data/source/iterator_source.py
0 → 100644
浏览文件 @
c089b67b
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
numpy
as
np
import
copy
import
logging
logger
=
logging
.
getLogger
(
__name__
)
from
..dataset
import
Dataset
class
IteratorSource
(
Dataset
):
"""
Load data samples from iterator in stream mode
Args:
iter_maker (callable): callable function to generate a iter
samples (int): number of samples to load, -1 means all
"""
def
__init__
(
self
,
iter_maker
,
samples
=-
1
,
**
kwargs
):
super
(
IteratorSource
,
self
).
__init__
()
self
.
_epoch
=
-
1
self
.
_iter_maker
=
iter_maker
self
.
_data_iter
=
None
self
.
_pos
=
-
1
self
.
_drained
=
False
self
.
_samples
=
samples
self
.
_sample_num
=
-
1
def
next
(
self
):
if
self
.
_epoch
<
0
:
self
.
reset
()
if
self
.
_data_iter
is
not
None
:
try
:
sample
=
next
(
self
.
_data_iter
)
self
.
_pos
+=
1
ret
=
sample
except
StopIteration
as
e
:
if
self
.
_sample_num
<=
0
:
self
.
_sample_num
=
self
.
_pos
elif
self
.
_sample_num
!=
self
.
_pos
:
logger
.
info
(
'num of loaded samples is different '
'with previouse setting[prev:%d,now:%d]'
%
(
self
.
_sample_num
,
self
.
_pos
))
self
.
_sample_num
=
self
.
_pos
self
.
_data_iter
=
None
self
.
_drained
=
True
raise
e
else
:
raise
StopIteration
(
"no more data in "
+
str
(
self
))
if
self
.
_samples
>
0
and
self
.
_pos
>=
self
.
_samples
:
self
.
_data_iter
=
None
self
.
_drained
=
True
raise
StopIteration
(
"no more data in "
+
str
(
self
))
else
:
return
ret
def
reset
(
self
):
if
self
.
_data_iter
is
None
:
self
.
_data_iter
=
self
.
_iter_maker
()
if
self
.
_epoch
<
0
:
self
.
_epoch
=
0
else
:
self
.
_epoch
+=
1
self
.
_pos
=
0
self
.
_drained
=
False
def
size
(
self
):
return
self
.
_sample_num
def
drained
(
self
):
assert
self
.
_epoch
>=
0
,
"the first epoch has not started yet"
return
self
.
_pos
>=
self
.
size
()
def
epoch_id
(
self
):
return
self
.
_epoch
ppdet/data/tests/run_all_tests.py
浏览文件 @
c089b67b
...
...
@@ -7,6 +7,7 @@ import unittest
import
test_loader
import
test_operator
import
test_roidb_source
import
test_iterator_source
import
test_transformer
import
test_reader
...
...
@@ -17,6 +18,7 @@ if __name__ == '__main__':
test_loader
.
TestLoader
,
test_operator
.
TestBase
,
test_roidb_source
.
TestRoiDbSource
,
test_iterator_source
.
TestIteratorSource
,
test_transformer
.
TestTransformer
,
test_reader
.
TestReader
,
]
...
...
ppdet/data/tests/set_env.py
浏览文件 @
c089b67b
...
...
@@ -3,6 +3,9 @@ import os
import
six
import
logging
import
matplotlib
matplotlib
.
use
(
'Agg'
,
force
=
False
)
prefix
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
#coco data for testing
...
...
ppdet/data/tests/test_iterator_source.py
0 → 100644
浏览文件 @
c089b67b
import
os
import
time
import
unittest
import
sys
import
logging
import
set_env
from
ppdet.data.source
import
IteratorSource
def
_generate_iter_maker
(
num
=
10
):
def
_reader
():
for
i
in
range
(
num
):
yield
{
'image'
:
'image_'
+
str
(
i
),
'label'
:
i
}
return
_reader
class
TestIteratorSource
(
unittest
.
TestCase
):
"""Test cases for dataset.source.roidb_source
"""
@
classmethod
def
setUpClass
(
cls
):
""" setup
"""
pass
@
classmethod
def
tearDownClass
(
cls
):
""" tearDownClass """
pass
def
test_basic
(
self
):
""" test basic apis 'next/size/drained'
"""
iter_maker
=
_generate_iter_maker
()
iter_source
=
IteratorSource
(
iter_maker
)
for
i
,
sample
in
enumerate
(
iter_source
):
self
.
assertTrue
(
'image'
in
sample
)
self
.
assertGreater
(
len
(
sample
[
'image'
]),
0
)
self
.
assertTrue
(
iter_source
.
drained
())
self
.
assertEqual
(
i
+
1
,
iter_source
.
size
())
def
test_reset
(
self
):
""" test functions 'reset/epoch_id'
"""
iter_maker
=
_generate_iter_maker
()
iter_source
=
IteratorSource
(
iter_maker
)
self
.
assertTrue
(
iter_source
.
next
()
is
not
None
)
self
.
assertEqual
(
iter_source
.
epoch_id
(),
0
)
iter_source
.
reset
()
self
.
assertEqual
(
iter_source
.
epoch_id
(),
1
)
self
.
assertTrue
(
iter_source
.
next
()
is
not
None
)
if
__name__
==
'__main__'
:
unittest
.
main
()
ppdet/data/tests/test_reader.py
浏览文件 @
c089b67b
...
...
@@ -8,6 +8,8 @@ import yaml
import
set_env
from
ppdet.data.reader
import
Reader
from
ppdet.data.source
import
build_source
from
ppdet.data.source
import
IteratorSource
class
TestReader
(
unittest
.
TestCase
):
...
...
@@ -114,6 +116,31 @@ class TestReader(unittest.TestCase):
self
.
assertEqual
(
out
[
0
][
5
].
shape
[
1
],
1
)
self
.
assertGreaterEqual
(
ct
,
rcnn
.
_maxiter
)
def
test_create
(
self
):
""" Test create a reader using my source
"""
def
_my_data_reader
():
mydata
=
build_source
(
self
.
rcnn_conf
[
'DATA'
][
'TRAIN'
])
for
i
,
sample
in
enumerate
(
mydata
):
yield
sample
my_source
=
IteratorSource
(
_my_data_reader
)
mode
=
'TRAIN'
train_rd
=
Reader
.
create
(
mode
,
self
.
rcnn_conf
[
'DATA'
][
mode
],
self
.
rcnn_conf
[
'TRANSFORM'
][
mode
],
max_iter
=
10
,
my_source
=
my_source
)
out
=
None
for
sample
in
train_rd
():
out
=
sample
self
.
assertTrue
(
sample
is
not
None
)
self
.
assertEqual
(
out
[
0
][
0
].
shape
[
0
],
3
)
self
.
assertEqual
(
out
[
0
][
1
].
shape
[
0
],
3
)
self
.
assertEqual
(
out
[
0
][
3
].
shape
[
1
],
4
)
self
.
assertEqual
(
out
[
0
][
4
].
shape
[
1
],
1
)
self
.
assertEqual
(
out
[
0
][
5
].
shape
[
1
],
1
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录