提交 91594e50 编写于 作者: D dengkaipeng

fix pisa

上级 33f02d99
...@@ -20,25 +20,33 @@ import numpy as np ...@@ -20,25 +20,33 @@ import numpy as np
__all__ = ['get_isr_p_func'] __all__ = ['get_isr_p_func']
def get_isr_p_func(pos_iou_thresh=0.25, bias=0, k=2): def get_isr_p_func(max_box_num=50, pos_iou_thresh=0.5, bias=0, k=2):
def irs_p(x): def irs_p(x):
np.save("data", x)
x = np.array(x) x = np.array(x)
max_ious = x[:, :, 0] gt_label = x[:, :max_box_num]
gt_inds = x[:, :, 1].astype('int32') gt_score = x[:, max_box_num:2 * max_box_num]
cls = x[:, :, 2].astype('int32') remain = x[:, 2 * max_box_num:]
pn = remain.shape[1] // 3
max_ious = remain[:, :pn]
gt_inds = remain[:, pn:2 * pn].astype('int32')
cls = remain[:, 2 * pn:].astype('int32')
# # n_{max}: max gt box num in each class pos_mask = max_ious > pos_iou_thresh
# valid_gt = gt_box[:, :, 2] > 0. if not np.any(pos_mask):
# valid_gt_label = gt_label[valid_gt] return np.zeros([max_ious.shape[0], pn, 2]).astype('float32')
# max_l_num = np.bincount(valid_gt_label).max()
cls_target = np.zeros_like(max_ious)
cls_target_weights = np.zeros_like(max_ious)
for i in range(gt_label.shape[0]):
cls_target[i] = gt_label[i, gt_inds[i]]
cls_target_weights[i] = gt_score[i, gt_inds[i]]
# cls_target *= pos_mask.astype('float32')
# divide gt index in each sample # divide gt index in each sample
gt_inds = gt_inds + np.arange(gt_inds.shape[ gt_inds = gt_inds + np.arange(gt_inds.shape[
0])[:, np.newaxis] * gt_inds.shape[1] 0])[:, np.newaxis] * max_box_num
all_pos_weights = np.ones_like(max_ious) all_pos_weights = np.zeros_like(max_ious)
pos_mask = max_ious > pos_iou_thresh
cls = np.reshape(cls, list(max_ious.shape) + [-1]) cls = np.reshape(cls, list(max_ious.shape) + [-1])
max_ious = max_ious[pos_mask] max_ious = max_ious[pos_mask]
pos_weights = all_pos_weights[pos_mask] pos_weights = all_pos_weights[pos_mask]
...@@ -58,12 +66,12 @@ def get_isr_p_func(pos_iou_thresh=0.25, bias=0, k=2): ...@@ -58,12 +66,12 @@ def get_isr_p_func(pos_iou_thresh=0.25, bias=0, k=2):
l_max_iou_rank = np.argsort(-l_max_ious).argsort().astype('float32') l_max_iou_rank = np.argsort(-l_max_ious).argsort().astype('float32')
weight_factor = np.clip(max_l_num - l_max_iou_rank, 0., weight_factor = np.clip(max_l_num - l_max_iou_rank, 0.,
None) / max_l_num None) / max_l_num
weight_factor = np.power(bias + (1 - bias) * weight_factor, k) pos_weights[l_inds] = np.power(bias + (1 - bias) * weight_factor, k)
pos_weights[l_inds] *= weight_factor pos_weights = pos_weights / max(np.mean(pos_weights), 1e-6)
pos_weights = pos_weights / np.mean(pos_weights)
all_pos_weights[pos_mask] = pos_weights all_pos_weights[pos_mask] = pos_weights
cls_target_weights *= all_pos_weights
return all_pos_weights return np.stack([cls_target, cls_target_weights], axis=-1)
return irs_p return irs_p
......
...@@ -66,7 +66,7 @@ class YOLOv3Loss(object): ...@@ -66,7 +66,7 @@ class YOLOv3Loss(object):
anchor_masks, mask_anchors, num_classes, prefix_name): anchor_masks, mask_anchors, num_classes, prefix_name):
if self._use_fine_grained_loss: if self._use_fine_grained_loss:
return self._get_fine_grained_loss( return self._get_fine_grained_loss(
outputs, targets, gt_box, gt_label, self._batch_size, outputs, targets, gt_box, gt_label, gt_score, self._batch_size,
num_classes, mask_anchors, self._ignore_thresh) num_classes, mask_anchors, self._ignore_thresh)
else: else:
losses = [] losses = []
...@@ -93,7 +93,7 @@ class YOLOv3Loss(object): ...@@ -93,7 +93,7 @@ class YOLOv3Loss(object):
return {'loss': sum(losses)} return {'loss': sum(losses)}
def _get_fine_grained_loss(self, outputs, targets, gt_box, gt_label, def _get_fine_grained_loss(self, outputs, targets, gt_box, gt_label,
batch_size, num_classes, mask_anchors, gt_score, batch_size, num_classes, mask_anchors,
ignore_thresh): ignore_thresh):
""" """
Calculate fine grained YOLOv3 loss Calculate fine grained YOLOv3 loss
...@@ -123,6 +123,7 @@ class YOLOv3Loss(object): ...@@ -123,6 +123,7 @@ class YOLOv3Loss(object):
"YOLOv3 output layer number not equal target number" "YOLOv3 output layer number not equal target number"
loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], [] loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], []
loss_carls, loss_isrp_clss = [], []
if self._iou_loss is not None: if self._iou_loss is not None:
loss_ious = [] loss_ious = []
if self._iou_aware_loss is not None: if self._iou_aware_loss is not None:
...@@ -144,22 +145,54 @@ class YOLOv3Loss(object): ...@@ -144,22 +145,54 @@ class YOLOv3Loss(object):
sorted_iou, sorted_gt_inds = fluid.layers.argsort( sorted_iou, sorted_gt_inds = fluid.layers.argsort(
iou, axis=-1, descending=True) iou, axis=-1, descending=True)
max_iou = sorted_iou[:, :, 0:1] max_iou = sorted_iou[:, :, 0]
gt_inds = fluid.layers.cast( gt_inds = fluid.layers.cast(
sorted_gt_inds[:, :, 0:1], dtype='float32') sorted_gt_inds[:, :, 0], dtype='float32')
pred_cls = fluid.layers.argmax(cls, axis=-1) cls_score = fluid.layers.sigmoid(cls)
pred_cls = fluid.layers.reshape(pred_cls, [batch_size, -1, 1]) sorted_cls_score, sorted_pred_cls = fluid.layers.argsort(
cls_score, axis=-1, descending=True)
pred_cls = fluid.layers.reshape(sorted_pred_cls[:, :, :, :, 0],
[batch_size, -1])
pred_cls = fluid.layers.cast(pred_cls, dtype='float32') pred_cls = fluid.layers.cast(pred_cls, dtype='float32')
gt_label_fp32 = fluid.layers.cast(gt_label, dtype='float32')
isr_p_input = fluid.layers.concat( isr_p_input = fluid.layers.concat(
[max_iou, gt_inds, pred_cls], axis=-1) [gt_label_fp32, gt_score, max_iou, gt_inds, pred_cls], axis=-1)
isr_p = get_isr_p_func() isr_p = get_isr_p_func()
pos_weights = fluid.layers.zeros_like(max_iou) isr_p_output = fluid.layers.zeros_like(sorted_iou[:, :, :2])
fluid.layers.py_func(isr_p, isr_p_input, pos_weights) fluid.layers.py_func(isr_p, isr_p_input, isr_p_output)
tobj_shape = fluid.layers.shape(tobj) tobj_shape = fluid.layers.shape(tobj)
pos_weights = fluid.layers.reshape(pos_weights, ( isr_p_output = fluid.layers.reshape(isr_p_output, (
-1, an_num, tobj_shape[2], tobj_shape[3], 2))
cls_target = fluid.layers.cast(
isr_p_output[:, :, :, :, 0:1], dtype='int32')
cls_target = fluid.layers.one_hot(cls_target, num_classes)
cls_target_weights = isr_p_output[:, :, :, :, 1]
cls_target_weights.stop_gradient = True
loss_isrp_cls = fluid.layers.sigmoid_cross_entropy_with_logits(
cls, cls_target)
loss_isrp_cls = fluid.layers.elementwise_mul(
loss_isrp_cls, cls_target_weights, axis=0)
loss_isrp_cls = fluid.layers.reduce_sum(
loss_isrp_cls, dim=[1, 2, 3])
bias = 0.2
pos_cls_score = fluid.layers.reduce_sum(
cls_score * cls_target, dim=[-1])
pos_cls_score = fluid.layers.reshape(pos_cls_score, [
batch_size,
-1,
])
pos_mask = fluid.layers.cast(
sorted_iou[:, :, 0] > 0.5, dtype='float32')
carl_weights = bias + (1 - bias) * pos_cls_score * pos_mask
carl_weights *= fluid.layers.reduce_sum(
pos_mask) / fluid.layers.reduce_sum(carl_weights)
carl_weights = fluid.layers.reshape(carl_weights, (
-1, an_num, tobj_shape[2], tobj_shape[3])) -1, an_num, tobj_shape[2], tobj_shape[3]))
tobj = tobj * pos_weights
# isr_tobj = tobj * pos_weights # isr_tobj = tobj * pos_weights
# loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls) # loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls)
...@@ -172,18 +205,26 @@ class YOLOv3Loss(object): ...@@ -172,18 +205,26 @@ class YOLOv3Loss(object):
# new_loss_cls.stop_gradient = True # new_loss_cls.stop_gradient = True
# pos_loss_cls_ratio = orig_loss_cls / new_loss_cls # pos_loss_cls_ratio = orig_loss_cls / new_loss_cls
tscale_tobj = tscale * tobj loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(x,
loss_x = fluid.layers.sigmoid_cross_entropy_with_logits( tx) * tscale
x, tx) * tscale_tobj loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(y,
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3]) ty) * tscale
loss_y = fluid.layers.sigmoid_cross_entropy_with_logits( loss_xy = loss_x + loss_y
y, ty) * tscale_tobj
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
# NOTE: we refined loss function of (w, h) as L1Loss # NOTE: we refined loss function of (w, h) as L1Loss
loss_w = fluid.layers.abs(w - tw) * tscale_tobj loss_w = fluid.layers.abs(w - tw) * tscale
loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3]) loss_h = fluid.layers.abs(h - th) * tscale
loss_h = fluid.layers.abs(h - th) * tscale_tobj loss_wh = loss_w + loss_h
loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
loss_carl = (loss_xy + loss_wh) * carl_weights
loss_carl = fluid.layers.reduce_sum(loss_carl, dim=[1, 2, 3])
# loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
# loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
# loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3])
# loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
loss_xy = fluid.layers.reduce_sum(loss_xy * tobj, dim=[1, 2, 3])
loss_wh = fluid.layers.reduce_sum(loss_wh * tobj, dim=[1, 2, 3])
if self._iou_loss is not None: if self._iou_loss is not None:
loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors, loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors,
downsample, self._batch_size) downsample, self._batch_size)
...@@ -200,6 +241,8 @@ class YOLOv3Loss(object): ...@@ -200,6 +241,8 @@ class YOLOv3Loss(object):
loss_iou_aware, dim=[1, 2, 3]) loss_iou_aware, dim=[1, 2, 3])
loss_iou_awares.append(fluid.layers.reduce_mean(loss_iou_aware)) loss_iou_awares.append(fluid.layers.reduce_mean(loss_iou_aware))
# tobj = tobj * pos_weights
loss_obj_pos, loss_obj_neg = self._calc_obj_loss( loss_obj_pos, loss_obj_neg = self._calc_obj_loss(
output, obj, tobj, iou, an_num, self._ignore_thresh, scale_x_y) output, obj, tobj, iou, an_num, self._ignore_thresh, scale_x_y)
...@@ -207,8 +250,10 @@ class YOLOv3Loss(object): ...@@ -207,8 +250,10 @@ class YOLOv3Loss(object):
loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0) loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0)
loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4]) loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4])
loss_xys.append(fluid.layers.reduce_mean(loss_x + loss_y)) loss_xys.append(fluid.layers.reduce_mean(loss_xy))
loss_whs.append(fluid.layers.reduce_mean(loss_w + loss_h)) loss_whs.append(fluid.layers.reduce_mean(loss_wh))
loss_isrp_clss.append(fluid.layers.reduce_mean(loss_isrp_cls))
loss_carls.append(fluid.layers.reduce_mean(loss_carl))
loss_objs.append( loss_objs.append(
fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg)) fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg))
loss_clss.append(fluid.layers.reduce_mean(loss_cls)) loss_clss.append(fluid.layers.reduce_mean(loss_cls))
...@@ -216,6 +261,8 @@ class YOLOv3Loss(object): ...@@ -216,6 +261,8 @@ class YOLOv3Loss(object):
losses_all = { losses_all = {
"loss_xy": fluid.layers.sum(loss_xys), "loss_xy": fluid.layers.sum(loss_xys),
"loss_wh": fluid.layers.sum(loss_whs), "loss_wh": fluid.layers.sum(loss_whs),
"loss_isrp_cls": fluid.layers.sum(loss_isrp_clss),
"loss_carl": fluid.layers.sum(loss_carls),
"loss_obj": fluid.layers.sum(loss_objs), "loss_obj": fluid.layers.sum(loss_objs),
"loss_cls": fluid.layers.sum(loss_clss), "loss_cls": fluid.layers.sum(loss_clss),
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册