未验证 提交 3cd9e857 编写于 作者: W wangguanzhong 提交者: GitHub

fix batch size > 1, test=dygraph (#2141)

上级 32edc345
...@@ -115,8 +115,8 @@ class PadBatchOp(BaseOperator): ...@@ -115,8 +115,8 @@ class PadBatchOp(BaseOperator):
gt_num_max = max(gt_num) gt_num_max = max(gt_num)
for i, data in enumerate(samples): for i, data in enumerate(samples):
gt_box_data = np.zeros([gt_num_max, 4], dtype=np.float32) gt_box_data = -np.ones([gt_num_max, 4], dtype=np.float32)
gt_class_data = np.zeros([gt_num_max], dtype=np.int32) gt_class_data = -np.ones([gt_num_max], dtype=np.int32)
is_crowd_data = np.ones([gt_num_max], dtype=np.int32) is_crowd_data = np.ones([gt_num_max], dtype=np.int32)
if pad_mask: if pad_mask:
......
...@@ -69,16 +69,15 @@ def label_box(anchors, gt_boxes, positive_overlap, negative_overlap, ...@@ -69,16 +69,15 @@ def label_box(anchors, gt_boxes, positive_overlap, negative_overlap,
return default_matches, default_match_labels return default_matches, default_match_labels
matched_vals, matches = paddle.topk(iou, k=1, axis=0) matched_vals, matches = paddle.topk(iou, k=1, axis=0)
match_labels = paddle.full(matches.shape, -1, dtype='int32') match_labels = paddle.full(matches.shape, -1, dtype='int32')
match_labels = paddle.where(matched_vals < negative_overlap, match_labels = paddle.where(matched_vals < negative_overlap,
paddle.zeros_like(match_labels), match_labels) paddle.zeros_like(match_labels), match_labels)
match_labels = paddle.where(matched_vals >= positive_overlap, match_labels = paddle.where(matched_vals >= positive_overlap,
paddle.ones_like(match_labels), match_labels) paddle.ones_like(match_labels), match_labels)
if allow_low_quality: if allow_low_quality:
highest_quality_foreach_gt = iou.max(axis=1, keepdim=True) highest_quality_foreach_gt = iou.max(axis=1, keepdim=True)
pred_inds_with_highest_quality = ( pred_inds_with_highest_quality = paddle.logical_and(
iou == highest_quality_foreach_gt).cast('int32').sum(0, iou > 0, iou == highest_quality_foreach_gt).cast('int32').sum(
keepdim=True) 0, keepdim=True)
match_labels = paddle.where(pred_inds_with_highest_quality > 0, match_labels = paddle.where(pred_inds_with_highest_quality > 0,
paddle.ones_like(match_labels), paddle.ones_like(match_labels),
match_labels) match_labels)
...@@ -151,7 +150,7 @@ def generate_proposal_target(rpn_rois, ...@@ -151,7 +150,7 @@ def generate_proposal_target(rpn_rois,
for i, rpn_roi in enumerate(rpn_rois): for i, rpn_roi in enumerate(rpn_rois):
max_overlap = max_overlaps[i] if is_cascade_rcnn else None max_overlap = max_overlaps[i] if is_cascade_rcnn else None
gt_bbox = gt_boxes[i] gt_bbox = gt_boxes[i]
gt_classes = gt_classes[i] gt_class = gt_classes[i]
if is_cascade_rcnn: if is_cascade_rcnn:
rpn_roi = filter_roi(rpn_roi, max_overlap) rpn_roi = filter_roi(rpn_roi, max_overlap)
bbox = paddle.concat([rpn_roi, gt_bbox]) bbox = paddle.concat([rpn_roi, gt_bbox])
...@@ -161,7 +160,7 @@ def generate_proposal_target(rpn_rois, ...@@ -161,7 +160,7 @@ def generate_proposal_target(rpn_rois,
bbox, gt_bbox, fg_thresh, bg_thresh, False) bbox, gt_bbox, fg_thresh, bg_thresh, False)
# Step2: sample bbox # Step2: sample bbox
sampled_inds, sampled_gt_classes = sample_bbox( sampled_inds, sampled_gt_classes = sample_bbox(
matches, match_labels, gt_classes, batch_size_per_im, fg_fraction, matches, match_labels, gt_class, batch_size_per_im, fg_fraction,
num_classes, use_random) num_classes, use_random)
# Step3: make output # Step3: make output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册