未验证 提交 c4298f80 编写于 作者: L Liu Songtao 提交者: GitHub

Merge pull request #1189 from FateScript/hubload

feat(model): support hub load
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
"""
Usage example:
import torch
model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_s")
"""
dependencies = ["torch"]
from yolox.models import ( # isort:skip # noqa: F401, E402
yolox_tiny,
yolox_nano,
yolox_s,
yolox_m,
yolox_l,
yolox_x,
yolov3,
)
......@@ -2,6 +2,7 @@
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.
from .build import *
from .darknet import CSPDarknet, Darknet
from .losses import IOUloss
from .yolo_fpn import YOLOFPN
......
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import torch
from torch import nn
from torch.hub import load_state_dict_from_url
__all__ = [
"create_yolox_model",
"yolox_nano",
"yolox_tiny",
"yolox_s",
"yolox_m",
"yolox_l",
"yolox_x",
"yolov3",
]
_CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
_CKPT_FULL_PATH = {
"yolox-nano": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_nano.pth",
"yolox-tiny": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_tiny.pth",
"yolox-s": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_s.pth",
"yolox-m": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_m.pth",
"yolox-l": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_l.pth",
"yolox-x": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_x.pth",
"yolov3": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_darknet.pth",
}
def create_yolox_model(
name: str, pretrained: bool = True, num_classes: int = 80, device=None
) -> nn.Module:
"""creates and loads a YOLOX model
Args:
name (str): name of model. for example, "yolox-s", "yolox-tiny".
pretrained (bool): load pretrained weights into the model. Default to True.
num_classes (int): number of model classes. Defalut to 80.
device (str): default device to for model. Defalut to None.
Returns:
YOLOX model (nn.Module)
"""
from yolox.exp import get_exp, Exp
if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
assert name in _CKPT_FULL_PATH, f"user should use one of value in {_CKPT_FULL_PATH.keys()}"
exp: Exp = get_exp(exp_name=name)
exp.num_classes = num_classes
yolox_model = exp.get_model()
if pretrained and num_classes == 80:
weights_url = _CKPT_FULL_PATH[name]
ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
if "model" in ckpt:
ckpt = ckpt["model"]
yolox_model.load_state_dict(ckpt)
yolox_model.to(device)
return yolox_model
def yolox_nano(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-nano", pretrained, num_classes, device)
def yolox_tiny(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
def yolox_s(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-s", pretrained, num_classes, device)
def yolox_m(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-m", pretrained, num_classes, device)
def yolox_l(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-l", pretrained, num_classes, device)
def yolox_x(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-x", pretrained, num_classes, device)
def yolov3(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
......@@ -9,7 +9,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from yolox.utils import bboxes_iou
from yolox.utils import bboxes_iou, meshgrid
from .losses import IOUloss
from .network_blocks import BaseConv, DWConv
......@@ -220,7 +220,7 @@ class YOLOXHead(nn.Module):
n_ch = 5 + self.num_classes
hsize, wsize = output.shape[-2:]
if grid.shape[2:4] != output.shape[2:4]:
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij")
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
self.grids[k] = grid
......@@ -237,7 +237,7 @@ class YOLOXHead(nn.Module):
grids = []
strides = []
for (hsize, wsize), stride in zip(self.hw, self.strides):
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij")
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
......@@ -321,7 +321,11 @@ class YOLOXHead(nn.Module):
labels,
imgs,
)
except RuntimeError:
except RuntimeError as e:
# TODO: the string might change, consider a better way
if "CUDA out of memory. " not in str(e):
raise # RuntimeError might not caused by CUDA OOM
logger.error(
"OOM RuntimeError is raised due to the huge memory cost during label assignment. \
CPU mode is applied in this batch. If you want to avoid this issue, \
......
......@@ -5,6 +5,7 @@
from .allreduce_norm import *
from .boxes import *
from .checkpoint import load_ckpt, save_checkpoint
from .compat import meshgrid
from .demo_utils import *
from .dist import *
from .ema import *
......
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import torch
_TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]
__all__ = ["meshgrid"]
def meshgrid(*tensors):
if _TORCH_VER >= [1, 10]:
return torch.meshgrid(*tensors, indexing="ij")
else:
return torch.meshgrid(*tensors)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册