Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
0e6468c7
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0e6468c7
编写于
3月 25, 2021
作者:
K
Kaipeng Deng
提交者:
GitHub
3月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine trainer (#2412)
* refine trainer
上级
2be546c4
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
30 addition
and
16 deletion
+30
-16
ppdet/data/reader.py
ppdet/data/reader.py
+1
-0
ppdet/data/source/__init__.py
ppdet/data/source/__init__.py
+2
-0
ppdet/data/source/category.py
ppdet/data/source/category.py
+0
-0
ppdet/data/source/coco.py
ppdet/data/source/coco.py
+5
-5
ppdet/data/source/dataset.py
ppdet/data/source/dataset.py
+4
-0
ppdet/engine/export_utils.py
ppdet/engine/export_utils.py
+1
-1
ppdet/engine/trainer.py
ppdet/engine/trainer.py
+12
-3
ppdet/metrics/__init__.py
ppdet/metrics/__init__.py
+1
-5
ppdet/metrics/metrics.py
ppdet/metrics/metrics.py
+4
-2
未找到文件。
ppdet/data/reader.py
浏览文件 @
0e6468c7
...
@@ -184,6 +184,7 @@ class BaseDataLoader(object):
...
@@ -184,6 +184,7 @@ class BaseDataLoader(object):
batch_sampler
=
None
,
batch_sampler
=
None
,
return_list
=
False
):
return_list
=
False
):
self
.
dataset
=
dataset
self
.
dataset
=
dataset
self
.
dataset
.
check_or_download_dataset
()
self
.
dataset
.
parse_dataset
()
self
.
dataset
.
parse_dataset
()
# get data
# get data
self
.
dataset
.
set_transform
(
self
.
_sample_transforms
)
self
.
dataset
.
set_transform
(
self
.
_sample_transforms
)
...
...
ppdet/data/source/__init__.py
浏览文件 @
0e6468c7
...
@@ -16,7 +16,9 @@ from . import coco
...
@@ -16,7 +16,9 @@ from . import coco
# TODO add voc and widerface dataset
# TODO add voc and widerface dataset
from
.
import
voc
from
.
import
voc
#from . import widerface
#from . import widerface
from
.
import
category
from
.coco
import
*
from
.coco
import
*
from
.voc
import
*
from
.voc
import
*
#from .widerface import *
#from .widerface import *
from
.category
import
*
ppdet/
metrics
/category.py
→
ppdet/
data/source
/category.py
浏览文件 @
0e6468c7
文件已移动
ppdet/data/source/coco.py
浏览文件 @
0e6468c7
...
@@ -49,10 +49,10 @@ class COCODataSet(DetDataset):
...
@@ -49,10 +49,10 @@ class COCODataSet(DetDataset):
records
=
[]
records
=
[]
ct
=
0
ct
=
0
catid2clsid
=
dict
({
catid
:
i
for
i
,
catid
in
enumerate
(
cat_ids
)})
self
.
catid2clsid
=
dict
({
catid
:
i
for
i
,
catid
in
enumerate
(
cat_ids
)})
cname2cid
=
dict
({
self
.
cname2cid
=
dict
({
coco
.
loadCats
(
catid
)[
0
][
'name'
]:
clsid
coco
.
loadCats
(
catid
)[
0
][
'name'
]:
clsid
for
catid
,
clsid
in
catid2clsid
.
items
()
for
catid
,
clsid
in
self
.
catid2clsid
.
items
()
})
})
if
'annotations'
not
in
coco
.
dataset
:
if
'annotations'
not
in
coco
.
dataset
:
...
@@ -119,7 +119,7 @@ class COCODataSet(DetDataset):
...
@@ -119,7 +119,7 @@ class COCODataSet(DetDataset):
has_segmentation
=
False
has_segmentation
=
False
for
i
,
box
in
enumerate
(
bboxes
):
for
i
,
box
in
enumerate
(
bboxes
):
catid
=
box
[
'category_id'
]
catid
=
box
[
'category_id'
]
gt_class
[
i
][
0
]
=
catid2clsid
[
catid
]
gt_class
[
i
][
0
]
=
self
.
catid2clsid
[
catid
]
gt_bbox
[
i
,
:]
=
box
[
'clean_bbox'
]
gt_bbox
[
i
,
:]
=
box
[
'clean_bbox'
]
is_crowd
[
i
][
0
]
=
box
[
'iscrowd'
]
is_crowd
[
i
][
0
]
=
box
[
'iscrowd'
]
# check RLE format
# check RLE format
...
@@ -163,4 +163,4 @@ class COCODataSet(DetDataset):
...
@@ -163,4 +163,4 @@ class COCODataSet(DetDataset):
break
break
assert
len
(
records
)
>
0
,
'not found any coco record in %s'
%
(
anno_path
)
assert
len
(
records
)
>
0
,
'not found any coco record in %s'
%
(
anno_path
)
logger
.
debug
(
'{} samples in file {}'
.
format
(
ct
,
anno_path
))
logger
.
debug
(
'{} samples in file {}'
.
format
(
ct
,
anno_path
))
self
.
roidbs
,
self
.
cname2cid
=
records
,
cname2cid
self
.
roidbs
=
records
ppdet/data/source/dataset.py
浏览文件 @
0e6468c7
...
@@ -67,6 +67,10 @@ class DetDataset(Dataset):
...
@@ -67,6 +67,10 @@ class DetDataset(Dataset):
return
self
.
transform
(
roidb
)
return
self
.
transform
(
roidb
)
def
check_or_download_dataset
(
self
):
self
.
dataset_dir
=
get_dataset_path
(
self
.
dataset_dir
,
self
.
anno_path
,
self
.
image_dir
)
def
set_kwargs
(
self
,
**
kwargs
):
def
set_kwargs
(
self
,
**
kwargs
):
self
.
mixup_epoch
=
kwargs
.
get
(
'mixup_epoch'
,
-
1
)
self
.
mixup_epoch
=
kwargs
.
get
(
'mixup_epoch'
,
-
1
)
self
.
cutmix_epoch
=
kwargs
.
get
(
'cutmix_epoch'
,
-
1
)
self
.
cutmix_epoch
=
kwargs
.
get
(
'cutmix_epoch'
,
-
1
)
...
...
ppdet/engine/export_utils.py
浏览文件 @
0e6468c7
...
@@ -20,7 +20,7 @@ import os
...
@@ -20,7 +20,7 @@ import os
import
yaml
import
yaml
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
ppdet.
metrics
import
get_categories
from
ppdet.
data.source.category
import
get_categories
from
ppdet.utils.logger
import
setup_logger
from
ppdet.utils.logger
import
setup_logger
logger
=
setup_logger
(
'ppdet.engine'
)
logger
=
setup_logger
(
'ppdet.engine'
)
...
...
ppdet/engine/trainer.py
浏览文件 @
0e6468c7
...
@@ -31,7 +31,8 @@ from paddle.static import InputSpec
...
@@ -31,7 +31,8 @@ from paddle.static import InputSpec
from
ppdet.core.workspace
import
create
from
ppdet.core.workspace
import
create
from
ppdet.utils.checkpoint
import
load_weight
,
load_pretrain_weight
from
ppdet.utils.checkpoint
import
load_weight
,
load_pretrain_weight
from
ppdet.utils.visualizer
import
visualize_results
from
ppdet.utils.visualizer
import
visualize_results
from
ppdet.metrics
import
Metric
,
COCOMetric
,
VOCMetric
,
WiderFaceMetric
,
get_categories
,
get_infer_results
from
ppdet.metrics
import
Metric
,
COCOMetric
,
VOCMetric
,
WiderFaceMetric
,
get_infer_results
from
ppdet.data.source.category
import
get_categories
import
ppdet.utils.stats
as
stats
import
ppdet.utils.stats
as
stats
from
.callbacks
import
Callback
,
ComposeCallback
,
LogPrinter
,
Checkpointer
,
WiferFaceEval
,
VisualDLWriter
from
.callbacks
import
Callback
,
ComposeCallback
,
LogPrinter
,
Checkpointer
,
WiferFaceEval
,
VisualDLWriter
...
@@ -116,8 +117,8 @@ class Trainer(object):
...
@@ -116,8 +117,8 @@ class Trainer(object):
self
.
_callbacks
=
[]
self
.
_callbacks
=
[]
self
.
_compose_callback
=
None
self
.
_compose_callback
=
None
def
_init_metrics
(
self
):
def
_init_metrics
(
self
,
validate
=
False
):
if
self
.
mode
==
'test'
:
if
self
.
mode
==
'test'
or
(
self
.
mode
==
'train'
and
not
validate
)
:
self
.
_metrics
=
[]
self
.
_metrics
=
[]
return
return
classwise
=
self
.
cfg
[
'classwise'
]
if
'classwise'
in
self
.
cfg
else
False
classwise
=
self
.
cfg
[
'classwise'
]
if
'classwise'
in
self
.
cfg
else
False
...
@@ -126,9 +127,12 @@ class Trainer(object):
...
@@ -126,9 +127,12 @@ class Trainer(object):
bias
=
self
.
cfg
[
'bias'
]
if
'bias'
in
self
.
cfg
else
0
bias
=
self
.
cfg
[
'bias'
]
if
'bias'
in
self
.
cfg
else
0
output_eval
=
self
.
cfg
[
'output_eval'
]
\
output_eval
=
self
.
cfg
[
'output_eval'
]
\
if
'output_eval'
in
self
.
cfg
else
None
if
'output_eval'
in
self
.
cfg
else
None
clsid2catid
=
{
v
:
k
for
k
,
v
in
self
.
dataset
.
catid2clsid
.
items
()}
\
if
self
.
mode
==
'eval'
else
None
self
.
_metrics
=
[
self
.
_metrics
=
[
COCOMetric
(
COCOMetric
(
anno_file
=
self
.
dataset
.
get_anno
(),
anno_file
=
self
.
dataset
.
get_anno
(),
clsid2catid
=
clsid2catid
,
classwise
=
classwise
,
classwise
=
classwise
,
output_eval
=
output_eval
,
output_eval
=
output_eval
,
bias
=
bias
)
bias
=
bias
)
...
@@ -186,6 +190,11 @@ class Trainer(object):
...
@@ -186,6 +190,11 @@ class Trainer(object):
def
train
(
self
,
validate
=
False
):
def
train
(
self
,
validate
=
False
):
assert
self
.
mode
==
'train'
,
"Model not in 'train' mode"
assert
self
.
mode
==
'train'
,
"Model not in 'train' mode"
# if validation in training is enabled, metrics should be re-init
if
validate
:
self
.
_init_metrics
(
validate
=
validate
)
self
.
_reset_metrics
()
model
=
self
.
model
model
=
self
.
model
if
self
.
cfg
.
fleet
:
if
self
.
cfg
.
fleet
:
model
=
fleet
.
distributed_model
(
model
)
model
=
fleet
.
distributed_model
(
model
)
...
...
ppdet/metrics/__init__.py
浏览文件 @
0e6468c7
...
@@ -15,8 +15,4 @@
...
@@ -15,8 +15,4 @@
from
.
import
metrics
from
.
import
metrics
from
.metrics
import
*
from
.metrics
import
*
from
.
import
category
__all__
=
metrics
.
__all__
from
.category
import
*
__all__
=
metrics
.
__all__
\
+
category
.
__all__
ppdet/metrics/metrics.py
浏览文件 @
0e6468c7
...
@@ -22,10 +22,10 @@ import json
...
@@ -22,10 +22,10 @@ import json
import
paddle
import
paddle
import
numpy
as
np
import
numpy
as
np
from
.category
import
get_categories
from
.map_utils
import
prune_zero_padding
,
DetectionMAP
from
.map_utils
import
prune_zero_padding
,
DetectionMAP
from
.coco_utils
import
get_infer_results
,
cocoapi_eval
from
.coco_utils
import
get_infer_results
,
cocoapi_eval
from
.widerface_utils
import
face_eval_run
from
.widerface_utils
import
face_eval_run
from
ppdet.data.source.category
import
get_categories
from
ppdet.utils.logger
import
setup_logger
from
ppdet.utils.logger
import
setup_logger
logger
=
setup_logger
(
__name__
)
logger
=
setup_logger
(
__name__
)
...
@@ -62,7 +62,9 @@ class COCOMetric(Metric):
...
@@ -62,7 +62,9 @@ class COCOMetric(Metric):
assert
os
.
path
.
isfile
(
anno_file
),
\
assert
os
.
path
.
isfile
(
anno_file
),
\
"anno_file {} not a file"
.
format
(
anno_file
)
"anno_file {} not a file"
.
format
(
anno_file
)
self
.
anno_file
=
anno_file
self
.
anno_file
=
anno_file
self
.
clsid2catid
,
self
.
catid2name
=
get_categories
(
'COCO'
,
anno_file
)
self
.
clsid2catid
=
kwargs
.
get
(
'clsid2catid'
,
None
)
if
self
.
clsid2catid
is
None
:
self
.
clsid2catid
,
_
=
get_categories
(
'COCO'
,
anno_file
)
self
.
classwise
=
kwargs
.
get
(
'classwise'
,
False
)
self
.
classwise
=
kwargs
.
get
(
'classwise'
,
False
)
self
.
output_eval
=
kwargs
.
get
(
'output_eval'
,
None
)
self
.
output_eval
=
kwargs
.
get
(
'output_eval'
,
None
)
# TODO: bias should be unified
# TODO: bias should be unified
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录