yolo_training.py 22.5 KB
Newer Older
J
JiaQi Xu 已提交
1 2
import torch
import torch.nn as nn
B
Bubbliiiing 已提交
3 4
import math
import numpy as np
J
JiaQi Xu 已提交
5

B
Bubbliiiing 已提交
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
class YOLOLoss(nn.Module):
    def __init__(self, anchors, num_classes, input_shape, cuda, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]], label_smoothing = 0):
        super(YOLOLoss, self).__init__()
        #-----------------------------------------------------------#
        #   13x13的特征层对应的anchor是[142, 110],[192, 243],[459, 401]
        #   26x26的特征层对应的anchor是[36, 75],[76, 55],[72, 146]
        #   52x52的特征层对应的anchor是[12, 16],[19, 36],[40, 28]
        #-----------------------------------------------------------#
        self.anchors        = anchors
        self.num_classes    = num_classes
        self.bbox_attrs     = 5 + num_classes
        self.input_shape    = input_shape
        self.anchors_mask   = anchors_mask
        self.label_smoothing = label_smoothing

B
Bubbliiiing 已提交
21
        self.ignore_threshold = 0.5
B
Bubbliiiing 已提交
22
        self.cuda = cuda
J
JiaQi Xu 已提交
23

B
Bubbliiiing 已提交
24 25 26 27 28
    def clip_by_tensor(self, t, t_min, t_max):
        t = t.float()
        result = (t >= t_min).float() * t + (t < t_min).float() * t_min
        result = (result <= t_max).float() * result + (result > t_max).float() * t_max
        return result
J
JiaQi Xu 已提交
29

B
Bubbliiiing 已提交
30 31
    def MSELoss(self, pred, target):
        return torch.pow(pred - target, 2)
J
JiaQi Xu 已提交
32

B
Bubbliiiing 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    def BCELoss(self, pred, target):
        epsilon = 1e-7
        pred    = self.clip_by_tensor(pred, epsilon, 1.0 - epsilon)
        output  = - target * torch.log(pred) - (1.0 - target) * torch.log(1.0 - pred)
        return output
        
    def box_ciou(self, b1, b2):
        """
        输入为:
        ----------
        b1: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
        b2: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh

        返回为:
        -------
        ciou: tensor, shape=(batch, feat_w, feat_h, anchor_num, 1)
        """
        #----------------------------------------------------#
        #   求出预测框左上角右下角
        #----------------------------------------------------#
        b1_xy       = b1[..., :2]
        b1_wh       = b1[..., 2:4]
        b1_wh_half  = b1_wh/2.
        b1_mins     = b1_xy - b1_wh_half
        b1_maxes    = b1_xy + b1_wh_half
        #----------------------------------------------------#
        #   求出真实框左上角右下角
        #----------------------------------------------------#
        b2_xy       = b2[..., :2]
        b2_wh       = b2[..., 2:4]
        b2_wh_half  = b2_wh/2.
        b2_mins     = b2_xy - b2_wh_half
        b2_maxes    = b2_xy + b2_wh_half
J
JiaQi Xu 已提交
66

B
Bubbliiiing 已提交
67 68 69 70 71 72 73 74 75 76 77
        #----------------------------------------------------#
        #   求真实框和预测框所有的iou
        #----------------------------------------------------#
        intersect_mins  = torch.max(b1_mins, b2_mins)
        intersect_maxes = torch.min(b1_maxes, b2_maxes)
        intersect_wh    = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
        intersect_area  = intersect_wh[..., 0] * intersect_wh[..., 1]
        b1_area         = b1_wh[..., 0] * b1_wh[..., 1]
        b2_area         = b2_wh[..., 0] * b2_wh[..., 1]
        union_area      = b1_area + b2_area - intersect_area
        iou             = intersect_area / torch.clamp(union_area,min = 1e-6)
J
JiaQi Xu 已提交
78

B
Bubbliiiing 已提交
79
        #----------------------------------------------------#
B
Bubbliiiing 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
        #   计算中心的差距
        #----------------------------------------------------#
        center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1)
        
        #----------------------------------------------------#
        #   找到包裹两个框的最小框的左上角和右下角
        #----------------------------------------------------#
        enclose_mins    = torch.min(b1_mins, b2_mins)
        enclose_maxes   = torch.max(b1_maxes, b2_maxes)
        enclose_wh      = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
        #----------------------------------------------------#
        #   计算对角线距离
        #----------------------------------------------------#
        enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)
        ciou            = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)
        
        v       = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0] / torch.clamp(b1_wh[..., 1],min = 1e-6)) - torch.atan(b2_wh[..., 0] / torch.clamp(b2_wh[..., 1], min = 1e-6))), 2)
        alpha   = v / torch.clamp((1.0 - iou + v), min=1e-6)
        ciou    = ciou - alpha * v
        return ciou

    #---------------------------------------------------#
    #   平滑标签
    #---------------------------------------------------#
    def smooth_labels(self, y_true, label_smoothing, num_classes):
        return y_true * (1.0 - label_smoothing) + label_smoothing / num_classes

    def forward(self, l, input, targets=None):
        #----------------------------------------------------#
        #   l 代表使用的是第几个有效特征层
