未验证 提交 4ee74daf 编写于 作者: W wangxinxin08 提交者: GitHub

fix expand shape problem (#3745)

上级 090aa07d
......@@ -24,28 +24,23 @@ __all__ = ['YOLOv3FPN', 'PPYOLOFPN', 'PPYOLOTinyFPN', 'PPYOLOPAN']
def add_coord(x, data_format):
b = x.shape[0]
if data_format == 'NCHW':
h = x.shape[2]
w = x.shape[3]
b, _, h, w = paddle.shape(x)
else:
h = x.shape[1]
w = x.shape[2]
b, h, w, _ = paddle.shape(x)
gx = paddle.arange(w, dtype='float32') / (w - 1.) * 2.0 - 1.
if data_format == 'NCHW':
gx = gx.reshape([1, 1, 1, w]).expand([b, 1, h, w])
else:
gx = gx.reshape([1, 1, w, 1]).expand([b, h, w, 1])
gx.stop_gradient = True
gx = paddle.arange(w, dtype=x.dtype) / ((w - 1.) * 2.0) - 1.
gy = paddle.arange(h, dtype=x.dtype) / ((h - 1.) * 2.0) - 1.
gy = paddle.arange(h, dtype='float32') / (h - 1.) * 2.0 - 1.
if data_format == 'NCHW':
gx = gx.reshape([1, 1, 1, w]).expand([b, 1, h, w])
gy = gy.reshape([1, 1, h, 1]).expand([b, 1, h, w])
else:
gx = gx.reshape([1, 1, w, 1]).expand([b, h, w, 1])
gy = gy.reshape([1, h, 1, 1]).expand([b, h, w, 1])
gy.stop_gradient = True
gx.stop_gradient = True
gy.stop_gradient = True
return gx, gy
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册