diff --git a/official/vision/classification/shufflenet/model.py b/official/vision/classification/shufflenet/model.py
index 68d05d7e4ef0dda4411da527b47e955426093aa0..7622a7c17ba631002797430653d1d4da531a5215 100644
--- a/official/vision/classification/shufflenet/model.py
+++ b/official/vision/classification/shufflenet/model.py
@@ -110,7 +110,7 @@ class ShuffleV2Block(M.Module):
class ShuffleNetV2(M.Module):
- def __init__(self, input_size=224, num_classes=1000, model_size="1.5x"):
+ def __init__(self, num_classes=1000, model_size="1.5x"):
super(ShuffleNetV2, self).__init__()
self.stage_repeats = [4, 8, 4]
diff --git a/official/vision/detection/README.md b/official/vision/detection/README.md
index f66403c004800de4bf7b518701d5dcc40fed3ead..4b0481d230e2403cb44c1f155ba956cfae1ae417 100644
--- a/official/vision/detection/README.md
+++ b/official/vision/detection/README.md
@@ -2,23 +2,22 @@
## 介绍
-本目录包含了采用MegEngine实现的经典[RetinaNet](https://arxiv.org/pdf/1708.02002>)网络结构,
-同时提供了在COCO2017数据集上的完整训练和测试代码。
+本目录包含了采用MegEngine实现的经典[RetinaNet](https://arxiv.org/pdf/1708.02002>)网络结构,同时提供了在COCO2017数据集上的完整训练和测试代码。
网络的性能在COCO2017验证集上的测试结果如下:
-| 模型 | mAP
@5-95 | batch
/gpu | gpu | speed
(8gpu) | speed
(1gpu)|
-| --- | --- | --- | --- | --- | --- |
-| retinanet-res50-1x-800size | 36.0 | 2 | 2080 | 2.27(it/s) | 3.7(it/s) |
+| 模型 | mAP
@5-95 | batch
/gpu | gpu | speed
(8gpu) | speed
(1gpu) |
+| --- | --- | --- | --- | --- | --- |
+| retinanet-res50-coco-1x-800size | 36.0 | 2 | 2080ti | 2.27(it/s) | 3.7(it/s) |
-* MegEngine v0.3.0
+* MegEngine v0.4.0
## 如何使用
模型训练好之后,可以通过如下命令测试单张图片:
```bash
-python3 tools/inference.py -f retinanet_res50_1x_800size.py \
+python3 tools/inference.py -f retinanet_res50_coco_1x_800size.py \
-i ../../assets/cat.jpg \
-m /path/to/retinanet_weights.pkl
```
@@ -35,8 +34,8 @@ python3 tools/inference.py -f retinanet_res50_1x_800size.py \
## 如何训练
-1. 在开始训练前,请确保已经下载解压好[COCO数据集](http://cocodataset.org/#download),
-并放在合适的数据目录下,准备好的数据集的目录结构如下所示(目前默认使用coco2017的数据集):
+1. 在开始训练前,请确保已经下载解压好[COCO2017数据集](http://cocodataset.org/#download),
+并放在合适的数据目录下,准备好的数据集的目录结构如下所示(目前默认使用COCO2017数据集):
```
/path/to/
@@ -46,14 +45,14 @@ python3 tools/inference.py -f retinanet_res50_1x_800size.py \
| |val2017
```
-2. 准备预训练的`backbone`网络权重:可使用 megengine.hub 下载`megengine`官方提供的在ImageNet上训练的resnet50模型, 并存放在 `/path/to/pretrain.pkl`。
+2. 准备预训练的`backbone`网络权重:可使用 megengine.hub 下载`megengine`官方提供的在ImageNet上训练的ResNet-50模型, 并存放在 `/path/to/pretrain.pkl`。
3. 在开始运行本目录下的代码之前,请确保按照[README](../../../README.md)进行了正确的环境配置。
4. 开始训练:
```bash
-python3 tools/train.py -f retinanet_res50_1x_800size.py \
+python3 tools/train.py -f retinanet_res50_coco_1x_800size.py \
-n 8 \
--batch_size 2 \
-w /path/to/pretrain.pkl
@@ -65,7 +64,7 @@ python3 tools/train.py -f retinanet_res50_1x_800size.py \
- `-n`, 用于训练的devices(gpu)数量,默认使用所有可用的gpu.
- `-w`, 预训练的backbone网络权重的路径。
- `--batch_size`,训练时采用的`batch size`, 默认2,表示每张卡训2张图。
-- `--dataset-dir`, coco数据集的根目录,默认`/data/datasets/coco`。
+- `--dataset-dir`, COCO2017数据集的上级目录,默认`/data/datasets`。
默认情况下模型会存在 `log-of-retinanet_res50_1x_800size`目录下。
@@ -74,10 +73,10 @@ python3 tools/train.py -f retinanet_res50_1x_800size.py \
在训练的过程中,可以通过如下命令测试模型在`COCO2017`验证集的性能:
```bash
-python3 tools/test.py -n 8 \
- -f retinanet_res50_1x_800size.py \
+python3 tools/test.py -f retinanet_res50_coco_1x_800size.py \
+ -n 8 \
--model /path/to/retinanet_weights.pt \
- --dataset_dir /data/datasets/coco
+ --dataset_dir /data/datasets
```
`tools/test.py`的命令行选项如下:
@@ -85,7 +84,7 @@ python3 tools/test.py -n 8 \
- `-f`, 所需要测试的网络结构描述文件。
- `-n`, 用于测试的devices(gpu)数量,默认1;
- `--model`, 需要测试的模型;可以从顶部的表格中下载训练好的检测器权重, 也可以用自行训练好的权重。
-- `--dataset_dir`,coco数据集的根目录,默认`/data/datasets`
+- `--dataset_dir`,COCO2017数据集的上级目录,默认`/data/datasets`
## 参考文献
diff --git a/official/vision/detection/layers/basic/functional.py b/official/vision/detection/layers/basic/functional.py
index ae4f117b0706b18b3786e12f8e17ec37a54eedcc..8fdff3f5c5db77f2ed12310960a8ba4c981fd674 100644
--- a/official/vision/detection/layers/basic/functional.py
+++ b/official/vision/detection/layers/basic/functional.py
@@ -10,8 +10,7 @@ import megengine as mge
import megengine.functional as F
import numpy as np
-from megengine import _internal as mgb
-from megengine.core import Tensor, wrap_io_tensor
+from megengine.core import Tensor
def get_padded_array_np(
@@ -86,8 +85,3 @@ def get_padded_tensor(
else:
raise Exception("Not supported tensor dim: %d" % ndim)
return padded_array
-
-
-@wrap_io_tensor
-def indexing_set_one_hot(inp, axis, idx, value) -> Tensor:
- return mgb.opr.indexing_set_one_hot(inp, axis, idx, value)
diff --git a/official/vision/detection/layers/det/loss.py b/official/vision/detection/layers/det/loss.py
index 0feafc557ce6b28bed690e91bca7a83339175323..c355d52c3930973369f0c435bc01db6a6ba4adca 100644
--- a/official/vision/detection/layers/det/loss.py
+++ b/official/vision/detection/layers/det/loss.py
@@ -12,8 +12,6 @@ import numpy as np
from megengine.core import tensor, Tensor
-from official.vision.detection.layers import basic
-
def get_focal_loss(
score: Tensor,
@@ -51,28 +49,19 @@ def get_focal_loss(
Returns:
the calculated focal loss.
"""
- mask = 1 - (label == ignore_label)
- valid_label = label * mask
-
- score_shp = score.shape
- zero_mat = mge.zeros(
- F.concat([score_shp[0], score_shp[1], score_shp[2] + 1], axis=0),
- dtype=np.float32,
- )
- one_mat = mge.ones(
- F.concat([score_shp[0], score_shp[1], tensor(1)], axis=0), dtype=np.float32,
- )
-
- one_hot = basic.indexing_set_one_hot(
- zero_mat, 2, valid_label.astype(np.int32), one_mat
- )[:, :, 1:]
- pos_part = F.power(1 - score, gamma) * one_hot * F.log(score)
- neg_part = F.power(score, gamma) * (1 - one_hot) * F.log(1 - score)
- loss = -(alpha * pos_part + (1 - alpha) * neg_part).sum(axis=2) * mask
+ class_range = F.arange(1, score.shape[2] + 1)
+
+ label = F.add_axis(label, axis=2)
+ pos_part = (1 - score) ** gamma * F.log(score)
+ neg_part = score ** gamma * F.log(1 - score)
+
+ pos_loss = -(label == class_range) * pos_part * alpha
+ neg_loss = -(label != class_range) * (label != ignore_label) * neg_part * (1 - alpha)
+ loss = pos_loss + neg_loss
if norm_type == "fg":
- positive_mask = label > background
- return loss.sum() / F.maximum(positive_mask.sum(), 1)
+ fg_mask = (label != background) * (label != ignore_label)
+ return loss.sum() / F.maximum(fg_mask.sum(), 1)
elif norm_type == "none":
return loss.sum()
else:
@@ -117,8 +106,7 @@ def get_smooth_l1_loss(
gt_bbox = gt_bbox.reshape(-1, 4)
label = label.reshape(-1)
- valid_mask = 1 - (label == ignore_label)
- fg_mask = (1 - (label == background)) * valid_mask
+ fg_mask = (label != background) * (label != ignore_label)
losses = get_smooth_l1_base(pred_bbox, gt_bbox, sigma, is_fix=fix_smooth_l1)
if norm_type == "fg":
@@ -154,19 +142,16 @@ def get_smooth_l1_base(
cond_point = sigma
x = pred_bbox - gt_bbox
abs_x = F.abs(x)
- in_mask = abs_x < cond_point
- out_mask = 1 - in_mask
- in_loss = 0.5 * (x ** 2)
- out_loss = sigma * abs_x - 0.5 * (sigma ** 2)
- loss = in_loss * in_mask + out_loss * out_mask
+ in_loss = 0.5 * x ** 2
+ out_loss = sigma * abs_x - 0.5 * sigma ** 2
else:
sigma2 = sigma ** 2
cond_point = 1 / sigma2
x = pred_bbox - gt_bbox
abs_x = F.abs(x)
- in_mask = abs_x < cond_point
- out_mask = 1 - in_mask
- in_loss = 0.5 * (sigma * x) ** 2
+ in_loss = 0.5 * x ** 2 * sigma2
out_loss = abs_x - 0.5 / sigma2
- loss = in_loss * in_mask + out_loss * out_mask
+ in_mask = abs_x < cond_point
+ out_mask = 1 - in_mask
+ loss = in_loss * in_mask + out_loss * out_mask
return loss
diff --git a/official/vision/detection/layers/det/retinanet.py b/official/vision/detection/layers/det/retinanet.py
index 460a8568c145cd6e99a3e3c666038200f66ca39c..a6b7b49b8530561f1b2868e17038e980f018898d 100644
--- a/official/vision/detection/layers/det/retinanet.py
+++ b/official/vision/detection/layers/det/retinanet.py
@@ -28,7 +28,7 @@ class RetinaNetHead(M.Module):
num_classes = cfg.num_classes
num_convs = 4
prior_prob = cfg.cls_prior_prob
- num_anchors = [9, 9, 9, 9, 9]
+ num_anchors = [len(cfg.anchor_ratios) * len(cfg.anchor_scales)] * 5
assert (
len(set(num_anchors)) == 1
diff --git a/official/vision/detection/models/__init__.py b/official/vision/detection/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac9cdf6d1288df22ac8a3e71c04bb49b2c051911
--- /dev/null
+++ b/official/vision/detection/models/__init__.py
@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+#
+# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+from .retinanet import *
+
+_EXCLUDE = {}
+__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
diff --git a/official/vision/detection/retinanet_res50_1x_800size.py b/official/vision/detection/models/retinanet.py
similarity index 92%
rename from official/vision/detection/retinanet_res50_1x_800size.py
rename to official/vision/detection/models/retinanet.py
index eaa02d43af01557b790cd0eb93f2981e25ff7349..255d447e99a2319d394c7f5f534a0f24dba2dab8 100644
--- a/official/vision/detection/retinanet_res50_1x_800size.py
+++ b/official/vision/detection/models/retinanet.py
@@ -10,7 +10,6 @@ import megengine as mge
import megengine.functional as F
import megengine.module as M
import numpy as np
-from megengine import hub
from official.vision.classification.resnet.model import resnet50
from official.vision.detection import layers
@@ -47,7 +46,7 @@ class RetinaNet(M.Module):
for p in bottom_up.layer1.parameters():
p.requires_grad = False
- # -------------------------- build the FPN -------------------------- #
+ # ----------------------- build the FPN ----------------------------- #
in_channels_p6p7 = 2048
out_channels = 256
self.backbone = layers.FPN(
@@ -61,7 +60,7 @@ class RetinaNet(M.Module):
backbone_shape = self.backbone.output_shape()
feature_shapes = [backbone_shape[f] for f in self.in_features]
- # -------------------------- build the RetinaNet Head -------------- #
+ # ----------------------- build the RetinaNet Head ------------------ #
self.head = layers.RetinaNetHead(cfg, feature_shapes)
self.inputs = {
@@ -199,13 +198,22 @@ class RetinaNetConfig:
self.resnet_norm = "FrozenBN"
self.backbone_freeze_at = 2
- # ------------------------ data cfg --------------------------- #
+ # ------------------------ data cfg -------------------------- #
+ self.train_dataset = dict(
+ name="coco",
+ root="train2017",
+ ann_file="instances_train2017.json"
+ )
+ self.test_dataset = dict(
+ name="coco",
+ root="val2017",
+ ann_file="instances_val2017.json"
+ )
self.train_image_short_size = 800
self.train_image_max_size = 1333
self.num_classes = 80
self.img_mean = np.array([103.530, 116.280, 123.675]) # BGR
self.img_std = np.array([57.375, 57.120, 58.395])
- # self.img_std = np.array([1.0, 1.0, 1.0])
self.reg_mean = None
self.reg_std = np.array([0.1, 0.1, 0.2, 0.2])
@@ -217,7 +225,7 @@ class RetinaNetConfig:
self.class_aware_box = False
self.cls_prior_prob = 0.01
- # ------------------------ losss cfg ------------------------- #
+ # ------------------------ loss cfg -------------------------- #
self.focal_loss_alpha = 0.25
self.focal_loss_gamma = 2
self.reg_loss_weight = 1.0 / 4.0
@@ -229,29 +237,14 @@ class RetinaNetConfig:
self.log_interval = 20
self.nr_images_epoch = 80000
self.max_epoch = 18
- self.warm_iters = 100
+ self.warm_iters = 500
self.lr_decay_rate = 0.1
self.lr_decay_sates = [12, 16, 17]
- # ------------------------ testing cfg ------------------------- #
+ # ------------------------ testing cfg ----------------------- #
self.test_image_short_size = 800
self.test_image_max_size = 1333
self.test_max_boxes_per_image = 100
self.test_vis_threshold = 0.3
self.test_cls_threshold = 0.05
self.test_nms = 0.5
-
-
-@hub.pretrained(
- "https://data.megengine.org.cn/models/weights/"
- "retinanet_d3f58dce_res50_1x_800size_36dot0.pkl"
-)
-def retinanet_res50_1x_800size(batch_size=1, **kwargs):
- r"""ResNet-18 model from
- `"RetinaNet" `_
- """
- return RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs)
-
-
-Net = RetinaNet
-Cfg = RetinaNetConfig
diff --git a/official/vision/detection/retinanet_res50_coco_1x_800size.py b/official/vision/detection/retinanet_res50_coco_1x_800size.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdb9cc4a009781889be7193d89a1dd1f5fce6b47
--- /dev/null
+++ b/official/vision/detection/retinanet_res50_coco_1x_800size.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+#
+# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+from megengine import hub
+
+from official.vision.detection import models
+
+
+class CustomRetinaNetConfig(models.RetinaNetConfig):
+ def __init__(self):
+ super().__init__()
+
+ # ------------------------ data cfg -------------------------- #
+ self.train_dataset = dict(
+ name="coco",
+ root="train2017",
+ ann_file="annotations/instances_train2017.json"
+ )
+ self.test_dataset = dict(
+ name="coco",
+ root="val2017",
+ ann_file="annotations/instances_val2017.json"
+ )
+
+
+@hub.pretrained(
+ "https://data.megengine.org.cn/models/weights/"
+ "retinanet_d3f58dce_res50_1x_800size_36dot0.pkl"
+)
+def retinanet_res50_coco_1x_800size(batch_size=1, **kwargs):
+ r"""ResNet-18 model from
+ `"RetinaNet" `_
+ """
+ return models.RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs)
+
+
+Net = models.RetinaNet
+Cfg = CustomRetinaNetConfig
diff --git a/official/vision/detection/retinanet_res50_objects365_1x_800size.py b/official/vision/detection/retinanet_res50_objects365_1x_800size.py
new file mode 100644
index 0000000000000000000000000000000000000000..028cebffc29f00698cf92f212af6dcc66c66f6cf
--- /dev/null
+++ b/official/vision/detection/retinanet_res50_objects365_1x_800size.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+#
+# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+from megengine import hub
+
+from official.vision.detection import models
+
+
+class CustomRetinaNetConfig(models.RetinaNetConfig):
+ def __init__(self):
+ super().__init__()
+
+ # ------------------------ data cfg -------------------------- #
+ self.train_dataset = dict(
+ name="objects365",
+ root="train",
+ ann_file="annotations/objects365_train_20190423.json"
+ )
+ self.test_dataset = dict(
+ name="objects365",
+ root="val",
+ ann_file="annotations/objects365_val_20190423.json"
+ )
+
+ # ------------------------ training cfg ---------------------- #
+ self.nr_images_epoch = 400000
+
+
+def retinanet_objects365_res50_1x_800size(batch_size=1, **kwargs):
+ r"""ResNet-18 model from
+ `"RetinaNet" `_
+ """
+ return models.RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs)
+
+
+Net = models.RetinaNet
+Cfg = CustomRetinaNetConfig
diff --git a/official/vision/detection/tools/data_mapper.py b/official/vision/detection/tools/data_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f5d5445666d3e755cbd515016762acb03eab1be
--- /dev/null
+++ b/official/vision/detection/tools/data_mapper.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+#
+# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+from megengine.data.dataset import COCO, Objects365
+
+data_mapper = dict(
+ coco=COCO,
+ objects365=Objects365,
+)
diff --git a/official/vision/detection/tools/test.py b/official/vision/detection/tools/test.py
index d57e3a0b69c49b968439b7f1fdf49d1a9efc2967..f9f66d4e4d9675da21b5f5dcf39a6d7835984010 100644
--- a/official/vision/detection/tools/test.py
+++ b/official/vision/detection/tools/test.py
@@ -19,9 +19,9 @@ import megengine as mge
import numpy as np
from megengine import jit
from megengine.data import DataLoader, SequentialSampler
-from megengine.data.dataset import COCO as COCODataset
from tqdm import tqdm
+from official.vision.detection.tools.data_mapper import data_mapper
from official.vision.detection.tools.nms import py_cpu_nms
logger = mge.get_logger(__name__)
@@ -119,9 +119,10 @@ class DetEvaluator:
return dtboxes_all
@staticmethod
- def format(results):
- all_results = []
+ def format(results, cfg):
+ dataset_class = data_mapper[cfg.test_dataset["name"]]
+ all_results = []
for record in results:
image_filename = record["image_id"]
boxes = record["det_res"]
@@ -133,8 +134,8 @@ class DetEvaluator:
elem["image_id"] = image_filename
elem["bbox"] = box[:4].tolist()
elem["score"] = box[4]
- elem["category_id"] = COCODataset.classes_originID[
- COCODataset.class_names[int(box[5]) + 1]
+ elem["category_id"] = dataset_class.classes_originID[
+ dataset_class.class_names[int(box[5])]
]
all_results.append(elem)
return all_results
@@ -156,7 +157,7 @@ class DetEvaluator:
for det in dets:
bb = det[:4].astype(int)
if is_show_label:
- cls_id = int(det[5] + 1)
+ cls_id = int(det[5])
score = det[4]
if cls_id == 0:
@@ -200,10 +201,10 @@ class DetEvaluator:
break
-def build_dataloader(rank, world_size, data_dir):
- val_dataset = COCODataset(
- os.path.join(data_dir, "val2017"),
- os.path.join(data_dir, "annotations/instances_val2017.json"),
+def build_dataloader(rank, world_size, data_dir, cfg):
+ val_dataset = data_mapper[cfg.test_dataset["name"]](
+ os.path.join(data_dir, cfg.test_dataset["name"], cfg.test_dataset["root"]),
+ os.path.join(data_dir, cfg.test_dataset["name"], cfg.test_dataset["ann_file"]),
order=["image", "info"],
)
val_sampler = SequentialSampler(val_dataset, 1, world_size=world_size, rank=rank)
@@ -236,7 +237,7 @@ def worker(
evaluator = DetEvaluator(model)
model.load_state_dict(mge.load(model_file)["state_dict"])
- loader = build_dataloader(worker_id, total_worker, data_dir)
+ loader = build_dataloader(worker_id, total_worker, data_dir, model.cfg)
for data_dict in loader:
data, im_info = DetEvaluator.process_inputs(
data_dict[0][0],
@@ -262,7 +263,7 @@ def make_parser():
parser.add_argument(
"-f", "--file", default="net.py", type=str, help="net description file"
)
- parser.add_argument("-d", "--dataset_dir", default="/data/datasets/coco", type=str)
+ parser.add_argument("-d", "--dataset_dir", default="/data/datasets", type=str)
parser.add_argument("-se", "--start_epoch", default=-1, type=int)
parser.add_argument("-ee", "--end_epoch", default=-1, type=int)
parser.add_argument("-m", "--model", default=None, type=str)
@@ -312,7 +313,12 @@ def main():
for p in procs:
p.join()
- all_results = DetEvaluator.format(results_list)
+ sys.path.insert(0, os.path.dirname(args.file))
+ current_network = importlib.import_module(
+ os.path.basename(args.file).split(".")[0]
+ )
+ cfg = current_network.Cfg()
+ all_results = DetEvaluator.format(results_list, cfg)
json_path = "log-of-{}/epoch_{}.json".format(
os.path.basename(args.file).split(".")[0], epoch_num
)
@@ -323,7 +329,9 @@ def main():
logger.info("Save to %s finished, start evaluation!", json_path)
eval_gt = COCO(
- os.path.join(args.dataset_dir, "annotations/instances_val2017.json")
+ os.path.join(
+ args.dataset_dir, cfg.test_dataset["name"], cfg.test_dataset["ann_file"]
+ )
)
eval_dt = eval_gt.loadRes(json_path)
cocoEval = COCOeval(eval_gt, eval_dt, iouType="bbox")
diff --git a/official/vision/detection/tools/train.py b/official/vision/detection/tools/train.py
index 6d5ad7b77a5b139992728f0bf4eea9342082fa55..c8a8c0a5460e59aecc6fc4243ffa5b4c6b67fe40 100644
--- a/official/vision/detection/tools/train.py
+++ b/official/vision/detection/tools/train.py
@@ -22,9 +22,10 @@ from megengine import jit
from megengine import optimizer as optim
from megengine.data import Collator, DataLoader, Infinite, RandomSampler
from megengine.data import transform as T
-from megengine.data.dataset import COCO
from tabulate import tabulate
+from official.vision.detection.tools.data_mapper import data_mapper
+
logger = mge.get_logger(__name__)
@@ -175,7 +176,7 @@ def make_parser():
"-b", "--batch_size", default=2, type=int, help="batchsize for training",
)
parser.add_argument(
- "-d", "--dataset_dir", default="/data/datasets/coco", type=str,
+ "-d", "--dataset_dir", default="/data/datasets", type=str,
)
return parser
@@ -232,9 +233,9 @@ def main():
def build_dataloader(batch_size, data_dir, cfg):
- train_dataset = COCO(
- os.path.join(data_dir, "train2017"),
- os.path.join(data_dir, "annotations/instances_train2017.json"),
+ train_dataset = data_mapper[cfg.train_dataset["name"]](
+ os.path.join(data_dir, cfg.train_dataset["name"], cfg.train_dataset["root"]),
+ os.path.join(data_dir, cfg.train_dataset["name"], cfg.train_dataset["ann_file"]),
remove_images_without_annotations=True,
order=["image", "boxes", "boxes_category", "info"],
)