B
Bubbliiiing 已提交
110 111 112
        #   input的shape为  bs, 3*(5+num_classes), 13, 13
        #                   bs, 3*(5+num_classes), 26, 26
        #                   bs, 3*(5+num_classes), 52, 52
B
Bubbliiiing 已提交
113
        #   targets 真实框的标签情况 [batch_size, num_gt, 5]
B
Bubbliiiing 已提交
114
        #----------------------------------------------------#
B
Bubbliiiing 已提交
115 116 117 118 119 120
        #--------------------------------#
        #   获得图片数量,特征层的高和宽
        #--------------------------------#
        bs      = input.size(0)
        in_h    = input.size(2)
        in_w    = input.size(3)
B
Bubbliiiing 已提交
121 122 123
        #-----------------------------------------------------------------------#
        #   计算步长
        #   每一个特征点对应原来的图片上多少个像素点
B
Bubbliiiing 已提交
124
        #   
B
Bubbliiiing 已提交
125 126 127 128 129
        #   如果特征层为13x13的话,一个特征点就对应原来的图片上的32个像素点
        #   如果特征层为26x26的话,一个特征点就对应原来的图片上的16个像素点
        #   如果特征层为52x52的话,一个特征点就对应原来的图片上的8个像素点
        #   stride_h = stride_w = 32、16、8
        #-----------------------------------------------------------------------#
B
Bubbliiiing 已提交
130 131
        stride_h = self.input_shape[0] / in_h
        stride_w = self.input_shape[1] / in_w
B
Bubbliiiing 已提交
132 133 134
        #-------------------------------------------------#
        #   此时获得的scaled_anchors大小是相对于特征层的
        #-------------------------------------------------#
B
Bubbliiiing 已提交
135
        scaled_anchors  = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in self.anchors]
B
Bubbliiiing 已提交
136 137
        #-----------------------------------------------#
        #   输入的input一共有三个,他们的shape分别是
B
Bubbliiiing 已提交
138 139
        #   bs, 3 * (5+num_classes), 13, 13 => bs, 3, 5 + num_classes, 13, 13 => batch_size, 3, 13, 13, 5 + num_classes

B
Bubbliiiing 已提交
140 141 142 143
        #   batch_size, 3, 13, 13, 5 + num_classes
        #   batch_size, 3, 26, 26, 5 + num_classes
        #   batch_size, 3, 52, 52, 5 + num_classes
        #-----------------------------------------------#
