diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 050b9bce619640bb1a35b60811ff399915b4b865..03ba78e12f6376ce0dfd924e4be7b35229d46e45 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -484,17 +484,26 @@ def local_response_norm(x, channel_last = True if data_format[-1] == "C" else False + from functools import reduce + sum_sizes = reduce(lambda x, y: x * y, sizes[1:]) + div = paddle.unsqueeze(paddle.multiply(x, x), axis=1) if not channel_last: pad4d_shape = [0, 0, size // 2, (size - 1) // 2] pool2d_shape = (size, 1) - reshape_shape = [sizes[0], 1, sizes[1], sizes[2], -1] + reshape_shape = [ + sizes[0], 1, sizes[1], sizes[2], + int(sum_sizes / (sizes[1] * sizes[2])) + ] pad5d_shape = [0, 0, 0, 0, size // 2, (size - 1) // 2] pool3d_shape = (size, 1, 1) else: pad4d_shape = [size // 2, (size - 1) // 2, 0, 0] pool2d_shape = (1, size) - reshape_shape = [sizes[0], 1, sizes[1], -1, sizes[-1]] + reshape_shape = [ + sizes[0], 1, sizes[1], int(sum_sizes / (sizes[1] * sizes[-1])), + sizes[-1] + ] pad5d_shape = [size // 2, (size - 1) // 2, 0, 0, 0, 0] pool3d_shape = (1, 1, size)