diff --git a/ppdet/modeling/necks/yolo_fpn.py b/ppdet/modeling/necks/yolo_fpn.py index e829a379f6ed3f231410162fce660b0dc1a42fbe..bc95848747f8a622f885c2a0673cfe1afc9bb1ef 100644 --- a/ppdet/modeling/necks/yolo_fpn.py +++ b/ppdet/modeling/necks/yolo_fpn.py @@ -24,11 +24,11 @@ __all__ = ['YOLOv3FPN', 'PPYOLOFPN', 'PPYOLOTinyFPN', 'PPYOLOPAN'] def add_coord(x, data_format): - shape = paddle.shape(x) + b = paddle.shape(x)[0] if data_format == 'NCHW': - b, h, w = shape[0], shape[2], shape[3] + h, w = x.shape[2], x.shape[3] else: - b, h, w = shape[0], shape[1], shape[2] + h, w = x.shape[1], x.shape[2] gx = paddle.arange(w, dtype=x.dtype) / ((w - 1.) * 2.0) - 1. gy = paddle.arange(h, dtype=x.dtype) / ((h - 1.) * 2.0) - 1.