提交 e685457f 编写于 作者: F Feng Wang

fix(model): compatible meshgrid and CUDA OOM error

上级 d18f5e82
......@@ -8,7 +8,7 @@ Usage example:
"""
dependencies = ["torch"]
from yolox.models import ( # noqa: F401, E402
from yolox.models import ( # isort:skip # noqa: F401, E402
yolox_tiny,
yolox_nano,
yolox_s,
......
......@@ -9,7 +9,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from yolox.utils import bboxes_iou
from yolox.utils import bboxes_iou, meshgrid
from .losses import IOUloss
from .network_blocks import BaseConv, DWConv
......@@ -220,7 +220,7 @@ class YOLOXHead(nn.Module):
n_ch = 5 + self.num_classes
hsize, wsize = output.shape[-2:]
if grid.shape[2:4] != output.shape[2:4]:
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij")
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
self.grids[k] = grid
......@@ -237,7 +237,7 @@ class YOLOXHead(nn.Module):
grids = []
strides = []
for (hsize, wsize), stride in zip(self.hw, self.strides):
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij")
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
......@@ -321,7 +321,11 @@ class YOLOXHead(nn.Module):
labels,
imgs,
)
except RuntimeError:
except RuntimeError as e:
# TODO: the string might change, consider a better way
if "CUDA out of memory. " not in str(e):
raise # RuntimeError might not caused by CUDA OOM
logger.error(
"OOM RuntimeError is raised due to the huge memory cost during label assignment. \
CPU mode is applied in this batch. If you want to avoid this issue, \
......
......@@ -5,6 +5,7 @@
from .allreduce_norm import *
from .boxes import *
from .checkpoint import load_ckpt, save_checkpoint
from .compat import meshgrid
from .demo_utils import *
from .dist import *
from .ema import *
......
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import torch
_TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]
__all__ = ["meshgrid"]
def meshgrid(*tensors):
if _TORCH_VER >= [1, 10]:
return torch.meshgrid(*tensors, indexing="ij")
else:
return torch.meshgrid(*tensors)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册