Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
2ee6dd97
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2ee6dd97
编写于
10月 14, 2019
作者:
Y
Yuan Gao
提交者:
qingqing01
10月 14, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add class aware sampling strategy (#3104)
* add class aware sampling strategy * remove redundancy code
上级
e968c137
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
159 addition
and
8 deletion
+159
-8
ppdet/data/data_feed.py
ppdet/data/data_feed.py
+15
-5
ppdet/data/source/__init__.py
ppdet/data/source/__init__.py
+9
-1
ppdet/data/source/class_aware_sampling_roidb_source.py
ppdet/data/source/class_aware_sampling_roidb_source.py
+132
-0
ppdet/data/source/coco_loader.py
ppdet/data/source/coco_loader.py
+2
-1
ppdet/utils/coco_eval.py
ppdet/utils/coco_eval.py
+1
-1
未找到文件。
ppdet/data/data_feed.py
浏览文件 @
2ee6dd97
...
...
@@ -70,6 +70,10 @@ def _prepare_data_config(feed, args_path):
'TYPE'
:
type
(
feed
.
dataset
).
__source__
}
if
feed
.
mode
==
'TRAIN'
:
data_config
[
'CLASS_AWARE_SAMPLING'
]
=
getattr
(
feed
,
'class_aware_sampling'
,
False
)
if
len
(
getattr
(
feed
.
dataset
,
'images'
,
[]))
>
0
:
data_config
[
'IMAGES'
]
=
feed
.
dataset
.
images
...
...
@@ -301,7 +305,8 @@ class DataFeed(object):
bufsize
=
10
,
use_process
=
False
,
memsize
=
None
,
use_padded_im_info
=
False
):
use_padded_im_info
=
False
,
class_aware_sampling
=
False
):
super
(
DataFeed
,
self
).
__init__
()
self
.
fields
=
fields
self
.
image_shape
=
image_shape
...
...
@@ -318,6 +323,7 @@ class DataFeed(object):
self
.
memsize
=
memsize
self
.
dataset
=
dataset
self
.
use_padded_im_info
=
use_padded_im_info
self
.
class_aware_sampling
=
class_aware_sampling
if
isinstance
(
dataset
,
dict
):
self
.
dataset
=
DataSet
(
**
dataset
)
...
...
@@ -447,7 +453,8 @@ class FasterRCNNTrainFeed(DataFeed):
bufsize
=
10
,
num_workers
=
2
,
use_process
=
False
,
memsize
=
None
):
memsize
=
None
,
class_aware_sampling
=
False
):
# XXX this should be handled by the data loader, since `fields` is
# given, just collect them
sample_transforms
.
append
(
ArrangeRCNN
())
...
...
@@ -464,7 +471,8 @@ class FasterRCNNTrainFeed(DataFeed):
bufsize
=
bufsize
,
num_workers
=
num_workers
,
use_process
=
use_process
,
memsize
=
memsize
)
memsize
=
memsize
,
class_aware_sampling
=
class_aware_sampling
)
# XXX these modes should be unified
self
.
mode
=
'TRAIN'
...
...
@@ -891,7 +899,8 @@ class YoloTrainFeed(DataFeed):
use_process
=
True
,
memsize
=
None
,
num_max_boxes
=
50
,
mixup_epoch
=
250
):
mixup_epoch
=
250
,
class_aware_sampling
=
False
):
sample_transforms
.
append
(
ArrangeYOLO
())
super
(
YoloTrainFeed
,
self
).
__init__
(
dataset
,
...
...
@@ -907,7 +916,8 @@ class YoloTrainFeed(DataFeed):
num_workers
=
num_workers
,
bufsize
=
bufsize
,
use_process
=
use_process
,
memsize
=
memsize
)
memsize
=
memsize
,
class_aware_sampling
=
class_aware_sampling
)
self
.
num_max_boxes
=
num_max_boxes
self
.
mixup_epoch
=
mixup_epoch
self
.
mode
=
'TRAIN'
...
...
ppdet/data/source/__init__.py
浏览文件 @
2ee6dd97
...
...
@@ -21,6 +21,7 @@ import copy
from
.roidb_source
import
RoiDbSource
from
.simple_source
import
SimpleSource
from
.iterator_source
import
IteratorSource
from
.class_aware_sampling_roidb_source
import
ClassAwareSamplingRoiDbSource
def
build_source
(
config
):
...
...
@@ -53,7 +54,12 @@ def build_source(config):
source_type
=
'RoiDbSource'
if
'type'
in
data_cf
:
if
data_cf
[
'type'
]
in
[
'VOCSource'
,
'COCOSource'
,
'RoiDbSource'
]:
source_type
=
'RoiDbSource'
if
'class_aware_sampling'
in
args
and
args
[
'class_aware_sampling'
]:
source_type
=
'ClassAwareSamplingRoiDbSource'
else
:
source_type
=
'RoiDbSource'
if
'class_aware_sampling'
in
args
:
del
args
[
'class_aware_sampling'
]
else
:
source_type
=
data_cf
[
'type'
]
del
args
[
'type'
]
...
...
@@ -61,5 +67,7 @@ def build_source(config):
return
RoiDbSource
(
**
args
)
elif
source_type
==
'SimpleSource'
:
return
SimpleSource
(
**
args
)
elif
source_type
==
'ClassAwareSamplingRoiDbSource'
:
return
ClassAwareSamplingRoiDbSource
(
**
args
)
else
:
raise
ValueError
(
'source type not supported: '
+
source_type
)
ppdet/data/source/class_aware_sampling_roidb_source.py
0 → 100644
浏览文件 @
2ee6dd97
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#function:
# interface to load data from local files and parse it for samples,
# eg: roidb data in pickled files
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
import
os
import
random
import
copy
import
collections
import
pickle
as
pkl
import
numpy
as
np
from
.roidb_source
import
RoiDbSource
class
ClassAwareSamplingRoiDbSource
(
RoiDbSource
):
""" interface to load class aware sampling roidb data from files
"""
def
__init__
(
self
,
anno_file
,
image_dir
=
None
,
samples
=-
1
,
is_shuffle
=
True
,
load_img
=
False
,
cname2cid
=
None
,
use_default_label
=
None
,
mixup_epoch
=-
1
,
with_background
=
True
):
""" Init
Args:
fname (str): label file path
image_dir (str): root dir for images
samples (int): samples to load, -1 means all
is_shuffle (bool): whether to shuffle samples
load_img (bool): whether load data in this class
cname2cid (dict): the label name to id dictionary
use_default_label (bool):whether use the default mapping of label to id
mixup_epoch (int): parse mixup in first n epoch
with_background (bool): whether load background
as a class
"""
super
(
ClassAwareSamplingRoiDbSource
,
self
).
__init__
(
anno_file
=
anno_file
,
image_dir
=
image_dir
,
samples
=
samples
,
is_shuffle
=
is_shuffle
,
load_img
=
load_img
,
cname2cid
=
cname2cid
,
use_default_label
=
use_default_label
,
mixup_epoch
=
mixup_epoch
,
with_background
=
with_background
)
self
.
_img_weights
=
None
def
__str__
(
self
):
return
'ClassAwareSamplingRoidbSource(fname:%s,epoch:%d,size:%d)'
\
%
(
self
.
_fname
,
self
.
_epoch
,
self
.
size
())
def
next
(
self
):
""" load next sample
"""
if
self
.
_epoch
<
0
:
self
.
reset
()
_pos
=
np
.
random
.
choice
(
self
.
_samples
,
1
,
replace
=
False
,
p
=
self
.
_img_weights
)[
0
]
sample
=
copy
.
deepcopy
(
self
.
_roidb
[
_pos
])
if
self
.
_load_img
:
sample
[
'image'
]
=
self
.
_load_image
(
sample
[
'im_file'
])
else
:
sample
[
'im_file'
]
=
os
.
path
.
join
(
self
.
_image_dir
,
sample
[
'im_file'
])
return
sample
def
_calc_img_weights
(
self
):
""" calculate the probabilities of each sample
"""
imgs_cls
=
[]
num_per_cls
=
{}
img_weights
=
[]
for
i
,
roidb
in
enumerate
(
self
.
_roidb
):
img_cls
=
set
(
[
k
for
cls
in
self
.
_roidb
[
i
][
'gt_class'
]
for
k
in
cls
])
imgs_cls
.
append
(
img_cls
)
for
c
in
img_cls
:
if
c
not
in
num_per_cls
:
num_per_cls
[
c
]
=
1
else
:
num_per_cls
[
c
]
+=
1
for
i
in
range
(
len
(
self
.
_roidb
)):
weights
=
0
for
c
in
imgs_cls
[
i
]:
weights
+=
1
/
num_per_cls
[
c
]
img_weights
.
append
(
weights
)
# Probabilities sum to 1
img_weights
=
img_weights
/
np
.
sum
(
img_weights
)
return
img_weights
def
reset
(
self
):
""" implementation of Dataset.reset
"""
if
self
.
_roidb
is
None
:
self
.
_roidb
=
self
.
_load
()
if
self
.
_img_weights
is
None
:
self
.
_img_weights
=
self
.
_calc_img_weights
()
self
.
_samples
=
len
(
self
.
_roidb
)
if
self
.
_epoch
<
0
:
self
.
_epoch
=
0
ppdet/data/source/coco_loader.py
浏览文件 @
2ee6dd97
...
...
@@ -101,7 +101,8 @@ def load(anno_path, sample_num=-1, with_background=True):
gt_class
[
i
][
0
]
=
catid2clsid
[
catid
]
gt_bbox
[
i
,
:]
=
box
[
'clean_bbox'
]
is_crowd
[
i
][
0
]
=
box
[
'iscrowd'
]
gt_poly
[
i
]
=
box
[
'segmentation'
]
if
'segmentation'
in
box
:
gt_poly
[
i
]
=
box
[
'segmentation'
]
coco_rec
=
{
'im_file'
:
im_fname
,
...
...
ppdet/utils/coco_eval.py
浏览文件 @
2ee6dd97
...
...
@@ -213,7 +213,7 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False):
for
j
in
range
(
num
):
dt
=
bboxes
[
k
]
clsid
,
score
,
xmin
,
ymin
,
xmax
,
ymax
=
dt
.
tolist
()
catid
=
clsid2catid
[
clsid
]
catid
=
(
clsid2catid
[
int
(
clsid
)])
if
is_bbox_normalized
:
xmin
,
ymin
,
xmax
,
ymax
=
\
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录