From 514908dd346063a0b4ce1c3825ed145fe0509bef Mon Sep 17 00:00:00 2001 From: Jianfeng Wang Date: Wed, 17 Jun 2020 16:10:01 +0800 Subject: [PATCH] feat(detection): support RetinaNet with Objects365 and SyncBN (#29) * feat(detection): support objects365 * feat(detection): support retinanet with SyncBN * feat(detection): add GroupedSampler --- .../vision/detection/layers/basic/norm.py | 8 ++- official/vision/detection/layers/det/loss.py | 4 +- .../vision/detection/layers/det/retinanet.py | 4 +- official/vision/detection/models/retinanet.py | 10 +-- .../retinanet_res50_coco_1x_800size.py | 24 ++----- .../retinanet_res50_coco_1x_800size_syncbn.py | 32 ++++++++++ .../retinanet_res50_objects365_1x_800size.py | 10 +-- official/vision/detection/tools/inference.py | 5 +- official/vision/detection/tools/test.py | 5 +- official/vision/detection/tools/train.py | 64 ++++++++++++++++++- 10 files changed, 128 insertions(+), 38 deletions(-) create mode 100644 official/vision/detection/retinanet_res50_coco_1x_800size_syncbn.py diff --git a/official/vision/detection/layers/basic/norm.py b/official/vision/detection/layers/basic/norm.py index 487dd17..5d0463a 100644 --- a/official/vision/detection/layers/basic/norm.py +++ b/official/vision/detection/layers/basic/norm.py @@ -58,7 +58,7 @@ class FrozenBatchNorm2d(M.Module): def get_norm(norm, out_channels=None): """ Args: - norm (str): currently support "BN" and "FrozenBN" + norm (str): currently support "BN", "SyncBN" and "FrozenBN" Returns: M.Module or None: the normalization layer @@ -66,7 +66,11 @@ def get_norm(norm, out_channels=None): if isinstance(norm, str): if len(norm) == 0: return None - norm = {"BN": M.BatchNorm2d, "FrozenBN": FrozenBatchNorm2d}[norm] + norm = { + "BN": M.BatchNorm2d, + "SyncBN": M.SyncBatchNorm, + "FrozenBN": FrozenBatchNorm2d + }[norm] if out_channels is not None: return norm(out_channels) else: diff --git a/official/vision/detection/layers/det/loss.py b/official/vision/detection/layers/det/loss.py index c355d52..1d48adb 100644 --- a/official/vision/detection/layers/det/loss.py +++ b/official/vision/detection/layers/det/loss.py @@ -151,7 +151,5 @@ def get_smooth_l1_base( abs_x = F.abs(x) in_loss = 0.5 * x ** 2 * sigma2 out_loss = abs_x - 0.5 / sigma2 - in_mask = abs_x < cond_point - out_mask = 1 - in_mask - loss = in_loss * in_mask + out_loss * out_mask + loss = F.where(abs_x < cond_point, in_loss, out_loss) return loss diff --git a/official/vision/detection/layers/det/retinanet.py b/official/vision/detection/layers/det/retinanet.py index a6b7b49..df6d195 100644 --- a/official/vision/detection/layers/det/retinanet.py +++ b/official/vision/detection/layers/det/retinanet.py @@ -28,7 +28,9 @@ class RetinaNetHead(M.Module): num_classes = cfg.num_classes num_convs = 4 prior_prob = cfg.cls_prior_prob - num_anchors = [len(cfg.anchor_ratios) * len(cfg.anchor_scales)] * 5 + num_anchors = [len(cfg.anchor_ratios) * len(cfg.anchor_scales)] * len( + input_shape + ) assert ( len(set(num_anchors)) == 1 diff --git a/official/vision/detection/models/retinanet.py b/official/vision/detection/models/retinanet.py index 255d447..1013819 100644 --- a/official/vision/detection/models/retinanet.py +++ b/official/vision/detection/models/retinanet.py @@ -53,7 +53,7 @@ class RetinaNet(M.Module): bottom_up=bottom_up, in_features=["res3", "res4", "res5"], out_channels=out_channels, - norm="", + norm=self.cfg.fpn_norm, top_block=layers.LastLevelP6P7(in_channels_p6p7, out_channels), ) @@ -97,7 +97,8 @@ class RetinaNet(M.Module): ] anchors_list = [ - self.anchor_gen(features[i], self.stride_list[i]) for i in range(5) + self.anchor_gen(features[i], self.stride_list[i]) + for i in range(len(features)) ] all_level_box_cls = F.sigmoid(F.concat(box_cls_list, axis=1)) @@ -196,18 +197,19 @@ class RetinaNet(M.Module): class RetinaNetConfig: def __init__(self): self.resnet_norm = "FrozenBN" + self.fpn_norm = "" self.backbone_freeze_at = 2 # ------------------------ data cfg -------------------------- # self.train_dataset = dict( name="coco", root="train2017", - ann_file="instances_train2017.json" + ann_file="annotations/instances_train2017.json", ) self.test_dataset = dict( name="coco", root="val2017", - ann_file="instances_val2017.json" + ann_file="annotations/instances_val2017.json", ) self.train_image_short_size = 800 self.train_image_max_size = 1333 diff --git a/official/vision/detection/retinanet_res50_coco_1x_800size.py b/official/vision/detection/retinanet_res50_coco_1x_800size.py index bdb9cc4..8324290 100644 --- a/official/vision/detection/retinanet_res50_coco_1x_800size.py +++ b/official/vision/detection/retinanet_res50_coco_1x_800size.py @@ -11,33 +11,17 @@ from megengine import hub from official.vision.detection import models -class CustomRetinaNetConfig(models.RetinaNetConfig): - def __init__(self): - super().__init__() - - # ------------------------ data cfg -------------------------- # - self.train_dataset = dict( - name="coco", - root="train2017", - ann_file="annotations/instances_train2017.json" - ) - self.test_dataset = dict( - name="coco", - root="val2017", - ann_file="annotations/instances_val2017.json" - ) - - @hub.pretrained( "https://data.megengine.org.cn/models/weights/" "retinanet_d3f58dce_res50_1x_800size_36dot0.pkl" ) def retinanet_res50_coco_1x_800size(batch_size=1, **kwargs): - r"""ResNet-18 model from + r""" + RetinaNet trained from COCO dataset. `"RetinaNet" `_ """ - return models.RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs) + return models.RetinaNet(models.RetinaNetConfig(), batch_size=batch_size, **kwargs) Net = models.RetinaNet -Cfg = CustomRetinaNetConfig +Cfg = models.RetinaNetConfig diff --git a/official/vision/detection/retinanet_res50_coco_1x_800size_syncbn.py b/official/vision/detection/retinanet_res50_coco_1x_800size_syncbn.py new file mode 100644 index 0000000..363a542 --- /dev/null +++ b/official/vision/detection/retinanet_res50_coco_1x_800size_syncbn.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from megengine import hub + +from official.vision.detection import models + + +class CustomRetinaNetConfig(models.RetinaNetConfig): + def __init__(self): + super().__init__() + + self.resnet_norm = "SyncBN" + self.fpn_norm = "SyncBN" + self.backbone_freeze_at = 0 + + +def retinanet_res50_coco_1x_800size_syncbn(batch_size=1, **kwargs): + r""" + RetinaNet with SyncBN trained from COCO dataset. + `"RetinaNet" `_ + """ + return models.RetinaNet(CustomRetinaNetConfig(), batch_size=batch_size, **kwargs) + + +Net = models.RetinaNet +Cfg = CustomRetinaNetConfig diff --git a/official/vision/detection/retinanet_res50_objects365_1x_800size.py b/official/vision/detection/retinanet_res50_objects365_1x_800size.py index 028cebf..951d09a 100644 --- a/official/vision/detection/retinanet_res50_objects365_1x_800size.py +++ b/official/vision/detection/retinanet_res50_objects365_1x_800size.py @@ -19,23 +19,25 @@ class CustomRetinaNetConfig(models.RetinaNetConfig): self.train_dataset = dict( name="objects365", root="train", - ann_file="annotations/objects365_train_20190423.json" + ann_file="annotations/objects365_train_20190423.json", ) self.test_dataset = dict( name="objects365", root="val", - ann_file="annotations/objects365_val_20190423.json" + ann_file="annotations/objects365_val_20190423.json", ) + self.num_classes = 365 # ------------------------ training cfg ---------------------- # self.nr_images_epoch = 400000 def retinanet_objects365_res50_1x_800size(batch_size=1, **kwargs): - r"""ResNet-18 model from + r""" + RetinaNet trained from Objects365 dataset. `"RetinaNet" `_ """ - return models.RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs) + return models.RetinaNet(CustomRetinaNetConfig(), batch_size=batch_size, **kwargs) Net = models.RetinaNet diff --git a/official/vision/detection/tools/inference.py b/official/vision/detection/tools/inference.py index b8db4cd..c1a3556 100644 --- a/official/vision/detection/tools/inference.py +++ b/official/vision/detection/tools/inference.py @@ -47,7 +47,10 @@ def main(): current_network = importlib.import_module(os.path.basename(args.file).split(".")[0]) model = current_network.Net(current_network.Cfg(), batch_size=1) model.eval() - model.load_state_dict(mge.load(args.model)["state_dict"]) + state_dict = mge.load(args.model) + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + model.load_state_dict(state_dict) evaluator = DetEvaluator(model) diff --git a/official/vision/detection/tools/test.py b/official/vision/detection/tools/test.py index f9f66d4..4fe87a5 100644 --- a/official/vision/detection/tools/test.py +++ b/official/vision/detection/tools/test.py @@ -235,7 +235,10 @@ def worker( model = current_network.Net(current_network.Cfg(), batch_size=1) model.eval() evaluator = DetEvaluator(model) - model.load_state_dict(mge.load(model_file)["state_dict"]) + state_dict = mge.load(model_file) + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + model.load_state_dict(state_dict) loader = build_dataloader(worker_id, total_worker, data_dir, model.cfg) for data_dict in loader: diff --git a/official/vision/detection/tools/train.py b/official/vision/detection/tools/train.py index c8a8c0a..f93093d 100644 --- a/official/vision/detection/tools/train.py +++ b/official/vision/detection/tools/train.py @@ -232,14 +232,35 @@ def main(): worker(0, 1, args) +def build_sampler(train_dataset, batch_size, aspect_grouping=[1]): + def _compute_aspect_ratios(dataset): + aspect_ratios = [] + for i in range(len(dataset)): + info = dataset.get_img_info(i) + aspect_ratios.append(info["height"] / info["width"]) + return aspect_ratios + + def _quantize(x, bins): + return list(map(lambda y: bisect.bisect_right(sorted(bins), y), x)) + + if len(aspect_grouping) == 0: + return Infinite(RandomSampler(train_dataset, batch_size, drop_last=True)) + + aspect_ratios = _compute_aspect_ratios(train_dataset) + group_ids = _quantize(aspect_ratios, aspect_grouping) + return Infinite(GroupedRandomSampler(train_dataset, batch_size, group_ids)) + + def build_dataloader(batch_size, data_dir, cfg): train_dataset = data_mapper[cfg.train_dataset["name"]]( os.path.join(data_dir, cfg.train_dataset["name"], cfg.train_dataset["root"]), - os.path.join(data_dir, cfg.train_dataset["name"], cfg.train_dataset["ann_file"]), + os.path.join( + data_dir, cfg.train_dataset["name"], cfg.train_dataset["ann_file"] + ), remove_images_without_annotations=True, order=["image", "boxes", "boxes_category", "info"], ) - train_sampler = Infinite(RandomSampler(train_dataset, batch_size, drop_last=True)) + train_sampler = build_sampler(train_dataset, batch_size) train_dataloader = DataLoader( train_dataset, sampler=train_sampler, @@ -259,6 +280,45 @@ def build_dataloader(batch_size, data_dir, cfg): return {"train": train_dataloader} +class GroupedRandomSampler(RandomSampler): + def __init__( + self, + dataset, + batch_size, + group_ids, + indices=None, + world_size=None, + rank=None, + seed=None, + ): + super().__init__(dataset, batch_size, False, indices, world_size, rank, seed) + self.group_ids = group_ids + assert len(group_ids) == len(dataset) + groups = np.unique(self.group_ids).tolist() + + # buffer the indices of each group until batch size is reached + self.buffer_per_group = {k: [] for k in groups} + + def batch(self): + indices = list(self.sample()) + if self.world_size > 1: + indices = self.scatter(indices) + + batch_index = [] + for ind in indices: + group_id = self.group_ids[ind] + group_buffer = self.buffer_per_group[group_id] + group_buffer.append(ind) + if len(group_buffer) == self.batch_size: + batch_index.append(group_buffer) + self.buffer_per_group[group_id] = [] + + return iter(batch_index) + + def __len__(self): + raise NotImplementedError("len() of GroupedRandomSampler is not well-defined.") + + class DetectionPadCollator(Collator): def __init__(self, pad_value: float = 0.0): super().__init__() -- GitLab