未验证 提交 de3bbd61 编写于 作者: W Wang Feng 提交者: GitHub

Merge pull request #13 from wjfwzzc/master

feat(detection): support Objects365 and reformat
...@@ -110,7 +110,7 @@ class ShuffleV2Block(M.Module): ...@@ -110,7 +110,7 @@ class ShuffleV2Block(M.Module):
class ShuffleNetV2(M.Module): class ShuffleNetV2(M.Module):
def __init__(self, input_size=224, num_classes=1000, model_size="1.5x"): def __init__(self, num_classes=1000, model_size="1.5x"):
super(ShuffleNetV2, self).__init__() super(ShuffleNetV2, self).__init__()
self.stage_repeats = [4, 8, 4] self.stage_repeats = [4, 8, 4]
......
...@@ -2,23 +2,22 @@ ...@@ -2,23 +2,22 @@
## 介绍 ## 介绍
本目录包含了采用MegEngine实现的经典[RetinaNet](https://arxiv.org/pdf/1708.02002>)网络结构, 本目录包含了采用MegEngine实现的经典[RetinaNet](https://arxiv.org/pdf/1708.02002>)网络结构,同时提供了在COCO2017数据集上的完整训练和测试代码。
同时提供了在COCO2017数据集上的完整训练和测试代码。
网络的性能在COCO2017验证集上的测试结果如下: 网络的性能在COCO2017验证集上的测试结果如下:
| 模型 | mAP<br>@5-95 | batch<br>/gpu | gpu | speed<br>(8gpu) | speed<br>(1gpu)| | 模型 | mAP<br>@5-95 | batch<br>/gpu | gpu | speed<br>(8gpu) | speed<br>(1gpu) |
| --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- |
| retinanet-res50-1x-800size | 36.0 | 2 | 2080 | 2.27(it/s) | 3.7(it/s) | | retinanet-res50-coco-1x-800size | 36.0 | 2 | 2080ti | 2.27(it/s) | 3.7(it/s) |
* MegEngine v0.3.0 * MegEngine v0.4.0
## 如何使用 ## 如何使用
模型训练好之后,可以通过如下命令测试单张图片: 模型训练好之后,可以通过如下命令测试单张图片:
```bash ```bash
python3 tools/inference.py -f retinanet_res50_1x_800size.py \ python3 tools/inference.py -f retinanet_res50_coco_1x_800size.py \
-i ../../assets/cat.jpg \ -i ../../assets/cat.jpg \
-m /path/to/retinanet_weights.pkl -m /path/to/retinanet_weights.pkl
``` ```
...@@ -35,8 +34,8 @@ python3 tools/inference.py -f retinanet_res50_1x_800size.py \ ...@@ -35,8 +34,8 @@ python3 tools/inference.py -f retinanet_res50_1x_800size.py \
## 如何训练 ## 如何训练
1. 在开始训练前,请确保已经下载解压好[COCO数据集](http://cocodataset.org/#download) 1. 在开始训练前,请确保已经下载解压好[COCO2017数据集](http://cocodataset.org/#download)
并放在合适的数据目录下,准备好的数据集的目录结构如下所示(目前默认使用coco2017的数据集): 并放在合适的数据目录下,准备好的数据集的目录结构如下所示(目前默认使用COCO2017数据集):
``` ```
/path/to/ /path/to/
...@@ -46,14 +45,14 @@ python3 tools/inference.py -f retinanet_res50_1x_800size.py \ ...@@ -46,14 +45,14 @@ python3 tools/inference.py -f retinanet_res50_1x_800size.py \
| |val2017 | |val2017
``` ```
2. 准备预训练的`backbone`网络权重:可使用 megengine.hub 下载`megengine`官方提供的在ImageNet上训练的resnet50模型, 并存放在 `/path/to/pretrain.pkl` 2. 准备预训练的`backbone`网络权重:可使用 megengine.hub 下载`megengine`官方提供的在ImageNet上训练的ResNet-50模型, 并存放在 `/path/to/pretrain.pkl`
3. 在开始运行本目录下的代码之前,请确保按照[README](../../../README.md)进行了正确的环境配置。 3. 在开始运行本目录下的代码之前,请确保按照[README](../../../README.md)进行了正确的环境配置。
4. 开始训练: 4. 开始训练:
```bash ```bash
python3 tools/train.py -f retinanet_res50_1x_800size.py \ python3 tools/train.py -f retinanet_res50_coco_1x_800size.py \
-n 8 \ -n 8 \
--batch_size 2 \ --batch_size 2 \
-w /path/to/pretrain.pkl -w /path/to/pretrain.pkl
...@@ -65,7 +64,7 @@ python3 tools/train.py -f retinanet_res50_1x_800size.py \ ...@@ -65,7 +64,7 @@ python3 tools/train.py -f retinanet_res50_1x_800size.py \
- `-n`, 用于训练的devices(gpu)数量,默认使用所有可用的gpu. - `-n`, 用于训练的devices(gpu)数量,默认使用所有可用的gpu.
- `-w`, 预训练的backbone网络权重的路径。 - `-w`, 预训练的backbone网络权重的路径。
- `--batch_size`,训练时采用的`batch size`, 默认2,表示每张卡训2张图。 - `--batch_size`,训练时采用的`batch size`, 默认2,表示每张卡训2张图。
- `--dataset-dir`, coco数据集的根目录,默认`/data/datasets/coco` - `--dataset-dir`, COCO2017数据集的上级目录,默认`/data/datasets`
默认情况下模型会存在 `log-of-retinanet_res50_1x_800size`目录下。 默认情况下模型会存在 `log-of-retinanet_res50_1x_800size`目录下。
...@@ -74,10 +73,10 @@ python3 tools/train.py -f retinanet_res50_1x_800size.py \ ...@@ -74,10 +73,10 @@ python3 tools/train.py -f retinanet_res50_1x_800size.py \
在训练的过程中,可以通过如下命令测试模型在`COCO2017`验证集的性能: 在训练的过程中,可以通过如下命令测试模型在`COCO2017`验证集的性能:
```bash ```bash
python3 tools/test.py -n 8 \ python3 tools/test.py -f retinanet_res50_coco_1x_800size.py \
-f retinanet_res50_1x_800size.py \ -n 8 \
--model /path/to/retinanet_weights.pt \ --model /path/to/retinanet_weights.pt \
--dataset_dir /data/datasets/coco --dataset_dir /data/datasets
``` ```
`tools/test.py`的命令行选项如下: `tools/test.py`的命令行选项如下:
...@@ -85,7 +84,7 @@ python3 tools/test.py -n 8 \ ...@@ -85,7 +84,7 @@ python3 tools/test.py -n 8 \
- `-f`, 所需要测试的网络结构描述文件。 - `-f`, 所需要测试的网络结构描述文件。
- `-n`, 用于测试的devices(gpu)数量,默认1; - `-n`, 用于测试的devices(gpu)数量,默认1;
- `--model`, 需要测试的模型;可以从顶部的表格中下载训练好的检测器权重, 也可以用自行训练好的权重。 - `--model`, 需要测试的模型;可以从顶部的表格中下载训练好的检测器权重, 也可以用自行训练好的权重。
- `--dataset_dir`coco数据集的根目录,默认`/data/datasets` - `--dataset_dir`COCO2017数据集的上级目录,默认`/data/datasets`
## 参考文献 ## 参考文献
......
...@@ -10,8 +10,7 @@ import megengine as mge ...@@ -10,8 +10,7 @@ import megengine as mge
import megengine.functional as F import megengine.functional as F
import numpy as np import numpy as np
from megengine import _internal as mgb from megengine.core import Tensor
from megengine.core import Tensor, wrap_io_tensor
def get_padded_array_np( def get_padded_array_np(
...@@ -86,8 +85,3 @@ def get_padded_tensor( ...@@ -86,8 +85,3 @@ def get_padded_tensor(
else: else:
raise Exception("Not supported tensor dim: %d" % ndim) raise Exception("Not supported tensor dim: %d" % ndim)
return padded_array return padded_array
@wrap_io_tensor
def indexing_set_one_hot(inp, axis, idx, value) -> Tensor:
return mgb.opr.indexing_set_one_hot(inp, axis, idx, value)
...@@ -12,8 +12,6 @@ import numpy as np ...@@ -12,8 +12,6 @@ import numpy as np
from megengine.core import tensor, Tensor from megengine.core import tensor, Tensor
from official.vision.detection.layers import basic
def get_focal_loss( def get_focal_loss(
score: Tensor, score: Tensor,
...@@ -51,28 +49,19 @@ def get_focal_loss( ...@@ -51,28 +49,19 @@ def get_focal_loss(
Returns: Returns:
the calculated focal loss. the calculated focal loss.
""" """
mask = 1 - (label == ignore_label) class_range = F.arange(1, score.shape[2] + 1)
valid_label = label * mask
label = F.add_axis(label, axis=2)
score_shp = score.shape pos_part = (1 - score) ** gamma * F.log(score)
zero_mat = mge.zeros( neg_part = score ** gamma * F.log(1 - score)
F.concat([score_shp[0], score_shp[1], score_shp[2] + 1], axis=0),
dtype=np.float32, pos_loss = -(label == class_range) * pos_part * alpha
) neg_loss = -(label != class_range) * (label != ignore_label) * neg_part * (1 - alpha)
one_mat = mge.ones( loss = pos_loss + neg_loss
F.concat([score_shp[0], score_shp[1], tensor(1)], axis=0), dtype=np.float32,
)
one_hot = basic.indexing_set_one_hot(
zero_mat, 2, valid_label.astype(np.int32), one_mat
)[:, :, 1:]
pos_part = F.power(1 - score, gamma) * one_hot * F.log(score)
neg_part = F.power(score, gamma) * (1 - one_hot) * F.log(1 - score)
loss = -(alpha * pos_part + (1 - alpha) * neg_part).sum(axis=2) * mask
if norm_type == "fg": if norm_type == "fg":
positive_mask = label > background fg_mask = (label != background) * (label != ignore_label)
return loss.sum() / F.maximum(positive_mask.sum(), 1) return loss.sum() / F.maximum(fg_mask.sum(), 1)
elif norm_type == "none": elif norm_type == "none":
return loss.sum() return loss.sum()
else: else:
...@@ -117,8 +106,7 @@ def get_smooth_l1_loss( ...@@ -117,8 +106,7 @@ def get_smooth_l1_loss(
gt_bbox = gt_bbox.reshape(-1, 4) gt_bbox = gt_bbox.reshape(-1, 4)
label = label.reshape(-1) label = label.reshape(-1)
valid_mask = 1 - (label == ignore_label) fg_mask = (label != background) * (label != ignore_label)
fg_mask = (1 - (label == background)) * valid_mask
losses = get_smooth_l1_base(pred_bbox, gt_bbox, sigma, is_fix=fix_smooth_l1) losses = get_smooth_l1_base(pred_bbox, gt_bbox, sigma, is_fix=fix_smooth_l1)
if norm_type == "fg": if norm_type == "fg":
...@@ -154,19 +142,16 @@ def get_smooth_l1_base( ...@@ -154,19 +142,16 @@ def get_smooth_l1_base(
cond_point = sigma cond_point = sigma
x = pred_bbox - gt_bbox x = pred_bbox - gt_bbox
abs_x = F.abs(x) abs_x = F.abs(x)
in_mask = abs_x < cond_point in_loss = 0.5 * x ** 2
out_mask = 1 - in_mask out_loss = sigma * abs_x - 0.5 * sigma ** 2
in_loss = 0.5 * (x ** 2)
out_loss = sigma * abs_x - 0.5 * (sigma ** 2)
loss = in_loss * in_mask + out_loss * out_mask
else: else:
sigma2 = sigma ** 2 sigma2 = sigma ** 2
cond_point = 1 / sigma2 cond_point = 1 / sigma2
x = pred_bbox - gt_bbox x = pred_bbox - gt_bbox
abs_x = F.abs(x) abs_x = F.abs(x)
in_mask = abs_x < cond_point in_loss = 0.5 * x ** 2 * sigma2
out_mask = 1 - in_mask
in_loss = 0.5 * (sigma * x) ** 2
out_loss = abs_x - 0.5 / sigma2 out_loss = abs_x - 0.5 / sigma2
loss = in_loss * in_mask + out_loss * out_mask in_mask = abs_x < cond_point
out_mask = 1 - in_mask
loss = in_loss * in_mask + out_loss * out_mask
return loss return loss
...@@ -28,7 +28,7 @@ class RetinaNetHead(M.Module): ...@@ -28,7 +28,7 @@ class RetinaNetHead(M.Module):
num_classes = cfg.num_classes num_classes = cfg.num_classes
num_convs = 4 num_convs = 4
prior_prob = cfg.cls_prior_prob prior_prob = cfg.cls_prior_prob
num_anchors = [9, 9, 9, 9, 9] num_anchors = [len(cfg.anchor_ratios) * len(cfg.anchor_scales)] * 5
assert ( assert (
len(set(num_anchors)) == 1 len(set(num_anchors)) == 1
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .retinanet import *
_EXCLUDE = {}
__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
...@@ -10,7 +10,6 @@ import megengine as mge ...@@ -10,7 +10,6 @@ import megengine as mge
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
import numpy as np import numpy as np
from megengine import hub
from official.vision.classification.resnet.model import resnet50 from official.vision.classification.resnet.model import resnet50
from official.vision.detection import layers from official.vision.detection import layers
...@@ -47,7 +46,7 @@ class RetinaNet(M.Module): ...@@ -47,7 +46,7 @@ class RetinaNet(M.Module):
for p in bottom_up.layer1.parameters(): for p in bottom_up.layer1.parameters():
p.requires_grad = False p.requires_grad = False
# -------------------------- build the FPN -------------------------- # # ----------------------- build the FPN ----------------------------- #
in_channels_p6p7 = 2048 in_channels_p6p7 = 2048
out_channels = 256 out_channels = 256
self.backbone = layers.FPN( self.backbone = layers.FPN(
...@@ -61,7 +60,7 @@ class RetinaNet(M.Module): ...@@ -61,7 +60,7 @@ class RetinaNet(M.Module):
backbone_shape = self.backbone.output_shape() backbone_shape = self.backbone.output_shape()
feature_shapes = [backbone_shape[f] for f in self.in_features] feature_shapes = [backbone_shape[f] for f in self.in_features]
# -------------------------- build the RetinaNet Head -------------- # # ----------------------- build the RetinaNet Head ------------------ #
self.head = layers.RetinaNetHead(cfg, feature_shapes) self.head = layers.RetinaNetHead(cfg, feature_shapes)
self.inputs = { self.inputs = {
...@@ -199,13 +198,22 @@ class RetinaNetConfig: ...@@ -199,13 +198,22 @@ class RetinaNetConfig:
self.resnet_norm = "FrozenBN" self.resnet_norm = "FrozenBN"
self.backbone_freeze_at = 2 self.backbone_freeze_at = 2
# ------------------------ data cfg --------------------------- # # ------------------------ data cfg -------------------------- #
self.train_dataset = dict(
name="coco",
root="train2017",
ann_file="instances_train2017.json"
)
self.test_dataset = dict(
name="coco",
root="val2017",
ann_file="instances_val2017.json"
)
self.train_image_short_size = 800 self.train_image_short_size = 800
self.train_image_max_size = 1333 self.train_image_max_size = 1333
self.num_classes = 80 self.num_classes = 80
self.img_mean = np.array([103.530, 116.280, 123.675]) # BGR self.img_mean = np.array([103.530, 116.280, 123.675]) # BGR
self.img_std = np.array([57.375, 57.120, 58.395]) self.img_std = np.array([57.375, 57.120, 58.395])
# self.img_std = np.array([1.0, 1.0, 1.0])
self.reg_mean = None self.reg_mean = None
self.reg_std = np.array([0.1, 0.1, 0.2, 0.2]) self.reg_std = np.array([0.1, 0.1, 0.2, 0.2])
...@@ -217,7 +225,7 @@ class RetinaNetConfig: ...@@ -217,7 +225,7 @@ class RetinaNetConfig:
self.class_aware_box = False self.class_aware_box = False
self.cls_prior_prob = 0.01 self.cls_prior_prob = 0.01
# ------------------------ losss cfg ------------------------- # # ------------------------ loss cfg -------------------------- #
self.focal_loss_alpha = 0.25 self.focal_loss_alpha = 0.25
self.focal_loss_gamma = 2 self.focal_loss_gamma = 2
self.reg_loss_weight = 1.0 / 4.0 self.reg_loss_weight = 1.0 / 4.0
...@@ -229,29 +237,14 @@ class RetinaNetConfig: ...@@ -229,29 +237,14 @@ class RetinaNetConfig:
self.log_interval = 20 self.log_interval = 20
self.nr_images_epoch = 80000 self.nr_images_epoch = 80000
self.max_epoch = 18 self.max_epoch = 18
self.warm_iters = 100 self.warm_iters = 500
self.lr_decay_rate = 0.1 self.lr_decay_rate = 0.1
self.lr_decay_sates = [12, 16, 17] self.lr_decay_sates = [12, 16, 17]
# ------------------------ testing cfg ------------------------- # # ------------------------ testing cfg ----------------------- #
self.test_image_short_size = 800 self.test_image_short_size = 800
self.test_image_max_size = 1333 self.test_image_max_size = 1333
self.test_max_boxes_per_image = 100 self.test_max_boxes_per_image = 100
self.test_vis_threshold = 0.3 self.test_vis_threshold = 0.3
self.test_cls_threshold = 0.05 self.test_cls_threshold = 0.05
self.test_nms = 0.5 self.test_nms = 0.5
@hub.pretrained(
"https://data.megengine.org.cn/models/weights/"
"retinanet_d3f58dce_res50_1x_800size_36dot0.pkl"
)
def retinanet_res50_1x_800size(batch_size=1, **kwargs):
r"""ResNet-18 model from
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
"""
return RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs)
Net = RetinaNet
Cfg = RetinaNetConfig
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import hub
from official.vision.detection import models
class CustomRetinaNetConfig(models.RetinaNetConfig):
def __init__(self):
super().__init__()
# ------------------------ data cfg -------------------------- #
self.train_dataset = dict(
name="coco",
root="train2017",
ann_file="annotations/instances_train2017.json"
)
self.test_dataset = dict(
name="coco",
root="val2017",
ann_file="annotations/instances_val2017.json"
)
@hub.pretrained(
"https://data.megengine.org.cn/models/weights/"
"retinanet_d3f58dce_res50_1x_800size_36dot0.pkl"
)
def retinanet_res50_coco_1x_800size(batch_size=1, **kwargs):
r"""ResNet-18 model from
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
"""
return models.RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs)
Net = models.RetinaNet
Cfg = CustomRetinaNetConfig
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import hub
from official.vision.detection import models
class CustomRetinaNetConfig(models.RetinaNetConfig):
def __init__(self):
super().__init__()
# ------------------------ data cfg -------------------------- #
self.train_dataset = dict(
name="objects365",
root="train",
ann_file="annotations/objects365_train_20190423.json"
)
self.test_dataset = dict(
name="objects365",
root="val",
ann_file="annotations/objects365_val_20190423.json"
)
# ------------------------ training cfg ---------------------- #
self.nr_images_epoch = 400000
def retinanet_objects365_res50_1x_800size(batch_size=1, **kwargs):
r"""ResNet-18 model from
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
"""
return models.RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs)
Net = models.RetinaNet
Cfg = CustomRetinaNetConfig
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine.data.dataset import COCO, Objects365
data_mapper = dict(
coco=COCO,
objects365=Objects365,
)
...@@ -19,9 +19,9 @@ import megengine as mge ...@@ -19,9 +19,9 @@ import megengine as mge
import numpy as np import numpy as np
from megengine import jit from megengine import jit
from megengine.data import DataLoader, SequentialSampler from megengine.data import DataLoader, SequentialSampler
from megengine.data.dataset import COCO as COCODataset
from tqdm import tqdm from tqdm import tqdm
from official.vision.detection.tools.data_mapper import data_mapper
from official.vision.detection.tools.nms import py_cpu_nms from official.vision.detection.tools.nms import py_cpu_nms
logger = mge.get_logger(__name__) logger = mge.get_logger(__name__)
...@@ -119,9 +119,10 @@ class DetEvaluator: ...@@ -119,9 +119,10 @@ class DetEvaluator:
return dtboxes_all return dtboxes_all
@staticmethod @staticmethod
def format(results): def format(results, cfg):
all_results = [] dataset_class = data_mapper[cfg.test_dataset["name"]]
all_results = []
for record in results: for record in results:
image_filename = record["image_id"] image_filename = record["image_id"]
boxes = record["det_res"] boxes = record["det_res"]
...@@ -133,8 +134,8 @@ class DetEvaluator: ...@@ -133,8 +134,8 @@ class DetEvaluator:
elem["image_id"] = image_filename elem["image_id"] = image_filename
elem["bbox"] = box[:4].tolist() elem["bbox"] = box[:4].tolist()
elem["score"] = box[4] elem["score"] = box[4]
elem["category_id"] = COCODataset.classes_originID[ elem["category_id"] = dataset_class.classes_originID[
COCODataset.class_names[int(box[5]) + 1] dataset_class.class_names[int(box[5])]
] ]
all_results.append(elem) all_results.append(elem)
return all_results return all_results
...@@ -156,7 +157,7 @@ class DetEvaluator: ...@@ -156,7 +157,7 @@ class DetEvaluator:
for det in dets: for det in dets:
bb = det[:4].astype(int) bb = det[:4].astype(int)
if is_show_label: if is_show_label:
cls_id = int(det[5] + 1) cls_id = int(det[5])
score = det[4] score = det[4]
if cls_id == 0: if cls_id == 0:
...@@ -200,10 +201,10 @@ class DetEvaluator: ...@@ -200,10 +201,10 @@ class DetEvaluator:
break break
def build_dataloader(rank, world_size, data_dir): def build_dataloader(rank, world_size, data_dir, cfg):
val_dataset = COCODataset( val_dataset = data_mapper[cfg.test_dataset["name"]](
os.path.join(data_dir, "val2017"), os.path.join(data_dir, cfg.test_dataset["name"], cfg.test_dataset["root"]),
os.path.join(data_dir, "annotations/instances_val2017.json"), os.path.join(data_dir, cfg.test_dataset["name"], cfg.test_dataset["ann_file"]),
order=["image", "info"], order=["image", "info"],
) )
val_sampler = SequentialSampler(val_dataset, 1, world_size=world_size, rank=rank) val_sampler = SequentialSampler(val_dataset, 1, world_size=world_size, rank=rank)
...@@ -236,7 +237,7 @@ def worker( ...@@ -236,7 +237,7 @@ def worker(
evaluator = DetEvaluator(model) evaluator = DetEvaluator(model)
model.load_state_dict(mge.load(model_file)["state_dict"]) model.load_state_dict(mge.load(model_file)["state_dict"])
loader = build_dataloader(worker_id, total_worker, data_dir) loader = build_dataloader(worker_id, total_worker, data_dir, model.cfg)
for data_dict in loader: for data_dict in loader:
data, im_info = DetEvaluator.process_inputs( data, im_info = DetEvaluator.process_inputs(
data_dict[0][0], data_dict[0][0],
...@@ -262,7 +263,7 @@ def make_parser(): ...@@ -262,7 +263,7 @@ def make_parser():
parser.add_argument( parser.add_argument(
"-f", "--file", default="net.py", type=str, help="net description file" "-f", "--file", default="net.py", type=str, help="net description file"
) )
parser.add_argument("-d", "--dataset_dir", default="/data/datasets/coco", type=str) parser.add_argument("-d", "--dataset_dir", default="/data/datasets", type=str)
parser.add_argument("-se", "--start_epoch", default=-1, type=int) parser.add_argument("-se", "--start_epoch", default=-1, type=int)
parser.add_argument("-ee", "--end_epoch", default=-1, type=int) parser.add_argument("-ee", "--end_epoch", default=-1, type=int)
parser.add_argument("-m", "--model", default=None, type=str) parser.add_argument("-m", "--model", default=None, type=str)
...@@ -312,7 +313,12 @@ def main(): ...@@ -312,7 +313,12 @@ def main():
for p in procs: for p in procs:
p.join() p.join()
all_results = DetEvaluator.format(results_list) sys.path.insert(0, os.path.dirname(args.file))
current_network = importlib.import_module(
os.path.basename(args.file).split(".")[0]
)
cfg = current_network.Cfg()
all_results = DetEvaluator.format(results_list, cfg)
json_path = "log-of-{}/epoch_{}.json".format( json_path = "log-of-{}/epoch_{}.json".format(
os.path.basename(args.file).split(".")[0], epoch_num os.path.basename(args.file).split(".")[0], epoch_num
) )
...@@ -323,7 +329,9 @@ def main(): ...@@ -323,7 +329,9 @@ def main():
logger.info("Save to %s finished, start evaluation!", json_path) logger.info("Save to %s finished, start evaluation!", json_path)
eval_gt = COCO( eval_gt = COCO(
os.path.join(args.dataset_dir, "annotations/instances_val2017.json") os.path.join(
args.dataset_dir, cfg.test_dataset["name"], cfg.test_dataset["ann_file"]
)
) )
eval_dt = eval_gt.loadRes(json_path) eval_dt = eval_gt.loadRes(json_path)
cocoEval = COCOeval(eval_gt, eval_dt, iouType="bbox") cocoEval = COCOeval(eval_gt, eval_dt, iouType="bbox")
......
...@@ -22,9 +22,10 @@ from megengine import jit ...@@ -22,9 +22,10 @@ from megengine import jit
from megengine import optimizer as optim from megengine import optimizer as optim
from megengine.data import Collator, DataLoader, Infinite, RandomSampler from megengine.data import Collator, DataLoader, Infinite, RandomSampler
from megengine.data import transform as T from megengine.data import transform as T
from megengine.data.dataset import COCO
from tabulate import tabulate from tabulate import tabulate
from official.vision.detection.tools.data_mapper import data_mapper
logger = mge.get_logger(__name__) logger = mge.get_logger(__name__)
...@@ -175,7 +176,7 @@ def make_parser(): ...@@ -175,7 +176,7 @@ def make_parser():
"-b", "--batch_size", default=2, type=int, help="batchsize for training", "-b", "--batch_size", default=2, type=int, help="batchsize for training",
) )
parser.add_argument( parser.add_argument(
"-d", "--dataset_dir", default="/data/datasets/coco", type=str, "-d", "--dataset_dir", default="/data/datasets", type=str,
) )
return parser return parser
...@@ -232,9 +233,9 @@ def main(): ...@@ -232,9 +233,9 @@ def main():
def build_dataloader(batch_size, data_dir, cfg): def build_dataloader(batch_size, data_dir, cfg):
train_dataset = COCO( train_dataset = data_mapper[cfg.train_dataset["name"]](
os.path.join(data_dir, "train2017"), os.path.join(data_dir, cfg.train_dataset["name"], cfg.train_dataset["root"]),
os.path.join(data_dir, "annotations/instances_train2017.json"), os.path.join(data_dir, cfg.train_dataset["name"], cfg.train_dataset["ann_file"]),
remove_images_without_annotations=True, remove_images_without_annotations=True,
order=["image", "boxes", "boxes_category", "info"], order=["image", "boxes", "boxes_category", "info"],
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册