From ba2aad26e6bc1e5c2dad76ca96692a0d63eccfac Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Mon, 31 Oct 2022 20:12:40 +0800 Subject: [PATCH] fix dynamic shape of reshape op when export model (#7230) --- ppdet/modeling/heads/gfl_head.py | 2 +- ppdet/modeling/heads/pico_head.py | 13 +++++++------ ppdet/modeling/heads/ppyoloe_head.py | 8 ++++---- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/ppdet/modeling/heads/gfl_head.py b/ppdet/modeling/heads/gfl_head.py index aa6dc478b..252d99843 100644 --- a/ppdet/modeling/heads/gfl_head.py +++ b/ppdet/modeling/heads/gfl_head.py @@ -260,7 +260,7 @@ class GFLHead(nn.Layer): center_points = paddle.stack([x, y], axis=-1) cls_score = cls_score.reshape([b, -1, self.cls_out_channels]) bbox_pred = self.distribution_project(bbox_pred) * stride - bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4]) + bbox_pred = bbox_pred.reshape([-1, cell_h * cell_w, 4]) # NOTE: If keep_ratio=False and image shape value that # multiples of 32, distance2bbox not set max_shapes parameter diff --git a/ppdet/modeling/heads/pico_head.py b/ppdet/modeling/heads/pico_head.py index 07cc3c7ca..3d41db0d8 100644 --- a/ppdet/modeling/heads/pico_head.py +++ b/ppdet/modeling/heads/pico_head.py @@ -353,13 +353,13 @@ class PicoHead(OTAVFLHead): bbox_pred = bbox_pred.reshape([1, (self.reg_max + 1) * 4, -1]).transpose([0, 2, 1]) else: - b, _, h, w = fpn_feat.shape + _, _, h, w = fpn_feat.shape l = h * w cls_score_out = F.sigmoid( - cls_score.reshape([b, self.cls_out_channels, l])) + cls_score.reshape([-1, self.cls_out_channels, l])) bbox_pred = bbox_pred.transpose([0, 2, 3, 1]) bbox_pred = self.distribution_project(bbox_pred) - bbox_pred = bbox_pred.reshape([b, l, 4]) + bbox_pred = bbox_pred.reshape([-1, l, 4]) cls_logits_list.append(cls_score_out) bboxes_reg_list.append(bbox_pred) @@ -597,7 +597,7 @@ class PicoHeadV2(GFLHead): anchor_points, stride_tensor = self._generate_anchors(fpn_feats) cls_score_list, box_list = [], [] for i, (fpn_feat, stride) in enumerate(zip(fpn_feats, self.fpn_stride)): - b, _, h, w = fpn_feat.shape + _, _, h, w = fpn_feat.shape # task decomposition conv_cls_feat, se_feat = self.conv_feat(fpn_feat, i) cls_logit = self.head_cls_list[i](se_feat) @@ -620,10 +620,11 @@ class PicoHeadV2(GFLHead): [0, 2, 1])) else: l = h * w - cls_score_out = cls_score.reshape([b, self.cls_out_channels, l]) + cls_score_out = cls_score.reshape( + [-1, self.cls_out_channels, l]) bbox_pred = reg_pred.transpose([0, 2, 3, 1]) bbox_pred = self.distribution_project(bbox_pred) - bbox_pred = bbox_pred.reshape([b, l, 4]) + bbox_pred = bbox_pred.reshape([-1, l, 4]) cls_score_list.append(cls_score_out) box_list.append(bbox_pred) diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py index cdcf2bce5..279412066 100644 --- a/ppdet/modeling/heads/ppyoloe_head.py +++ b/ppdet/modeling/heads/ppyoloe_head.py @@ -192,7 +192,7 @@ class PPYOLOEHead(nn.Layer): anchor_points, stride_tensor = self._generate_anchors(feats) cls_score_list, reg_dist_list = [], [] for i, feat in enumerate(feats): - b, _, h, w = feat.shape + _, _, h, w = feat.shape l = h * w avg_feat = F.adaptive_avg_pool2d(feat, (1, 1)) cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) + @@ -203,7 +203,7 @@ class PPYOLOEHead(nn.Layer): reg_dist = self.proj_conv(F.softmax(reg_dist, axis=1)).squeeze(1) # cls and reg cls_score = F.sigmoid(cls_logit) - cls_score_list.append(cls_score.reshape([b, self.num_classes, l])) + cls_score_list.append(cls_score.reshape([-1, self.num_classes, l])) reg_dist_list.append(reg_dist) cls_score_list = paddle.concat(cls_score_list, axis=-1) @@ -238,8 +238,8 @@ class PPYOLOEHead(nn.Layer): return loss def _bbox_decode(self, anchor_points, pred_dist): - b, l, _ = get_static_shape(pred_dist) - pred_dist = F.softmax(pred_dist.reshape([b, l, 4, self.reg_max + 1])) + _, l, _ = get_static_shape(pred_dist) + pred_dist = F.softmax(pred_dist.reshape([-1, l, 4, self.reg_max + 1])) pred_dist = self.proj_conv(pred_dist.transpose([0, 3, 1, 2])).squeeze(1) return batch_distance2bbox(anchor_points, pred_dist) -- GitLab