B
Bubbliiiing 已提交
144
        prediction = input.view(bs, len(self.anchors_mask[l]), self.bbox_attrs, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()
J
JiaQi Xu 已提交
145
        
B
Bubbliiiing 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158
        #-----------------------------------------------#
        #   先验框的中心位置的调整参数
        #-----------------------------------------------#
        x = torch.sigmoid(prediction[..., 0])
        y = torch.sigmoid(prediction[..., 1])
        #-----------------------------------------------#
        #   先验框的宽高调整参数
        #-----------------------------------------------#
        w = prediction[..., 2]
        h = prediction[..., 3]
        #-----------------------------------------------#
        #   获得置信度,是否有物体
        #-----------------------------------------------#
B
Bubbliiiing 已提交
159
        conf = torch.sigmoid(prediction[..., 4])
B
Bubbliiiing 已提交
160 161 162
        #-----------------------------------------------#
        #   种类置信度
        #-----------------------------------------------#
B
Bubbliiiing 已提交
163 164
        pred_cls = torch.sigmoid(prediction[..., 5:])

B
Bubbliiiing 已提交
165 166 167 168
        #-----------------------------------------------#
        #   获得网络应该有的预测结果
        #-----------------------------------------------#
        y_true, noobj_mask, box_loss_scale = self.get_target(l, targets, scaled_anchors, in_h, in_w)
J
JiaQi Xu 已提交
169

B
Bubbliiiing 已提交
170 171 172 173 174
        #---------------------------------------------------------------#
        #   将预测结果进行解码,判断预测结果和真实值的重合程度
        #   如果重合程度过大则忽略,因为这些特征点属于预测比较准确的特征点
        #   作为负样本不合适
        #----------------------------------------------------------------#
B
Bubbliiiing 已提交
175
        noobj_mask, pred_boxes = self.get_ignore(l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask)
J
JiaQi Xu 已提交
176 177

        if self.cuda:
B
Bubbliiiing 已提交
178 179 180 181 182 183 184 185 186
            y_true          = y_true.cuda()
            noobj_mask      = noobj_mask.cuda()
            box_loss_scale  = box_loss_scale.cuda()
        #-----------------------------------------------------------#
        #   reshape_y_true[...,2:3]和reshape_y_true[...,3:4]
        #   表示真实框的宽高,二者均在0-1之间
        #   真实框越大,比重越小,小框的比重更大。
        #-----------------------------------------------------------#
        box_loss_scale = 2 - box_loss_scale
J
JiaQi Xu 已提交
187

B
Bubbliiiing 已提交
188 189 190
        #---------------------------------------------------------------#
        #   计算预测结果和真实结果的CIOU
        #----------------------------------------------------------------#
B
Bubbliiiing 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203
        ciou        = (1 - self.box_ciou(pred_boxes[y_true[..., 4] == 1], y_true[..., :4][y_true[..., 4] == 1])) * box_loss_scale[y_true[..., 4] == 1]
        loss_loc    = torch.sum(ciou)
        #-----------------------------------------------------------#
        #   计算置信度的loss
        #-----------------------------------------------------------#
        loss_conf   = torch.sum(self.BCELoss(conf, y_true[..., 4]) * y_true[..., 4]) + \
                      torch.sum(self.BCELoss(conf, y_true[..., 4]) * noobj_mask)

        loss_cls    = torch.sum(self.BCELoss(pred_cls[y_true[..., 4] == 1], self.smooth_labels(y_true[..., 5:][y_true[..., 4] == 1], self.label_smoothing, self.num_classes)))

        loss        = loss_loc + loss_conf + loss_cls
        num_pos = torch.sum(y_true[..., 4])
        num_pos = torch.max(num_pos, torch.ones_like(num_pos))
B
Bubbliiiing 已提交
204
        return loss, num_pos
J
JiaQi Xu 已提交
205

B
Bubbliiiing 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
    def calculate_iou(self, _box_a, _box_b):
        #-----------------------------------------------------------#
        #   计算真实框的左上角和右下角
        #-----------------------------------------------------------#
        b1_x1, b1_x2 = _box_a[:, 0] - _box_a[:, 2] / 2, _box_a[:, 0] + _box_a[:, 2] / 2
        b1_y1, b1_y2 = _box_a[:, 1] - _box_a[:, 3] / 2, _box_a[:, 1] + _box_a[:, 3] / 2
        #-----------------------------------------------------------#
        #   计算先验框获得的预测框的左上角和右下角
        #-----------------------------------------------------------#
        b2_x1, b2_x2 = _box_b[:, 0] - _box_b[:, 2] / 2, _box_b[:, 0] + _box_b[:, 2] / 2
        b2_y1, b2_y2 = _box_b[:, 1] - _box_b[:, 3] / 2, _box_b[:, 1] + _box_b[:, 3] / 2

        #-----------------------------------------------------------#
        #   将真实框和预测框都转化成左上角右下角的形式
        #-----------------------------------------------------------#
        box_a = torch.zeros_like(_box_a)
        box_b = torch.zeros_like(_box_b)
        box_a[:, 0], box_a[:, 1], box_a[:, 2], box_a[:, 3] = b1_x1, b1_y1, b1_x2, b1_y2
        box_b[:, 0], box_b[:, 1], box_b[:, 2], box_b[:, 3] = b2_x1, b2_y1, b2_x2, b2_y2

        #-----------------------------------------------------------#
        #   A为真实框的数量,B为先验框的数量
        #-----------------------------------------------------------#
        A = box_a.size(0)
        B = box_b.size(0)

        #-----------------------------------------------------------#
        #   计算交的面积
        #-----------------------------------------------------------#
        max_xy  = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
        min_xy  = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
        inter   = torch.clamp((max_xy - min_xy), min=0)
        inter   = inter[:, :, 0] * inter[:, :, 1]
        #-----------------------------------------------------------#
        #   计算预测框和真实框各自的面积
        #-----------------------------------------------------------#
        area_a = ((box_a[:, 2]-box_a[:, 0]) * (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter)  # [A,B]
        area_b = ((box_b[:, 2]-box_b[:, 0]) * (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)  # [A,B]
        #-----------------------------------------------------------#
        #   求IOU
        #-----------------------------------------------------------#
        union = area_a + area_b - inter
        return inter / union  # [A,B]
    
    def get_target(self, l, targets, anchors, in_h, in_w):
B
Bubbliiiing 已提交
251 252 253
        #-----------------------------------------------------#
        #   计算一共有多少张图片
        #-----------------------------------------------------#
B
Bubbliiiing 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
        bs              = len(targets)
        #-----------------------------------------------------#
        #   用于选取哪些先验框不包含物体
        #-----------------------------------------------------#
        noobj_mask      = torch.ones(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
        #-----------------------------------------------------#
        #   让网络更加去关注小目标
        #-----------------------------------------------------#
        box_loss_scale  = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
        #-----------------------------------------------------#
        #   batch_size, 3, 13, 13, 5 + num_classes
        #-----------------------------------------------------#
        y_true          = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, self.bbox_attrs, requires_grad = False)
        for b in range(bs):            
            if len(targets[b])==0:
B
Bubbliiiing 已提交
269
                continue
B
Bubbliiiing 已提交
270
            batch_target = torch.zeros_like(targets[b])
B
Bubbliiiing 已提交
271 272 273
            #-------------------------------------------------------#
            #   计算出正样本在特征层上的中心点
            #-------------------------------------------------------#
B
Bubbliiiing 已提交
274 275 276 277
            batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
            batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
            batch_target[:, 4] = targets[b][:, 4]
            batch_target = batch_target.cpu()
B
Bubbliiiing 已提交
278
            
B
Bubbliiiing 已提交
279 280 281 282
            #-------------------------------------------------------#
            #   将真实框转换一个形式
            #   num_true_box, 4
            #-------------------------------------------------------#
B
Bubbliiiing 已提交
283
            gt_box          = torch.FloatTensor(torch.cat((torch.zeros((batch_target.size(0), 2)), batch_target[:, 2:4]), 1))
B
Bubbliiiing 已提交
284 285 286 287
            #-------------------------------------------------------#
            #   将先验框转换一个形式
            #   9, 4
            #-------------------------------------------------------#
B
Bubbliiiing 已提交
288
            anchor_shapes   = torch.FloatTensor(torch.cat((torch.zeros((len(anchors), 2)), torch.FloatTensor(anchors)), 1))
B
Bubbliiiing 已提交
289 290
            #-------------------------------------------------------#
            #   计算交并比
B
Bubbliiiing 已提交
291 292 293
            #   self.calculate_iou(gt_box, anchor_shapes) = [num_true_box, 9]每一个真实框和9个先验框的重合情况
            #   best_ns:
            #   [每个真实框最大的重合度max_iou, 每一个真实框最重合的先验框的序号]
B
Bubbliiiing 已提交
294
            #-------------------------------------------------------#
B
Bubbliiiing 已提交
295
            best_ns = torch.argmax(self.calculate_iou(gt_box, anchor_shapes), dim=-1)
B
Bubbliiiing 已提交
296

B
Bubbliiiing 已提交
297 298
            for t, best_n in enumerate(best_ns):
                if best_n not in self.anchors_mask[l]:
J
JiaQi Xu 已提交
299
                    continue
B
Bubbliiiing 已提交
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
                #----------------------------------------#
                #   判断这个先验框是当前特征点的哪一个先验框
                #----------------------------------------#
                k = self.anchors_mask[l].index(best_n)
                #----------------------------------------#
                #   获得真实框属于哪个网格点
                #----------------------------------------#
                i = torch.floor(batch_target[t, 0]).long()
                j = torch.floor(batch_target[t, 1]).long()
                #----------------------------------------#
                #   取出真实框的种类
                #----------------------------------------#
                c = batch_target[t, 4].long()
                
                #----------------------------------------#
                #   noobj_mask代表无目标的特征点
                #----------------------------------------#
                noobj_mask[b, k, j, i] = 0
                #----------------------------------------#
                #   tx、ty代表中心调整参数的真实值
                #----------------------------------------#
                y_true[b, k, j, i, 0] = batch_target[t, 0]
                y_true[b, k, j, i, 1] = batch_target[t, 1]
                y_true[b, k, j, i, 2] = batch_target[t, 2]
                y_true[b, k, j, i, 3] = batch_target[t, 3]
                y_true[b, k, j, i, 4] = 1
                y_true[b, k, j, i, c + 5] = 1
                #----------------------------------------#
                #   用于获得xywh的比例
                #   大目标loss权重小,小目标loss权重大
                #----------------------------------------#
                box_loss_scale[b, k, j, i] = batch_target[t, 2] * batch_target[t, 3] / in_w / in_h
        return y_true, noobj_mask, box_loss_scale

    def get_ignore(self, l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask):
B
Bubbliiiing 已提交
335 336 337
        #-----------------------------------------------------#
        #   计算一共有多少张图片
        #-----------------------------------------------------#
B
Bubbliiiing 已提交
338
        bs = len(targets)
J
JiaQi Xu 已提交
339 340

        FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
B
Bubbliiiing 已提交
341 342 343 344
        LongTensor  = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
        #-----------------------------------------------------#
        #   生成网格,先验框中心,网格左上角
        #-----------------------------------------------------#
B
Bubbliiiing 已提交
345
        grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat(
B
Bubbliiiing 已提交
346
            int(bs * len(self.anchors_mask[l])), 1, 1).view(x.shape).type(FloatTensor)
B
Bubbliiiing 已提交
347
        grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat(
B
Bubbliiiing 已提交
348
            int(bs * len(self.anchors_mask[l])), 1, 1).view(y.shape).type(FloatTensor)
J
JiaQi Xu 已提交
349 350

        # 生成先验框的宽高
B
Bubbliiiing 已提交
351 352 353
        scaled_anchors_l = np.array(scaled_anchors)[self.anchors_mask[l]]
        anchor_w = FloatTensor(scaled_anchors_l).index_select(1, LongTensor([0]))
        anchor_h = FloatTensor(scaled_anchors_l).index_select(1, LongTensor([1]))
J
JiaQi Xu 已提交
354 355 356
        
        anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
        anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
B
Bubbliiiing 已提交
357 358 359
        #-------------------------------------------------------#
        #   计算调整后的先验框中心与宽高
        #-------------------------------------------------------#
B
Bubbliiiing 已提交
360 361 362 363 364 365
        pred_boxes_x    = torch.unsqueeze(x + grid_x, -1)
        pred_boxes_y    = torch.unsqueeze(y + grid_y, -1)
        pred_boxes_w    = torch.unsqueeze(torch.exp(w) * anchor_w, -1)
        pred_boxes_h    = torch.unsqueeze(torch.exp(h) * anchor_h, -1)
        pred_boxes      = torch.cat([pred_boxes_x, pred_boxes_y, pred_boxes_w, pred_boxes_h], dim = -1)
        for b in range(bs):           
B
Bubbliiiing 已提交
366 367 368 369
            #-------------------------------------------------------#
            #   将预测结果转换一个形式
            #   pred_boxes_for_ignore      num_anchors, 4
            #-------------------------------------------------------#
B
Bubbliiiing 已提交
370
            pred_boxes_for_ignore = pred_boxes[b].view(-1, 4)
B
Bubbliiiing 已提交
371 372 373 374
            #-------------------------------------------------------#
            #   计算真实框,并把真实框转换成相对于特征层的大小
            #   gt_box      num_true_box, 4
            #-------------------------------------------------------#
B
Bubbliiiing 已提交
375 376 377 378 379 380 381 382
            if len(targets[b]) > 0:
                batch_target = torch.zeros_like(targets[b])
                #-------------------------------------------------------#
                #   计算出正样本在特征层上的中心点
                #-------------------------------------------------------#
                batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
                batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
                batch_target = batch_target[:, :4]
B
Bubbliiiing 已提交
383 384 385 386
                #-------------------------------------------------------#
                #   计算交并比
                #   anch_ious       num_true_box, num_anchors
                #-------------------------------------------------------#
B
Bubbliiiing 已提交
387
                anch_ious = self.calculate_iou(batch_target, pred_boxes_for_ignore)
B
Bubbliiiing 已提交
388 389 390 391
                #-------------------------------------------------------#
                #   每个先验框对应真实框的最大重合度
                #   anch_ious_max   num_anchors
                #-------------------------------------------------------#
B
Bubbliiiing 已提交
392 393 394
                anch_ious_max, _    = torch.max(anch_ious, dim = 0)
                anch_ious_max       = anch_ious_max.view(pred_boxes[b].size()[:3])
                noobj_mask[b][anch_ious_max > self.ignore_threshold] = 0
J
JiaQi Xu 已提交
395 396
        return noobj_mask, pred_boxes

B
Bubbliiiing 已提交
397
def weights_init(net, init_type='normal', init_gain = 0.02):
B
Bubbliiiing 已提交
398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and classname.find('Conv') != -1:
            if init_type == 'normal':
                torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)
    print('initialize network with %s type' % init_type)
    net.apply(init_func)