MegEngine Models
本仓库包含了采用MegEngine实现的各种主流深度学习模型。
official目录下提供了各种经典的图像分类、目标检测、图像分割以及自然语言模型的官方实现。每个模型同时提供了模型定义、推理以及训练的代码。
官方会一直维护official下的代码,保持适配MegEngine的最新API,提供最优的模型实现。同时,提供高质量的学习文档,帮助新手学习如何在MegEngine下训练自己的模型。
综述
对于每个模型,我们提供了至少四个脚本文件:模型定义(model.py
)、模型推理(inference.py
)、模型训练(train.py
)、模型测试(test.py
)。
每个模型目录下都对应有一个README
,介绍了模型的详细信息,并详细描述了训练和测试的流程。例如 ResNet README。
另外,official
下定义的模型可以通过megengine.hub
来直接加载,例如:
import megengine.hub
# 只加载网络结构
resnet18 = megengine.hub.load("megengine/models", "resnet18")
# 加载网络结构和预训练权重
resnet18 = megengine.hub.load("megengine/models", "resnet18", pretrained=True)
更多可以通过megengine.hub
接口加载的模型见hubconf.py。
安装和环境配置
在开始运行本仓库下的代码之前,用户需要通过以下步骤来配置本地环境:
- 克隆仓库
git clone https://github.com/MegEngine/Models.git
- 安装依赖包
pip3 install --user -r requirements.txt
- 添加目录到python环境变量中
export PYTHONPATH=/path/to/models:$PYTHONPATH
官方模型介绍
图像分类
图像分类是计算机视觉的基础任务。许多计算机视觉的其它任务(例如物体检测)都使用了基于图像分类的预训练模型。因此,我们提供了各种在ImageNet上预训练好的分类模型,包括ResNet系列, shufflenet系列等,这些模型在ImageNet验证集上的测试结果如下表:
模型 | top1 acc | top5 acc |
---|---|---|
ResNet18 | 70.312 | 89.430 |
ResNet34 | 73.960 | 91.630 |
ResNet50 | 76.254 | 93.056 |
ResNet101 | 77.944 | 93.844 |
ResNet152 | 78.582 | 94.130 |
ResNeXt50 32x4d | 77.592 | 93.644 |
ResNeXt101 32x8d | 79.520 | 94.586 |
ShuffleNetV2 x0.5 | 60.696 | 82.190 |
ShuffleNetV2 x1.0 | 69.372 | 88.764 |
ShuffleNetV2 x1.5 | 72.806 | 90.792 |
ShuffleNetV2 x2.0 | 75.074 | 92.278 |
目标检测
目标检测同样是计算机视觉中的常见任务,我们提供了两个经典的目标检测模型Retinanet和Faster R-CNN,这两个模型在COCO验证集上的测试结果如下:
模型 | mAP @5-95 |
---|---|
retinanet-res50-coco-1x-800size | 36.4 |
retinanet-res50-coco-1x-800size-syncbn | 37.1 |
retinanet-res101-coco-2x-800size | 40.8 |
retinanet-resx101-coco-2x-800size | 41.8 |
faster-rcnn-res50-coco-1x-800size | 38.8 |
faster-rcnn-res50-coco-1x-800size-syncbn | 39.3 |
faster-rcnn-res101-coco-2x-800size | 43.0 |
faster-rcnn-resx101-coco-2x-800size | 44.7 |
图像分割
我们也提供了经典的语义分割模型--Deeplabv3plus,这个模型在PASCAL VOC验证集上的测试结果如下:
模型 | Backbone | mIoU_single | mIoU_multi |
---|---|---|---|
Deeplabv3plus | Resnet101 | 79.0 | 79.8 |
人体关节点检测
我们提供了人体关节点检测的经典模型SimpleBaseline和高精度模型MSPN,使用在COCO val2017上人体检测AP为56的检测结果,提供的模型在COCO val2017上的关节点检测结果为:
Methods | Backbone | Input Size | AP | Ap .5 | AP .75 | AP (M) | AP (L) | AR | AR .5 | AR .75 | AR (M) | AR (L) |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SimpleBaseline | Res50 | 256x192 | 0.712 | 0.887 | 0.779 | 0.673 | 0.785 | 0.782 | 0.932 | 0.839 | 0.730 | 0.854 |
SimpleBaseline | Res101 | 256x192 | 0.722 | 0.891 | 0.795 | 0.687 | 0.795 | 0.794 | 0.936 | 0.855 | 0.745 | 0.863 |
SimpleBaseline | Res152 | 256x192 | 0.724 | 0.888 | 0.794 | 0.688 | 0.795 | 0.795 | 0.934 | 0.856 | 0.746 | 0.863 |
MSPN_4stage | MSPN | 256x192 | 0.752 | 0.900 | 0.819 | 0.716 | 0.825 | 0.819 | 0.943 | 0.875 | 0.770 | 0.887 |
自然语言处理
我们同样支持一些常见的自然语言处理模型,模型的权重来自Google的pre-trained models, 用户可以直接使用megengine.hub
轻松的调用预训练的bert模型。
另外,我们在bert中还提供了更加方便的脚本, 可以通过任务名直接获取到对应字典, 配置, 与预训练模型。
模型 | 字典 | 配置 |
---|---|---|
wwm_cased_L-24_H-1024_A-16 | link | link |
wwm_uncased_L-24_H-1024_A-16 | link | link |
cased_L-12_H-768_A-12 | link | link |
cased_L-24_H-1024_A-16 | link | link |
uncased_L-12_H-768_A-12 | link | link |
uncased_L-24_H-1024_A-16 | link | link |
chinese_L-12_H-768_A-12 | link | link |
multi_cased_L-12_H-768_A-12 | link | link |
在glue_data/MRPC数据集中使用默认的超参数进行微调和评估,评估结果介于84%和88%之间。
Dataset | pretrained_bert | acc |
---|---|---|
glue_data/MRPC | uncased_L-12_H-768_A-12 | 86.25% |