diff --git a/README.md b/README.md index 92bfcaedad5fd55a4308af85838967727a4c7416..f828fa485cc2d9990aca23547062748df4b3f313 100644 --- a/README.md +++ b/README.md @@ -62,13 +62,13 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH | ResNet34 | 73.960 | 91.630 | | ResNet50 | 76.254 | 93.056 | | ResNet101 | 77.944 | 93.844 | -| ResNet152 | 78.582 | 94.130 | +| ResNet152 | 78.582 | 94.130 | | ResNeXt50 32x4d | 77.592 | 93.644 | | ResNeXt101 32x8d| 79.520 | 94.586 | -| ShuffleNetV2 x0.5 | 60.696 | 82.190 | -| ShuffleNetV2 x1.0 | 69.372 | 88.764 | -| ShuffleNetV2 x1.5 | 72.806 | 90.792 | -| ShuffleNetV2 x2.0 | 75.074 | 92.278 | +| ShuffleNetV2 x0.5 | 60.696 | 82.190 | +| ShuffleNetV2 x1.0 | 69.372 | 88.764 | +| ShuffleNetV2 x1.5 | 72.806 | 90.792 | +| ShuffleNetV2 x2.0 | 75.074 | 92.278 | ### 目标检测 @@ -89,19 +89,7 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH | :--: |:--: |:--: |:--: | | Deeplabv3plus | Resnet101 | 79.0 | 79.8 | -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD ### 人体关节点检测 -======= -### 人体关节点 ->>>>>>> update readme -======= -### 人体关节点 ->>>>>>> update readme -======= -### 人体关节点检测 ->>>>>>> 3fdaf98eee3169f70ace463d54cd177ee1fcf68e 我们提供了人体关节点检测的经典模型[SimpleBaseline](https://arxiv.org/pdf/1804.06208.pdf)和高精度模型[MSPN](https://arxiv.org/pdf/1901.00148.pdf),使用在COCO val2017上人体检测AP为56的检测结果,提供的模型在COCO val2017上的关节点检测结果为: diff --git a/hubconf.py b/hubconf.py index fc21310bed004f9b0c54a09b3ae1aa121e4d4ae4..cb2e9022e19c945dc7ca00f834df5f62d9d776d7 100644 --- a/hubconf.py +++ b/hubconf.py @@ -31,10 +31,23 @@ from official.nlp.bert.model import ( from official.vision.detection.faster_rcnn_fpn_res50_coco_1x_800size import ( faster_rcnn_fpn_res50_coco_1x_800size, ) +from official.vision.detection.faster_rcnn_fpn_res50_coco_1x_800size_syncbn import ( + faster_rcnn_fpn_res50_coco_1x_800size_syncbn, +) from official.vision.detection.retinanet_res50_coco_1x_800size import ( retinanet_res50_coco_1x_800size, ) +from official.vision.detection.retinanet_res50_coco_1x_800size_syncbn import ( + retinanet_res50_coco_1x_800size_syncbn, +) +# TODO: need pretrained weights +# from official.vision.detection.retinanet_res50_objects365_1x_800size import ( +# retinanet_res50_objects365_1x_800size, +# ) +# from official.vision.detection.retinanet_res50_voc_1x_800size import ( +# retinanet_res50_voc_1x_800size, +# ) from official.vision.detection.models import FasterRCNN, RetinaNet from official.vision.detection.tools.test import DetEvaluator @@ -45,10 +58,10 @@ from official.vision.segmentation.deeplabv3plus import ( ) from official.vision.keypoints.models import ( - simplebaseline_res50, - simplebaseline_res101, - simplebaseline_res152, - mspn_4stage + simplebaseline_res50, + simplebaseline_res101, + simplebaseline_res152, + mspn_4stage ) from official.vision.keypoints.inference import KeypointEvaluator diff --git a/official/vision/detection/faster_rcnn_fpn_res50_coco_1x_800size_syncbn.py b/official/vision/detection/faster_rcnn_fpn_res50_coco_1x_800size_syncbn.py index c557d46c3bef891b312644ecfb521a0b7b5d2389..33dd7ff86ce96bc5e891632aa1e0af3710faac66 100644 --- a/official/vision/detection/faster_rcnn_fpn_res50_coco_1x_800size_syncbn.py +++ b/official/vision/detection/faster_rcnn_fpn_res50_coco_1x_800size_syncbn.py @@ -17,6 +17,7 @@ class CustomFasterRCNNFPNConfig(models.FasterRCNNConfig): self.resnet_norm = "SyncBN" self.fpn_norm = "SyncBN" + self.backbone_freeze_at = 0 @hub.pretrained( diff --git a/official/vision/detection/layers/det/loss.py b/official/vision/detection/layers/det/loss.py index 3fda1e4b4295572a24391bcac75913bffdb7a144..b0fa4bf91f39e9ed7ab1f0ba333c634cebef0a35 100644 --- a/official/vision/detection/layers/det/loss.py +++ b/official/vision/detection/layers/det/loss.py @@ -50,8 +50,8 @@ def get_focal_loss( class_range = F.arange(1, score.shape[2] + 1) label = F.add_axis(label, axis=2) - pos_part = (1 - score) ** gamma * F.log(score) - neg_part = score ** gamma * F.log(1 - score) + pos_part = (1 - score) ** gamma * F.log(F.clamp(score, 1e-8)) + neg_part = score ** gamma * F.log(F.clamp(1 - score, 1e-8)) pos_loss = -(label == class_range) * pos_part * alpha neg_loss = -(label != class_range) * (label != ignore_label) * neg_part * (1 - alpha) @@ -151,6 +151,8 @@ def get_smooth_l1_base( in_loss = 0.5 * x ** 2 * sigma2 out_loss = abs_x - 0.5 / sigma2 + # FIXME: F.where cannot handle 0-shape tensor yet + # loss = F.where(abs_x < cond_point, in_loss, out_loss) in_mask = abs_x < cond_point out_mask = 1 - in_mask loss = in_loss * in_mask + out_loss * out_mask diff --git a/official/vision/detection/layers/det/rcnn.py b/official/vision/detection/layers/det/rcnn.py index a52edcf8ffeed4cda75822275934e905d7eedfcc..6d0e5e184b35570fa3166f6c591e5e8d5621776b 100644 --- a/official/vision/detection/layers/det/rcnn.py +++ b/official/vision/detection/layers/det/rcnn.py @@ -19,8 +19,8 @@ class RCNN(M.Module): super().__init__() self.cfg = cfg self.box_coder = layers.BoxCoder( - reg_mean=cfg.bbox_normalize_means, - reg_std=cfg.bbox_normalize_stds + reg_mean=cfg.rcnn_reg_mean, + reg_std=cfg.rcnn_reg_std ) # roi head diff --git a/official/vision/detection/models/faster_rcnn_fpn.py b/official/vision/detection/models/faster_rcnn_fpn.py index 7122e3d9ee9e69fb7bc8528eecff414c910b3510..eaaff9a7d51ecb41b7e972d7a0f849b07a214a9b 100644 --- a/official/vision/detection/models/faster_rcnn_fpn.py +++ b/official/vision/detection/models/faster_rcnn_fpn.py @@ -33,19 +33,19 @@ class FasterRCNN(M.Module): for p in bottom_up.layer1.parameters(): p.requires_grad = False - # -------------------------- build the FPN -------------------------- # + # ----------------------- build the FPN ----------------------------- # out_channels = 256 self.backbone = layers.FPN( bottom_up=bottom_up, in_features=["res2", "res3", "res4", "res5"], out_channels=out_channels, - norm="", + norm=cfg.fpn_norm, top_block=layers.FPNP6(), strides=[4, 8, 16, 32], channels=[256, 512, 1024, 2048], ) - # -------------------------- build the RPN -------------------------- # + # ----------------------- build the RPN ----------------------------- # self.RPN = layers.RPN(cfg) # ----------------------- build the RCNN head ----------------------- # @@ -122,24 +122,25 @@ class FasterRCNN(M.Module): class FasterRCNNConfig: - def __init__(self): self.resnet_norm = "FrozenBN" + self.fpn_norm = "" self.backbone_freeze_at = 2 - # ------------------------ data cfg --------------------------- # + # ------------------------ data cfg -------------------------- # self.train_dataset = dict( name="coco", root="train2017", ann_file="annotations/instances_train2017.json", + remove_images_without_annotations=True, ) self.test_dataset = dict( name="coco", root="val2017", ann_file="annotations/instances_val2017.json", + remove_images_without_annotations=False, ) self.num_classes = 80 - self.img_mean = np.array([103.530, 116.280, 123.675]) # BGR self.img_std = np.array([57.375, 57.120, 58.395]) @@ -150,9 +151,6 @@ class FasterRCNNConfig: self.anchor_offset = -0.5 self.num_cell_anchors = len(self.anchor_aspect_ratios) - self.bbox_normalize_means = None - self.bbox_normalize_stds = np.array([0.1, 0.1, 0.2, 0.2]) - self.rpn_stride = np.array([4, 8, 16, 32, 64]).astype(np.float32) self.rpn_in_features = ["p2", "p3", "p4", "p5", "p6"] self.rpn_channel = 256 @@ -175,12 +173,15 @@ class FasterRCNNConfig: self.bg_threshold_high = 0.5 self.bg_threshold_low = 0.0 + self.rcnn_reg_mean = None + self.rcnn_reg_std = np.array([0.1, 0.1, 0.2, 0.2]) self.rcnn_in_features = ["p2", "p3", "p4", "p5"] self.rcnn_stride = [4, 8, 16, 32] # ------------------------ loss cfg -------------------------- # self.rpn_smooth_l1_beta = 3 self.rcnn_smooth_l1_beta = 1 + self.num_losses = 5 # ------------------------ training cfg ---------------------- # self.train_image_short_size = 800 @@ -188,7 +189,6 @@ class FasterRCNNConfig: self.train_prev_nms_top_n = 2000 self.train_post_nms_top_n = 1000 - self.num_losses = 5 self.basic_lr = 0.02 / 16.0 # The basic learning rate for single-image self.momentum = 0.9 self.weight_decay = 1e-4 @@ -197,15 +197,14 @@ class FasterRCNNConfig: self.max_epoch = 18 self.warm_iters = 500 self.lr_decay_rate = 0.1 - self.lr_decay_sates = [12, 16, 17] + self.lr_decay_stages = [12, 16, 17] - # ------------------------ testing cfg ------------------------- # + # ------------------------ testing cfg ----------------------- # self.test_image_short_size = 800 self.test_image_max_size = 1333 self.test_prev_nms_top_n = 1000 self.test_post_nms_top_n = 1000 self.test_max_boxes_per_image = 100 - self.test_vis_threshold = 0.3 self.test_cls_threshold = 0.05 self.test_nms = 0.5 diff --git a/official/vision/detection/models/retinanet.py b/official/vision/detection/models/retinanet.py index eb5f07e3ac84f4d6ba563f94a578265717e51a46..f576c4bc67b69267f60a606b92aa6f27e066defa 100644 --- a/official/vision/detection/models/retinanet.py +++ b/official/vision/detection/models/retinanet.py @@ -36,7 +36,7 @@ class RetinaNet(M.Module): self.in_features = ["p3", "p4", "p5", "p6", "p7"] # ----------------------- build the backbone ------------------------ # - bottom_up = resnet50(norm=layers.get_norm(self.cfg.resnet_norm)) + bottom_up = resnet50(norm=layers.get_norm(cfg.resnet_norm)) # ------------ freeze the weights of resnet stage1 and stage 2 ------ # if self.cfg.backbone_freeze_at >= 1: @@ -53,7 +53,7 @@ class RetinaNet(M.Module): bottom_up=bottom_up, in_features=["res3", "res4", "res5"], out_channels=out_channels, - norm=self.cfg.fpn_norm, + norm=cfg.fpn_norm, top_block=layers.LastLevelP6P7(in_channels_p6p7, out_channels), ) @@ -211,14 +211,14 @@ class RetinaNetConfig: name="coco", root="train2017", ann_file="annotations/instances_train2017.json", + remove_images_without_annotations=True, ) self.test_dataset = dict( name="coco", root="val2017", ann_file="annotations/instances_val2017.json", + remove_images_without_annotations=False, ) - self.train_image_short_size = 800 - self.train_image_max_size = 1333 self.num_classes = 80 self.img_mean = np.array([103.530, 116.280, 123.675]) # BGR self.img_std = np.array([57.375, 57.120, 58.395]) @@ -240,6 +240,9 @@ class RetinaNetConfig: self.num_losses = 3 # ------------------------ training cfg ---------------------- # + self.train_image_short_size = 800 + self.train_image_max_size = 1333 + self.basic_lr = 0.01 / 16.0 # The basic learning rate for single-image self.momentum = 0.9 self.weight_decay = 1e-4 @@ -248,7 +251,7 @@ class RetinaNetConfig: self.max_epoch = 18 self.warm_iters = 500 self.lr_decay_rate = 0.1 - self.lr_decay_sates = [12, 16, 17] + self.lr_decay_stages = [12, 16, 17] # ------------------------ testing cfg ----------------------- # self.test_image_short_size = 800 diff --git a/official/vision/detection/retinanet_res50_objects365_1x_800size.py b/official/vision/detection/retinanet_res50_objects365_1x_800size.py index 2b5397856dfcee1315dc20eba9f268730088e9ca..ae5b745d44ac6d5e6e5035a516bf874471425860 100644 --- a/official/vision/detection/retinanet_res50_objects365_1x_800size.py +++ b/official/vision/detection/retinanet_res50_objects365_1x_800size.py @@ -18,11 +18,13 @@ class CustomRetinaNetConfig(models.RetinaNetConfig): name="objects365", root="train", ann_file="annotations/objects365_train_20190423.json", + remove_images_without_annotations=True, ) self.test_dataset = dict( name="objects365", root="val", ann_file="annotations/objects365_val_20190423.json", + remove_images_without_annotations=False, ) self.num_classes = 365 @@ -30,7 +32,7 @@ class CustomRetinaNetConfig(models.RetinaNetConfig): self.nr_images_epoch = 400000 -def retinanet_objects365_res50_1x_800size(batch_size=1, **kwargs): +def retinanet_res50_objects365_1x_800size(batch_size=1, **kwargs): r""" RetinaNet trained from Objects365 dataset. `"RetinaNet" `_ diff --git a/official/vision/detection/retinanet_res50_voc_1x_800size.py b/official/vision/detection/retinanet_res50_voc_1x_800size.py new file mode 100644 index 0000000000000000000000000000000000000000..725c8903306d9c57dfed23ee90f6ec54c4c99ec7 --- /dev/null +++ b/official/vision/detection/retinanet_res50_voc_1x_800size.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from megengine import hub + +from official.vision.detection import models + + +class CustomRetinaNetConfig(models.RetinaNetConfig): + def __init__(self): + super().__init__() + + # ------------------------ data cfg -------------------------- # + self.train_dataset = dict( + name="voc", + root="VOCdevkit/VOC2012", + image_set="train", + ) + self.test_dataset = dict( + name="voc", + root="VOCdevkit/VOC2012", + image_set="val", + ) + self.num_classes = 20 + + # ------------------------ training cfg ---------------------- # + self.nr_images_epoch = 16000 + + +def retinanet_res50_voc_1x_800size(batch_size=1, **kwargs): + r""" + RetinaNet trained from VOC dataset. + `"RetinaNet" `_ + """ + return models.RetinaNet(CustomRetinaNetConfig(), batch_size=batch_size, **kwargs) + + +Net = models.RetinaNet +Cfg = CustomRetinaNetConfig diff --git a/official/vision/detection/tools/data_mapper.py b/official/vision/detection/tools/data_mapper.py index 4f5d5445666d3e755cbd515016762acb03eab1be..222436f48a1c6d4fac7e0ebc6df25ad66d164107 100644 --- a/official/vision/detection/tools/data_mapper.py +++ b/official/vision/detection/tools/data_mapper.py @@ -6,9 +6,10 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from megengine.data.dataset import COCO, Objects365 +from megengine.data.dataset import COCO, Objects365, PascalVOC data_mapper = dict( coco=COCO, objects365=Objects365, + voc=PascalVOC, ) diff --git a/official/vision/detection/tools/train.py b/official/vision/detection/tools/train.py index a65bdbc5e2f8211ab59934e6c7eff55859908b54..fb280c54ce3e6c6178c82a1486aa275ea73be891 100644 --- a/official/vision/detection/tools/train.py +++ b/official/vision/detection/tools/train.py @@ -8,6 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import argparse import bisect +import copy import functools import importlib import multiprocessing as mp @@ -92,12 +93,21 @@ def worker(rank, world_size, args): * model.batch_size * ( model.cfg.lr_decay_rate - ** bisect.bisect_right(model.cfg.lr_decay_sates, epoch_id) + ** bisect.bisect_right(model.cfg.lr_decay_stages, epoch_id) ) ) tot_steps = model.cfg.nr_images_epoch // (model.batch_size * world_size) - train_one_epoch(model, train_loader, opt, tot_steps, rank, epoch_id, world_size) + train_one_epoch( + model, + train_loader, + opt, + tot_steps, + rank, + epoch_id, + world_size, + args.enable_sublinear, + ) if rank == 0: save_path = "log-of-{}/epoch_{}.pkl".format( os.path.basename(args.file).split(".")[0], epoch_id @@ -115,7 +125,7 @@ def adjust_learning_rate(optimizer, epoch_id, step, model, world_size): * model.batch_size * ( model.cfg.lr_decay_rate - ** bisect.bisect_right(model.cfg.lr_decay_sates, epoch_id) + ** bisect.bisect_right(model.cfg.lr_decay_stages, epoch_id) ) ) # Warm up @@ -125,8 +135,19 @@ def adjust_learning_rate(optimizer, epoch_id, step, model, world_size): param_group["lr"] = base_lr * lr_factor -def train_one_epoch(model, data_queue, opt, tot_steps, rank, epoch_id, world_size): - @jit.trace(symbolic=True, opt_level=2) +def train_one_epoch( + model, + data_queue, + opt, + tot_steps, + rank, + epoch_id, + world_size, + enable_sublinear=False, +): + sublinear_cfg = jit.SublinearMemoryConfig() if enable_sublinear else None + + @jit.trace(symbolic=True, opt_level=2, sublinear_memory_config=sublinear_cfg) def propagate(): loss_dict = model(model.inputs) opt.backward(loss_dict["total_loss"]) @@ -180,6 +201,7 @@ def make_parser(): parser.add_argument( "-d", "--dataset_dir", default="/data/datasets", type=str, ) + parser.add_argument("--enable_sublinear", action="store_true") return parser @@ -234,6 +256,20 @@ def main(): worker(0, 1, args) +def build_dataset(data_dir, cfg): + data_cfg = copy.deepcopy(cfg.train_dataset) + data_name = data_cfg.pop("name") + + data_cfg["root"] = os.path.join(data_dir, data_name, data_cfg["root"]) + + if "ann_file" in data_cfg: + data_cfg["ann_file"] = os.path.join(data_dir, data_name, data_cfg["ann_file"]) + + data_cfg["order"] = ["image", "boxes", "boxes_category", "info"] + + return data_mapper[data_name](**data_cfg) + + def build_sampler(train_dataset, batch_size, aspect_grouping=[1]): def _compute_aspect_ratios(dataset): aspect_ratios = [] @@ -254,14 +290,7 @@ def build_sampler(train_dataset, batch_size, aspect_grouping=[1]): def build_dataloader(batch_size, data_dir, cfg): - train_dataset = data_mapper[cfg.train_dataset["name"]]( - os.path.join(data_dir, cfg.train_dataset["name"], cfg.train_dataset["root"]), - os.path.join( - data_dir, cfg.train_dataset["name"], cfg.train_dataset["ann_file"] - ), - remove_images_without_annotations=True, - order=["image", "boxes", "boxes_category", "info"], - ) + train_dataset = build_dataset(data_dir, cfg) train_sampler = build_sampler(train_dataset, batch_size) train_dataloader = DataLoader( train_dataset,