From 0f148a62ad1d19126e00b3cb61943e3cd8bc90fd Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Wed, 10 Jun 2020 14:52:45 +0000 Subject: [PATCH] add isr_p --- ppdet/modeling/losses/pisa_utils.py | 87 +++++++++++++++++++++++++++++ ppdet/modeling/losses/yolo_loss.py | 69 +++++++++++++++++------ 2 files changed, 138 insertions(+), 18 deletions(-) create mode 100644 ppdet/modeling/losses/pisa_utils.py diff --git a/ppdet/modeling/losses/pisa_utils.py b/ppdet/modeling/losses/pisa_utils.py new file mode 100644 index 000000000..6d6345037 --- /dev/null +++ b/ppdet/modeling/losses/pisa_utils.py @@ -0,0 +1,87 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np + +__all__ = ['get_isr_p_func'] + + +def get_isr_p_func(pos_iou_thresh=0.25, bias=0, k=2): + def irs_p(x): + np.save("data", x) + x = np.array(x) + max_ious = x[:, :, 0] + gt_inds = x[:, :, 1].astype('int32') + cls = x[:, :, 2].astype('int32') + + # # n_{max}: max gt box num in each class + # valid_gt = gt_box[:, :, 2] > 0. + # valid_gt_label = gt_label[valid_gt] + # max_l_num = np.bincount(valid_gt_label).max() + + # divide gt index in each sample + gt_inds = gt_inds + np.arange(gt_inds.shape[ + 0])[:, np.newaxis] * gt_inds.shape[1] + + all_pos_weights = np.ones_like(max_ious) + pos_mask = max_ious > pos_iou_thresh + cls = np.reshape(cls, list(max_ious.shape) + [-1]) + max_ious = max_ious[pos_mask] + pos_weights = all_pos_weights[pos_mask] + gt_inds = gt_inds[pos_mask] + cls = cls[pos_mask] + max_l_num = np.bincount(cls.reshape(-1)).max() + for l in np.unique(cls): + l_inds = np.nonzero(cls == l)[0] + l_gt_inds = gt_inds[l_inds] + for t in np.unique(l_gt_inds): + t_inds = np.array(l_inds)[l_gt_inds == t] + t_max_ious = max_ious[t_inds] + t_max_iou_rank = np.argsort(-t_max_ious).argsort().astype( + 'float32') + max_ious[t_inds] += np.clip(t_max_iou_rank, 0., None) + l_max_ious = max_ious[l_inds] + l_max_iou_rank = np.argsort(-l_max_ious).argsort().astype('float32') + weight_factor = np.clip(max_l_num - l_max_iou_rank, 0., + None) / max_l_num + weight_factor = np.power(bias + (1 - bias) * weight_factor, k) + pos_weights[l_inds] *= weight_factor * 1.2 + pos_weights = pos_weights / np.mean(pos_weights) + all_pos_weights[pos_mask] = pos_weights + + return all_pos_weights + + return irs_p + + +if __name__ == "__main__": + import numpy as np + import paddle.fluid as fluid + x_np = np.load('./data.npy') + + x = fluid.data(name='x', shape=[8, 15552, 3], dtype='float32') + pos_weights = fluid.default_main_program().current_block().create_var( + name="pos_weights", dtype='float32', shape=[8, 15552]) + isr_p = get_isr_p_func() + fluid.layers.py_func(isr_p, x, pos_weights) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + ret = exe.run(fetch_list=[pos_weights.name], feed={'x': x_np}) + print(ret) + np.save("ret", ret[0]) diff --git a/ppdet/modeling/losses/yolo_loss.py b/ppdet/modeling/losses/yolo_loss.py index 732edb39d..5f31a498d 100644 --- a/ppdet/modeling/losses/yolo_loss.py +++ b/ppdet/modeling/losses/yolo_loss.py @@ -18,6 +18,7 @@ from __future__ import print_function from paddle import fluid from ppdet.core.workspace import register +from .pisa_utils import get_isr_p_func try: from collections.abc import Sequence except Exception: @@ -65,8 +66,8 @@ class YOLOv3Loss(object): anchor_masks, mask_anchors, num_classes, prefix_name): if self._use_fine_grained_loss: return self._get_fine_grained_loss( - outputs, targets, gt_box, self._batch_size, num_classes, - mask_anchors, self._ignore_thresh) + outputs, targets, gt_box, gt_label, self._batch_size, + num_classes, mask_anchors, self._ignore_thresh) else: losses = [] for i, output in enumerate(outputs): @@ -91,8 +92,9 @@ class YOLOv3Loss(object): return {'loss': sum(losses)} - def _get_fine_grained_loss(self, outputs, targets, gt_box, batch_size, - num_classes, mask_anchors, ignore_thresh): + def _get_fine_grained_loss(self, outputs, targets, gt_box, gt_label, + batch_size, num_classes, mask_anchors, + ignore_thresh): """ Calculate fine grained YOLOv3 loss @@ -135,6 +137,38 @@ class YOLOv3Loss(object): num_classes) tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target) + scale_x_y = self.scale_x_y if not isinstance( + self.scale_x_y, Sequence) else self.scale_x_y[i] + iou = self._calc_iou(output, target, gt_box, anchors, batch_size, + num_classes, downsample, scale_x_y) + + # sorted_iou, sorted_gt_inds = fluid.layers.argsort(iou, axis=-1, descending=True) + # max_iou = sorted_iou[:, :, 0:1] + # gt_inds = fluid.layers.cast(sorted_gt_inds[:, :, 0:1], dtype='float32') + # pred_cls = fluid.layers.argmax(cls, axis=-1) + # pred_cls = fluid.layers.reshape(pred_cls, [batch_size, -1, 1]) + # pred_cls = fluid.layers.cast(pred_cls, dtype='float32') + # isr_p_input = fluid.layers.concat([max_iou, gt_inds, pred_cls], axis=-1) + # isr_p = get_isr_p_func() + # pos_weights = fluid.layers.zeros_like(max_iou) + # fluid.layers.py_func(isr_p, isr_p_input, pos_weights) + # + # tobj_shape = fluid.layers.shape(tobj) + # pos_weights = fluid.layers.reshape(pos_weights, (-1, an_num, tobj_shape[2], + # tobj_shape[3])) + # tobj = tobj * pos_weights + + # isr_tobj = tobj * pos_weights + # loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls) + # pos_mask = fluid.layers.cast(pos_weights > 0., dtype='flaot32') + # orig_loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj * pos_mask, axis=0) + # orig_loss_cls = fluid.layers.reduce_sum(loss_cls) + # orig_loss_cls.stop_gradient = True + # new_loss_cls = fluid.layers.elementwise_mul(loss_cls, isr_tobj * pos_mask, axis=0) + # new_loss_cls = fluid.layers.reduce_sum(loss_cls) + # new_loss_cls.stop_gradient = True + # pos_loss_cls_ratio = orig_loss_cls / new_loss_cls + tscale_tobj = tscale * tobj loss_x = fluid.layers.sigmoid_cross_entropy_with_logits( x, tx) * tscale_tobj @@ -163,11 +197,8 @@ class YOLOv3Loss(object): loss_iou_aware, dim=[1, 2, 3]) loss_iou_awares.append(fluid.layers.reduce_mean(loss_iou_aware)) - scale_x_y = self.scale_x_y if not isinstance( - self.scale_x_y, Sequence) else self.scale_x_y[i] loss_obj_pos, loss_obj_neg = self._calc_obj_loss( - output, obj, tobj, gt_box, self._batch_size, anchors, - num_classes, downsample, self._ignore_thresh, scale_x_y) + output, obj, tobj, iou, an_num, self._ignore_thresh, scale_x_y) loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls) loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0) @@ -276,11 +307,8 @@ class YOLOv3Loss(object): return (tx, ty, tw, th, tscale, tobj, tcls) - def _calc_obj_loss(self, output, obj, tobj, gt_box, batch_size, anchors, - num_classes, downsample, ignore_thresh, scale_x_y): - # A prediction bbox overlap any gt_bbox over ignore_thresh, - # objectness loss will be ignored, process as follows: - + def _calc_iou(self, output, target, gt_box, anchors, batch_size, + num_classes, downsample, scale_x_y): # 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here # NOTE: img_size is set as 1.0 to get noramlized pred bbox bbox, prob = fluid.layers.yolo_box( @@ -302,7 +330,6 @@ class YOLOv3Loss(object): else: preds = [bbox] gts = [gt_box] - probs = [prob] ious = [] for pred, gt in zip(preds, gts): @@ -322,10 +349,17 @@ class YOLOv3Loss(object): pred = fluid.layers.squeeze(pred, axes=[0]) gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0])) ious.append(fluid.layers.iou_similarity(pred, gt)) - iou = fluid.layers.stack(ious, axis=0) - # 3. Get iou_mask by IoU between gt bbox and prediction bbox, - # Get obj_mask by tobj(holds gt_score), calculate objectness loss + + return iou + + def _calc_obj_loss(self, output, obj, tobj, iou, an_num, ignore_thresh, + scale_x_y): + # A prediction bbox overlap any gt_bbox over ignore_thresh, + # objectness loss will be ignored, process as follows: + + # Get iou_mask by IoU between gt bbox and prediction bbox, + # Get obj_mask by tobj(holds gt_score), calculate objectness loss max_iou = fluid.layers.reduce_max(iou, dim=-1) iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32") @@ -334,7 +368,6 @@ class YOLOv3Loss(object): iou_mask = iou_mask * fluid.layers.cast( max_prob <= 0.25, dtype="float32") output_shape = fluid.layers.shape(output) - an_num = len(anchors) // 2 iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2], output_shape[3])) iou_mask.stop_gradient = True -- GitLab