未验证 提交 70c98b36 编写于 作者: J Jianfeng Wang 提交者: GitHub

feat(detection): support models with res101 backbone (#48)

上级 8940da38
...@@ -70,16 +70,18 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH ...@@ -70,16 +70,18 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH
| ShuffleNetV2 x1.5 | 72.806 | 90.792 | | ShuffleNetV2 x1.5 | 72.806 | 90.792 |
| ShuffleNetV2 x2.0 | 75.074 | 92.278 | | ShuffleNetV2 x2.0 | 75.074 | 92.278 |
### 目标检测 ### 目标检测
目标检测同样是计算机视觉中的常见任务,我们提供了两个经典的目标检测模型[Retinanet](./official/vision/detection/model/retinanet)[Faster R-CNN](./official/vision/detection/model/faster_rcnn),这两个模型在**COCO验证集**上的测试结果如下: 目标检测同样是计算机视觉中的常见任务,我们提供了两个经典的目标检测模型[Retinanet](./official/vision/detection/model/retinanet)[Faster R-CNN](./official/vision/detection/model/faster_rcnn),这两个模型在**COCO验证集**上的测试结果如下:
| 模型 | mAP<br>@5-95 | | 模型 | mAP<br>@5-95 |
| :---: | :---: | | :---: | :---: |
| retinanet-res50-1x-800size | 36.4 | | retinanet-res50-coco-1x-800size | 36.4 |
| faster-rcnn-res50-1x-800size | 38.8 | | retinanet-res50-coco-1x-800size-syncbn | 37.1 |
| retinanet-res101-coco-2x-800size | 40.8 |
| faster-rcnn-res50-coco-1x-800size | 38.8 |
| faster-rcnn-res50-coco-1x-800size-syncbn | 39.3 |
| faster-rcnn-res101-coco-2x-800size | 43.0 |
### 图像分割 ### 图像分割
...@@ -117,7 +119,6 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH ...@@ -117,7 +119,6 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH
| chinese_L-12_H-768_A-12| [link](https://data.megengine.org.cn/models/weights/bert/chinese_L-12_H-768_A-12/vocab.txt) | [link](https://data.megengine.org.cn/models/weights/bert/chinese_L-12_H-768_A-12/bert_config.json) | chinese_L-12_H-768_A-12| [link](https://data.megengine.org.cn/models/weights/bert/chinese_L-12_H-768_A-12/vocab.txt) | [link](https://data.megengine.org.cn/models/weights/bert/chinese_L-12_H-768_A-12/bert_config.json)
| multi_cased_L-12_H-768_A-12| [link](https://data.megengine.org.cn/models/weights/bert/multi_cased_L-12_H-768_A-12/vocab.txt) | [link](https://data.megengine.org.cn/models/weights/bert/multi_cased_L-12_H-768_A-12/bert_config.json) | multi_cased_L-12_H-768_A-12| [link](https://data.megengine.org.cn/models/weights/bert/multi_cased_L-12_H-768_A-12/vocab.txt) | [link](https://data.megengine.org.cn/models/weights/bert/multi_cased_L-12_H-768_A-12/bert_config.json)
在glue_data/MRPC数据集中使用默认的超参数进行微调和评估,评估结果介于84%和88%之间。 在glue_data/MRPC数据集中使用默认的超参数进行微调和评估,评估结果介于84%和88%之间。
| Dataset | pretrained_bert | acc | | Dataset | pretrained_bert | acc |
......
...@@ -30,8 +30,10 @@ from official.vision.classification.shufflenet.model import ( ...@@ -30,8 +30,10 @@ from official.vision.classification.shufflenet.model import (
from official.vision.detection.configs import ( from official.vision.detection.configs import (
faster_rcnn_res50_coco_1x_800size, faster_rcnn_res50_coco_1x_800size,
faster_rcnn_res50_coco_1x_800size_syncbn, faster_rcnn_res50_coco_1x_800size_syncbn,
faster_rcnn_res101_coco_2x_800size,
retinanet_res50_coco_1x_800size, retinanet_res50_coco_1x_800size,
retinanet_res50_coco_1x_800size_syncbn, retinanet_res50_coco_1x_800size_syncbn,
retinanet_res101_coco_2x_800size,
) )
from official.vision.detection.models import FasterRCNN, RetinaNet from official.vision.detection.models import FasterRCNN, RetinaNet
from official.vision.detection.tools.utils import DetEvaluator from official.vision.detection.tools.utils import DetEvaluator
......
...@@ -10,10 +10,12 @@ ...@@ -10,10 +10,12 @@
| --- | :---: | :---: | :---: | :---: | | --- | :---: | :---: | :---: | :---: |
| retinanet-res50-coco-1x-800size | 36.4 | 2 | 2080Ti | 3.1(it/s) | | retinanet-res50-coco-1x-800size | 36.4 | 2 | 2080Ti | 3.1(it/s) |
| retinanet-res50-coco-1x-800size-syncbn | 37.1 | 2 | 2080Ti | 1.7(it/s) | | retinanet-res50-coco-1x-800size-syncbn | 37.1 | 2 | 2080Ti | 1.7(it/s) |
| retinanet-res101-coco-2x-800size | 40.8 | 2 | 2080Ti | 2.1(it/s) |
| faster-rcnn-res50-coco-1x-800size | 38.8 | 2 | 2080Ti | 3.3(it/s) | | faster-rcnn-res50-coco-1x-800size | 38.8 | 2 | 2080Ti | 3.3(it/s) |
| faster-rcnn-res50-coco-1x-800size-syncbn | 39.3 | 2 | 2080Ti | 1.8(it/s) | | faster-rcnn-res50-coco-1x-800size-syncbn | 39.3 | 2 | 2080Ti | 1.8(it/s) |
| faster-rcnn-res101-coco-2x-800size | 43.0 | 2 | 2080Ti | 2.3(it/s) |
* MegEngine v0.4.0 * MegEngine v0.5.1
## 如何使用 ## 如何使用
......
from .faster_rcnn_res50_coco_1x_800size import faster_rcnn_res50_coco_1x_800size from .faster_rcnn_res50_coco_1x_800size import faster_rcnn_res50_coco_1x_800size
from .faster_rcnn_res50_coco_1x_800size_syncbn import faster_rcnn_res50_coco_1x_800size_syncbn from .faster_rcnn_res50_coco_1x_800size_syncbn import faster_rcnn_res50_coco_1x_800size_syncbn
from .faster_rcnn_res101_coco_2x_800size import faster_rcnn_res101_coco_2x_800size
from .retinanet_res50_coco_1x_800size import retinanet_res50_coco_1x_800size from .retinanet_res50_coco_1x_800size import retinanet_res50_coco_1x_800size
from .retinanet_res50_coco_1x_800size_syncbn import retinanet_res50_coco_1x_800size_syncbn from .retinanet_res50_coco_1x_800size_syncbn import retinanet_res50_coco_1x_800size_syncbn
from .retinanet_res101_coco_2x_800size import retinanet_res101_coco_2x_800size
_EXCLUDE = {} _EXCLUDE = {}
__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")] __all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import hub
from official.vision.detection import models
class CustomFasterRCNNConfig(models.FasterRCNNConfig):
def __init__(self):
super().__init__()
self.backbone = "resnet101"
# ------------------------ training cfg ---------------------- #
self.max_epoch = 36
self.lr_decay_stages = [24, 32, 34]
@hub.pretrained(
"https://data.megengine.org.cn/models/weights/"
"faster_rcnn_res101_coco_2x_800size_43dot0_ee249359.pkl"
)
def faster_rcnn_res101_coco_2x_800size(batch_size=1, **kwargs):
r"""
Faster-RCNN FPN trained from COCO dataset.
`"Faster-RCNN" <https://arxiv.org/abs/1506.01497>`_
`"FPN" <https://arxiv.org/abs/1612.03144>`_
`"COCO" <https://arxiv.org/abs/1405.0312>`_
"""
return models.FasterRCNN(CustomFasterRCNNConfig(), batch_size=batch_size, **kwargs)
Net = models.FasterRCNN
Cfg = CustomFasterRCNNConfig
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import hub
from official.vision.detection import models
class CustomRetinaNetConfig(models.RetinaNetConfig):
def __init__(self):
super().__init__()
self.backbone = "resnet101"
# ------------------------ training cfg ---------------------- #
self.max_epoch = 36
self.lr_decay_stages = [24, 32, 34]
@hub.pretrained(
"https://data.megengine.org.cn/models/weights/"
"retinanet_res101_coco_2x_800size_40dot8_661c3608.pkl"
)
def retinanet_res101_coco_2x_800size(batch_size=1, **kwargs):
r"""
RetinaNet trained from COCO dataset.
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
`"FPN" <https://arxiv.org/abs/1612.03144>`_
`"COCO" <https://arxiv.org/abs/1405.0312>`_
"""
return models.RetinaNet(CustomRetinaNetConfig(), batch_size=batch_size, **kwargs)
Net = models.RetinaNet
Cfg = CustomRetinaNetConfig
...@@ -48,7 +48,9 @@ def main(): ...@@ -48,7 +48,9 @@ def main():
sys.path.insert(0, os.path.dirname(args.file)) sys.path.insert(0, os.path.dirname(args.file))
current_network = importlib.import_module(os.path.basename(args.file).split(".")[0]) current_network = importlib.import_module(os.path.basename(args.file).split(".")[0])
model = current_network.Net(current_network.Cfg(), batch_size=1) cfg = current_network.Cfg()
cfg.backbone_pretrained = False
model = current_network.Net(cfg, batch_size=1)
model.eval() model.eval()
state_dict = mge.load(args.weight_file) state_dict = mge.load(args.weight_file)
if "state_dict" in state_dict: if "state_dict" in state_dict:
......
...@@ -37,9 +37,6 @@ def make_parser(): ...@@ -37,9 +37,6 @@ def make_parser():
parser.add_argument( parser.add_argument(
"-n", "--ngpus", default=1, type=int, help="total number of gpus for testing", "-n", "--ngpus", default=1, type=int, help="total number of gpus for testing",
) )
parser.add_argument(
"-b", "--batch_size", default=1, type=int, help="batchsize for testing",
)
parser.add_argument( parser.add_argument(
"-d", "--dataset_dir", default="/data/datasets", type=str, "-d", "--dataset_dir", default="/data/datasets", type=str,
) )
...@@ -56,6 +53,9 @@ def main(): ...@@ -56,6 +53,9 @@ def main():
parser = make_parser() parser = make_parser()
args = parser.parse_args() args = parser.parse_args()
sys.path.insert(0, os.path.dirname(args.file))
current_network = importlib.import_module(os.path.basename(args.file).split(".")[0])
if args.end_epoch == -1: if args.end_epoch == -1:
args.end_epoch = args.start_epoch args.end_epoch = args.start_epoch
...@@ -75,7 +75,7 @@ def main(): ...@@ -75,7 +75,7 @@ def main():
proc = Process( proc = Process(
target=worker, target=worker,
args=( args=(
args.file, current_network,
model_file, model_file,
args.dataset_dir, args.dataset_dir,
i, i,
...@@ -86,10 +86,6 @@ def main(): ...@@ -86,10 +86,6 @@ def main():
proc.start() proc.start()
procs.append(proc) procs.append(proc)
sys.path.insert(0, os.path.dirname(args.file))
current_network = importlib.import_module(
os.path.basename(args.file).split(".")[0]
)
cfg = current_network.Cfg() cfg = current_network.Cfg()
num_imgs = dict(coco=5000, objects365=30000) num_imgs = dict(coco=5000, objects365=30000)
...@@ -139,7 +135,7 @@ def main(): ...@@ -139,7 +135,7 @@ def main():
def worker( def worker(
net_file, model_file, data_dir, worker_id, total_worker, result_queue, current_network, model_file, data_dir, worker_id, total_worker, result_queue,
): ):
""" """
:param net_file: network description file :param net_file: network description file
...@@ -156,9 +152,9 @@ def worker( ...@@ -156,9 +152,9 @@ def worker(
pred = model(model.inputs) pred = model(model.inputs)
return pred return pred
sys.path.insert(0, os.path.dirname(net_file)) cfg = current_network.Cfg()
current_network = importlib.import_module(os.path.basename(net_file).split(".")[0]) cfg.backbone_pretrained = False
model = current_network.Net(current_network.Cfg(), batch_size=1) model = current_network.Net(cfg, batch_size=1)
model.eval() model.eval()
evaluator = DetEvaluator(model) evaluator = DetEvaluator(model)
state_dict = mge.load(model_file) state_dict = mge.load(model_file)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册