未验证 提交 70c98b36 编写于 作者: J Jianfeng Wang 提交者: GitHub

feat(detection): support models with res101 backbone (#48)

上级 8940da38
......@@ -70,16 +70,18 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH
| ShuffleNetV2 x1.5 | 72.806 | 90.792 |
| ShuffleNetV2 x2.0 | 75.074 | 92.278 |
### 目标检测
目标检测同样是计算机视觉中的常见任务,我们提供了两个经典的目标检测模型[Retinanet](./official/vision/detection/model/retinanet)[Faster R-CNN](./official/vision/detection/model/faster_rcnn),这两个模型在**COCO验证集**上的测试结果如下:
| 模型 | mAP<br>@5-95 |
| :---: | :---: |
| retinanet-res50-1x-800size | 36.4 |
| faster-rcnn-res50-1x-800size | 38.8 |
| retinanet-res50-coco-1x-800size | 36.4 |
| retinanet-res50-coco-1x-800size-syncbn | 37.1 |
| retinanet-res101-coco-2x-800size | 40.8 |
| faster-rcnn-res50-coco-1x-800size | 38.8 |
| faster-rcnn-res50-coco-1x-800size-syncbn | 39.3 |
| faster-rcnn-res101-coco-2x-800size | 43.0 |
### 图像分割
......@@ -117,7 +119,6 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH
| chinese_L-12_H-768_A-12| [link](https://data.megengine.org.cn/models/weights/bert/chinese_L-12_H-768_A-12/vocab.txt) | [link](https://data.megengine.org.cn/models/weights/bert/chinese_L-12_H-768_A-12/bert_config.json)
| multi_cased_L-12_H-768_A-12| [link](https://data.megengine.org.cn/models/weights/bert/multi_cased_L-12_H-768_A-12/vocab.txt) | [link](https://data.megengine.org.cn/models/weights/bert/multi_cased_L-12_H-768_A-12/bert_config.json)
在glue_data/MRPC数据集中使用默认的超参数进行微调和评估,评估结果介于84%和88%之间。
| Dataset | pretrained_bert | acc |
......
......@@ -30,8 +30,10 @@ from official.vision.classification.shufflenet.model import (
from official.vision.detection.configs import (
faster_rcnn_res50_coco_1x_800size,
faster_rcnn_res50_coco_1x_800size_syncbn,
faster_rcnn_res101_coco_2x_800size,
retinanet_res50_coco_1x_800size,
retinanet_res50_coco_1x_800size_syncbn,
retinanet_res101_coco_2x_800size,
)
from official.vision.detection.models import FasterRCNN, RetinaNet
from official.vision.detection.tools.utils import DetEvaluator
......
......@@ -10,10 +10,12 @@
| --- | :---: | :---: | :---: | :---: |
| retinanet-res50-coco-1x-800size | 36.4 | 2 | 2080Ti | 3.1(it/s) |
| retinanet-res50-coco-1x-800size-syncbn | 37.1 | 2 | 2080Ti | 1.7(it/s) |
| retinanet-res101-coco-2x-800size | 40.8 | 2 | 2080Ti | 2.1(it/s) |
| faster-rcnn-res50-coco-1x-800size | 38.8 | 2 | 2080Ti | 3.3(it/s) |
| faster-rcnn-res50-coco-1x-800size-syncbn | 39.3 | 2 | 2080Ti | 1.8(it/s) |
| faster-rcnn-res101-coco-2x-800size | 43.0 | 2 | 2080Ti | 2.3(it/s) |
* MegEngine v0.4.0
* MegEngine v0.5.1
## 如何使用
......
from .faster_rcnn_res50_coco_1x_800size import faster_rcnn_res50_coco_1x_800size
from .faster_rcnn_res50_coco_1x_800size_syncbn import faster_rcnn_res50_coco_1x_800size_syncbn
from .faster_rcnn_res101_coco_2x_800size import faster_rcnn_res101_coco_2x_800size
from .retinanet_res50_coco_1x_800size import retinanet_res50_coco_1x_800size
from .retinanet_res50_coco_1x_800size_syncbn import retinanet_res50_coco_1x_800size_syncbn
from .retinanet_res101_coco_2x_800size import retinanet_res101_coco_2x_800size
_EXCLUDE = {}
__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import hub
from official.vision.detection import models
class CustomFasterRCNNConfig(models.FasterRCNNConfig):
def __init__(self):
super().__init__()
self.backbone = "resnet101"
# ------------------------ training cfg ---------------------- #
self.max_epoch = 36
self.lr_decay_stages = [24, 32, 34]
@hub.pretrained(
"https://data.megengine.org.cn/models/weights/"
"faster_rcnn_res101_coco_2x_800size_43dot0_ee249359.pkl"
)
def faster_rcnn_res101_coco_2x_800size(batch_size=1, **kwargs):
r"""
Faster-RCNN FPN trained from COCO dataset.
`"Faster-RCNN" <https://arxiv.org/abs/1506.01497>`_
`"FPN" <https://arxiv.org/abs/1612.03144>`_
`"COCO" <https://arxiv.org/abs/1405.0312>`_
"""
return models.FasterRCNN(CustomFasterRCNNConfig(), batch_size=batch_size, **kwargs)
Net = models.FasterRCNN
Cfg = CustomFasterRCNNConfig
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import hub
from official.vision.detection import models
class CustomRetinaNetConfig(models.RetinaNetConfig):
def __init__(self):
super().__init__()
self.backbone = "resnet101"
# ------------------------ training cfg ---------------------- #
self.max_epoch = 36
self.lr_decay_stages = [24, 32, 34]
@hub.pretrained(
"https://data.megengine.org.cn/models/weights/"
"retinanet_res101_coco_2x_800size_40dot8_661c3608.pkl"
)
def retinanet_res101_coco_2x_800size(batch_size=1, **kwargs):
r"""
RetinaNet trained from COCO dataset.
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
`"FPN" <https://arxiv.org/abs/1612.03144>`_
`"COCO" <https://arxiv.org/abs/1405.0312>`_
"""
return models.RetinaNet(CustomRetinaNetConfig(), batch_size=batch_size, **kwargs)
Net = models.RetinaNet
Cfg = CustomRetinaNetConfig
......@@ -48,7 +48,9 @@ def main():
sys.path.insert(0, os.path.dirname(args.file))
current_network = importlib.import_module(os.path.basename(args.file).split(".")[0])
model = current_network.Net(current_network.Cfg(), batch_size=1)
cfg = current_network.Cfg()
cfg.backbone_pretrained = False
model = current_network.Net(cfg, batch_size=1)
model.eval()
state_dict = mge.load(args.weight_file)
if "state_dict" in state_dict:
......
......@@ -37,9 +37,6 @@ def make_parser():
parser.add_argument(
"-n", "--ngpus", default=1, type=int, help="total number of gpus for testing",
)
parser.add_argument(
"-b", "--batch_size", default=1, type=int, help="batchsize for testing",
)
parser.add_argument(
"-d", "--dataset_dir", default="/data/datasets", type=str,
)
......@@ -56,6 +53,9 @@ def main():
parser = make_parser()
args = parser.parse_args()
sys.path.insert(0, os.path.dirname(args.file))
current_network = importlib.import_module(os.path.basename(args.file).split(".")[0])
if args.end_epoch == -1:
args.end_epoch = args.start_epoch
......@@ -75,7 +75,7 @@ def main():
proc = Process(
target=worker,
args=(
args.file,
current_network,
model_file,
args.dataset_dir,
i,
......@@ -86,10 +86,6 @@ def main():
proc.start()
procs.append(proc)
sys.path.insert(0, os.path.dirname(args.file))
current_network = importlib.import_module(
os.path.basename(args.file).split(".")[0]
)
cfg = current_network.Cfg()
num_imgs = dict(coco=5000, objects365=30000)
......@@ -139,7 +135,7 @@ def main():
def worker(
net_file, model_file, data_dir, worker_id, total_worker, result_queue,
current_network, model_file, data_dir, worker_id, total_worker, result_queue,
):
"""
:param net_file: network description file
......@@ -156,9 +152,9 @@ def worker(
pred = model(model.inputs)
return pred
sys.path.insert(0, os.path.dirname(net_file))
current_network = importlib.import_module(os.path.basename(net_file).split(".")[0])
model = current_network.Net(current_network.Cfg(), batch_size=1)
cfg = current_network.Cfg()
cfg.backbone_pretrained = False
model = current_network.Net(cfg, batch_size=1)
model.eval()
evaluator = DetEvaluator(model)
state_dict = mge.load(model_file)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册