未验证 提交 20e300e2 编写于 作者: H huangjun12 提交者: GitHub

fix lrn bug in reshape size, test=develop (#30968)

上级 8ab29f4b
...@@ -484,17 +484,26 @@ def local_response_norm(x, ...@@ -484,17 +484,26 @@ def local_response_norm(x,
channel_last = True if data_format[-1] == "C" else False channel_last = True if data_format[-1] == "C" else False
from functools import reduce
sum_sizes = reduce(lambda x, y: x * y, sizes[1:])
div = paddle.unsqueeze(paddle.multiply(x, x), axis=1) div = paddle.unsqueeze(paddle.multiply(x, x), axis=1)
if not channel_last: if not channel_last:
pad4d_shape = [0, 0, size // 2, (size - 1) // 2] pad4d_shape = [0, 0, size // 2, (size - 1) // 2]
pool2d_shape = (size, 1) pool2d_shape = (size, 1)
reshape_shape = [sizes[0], 1, sizes[1], sizes[2], -1] reshape_shape = [
sizes[0], 1, sizes[1], sizes[2],
int(sum_sizes / (sizes[1] * sizes[2]))
]
pad5d_shape = [0, 0, 0, 0, size // 2, (size - 1) // 2] pad5d_shape = [0, 0, 0, 0, size // 2, (size - 1) // 2]
pool3d_shape = (size, 1, 1) pool3d_shape = (size, 1, 1)
else: else:
pad4d_shape = [size // 2, (size - 1) // 2, 0, 0] pad4d_shape = [size // 2, (size - 1) // 2, 0, 0]
pool2d_shape = (1, size) pool2d_shape = (1, size)
reshape_shape = [sizes[0], 1, sizes[1], -1, sizes[-1]] reshape_shape = [
sizes[0], 1, sizes[1], int(sum_sizes / (sizes[1] * sizes[-1])),
sizes[-1]
]
pad5d_shape = [size // 2, (size - 1) // 2, 0, 0, 0, 0] pad5d_shape = [size // 2, (size - 1) // 2, 0, 0, 0, 0]
pool3d_shape = (1, 1, size) pool3d_shape = (1, 1, size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册