From cfeeec52ef388119f4fc8814eb3a8ecde1b440c1 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Thu, 21 Jan 2021 20:39:52 -0600 Subject: [PATCH] fix trt problem (#2107) --- dygraph/ppdet/modeling/heads/yolo_head.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dygraph/ppdet/modeling/heads/yolo_head.py b/dygraph/ppdet/modeling/heads/yolo_head.py index c88f26759..cd225d576 100644 --- a/dygraph/ppdet/modeling/heads/yolo_head.py +++ b/dygraph/ppdet/modeling/heads/yolo_head.py @@ -86,18 +86,18 @@ class YOLOv3Head(nn.Layer): ioup, x = out[:, 0:na, :, :], out[:, na:, :, :] b, c, h, w = x.shape no = c // na - x = x.reshape((b, na, no, h, w)) - ioup = ioup.reshape((b, na, 1, h, w)) - obj = x[:, :, 4:5, :, :] + x = x.reshape((b, na, no, h * w)) + ioup = ioup.reshape((b, na, 1, h * w)) + obj = x[:, :, 4:5, :] ioup = F.sigmoid(ioup) obj = F.sigmoid(obj) obj_t = (obj**(1 - self.iou_aware_factor)) * ( ioup**self.iou_aware_factor) obj_t = _de_sigmoid(obj_t) - loc_t = x[:, :, :4, :, :] - cls_t = x[:, :, 5:, :, :] + loc_t = x[:, :, :4, :] + cls_t = x[:, :, 5:, :] y_t = paddle.concat([loc_t, obj_t, cls_t], axis=2) - y_t = y_t.reshape((b, -1, h, w)) + y_t = y_t.reshape((b, c, h, w)) y.append(y_t) return y else: -- GitLab