diff --git a/official/vision/classification/shufflenet/model.py b/official/vision/classification/shufflenet/model.py index 68d05d7e4ef0dda4411da527b47e955426093aa0..7622a7c17ba631002797430653d1d4da531a5215 100644 --- a/official/vision/classification/shufflenet/model.py +++ b/official/vision/classification/shufflenet/model.py @@ -110,7 +110,7 @@ class ShuffleV2Block(M.Module): class ShuffleNetV2(M.Module): - def __init__(self, input_size=224, num_classes=1000, model_size="1.5x"): + def __init__(self, num_classes=1000, model_size="1.5x"): super(ShuffleNetV2, self).__init__() self.stage_repeats = [4, 8, 4] diff --git a/official/vision/detection/README.md b/official/vision/detection/README.md index f66403c004800de4bf7b518701d5dcc40fed3ead..4b0481d230e2403cb44c1f155ba956cfae1ae417 100644 --- a/official/vision/detection/README.md +++ b/official/vision/detection/README.md @@ -2,23 +2,22 @@ ## 介绍 -本目录包含了采用MegEngine实现的经典[RetinaNet](https://arxiv.org/pdf/1708.02002>)网络结构, -同时提供了在COCO2017数据集上的完整训练和测试代码。 +本目录包含了采用MegEngine实现的经典[RetinaNet](https://arxiv.org/pdf/1708.02002>)网络结构,同时提供了在COCO2017数据集上的完整训练和测试代码。 网络的性能在COCO2017验证集上的测试结果如下: -| 模型 | mAP
@5-95 | batch
/gpu | gpu | speed
(8gpu) | speed
(1gpu)| -| --- | --- | --- | --- | --- | --- | -| retinanet-res50-1x-800size | 36.0 | 2 | 2080 | 2.27(it/s) | 3.7(it/s) | +| 模型 | mAP
@5-95 | batch
/gpu | gpu | speed
(8gpu) | speed
(1gpu) | +| --- | --- | --- | --- | --- | --- | +| retinanet-res50-coco-1x-800size | 36.0 | 2 | 2080ti | 2.27(it/s) | 3.7(it/s) | -* MegEngine v0.3.0 +* MegEngine v0.4.0 ## 如何使用 模型训练好之后,可以通过如下命令测试单张图片: ```bash -python3 tools/inference.py -f retinanet_res50_1x_800size.py \ +python3 tools/inference.py -f retinanet_res50_coco_1x_800size.py \ -i ../../assets/cat.jpg \ -m /path/to/retinanet_weights.pkl ``` @@ -35,8 +34,8 @@ python3 tools/inference.py -f retinanet_res50_1x_800size.py \ ## 如何训练 -1. 在开始训练前,请确保已经下载解压好[COCO数据集](http://cocodataset.org/#download), -并放在合适的数据目录下,准备好的数据集的目录结构如下所示(目前默认使用coco2017的数据集): +1. 在开始训练前,请确保已经下载解压好[COCO2017数据集](http://cocodataset.org/#download), +并放在合适的数据目录下,准备好的数据集的目录结构如下所示(目前默认使用COCO2017数据集): ``` /path/to/ @@ -46,14 +45,14 @@ python3 tools/inference.py -f retinanet_res50_1x_800size.py \ | |val2017 ``` -2. 准备预训练的`backbone`网络权重:可使用 megengine.hub 下载`megengine`官方提供的在ImageNet上训练的resnet50模型, 并存放在 `/path/to/pretrain.pkl`。 +2. 准备预训练的`backbone`网络权重:可使用 megengine.hub 下载`megengine`官方提供的在ImageNet上训练的ResNet-50模型, 并存放在 `/path/to/pretrain.pkl`。 3. 在开始运行本目录下的代码之前,请确保按照[README](../../../README.md)进行了正确的环境配置。 4. 开始训练: ```bash -python3 tools/train.py -f retinanet_res50_1x_800size.py \ +python3 tools/train.py -f retinanet_res50_coco_1x_800size.py \ -n 8 \ --batch_size 2 \ -w /path/to/pretrain.pkl @@ -65,7 +64,7 @@ python3 tools/train.py -f retinanet_res50_1x_800size.py \ - `-n`, 用于训练的devices(gpu)数量,默认使用所有可用的gpu. - `-w`, 预训练的backbone网络权重的路径。 - `--batch_size`,训练时采用的`batch size`, 默认2,表示每张卡训2张图。 -- `--dataset-dir`, coco数据集的根目录,默认`/data/datasets/coco`。 +- `--dataset-dir`, COCO2017数据集的上级目录,默认`/data/datasets`。 默认情况下模型会存在 `log-of-retinanet_res50_1x_800size`目录下。 @@ -74,10 +73,10 @@ python3 tools/train.py -f retinanet_res50_1x_800size.py \ 在训练的过程中,可以通过如下命令测试模型在`COCO2017`验证集的性能: ```bash -python3 tools/test.py -n 8 \ - -f retinanet_res50_1x_800size.py \ +python3 tools/test.py -f retinanet_res50_coco_1x_800size.py \ + -n 8 \ --model /path/to/retinanet_weights.pt \ - --dataset_dir /data/datasets/coco + --dataset_dir /data/datasets ``` `tools/test.py`的命令行选项如下: @@ -85,7 +84,7 @@ python3 tools/test.py -n 8 \ - `-f`, 所需要测试的网络结构描述文件。 - `-n`, 用于测试的devices(gpu)数量,默认1; - `--model`, 需要测试的模型;可以从顶部的表格中下载训练好的检测器权重, 也可以用自行训练好的权重。 -- `--dataset_dir`,coco数据集的根目录,默认`/data/datasets` +- `--dataset_dir`,COCO2017数据集的上级目录,默认`/data/datasets` ## 参考文献 diff --git a/official/vision/detection/layers/basic/functional.py b/official/vision/detection/layers/basic/functional.py index ae4f117b0706b18b3786e12f8e17ec37a54eedcc..8fdff3f5c5db77f2ed12310960a8ba4c981fd674 100644 --- a/official/vision/detection/layers/basic/functional.py +++ b/official/vision/detection/layers/basic/functional.py @@ -10,8 +10,7 @@ import megengine as mge import megengine.functional as F import numpy as np -from megengine import _internal as mgb -from megengine.core import Tensor, wrap_io_tensor +from megengine.core import Tensor def get_padded_array_np( @@ -86,8 +85,3 @@ def get_padded_tensor( else: raise Exception("Not supported tensor dim: %d" % ndim) return padded_array - - -@wrap_io_tensor -def indexing_set_one_hot(inp, axis, idx, value) -> Tensor: - return mgb.opr.indexing_set_one_hot(inp, axis, idx, value) diff --git a/official/vision/detection/layers/det/loss.py b/official/vision/detection/layers/det/loss.py index 0feafc557ce6b28bed690e91bca7a83339175323..c355d52c3930973369f0c435bc01db6a6ba4adca 100644 --- a/official/vision/detection/layers/det/loss.py +++ b/official/vision/detection/layers/det/loss.py @@ -12,8 +12,6 @@ import numpy as np from megengine.core import tensor, Tensor -from official.vision.detection.layers import basic - def get_focal_loss( score: Tensor, @@ -51,28 +49,19 @@ def get_focal_loss( Returns: the calculated focal loss. """ - mask = 1 - (label == ignore_label) - valid_label = label * mask - - score_shp = score.shape - zero_mat = mge.zeros( - F.concat([score_shp[0], score_shp[1], score_shp[2] + 1], axis=0), - dtype=np.float32, - ) - one_mat = mge.ones( - F.concat([score_shp[0], score_shp[1], tensor(1)], axis=0), dtype=np.float32, - ) - - one_hot = basic.indexing_set_one_hot( - zero_mat, 2, valid_label.astype(np.int32), one_mat - )[:, :, 1:] - pos_part = F.power(1 - score, gamma) * one_hot * F.log(score) - neg_part = F.power(score, gamma) * (1 - one_hot) * F.log(1 - score) - loss = -(alpha * pos_part + (1 - alpha) * neg_part).sum(axis=2) * mask + class_range = F.arange(1, score.shape[2] + 1) + + label = F.add_axis(label, axis=2) + pos_part = (1 - score) ** gamma * F.log(score) + neg_part = score ** gamma * F.log(1 - score) + + pos_loss = -(label == class_range) * pos_part * alpha + neg_loss = -(label != class_range) * (label != ignore_label) * neg_part * (1 - alpha) + loss = pos_loss + neg_loss if norm_type == "fg": - positive_mask = label > background - return loss.sum() / F.maximum(positive_mask.sum(), 1) + fg_mask = (label != background) * (label != ignore_label) + return loss.sum() / F.maximum(fg_mask.sum(), 1) elif norm_type == "none": return loss.sum() else: @@ -117,8 +106,7 @@ def get_smooth_l1_loss( gt_bbox = gt_bbox.reshape(-1, 4) label = label.reshape(-1) - valid_mask = 1 - (label == ignore_label) - fg_mask = (1 - (label == background)) * valid_mask + fg_mask = (label != background) * (label != ignore_label) losses = get_smooth_l1_base(pred_bbox, gt_bbox, sigma, is_fix=fix_smooth_l1) if norm_type == "fg": @@ -154,19 +142,16 @@ def get_smooth_l1_base( cond_point = sigma x = pred_bbox - gt_bbox abs_x = F.abs(x) - in_mask = abs_x < cond_point - out_mask = 1 - in_mask - in_loss = 0.5 * (x ** 2) - out_loss = sigma * abs_x - 0.5 * (sigma ** 2) - loss = in_loss * in_mask + out_loss * out_mask + in_loss = 0.5 * x ** 2 + out_loss = sigma * abs_x - 0.5 * sigma ** 2 else: sigma2 = sigma ** 2 cond_point = 1 / sigma2 x = pred_bbox - gt_bbox abs_x = F.abs(x) - in_mask = abs_x < cond_point - out_mask = 1 - in_mask - in_loss = 0.5 * (sigma * x) ** 2 + in_loss = 0.5 * x ** 2 * sigma2 out_loss = abs_x - 0.5 / sigma2 - loss = in_loss * in_mask + out_loss * out_mask + in_mask = abs_x < cond_point + out_mask = 1 - in_mask + loss = in_loss * in_mask + out_loss * out_mask return loss diff --git a/official/vision/detection/layers/det/retinanet.py b/official/vision/detection/layers/det/retinanet.py index 460a8568c145cd6e99a3e3c666038200f66ca39c..a6b7b49b8530561f1b2868e17038e980f018898d 100644 --- a/official/vision/detection/layers/det/retinanet.py +++ b/official/vision/detection/layers/det/retinanet.py @@ -28,7 +28,7 @@ class RetinaNetHead(M.Module): num_classes = cfg.num_classes num_convs = 4 prior_prob = cfg.cls_prior_prob - num_anchors = [9, 9, 9, 9, 9] + num_anchors = [len(cfg.anchor_ratios) * len(cfg.anchor_scales)] * 5 assert ( len(set(num_anchors)) == 1 diff --git a/official/vision/detection/models/__init__.py b/official/vision/detection/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac9cdf6d1288df22ac8a3e71c04bb49b2c051911 --- /dev/null +++ b/official/vision/detection/models/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .retinanet import * + +_EXCLUDE = {} +__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")] diff --git a/official/vision/detection/retinanet_res50_1x_800size.py b/official/vision/detection/models/retinanet.py similarity index 92% rename from official/vision/detection/retinanet_res50_1x_800size.py rename to official/vision/detection/models/retinanet.py index eaa02d43af01557b790cd0eb93f2981e25ff7349..255d447e99a2319d394c7f5f534a0f24dba2dab8 100644 --- a/official/vision/detection/retinanet_res50_1x_800size.py +++ b/official/vision/detection/models/retinanet.py @@ -10,7 +10,6 @@ import megengine as mge import megengine.functional as F import megengine.module as M import numpy as np -from megengine import hub from official.vision.classification.resnet.model import resnet50 from official.vision.detection import layers @@ -47,7 +46,7 @@ class RetinaNet(M.Module): for p in bottom_up.layer1.parameters(): p.requires_grad = False - # -------------------------- build the FPN -------------------------- # + # ----------------------- build the FPN ----------------------------- # in_channels_p6p7 = 2048 out_channels = 256 self.backbone = layers.FPN( @@ -61,7 +60,7 @@ class RetinaNet(M.Module): backbone_shape = self.backbone.output_shape() feature_shapes = [backbone_shape[f] for f in self.in_features] - # -------------------------- build the RetinaNet Head -------------- # + # ----------------------- build the RetinaNet Head ------------------ # self.head = layers.RetinaNetHead(cfg, feature_shapes) self.inputs = { @@ -199,13 +198,22 @@ class RetinaNetConfig: self.resnet_norm = "FrozenBN" self.backbone_freeze_at = 2 - # ------------------------ data cfg --------------------------- # + # ------------------------ data cfg -------------------------- # + self.train_dataset = dict( + name="coco", + root="train2017", + ann_file="instances_train2017.json" + ) + self.test_dataset = dict( + name="coco", + root="val2017", + ann_file="instances_val2017.json" + ) self.train_image_short_size = 800 self.train_image_max_size = 1333 self.num_classes = 80 self.img_mean = np.array([103.530, 116.280, 123.675]) # BGR self.img_std = np.array([57.375, 57.120, 58.395]) - # self.img_std = np.array([1.0, 1.0, 1.0]) self.reg_mean = None self.reg_std = np.array([0.1, 0.1, 0.2, 0.2]) @@ -217,7 +225,7 @@ class RetinaNetConfig: self.class_aware_box = False self.cls_prior_prob = 0.01 - # ------------------------ losss cfg ------------------------- # + # ------------------------ loss cfg -------------------------- # self.focal_loss_alpha = 0.25 self.focal_loss_gamma = 2 self.reg_loss_weight = 1.0 / 4.0 @@ -229,29 +237,14 @@ class RetinaNetConfig: self.log_interval = 20 self.nr_images_epoch = 80000 self.max_epoch = 18 - self.warm_iters = 100 + self.warm_iters = 500 self.lr_decay_rate = 0.1 self.lr_decay_sates = [12, 16, 17] - # ------------------------ testing cfg ------------------------- # + # ------------------------ testing cfg ----------------------- # self.test_image_short_size = 800 self.test_image_max_size = 1333 self.test_max_boxes_per_image = 100 self.test_vis_threshold = 0.3 self.test_cls_threshold = 0.05 self.test_nms = 0.5 - - -@hub.pretrained( - "https://data.megengine.org.cn/models/weights/" - "retinanet_d3f58dce_res50_1x_800size_36dot0.pkl" -) -def retinanet_res50_1x_800size(batch_size=1, **kwargs): - r"""ResNet-18 model from - `"RetinaNet" `_ - """ - return RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs) - - -Net = RetinaNet -Cfg = RetinaNetConfig diff --git a/official/vision/detection/retinanet_res50_coco_1x_800size.py b/official/vision/detection/retinanet_res50_coco_1x_800size.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb9cc4a009781889be7193d89a1dd1f5fce6b47 --- /dev/null +++ b/official/vision/detection/retinanet_res50_coco_1x_800size.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from megengine import hub + +from official.vision.detection import models + + +class CustomRetinaNetConfig(models.RetinaNetConfig): + def __init__(self): + super().__init__() + + # ------------------------ data cfg -------------------------- # + self.train_dataset = dict( + name="coco", + root="train2017", + ann_file="annotations/instances_train2017.json" + ) + self.test_dataset = dict( + name="coco", + root="val2017", + ann_file="annotations/instances_val2017.json" + ) + + +@hub.pretrained( + "https://data.megengine.org.cn/models/weights/" + "retinanet_d3f58dce_res50_1x_800size_36dot0.pkl" +) +def retinanet_res50_coco_1x_800size(batch_size=1, **kwargs): + r"""ResNet-18 model from + `"RetinaNet" `_ + """ + return models.RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs) + + +Net = models.RetinaNet +Cfg = CustomRetinaNetConfig diff --git a/official/vision/detection/retinanet_res50_objects365_1x_800size.py b/official/vision/detection/retinanet_res50_objects365_1x_800size.py new file mode 100644 index 0000000000000000000000000000000000000000..028cebffc29f00698cf92f212af6dcc66c66f6cf --- /dev/null +++ b/official/vision/detection/retinanet_res50_objects365_1x_800size.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from megengine import hub + +from official.vision.detection import models + + +class CustomRetinaNetConfig(models.RetinaNetConfig): + def __init__(self): + super().__init__() + + # ------------------------ data cfg -------------------------- # + self.train_dataset = dict( + name="objects365", + root="train", + ann_file="annotations/objects365_train_20190423.json" + ) + self.test_dataset = dict( + name="objects365", + root="val", + ann_file="annotations/objects365_val_20190423.json" + ) + + # ------------------------ training cfg ---------------------- # + self.nr_images_epoch = 400000 + + +def retinanet_objects365_res50_1x_800size(batch_size=1, **kwargs): + r"""ResNet-18 model from + `"RetinaNet" `_ + """ + return models.RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs) + + +Net = models.RetinaNet +Cfg = CustomRetinaNetConfig diff --git a/official/vision/detection/tools/data_mapper.py b/official/vision/detection/tools/data_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..4f5d5445666d3e755cbd515016762acb03eab1be --- /dev/null +++ b/official/vision/detection/tools/data_mapper.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from megengine.data.dataset import COCO, Objects365 + +data_mapper = dict( + coco=COCO, + objects365=Objects365, +) diff --git a/official/vision/detection/tools/test.py b/official/vision/detection/tools/test.py index d57e3a0b69c49b968439b7f1fdf49d1a9efc2967..f9f66d4e4d9675da21b5f5dcf39a6d7835984010 100644 --- a/official/vision/detection/tools/test.py +++ b/official/vision/detection/tools/test.py @@ -19,9 +19,9 @@ import megengine as mge import numpy as np from megengine import jit from megengine.data import DataLoader, SequentialSampler -from megengine.data.dataset import COCO as COCODataset from tqdm import tqdm +from official.vision.detection.tools.data_mapper import data_mapper from official.vision.detection.tools.nms import py_cpu_nms logger = mge.get_logger(__name__) @@ -119,9 +119,10 @@ class DetEvaluator: return dtboxes_all @staticmethod - def format(results): - all_results = [] + def format(results, cfg): + dataset_class = data_mapper[cfg.test_dataset["name"]] + all_results = [] for record in results: image_filename = record["image_id"] boxes = record["det_res"] @@ -133,8 +134,8 @@ class DetEvaluator: elem["image_id"] = image_filename elem["bbox"] = box[:4].tolist() elem["score"] = box[4] - elem["category_id"] = COCODataset.classes_originID[ - COCODataset.class_names[int(box[5]) + 1] + elem["category_id"] = dataset_class.classes_originID[ + dataset_class.class_names[int(box[5])] ] all_results.append(elem) return all_results @@ -156,7 +157,7 @@ class DetEvaluator: for det in dets: bb = det[:4].astype(int) if is_show_label: - cls_id = int(det[5] + 1) + cls_id = int(det[5]) score = det[4] if cls_id == 0: @@ -200,10 +201,10 @@ class DetEvaluator: break -def build_dataloader(rank, world_size, data_dir): - val_dataset = COCODataset( - os.path.join(data_dir, "val2017"), - os.path.join(data_dir, "annotations/instances_val2017.json"), +def build_dataloader(rank, world_size, data_dir, cfg): + val_dataset = data_mapper[cfg.test_dataset["name"]]( + os.path.join(data_dir, cfg.test_dataset["name"], cfg.test_dataset["root"]), + os.path.join(data_dir, cfg.test_dataset["name"], cfg.test_dataset["ann_file"]), order=["image", "info"], ) val_sampler = SequentialSampler(val_dataset, 1, world_size=world_size, rank=rank) @@ -236,7 +237,7 @@ def worker( evaluator = DetEvaluator(model) model.load_state_dict(mge.load(model_file)["state_dict"]) - loader = build_dataloader(worker_id, total_worker, data_dir) + loader = build_dataloader(worker_id, total_worker, data_dir, model.cfg) for data_dict in loader: data, im_info = DetEvaluator.process_inputs( data_dict[0][0], @@ -262,7 +263,7 @@ def make_parser(): parser.add_argument( "-f", "--file", default="net.py", type=str, help="net description file" ) - parser.add_argument("-d", "--dataset_dir", default="/data/datasets/coco", type=str) + parser.add_argument("-d", "--dataset_dir", default="/data/datasets", type=str) parser.add_argument("-se", "--start_epoch", default=-1, type=int) parser.add_argument("-ee", "--end_epoch", default=-1, type=int) parser.add_argument("-m", "--model", default=None, type=str) @@ -312,7 +313,12 @@ def main(): for p in procs: p.join() - all_results = DetEvaluator.format(results_list) + sys.path.insert(0, os.path.dirname(args.file)) + current_network = importlib.import_module( + os.path.basename(args.file).split(".")[0] + ) + cfg = current_network.Cfg() + all_results = DetEvaluator.format(results_list, cfg) json_path = "log-of-{}/epoch_{}.json".format( os.path.basename(args.file).split(".")[0], epoch_num ) @@ -323,7 +329,9 @@ def main(): logger.info("Save to %s finished, start evaluation!", json_path) eval_gt = COCO( - os.path.join(args.dataset_dir, "annotations/instances_val2017.json") + os.path.join( + args.dataset_dir, cfg.test_dataset["name"], cfg.test_dataset["ann_file"] + ) ) eval_dt = eval_gt.loadRes(json_path) cocoEval = COCOeval(eval_gt, eval_dt, iouType="bbox") diff --git a/official/vision/detection/tools/train.py b/official/vision/detection/tools/train.py index 6d5ad7b77a5b139992728f0bf4eea9342082fa55..c8a8c0a5460e59aecc6fc4243ffa5b4c6b67fe40 100644 --- a/official/vision/detection/tools/train.py +++ b/official/vision/detection/tools/train.py @@ -22,9 +22,10 @@ from megengine import jit from megengine import optimizer as optim from megengine.data import Collator, DataLoader, Infinite, RandomSampler from megengine.data import transform as T -from megengine.data.dataset import COCO from tabulate import tabulate +from official.vision.detection.tools.data_mapper import data_mapper + logger = mge.get_logger(__name__) @@ -175,7 +176,7 @@ def make_parser(): "-b", "--batch_size", default=2, type=int, help="batchsize for training", ) parser.add_argument( - "-d", "--dataset_dir", default="/data/datasets/coco", type=str, + "-d", "--dataset_dir", default="/data/datasets", type=str, ) return parser @@ -232,9 +233,9 @@ def main(): def build_dataloader(batch_size, data_dir, cfg): - train_dataset = COCO( - os.path.join(data_dir, "train2017"), - os.path.join(data_dir, "annotations/instances_train2017.json"), + train_dataset = data_mapper[cfg.train_dataset["name"]]( + os.path.join(data_dir, cfg.train_dataset["name"], cfg.train_dataset["root"]), + os.path.join(data_dir, cfg.train_dataset["name"], cfg.train_dataset["ann_file"]), remove_images_without_annotations=True, order=["image", "boxes", "boxes_category", "info"], )