未验证 提交 8e392522 编写于 作者: T tripleMu 提交者: GitHub

Support YOLOv8 deploy (#456)

上级 550664d5
# Copyright (c) OpenMMLab. All rights reserved.
from .common import DeployC2f
from .focus import DeployFocus, GConvFocus, NcnnFocus
__all__ = ['DeployFocus', 'NcnnFocus', 'GConvFocus']
__all__ = ['DeployFocus', 'NcnnFocus', 'GConvFocus', 'DeployC2f']
import torch
import torch.nn as nn
from torch import Tensor
class DeployC2f(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x: Tensor) -> Tensor:
x_main = self.main_conv(x)
x_main = [x_main, x_main[:,self.mid_channels:,...]]
x_main.extend(blocks(x_main[-1]) for blocks in self.blocks)
x_main.pop(1)
return self.final_conv(torch.cat(x_main, 1))
......@@ -11,7 +11,8 @@ from torch import Tensor
from mmyolo.models import RepVGGBlock
from mmyolo.models.dense_heads import (RTMDetHead, YOLOv5Head, YOLOv7Head,
YOLOXHead)
from ..backbone import DeployFocus, GConvFocus, NcnnFocus
from mmyolo.models.layers import CSPLayerWithTwoConv
from ..backbone import DeployC2f, DeployFocus, GConvFocus, NcnnFocus
from ..bbox_code import (rtmdet_bbox_decoder, yolov5_bbox_decoder,
yolox_bbox_decoder)
from ..nms import batched_nms, efficient_nms, onnx_nms
......@@ -49,7 +50,7 @@ class DeployModel(nn.Module):
for layer in self.baseModel.modules():
if isinstance(layer, RepVGGBlock):
layer.switch_to_deploy()
if isinstance(layer, Focus):
elif isinstance(layer, Focus):
# onnxruntime tensorrt8 tensorrt7
if self.backend in (1, 2, 3):
self.baseModel.backbone.stem = DeployFocus(layer)
......@@ -59,6 +60,8 @@ class DeployModel(nn.Module):
# switch focus to group conv
else:
self.baseModel.backbone.stem = GConvFocus(layer)
elif isinstance(layer, CSPLayerWithTwoConv):
setattr(layer, '__class__', DeployC2f)
def pred_by_feat(self,
cls_scores: List[Tensor],
......
......@@ -16,6 +16,7 @@ warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
warnings.filterwarnings(action='ignore', category=UserWarning)
warnings.filterwarnings(action='ignore', category=FutureWarning)
warnings.filterwarnings(action='ignore', category=ResourceWarning)
def parse_args():
......
# Copyright (c) OpenMMLab. All rights reserved.
from projects.easydeploy.model import ORTWrapper, TRTWrapper # isort:skip
import os
import random
from argparse import ArgumentParser
......@@ -15,8 +16,6 @@ from mmengine.utils import ProgressBar, path
from mmyolo.utils import register_all_modules
from mmyolo.utils.misc import get_file_list
from projects.easydeploy.model import ORTWrapper, TRTWrapper # isort:skip
def parse_args():
parser = ArgumentParser()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册