Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
Models
提交
de3bbd61
M
Models
项目概览
MegEngine 天元
/
Models
通知
6
Star
3
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
de3bbd61
编写于
5月 13, 2020
作者:
W
Wang Feng
提交者:
GitHub
5月 13, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13 from wjfwzzc/master
feat(detection): support Objects365 and reformat
上级
9766a399
95a5e7bd
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
191 addition
and
100 deletion
+191
-100
official/vision/classification/shufflenet/model.py
official/vision/classification/shufflenet/model.py
+1
-1
official/vision/detection/README.md
official/vision/detection/README.md
+15
-16
official/vision/detection/layers/basic/functional.py
official/vision/detection/layers/basic/functional.py
+1
-7
official/vision/detection/layers/det/loss.py
official/vision/detection/layers/det/loss.py
+18
-33
official/vision/detection/layers/det/retinanet.py
official/vision/detection/layers/det/retinanet.py
+1
-1
official/vision/detection/models/__init__.py
official/vision/detection/models/__init__.py
+12
-0
official/vision/detection/models/retinanet.py
official/vision/detection/models/retinanet.py
+16
-23
official/vision/detection/retinanet_res50_coco_1x_800size.py
official/vision/detection/retinanet_res50_coco_1x_800size.py
+43
-0
official/vision/detection/retinanet_res50_objects365_1x_800size.py
...vision/detection/retinanet_res50_objects365_1x_800size.py
+42
-0
official/vision/detection/tools/data_mapper.py
official/vision/detection/tools/data_mapper.py
+14
-0
official/vision/detection/tools/test.py
official/vision/detection/tools/test.py
+22
-14
official/vision/detection/tools/train.py
official/vision/detection/tools/train.py
+6
-5
未找到文件。
official/vision/classification/shufflenet/model.py
浏览文件 @
de3bbd61
...
...
@@ -110,7 +110,7 @@ class ShuffleV2Block(M.Module):
class
ShuffleNetV2
(
M
.
Module
):
def
__init__
(
self
,
input_size
=
224
,
num_classes
=
1000
,
model_size
=
"1.5x"
):
def
__init__
(
self
,
num_classes
=
1000
,
model_size
=
"1.5x"
):
super
(
ShuffleNetV2
,
self
).
__init__
()
self
.
stage_repeats
=
[
4
,
8
,
4
]
...
...
official/vision/detection/README.md
浏览文件 @
de3bbd61
...
...
@@ -2,23 +2,22 @@
## 介绍
本目录包含了采用MegEngine实现的经典
[
RetinaNet
](
https://arxiv.org/pdf/1708.02002>
)
网络结构,
同时提供了在COCO2017数据集上的完整训练和测试代码。
本目录包含了采用MegEngine实现的经典
[
RetinaNet
](
https://arxiv.org/pdf/1708.02002>
)
网络结构,同时提供了在COCO2017数据集上的完整训练和测试代码。
网络的性能在COCO2017验证集上的测试结果如下:
| 模型
| mAP
<br>
@5-95 | batch
<br>
/gpu | gpu | speed
<br>
(8gpu) | speed
<br>
(1gpu)
|
| ---
| --- | --- | --- | --- | ---
|
| retinanet-res50-
1x-800size | 36.0 | 2 | 2080 | 2.27(it/s) | 3.7(it/s) |
| 模型
| mAP
<br>
@5-95 | batch
<br>
/gpu | gpu | speed
<br>
(8gpu) | speed
<br>
(1gpu)
|
| ---
| --- | --- | --- | --- | ---
|
| retinanet-res50-
coco-1x-800size | 36.0 | 2 | 2080ti | 2.27(it/s) | 3.7(it/s) |
*
MegEngine v0.
3
.0
*
MegEngine v0.
4
.0
## 如何使用
模型训练好之后,可以通过如下命令测试单张图片:
```
bash
python3 tools/inference.py
-f
retinanet_res50_1x_800size.py
\
python3 tools/inference.py
-f
retinanet_res50_
coco_
1x_800size.py
\
-i
../../assets/cat.jpg
\
-m
/path/to/retinanet_weights.pkl
```
...
...
@@ -35,8 +34,8 @@ python3 tools/inference.py -f retinanet_res50_1x_800size.py \
## 如何训练
1.
在开始训练前,请确保已经下载解压好
[
COCO数据集
](
http://cocodataset.org/#download
)
,
并放在合适的数据目录下,准备好的数据集的目录结构如下所示(目前默认使用
coco2017的
数据集):
1.
在开始训练前,请确保已经下载解压好
[
COCO
2017
数据集
](
http://cocodataset.org/#download
)
,
并放在合适的数据目录下,准备好的数据集的目录结构如下所示(目前默认使用
COCO2017
数据集):
```
/path/to/
...
...
@@ -46,14 +45,14 @@ python3 tools/inference.py -f retinanet_res50_1x_800size.py \
| |val2017
```
2.
准备预训练的
`backbone`
网络权重:可使用 megengine.hub 下载
`megengine`
官方提供的在ImageNet上训练的
resnet
50模型, 并存放在
`/path/to/pretrain.pkl`
。
2.
准备预训练的
`backbone`
网络权重:可使用 megengine.hub 下载
`megengine`
官方提供的在ImageNet上训练的
ResNet-
50模型, 并存放在
`/path/to/pretrain.pkl`
。
3.
在开始运行本目录下的代码之前,请确保按照
[
README
](
../../../README.md
)
进行了正确的环境配置。
4.
开始训练:
```
bash
python3 tools/train.py
-f
retinanet_res50_1x_800size.py
\
python3 tools/train.py
-f
retinanet_res50_
coco_
1x_800size.py
\
-n
8
\
--batch_size
2
\
-w
/path/to/pretrain.pkl
...
...
@@ -65,7 +64,7 @@ python3 tools/train.py -f retinanet_res50_1x_800size.py \
-
`-n`
, 用于训练的devices(gpu)数量,默认使用所有可用的gpu.
-
`-w`
, 预训练的backbone网络权重的路径。
-
`--batch_size`
,训练时采用的
`batch size`
, 默认2,表示每张卡训2张图。
-
`--dataset-dir`
,
coco数据集的根目录,默认
`/data/datasets/coco
`
。
-
`--dataset-dir`
,
COCO2017数据集的上级目录,默认
`/data/datasets
`
。
默认情况下模型会存在
`log-of-retinanet_res50_1x_800size`
目录下。
...
...
@@ -74,10 +73,10 @@ python3 tools/train.py -f retinanet_res50_1x_800size.py \
在训练的过程中,可以通过如下命令测试模型在
`COCO2017`
验证集的性能:
```
bash
python3 tools/test.py
-
n
8
\
-
f
retinanet_res50_1x_800size.py
\
python3 tools/test.py
-
f
retinanet_res50_coco_1x_800size.py
\
-
n
8
\
--model
/path/to/retinanet_weights.pt
\
--dataset_dir
/data/datasets
/coco
--dataset_dir
/data/datasets
```
`tools/test.py`
的命令行选项如下:
...
...
@@ -85,7 +84,7 @@ python3 tools/test.py -n 8 \
-
`-f`
, 所需要测试的网络结构描述文件。
-
`-n`
, 用于测试的devices(gpu)数量,默认1;
-
`--model`
, 需要测试的模型;可以从顶部的表格中下载训练好的检测器权重, 也可以用自行训练好的权重。
-
`--dataset_dir`
,
coco数据集的根
目录,默认
`/data/datasets`
-
`--dataset_dir`
,
COCO2017数据集的上级
目录,默认
`/data/datasets`
## 参考文献
...
...
official/vision/detection/layers/basic/functional.py
浏览文件 @
de3bbd61
...
...
@@ -10,8 +10,7 @@ import megengine as mge
import
megengine.functional
as
F
import
numpy
as
np
from
megengine
import
_internal
as
mgb
from
megengine.core
import
Tensor
,
wrap_io_tensor
from
megengine.core
import
Tensor
def
get_padded_array_np
(
...
...
@@ -86,8 +85,3 @@ def get_padded_tensor(
else
:
raise
Exception
(
"Not supported tensor dim: %d"
%
ndim
)
return
padded_array
@
wrap_io_tensor
def
indexing_set_one_hot
(
inp
,
axis
,
idx
,
value
)
->
Tensor
:
return
mgb
.
opr
.
indexing_set_one_hot
(
inp
,
axis
,
idx
,
value
)
official/vision/detection/layers/det/loss.py
浏览文件 @
de3bbd61
...
...
@@ -12,8 +12,6 @@ import numpy as np
from
megengine.core
import
tensor
,
Tensor
from
official.vision.detection.layers
import
basic
def
get_focal_loss
(
score
:
Tensor
,
...
...
@@ -51,28 +49,19 @@ def get_focal_loss(
Returns:
the calculated focal loss.
"""
mask
=
1
-
(
label
==
ignore_label
)
valid_label
=
label
*
mask
score_shp
=
score
.
shape
zero_mat
=
mge
.
zeros
(
F
.
concat
([
score_shp
[
0
],
score_shp
[
1
],
score_shp
[
2
]
+
1
],
axis
=
0
),
dtype
=
np
.
float32
,
)
one_mat
=
mge
.
ones
(
F
.
concat
([
score_shp
[
0
],
score_shp
[
1
],
tensor
(
1
)],
axis
=
0
),
dtype
=
np
.
float32
,
)
one_hot
=
basic
.
indexing_set_one_hot
(
zero_mat
,
2
,
valid_label
.
astype
(
np
.
int32
),
one_mat
)[:,
:,
1
:]
pos_part
=
F
.
power
(
1
-
score
,
gamma
)
*
one_hot
*
F
.
log
(
score
)
neg_part
=
F
.
power
(
score
,
gamma
)
*
(
1
-
one_hot
)
*
F
.
log
(
1
-
score
)
loss
=
-
(
alpha
*
pos_part
+
(
1
-
alpha
)
*
neg_part
).
sum
(
axis
=
2
)
*
mask
class_range
=
F
.
arange
(
1
,
score
.
shape
[
2
]
+
1
)
label
=
F
.
add_axis
(
label
,
axis
=
2
)
pos_part
=
(
1
-
score
)
**
gamma
*
F
.
log
(
score
)
neg_part
=
score
**
gamma
*
F
.
log
(
1
-
score
)
pos_loss
=
-
(
label
==
class_range
)
*
pos_part
*
alpha
neg_loss
=
-
(
label
!=
class_range
)
*
(
label
!=
ignore_label
)
*
neg_part
*
(
1
-
alpha
)
loss
=
pos_loss
+
neg_loss
if
norm_type
==
"fg"
:
positive_mask
=
label
>
background
return
loss
.
sum
()
/
F
.
maximum
(
positive
_mask
.
sum
(),
1
)
fg_mask
=
(
label
!=
background
)
*
(
label
!=
ignore_label
)
return
loss
.
sum
()
/
F
.
maximum
(
fg
_mask
.
sum
(),
1
)
elif
norm_type
==
"none"
:
return
loss
.
sum
()
else
:
...
...
@@ -117,8 +106,7 @@ def get_smooth_l1_loss(
gt_bbox
=
gt_bbox
.
reshape
(
-
1
,
4
)
label
=
label
.
reshape
(
-
1
)
valid_mask
=
1
-
(
label
==
ignore_label
)
fg_mask
=
(
1
-
(
label
==
background
))
*
valid_mask
fg_mask
=
(
label
!=
background
)
*
(
label
!=
ignore_label
)
losses
=
get_smooth_l1_base
(
pred_bbox
,
gt_bbox
,
sigma
,
is_fix
=
fix_smooth_l1
)
if
norm_type
==
"fg"
:
...
...
@@ -154,19 +142,16 @@ def get_smooth_l1_base(
cond_point
=
sigma
x
=
pred_bbox
-
gt_bbox
abs_x
=
F
.
abs
(
x
)
in_mask
=
abs_x
<
cond_point
out_mask
=
1
-
in_mask
in_loss
=
0.5
*
(
x
**
2
)
out_loss
=
sigma
*
abs_x
-
0.5
*
(
sigma
**
2
)
loss
=
in_loss
*
in_mask
+
out_loss
*
out_mask
in_loss
=
0.5
*
x
**
2
out_loss
=
sigma
*
abs_x
-
0.5
*
sigma
**
2
else
:
sigma2
=
sigma
**
2
cond_point
=
1
/
sigma2
x
=
pred_bbox
-
gt_bbox
abs_x
=
F
.
abs
(
x
)
in_mask
=
abs_x
<
cond_point
out_mask
=
1
-
in_mask
in_loss
=
0.5
*
(
sigma
*
x
)
**
2
in_loss
=
0.5
*
x
**
2
*
sigma2
out_loss
=
abs_x
-
0.5
/
sigma2
loss
=
in_loss
*
in_mask
+
out_loss
*
out_mask
in_mask
=
abs_x
<
cond_point
out_mask
=
1
-
in_mask
loss
=
in_loss
*
in_mask
+
out_loss
*
out_mask
return
loss
official/vision/detection/layers/det/retinanet.py
浏览文件 @
de3bbd61
...
...
@@ -28,7 +28,7 @@ class RetinaNetHead(M.Module):
num_classes
=
cfg
.
num_classes
num_convs
=
4
prior_prob
=
cfg
.
cls_prior_prob
num_anchors
=
[
9
,
9
,
9
,
9
,
9
]
num_anchors
=
[
len
(
cfg
.
anchor_ratios
)
*
len
(
cfg
.
anchor_scales
)]
*
5
assert
(
len
(
set
(
num_anchors
))
==
1
...
...
official/vision/detection/models/__init__.py
0 → 100644
浏览文件 @
de3bbd61
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.retinanet
import
*
_EXCLUDE
=
{}
__all__
=
[
k
for
k
in
globals
().
keys
()
if
k
not
in
_EXCLUDE
and
not
k
.
startswith
(
"_"
)]
official/vision/detection/
retinanet_res50_1x_800size
.py
→
official/vision/detection/
models/retinanet
.py
浏览文件 @
de3bbd61
...
...
@@ -10,7 +10,6 @@ import megengine as mge
import
megengine.functional
as
F
import
megengine.module
as
M
import
numpy
as
np
from
megengine
import
hub
from
official.vision.classification.resnet.model
import
resnet50
from
official.vision.detection
import
layers
...
...
@@ -47,7 +46,7 @@ class RetinaNet(M.Module):
for
p
in
bottom_up
.
layer1
.
parameters
():
p
.
requires_grad
=
False
# -----------------------
--- build the FPN
-------------------------- #
# -----------------------
build the FPN ---
-------------------------- #
in_channels_p6p7
=
2048
out_channels
=
256
self
.
backbone
=
layers
.
FPN
(
...
...
@@ -61,7 +60,7 @@ class RetinaNet(M.Module):
backbone_shape
=
self
.
backbone
.
output_shape
()
feature_shapes
=
[
backbone_shape
[
f
]
for
f
in
self
.
in_features
]
# -----------------------
--- build the RetinaNet Head
-------------- #
# -----------------------
build the RetinaNet Head ----
-------------- #
self
.
head
=
layers
.
RetinaNetHead
(
cfg
,
feature_shapes
)
self
.
inputs
=
{
...
...
@@ -199,13 +198,22 @@ class RetinaNetConfig:
self
.
resnet_norm
=
"FrozenBN"
self
.
backbone_freeze_at
=
2
# ------------------------ data cfg --------------------------- #
# ------------------------ data cfg -------------------------- #
self
.
train_dataset
=
dict
(
name
=
"coco"
,
root
=
"train2017"
,
ann_file
=
"instances_train2017.json"
)
self
.
test_dataset
=
dict
(
name
=
"coco"
,
root
=
"val2017"
,
ann_file
=
"instances_val2017.json"
)
self
.
train_image_short_size
=
800
self
.
train_image_max_size
=
1333
self
.
num_classes
=
80
self
.
img_mean
=
np
.
array
([
103.530
,
116.280
,
123.675
])
# BGR
self
.
img_std
=
np
.
array
([
57.375
,
57.120
,
58.395
])
# self.img_std = np.array([1.0, 1.0, 1.0])
self
.
reg_mean
=
None
self
.
reg_std
=
np
.
array
([
0.1
,
0.1
,
0.2
,
0.2
])
...
...
@@ -217,7 +225,7 @@ class RetinaNetConfig:
self
.
class_aware_box
=
False
self
.
cls_prior_prob
=
0.01
# ------------------------ loss
s cfg
------------------------- #
# ------------------------ loss
cfg -
------------------------- #
self
.
focal_loss_alpha
=
0.25
self
.
focal_loss_gamma
=
2
self
.
reg_loss_weight
=
1.0
/
4.0
...
...
@@ -229,29 +237,14 @@ class RetinaNetConfig:
self
.
log_interval
=
20
self
.
nr_images_epoch
=
80000
self
.
max_epoch
=
18
self
.
warm_iters
=
1
00
self
.
warm_iters
=
5
00
self
.
lr_decay_rate
=
0.1
self
.
lr_decay_sates
=
[
12
,
16
,
17
]
# ------------------------ testing cfg -----------------------
--
#
# ------------------------ testing cfg ----------------------- #
self
.
test_image_short_size
=
800
self
.
test_image_max_size
=
1333
self
.
test_max_boxes_per_image
=
100
self
.
test_vis_threshold
=
0.3
self
.
test_cls_threshold
=
0.05
self
.
test_nms
=
0.5
@
hub
.
pretrained
(
"https://data.megengine.org.cn/models/weights/"
"retinanet_d3f58dce_res50_1x_800size_36dot0.pkl"
)
def
retinanet_res50_1x_800size
(
batch_size
=
1
,
**
kwargs
):
r
"""ResNet-18 model from
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
"""
return
RetinaNet
(
RetinaNetConfig
(),
batch_size
=
batch_size
,
**
kwargs
)
Net
=
RetinaNet
Cfg
=
RetinaNetConfig
official/vision/detection/retinanet_res50_coco_1x_800size.py
0 → 100644
浏览文件 @
de3bbd61
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
megengine
import
hub
from
official.vision.detection
import
models
class
CustomRetinaNetConfig
(
models
.
RetinaNetConfig
):
def
__init__
(
self
):
super
().
__init__
()
# ------------------------ data cfg -------------------------- #
self
.
train_dataset
=
dict
(
name
=
"coco"
,
root
=
"train2017"
,
ann_file
=
"annotations/instances_train2017.json"
)
self
.
test_dataset
=
dict
(
name
=
"coco"
,
root
=
"val2017"
,
ann_file
=
"annotations/instances_val2017.json"
)
@
hub
.
pretrained
(
"https://data.megengine.org.cn/models/weights/"
"retinanet_d3f58dce_res50_1x_800size_36dot0.pkl"
)
def
retinanet_res50_coco_1x_800size
(
batch_size
=
1
,
**
kwargs
):
r
"""ResNet-18 model from
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
"""
return
models
.
RetinaNet
(
RetinaNetConfig
(),
batch_size
=
batch_size
,
**
kwargs
)
Net
=
models
.
RetinaNet
Cfg
=
CustomRetinaNetConfig
official/vision/detection/retinanet_res50_objects365_1x_800size.py
0 → 100644
浏览文件 @
de3bbd61
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
megengine
import
hub
from
official.vision.detection
import
models
class
CustomRetinaNetConfig
(
models
.
RetinaNetConfig
):
def
__init__
(
self
):
super
().
__init__
()
# ------------------------ data cfg -------------------------- #
self
.
train_dataset
=
dict
(
name
=
"objects365"
,
root
=
"train"
,
ann_file
=
"annotations/objects365_train_20190423.json"
)
self
.
test_dataset
=
dict
(
name
=
"objects365"
,
root
=
"val"
,
ann_file
=
"annotations/objects365_val_20190423.json"
)
# ------------------------ training cfg ---------------------- #
self
.
nr_images_epoch
=
400000
def
retinanet_objects365_res50_1x_800size
(
batch_size
=
1
,
**
kwargs
):
r
"""ResNet-18 model from
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
"""
return
models
.
RetinaNet
(
RetinaNetConfig
(),
batch_size
=
batch_size
,
**
kwargs
)
Net
=
models
.
RetinaNet
Cfg
=
CustomRetinaNetConfig
official/vision/detection/tools/data_mapper.py
0 → 100644
浏览文件 @
de3bbd61
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
megengine.data.dataset
import
COCO
,
Objects365
data_mapper
=
dict
(
coco
=
COCO
,
objects365
=
Objects365
,
)
official/vision/detection/tools/test.py
浏览文件 @
de3bbd61
...
...
@@ -19,9 +19,9 @@ import megengine as mge
import
numpy
as
np
from
megengine
import
jit
from
megengine.data
import
DataLoader
,
SequentialSampler
from
megengine.data.dataset
import
COCO
as
COCODataset
from
tqdm
import
tqdm
from
official.vision.detection.tools.data_mapper
import
data_mapper
from
official.vision.detection.tools.nms
import
py_cpu_nms
logger
=
mge
.
get_logger
(
__name__
)
...
...
@@ -119,9 +119,10 @@ class DetEvaluator:
return
dtboxes_all
@
staticmethod
def
format
(
results
):
all_results
=
[
]
def
format
(
results
,
cfg
):
dataset_class
=
data_mapper
[
cfg
.
test_dataset
[
"name"
]
]
all_results
=
[]
for
record
in
results
:
image_filename
=
record
[
"image_id"
]
boxes
=
record
[
"det_res"
]
...
...
@@ -133,8 +134,8 @@ class DetEvaluator:
elem
[
"image_id"
]
=
image_filename
elem
[
"bbox"
]
=
box
[:
4
].
tolist
()
elem
[
"score"
]
=
box
[
4
]
elem
[
"category_id"
]
=
COCODataset
.
classes_originID
[
COCODataset
.
class_names
[
int
(
box
[
5
])
+
1
]
elem
[
"category_id"
]
=
dataset_class
.
classes_originID
[
dataset_class
.
class_names
[
int
(
box
[
5
])
]
]
all_results
.
append
(
elem
)
return
all_results
...
...
@@ -156,7 +157,7 @@ class DetEvaluator:
for
det
in
dets
:
bb
=
det
[:
4
].
astype
(
int
)
if
is_show_label
:
cls_id
=
int
(
det
[
5
]
+
1
)
cls_id
=
int
(
det
[
5
])
score
=
det
[
4
]
if
cls_id
==
0
:
...
...
@@ -200,10 +201,10 @@ class DetEvaluator:
break
def
build_dataloader
(
rank
,
world_size
,
data_dir
):
val_dataset
=
COCODataset
(
os
.
path
.
join
(
data_dir
,
"val2017"
),
os
.
path
.
join
(
data_dir
,
"annotations/instances_val2017.json"
),
def
build_dataloader
(
rank
,
world_size
,
data_dir
,
cfg
):
val_dataset
=
data_mapper
[
cfg
.
test_dataset
[
"name"
]]
(
os
.
path
.
join
(
data_dir
,
cfg
.
test_dataset
[
"name"
],
cfg
.
test_dataset
[
"root"
]
),
os
.
path
.
join
(
data_dir
,
cfg
.
test_dataset
[
"name"
],
cfg
.
test_dataset
[
"ann_file"
]
),
order
=
[
"image"
,
"info"
],
)
val_sampler
=
SequentialSampler
(
val_dataset
,
1
,
world_size
=
world_size
,
rank
=
rank
)
...
...
@@ -236,7 +237,7 @@ def worker(
evaluator
=
DetEvaluator
(
model
)
model
.
load_state_dict
(
mge
.
load
(
model_file
)[
"state_dict"
])
loader
=
build_dataloader
(
worker_id
,
total_worker
,
data_dir
)
loader
=
build_dataloader
(
worker_id
,
total_worker
,
data_dir
,
model
.
cfg
)
for
data_dict
in
loader
:
data
,
im_info
=
DetEvaluator
.
process_inputs
(
data_dict
[
0
][
0
],
...
...
@@ -262,7 +263,7 @@ def make_parser():
parser
.
add_argument
(
"-f"
,
"--file"
,
default
=
"net.py"
,
type
=
str
,
help
=
"net description file"
)
parser
.
add_argument
(
"-d"
,
"--dataset_dir"
,
default
=
"/data/datasets
/coco
"
,
type
=
str
)
parser
.
add_argument
(
"-d"
,
"--dataset_dir"
,
default
=
"/data/datasets"
,
type
=
str
)
parser
.
add_argument
(
"-se"
,
"--start_epoch"
,
default
=-
1
,
type
=
int
)
parser
.
add_argument
(
"-ee"
,
"--end_epoch"
,
default
=-
1
,
type
=
int
)
parser
.
add_argument
(
"-m"
,
"--model"
,
default
=
None
,
type
=
str
)
...
...
@@ -312,7 +313,12 @@ def main():
for
p
in
procs
:
p
.
join
()
all_results
=
DetEvaluator
.
format
(
results_list
)
sys
.
path
.
insert
(
0
,
os
.
path
.
dirname
(
args
.
file
))
current_network
=
importlib
.
import_module
(
os
.
path
.
basename
(
args
.
file
).
split
(
"."
)[
0
]
)
cfg
=
current_network
.
Cfg
()
all_results
=
DetEvaluator
.
format
(
results_list
,
cfg
)
json_path
=
"log-of-{}/epoch_{}.json"
.
format
(
os
.
path
.
basename
(
args
.
file
).
split
(
"."
)[
0
],
epoch_num
)
...
...
@@ -323,7 +329,9 @@ def main():
logger
.
info
(
"Save to %s finished, start evaluation!"
,
json_path
)
eval_gt
=
COCO
(
os
.
path
.
join
(
args
.
dataset_dir
,
"annotations/instances_val2017.json"
)
os
.
path
.
join
(
args
.
dataset_dir
,
cfg
.
test_dataset
[
"name"
],
cfg
.
test_dataset
[
"ann_file"
]
)
)
eval_dt
=
eval_gt
.
loadRes
(
json_path
)
cocoEval
=
COCOeval
(
eval_gt
,
eval_dt
,
iouType
=
"bbox"
)
...
...
official/vision/detection/tools/train.py
浏览文件 @
de3bbd61
...
...
@@ -22,9 +22,10 @@ from megengine import jit
from
megengine
import
optimizer
as
optim
from
megengine.data
import
Collator
,
DataLoader
,
Infinite
,
RandomSampler
from
megengine.data
import
transform
as
T
from
megengine.data.dataset
import
COCO
from
tabulate
import
tabulate
from
official.vision.detection.tools.data_mapper
import
data_mapper
logger
=
mge
.
get_logger
(
__name__
)
...
...
@@ -175,7 +176,7 @@ def make_parser():
"-b"
,
"--batch_size"
,
default
=
2
,
type
=
int
,
help
=
"batchsize for training"
,
)
parser
.
add_argument
(
"-d"
,
"--dataset_dir"
,
default
=
"/data/datasets
/coco
"
,
type
=
str
,
"-d"
,
"--dataset_dir"
,
default
=
"/data/datasets"
,
type
=
str
,
)
return
parser
...
...
@@ -232,9 +233,9 @@ def main():
def
build_dataloader
(
batch_size
,
data_dir
,
cfg
):
train_dataset
=
COCO
(
os
.
path
.
join
(
data_dir
,
"train2017"
),
os
.
path
.
join
(
data_dir
,
"annotations/instances_train2017.json"
),
train_dataset
=
data_mapper
[
cfg
.
train_dataset
[
"name"
]]
(
os
.
path
.
join
(
data_dir
,
cfg
.
train_dataset
[
"name"
],
cfg
.
train_dataset
[
"root"
]
),
os
.
path
.
join
(
data_dir
,
cfg
.
train_dataset
[
"name"
],
cfg
.
train_dataset
[
"ann_file"
]
),
remove_images_without_annotations
=
True
,
order
=
[
"image"
,
"boxes"
,
"boxes_category"
,
"info"
],
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录