From 3cd9e8571aecbfa36ff765f432a2998652eb24b3 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Sat, 30 Jan 2021 11:24:37 +0800 Subject: [PATCH] fix batch size > 1, test=dygraph (#2141) --- dygraph/ppdet/data/transform/batch_operator.py | 4 ++-- dygraph/ppdet/modeling/proposal_generator/target.py | 11 +++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/dygraph/ppdet/data/transform/batch_operator.py b/dygraph/ppdet/data/transform/batch_operator.py index ff84b645a..8712d8151 100644 --- a/dygraph/ppdet/data/transform/batch_operator.py +++ b/dygraph/ppdet/data/transform/batch_operator.py @@ -115,8 +115,8 @@ class PadBatchOp(BaseOperator): gt_num_max = max(gt_num) for i, data in enumerate(samples): - gt_box_data = np.zeros([gt_num_max, 4], dtype=np.float32) - gt_class_data = np.zeros([gt_num_max], dtype=np.int32) + gt_box_data = -np.ones([gt_num_max, 4], dtype=np.float32) + gt_class_data = -np.ones([gt_num_max], dtype=np.int32) is_crowd_data = np.ones([gt_num_max], dtype=np.int32) if pad_mask: diff --git a/dygraph/ppdet/modeling/proposal_generator/target.py b/dygraph/ppdet/modeling/proposal_generator/target.py index aa2ddba1d..e5d4a1019 100644 --- a/dygraph/ppdet/modeling/proposal_generator/target.py +++ b/dygraph/ppdet/modeling/proposal_generator/target.py @@ -69,16 +69,15 @@ def label_box(anchors, gt_boxes, positive_overlap, negative_overlap, return default_matches, default_match_labels matched_vals, matches = paddle.topk(iou, k=1, axis=0) match_labels = paddle.full(matches.shape, -1, dtype='int32') - match_labels = paddle.where(matched_vals < negative_overlap, paddle.zeros_like(match_labels), match_labels) match_labels = paddle.where(matched_vals >= positive_overlap, paddle.ones_like(match_labels), match_labels) if allow_low_quality: highest_quality_foreach_gt = iou.max(axis=1, keepdim=True) - pred_inds_with_highest_quality = ( - iou == highest_quality_foreach_gt).cast('int32').sum(0, - keepdim=True) + pred_inds_with_highest_quality = paddle.logical_and( + iou > 0, iou == highest_quality_foreach_gt).cast('int32').sum( + 0, keepdim=True) match_labels = paddle.where(pred_inds_with_highest_quality > 0, paddle.ones_like(match_labels), match_labels) @@ -151,7 +150,7 @@ def generate_proposal_target(rpn_rois, for i, rpn_roi in enumerate(rpn_rois): max_overlap = max_overlaps[i] if is_cascade_rcnn else None gt_bbox = gt_boxes[i] - gt_classes = gt_classes[i] + gt_class = gt_classes[i] if is_cascade_rcnn: rpn_roi = filter_roi(rpn_roi, max_overlap) bbox = paddle.concat([rpn_roi, gt_bbox]) @@ -161,7 +160,7 @@ def generate_proposal_target(rpn_rois, bbox, gt_bbox, fg_thresh, bg_thresh, False) # Step2: sample bbox sampled_inds, sampled_gt_classes = sample_bbox( - matches, match_labels, gt_classes, batch_size_per_im, fg_fraction, + matches, match_labels, gt_class, batch_size_per_im, fg_fraction, num_classes, use_random) # Step3: make output -- GitLab