From 2285e0c9c1e3de6b4d59d19016b71827e7ed78a9 Mon Sep 17 00:00:00 2001 From: JYChen Date: Tue, 25 Apr 2023 10:34:26 +0800 Subject: [PATCH] fix training error brought by 0-d getitem (#8140) * fix training error brought by 0-d getitem * fix other model --- ppdet/modeling/architectures/faster_rcnn.py | 27 +++++++++++-------- ppdet/modeling/heads/cascade_head.py | 2 +- ppdet/modeling/heads/s2anet_head.py | 2 +- ppdet/modeling/heads/tood_head.py | 10 +++++-- ppdet/modeling/proposal_generator/target.py | 6 ++--- .../transformers/deformable_transformer.py | 5 +++- 6 files changed, 33 insertions(+), 19 deletions(-) diff --git a/ppdet/modeling/architectures/faster_rcnn.py b/ppdet/modeling/architectures/faster_rcnn.py index 41c286fe0..93fd0f9c6 100644 --- a/ppdet/modeling/architectures/faster_rcnn.py +++ b/ppdet/modeling/architectures/faster_rcnn.py @@ -86,15 +86,16 @@ class FasterRCNN(BaseArch): preds, _ = self.bbox_head(body_feats, rois, rois_num, None) im_shape = self.inputs['im_shape'] scale_factor = self.inputs['scale_factor'] - bbox, bbox_num, nms_keep_idx = self.bbox_post_process(preds, (rois, rois_num), - im_shape, scale_factor) + bbox, bbox_num, nms_keep_idx = self.bbox_post_process( + preds, (rois, rois_num), im_shape, scale_factor) # rescale the prediction back to origin image bboxes, bbox_pred, bbox_num = self.bbox_post_process.get_pred( bbox, bbox_num, im_shape, scale_factor) if self.use_extra_data: - extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx + extra_data = { + } # record the bbox output before nms, such like scores and nms_keep_idx """extra_data:{ 'scores': predict scores, 'nms_keep_idx': bbox index before nms, @@ -102,12 +103,12 @@ class FasterRCNN(BaseArch): """ extra_data['scores'] = preds[1] # predict scores (probability) # Todo: get logits output - extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms + extra_data[ + 'nms_keep_idx'] = nms_keep_idx # bbox index before nms return bbox_pred, bbox_num, extra_data else: return bbox_pred, bbox_num - def get_loss(self, ): rpn_loss, bbox_loss = self._forward() loss = {} @@ -120,7 +121,11 @@ class FasterRCNN(BaseArch): def get_pred(self): if self.use_extra_data: bbox_pred, bbox_num, extra_data = self._forward() - output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'extra_data': extra_data} + output = { + 'bbox': bbox_pred, + 'bbox_num': bbox_num, + 'extra_data': extra_data + } else: bbox_pred, bbox_num = self._forward() output = {'bbox': bbox_pred, 'bbox_num': bbox_num} @@ -131,7 +136,7 @@ class FasterRCNN(BaseArch): if self.neck is not None: body_feats = self.neck(body_feats) rois = [roi for roi in data['gt_bbox']] - rois_num = paddle.concat([paddle.shape(roi)[0] for roi in rois]) + rois_num = paddle.concat([paddle.shape(roi)[0:1] for roi in rois]) preds, _ = self.bbox_head(body_feats, rois, rois_num, None, cot=True) return preds @@ -142,13 +147,13 @@ class FasterRCNN(BaseArch): label_list = [] for step_id, data in enumerate(loader): - _, bbox_prob = self.target_bbox_forward(data) + _, bbox_prob = self.target_bbox_forward(data) batch_size = data['im_id'].shape[0] for i in range(batch_size): - num_bbox = data['gt_class'][i].shape[0] + num_bbox = data['gt_class'][i].shape[0] train_labels = data['gt_class'][i] train_labels_list.append(train_labels.numpy().squeeze(1)) - base_labels = bbox_prob.detach().numpy()[:,:-1] + base_labels = bbox_prob.detach().numpy()[:, :-1] label_list.append(base_labels) labels = np.concatenate(train_labels_list, 0) @@ -159,4 +164,4 @@ class FasterRCNN(BaseArch): this_class = probabilities[labels == i] average = np.mean(this_class, axis=0, keepdims=True) conditional.append(average) - return np.concatenate(conditional) \ No newline at end of file + return np.concatenate(conditional) diff --git a/ppdet/modeling/heads/cascade_head.py b/ppdet/modeling/heads/cascade_head.py index bb0beadbd..d6f21d20c 100644 --- a/ppdet/modeling/heads/cascade_head.py +++ b/ppdet/modeling/heads/cascade_head.py @@ -301,7 +301,7 @@ class CascadeHead(BBoxHead): keep = paddle.zeros([1], dtype='int32') clip_box = paddle.gather(clip_box, keep) rois.append(clip_box) - rois_num = paddle.concat([paddle.shape(r)[0] for r in rois]) + rois_num = paddle.concat([paddle.shape(r)[0:1] for r in rois]) return rois, rois_num def _get_pred_bbox(self, deltas, proposals, weights): diff --git a/ppdet/modeling/heads/s2anet_head.py b/ppdet/modeling/heads/s2anet_head.py index 8abddcff1..99fd13a9a 100644 --- a/ppdet/modeling/heads/s2anet_head.py +++ b/ppdet/modeling/heads/s2anet_head.py @@ -360,7 +360,7 @@ class S2ANetHead(nn.Layer): for i in range(bbox_num.shape[0]): expand_shape = paddle.expand(origin_shape[i:i + 1, :], [bbox_num[i], 2]) - scale_y, scale_x = scale_factor[i][0], scale_factor[i][1] + scale_y, scale_x = scale_factor[i, 0:1], scale_factor[i, 1:2] scale = paddle.concat([ scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x, scale_y diff --git a/ppdet/modeling/heads/tood_head.py b/ppdet/modeling/heads/tood_head.py index f463ef239..be840984f 100644 --- a/ppdet/modeling/heads/tood_head.py +++ b/ppdet/modeling/heads/tood_head.py @@ -86,7 +86,10 @@ class TaskDecomposition(nn.Layer): normal_(self.la_conv2.weight, std=0.001) def forward(self, feat, avg_feat): - b, _, h, w = get_static_shape(feat) + feat_shape = get_static_shape(feat) + b = feat_shape[0:1] + h = feat_shape[2:3] + w = feat_shape[3:4] weight = F.relu(self.la_conv1(avg_feat)) weight = F.sigmoid(self.la_conv2(weight)).unsqueeze(-1) feat = paddle.reshape( @@ -204,7 +207,10 @@ class TOODHead(nn.Layer): constant_(self.reg_offset_conv2.bias) def _reg_grid_sample(self, feat, offset, anchor_points): - b, _, h, w = get_static_shape(feat) + feat_shape = get_static_shape(feat) + b = feat_shape[0:1] + h = feat_shape[2:3] + w = feat_shape[3:4] feat = paddle.reshape(feat, [-1, 1, h, w]) offset = paddle.reshape(offset, [-1, 2, h, w]).transpose([0, 2, 3, 1]) grid_shape = paddle.concat([w, h]).astype('float32') diff --git a/ppdet/modeling/proposal_generator/target.py b/ppdet/modeling/proposal_generator/target.py index f95f906a2..041b2c791 100644 --- a/ppdet/modeling/proposal_generator/target.py +++ b/ppdet/modeling/proposal_generator/target.py @@ -237,7 +237,7 @@ def generate_proposal_target(rpn_rois, tgt_bboxes.append(sampled_bbox) rois_with_gt.append(rois_per_image) tgt_gt_inds.append(sampled_gt_ind) - new_rois_num.append(paddle.shape(sampled_inds)[0]) + new_rois_num.append(paddle.shape(sampled_inds)[0:1]) new_rois_num = paddle.concat(new_rois_num) return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num @@ -380,7 +380,7 @@ def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds, mask_index.append(fg_inds) mask_rois.append(fg_rois) - mask_rois_num.append(paddle.shape(fg_rois)[0]) + mask_rois_num.append(paddle.shape(fg_rois)[0:1]) tgt_classes.append(fg_classes) tgt_masks.append(tgt_mask) tgt_weights.append(weight) @@ -672,7 +672,7 @@ def libra_generate_proposal_target(rpn_rois, rois_with_gt.append(rois_per_image) sampled_max_overlaps.append(sampled_overlap) tgt_gt_inds.append(sampled_gt_ind) - new_rois_num.append(paddle.shape(sampled_inds)[0]) + new_rois_num.append(paddle.shape(sampled_inds)[0:1]) new_rois_num = paddle.concat(new_rois_num) # rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num diff --git a/ppdet/modeling/transformers/deformable_transformer.py b/ppdet/modeling/transformers/deformable_transformer.py index fcb5a0aab..ab05704f4 100644 --- a/ppdet/modeling/transformers/deformable_transformer.py +++ b/ppdet/modeling/transformers/deformable_transformer.py @@ -486,7 +486,10 @@ class DeformableTransformer(nn.Layer): spatial_shapes = [] valid_ratios = [] for level, src in enumerate(srcs): - bs, _, h, w = paddle.shape(src) + src_shape = paddle.shape(src) + bs = src_shape[0:1] + h = src_shape[2:3] + w = src_shape[3:4] spatial_shapes.append(paddle.concat([h, w])) src = src.flatten(2).transpose([0, 2, 1]) src_flatten.append(src) -- GitLab