from regex import X import torch import torch.nn as nn import numpy as np from nets.CSPdarknet import CSPDarknet, Conv, MP, RCSPDark_Block, RCSPDark_Transition, autopad, SiLU class SPPCSPC(nn.Module): # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)): super(SPPCSPC, self).__init__() c_ = int(2 * c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c1, c_, 1, 1) self.cv3 = Conv(c_, c_, 3, 1) self.cv4 = Conv(c_, c_, 1, 1) self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) self.cv5 = Conv(4 * c_, c_, 1, 1) self.cv6 = Conv(c_, c_, 3, 1) self.cv7 = Conv(2 * c_, c2, 1, 1) def forward(self, x): x1 = self.cv4(self.cv3(self.cv1(x))) y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1))) y2 = self.cv2(x) return self.cv7(torch.cat((y1, y2), dim=1)) class RepConv(nn.Module): # Represented convolution # https://arxiv.org/abs/2101.03697 def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=SiLU(), deploy=False): super(RepConv, self).__init__() self.deploy = deploy self.groups = g self.in_channels = c1 self.out_channels = c2 assert k == 3 assert autopad(k, p) == 1 padding_11 = autopad(k, p) - k // 2 self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) if deploy: self.rbr_reparam = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=True) else: self.rbr_identity = (nn.BatchNorm2d(num_features=c1) if c2 == c1 and s == 1 else None) self.rbr_dense = nn.Sequential( nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False), nn.BatchNorm2d(num_features=c2), ) self.rbr_1x1 = nn.Sequential( nn.Conv2d( c1, c2, 1, s, padding_11, groups=g, bias=False), nn.BatchNorm2d(num_features=c2), ) def forward(self, inputs): if hasattr(self, "rbr_reparam"): return self.act(self.rbr_reparam(inputs)) if self.rbr_identity is None: id_out = 0 else: id_out = self.rbr_identity(inputs) return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out) def get_equivalent_kernel_bias(self): kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) return ( kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid, ) def _pad_1x1_to_3x3_tensor(self, kernel1x1): if kernel1x1 is None: return 0 else: return nn.functional.pad(kernel1x1, [1, 1, 1, 1]) def _fuse_bn_tensor(self, branch): if branch is None: return 0, 0 if isinstance(branch, nn.Sequential): kernel = branch[0].weight running_mean = branch[1].running_mean running_var = branch[1].running_var gamma = branch[1].weight beta = branch[1].bias eps = branch[1].eps else: assert isinstance(branch, nn.BatchNorm2d) if not hasattr(self, "id_tensor"): input_dim = self.in_channels // self.groups kernel_value = np.zeros( (self.in_channels, input_dim, 3, 3), dtype=np.float32 ) for i in range(self.in_channels): kernel_value[i, i % input_dim, 1, 1] = 1 self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device) kernel = self.id_tensor running_mean = branch.running_mean running_var = branch.running_var gamma = branch.weight beta = branch.bias eps = branch.eps std = (running_var + eps).sqrt() t = (gamma / std).reshape(-1, 1, 1, 1) return kernel * t, beta - running_mean * gamma / std def repvgg_convert(self): kernel, bias = self.get_equivalent_kernel_bias() return ( kernel.detach().cpu().numpy(), bias.detach().cpu().numpy(), ) def fuse_conv_bn(self, conv, bn): std = (bn.running_var + bn.eps).sqrt() bias = bn.bias - bn.running_mean * bn.weight / std t = (bn.weight / std).reshape(-1, 1, 1, 1) weights = conv.weight * t bn = nn.Identity() conv = nn.Conv2d(in_channels = conv.in_channels, out_channels = conv.out_channels, kernel_size = conv.kernel_size, stride=conv.stride, padding = conv.padding, dilation = conv.dilation, groups = conv.groups, bias = True, padding_mode = conv.padding_mode) conv.weight = torch.nn.Parameter(weights) conv.bias = torch.nn.Parameter(bias) return conv def fuse_repvgg_block(self): if self.deploy: return print(f"RepConv.fuse_repvgg_block") self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1]) self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1]) rbr_1x1_bias = self.rbr_1x1.bias weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1]) # Fuse self.rbr_identity if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)): # print(f"fuse: rbr_identity == BatchNorm2d or SyncBatchNorm") identity_conv_1x1 = nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device) identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze() # print(f" identity_conv_1x1.weight = {identity_conv_1x1.weight.shape}") identity_conv_1x1.weight.data.fill_(0.0) identity_conv_1x1.weight.data.fill_diagonal_(1.0) identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3) # print(f" identity_conv_1x1.weight = {identity_conv_1x1.weight.shape}") identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity) bias_identity_expanded = identity_conv_1x1.bias weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1]) else: # print(f"fuse: rbr_identity != BatchNorm2d, rbr_identity = {self.rbr_identity}") bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) ) weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) ) #print(f"self.rbr_1x1.weight = {self.rbr_1x1.weight.shape}, ") #print(f"weight_1x1_expanded = {weight_1x1_expanded.shape}, ") #print(f"self.rbr_dense.weight = {self.rbr_dense.weight.shape}, ") self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded) self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded) self.rbr_reparam = self.rbr_dense self.deploy = True if self.rbr_identity is not None: del self.rbr_identity self.rbr_identity = None if self.rbr_1x1 is not None: del self.rbr_1x1 self.rbr_1x1 = None if self.rbr_dense is not None: del self.rbr_dense self.rbr_dense = None def fuse_conv_and_bn(conv, bn): # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ fusedconv = nn.Conv2d(conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, groups=conv.groups, bias=True).requires_grad_(False).to(conv.weight.device) # prepare filters w_conv = conv.weight.clone().view(conv.out_channels, -1) w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) # prepare spatial bias b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) return fusedconv #---------------------------------------------------# # yolo_body #---------------------------------------------------# class YoloBody(nn.Module): def __init__(self, anchors_mask, num_classes, pretrained=False): super(YoloBody, self).__init__() base_channels = 32 #-----------------------------------------------# # 输入图片是640, 640, 3 # 初始的基本通道是64 #-----------------------------------------------# #---------------------------------------------------# # 生成CSPdarknet53的主干模型 # 获得三个有效特征层,他们的shape分别是: # 52,52,256 # 26,26,512 # 13,13,1024 #---------------------------------------------------# # self.backbone = CSPDarknet(model, base_channels, base_depth) self.backbone = CSPDarknet(base_channels) self.upsample = nn.Upsample(scale_factor=2, mode="nearest") self.sppcspc = SPPCSPC(base_channels * 32, base_channels * 16) self.conv_for_P5 = Conv(base_channels * 16, base_channels * 8) self.conv_for_feat2 = Conv(base_channels * 32, base_channels * 8) self.conv3_for_upsample1 = RCSPDark_Block(base_channels * 16, base_channels * 4, base_channels * 8, ids=[-1, -2, -3, -4, -5, -6]) self.conv_for_P4 = Conv(base_channels * 8, base_channels * 4) self.conv_for_feat1 = Conv(base_channels * 16, base_channels * 4) self.conv3_for_upsample2 = RCSPDark_Block(base_channels * 8, base_channels * 2, base_channels * 4, ids=[-1, -2, -3, -4, -5, -6]) self.down_sample1 = RCSPDark_Transition(base_channels * 4, base_channels * 4) self.conv3_for_downsample1 = RCSPDark_Block(base_channels * 16, base_channels * 4, base_channels * 8, ids=[-1, -2, -3, -4, -5, -6]) self.down_sample2 = RCSPDark_Transition(base_channels * 8, base_channels * 8) self.conv3_for_downsample2 = RCSPDark_Block(base_channels * 32, base_channels * 8, base_channels * 16, ids=[-1, -2, -3, -4, -5, -6]) self.rep_conv_1 = RepConv(base_channels * 4, base_channels * 8, 3, 1) self.rep_conv_2 = RepConv(base_channels * 8, base_channels * 16, 3, 1) self.rep_conv_3 = RepConv(base_channels * 16, base_channels * 32, 3, 1) self.yolo_head_P3 = nn.Conv2d(base_channels * 8, len(anchors_mask[2]) * (5 + num_classes), 1) self.yolo_head_P4 = nn.Conv2d(base_channels * 16, len(anchors_mask[1]) * (5 + num_classes), 1) self.yolo_head_P5 = nn.Conv2d(base_channels * 32, len(anchors_mask[0]) * (5 + num_classes), 1) def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers print('Fusing layers... ') for m in self.modules(): if isinstance(m, RepConv): #print(f" fuse_repvgg_block") m.fuse_repvgg_block() elif type(m) is Conv and hasattr(m, 'bn'): m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, 'bn') # remove batchnorm m.forward = m.fuseforward # update forward return self def forward(self, x): # backbone feat1, feat2, feat3 = self.backbone.forward(x) P5 = self.sppcspc(feat3) P5_conv = self.conv_for_P5(P5) P5_upsample = self.upsample(P5_conv) P4 = torch.cat([self.conv_for_feat2(feat2), P5_upsample], 1) P4 = self.conv3_for_upsample1(P4) P4_conv = self.conv_for_P4(P4) P4_upsample = self.upsample(P4_conv) P3 = torch.cat([self.conv_for_feat1(feat1), P4_upsample], 1) P3 = self.conv3_for_upsample2(P3) P3_downsample = self.down_sample1(P3) P4 = torch.cat([P3_downsample, P4], 1) P4 = self.conv3_for_downsample1(P4) P4_downsample = self.down_sample2(P4) P5 = torch.cat([P4_downsample, P5], 1) P5 = self.conv3_for_downsample2(P5) P3 = self.rep_conv_1(P3) P4 = self.rep_conv_2(P4) P5 = self.rep_conv_3(P5) #---------------------------------------------------# # 第三个特征层 # y3=(batch_size,75,52,52) #---------------------------------------------------# out2 = self.yolo_head_P3(P3) #---------------------------------------------------# # 第二个特征层 # y2=(batch_size,75,26,26) #---------------------------------------------------# out1 = self.yolo_head_P4(P4) #---------------------------------------------------# # 第一个特征层 # y1=(batch_size,75,13,13) #---------------------------------------------------# out0 = self.yolo_head_P5(P5) return out0, out1, out2