diff --git a/dygraph/ppdet/data/transform/batch_operator.py b/dygraph/ppdet/data/transform/batch_operator.py index ff84b645a684b0769f914a05289e6804227312e8..8712d815195a6b6dbcf28a37c533e5333dec31dc 100644 --- a/dygraph/ppdet/data/transform/batch_operator.py +++ b/dygraph/ppdet/data/transform/batch_operator.py @@ -115,8 +115,8 @@ class PadBatchOp(BaseOperator): gt_num_max = max(gt_num) for i, data in enumerate(samples): - gt_box_data = np.zeros([gt_num_max, 4], dtype=np.float32) - gt_class_data = np.zeros([gt_num_max], dtype=np.int32) + gt_box_data = -np.ones([gt_num_max, 4], dtype=np.float32) + gt_class_data = -np.ones([gt_num_max], dtype=np.int32) is_crowd_data = np.ones([gt_num_max], dtype=np.int32) if pad_mask: diff --git a/dygraph/ppdet/modeling/proposal_generator/target.py b/dygraph/ppdet/modeling/proposal_generator/target.py index aa2ddba1d8eb5ba9c4cdf4449921bbe96d445f97..e5d4a10192ca4eb9d974c7b5c6c7c9f3c9b3067f 100644 --- a/dygraph/ppdet/modeling/proposal_generator/target.py +++ b/dygraph/ppdet/modeling/proposal_generator/target.py @@ -69,16 +69,15 @@ def label_box(anchors, gt_boxes, positive_overlap, negative_overlap, return default_matches, default_match_labels matched_vals, matches = paddle.topk(iou, k=1, axis=0) match_labels = paddle.full(matches.shape, -1, dtype='int32') - match_labels = paddle.where(matched_vals < negative_overlap, paddle.zeros_like(match_labels), match_labels) match_labels = paddle.where(matched_vals >= positive_overlap, paddle.ones_like(match_labels), match_labels) if allow_low_quality: highest_quality_foreach_gt = iou.max(axis=1, keepdim=True) - pred_inds_with_highest_quality = ( - iou == highest_quality_foreach_gt).cast('int32').sum(0, - keepdim=True) + pred_inds_with_highest_quality = paddle.logical_and( + iou > 0, iou == highest_quality_foreach_gt).cast('int32').sum( + 0, keepdim=True) match_labels = paddle.where(pred_inds_with_highest_quality > 0, paddle.ones_like(match_labels), match_labels) @@ -151,7 +150,7 @@ def generate_proposal_target(rpn_rois, for i, rpn_roi in enumerate(rpn_rois): max_overlap = max_overlaps[i] if is_cascade_rcnn else None gt_bbox = gt_boxes[i] - gt_classes = gt_classes[i] + gt_class = gt_classes[i] if is_cascade_rcnn: rpn_roi = filter_roi(rpn_roi, max_overlap) bbox = paddle.concat([rpn_roi, gt_bbox]) @@ -161,7 +160,7 @@ def generate_proposal_target(rpn_rois, bbox, gt_bbox, fg_thresh, bg_thresh, False) # Step2: sample bbox sampled_inds, sampled_gt_classes = sample_bbox( - matches, match_labels, gt_classes, batch_size_per_im, fg_fraction, + matches, match_labels, gt_class, batch_size_per_im, fg_fraction, num_classes, use_random) # Step3: make output