From 4ee74daf80b062730e36f5ca49bd7876e0801526 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Thu, 22 Jul 2021 15:33:09 +0800 Subject: [PATCH] fix expand shape problem (#3745) --- ppdet/modeling/necks/yolo_fpn.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/ppdet/modeling/necks/yolo_fpn.py b/ppdet/modeling/necks/yolo_fpn.py index e8913131e..041ccd2dc 100644 --- a/ppdet/modeling/necks/yolo_fpn.py +++ b/ppdet/modeling/necks/yolo_fpn.py @@ -24,28 +24,23 @@ __all__ = ['YOLOv3FPN', 'PPYOLOFPN', 'PPYOLOTinyFPN', 'PPYOLOPAN'] def add_coord(x, data_format): - b = x.shape[0] if data_format == 'NCHW': - h = x.shape[2] - w = x.shape[3] + b, _, h, w = paddle.shape(x) else: - h = x.shape[1] - w = x.shape[2] + b, h, w, _ = paddle.shape(x) - gx = paddle.arange(w, dtype='float32') / (w - 1.) * 2.0 - 1. - if data_format == 'NCHW': - gx = gx.reshape([1, 1, 1, w]).expand([b, 1, h, w]) - else: - gx = gx.reshape([1, 1, w, 1]).expand([b, h, w, 1]) - gx.stop_gradient = True + gx = paddle.arange(w, dtype=x.dtype) / ((w - 1.) * 2.0) - 1. + gy = paddle.arange(h, dtype=x.dtype) / ((h - 1.) * 2.0) - 1. - gy = paddle.arange(h, dtype='float32') / (h - 1.) * 2.0 - 1. if data_format == 'NCHW': + gx = gx.reshape([1, 1, 1, w]).expand([b, 1, h, w]) gy = gy.reshape([1, 1, h, 1]).expand([b, 1, h, w]) else: + gx = gx.reshape([1, 1, w, 1]).expand([b, h, w, 1]) gy = gy.reshape([1, h, 1, 1]).expand([b, h, w, 1]) - gy.stop_gradient = True + gx.stop_gradient = True + gy.stop_gradient = True return gx, gy -- GitLab