未验证 提交 6d91289f 编写于 作者: N nihao 提交者: GitHub

Add SIoU and MobileOne block (#6312)

* Add SIoU and MobileOne block

* add paddle copyright

* mobileone block k>1 bugfix

* format code style
上级 00b656f2
......@@ -33,6 +33,7 @@ from . import cspresnet
from . import csp_darknet
from . import convnext
from . import vision_transformer
from . import mobileone
from .vgg import *
from .resnet import *
......@@ -54,4 +55,6 @@ from .esnet import *
from .cspresnet import *
from .csp_darknet import *
from .convnext import *
from .vision_transformer import *
\ No newline at end of file
from .vision_transformer import *
from .vision_transformer import *
from .mobileone import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is the paddle implementation of MobileOne block, see: https://arxiv.org/pdf/2206.04040.pdf.
Some codes are based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
Ths copyright of microsoft/Swin-Transformer is as follows:
MIT License [see LICENSE for details]
"""
import paddle
import paddle.nn as nn
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal
from ppdet.modeling.ops import get_act_fn
from ppdet.modeling.layers import ConvNormLayer
class MobileOneBlock(nn.Layer):
def __init__(
self,
ch_in,
ch_out,
stride,
kernel_size,
conv_num=1,
norm_type='bn',
norm_decay=0.,
norm_groups=32,
bias_on=False,
lr_scale=1.,
freeze_norm=False,
initializer=Normal(
mean=0., std=0.01),
skip_quant=False,
act='relu', ):
super(MobileOneBlock, self).__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.kernel_size = kernel_size
self.stride = stride
self.padding = (kernel_size - 1) // 2
self.k = conv_num
self.depth_conv = nn.LayerList()
self.point_conv = nn.LayerList()
for i in range(self.k):
if i > 0:
stride = 1
self.depth_conv.append(
ConvNormLayer(
ch_in,
ch_in,
kernel_size,
stride=stride,
groups=ch_in,
norm_type=norm_type,
norm_decay=norm_decay,
norm_groups=norm_groups,
bias_on=bias_on,
lr_scale=lr_scale,
freeze_norm=freeze_norm,
initializer=initializer,
skip_quant=skip_quant))
self.point_conv.append(
ConvNormLayer(
ch_in,
ch_out,
1,
stride=1,
groups=1,
norm_type=norm_type,
norm_decay=norm_decay,
norm_groups=norm_groups,
bias_on=bias_on,
lr_scale=lr_scale,
freeze_norm=freeze_norm,
initializer=initializer,
skip_quant=skip_quant))
self.rbr_1x1 = ConvNormLayer(
ch_in,
ch_in,
1,
stride=self.stride,
groups=ch_in,
norm_type=norm_type,
norm_decay=norm_decay,
norm_groups=norm_groups,
bias_on=bias_on,
lr_scale=lr_scale,
freeze_norm=freeze_norm,
initializer=initializer,
skip_quant=skip_quant)
self.rbr_identity_st1 = nn.BatchNorm2D(
num_features=ch_in,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(
0.0))) if ch_in == ch_out and self.stride == 1 else None
self.rbr_identity_st2 = nn.BatchNorm2D(
num_features=ch_out,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.act = get_act_fn(act) if act is None or isinstance(act, (
str, dict)) else act
def forward(self, x):
if hasattr(self, "conv1") and hasattr(self, "conv2"):
y = self.act(self.conv2(self.act(self.conv1(x))))
else:
if self.rbr_identity_st1 is None:
id_out_st1 = 0
else:
id_out_st1 = self.rbr_identity_st1(x)
x1_1 = x.clone()
for i in range(self.k):
x1_1 = self.depth_conv[i](x1_1)
x1_2 = self.rbr_1x1(x)
x1 = self.act(x1_1 + x1_2 + id_out_st1)
if self.rbr_identity_st2 is None:
id_out_st2 = 0
else:
id_out_st2 = self.rbr_identity_st2(x1)
x2_1 = x1.clone()
for i in range(self.k):
x2_1 = self.point_conv[i](x2_1)
y = self.act(x2_1 + id_out_st2)
return y
def convert_to_deploy(self):
if not hasattr(self, 'conv1'):
self.conv1 = nn.Conv2D(
in_channels=self.ch_in,
out_channels=self.ch_in,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
groups=self.ch_in)
if not hasattr(self, 'conv2'):
self.conv2 = nn.Conv2D(
in_channels=self.ch_in,
out_channels=self.ch_out,
kernel_size=1,
stride=1,
padding='SAME',
groups=1)
conv1_kernel, conv1_bias, conv2_kernel, conv2_bias = self.get_equivalent_kernel_bias(
)
self.conv1.weight.set_value(conv1_kernel)
self.conv1.bias.set_value(conv1_bias)
self.conv2.weight.set_value(conv2_kernel)
self.conv2.bias.set_value(conv2_bias)
self.__delattr__('depth_conv')
self.__delattr__('point_conv')
self.__delattr__('rbr_1x1')
if hasattr(self, 'rbr_identity_st1'):
self.__delattr__('rbr_identity_st1')
if hasattr(self, 'rbr_identity_st2'):
self.__delattr__('rbr_identity_st2')
if hasattr(self, 'id_tensor'):
self.__delattr__('id_tensor')
def get_equivalent_kernel_bias(self):
st1_kernel3x3, st1_bias3x3 = self._fuse_bn_tensor(self.depth_conv)
st1_kernel1x1, st1_bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
st1_kernelid, st1_biasid = self._fuse_bn_tensor(
self.rbr_identity_st1, kernel_size=self.kernel_size)
st2_kernel1x1, st2_bias1x1 = self._fuse_bn_tensor(self.point_conv)
st2_kernelid, st2_biasid = self._fuse_bn_tensor(
self.rbr_identity_st2, kernel_size=1)
conv1_kernel = st1_kernel3x3 + self._pad_1x1_to_3x3_tensor(
st1_kernel1x1) + st1_kernelid
conv1_bias = st1_bias3x3 + st1_bias1x1 + st1_biasid
conv2_kernel = st2_kernel1x1 + st2_kernelid
conv2_bias = st2_bias1x1 + st2_biasid
return conv1_kernel, conv1_bias, conv2_kernel, conv2_bias
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
padding_size = (self.kernel_size - 1) // 2
return nn.functional.pad(
kernel1x1,
[padding_size, padding_size, padding_size, padding_size])
def _fuse_bn_tensor(self, branch, kernel_size=3):
if branch is None:
return 0, 0
if isinstance(branch, nn.LayerList):
kernel = 0
running_mean = 0
running_var = 0
gamma = 0
beta = 0
eps = 0
for block in branch:
kernel += block.conv.weight
running_mean += block.norm._mean
running_var += block.norm._variance
gamma += block.norm.weight
beta += block.norm.bias
eps += block.norm._epsilon
kernel /= len(branch)
running_mean /= len(branch)
running_var /= len(branch)
gamma /= len(branch)
beta /= len(branch)
eps /= len(branch)
elif isinstance(branch, ConvNormLayer):
kernel = branch.conv.weight
running_mean = branch.norm._mean
running_var = branch.norm._variance
gamma = branch.norm.weight
beta = branch.norm.bias
eps = branch.norm._epsilon
else:
assert isinstance(branch, nn.BatchNorm2D)
input_dim = self.ch_in if kernel_size == 1 else 1
kernel_value = paddle.zeros(
shape=[self.ch_in, input_dim, kernel_size, kernel_size],
dtype='float32')
if kernel_size > 1:
for i in range(self.ch_in):
kernel_value[i, i % input_dim, 1, 1] = 1
elif kernel_size == 1:
for i in range(self.ch_in):
kernel_value[i, i % input_dim, 0, 0] = 1
else:
raise ValueError("Invalid kernel size recieved!")
kernel = paddle.to_tensor(kernel_value, place=branch.weight.place)
running_mean = branch._mean
running_var = branch._variance
gamma = branch.weight
beta = branch.bias
eps = branch._epsilon
std = (running_var + eps).sqrt()
t = (gamma / std).reshape((-1, 1, 1, 1))
return kernel * t, beta - running_mean * gamma / std
......@@ -17,13 +17,13 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import math
import paddle
from ppdet.core.workspace import register, serializable
from ..bbox_utils import bbox_iou
__all__ = ['IouLoss', 'GIoULoss', 'DIouLoss']
__all__ = ['IouLoss', 'GIoULoss', 'DIouLoss', 'SIoULoss']
@register
......@@ -208,3 +208,88 @@ class DIouLoss(GIoULoss):
diou = paddle.mean((1 - iouk + ciou_term + diou_term) * iou_weight)
return diou * self.loss_weight
@register
@serializable
class SIoULoss(GIoULoss):
"""
see https://arxiv.org/pdf/2205.12740.pdf
Args:
loss_weight (float): siou loss weight, default as 1
eps (float): epsilon to avoid divide by zero, default as 1e-10
theta (float): default as 4
reduction (str): Options are "none", "mean" and "sum". default as none
"""
def __init__(self, loss_weight=1., eps=1e-10, theta=4., reduction='none'):
super(SIoULoss, self).__init__(loss_weight=loss_weight, eps=eps)
self.loss_weight = loss_weight
self.eps = eps
self.theta = theta
self.reduction = reduction
def __call__(self, pbox, gbox):
x1, y1, x2, y2 = paddle.split(pbox, num_or_sections=4, axis=-1)
x1g, y1g, x2g, y2g = paddle.split(gbox, num_or_sections=4, axis=-1)
box1 = [x1, y1, x2, y2]
box2 = [x1g, y1g, x2g, y2g]
iou = bbox_iou(box1, box2)
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1 + self.eps
h = y2 - y1 + self.eps
cxg = (x1g + x2g) / 2
cyg = (y1g + y2g) / 2
wg = x2g - x1g + self.eps
hg = y2g - y1g + self.eps
x2 = paddle.maximum(x1, x2)
y2 = paddle.maximum(y1, y2)
# A or B
xc1 = paddle.minimum(x1, x1g)
yc1 = paddle.minimum(y1, y1g)
xc2 = paddle.maximum(x2, x2g)
yc2 = paddle.maximum(y2, y2g)
cw_out = xc2 - xc1
ch_out = yc2 - yc1
ch = paddle.maximum(cy, cyg) - paddle.minimum(cy, cyg)
cw = paddle.maximum(cx, cxg) - paddle.minimum(cx, cxg)
# angle cost
dist_intersection = paddle.sqrt((cx - cxg)**2 + (cy - cyg)**2)
sin_angle_alpha = ch / dist_intersection
sin_angle_beta = cw / dist_intersection
thred = paddle.pow(paddle.to_tensor(2), 0.5) / 2
thred.stop_gradient = True
sin_alpha = paddle.where(sin_angle_alpha > thred, sin_angle_beta,
sin_angle_alpha)
angle_cost = paddle.cos(paddle.asin(sin_alpha) * 2 - math.pi / 2)
# distance cost
gamma = 2 - angle_cost
# gamma.stop_gradient = True
beta_x = ((cxg - cx) / cw_out)**2
beta_y = ((cyg - cy) / ch_out)**2
dist_cost = 1 - paddle.exp(-gamma * beta_x) + 1 - paddle.exp(-gamma *
beta_y)
# shape cost
omega_w = paddle.abs(w - wg) / paddle.maximum(w, wg)
omega_h = paddle.abs(hg - h) / paddle.maximum(h, hg)
omega = (1 - paddle.exp(-omega_w))**self.theta + (
1 - paddle.exp(-omega_h))**self.theta
siou_loss = 1 - iou + (omega + dist_cost) / 2
if self.reduction == 'mean':
siou_loss = paddle.mean(siou_loss)
elif self.reduction == 'sum':
siou_loss = paddle.sum(siou_loss)
return siou_loss * self.loss_weight
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册