未验证 提交 cfeeec52 编写于 作者: W wangxinxin08 提交者: GitHub

fix trt problem (#2107)

上级 93e2a447
...@@ -86,18 +86,18 @@ class YOLOv3Head(nn.Layer): ...@@ -86,18 +86,18 @@ class YOLOv3Head(nn.Layer):
ioup, x = out[:, 0:na, :, :], out[:, na:, :, :] ioup, x = out[:, 0:na, :, :], out[:, na:, :, :]
b, c, h, w = x.shape b, c, h, w = x.shape
no = c // na no = c // na
x = x.reshape((b, na, no, h, w)) x = x.reshape((b, na, no, h * w))
ioup = ioup.reshape((b, na, 1, h, w)) ioup = ioup.reshape((b, na, 1, h * w))
obj = x[:, :, 4:5, :, :] obj = x[:, :, 4:5, :]
ioup = F.sigmoid(ioup) ioup = F.sigmoid(ioup)
obj = F.sigmoid(obj) obj = F.sigmoid(obj)
obj_t = (obj**(1 - self.iou_aware_factor)) * ( obj_t = (obj**(1 - self.iou_aware_factor)) * (
ioup**self.iou_aware_factor) ioup**self.iou_aware_factor)
obj_t = _de_sigmoid(obj_t) obj_t = _de_sigmoid(obj_t)
loc_t = x[:, :, :4, :, :] loc_t = x[:, :, :4, :]
cls_t = x[:, :, 5:, :, :] cls_t = x[:, :, 5:, :]
y_t = paddle.concat([loc_t, obj_t, cls_t], axis=2) y_t = paddle.concat([loc_t, obj_t, cls_t], axis=2)
y_t = y_t.reshape((b, -1, h, w)) y_t = y_t.reshape((b, c, h, w))
y.append(y_t) y.append(y_t)
return y return y
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册